# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py

"""Utilities for downloading and initializing model weights."""

import collections
import concurrent.futures
import fnmatch
import glob
import hashlib
import itertools
import json
import logging
import os
import tempfile
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    Tuple,
    Union,
)

import filelock
import huggingface_hub.constants
import numpy as np
import safetensors.torch
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from tqdm.auto import tqdm

from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.modelopt_quant import (
    ModelOptFp4Config,
    ModelOptFp8Config,
)
from sglang.srt.model_loader.ci_weight_validation import (
    ci_download_with_validation_and_retry,
    ci_validate_and_cleanup_local_snapshot,
)
from sglang.srt.utils import (
    BAR_FORMAT,
    find_local_repo_dir,
    is_cpu,
    log_info_on_rank0,
    print_warning_once,
)
from sglang.utils import is_in_ci

try:
    from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError as e:
    SafeTensorsFileLoader = SingleGroup = None

logger = logging.getLogger(__name__)


def enable_hf_transfer():
    """automatically activates hf_transfer"""
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa

            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass


enable_hf_transfer()


# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()


def get_lock(
    model_name_or_path: str, cache_dir: Optional[str] = None, suffix: str = ""
):
    lock_dir = cache_dir or temp_dir
    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
    model_name = model_name_or_path.replace("/", "-")
    hash_name = hashlib.sha256(model_name.encode()).hexdigest()
    # add hash to avoid conflict with old users' lock files
    lock_file_name = hash_name + model_name + suffix + ".lock"
    # mode 0o666 is required for the filelock to be shared across users
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
) -> None:
    loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)

    from safetensors.torch import save_file

    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = safetensors.torch.load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
    for prefix, new_prefix in prefix_mapping.items():
        if key.startswith(prefix):
            key = key.replace(prefix, new_prefix, 1)
    return key


def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
    for substr, new_substr in substring_mapping.items():
        if substr in key:
            key = key.replace(substr, new_substr)
    return key


class DisabledTqdm(tqdm):
    def __init__(self, *args, **kwargs):
        kwargs["disable"] = True
        super().__init__(*args, **kwargs)


# TODO(woosuk): Move this to other place.
def get_quant_config(
    model_config: ModelConfig,
    load_config: LoadConfig,
    packed_modules_mapping: Dict[str, List[str]],
    remap_prefix: Dict[str, str] | None = None,
) -> QuantizationConfig:
    quant_cls = get_quantization_config(model_config.quantization)

    # GGUF doesn't have config file
    if model_config.quantization == "gguf":
        return quant_cls.from_config({})

    # Read the quantization config from the HF model config, if available.
    hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
    # some vision model may keep quantization_config in their text_config
    hf_text_config = getattr(model_config.hf_config, "text_config", None)
    if hf_quant_config is None and hf_text_config is not None:
        hf_quant_config = getattr(hf_text_config, "quantization_config", None)
    if hf_quant_config is None:
        # compressed-tensors uses a compressions_config
        hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
    if hf_quant_config is not None:
        if not isinstance(hf_quant_config, dict):
            hf_quant_config = hf_quant_config.to_dict()
        hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
        return quant_cls.from_config(hf_quant_config)

    # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
    if model_config.quantization == "bitsandbytes":
        if (
            not load_config.model_loader_extra_config
            or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
        ):
            return quant_cls.from_config({"adapter_name_or_path": ""})
        model_name_or_path = load_config.model_loader_extra_config[
            "qlora_adapter_name_or_path"
        ]
    else:
        model_name_or_path = model_config.model_path

    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
        with get_lock(model_name_or_path, load_config.download_dir):
            hf_folder = snapshot_download(
                model_name_or_path,
                revision=model_config.revision,
                allow_patterns="*.json",
                cache_dir=load_config.download_dir,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                tqdm_class=DisabledTqdm,
            )
    else:
        hf_folder = model_name_or_path

    possible_config_filenames = quant_cls.get_config_filenames()

    # If the quantization config is not found, use the default config.
    if not possible_config_filenames:
        if model_config.quantization == "mxfp8":
            return Fp8Config(use_mxfp8=True, is_checkpoint_fp8_serialized=False)
        return quant_cls()

    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_config_files = [
        f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
    ]
    if len(quant_config_files) == 0:
        raise ValueError(f"Cannot find the config file for {model_config.quantization}")
    if len(quant_config_files) > 1:
        raise ValueError(
            f"Found multiple config files for {model_config.quantization}: "
            f"{quant_config_files}"
        )

    quant_config_file = quant_config_files[0]
    with open(quant_config_file) as f:
        config = json.load(f)
        if remap_prefix is not None:
            exclude_modules = [
                replace_prefix(key, remap_prefix)
                for key in config["quantization"]["exclude_modules"]
            ]
            config["quantization"]["exclude_modules"] = exclude_modules
        config["packed_modules_mapping"] = packed_modules_mapping

        if model_config.quantization == "bitsandbytes":
            config["adapter_name_or_path"] = model_name_or_path
        elif model_config.quantization.startswith("modelopt") and (
            config.get("producer", {}).get("name", "").startswith("modelopt")
        ):
            quant_algo = config["quantization"]["quant_algo"]
            if quant_algo is None:
                # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
                if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
                    raise ValueError(
                        f"Invalid quant_config, quantization method: {model_config.quantization},"
                        f"hf architectures: {model_config.hf_config.architectures[0]}. "
                    )
                return None
            elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
                return ModelOptFp8Config.from_config(config)
            elif "FP4" in quant_algo:
                return ModelOptFp4Config.from_config(config)
        return quant_cls.from_config(config)


