# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/mistral.py
# SPDX-License-Identifier: Apache-2.0
import json
from pathlib import Path
from typing import Any

from transformers import PretrainedConfig, WhisperConfig

from sglang.srt.utils import logger


def adapt_config_dict(
    config_dict: dict[str, Any], model: str, **kwargs
) -> tuple[dict, PretrainedConfig]:
    config_dict.update(kwargs)
    config_dict = _remap_general_mistral_args(config_dict)

    if bool(config_dict.get("quantization")):
        config_dict = _remap_mistral_quantization_args(config_dict)

    is_moe = bool(config_dict.get("moe"))
    is_mistral_large_3 = (
        is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
    )
    is_eagle = "eagle" in model.lower()
    if is_moe:
        if is_mistral_large_3:
            config_dict = _remap_moe_args(config_dict)
            config_dict["model_type"] = "deepseek_v3"
            if is_eagle:
                config_dict["architectures"] = ["MistralLarge3ForCausalLMEagle"]
            else:
                config_dict["architectures"] = ["MistralLarge3ForCausalLM"]

            assert (
                "llama_4_scaling" in config_dict
            ), "MistralLarge3 expect llama4 scaling config."
            llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
            assert all(
                [
                    key in config_dict["llama_4_scaling"]
                    for key in llama_4_scaling_config_keys
                ]
            ), (
                "llama_4_scaling config should define the keys: "
                f"{','.join(llama_4_scaling_config_keys)}"
            )
        else:
            config_dict["architectures"] = ["MixtralForCausalLM"]
    else:
        config_dict["architectures"] = ["MistralForCausalLM"]

    if bool(config_dict.get("yarn")):
        config_dict = _remap_mistral_yarn_args(config_dict)

    if bool(config_dict.get("llama_4_scaling")):
        llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
        assert all(
            [
                key in config_dict["llama_4_scaling"]
                for key in llama_4_scaling_config_keys
            ]
        ), (
            "llama_4_scaling config should define the keys: "
            f"{','.join(llama_4_scaling_config_keys)}"
        )

    is_vision = bool(
        (config_dict.get("multimodal") or {}).get("vision_encoder_args")
        or config_dict.get("vision_encoder")
    )
    is_audio = bool(
        ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
            "encoder_args"
        )
    )

    assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"

    if is_vision:
        config_dict = _remap_mistral_vision_args(config_dict)
    if is_audio:
        config_dict = _remap_mistral_audio_args(config_dict)

    config = PretrainedConfig.from_dict(config_dict)

    logger.debug("Initialized config %s", config)

    return config_dict, config


