# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import logging
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union

import orjson

from sglang.srt.configs.modelopt_config import ModelOptConfig
from sglang.srt.utils import is_hip

logger = logging.getLogger(__name__)


class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    SHARDED_STATE = "sharded_state"
    GGUF = "gguf"
    BITSANDBYTES = "bitsandbytes"
    MISTRAL = "mistral"
    LAYERED = "layered"
    FLASH_RL = "flash_rl"  # For RL training with quantized models
    JAX = "jax"
    REMOTE = "remote"
    REMOTE_INSTANCE = "remote_instance"
    RDMA = "rdma"
    LOCAL_CACHED = "local_cached"
    FASTSAFETENSORS = "fastsafetensors"
    PRIVATE = "private"


@dataclass
class LoadConfig:
    """
    download_dir: Directory to download and load the weights, default to the
        default cache directory of huggingface.
    load_format: The format of the model weights to load:
        "auto" will try to load the weights in the safetensors format and
            fall back to the pytorch bin format if safetensors format is
            not available.
        "pt" will load the weights in the pytorch bin format.
        "safetensors" will load the weights in the safetensors format.
        "npcache" will load the weights in pytorch format and store
            a numpy cache to speed up the loading.
        "dummy" will initialize the weights with random values, which is
            mainly for profiling.
        "bitsandbytes" will load nf4 type weights.
        "flash_rl" will load weights with support for RL training
            with quantized models, enabling efficient weight reloading.
    ignore_patterns: The list of patterns to ignore when loading the model.
        Default to "original/**/*" to avoid repeated loading of llama's
        checkpoints.
    decryption_key_file: If set, decrypts the output files with a password read
        from this file (after PBKDF2).
    decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.

    # ModelOpt-specific loading options
    modelopt_checkpoint_restore_path: Optional[str] = None
    modelopt_checkpoint_save_path: Optional[str] = None
    modelopt_export_path: Optional[str] = None
    """

    load_format: Union[str, LoadFormat] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
    ignore_patterns: Optional[Union[List[str], str]] = None
    decryption_key_file: Optional[str] = None
    decrypt_max_concurrency: int = -1
    tp_rank: Optional[int] = None
    remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
    remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
    remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
    remote_instance_weight_loader_backend: Optional[str] = None
    remote_instance_weight_loader_transfer_engine: Optional[Any] = None

    # ModelOpt-specific loading options
    modelopt_checkpoint_restore_path: Optional[str] = None
    modelopt_checkpoint_save_path: Optional[str] = None
    modelopt_export_path: Optional[str] = None

    # ModelOpt configuration object
    modelopt_config: Optional[ModelOptConfig] = None

    # QuantizedRL-specific options (for FlashRL-style quantization)
    rl_quant_profile: Optional[str] = (
        None  # Path to rollout quantization profile (e.g., /root/profile.7b.pt)
    )

    # For multi-layer MTP
    draft_model_idx: Optional[int] = None

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
        self._verify_load_format()

        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns,
            )
        else:
            self.ignore_patterns = ["original/**/*"]

        # Create ModelOptConfig if not provided
        if self.modelopt_config is None:
            self.modelopt_config = ModelOptConfig(
                checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
                checkpoint_save_path=self.modelopt_checkpoint_save_path,
                export_path=self.modelopt_export_path,
            )

    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f
                for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}"
            )