def _check_index_files_exist(snapshot_dir: str) -> Tuple[bool, Optional[str]]:
    """
    Check if all files listed in safetensors index files actually exist on disk.

    This catches cases where the snapshot directory exists but files are missing
    (e.g., due to incomplete downloads or corrupted cache).

    Args:
        snapshot_dir: Path to the model snapshot directory

    Returns:
        Tuple of (all_exist, error_message)
    """
    index_files = [
        f for f in os.listdir(snapshot_dir) if f.endswith(".safetensors.index.json")
    ]

    if not index_files:
        return True, None  # Not a sharded model

    for index_file in index_files:
        index_path = os.path.join(snapshot_dir, index_file)
        if not os.path.exists(index_path):
            continue
        try:
            with open(index_path) as f:
                weight_map = json.load(f).get("weight_map", {})
            if not weight_map:
                continue
            required_files = set(weight_map.values())
            missing_files = [
                fn
                for fn in required_files
                if not os.path.exists(os.path.join(snapshot_dir, fn))
            ]
            if missing_files:
                return (
                    False,
                    f"Missing {len(missing_files)} file(s) from index {index_file}: "
                    f"{missing_files[:3]}{'...' if len(missing_files) > 3 else ''}",
                )
        except Exception as e:
            logger.warning("Failed to read index file %s: %s", index_file, e)
            continue

    return True, None


def _find_local_hf_snapshot_dir_unlocked(
    model_name_or_path: str,
    cache_dir: Optional[str],
    allow_patterns: List[str],
    revision: Optional[str] = None,
) -> Optional[str]:
    """Find local HF snapshot directory without locking.

    IMPORTANT: Caller MUST hold the model lock before calling this function
    to prevent race conditions during validation and cleanup.

    If the weights are already local, skip downloading and returns the path.
    """
    if os.path.isdir(model_name_or_path):
        return None

    found_local_snapshot_dir = None

    # Check custom cache_dir (if provided)
    if cache_dir:
        try:
            repo_folder = os.path.join(
                cache_dir,
                huggingface_hub.constants.REPO_ID_SEPARATOR.join(
                    ["models", *model_name_or_path.split("/")]
                ),
            )
            rev_to_use = revision
            if not rev_to_use:
                ref_main = os.path.join(repo_folder, "refs", "main")
                if os.path.isfile(ref_main):
                    with open(ref_main) as f:
                        rev_to_use = f.read().strip()
            if rev_to_use:
                rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
                if os.path.isdir(rev_dir):
                    found_local_snapshot_dir = rev_dir
        except Exception as e:
            logger.warning(
                "Failed to find local snapshot in custom cache_dir %s: %s",
                cache_dir,
                e,
            )

    # Check default HF cache as well
    if not found_local_snapshot_dir:
        try:
            rev_dir = find_local_repo_dir(model_name_or_path, revision)
            if rev_dir and os.path.isdir(rev_dir):
                found_local_snapshot_dir = rev_dir
        except Exception as e:
            logger.warning("Failed to find local snapshot in default HF cache: %s", e)

    # if local snapshot exists, validate it contains at least one weight file
    # matching allow_patterns before skipping download.
    if found_local_snapshot_dir is None:
        return None

    # Check if snapshot dir exists (might have been cleaned by another process
    # before we acquired the lock)
    if not os.path.isdir(found_local_snapshot_dir):
        return None

    local_weight_files: List[str] = []
    try:
        for pattern in allow_patterns:
            matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
            for f in matched_files:
                # os.path.exists returns False for broken symlinks.
                if not os.path.exists(f):
                    continue
                local_weight_files.append(f)
    except Exception as e:
        logger.warning(
            "Failed to scan local snapshot %s with patterns %s: %s",
            found_local_snapshot_dir,
            allow_patterns,
            e,
        )
        local_weight_files = []

    # Check for missing files from index (lightweight, for all users)
    # This catches incomplete downloads before they cause cryptic load errors
    if local_weight_files:
        is_complete, error_msg = _check_index_files_exist(found_local_snapshot_dir)
        if not is_complete:
            log_info_on_rank0(
                logger,
                f"Local snapshot incomplete for {model_name_or_path}: {error_msg}. "
                f"Will download missing files.",
            )
            return None  # Triggers snapshot_download() which handles partial downloads

    # Only perform cache validation and cleanup in CI to avoid
    # unnecessary overhead for regular users
    if is_in_ci() and local_weight_files:
        is_valid = ci_validate_and_cleanup_local_snapshot(
            model_name_or_path, found_local_snapshot_dir, local_weight_files
        )
        if not is_valid:
            return None

    if len(local_weight_files) > 0:
        log_info_on_rank0(
            logger,
            f"Found local HF snapshot for {model_name_or_path} at "
            f"{found_local_snapshot_dir}; skipping download.",
        )
        return found_local_snapshot_dir
    else:
        log_info_on_rank0(
            logger,
            f"Local HF snapshot at {found_local_snapshot_dir} has no files matching "
            f"{allow_patterns}; will attempt download.",
        )
        return None


