# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Huggingface Transformers."""

import contextlib
import json
import logging
import os
import tempfile
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

import torch
from huggingface_hub import snapshot_download

from sglang.srt.utils import get_bool_env_var

# Conditional import based on SGLANG_USE_MODELSCOPE environment variable
if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
    from modelscope import AutoConfig, GenerationConfig
else:
    from transformers import AutoConfig, GenerationConfig

from transformers import (
    AutoProcessor,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from sglang.srt.configs import (
    AfmoeConfig,
    BailingHybridConfig,
    ChatGLMConfig,
    DbrxConfig,
    DeepseekVL2Config,
    DotsOCRConfig,
    DotsVLMConfig,
    ExaoneConfig,
    FalconH1Config,
    GraniteMoeHybridConfig,
    JetNemotronConfig,
    JetVLMConfig,
    KimiK25Config,
    KimiLinearConfig,
    KimiVLConfig,
    LongcatFlashConfig,
    MultiModalityConfig,
    NemotronH_Nano_VL_V2_Config,
    NemotronHConfig,
    Olmo3Config,
    Qwen3_5Config,
    Qwen3_5MoeConfig,
    Qwen3NextConfig,
    Step3p5Config,
    Step3VLConfig,
)
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils
from sglang.srt.utils.patch_tokenizer import patch_tokenizer

_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
    AfmoeConfig,
    BailingHybridConfig,
    ChatGLMConfig,
    DbrxConfig,
    ExaoneConfig,
    DeepseekVL2Config,
    MultiModalityConfig,
    KimiVLConfig,
    InternVLChatConfig,
    Step3VLConfig,
    LongcatFlashConfig,
    Olmo3Config,
    KimiLinearConfig,
    Qwen3NextConfig,
    FalconH1Config,
    GraniteMoeHybridConfig,
    DotsVLMConfig,
    DotsOCRConfig,
    NemotronH_Nano_VL_V2_Config,
    NemotronHConfig,
    DeepseekVLV2Config,
    Qwen3_5Config,
    Qwen3_5MoeConfig,
    JetNemotronConfig,
    JetVLMConfig,
    KimiK25Config,
    Step3p5Config,
]

_CONFIG_REGISTRY = {
    config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY
}

for name, cls in _CONFIG_REGISTRY.items():
    with contextlib.suppress(ValueError):
        AutoConfig.register(name, cls)


def download_from_hf(
    model_path: str,
    allow_patterns: Optional[Union[str, list]] = None,
):
    if os.path.exists(model_path):
        return model_path

    if not allow_patterns:
        allow_patterns = ["*.json", "*.bin", "*.model"]

    return snapshot_download(model_path, allow_patterns=allow_patterns)


def get_hf_text_config(config: PretrainedConfig):
    """Get the "sub" config relevant to llm for multi modal models.
    No op for pure text models.
    """
    if config.architectures is not None:
        class_name = config.architectures[0]
        if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
            # We support non-hf version of llava models, so we do not want to
            # read the wrong values from the unused default text_config.
            # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
            # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
            setattr(config, "dtype", torch.float16)
            return config

    if hasattr(config, "text_config"):
        # The code operates under the assumption that text_config should have
        # `num_attention_heads` (among others). Assert here to fail early
        # if transformers config doesn't align with this assumption.
        assert hasattr(config.text_config, "num_attention_heads")
        return config.text_config

    if hasattr(config, "llm_config"):
        # PointsV1.5 Chat Model
        assert hasattr(config.llm_config, "num_attention_heads")
        return config.llm_config

    if hasattr(config, "language_config"):
        return config.language_config
    if hasattr(config, "thinker_config"):
        # qwen2.5 omni
        thinker_config = config.thinker_config
        if hasattr(thinker_config, "text_config"):
            setattr(
                thinker_config.text_config,
                "torch_dtype",
                getattr(thinker_config, "torch_dtype", None),
            )
            return thinker_config.text_config
        return thinker_config
    if hasattr(config, "llm_config"):
        return config.llm_config
    else:
        return config


# Temporary hack for DeepSeek-V3.2 model
def _load_deepseek_v32_model(
    model_path: str,
    trust_remote_code: bool = False,
    revision: Optional[str] = None,
    **kwargs,
):
    # first get the local path
    local_path = download_from_hf(model_path)
    # then load the config file in json
    config_file = os.path.join(local_path, "config.json")
    if not os.path.exists(config_file):
        raise RuntimeError(f"Can't find config file in {local_path}.")

    with open(config_file, "r") as f:
        config_json = json.load(f)

    config_json["architectures"] = ["DeepseekV3ForCausalLM"]
    config_json["model_type"] = "deepseek_v3"

    tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder")
    os.makedirs(tmp_path, exist_ok=True)

    unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
    with open(unique_path, "w") as f:
        json.dump(config_json, f)

    return AutoConfig.from_pretrained(
        unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
    )


