import functools
import os
from pathlib import Path
from typing import Any, Dict

import polars as pl
import torch


class DumpLoader:
    def __init__(self):
        directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")

        self._enable = directory is not None
        if self._enable:
            self._directory = Path(directory)
            self._df = read_meta(directory)

    @property
    def enable(self):
        return self._enable

    def load(self, name, **kwargs):
        assert self._enable, "Please call DumpLoader.load only when it is enabled"

        from sglang.srt.debug_utils.dumper import dumper

        forward_pass_id = dumper._forward_pass_id
        conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
        row = find_row(self._df, conditions=conditions)
        assert (
            row is not None
        ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"

        path = self._directory / row["filename"]
        output = torch.load(path, weights_only=False)
        if isinstance(output, dict) and "value" in output:
            output = output["value"]

        print(
            f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
        )
        return output


def read_meta(directory):
    directory = Path(directory)
    assert directory.is_dir(), f"{directory=} should be a directory"

    rows = []
    for p in directory.glob("*.pt"):
        try:
            full_kwargs = {}
            for kv in p.stem.split("___"):
                k, v = kv.split("=")
                full_kwargs[k] = v
            rows.append(
                {
                    "filename": str(p.name),
                    **full_kwargs,
                }
            )
        except Exception as e:
            print(f"[DumpLoader] skip loading {p} due to error {e}")

    df = pl.DataFrame(rows)
    df = df.with_columns(
        pl.col("forward_pass_id").cast(int),
        pl.col("rank").cast(int),
        pl.col("dump_index").cast(int),
    )
    df = _add_duplicate_index(df)
    df = df.sort("rank", "dump_index")
    return df


def _add_duplicate_index(df: pl.DataFrame) -> pl.DataFrame:
    group_cols = [c for c in df.columns if c not in ["filename", "dump_index"]]
    df = df.sort(group_cols + ["dump_index"])
    df = df.with_columns(
        pl.cum_count("dump_index").over(group_cols).sub(1).alias("duplicate_index")
    )
    return df


def find_row(df, conditions: Dict[str, Any]):
    df_sub = df.filter(
        functools.reduce(
            lambda a, b: a & b,
            [
                (
                    pl.col(col)
                    == _cast_to_polars_dtype(conditions[col], df.schema[col])
                    if conditions[col] is not None
                    else pl.col(col).is_null()
                )
                for col in conditions.keys()
                if col in df.columns
            ],
        )
    )
    if len(df_sub) > 1:
        print(f"find_row find ambiguous results: {df_sub=}")
        return None
    return df_sub.to_dicts()[0] if len(df_sub) > 0 else None


def _cast_to_polars_dtype(value, target_dtype):
    if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
        return int(value)
    elif target_dtype in (pl.Float64, pl.Float32):
        return float(value)
    elif target_dtype == pl.Boolean:
        return bool(value)
    elif target_dtype == pl.String:
        return str(value)
    else:
        return value


dump_loader = DumpLoader()