def download_weights_from_hf(
    model_name_or_path: str,
    cache_dir: Optional[str],
    allow_patterns: List[str],
    revision: Optional[str] = None,
    ignore_patterns: Optional[Union[str, List[str]]] = None,
    max_retries: int = 3,
) -> str:
    """Download model weights from Hugging Face Hub.

    Args:
        model_name_or_path (str): The model name or path.
        cache_dir (Optional[str]): The cache directory to store the model
            weights. If None, will use HF defaults.
        allow_patterns (List[str]): The allowed patterns for the
            weight files. Files matched by any of the patterns will be
            downloaded.
        revision (Optional[str]): The revision of the model.
        ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
            filter out the weight files. Files matched by any of the patterns
            will be ignored.
        max_retries (int): Maximum number of download retries if corruption
            is detected. Defaults to 3.

    Returns:
        str: The path to the downloaded model weights.
    """
    # For local paths, no HF operations needed
    if os.path.isdir(model_name_or_path):
        return model_name_or_path

    # Use a SINGLE lock for the entire operation (validation + cleanup + download)
    # to prevent race conditions where:
    # 1. Process A validates, finds corruption, deletes corrupted file
    # 2. Process B validates, sees missing file, deletes ENTIRE cache
    # 3. Process A tries to download but cache is gone
    # By using one lock, validation/cleanup and download are atomic.
    with get_lock(model_name_or_path, cache_dir):
        # Check for valid local cache first (validates and cleans up if needed)
        path = _find_local_hf_snapshot_dir_unlocked(
            model_name_or_path, cache_dir, allow_patterns, revision
        )
        if path is not None:
            # Valid local cache found, skip download
            return path

        # In CI, skip HF API calls if we're in offline mode or want to avoid rate limits
        # But we already checked for local cache above, so if we're here we need to download
        if not huggingface_hub.constants.HF_HUB_OFFLINE:
            # Before we download we look at what is available:
            fs = HfFileSystem()
            file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

            # depending on what is available we download different things
            for pattern in allow_patterns:
                matching = fnmatch.filter(file_list, pattern)
                if len(matching) > 0:
                    allow_patterns = [pattern]
                    break

        log_info_on_rank0(logger, f"Using model weights format {allow_patterns}")

        if not is_in_ci():
            # Simple download without validation for non-CI environments
            hf_folder = snapshot_download(
                model_name_or_path,
                allow_patterns=allow_patterns,
                ignore_patterns=ignore_patterns,
                cache_dir=cache_dir,
                tqdm_class=DisabledTqdm,
                revision=revision,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            )
            return hf_folder
        else:
            # Only perform validation and retry in CI to avoid overhead for regular users
            return ci_download_with_validation_and_retry(
                model_name_or_path=model_name_or_path,
                allow_patterns=allow_patterns,
                ignore_patterns=ignore_patterns,
                cache_dir=cache_dir,
                revision=revision,
                max_retries=max_retries,
            )