# Temporary hack for Mistral Large
def _load_mistral_large_3_for_causal_LM(
    model_path: str,
    trust_remote_code: bool = False,
    revision: Optional[str] = None,
    **kwargs,
):
    # first get the local path
    local_path = download_from_hf(model_path)
    # then load the config file in json
    parser = mistral_utils.MistralConfigParser()
    config_dict, _ = parser.parse(local_path)

    with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as f:
        json.dump(config_dict, f)
        f.flush()
        loaded_config = AutoConfig.from_pretrained(
            f.name, trust_remote_code=trust_remote_code, revision=revision, **kwargs
        )
    text_config = getattr(loaded_config, "text_config", None)
    if text_config is not None and isinstance(text_config, dict):
        text_config = AutoConfig.for_model(**text_config)
        setattr(loaded_config, "text_config", text_config)
    vision_config = getattr(loaded_config, "vision_config", None)
    if vision_config is not None and isinstance(vision_config, dict):
        vision_config = AutoConfig.for_model(**vision_config)
        setattr(loaded_config, "vision_config", vision_config)

    return loaded_config


def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool:
    # TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr.
    # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
    auto_map = getattr(config, "auto_map", None) or {}
    return auto_map.get("AutoModel") == "modeling_deepseekocr.DeepseekOCRForCausalLM"


def _is_deepseek_ocr2_model(config: PretrainedConfig) -> bool:
    auto_map = getattr(config, "auto_map", None) or {}
    return auto_map.get("AutoModel") == "modeling_deepseekocr2.DeepseekOCR2ForCausalLM"


def _override_deepseek_ocr_v_head_dim(config: DeepseekVLV2Config) -> None:
    # FIXME: deepseek-ocr's v_head_dim is set to 0 in its config file.
    # https://huggingface.co/deepseek-ai/DeepSeek-OCR/blob/main/config.json#L116
    if config.text_config.v_head_dim == 0:
        V_HEAD_DIM_PATCH = 128
        config.text_config.v_head_dim = V_HEAD_DIM_PATCH
        logger.warning(
            f"Overriding deepseek-ocr's v_head_dim from 0 to {V_HEAD_DIM_PATCH} to avoid potential issues."
        )


def _override_v_head_dim_if_zero(config: PretrainedConfig, patch: int = 128) -> None:
    text_config = getattr(config, "text_config", None)
    language_config = getattr(config, "language_config", None)
    target = text_config or language_config
    if target is None:
        return
    if getattr(target, "v_head_dim", None) == 0:
        setattr(target, "v_head_dim", patch)
        logger.warning(
            f"Overriding v_head_dim from 0 to {patch} to avoid potential issues."
        )


def _ensure_llama_flash_attention2_compat() -> None:
    """Ensure LlamaFlashAttention2 symbol exists for remote code compatibility."""
    try:
        from transformers.models.llama import modeling_llama
    except Exception:
        return
    if not hasattr(modeling_llama, "LlamaFlashAttention2"):
        if hasattr(modeling_llama, "LlamaAttention"):
            modeling_llama.LlamaFlashAttention2 = modeling_llama.LlamaAttention