def _remap_mistral_vision_args(config: dict) -> dict:
    if config.get("multimodal"):
        vision_config = config.pop("multimodal")
    else:
        vision_config = config.pop("vision_encoder")

    quant_config = config.get("quantization_config")

    config = {
        "model_type": "pixtral",
        "architectures": ["PixtralForConditionalGeneration"],
        "text_config": config,
        "vision_config": {"model_type": "pixtral", **vision_config},
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_mistral_yarn_args(config: dict) -> dict:
    yarn_config_map = {
        "factor": "factor",
        "original_max_position_embeddings": "original_max_position_embeddings",
        "beta": "beta_fast",
        "alpha": "beta_slow",
        "apply_scale": None,
    }
    yarn_config = config.get("yarn") or {}
    config["rope_scaling"] = {
        "rope_type": "yarn",
        "mscale_all_dim": 1,
    }
    for old_name, new_name in yarn_config_map.items():
        if old_name in yarn_config:
            value = yarn_config.pop(old_name)
            if new_name is not None:
                config["rope_scaling"][new_name] = value

    assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

    return config


def _remap_general_mistral_args(config: dict) -> dict:
    # Mistral key -> HF key
    config_mapping = {
        "dim": "hidden_size",
        "norm_eps": "rms_norm_eps",
        "n_kv_heads": "num_key_value_heads",
        "n_layers": "num_hidden_layers",
        "n_heads": "num_attention_heads",
        "hidden_dim": "intermediate_size",
    }
    # HF key -> (Mistral key, default value)
    top_level_mapping_with_default = {
        "model_type": ("model_type", "transformer"),
        "hidden_act": ("activation", "silu"),
        "tie_word_embeddings": ("tied_embeddings", False),
        "max_seq_len": ("max_seq_len", 128_000),
        "max_position_embeddings": ("max_position_embeddings", 128_000),
    }

    for key, new_key in config_mapping.items():
        if key in config:
            config[new_key] = config.pop(key)

    for new_key, (key, default_value) in top_level_mapping_with_default.items():
        config[new_key] = config.pop(key, default_value)

    return config


def _remap_mistral_quantization_args(config: dict) -> dict:
    if config.get("quantization"):
        quantization = config.pop("quantization", {})
        if quantization.get("qformat_weight") == "fp8_e4m3":
            qscheme_act = quantization.get("qscheme_act")
            assert qscheme_act in (
                "NO_SCALES",
                "TENSOR",
                None,
            ), "Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
            is_dynamic = qscheme_act == "NO_SCALES"
            config["quantization_config"] = {
                "quant_method": "fp8",
                "activation_scheme": "dynamic" if is_dynamic else "static",
            }
        else:
            raise ValueError(f"Found unknown quantization='{quantization}' in config")

    return config


def _remap_mistral_audio_args(config: dict) -> dict:
    whisper_args = config["multimodal"].pop("whisper_model_args")
    encoder_args = whisper_args["encoder_args"]
    downsample_args = whisper_args["downsample_args"]

    quant_config = config.get("quantization_config")
    config = {
        "model_type": "whixtral",
        "architectures": ["VoxtralForConditionalGeneration"],
        "text_config": PretrainedConfig.from_dict(config),
        "audio_config": WhisperConfig(
            num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
            window_size=encoder_args["audio_encoding_args"]["window_size"],
            sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
            hop_length=encoder_args["audio_encoding_args"]["hop_length"],
            downsample_factor=downsample_args["downsample_factor"],
            d_model=encoder_args["dim"],
            encoder_layers=encoder_args["n_layers"],
            encoder_ffn_dim=encoder_args["hidden_dim"],
            encoder_attention_heads=encoder_args["n_heads"],
            vocab_size=encoder_args["vocab_size"],
            max_source_positions=encoder_args["max_source_positions"],
            is_encoder_decoder=False,  # Override WhisperConfig default
        ),
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_moe_args(config: dict) -> dict:
    moe_config_map = {
        "route_every_n": "moe_layer_freq",
        "first_k_dense_replace": "first_k_dense_replace",
        "num_experts_per_tok": "num_experts_per_tok",
        "num_experts": "n_routed_experts",
        "expert_hidden_dim": "moe_intermediate_size",
        "routed_scale": "routed_scaling_factor",
        "num_shared_experts": "n_shared_experts",
        "num_expert_groups": "n_group",
        "num_expert_groups_per_tok": "topk_group",
    }
    moe_config = config.get("moe", {})
    for old_name, new_name in moe_config_map.items():
        if old_name in moe_config:
            value = moe_config.pop(old_name)
            config[new_name] = value

    config["topk_method"] = None
    config["scoring_func"] = "softmax"
    config["routing_method_type"] = 1  # RoutingMethodType.Renormalize

    return config


class MistralConfigParser:
    def get_hf_file_to_dict(
        self, file_name: str, model: str | Path, revision: str | None = "main"
    ):
        file_path = Path(model) / file_name
        if not file_path.is_file():
            # TODO: Add logic to download from HF in case file is not locally found
            raise FileNotFoundError(f"File not found {model}, {file_name}")

        if file_path is not None and file_path.is_file():
            with open(file_path) as file:
                return json.load(file)

        return None

    def _download_mistral_config_file(self, model, revision) -> dict:
        config_file_name = "params.json"
        config_dict = self.get_hf_file_to_dict(config_file_name, model, revision)
        if config_dict is None:
            raise ValueError(
                f"Failed to load mistral '{config_file_name}' config for model "
                f"{model}. Please check if the model is a mistral-format model "
                f"and if the config file exists."
            )
        assert isinstance(config_dict, dict)
        return config_dict

    def parse(
        self,
        model: str | Path,
        revision: str | None = None,
        **kwargs,
    ) -> tuple[dict, PretrainedConfig]:
        # This function loads a params.json config which
        # should be used when loading models in mistral format
        config_dict = self._download_mistral_config_file(model, revision)
        if config_dict.get("max_position_embeddings") is None:
            logger.warning(
                "The params.json file is missing 'max_position_embeddings'"
                " and could not get a value from the HF config."
                " Defaulting to 128000"
            )
            config_dict["max_position_embeddings"] = 128_000

        config_dict, config = adapt_config_dict(config_dict, model)

        # Mistral configs may define sliding_window as list[int]. Convert it
        # to int and add the layer_types list[str] to make it HF compatible
        if (sliding_window := getattr(config, "sliding_window", None)) and isinstance(
            sliding_window, list
        ):
            pattern_repeats = config.num_hidden_layers // len(sliding_window)
            layer_types = sliding_window * pattern_repeats
            config.layer_types = [
                "full_attention" if layer_type is None else "sliding_attention"
                for layer_type in layer_types
            ]
            config.sliding_window = next(filter(None, sliding_window), None)

        return config_dict, config