def download_safetensors_index_file_from_hf(
    model_name_or_path: str,
    index_file: str,
    cache_dir: Optional[str],
    revision: Optional[str] = None,
) -> None:
    """Download hf safetensors index file from Hugging Face Hub.

    Args:
        model_name_or_path (str): The model name or path.
        cache_dir (Optional[str]): The cache directory to store the model
            weights. If None, will use HF defaults.
        revision (Optional[str]): The revision of the model.
    """
    # Use file lock to prevent multiple processes from
    # downloading the same model weights at the same time.
    with get_lock(model_name_or_path, cache_dir):
        try:
            # Download the safetensors index file.
            hf_hub_download(
                repo_id=model_name_or_path,
                filename=index_file,
                cache_dir=cache_dir,
                revision=revision,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            )
        # If file not found on remote or locally, we should not fail since
        # only some models will have index_file.
        except huggingface_hub.utils.EntryNotFoundError:
            logger.info("No %s found in remote.", index_file)
        except huggingface_hub.utils.LocalEntryNotFoundError:
            logger.info("No %s found in local cache.", index_file)


# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(
    hf_weights_files: List[str], hf_folder: str, index_file: str
) -> List[str]:
    # model.safetensors.index.json is a mapping from keys in the
    # torch state_dict to safetensors file holding that weight.
    index_file_name = os.path.join(hf_folder, index_file)
    if not os.path.isfile(index_file_name):
        # NOTE: this is a trick of handling mistral model
        # skip the unsupported consolidated.safetensors file
        if len(hf_weights_files) == 2:
            hf_weights_files.sort()
            if hf_weights_files[0].endswith(
                "consolidated.safetensors"
            ) and hf_weights_files[1].endswith("model.safetensors"):
                return [hf_weights_files[1]]
        return hf_weights_files

    # Iterate through the weight_map (weight_name: safetensors files)
    # to identify weights that we should use.
    with open(index_file_name) as f:
        weight_map = json.load(f)["weight_map"]
    weight_files_in_index = set()
    for weight_name in weight_map:
        weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
    # Filter out any fields that are not found in the index file.
    hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
    return hf_weights_files


def maybe_add_mtp_safetensors(
    hf_weights_files: List[str], hf_folder: str, index_file: str, hf_config
) -> List[str]:
    """
    Auto-detect and add mtp.safetensors for GLM4Moe MTP/NextN models if:
    1. mtp.safetensors exists in the model directory
    2. mtp.safetensors is NOT in the index (checkpoint packaging bug)
    3. Model architecture is Glm4MoeForCausalLM with num_nextn_predict_layers > 0

    This works around incorrectly packaged FP4 checkpoints like
    baseten-admin/glm-4.7-fp4 where mtp.safetensors exists but
    isn't referenced in model.safetensors.index.json.
    """
    # Only apply for GLM4Moe architecture with nextn layers
    arch = getattr(hf_config, "architectures", [None])[0]
    num_nextn_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
    if not (
        arch in ["Glm4MoeForCausalLM", "Glm4MoeForCausalLMNextN"]
        and num_nextn_layers > 0
    ):
        return hf_weights_files

    # Check if mtp.safetensors exists and is not already in the file list
    mtp_path = os.path.join(hf_folder, "mtp.safetensors")
    if not os.path.isfile(mtp_path) or mtp_path in hf_weights_files:
        return hf_weights_files

    # mtp.safetensors exists but not in index - this is a bug
    logger.warning(
        f"Found mtp.safetensors but it's not referenced in {index_file}. "
        f"This is a checkpoint packaging bug. Auto-adding it for loading. "
        f"Please report this to the checkpoint provider."
    )

    # Add it to the files list
    return hf_weights_files + [mtp_path]


def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files