@lru_cache_frozenset(maxsize=32)
def get_config(
    model: str,
    trust_remote_code: bool,
    revision: Optional[str] = None,
    model_override_args: Optional[dict] = None,
    **kwargs,
):
    is_gguf = check_gguf_file(model)
    if is_gguf:
        kwargs["gguf_file"] = model
        model = Path(model).parent

    if is_remote_url(model):
        # BaseConnector implements __del__() to clean up the local dir.
        # Since config files need to exist all the time, so we DO NOT use
        # with statement to avoid closing the client.
        client = create_remote_connector(model)
        client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
        model = client.get_local_dir()

    if "mistral-large-3" in str(model).lower():
        config = _load_mistral_large_3_for_causal_LM(
            model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
        )
    else:
        _ensure_llama_flash_attention2_compat()
        try:
            config = AutoConfig.from_pretrained(
                model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
            )
        except ValueError as e:
            if not "deepseek_v32" in str(e):
                raise e
            config = _load_deepseek_v32_model(
                model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
            )

    if (
        config.architectures is not None
        and config.architectures[0] == "Phi4MMForCausalLM"
    ):
        # Phi4MMForCausalLM uses a hard-coded vision_config. See:
        # https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
        # We set it here to support cases where num_attention_heads is not divisible by the TP size.
        from transformers import SiglipVisionConfig

        vision_config = {
            "hidden_size": 1152,
            "image_size": 448,
            "intermediate_size": 4304,
            "model_type": "siglip_vision_model",
            "num_attention_heads": 16,
            "num_hidden_layers": 26,
            # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
            "patch_size": 14,
        }
        config.vision_config = SiglipVisionConfig(**vision_config)
    text_config = get_hf_text_config(config=config)

    if isinstance(model, str) and text_config is not None:
        items = (
            text_config.items()
            if hasattr(text_config, "items")
            else vars(text_config).items()
        )
        for key, val in items:
            if not hasattr(config, key) and val is not None:
                setattr(config, key, val)

    if _is_deepseek_ocr2_model(config):
        _override_v_head_dim_if_zero(config)
        # Temporary hack for load deepseek-ocr2
        config.model_type = "deepseek-ocr"
        config.update({"architectures": ["DeepseekOCRForCausalLM"]})
        config = DeepseekVLV2Config.from_pretrained(model, revision=revision)
        _override_v_head_dim_if_zero(config)
        config.update({"architectures": ["DeepseekOCRForCausalLM"]})
        setattr(config, "_name_or_path", model)
    elif config.model_type in _CONFIG_REGISTRY:
        model_type = config.model_type
        if model_type == "deepseek_vl_v2":
            if _is_deepseek_ocr_model(config) or _is_deepseek_ocr2_model(config):
                model_type = "deepseek-ocr"
        config_class = _CONFIG_REGISTRY[model_type]
        config = config_class.from_pretrained(model, revision=revision)

        if _is_deepseek_ocr_model(config):
            _override_deepseek_ocr_v_head_dim(config)
            config.update({"architectures": ["DeepseekOCRForCausalLM"]})
        elif _is_deepseek_ocr2_model(config):
            _override_v_head_dim_if_zero(config)
            config.update({"architectures": ["DeepseekOCRForCausalLM"]})

        # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
        setattr(config, "_name_or_path", model)

    if isinstance(model, str) and config.model_type == "internvl_chat":
        for key, val in config.llm_config.__dict__.items():
            if not hasattr(config, key):
                setattr(config, key, val)

    if config.model_type == "multi_modality":
        config.update({"architectures": ["MultiModalityCausalLM"]})

    if model_override_args:
        config.update(model_override_args)

    # Special architecture mapping check for GGUF models
    if is_gguf:
        if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
            raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
        model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
        config.update({"architectures": [model_type]})

    return config


@lru_cache_frozenset(maxsize=32)
def get_generation_config(
    model: str,
    trust_remote_code: bool,
    revision: Optional[str] = None,
    **kwargs,
):
    try:
        return GenerationConfig.from_pretrained(
            model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
        )
    except OSError as e:
        return None


# Qwen-1M related
def get_sparse_attention_config(
    model: str,
    sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
    is_local = os.path.isdir(model)
    if not is_local:
        # Download the config files.
        model = download_from_hf(model, allow_patterns=["*.json"])

    config_file = os.path.join(model, sparse_attention_config_filename)
    if not os.path.exists(config_file):
        return {}

    # Load the sparse attention config.
    with open(config_file) as f:
        config = json.load(f)
    return config


# Models don't use the same configuration key for determining the maximum
# context length.  Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
    "max_sequence_length",
    "seq_length",
    "max_seq_len",
    "model_max_length",
    "max_position_embeddings",
]


def get_context_length(config):
    """Get the context length of a model from a huggingface model configs."""
    text_config = config
    rope_scaling = getattr(text_config, "rope_scaling", None)
    if rope_scaling:
        rope_scaling_factor = rope_scaling.get("factor", 1)
        if "original_max_position_embeddings" in rope_scaling:
            rope_scaling_factor = 1
        if rope_scaling.get("rope_type", None) == "llama3":
            rope_scaling_factor = 1
    else:
        rope_scaling_factor = 1

    for key in CONTEXT_LENGTH_KEYS:
        val = getattr(text_config, key, None)
        if val is not None:
            return int(rope_scaling_factor * val)
    return 2048


# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


# Filter warnings like: https://github.com/sgl-project/sglang/issues/8082
class TokenizerWarningsFilter(logging.Filter):
    def filter(self, record: logging.LogRecord) -> bool:
        return "Calling super().encode with" not in record.getMessage()