def np_cache_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str],
    hf_folder: str,
    hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model np files.

    Will dump the model weights to numpy files if they are not already dumped.
    """
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )
    # Convert the model weights from torch tensors to numpy arrays for
    # faster loading.
    np_folder = os.path.join(hf_folder, "np")
    os.makedirs(np_folder, exist_ok=True)
    weight_names_file = os.path.join(np_folder, "weight_names.json")
    # Use file lock to prevent multiple processes from
    # dumping the same model weights to numpy at the same time.
    with get_lock(model_name_or_path, cache_dir):
        if not os.path.exists(weight_names_file):
            weight_names: List[str] = []
            for bin_file in tqdm(
                hf_weights_files,
                desc="Loading np_cache checkpoint shards",
                disable=not enable_tqdm,
                bar_format=BAR_FORMAT,
                position=tqdm._get_free_pos(),
            ):
                state = torch.load(bin_file, map_location="cpu", weights_only=True)
                for name, param in state.items():
                    param_path = os.path.join(np_folder, name)
                    with open(param_path, "wb") as f:
                        np.save(f, param.cpu().detach().numpy())
                    weight_names.append(name)
            with open(weight_names_file, "w") as f:
                json.dump(weight_names, f)

    with open(weight_names_file) as f:
        weight_names = json.load(f)

    for name in weight_names:
        param_path = os.path.join(np_folder, name)
        with open(param_path, "rb") as f:
            param = np.load(f)
        yield name, torch.from_numpy(param)


def safetensors_weights_iterator(
    hf_weights_files: List[str],
    disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )
    for st_file in tqdm(
        hf_weights_files,
        desc="Loading safetensors checkpoint shards",
        disable=not enable_tqdm,
        bar_format=BAR_FORMAT,
        position=tqdm._get_free_pos(),
    ):
        if disable_mmap:
            with open(st_file, "rb") as f:
                result = safetensors.torch.load(f.read())
                for name in sorted(result.keys()):
                    yield name, result[name]
        else:
            with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
                for name in f.keys():
                    yield name, f.get_tensor(name)


def fastsafetensors_weights_iterator(
    hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """
    Iterate over the weights in the model safetensor files
    using fastsafetensor library to accelerate loading via GPU Direct Storage (if available).
    """
    if SafeTensorsFileLoader is None:
        raise ImportError(
            "Please install fastsafetensors via `pip install fastsafetensors`"
        )

    if torch.distributed.is_initialized():
        pg = torch.distributed.group.WORLD
    else:
        pg = SingleGroup()

    try:
        rank = pg.rank()
    except Exception:
        rank = 0

    device = torch.device(f"cuda:{rank}")

    weight_files_sub_lists = [
        hf_weights_files[i : i + pg.size()]
        for i in range(0, len(hf_weights_files), pg.size())
    ]

    _BAR_FORMAT = (
        "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]"
    )

    for f_list in tqdm(
        weight_files_sub_lists,
        desc="Loading safetensors using Fastsafetensor loader",
        disable=False,
        bar_format=_BAR_FORMAT,
    ):
        loader = SafeTensorsFileLoader(pg, device)
        rank_file_map = {i: [f] for i, f in enumerate(f_list)}
        loader.add_filenames(rank_file_map)
        try:
            fb = loader.copy_files_to_device()
            try:
                keys = list(fb.key_to_rank_lidx.keys())
                for k in keys:
                    t = fb.get_tensor(k)
                    yield k, t
            finally:
                pass
        finally:
            loader.close()


def multi_thread_safetensors_weights_iterator(
    hf_weights_files: List[str],
    max_workers: int,
    disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Multi-Thread iterate over the weights in the model safetensor files."""
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )

    def _load_file(st_file: str):
        if disable_mmap:
            with open(st_file, "rb") as f:
                return safetensors.torch.load(f.read())
        return safetensors.torch.load_file(st_file, device="cpu")

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]

        if enable_tqdm:
            futures_iter = tqdm(
                concurrent.futures.as_completed(futures),
                total=len(hf_weights_files),
                desc="Multi-thread loading shards",
                disable=not enable_tqdm,
                bar_format=BAR_FORMAT,
            )
        else:
            futures_iter = concurrent.futures.as_completed(futures)

        for future in futures_iter:
            state_dict = future.result()
            for name, param in state_dict.items():
                yield name, param


def buffered_multi_thread_safetensors_weights_iterator(
    hf_weights_files: List[str],
    max_workers: int,
    disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Multi-threaded safetensor loader with bounded memory via a sliding window.

    At most (max_workers + 1) shard files are in-flight at any time:
    max_workers loading concurrently + 1 prefetched and ready to yield.
    Peak CPU RAM ≈ (max_workers + 2) × shard_file_size.
    """
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )

    def _load_file(st_file: str):
        if disable_mmap:
            with open(st_file, "rb") as f:
                result = safetensors.torch.load(f.read())
        else:
            with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
                result = {k: f.get_tensor(k) for k in f.keys()}
        return result

    # Sliding window: max_workers loading + 1 prefetched.
    buffer_size = max_workers + 1

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        file_iter = iter(hf_weights_files)
        pending: collections.deque = collections.deque()

        # Seed the buffer.
        for st_file in itertools.islice(file_iter, buffer_size):
            pending.append(executor.submit(_load_file, st_file))

        with tqdm(
            total=len(hf_weights_files),
            desc="Multi-thread loading shards",
            disable=not enable_tqdm,
            bar_format=BAR_FORMAT,
            position=tqdm._get_free_pos(),
        ) as pbar:
            while pending:
                future = pending.popleft()
                state_dict = future.result()
                del future  # let GC reclaim the Future's internal result

                # Replenish: submit the next file to keep the buffer full.
                next_file = next(file_iter, None)
                if next_file is not None:
                    pending.append(executor.submit(_load_file, next_file))

                for name in sorted(state_dict.keys()):
                    yield name, state_dict[name]
                del state_dict
                pbar.update(1)


def _load_pt_file(bin_file: str) -> dict:
    """Load a PyTorch checkpoint file, handling legacy tar format.

    PyTorch 2.6 changed the default of weights_only from False to True.
    Legacy tar format files cannot be loaded with weights_only=True.
    This function tries weights_only=True first, then falls back to False
    for legacy tar format files from trusted sources (HuggingFace Hub).
    """
    try:
        return torch.load(bin_file, map_location="cpu", weights_only=True)
    except RuntimeError as e:
        if "legacy .tar format" in str(e):
            logger.warning(
                "Loading %s with weights_only=False (legacy tar format)",
                os.path.basename(bin_file),
            )
            return torch.load(bin_file, map_location="cpu", weights_only=False)
        raise


def pt_weights_iterator(
    hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files."""
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )
    for bin_file in tqdm(
        hf_weights_files,
        desc="Loading pt checkpoint shards",
        disable=not enable_tqdm,
        bar_format=BAR_FORMAT,
        position=tqdm._get_free_pos(),
    ):
        state = _load_pt_file(bin_file)
        yield from state.items()
        del state