def get_tokenizer(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """Gets a tokenizer for the given model name via Huggingface."""
    if tokenizer_name.endswith(".json"):
        from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer

        return TiktokenTokenizer(tokenizer_name)

    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

    # TODO(Xinyuan): Remove this once we have a proper tokenizer for Devstral
    if tokenizer_name == "mistralai/Devstral-Small-2505":
        tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

    is_gguf = check_gguf_file(tokenizer_name)
    if is_gguf:
        kwargs["gguf_file"] = tokenizer_name
        tokenizer_name = Path(tokenizer_name).parent

    if is_remote_url(tokenizer_name):
        # BaseConnector implements __del__() to clean up the local dir.
        # Since config files need to exist all the time, so we DO NOT use
        # with statement to avoid closing the client.
        client = create_remote_connector(tokenizer_name)
        client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
        tokenizer_name = client.get_local_dir()

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
            clean_up_tokenization_spaces=False,
            **kwargs,
        )
        # Filter tokenizer warnings
        logging.getLogger(tokenizer.__class__.__module__).addFilter(
            TokenizerWarningsFilter()
        )
    except TypeError as e:
        # The LLaMA tokenizer causes a protobuf error in some environments.
        err_msg = (
            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
            "original tokenizer."
        )
        raise RuntimeError(err_msg) from e
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        if not trust_remote_code and (
            "does not exist or is not currently imported." in str(e)
            or "requires you to execute the tokenizer file" in str(e)
        ):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI."
            )
            raise RuntimeError(err_msg) from e
        else:
            raise e

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        warnings.warn(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
        )

    attach_additional_stop_token_ids(tokenizer)
    tokenizer = patch_tokenizer(tokenizer)
    return tokenizer


# Some models doesn't have an available processor, e.g.: InternVL
def get_tokenizer_from_processor(processor):
    if isinstance(processor, PreTrainedTokenizerBase):
        return processor
    return processor.tokenizer


def get_processor(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    use_fast: Optional[bool] = True,
    **kwargs,
):
    # pop 'revision' from kwargs if present.
    revision = kwargs.pop("revision", tokenizer_revision)
    if "mistral-large-3" in str(tokenizer_name).lower():
        config = _load_mistral_large_3_for_causal_LM(
            tokenizer_name,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )
    else:
        _ensure_llama_flash_attention2_compat()
        config = AutoConfig.from_pretrained(
            tokenizer_name,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )
    if _is_deepseek_ocr_model(config):
        # Temporary hack for load deepseek-ocr
        config.model_type = "deepseek-ocr"
        config.update({"architectures": ["DeepseekOCRForCausalLM"]})
    elif _is_deepseek_ocr2_model(config):
        # Temporary hack for load deepseek-ocr2
        config.model_type = "deepseek-ocr"
        config.update({"architectures": ["DeepseekOCRForCausalLM"]})
        _override_v_head_dim_if_zero(config)

    # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
    if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
        if "size" not in kwargs:
            kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}

    if config.model_type not in {"llava", "clip"}:
        kwargs["use_fast"] = use_fast
    try:
        if "InternVL3_5" in tokenizer_name:
            processor = AutoTokenizer.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
                revision=revision,
                **kwargs,
            )
        else:
            if config.model_type in _CUSTOMIZED_MM_PROCESSOR:
                processor = _CUSTOMIZED_MM_PROCESSOR[config.model_type].from_pretrained(
                    tokenizer_name,
                    *args,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    **kwargs,
                )
            else:
                processor = AutoProcessor.from_pretrained(
                    tokenizer_name,
                    *args,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    **kwargs,
                )

    except ValueError as e:
        error_message = str(e)
        if "does not have a slow version" in error_message:
            logger.info(
                f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
            )
            kwargs["use_fast"] = True
            processor = AutoProcessor.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
                revision=revision,
                **kwargs,
            )
        else:
            raise e
    tokenizer = get_tokenizer_from_processor(processor)

    attach_additional_stop_token_ids(tokenizer)
    return processor


def attach_additional_stop_token_ids(tokenizer):
    # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
    if "<|eom_id|>" in tokenizer.get_added_vocab():
        tokenizer.additional_stop_token_ids = set(
            [tokenizer.get_added_vocab()["<|eom_id|>"]]
        )
    else:
        tokenizer.additional_stop_token_ids = None


def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
    """Check if the file is a GGUF model."""
    model = Path(model)
    if not model.is_file():
        return False
    elif model.suffix == ".gguf":
        return True

    with open(model, "rb") as f:
        header = f.read(4)
    return header == b"GGUF"