def multi_thread_pt_weights_iterator(
    hf_weights_files: List[str],
    max_workers: int,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Multi-Thread iterate over the weights in the model bin/pt files."""
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(_load_pt_file, bin_file) for bin_file in hf_weights_files
        ]

        if enable_tqdm:
            futures_iter = tqdm(
                concurrent.futures.as_completed(futures),
                total=len(hf_weights_files),
                desc="Multi-thread loading pt checkpoint shards",
                disable=not enable_tqdm,
                bar_format=BAR_FORMAT,
            )
        else:
            futures_iter = concurrent.futures.as_completed(futures)

        for future in futures_iter:
            state = future.result()
            yield from state.items()


def get_gguf_extra_tensor_names(
    gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]:
    import gguf

    reader = gguf.GGUFReader(gguf_file)
    expected_gguf_keys = set(gguf_to_hf_name_map.keys())
    exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
    extra_keys = expected_gguf_keys - exact_gguf_keys
    return [gguf_to_hf_name_map[key] for key in extra_keys]


def gguf_quant_weights_iterator(
    gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """
    Iterate over the quant weights in the model gguf files and convert
    them to torch tensors
    """

    import gguf

    reader = gguf.GGUFReader(gguf_file)

    for tensor in reader.tensors:
        if tensor.name in gguf_to_hf_name_map:
            weight_type = tensor.tensor_type
            name = gguf_to_hf_name_map[tensor.name]

            if weight_type.name != "F32":
                weight_type_name = name.replace("weight", "qweight_type")
                weight_type = torch.tensor(weight_type)
                yield weight_type_name, weight_type

    for tensor in reader.tensors:
        if tensor.name in gguf_to_hf_name_map:
            weight = tensor.data
            weight_type = tensor.tensor_type
            name = gguf_to_hf_name_map[tensor.name]

            if weight_type.name != "F32":
                name = name.replace("weight", "qweight")
            param = torch.tensor(weight)
            yield name, param


def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
        x = x[:]
    return x


def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})"
            )

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise


def row_parallel_weight_loader(
    param: torch.Tensor, loaded_weight: torch.Tensor
) -> None:
    """Load weights that are row-parallelized."""
    tp_rank = get_tensor_model_parallel_rank()
    shard_dim = 0 if param.dim() != 1 else None

    if shard_dim is not None:
        shard_size = param.data.shape[shard_dim]
        start_idx = tp_rank * shard_size
        loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)

    return default_weight_loader(param, loaded_weight)


LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
    """Create a weight loader that shards the weights along the given axis"""

    def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        tp_rank = get_attention_tp_rank()

        shard_size = param.data.shape[shard_axis]
        start_idx = tp_rank * shard_size

        if (
            is_cpu()
            and loaded_weight.size(0) % get_tensor_model_parallel_world_size() != 0
            and loaded_weight.dim() == 1
        ):
            param_data = param.data  # view copy on param for uneven padding
            param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
                param_data,
                loaded_weight,
                0,  # param_data_start
                start_idx,
                shard_axis,
                shard_size,
            )
            return default_weight_loader(param_data, loaded_weight)
        else:
            loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
            return default_weight_loader(param, loaded_weight)

    return loader


def composed_weight_loader(
    loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor]
) -> LoaderFunction:
    """Create a weight loader that post-processes the weights after loading"""

    def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        loader(param, loaded_weight)
        param.data.copy_(fn(param))
        return

    return composed_loader


def runai_safetensors_weights_iterator(
    hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    from runai_model_streamer import SafetensorsStreamer

    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )

    with SafetensorsStreamer() as streamer:
        for st_file in tqdm(
            hf_weights_files,
            desc="Loading safetensors using Runai Model Streamer",
            disable=not enable_tqdm,
            bar_format=BAR_FORMAT,
            position=tqdm._get_free_pos(),
        ):
            streamer.stream_file(st_file)
            yield from streamer.get_tensors()


def set_runai_streamer_env(load_config: LoadConfig):
    if load_config.model_loader_extra_config:
        extra_config = load_config.model_loader_extra_config

        if "concurrency" in extra_config and isinstance(
            extra_config.get("concurrency"), int
        ):
            os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
                extra_config.get("concurrency")
            )

        if "memory_limit" in extra_config and isinstance(
            extra_config.get("memory_limit"), int
        ):
            os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
                extra_config.get("memory_limit")
            )

    runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
    aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
    if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
        os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
    seed: int = 1234,
) -> None:
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.

    We use per-parameter random seed, so that dummy weights are consistent,
    even if the model is partitioned across multiple devices. When the seed
    is fixed, the random values generated by this function only depends on
    the parameter's number of elements and its data type.
    """
    for param in model.state_dict().values():
        if torch.is_floating_point(param):
            generator = torch.Generator(device=param.data.device)
            generator.manual_seed(seed)
            if torch.finfo(param.data.dtype).bits < 16:
                # uniform_ doesn't support < 16-bit datatypes (FP8)
                dtype = param.data.dtype
                tmp_param = param.data.to(torch.float16)
                tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
                param.data.copy_(tmp_param)
            else:
                param.uniform_(low, high, generator=generator)


def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        print_warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale"
        )
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            print_warning_once(
                f"Found kv_scale in the checkpoint (e.g. {name}), "
                "but not found the expected name in the model "
                f"(e.g. {remapped_name}). kv_scale is "
                "not loaded."
            )
            return None
        return remapped_name

    possible_scale_names = [".k_scale", ".v_scale"]
    # Patterns where modelopt stores scales under k_proj/v_proj
    # but the model expects them under attn (RadixAttention)
    modelopt_attn_prefixes = [".self_attn.", ".mixer."]
    for scale_name in possible_scale_names:
        if name.endswith(scale_name):
            # Check if this is a modelopt-style scale under k_proj/v_proj
            matched_prefix = None
            for attn_prefix in modelopt_attn_prefixes:
                if f"{attn_prefix}{scale_name[1]}_proj{scale_name}" in name:
                    matched_prefix = attn_prefix
                    break

            if matched_prefix is not None:
                remapped_name = name.replace(
                    f"{matched_prefix}{scale_name[1]}_proj{scale_name}",
                    f"{matched_prefix}attn{scale_name}",
                )
            else:
                remapped_name = name.replace(scale_name, f".attn{scale_name}")
            if remapped_name not in params_dict:
                print_warning_once(
                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
                    "but not found the expected name in the model "
                    f"(e.g. {remapped_name}). {scale_name} is "
                    "not loaded."
                )
                return None
            return remapped_name

    quark_scale_names = {
        ".q_proj.output_scale": ".attn.q_scale",
        ".k_proj.output_scale": ".attn.k_scale",
        ".v_proj.output_scale": ".attn.v_scale",
        "self_attn.prob_output_scale": ".attn.prob_scale",
    }
    for quark_scale_name, sglang_scale_name in quark_scale_names.items():
        if name.endswith(quark_scale_name):
            return name.replace(quark_scale_name, sglang_scale_name)

    # If there were no matches, return the untouched param name
    return name


# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
class KVCacheQuantSchema(BaseModel):
    dtype: str
    # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
    # layer indices to their per-tensor KV cache scaling factor.
    # TODO: Consider pulling this and its validation methods out into its
    # own schema class (tricky as its members are variable)
    scaling_factor: Dict[int, Dict[int, float]]

    @model_validator(mode="after")
    def check_is_fp8(self) -> "KVCacheQuantSchema":
        assert self.dtype == "float8_e4m3fn", (
            "Loaded scaling factors intended for KV cache dtype = "
            f"{self.dtype} rather than float8_e4m3fn!"
        )
        return self

    @model_validator(mode="after")
    def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
        context = info.context
        if context:
            tp_size = context["tp_size"]
            num_hidden_layers = context["num_hidden_layers"]
            assert len(self.scaling_factor) == tp_size, (
                f"Loaded dictionary has TP size {len(self.scaling_factor)} "
                f"but LLM engine is currently running with TP size {tp_size}."
            )
            for tp_rank, layer_maps in self.scaling_factor.items():
                assert len(layer_maps) == num_hidden_layers, (
                    f"KV cache scales map for TP rank {tp_rank} is malformed. "
                    f"Expected {num_hidden_layers} layers, got "
                    f"{len(layer_maps)}."
                )
            for i in range(tp_size):
                assert (
                    i in self.scaling_factor
                ), f"KV cache scales map for TP rank {i} not found."
        return self

    @model_validator(mode="after")
    def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
        context = info.context
        if context:
            tp_rank = context["tp_rank"]
            num_hidden_layers = context["num_hidden_layers"]
            layer_scales_map = self.scaling_factor[tp_rank]
            for i in range(num_hidden_layers):
                assert i in layer_scales_map, (
                    f"Could not find KV cache scales for layer {i} in "
                    f"TP rank {tp_rank}."
                )
        return self


class QuantParamSchema(BaseModel):
    # TODO: Generalize and extend with more fields
    # (e.g. weights/activations params) once functionality is enabled
    model_config = ConfigDict(protected_namespaces=())
    model_type: Optional[str]
    kv_cache: KVCacheQuantSchema

    @model_validator(mode="after")
    def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
        context = info.context
        if context:
            model_type = context.get("model_type", None)
            if model_type is not None:
                assert model_type == self.model_type, (
                    f"Model type is {model_type} but loaded "
                    f"scaling factors belonging to different "
                    f"model type {self.model_type}!"
                )
        return self


def kv_cache_scales_loader(
    filename: str,
    tp_rank: int,
    tp_size: int,
    num_hidden_layers: int,
    model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
    """
    A simple utility to read in KV cache scaling factors that have been
    previously serialized to disk. Used by the model to populate the appropriate
    KV cache scaling factors. The serialization should represent a dictionary
    whose keys are the TP ranks and values are another dictionary mapping layers
    to their KV cache scaling factors.
    """
    try:
        with open(filename) as f:
            context = {
                "model_type": model_type,
                "num_hidden_layers": num_hidden_layers,
                "tp_rank": tp_rank,
                "tp_size": tp_size,
            }
            schema_dct = json.load(f)
            schema = QuantParamSchema.model_validate(schema_dct, context=context)
            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
            return layer_scales_map.items()
    except FileNotFoundError:
        logger.error("File or directory '%s' not found.", filename)
    except json.JSONDecodeError:
        logger.error("Error decoding JSON in file '%s'.", filename)
    except Exception:
        logger.error("An error occurred while reading '%s'.", filename)
    # This section is reached if and only if any of the excepts are hit
    # Return an empty iterable (list) => no KV cache scales are loaded
    # which ultimately defaults to 1.0 scales
    logger.warning(
        "Defaulting to KV cache scaling factors = 1.0 for all "
        "layers in TP rank %d as an error occurred during loading.",
        tp_rank,
    )
    return []


def get_actual_shard_size(shard_size, weight_start, weight_end):
    if weight_end < weight_start:
        return 0

    return min(shard_size, weight_end - weight_start)


def reset_param_data_if_needed(param_data, dim, start, length):
    if length == 0:
        return

    assert length > 0, f"Length should be positive, but got {length}"

    param_data.narrow(dim, start, length).zero_()
    return


def narrow_padded_param_and_loaded_weight(
    param_data,
    loaded_weight,
    param_data_start,
    weight_start,
    dim,
    shard_size,
    narrow_weight=True,
):
    actual_shard_size = get_actual_shard_size(
        shard_size, weight_start, loaded_weight.size(dim)
    )

    if narrow_weight:
        if actual_shard_size > 0:
            loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
        else:
            # No real data to load; create a dummy tensor filled with zeros
            loaded_weight = torch.zeros_like(
                param_data.narrow(dim, param_data_start, actual_shard_size)
            )

    # [Note] Reset padded weights to zero.
    # If the actual shard size is less than the shard size, we need to reset
    # the padded param_data to zero and then copy the loaded_weight into it.
    reset_param_data_if_needed(
        param_data,
        dim,
        param_data_start + actual_shard_size,
        shard_size - actual_shard_size,
    )

    param_data = param_data.narrow(dim, param_data_start, actual_shard_size)

    return param_data, loaded_weight
