"""Common utilities for testing and benchmarking"""

import argparse
import asyncio
import copy
import doctest
import inspect
import json
import logging
import os
import random
import re
import shlex
import subprocess
import sys
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial, wraps
from io import BytesIO
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Awaitable, Callable, List, Optional, Tuple

import aiohttp
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image

from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config
from sglang.srt.environ import envs
from sglang.srt.utils import (
    get_bool_env_var,
    get_device,
    is_cuda,
    is_port_available,
    is_xpu,
    kill_process_tree,
    retry,
)
from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback

# General test models
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_SCORE = "Qwen/Qwen3-Reranker-0.6B"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE = "Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT = "Qwen/Qwen1.5-MoE-A2.7B-Chat"

# MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"

# Hybrid Mamba models
DEFAULT_HYBRID_MAMBA_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-Next-80B-A3B-Instruct"
# VL test models
DEFAULT_MODEL_NAME_FOR_TEST_VL_PP = "Qwen/Qwen3-VL-2B-Thinking"
DEFAULT_MODEL_NAME_FOR_TEST_GLM_41V_PP = "zai-org/GLM-4.1V-9B-Thinking"

# NVFP4 models
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-V3-0324-FP4"
DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4 = "nvidia/Qwen3-30B-A3B-FP4"

# FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST_FP8 = (
    "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
)
DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8 = (
    "nvidia/Llama-3.1-8B-Instruct-FP8"
)
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 = "Qwen/Qwen3-1.7B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE = "gaunernst/DeepSeek-V2-Lite-Chat-FP8"

# MXFP4 models
# Standard MXFP4 MoE test model
DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE = "openai/gpt-oss-20b"

# W8A8 models
DEFAULT_MODEL_NAME_FOR_TEST_W8A8 = "RedHatAI/Llama-3.2-3B-quantized.w8a8"
DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"

# INT4 models
DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
    "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4"
)

# EAGLE2 algorithm models
DEFAULT_TARGET_MODEL_EAGLE = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_DRAFT_MODEL_EAGLE = "lmsys/sglang-EAGLE-llama2-chat-7B"

# EAGLE3 model
DEFAULT_TARGET_MODEL_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_DRAFT_MODEL_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"

# EAGLE2 with DP-Attention models
DEFAULT_TARGET_MODEL_EAGLE_DP_ATTN = "Qwen/Qwen3-30B-A3B"
DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN = "Tengyunw/qwen3_30b_moe_eagle3"

# Standalone speculative decoding models
DEFAULT_TARGET_MODEL_STANDALONE = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_DRAFT_MODEL_STANDALONE = "meta-llama/Llama-3.2-1B-Instruct"

# N-gram speculative decoding models
DEFAULT_TARGET_MODEL_NGRAM = "Qwen/Qwen2.5-Coder-7B-Instruct"

# Other use cases
DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST = (
    "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc",  # auto_round:auto_gptq
    "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound",  # auto_round:auto_awq
)
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
    "meta-llama/Llama-4-Scout-17B-16E-Instruct"
)
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEFAULT_DEEPEP_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-V3-0324"
DEFAULT_DEEPEP_MODEL_NAME_FOR_TEST_NEXTN = "lmsys/DeepSeek-V3-NextN"
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
    "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
)
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B"
DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST = "Barrrrry/DeepSeek-R1-W4AFP8"
DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B"

# Nightly tests
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,zai-org/GLM-4.5-Air-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-VL-3B-Instruct"

DEFAULT_IMAGE_URL = "https://raw.githubusercontent.com/sgl-project/sglang/main/examples/assets/example_image.png"
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"

DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600


def download_image_with_retry(image_url: str, max_retries: int = 3) -> Image.Image:
    for i in range(max_retries):
        try:
            response = requests.get(image_url, timeout=30)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content))
            image.load()
            return image
        except Exception as e:
            if i == max_retries - 1:
                raise RuntimeError(
                    f"Failed to download image after {max_retries} retries: {image_url}"
                ) from e
            time.sleep(2**i)


def is_in_ci():
    """Return whether it is in CI runner."""
    return get_bool_env_var("SGLANG_IS_IN_CI")


def is_in_amd_ci():
    """Return whether it is in an AMD CI runner."""
    return get_bool_env_var("SGLANG_IS_IN_CI_AMD")


def is_blackwell_system():
    """Return whether it is running on a Blackwell (B200) system."""
    return envs.IS_BLACKWELL.get()


def is_h200_system():
    """Return whether it is running on an H200 system."""
    return envs.IS_H200.get()


def _use_cached_default_models(model_repo: str):
    cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR")
    if cache_dir and model_repo:
        model_path = os.path.join(cache_dir, model_repo)
        if os.path.isdir(model_path):
            return os.path.abspath(model_path)
    return ""


if is_in_ci():
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        10000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 2000
    )
else:
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        20000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 1000
    )
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"

if is_in_amd_ci():
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3600  # Match H200 timeout for large models

if is_blackwell_system():
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000

if is_h200_system():
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3600


def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None

    data = {
        "inputs": prompt,
        "parameters": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop_sequences": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    pred = res.json()["generated_text"][0]
    return pred


def find_available_port(base_port: int):
    port = base_port + random.randint(100, 1000)
    while True:
        if is_port_available(port):
            return port
        if port < 60000:
            port += 42
        else:
            port -= 43


def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
    assert url is not None

    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


def call_generate_outlines(
    prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
):
    assert url is not None

    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "regex": regex,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None

    data = {
        "text": prompt,
        "sampling_params": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    obj = res.json()
    pred = obj["text"]
    return pred


def call_generate_guidance(
    prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
    assert model is not None
    from guidance import gen

    rets = []
    for _ in range(n):
        out = (
            model
            + prompt
            + gen(
                name="answer",
                max_tokens=max_tokens,
                temperature=temperature,
                stop=stop,
                regex=regex,
            )
        )
        rets.append(out["answer"])
    return rets if n > 1 else rets[0]


def call_select_lightllm(context, choices, url=None):
    assert url is not None

    scores = []
    for i in range(len(choices)):
        data = {
            "inputs": context + choices[i],
            "parameters": {
                "max_new_tokens": 1,
            },
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
        scores.append(0)
    return np.argmax(scores)


def call_select_vllm(context, choices, url=None):
    assert url is not None

    scores = []
    for i in range(len(choices)):
        data = {
            "prompt": context + choices[i],
            "max_tokens": 1,
            "prompt_logprobs": 1,
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
        scores.append(res.json().get("prompt_score", 0))
    return np.argmax(scores)

    """
    Modify vllm/entrypoints/api_server.py

    if final_output.prompt_logprobs is not None:
        score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
        ret["prompt_score"] = score
    """


def call_select_guidance(context, choices, model=None):
    assert model is not None
    from guidance import select

    out = model + context + select(choices, name="answer")
    return choices.index(out["answer"])


def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
    parser.add_argument("--parallel", type=int, default=64)
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument(
        "--backend",
        type=str,
        required=True,
        choices=[
            "vllm",
            "outlines",
            "lightllm",
            "gserver",
            "guidance",
            "srt-raw",
            "llama.cpp",
        ],
    )
    parser.add_argument("--n-ctx", type=int, default=4096)
    parser.add_argument(
        "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
    )
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    args = parser.parse_args()

    if args.port is None:
        default_port = {
            "vllm": 21000,
            "outlines": 21000,
            "lightllm": 22000,
            "srt-raw": 30000,
            "gserver": 9988,
        }
        args.port = default_port.get(args.backend, None)
    return args


def auto_config_device() -> str:
    """Auto-config available device platform"""

    try:
        device = get_device()
    except (RuntimeError, ImportError) as e:
        print(f"Warning: {e} - Falling back to CPU")
        device = "cpu"

    return device


def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
    parser.add_argument("--parallel", type=int, default=64)
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=30000)
    parser.add_argument("--backend", type=str, default="srt")
    parser.add_argument(
        "--device",
        type=str,
        default="auto",
        choices=["auto", "cuda", "rocm", "cpu"],
        help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
    )
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    parser.add_argument("--raw-result-file", type=str)
    args = parser.parse_args()

    return args


def select_sglang_backend(args: argparse.Namespace):
    from sglang.lang.backend.openai import OpenAI
    from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint

    if args.backend.startswith("srt"):
        if args.backend == "srt-no-parallel":
            global_config.enable_parallel_encoding = False
        backend = RuntimeEndpoint(f"{args.host}:{args.port}")
    elif args.backend.startswith("gpt-"):
        backend = OpenAI(args.backend)
    else:
        raise ValueError(f"Invalid backend: {args.backend}")
    return backend


def _get_call_generate(args: argparse.Namespace):
    if args.backend == "lightllm":
        return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "srt-raw":
        return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "outlines":
        return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_generate = partial(call_generate_guidance, model=model)
        call_generate("Hello,", 1.0, 8, ".")
        return call_generate
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


def _get_call_select(args: argparse.Namespace):
    if args.backend == "lightllm":
        return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_select = partial(call_select_guidance, model=model)

        call_select("Hello,", ["world", "earth"])
        return call_select
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


def get_call_generate(args: argparse.Namespace):
    call_generate = _get_call_generate(args)

    def func(*args, **kwargs):
        try:
            return call_generate(*args, **kwargs)
        except Exception:
            print("Exception in call_generate:\n" + get_exception_traceback())
            raise

    return func


def get_call_select(args: argparse.Namespace):
    call_select = _get_call_select(args)

    def func(*args, **kwargs):
        try:
            return call_select(*args, **kwargs)
        except Exception:
            print("Exception in call_select:\n" + get_exception_traceback())
            raise

    return func


def _get_default_models():
    import inspect

    current_module = inspect.getmodule(_get_default_models)
    default_models = set()
    for name, value in current_module.__dict__.items():
        if (
            isinstance(name, str)
            and "DEFAULT_" in name
            and "MODEL_" in name
            and isinstance(value, str)
        ):
            if "," in value:
                parts = [part.strip() for part in value.split(",")]
                default_models.update(parts)
            else:
                default_models.add(value.strip())
    return json.dumps(list(default_models))


def try_cached_model(model_repo: str):
    model_dir = _use_cached_default_models(model_repo)
    return model_dir if model_dir else model_repo


def popen_with_error_check(command: list[str], allow_exit: bool = False):
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    def _run_and_check():
        stdout, stderr = process.communicate()

        while process.poll() is None:
            time.sleep(5)

        if not allow_exit or process.returncode != 0:
            raise Exception(
                f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
            )

    t = threading.Thread(target=_run_and_check)
    t.start()
    return process


def _try_enable_offline_mode_if_cache_complete(
    model_name_or_path: str, env: dict, other_args: Optional[list[str]] = None
) -> Optional[str]:
    """
    CI helper: Check if model cache is complete and enable offline mode.

    Uses per-run validation markers that are NOT shared across runners.
    Each runner independently validates its cache using lightweight checks
    before enabling offline mode.

    IMPORTANT: Even if a per-run marker exists, this function ALWAYS validates
    the current launch's requirements (e.g., hf_quant_config.json for modelopt).
    The marker is only a hint that this snapshot was validated earlier in the run.

    Args:
        model_name_or_path: Model identifier or path
        env: Environment dict to modify (will add HF_HUB_OFFLINE=1 if validation passes)
        other_args: Launch command arguments (used to detect quantization requirement)

    Returns:
        Per-run marker path if offline mode was enabled, None otherwise
    """
    from sglang.srt.model_loader.ci_weight_validation import (
        _get_per_run_marker_path,
        _read_per_run_marker,
        _write_per_run_marker,
        validate_cache_lightweight,
    )
    from sglang.srt.utils import find_local_repo_dir

    other_args = other_args or []

    # Skip offline mode for LoRA scenarios (dynamic adapter loading may need online access)
    is_lora_enabled = "--enable-lora" in other_args or "--lora-paths" in other_args
    if is_lora_enabled:
        print(f"CI_OFFLINE: LoRA enabled, skip offline mode - {model_name_or_path}")
        return None

    # Fast-path: If subprocess env already has HF_HUB_OFFLINE=1, skip
    if env.get("HF_HUB_OFFLINE") == "1":
        print(
            f"CI_OFFLINE: Subprocess env already has HF_HUB_OFFLINE=1, skip - {model_name_or_path}"
        )
        return None

    # Skip if already a local path
    if os.path.isdir(model_name_or_path):
        return None

    # Try to find local snapshot
    try:
        snapshot_dir = find_local_repo_dir(model_name_or_path, revision=None)
        if not snapshot_dir or not os.path.isdir(snapshot_dir):
            return None
    except Exception:
        return None

    # Detect if quantization requires hf_quant_config.json
    # Do this BEFORE checking marker to ensure current launch requirements are known
    requires_hf_quant_config = False
    for i, arg in enumerate(other_args):
        if arg == "--quantization" and i + 1 < len(other_args):
            quant_value = other_args[i + 1].lower()
            if quant_value in ["modelopt_fp4", "modelopt_fp8", "modelopt"]:
                requires_hf_quant_config = True
                break

    # Check per-run marker (fast hint - snapshot validated earlier in this run)
    per_run_marker = _read_per_run_marker(snapshot_dir)
    if per_run_marker is not None:
        # Marker exists, but STILL validate for current launch requirements
        # This prevents a test without --quantization from enabling offline
        # for a later test with --quantization that needs hf_quant_config.json
        is_valid = validate_cache_lightweight(snapshot_dir, requires_hf_quant_config)

        if not is_valid:
            # Current launch requirements not met, ignore marker
            print(
                f"CI_OFFLINE: Per-run marker found but current validation failed "
                f"(requires_hf_quant_config={requires_hf_quant_config}), "
                f"will use online mode - {model_name_or_path}"
            )
            return None

        # Marker exists and current validation passed
        env["HF_HUB_OFFLINE"] = "1"
        marker_path = _get_per_run_marker_path(snapshot_dir)
        print(
            f"CI_OFFLINE: Per-run marker found and current validation passed "
            f"(requires_hf_quant_config={requires_hf_quant_config}), "
            f"enabling offline mode - {model_name_or_path}"
        )
        return marker_path

    # No per-run marker - perform lightweight validation
    is_valid = validate_cache_lightweight(snapshot_dir, requires_hf_quant_config)

    if not is_valid:
        # Validation failed - cache is incomplete on this runner
        print(
            f"CI_OFFLINE: Cache validation failed "
            f"(requires_hf_quant_config={requires_hf_quant_config}), "
            f"will use online mode - {model_name_or_path}"
        )
        return None

    # Validation passed - enable offline mode and write per-run marker
    env["HF_HUB_OFFLINE"] = "1"

    # Write per-run marker for subsequent tests in this run
    _write_per_run_marker(snapshot_dir, model_name_or_path)

    # Return marker path for potential invalidation if offline launch fails
    marker_path = _get_per_run_marker_path(snapshot_dir)

    snapshot_basename = os.path.basename(snapshot_dir)
    print(
        f"CI_OFFLINE: Enabled HF_HUB_OFFLINE=1 for subprocess - "
        f"validation passed for {model_name_or_path} "
        f"(snapshot={snapshot_basename}, requires_hf_quant_config={requires_hf_quant_config})"
    )

    return marker_path


def _create_clean_subprocess_env(env: dict) -> dict:
    """Create a clean subprocess environment without internal CI keys.

    Removes all keys starting with '_CI_OFFLINE_' or 'CI_OFFLINE' to prevent
    leaking implementation details to the server subprocess.

    Args:
        env: Source environment dict

    Returns:
        Clean copy of environment dict
    """
    child_env = env.copy()
    keys_to_remove = [
        k for k in child_env if k.startswith(("_CI_OFFLINE_", "CI_OFFLINE_"))
    ]
    for k in keys_to_remove:
        del child_env[k]
    return child_env


def _launch_server_process(
    command: List[str],
    env: dict,
    return_stdout_stderr: Optional[tuple],
    model: str,
) -> subprocess.Popen:
    """Launch server subprocess with clean environment.

    Args:
        command: Command list for subprocess
        env: Environment dict (will be cleaned before use)
        return_stdout_stderr: Optional tuple of (stdout_file, stderr_file) for output capture
        model: Model name for logging

    Returns:
        Started subprocess.Popen object
    """
    child_env = _create_clean_subprocess_env(env)

    hf_hub_offline = child_env.get("HF_HUB_OFFLINE", "0")
    print(f"CI_OFFLINE: Launching server HF_HUB_OFFLINE={hf_hub_offline} model={model}")

    if return_stdout_stderr:
        proc = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=child_env,
            text=True,
            bufsize=1,
        )

        def _dump(src, sinks):
            for line in iter(src.readline, ""):
                for sink in sinks:
                    sink.write(line)
                    sink.flush()
            src.close()

        threading.Thread(
            target=_dump,
            args=(proc.stdout, [return_stdout_stderr[0], sys.stdout]),
            daemon=True,
        ).start()
        threading.Thread(
            target=_dump,
            args=(proc.stderr, [return_stdout_stderr[1], sys.stderr]),
            daemon=True,
        ).start()
    else:
        proc = subprocess.Popen(command, stdout=None, stderr=None, env=child_env)

    return proc


def _wait_for_server_health(
    proc: subprocess.Popen,
    base_url: str,
    api_key: Optional[str],
    timeout_duration: float,
) -> Tuple[bool, Optional[str]]:
    """Wait for server health check to pass.

    Args:
        proc: Server subprocess
        base_url: Base URL for health check
        api_key: Optional API key for authorization
        timeout_duration: Maximum wait time in seconds

    Returns:
        Tuple of (success, error_message)
    """
    start_time = time.perf_counter()
    with requests.Session() as session:
        while time.perf_counter() - start_time < timeout_duration:
            return_code = proc.poll()
            if return_code is not None:
                return False, f"Server process exited with code {return_code}"

            try:
                headers = {
                    "Content-Type": "application/json; charset=utf-8",
                    "Authorization": f"Bearer {api_key}",
                }
                response = session.get(
                    f"{base_url}/health_generate",
                    headers=headers,
                    timeout=5,
                )
                if response.status_code == 200:
                    return True, None
            except requests.RequestException:
                pass

            return_code = proc.poll()
            if return_code is not None:
                return False, f"Server unexpectedly exited (return_code={return_code})"

            time.sleep(10)

    return False, "Server failed to start within the timeout period"


def popen_launch_server(
    model: str,
    base_url: str,
    timeout: float,
    api_key: Optional[str] = None,
    other_args: Optional[list[str]] = None,
    env: Optional[dict] = None,
    return_stdout_stderr: Optional[tuple] = None,
    device: str = "auto",
    pd_separated: bool = False,
    num_replicas: Optional[int] = None,
):
    """Launch a server process with automatic device detection and offline/online retry.

    Args:
        model: Model path or identifier
        base_url: Base URL for the server
        timeout: Timeout for server startup
        api_key: Optional API key for authentication
        other_args: Additional command line arguments
        env: Environment dict for subprocess
        return_stdout_stderr: Optional tuple for output capture
        device: Device type ("auto", "cuda", "rocm" or "cpu")
        pd_separated: Whether to use PD separated mode
        num_replicas: Number of replicas for mixed PD mode

    Returns:
        Started subprocess.Popen object
    """
    other_args = other_args or []

    # Auto-detect device if needed
    if device == "auto":
        device = auto_config_device()
        other_args = list(other_args)
        other_args += ["--device", str(device)]

    # CI-specific: Validate cache and enable offline mode if complete
    if env is None:
        env = os.environ.copy()
    else:
        merged = os.environ.copy()
        merged.update(env)
        env = merged

    # Store per-run marker path for potential invalidation
    per_run_marker_path = None
    try:
        from sglang.utils import is_in_ci

        if is_in_ci():
            per_run_marker_path = _try_enable_offline_mode_if_cache_complete(
                model, env, other_args
            )
    except Exception as e:
        print(f"CI cache validation failed (non-fatal): {e}")

    # Build server command
    _, host, port = base_url.split(":")
    host = host[2:]

    use_mixed_pd_engine = not pd_separated and num_replicas is not None
    if pd_separated or use_mixed_pd_engine:
        command = "sglang.launch_pd_server"
    else:
        command = "sglang.launch_server"

    command = [
        "python3",
        "-m",
        command,
        "--model-path",
        model,
        *[str(x) for x in other_args],
    ]

    if pd_separated or use_mixed_pd_engine:
        command.extend(["--lb-host", host, "--lb-port", port])
    else:
        command.extend(["--host", host, "--port", port])

    if use_mixed_pd_engine:
        command.extend(["--mixed", "--num-replicas", str(num_replicas)])

    if api_key:
        command += ["--api-key", api_key]

    print(f"command={shlex.join(command)}")

    # Track if offline mode was enabled for potential retry
    offline_enabled = env.get("HF_HUB_OFFLINE") == "1"

    # First launch attempt
    process = _launch_server_process(command, env, return_stdout_stderr, model)
    success, error_msg = _wait_for_server_health(process, base_url, api_key, timeout)

    # If offline launch failed and offline was enabled, retry with online mode
    if not success and offline_enabled:
        print(
            f"CI_OFFLINE: Offline launch failed ({error_msg}), retrying with online mode..."
        )

        # Kill failed process
        try:
            if process.poll() is None:
                kill_process_tree(process.pid)
            else:
                process.wait(timeout=5)
        except Exception as e:
            print(f"CI_OFFLINE: Error cleaning up failed offline process: {e}")

        # Invalidate per-run marker to prevent subsequent tests from using offline
        if per_run_marker_path and os.path.exists(per_run_marker_path):
            try:
                os.remove(per_run_marker_path)
                print("CI_OFFLINE: Invalidated per-run marker due to offline failure")
            except Exception as e:
                print(f"CI_OFFLINE: Failed to remove per-run marker: {e}")

        # Retry with online mode
        env["HF_HUB_OFFLINE"] = "0"
        process = _launch_server_process(command, env, return_stdout_stderr, model)
        success, error_msg = _wait_for_server_health(
            process, base_url, api_key, timeout
        )

        if success:
            print("CI_OFFLINE: Online retry succeeded")
            return process

        # Online retry also failed
        try:
            kill_process_tree(process.pid)
        except Exception as e:
            print(f"CI_OFFLINE: Error killing process after online retry failure: {e}")

        if "exited" in error_msg:
            raise Exception(error_msg + ". Check server logs for errors.")
        raise TimeoutError(error_msg)

    # First attempt succeeded or offline was not enabled
    if success:
        return process

    # First attempt failed and offline was not enabled
    try:
        kill_process_tree(process.pid)
    except Exception as e:
        print(f"CI_OFFLINE: Error killing process after first attempt failure: {e}")

    if "exited" in error_msg:
        raise Exception(error_msg + ". Check server logs for errors.")
    raise TimeoutError(error_msg)


def popen_launch_pd_server(
    model: str,
    base_url: str,
    timeout: float,
    api_key: Optional[str] = None,
    other_args: list[str] = (),
    env: Optional[dict] = None,
):
    _, host, port = base_url.split(":")
    host = host[2:]

    command = "sglang.launch_server"

    command = [
        "python3",
        "-m",
        command,
        "--model-path",
        model,
        *[str(x) for x in other_args],
    ]

    command.extend(
        [
            "--host",
            host,
            "--port",
            port,
        ]
    )

    if api_key:
        command += ["--api-key", api_key]

    print(f"command={' '.join(command)}")

    process = subprocess.Popen(command, stdout=None, stderr=None, env=env)

    return process


def get_similarities(vec1, vec2):
    return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)


def get_benchmark_args(
    base_url="",
    backend="sglang",
    dataset_name="",
    dataset_path="",
    tokenizer="",
    num_prompts=500,
    sharegpt_output_len=None,
    random_input_len=4096,
    random_output_len=2048,
    sharegpt_context_len=None,
    request_rate=float("inf"),
    disable_stream=False,
    disable_ignore_eos=False,
    seed: int = 0,
    device="auto",
    pd_separated: bool = False,
    lora_name=None,
    lora_request_distribution="uniform",
    lora_zipf_alpha=1.5,
    gsp_num_groups=4,
    gsp_prompts_per_group=4,
    gsp_system_prompt_len=128,
    gsp_question_len=32,
    gsp_output_len=32,
    gsp_num_turns=1,
    header=None,
):
    return SimpleNamespace(
        backend=backend,
        base_url=base_url,
        host=None,
        port=None,
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        model=None,
        tokenizer=tokenizer,
        num_prompts=num_prompts,
        sharegpt_output_len=sharegpt_output_len,
        sharegpt_context_len=sharegpt_context_len,
        random_input_len=random_input_len,
        random_output_len=random_output_len,
        random_range_ratio=0.0,
        request_rate=request_rate,
        multi=None,
        output_file=None,
        disable_tqdm=False,
        disable_stream=disable_stream,
        return_logprob=False,
        return_routed_experts=False,
        seed=seed,
        disable_ignore_eos=disable_ignore_eos,
        extra_request_body=None,
        apply_chat_template=False,
        profile=None,
        lora_name=lora_name,
        lora_request_distribution=lora_request_distribution,
        lora_zipf_alpha=lora_zipf_alpha,
        prompt_suffix="",
        device=device,
        pd_separated=pd_separated,
        gsp_num_groups=gsp_num_groups,
        gsp_prompts_per_group=gsp_prompts_per_group,
        gsp_system_prompt_len=gsp_system_prompt_len,
        gsp_question_len=gsp_question_len,
        gsp_output_len=gsp_output_len,
        gsp_num_turns=gsp_num_turns,
        header=header,
    )


def run_bench_serving(
    model,
    num_prompts,
    request_rate,
    other_server_args,
    dataset_name="random",
    dataset_path="",
    tokenizer=None,
    random_input_len=4096,
    random_output_len=2048,
    sharegpt_context_len=None,
    disable_stream=False,
    disable_ignore_eos=False,
    need_warmup=False,
    seed: int = 0,
    device="auto",
    background_task: Optional[Callable[[str, asyncio.Event], Awaitable[None]]] = None,
    lora_name: Optional[str] = None,
):
    if device == "auto":
        device = auto_config_device()
    # Launch the server
    base_url = DEFAULT_URL_FOR_TEST
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
    )

    # Run benchmark
    args = get_benchmark_args(
        base_url=base_url,
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        tokenizer=tokenizer,
        num_prompts=num_prompts,
        random_input_len=random_input_len,
        random_output_len=random_output_len,
        sharegpt_context_len=sharegpt_context_len,
        request_rate=request_rate,
        disable_stream=disable_stream,
        disable_ignore_eos=disable_ignore_eos,
        seed=seed,
        device=device,
        lora_name=lora_name,
    )

    async def _run():
        if need_warmup:
            warmup_args = copy.deepcopy(args)
            warmup_args.num_prompts = 16
            await asyncio.to_thread(run_benchmark, warmup_args)

        start_event = asyncio.Event()
        stop_event = asyncio.Event()
        task_handle = (
            asyncio.create_task(background_task(base_url, start_event, stop_event))
            if background_task
            else None
        )

        try:
            start_event.set()
            result = await asyncio.to_thread(run_benchmark, args)
        finally:
            if task_handle:
                stop_event.set()
                await task_handle

        return result

    try:
        res = asyncio.run(_run())
    finally:
        kill_process_tree(process.pid)

    assert res["completed"] == num_prompts
    return res


async def _run_api_benchmark_requests(
    base_url: str,
    endpoint: str,
    test_requests: List[dict],
    num_requests: int,
    response_validator: Callable[[dict], bool],
):
    """
    Helper function to run API benchmark requests and collect metrics.

    Args:
        base_url: The base URL of the server
        endpoint: The API endpoint to test (e.g., "/v1/score", "/v1/embeddings")
        test_requests: List of request payloads to send
        num_requests: Total number of requests expected
        response_validator: Function to validate if response contains expected data

    Returns:
        Dictionary with benchmark metrics
    """
    start_time = time.monotonic()
    successful_requests = 0
    total_latency = 0
    latencies = []

    async with aiohttp.ClientSession() as session:
        for request_data in test_requests:
            try:
                request_start = time.monotonic()
                async with session.post(
                    f"{base_url}{endpoint}",
                    json=request_data,
                    timeout=aiohttp.ClientTimeout(total=30),
                ) as response:
                    if response.status == 200:
                        response_data = await response.json()
                        request_end = time.monotonic()

                        if response_validator(response_data):
                            latency_ms = (request_end - request_start) * 1000
                            latencies.append(latency_ms)
                            total_latency += latency_ms
                            successful_requests += 1
            except Exception:
                continue

    end_time = time.monotonic()
    total_time = end_time - start_time

    if successful_requests > 0:
        throughput = successful_requests / total_time
        avg_latency = total_latency / successful_requests
        p95_latency = np.percentile(latencies, 95) if latencies else 0

        return {
            "completed": successful_requests,
            "total_requests": num_requests,
            "throughput": throughput,
            "avg_latency_ms": avg_latency,
            "p95_latency_ms": p95_latency,
            "successful_requests": successful_requests,
        }
    else:
        return {
            "completed": 0,
            "total_requests": num_requests,
            "throughput": 0,
            "avg_latency_ms": 0,
            "p95_latency_ms": 0,
            "successful_requests": 0,
        }


def run_score_benchmark(
    model,
    num_requests=100,
    batch_size=5,
    other_server_args=None,
    need_warmup=False,
    device="auto",
):
    """Score API benchmark function compatible with run_bench_serving pattern"""
    if other_server_args is None:
        other_server_args = []

    if device == "auto":
        device = auto_config_device()

    # Launch the server (consistent with run_bench_serving)
    base_url = DEFAULT_URL_FOR_TEST
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
    )

    async def _run_benchmark():
        # Load tokenizer for generating test data
        from sglang.srt.utils.hf_transformers_utils import get_tokenizer

        tokenizer = get_tokenizer(model)

        # Score API configuration
        score_query_tokens = 120
        score_item_tokens = 180
        score_label_token_ids = [9454, 2753]  # Yes/No token IDs
        special_token = "<|im_start|>"

        def generate_text_with_token_count(num_tokens):
            """Generate text with precise token count using replicated token."""
            text = special_token * num_tokens
            actual_tokens = len(tokenizer.encode(text, add_special_tokens=False))
            if actual_tokens != num_tokens:
                text = special_token * (
                    num_tokens
                    // len(tokenizer.encode(special_token, add_special_tokens=False))
                )
            return text

        if need_warmup:
            warmup_data = {
                "query": generate_text_with_token_count(score_query_tokens),
                "items": [
                    generate_text_with_token_count(score_item_tokens) for _ in range(3)
                ],
                "label_token_ids": score_label_token_ids,
                "model": model,
                "apply_softmax": True,
            }

            async with aiohttp.ClientSession() as session:
                try:
                    await session.post(
                        f"{base_url}/v1/score",
                        json=warmup_data,
                        timeout=aiohttp.ClientTimeout(total=30),
                    )
                except:
                    pass  # Ignore warmup errors

        test_requests = []
        for i in range(num_requests):
            query = generate_text_with_token_count(score_query_tokens)
            items = [
                generate_text_with_token_count(score_item_tokens)
                for _ in range(batch_size)
            ]

            score_data = {
                "query": query,
                "items": items,
                "label_token_ids": score_label_token_ids,
                "model": model,
                "apply_softmax": True,
            }
            test_requests.append(score_data)

        # Run benchmark requests using shared helper
        return await _run_api_benchmark_requests(
            base_url=base_url,
            endpoint="/v1/score",
            test_requests=test_requests,
            num_requests=num_requests,
            response_validator=lambda resp: "scores" in resp or "logprobs" in resp,
        )

    try:
        res = asyncio.run(_run_benchmark())
    finally:
        kill_process_tree(process.pid)

    assert res["completed"] == res["successful_requests"]
    return res


def run_embeddings_benchmark(
    model,
    num_requests=100,
    batch_size=1,
    input_tokens=500,
    other_server_args=None,
    need_warmup=False,
    device="auto",
):
    """Embeddings API benchmark function compatible with run_bench_serving pattern"""
    if other_server_args is None:
        other_server_args = []

    if device == "auto":
        device = auto_config_device()

    # Add --is-embedding flag for embedding models
    server_args = ["--is-embedding"] + other_server_args

    # Launch the server (consistent with run_bench_serving)
    base_url = DEFAULT_URL_FOR_TEST
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=server_args,
    )

    async def _run_benchmark():

        # Load tokenizer for generating test data
        from sglang.srt.utils.hf_transformers_utils import get_tokenizer

        tokenizer = get_tokenizer(model)

        def generate_text_with_token_count(num_tokens):
            """Generate text with precise token count using special tokens."""
            # Use a token that reliably produces 1 token
            special_token = "<|im_start|>"
            # Verify it's a single token
            test_tokens = tokenizer.encode(special_token, add_special_tokens=False)
            text = special_token * num_tokens
            return text

        # Generate input text
        input_text = generate_text_with_token_count(input_tokens)

        if need_warmup:
            warmup_data = {
                "input": input_text,
                "model": model,
            }

            async with aiohttp.ClientSession() as session:
                try:
                    await session.post(
                        f"{base_url}/v1/embeddings",
                        json=warmup_data,
                        timeout=aiohttp.ClientTimeout(total=30),
                    )
                except:
                    pass  # Ignore warmup errors

        test_requests = []
        for i in range(num_requests):
            if batch_size == 1:
                input_data = input_text
            else:
                input_data = [input_text for _ in range(batch_size)]

            embeddings_data = {
                "input": input_data,
                "model": model,
            }

            test_requests.append(embeddings_data)

        # Run benchmark requests using shared helper
        return await _run_api_benchmark_requests(
            base_url=base_url,
            endpoint="/v1/embeddings",
            test_requests=test_requests,
            num_requests=num_requests,
            response_validator=lambda resp: "data" in resp,
        )

    try:
        res = asyncio.run(_run_benchmark())
    finally:
        kill_process_tree(process.pid)

    assert res["completed"] == res["successful_requests"]
    return res


def run_bench_serving_multi(
    model,
    base_url,
    other_server_args,
    benchmark_args,
    need_warmup=False,
    pd_separated=False,
):
    # Launch the server
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
        pd_separated=pd_separated,
    )

    # run benchmark for all
    res_l = []
    try:
        for args in benchmark_args:
            if need_warmup:
                warmup_args = copy.deepcopy(args)
                warmup_args.num_prompts = 16
                run_benchmark(warmup_args)

            res = run_benchmark(args)
            res_l.append((args, res))
    finally:
        kill_process_tree(process.pid)

    return res_l


def run_bench_one_batch(model, other_args):
    """Launch a offline process with automatic device detection.

    Args:
        device: Device type ("auto", "cuda", "rocm" or "cpu").
                If "auto", will detect available platforms automatically.
    """
    # Auto-detect device if needed

    device = auto_config_device()
    print(f"Auto-configed device: {device}", flush=True)
    other_args += ["--device", str(device)]

    command = [
        "python3",
        "-m",
        "sglang.bench_one_batch",
        "--batch-size",
        "1",
        "--input",
        "128",
        "--output",
        "8",
        *[str(x) for x in other_args],
    ]
    if model is not None:
        command += ["--model-path", model]
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    prefill_latency = None
    decode_throughput = None
    decode_latency = None

    try:
        stdout, stderr = process.communicate()
        output = stdout.decode(errors="backslashreplace")
        error = stderr.decode(errors="backslashreplace")
        print(f"Output: {output}", flush=True)
        print(f"Error: {error}", flush=True)

        # Return prefill_latency, decode_throughput, decode_latency
        pattern = r"Benchmark[\s\S]*Total"
        match = re.search(pattern, output)
        bench_output = match[0] if match else ""
        pattern = r".*?latency: (?P<latency>\d+\.\d+).*?throughput:\s*(?P<throughput>\d+\.\d+)"
        match = re.search(r"Prefill." + pattern, bench_output)
        if match:
            prefill_latency = float(match.group("latency"))
        match = re.search(r"Decode." + pattern, bench_output)
        if match:
            decode_latency = float(match.group("latency"))
            decode_throughput = float(match.group("throughput"))
    finally:
        kill_process_tree(process.pid)

    if prefill_latency is None or decode_throughput is None or decode_latency is None:
        raise RuntimeError(
            f"Failed to parse benchmark output. "
            f"prefill_latency={prefill_latency}, decode_throughput={decode_throughput}, decode_latency={decode_latency}"
        )

    return prefill_latency, decode_throughput, decode_latency


def run_bench_offline_throughput(model, other_args):
    command = [
        "python3",
        "-m",
        "sglang.bench_offline_throughput",
        "--num-prompts",
        "1",
        "--dataset-name",
        "random",
        "--random-input-len",
        "256",
        "--random-output-len",
        "256",
        "--model-path",
        model,
        *[str(x) for x in other_args],
    ]

    print(f"command={' '.join(command)}")
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    try:
        stdout, stderr = process.communicate()
        output = stdout.decode(errors="backslashreplace")
        error = stderr.decode(errors="backslashreplace")
        print(f"Output: {output}", flush=True)
        print(f"Error: {error}", flush=True)

        output_throughput = -1
        for line in output.split("\n"):
            if "Last generation throughput (tok/s):" in line:
                output_throughput = float(line.split(":")[-1])
    finally:
        kill_process_tree(process.pid)

    return output_throughput


def run_bench_one_batch_server(
    model,
    base_url,
    server_args,
    bench_args,
    other_server_args,
    simulate_spec_acc_lens=None,
):
    from sglang.bench_one_batch_server import run_benchmark

    if simulate_spec_acc_lens is not None:
        env = {**os.environ, "SIMULATE_ACC_LEN": str(simulate_spec_acc_lens)}
    else:
        env = None

    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
        env=env,
    )
    try:
        run_benchmark(server_args=server_args, bench_args=bench_args)
    finally:
        kill_process_tree(process.pid)


def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    """calculate the ROUGE-L score"""
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores


STDERR_FILENAME = "/tmp/stderr.txt"
STDOUT_FILENAME = "/tmp/stdout.txt"


def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
    """Print the output in real time with another thread."""
    while not os.path.exists(filename):
        time.sleep(0.01)

    pt = 0
    while pt >= 0:
        if pt > 0 and not os.path.exists(filename):
            break
        try:
            lines = open(filename).readlines()
        except FileNotFoundError:
            print(f"{pt=}, {os.path.exists(filename)=}")
            raise
        for line in lines[pt:]:
            print(line, end="", flush=True)
            output_lines.append(line)
            pt += 1
        time.sleep(0.1)


def run_and_check_memory_leak(
    workload_func,
    disable_radix_cache,
    enable_mixed_chunk,
    disable_overlap,
    chunked_prefill_size,
    assert_has_abort,
    api_key: Optional[str] = None,
):
    other_args = [
        "--chunked-prefill-size",
        str(chunked_prefill_size),
        "--log-level",
        "debug",
    ]
    if disable_radix_cache:
        other_args += ["--disable-radix-cache"]
    if enable_mixed_chunk:
        other_args += ["--enable-mixed-chunk"]
    if disable_overlap:
        other_args += ["--disable-overlap-schedule"]

    model = DEFAULT_MODEL_NAME_FOR_TEST
    port = random.randint(4000, 5000)
    base_url = f"http://127.0.0.1:{port}"

    # Create files and launch the server
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_args,
        return_stdout_stderr=(stdout, stderr),
        api_key=api_key,
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines,))
    t.start()

    # Run the workload
    workload_func(base_url, model)

    # Clean up everything
    kill_process_tree(process.pid)
    stdout.close()
    stderr.close()
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
    kill_process_tree(process.pid)
    t.join()

    # Assert success
    has_new_server = False
    has_leak = False
    has_abort = False
    for line in output_lines:
        if "Uvicorn running" in line:
            has_new_server = True
        if "leak" in line:
            has_leak = True
        if "Abort" in line:
            has_abort = True

    assert has_new_server
    assert not has_leak
    if assert_has_abort:
        assert has_abort


def run_command_and_capture_output(command, env: Optional[dict] = None):
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = subprocess.Popen(
        command, stdout=stdout, stderr=stdout, env=env, text=True
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
    t.start()

    # Join the process
    process.wait()

    stdout.close()
    stderr.close()
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
    kill_process_tree(process.pid)
    t.join()

    return output_lines


def run_mmlu_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
    disable_overlap=False,
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        # Run the eval
        args = SimpleNamespace(
            base_url=base_url,
            model=model,
            eval_name="mmlu",
            num_examples=128,
            num_threads=128,
        )

        try:
            metrics = run_eval(args)
            assert metrics["score"] >= 0.65, f"{metrics=}"
        finally:
            pass

    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
        disable_overlap,
        chunked_prefill_size,
        assert_has_abort=False,
    )


def run_mulit_request_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
    enable_overlap=False,
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        def run_one(_):
            prompt = """
            System: You are a helpful assistant.
            User: What is the capital of France?
            Assistant: The capital of France is
            """

            response = requests.post(
                f"{base_url}/generate",
                json={
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 8,
                    },
                },
            )
            ret = response.json()

        with ThreadPoolExecutor(2) as executor:
            list(executor.map(run_one, list(range(4))))

    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
        enable_overlap,
        chunked_prefill_size,
        assert_has_abort=False,
    )


def write_github_step_summary(content):
    if not os.environ.get("GITHUB_STEP_SUMMARY"):
        logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
        return

    with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
        f.write(content)


def run_logprob_check(self: unittest.TestCase, arg: Tuple):
    (
        input_len,
        output_len,
        temperature,
        logprob_start_len,
        return_logprob,
        top_logprobs_num,
    ) = arg
    input_ids = list(range(input_len))

    response = requests.post(
        self.base_url + "/generate",
        json={
            "input_ids": input_ids,
            "sampling_params": {
                "temperature": temperature,
                "max_new_tokens": output_len,
                "ignore_eos": True,
            },
            "return_logprob": return_logprob,
            "logprob_start_len": logprob_start_len,
            "top_logprobs_num": top_logprobs_num,
        },
    )
    response_json = response.json()

    res = response_json
    self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
    self.assertEqual(res["meta_info"]["completion_tokens"], output_len)

    # Test the number of tokens are correct
    if return_logprob:
        self.assertEqual(
            len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
            res["meta_info"]["prompt_tokens"],
        )
        self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)

        if top_logprobs_num:
            self.assertEqual(
                len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
                res["meta_info"]["prompt_tokens"],
            )
            self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)

            for i in range(output_len):
                self.assertEqual(
                    len(res["meta_info"]["output_top_logprobs"][i]),
                    top_logprobs_num,
                )

                # Test the top-1 tokens are the same as output tokens if temperature == 0
                if temperature == 0:
                    rank = 0
                    while rank < len(res["meta_info"]["output_top_logprobs"][i]):
                        try:
                            self.assertListEqual(
                                res["meta_info"]["output_token_logprobs"][i],
                                res["meta_info"]["output_top_logprobs"][i][rank],
                            )
                            break
                        except AssertionError:
                            # There's a tie. Allow the second item in this case.
                            if (
                                res["meta_info"]["output_top_logprobs"][i][rank][0]
                                == res["meta_info"]["output_top_logprobs"][i][rank + 1][
                                    0
                                ]
                            ):
                                rank += 1
                            else:
                                raise


def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
    """Sends generate request serially and returns status codes. Max concurrency is 1."""

    def generate():
        prompt = """
        System: You are a helpful assistant.
        User: What is the capital of France?
        Assistant: The capital of France is
        """
        response = requests.post(
            f"{base_url}/generate",
            json={
                "text": prompt,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 500,
                },
            },
        )
        return response.status_code

    return [generate() for _ in range(num_requests)]


async def send_concurrent_generate_requests(
    base_url: str, num_requests: int
) -> List[str]:
    """Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""

    async def async_generate():
        async with aiohttp.ClientSession() as session:
            prompt = """
            System: You are a helpful assistant.
            User: What is the capital of France?
            Assistant: The capital of France is
            """
            async with session.post(
                f"{base_url}/generate",
                json={
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 500,
                    },
                },
            ) as response:
                return response.status

    tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
    return await asyncio.gather(*tasks)


async def send_concurrent_generate_requests_with_custom_params(
    base_url: str,
    custom_params: List[dict[str, Any]],
) -> Tuple[int, Any]:
    """Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""

    base_payload = {
        "text": """
                System: You are a helpful assistant.
                User: What is the capital of France?
                Assistant: The capital of France is
                """,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 500,
        },
    }

    async def async_generate_with_priority(req):
        async with aiohttp.ClientSession() as session:
            async with session.post(
                f"{base_url}/generate",
                json=req,
            ) as response:
                resp_json = await response.json()
                return (response.status, resp_json)

    tasks = []
    for c in custom_params:
        req = base_payload.copy()
        req.update(c)
        tasks.append(asyncio.create_task(async_generate_with_priority(req)))
    return await asyncio.gather(*tasks)


def run_distributed_test(func, world_size=2, backend="nccl", **kwargs):
    """Spawn ``world_size`` processes, initialise torch.distributed in each,
    run *func(rank, **kwargs)*, and propagate any worker exception to the caller.
    """
    import torch.multiprocessing as mp

    ctx = mp.get_context("spawn")
    result_queue = ctx.Queue()
    port = find_available_port(29500)

    processes = []
    for rank in range(world_size):
        p = ctx.Process(
            target=_distributed_worker,
            args=(rank, world_size, backend, port, func, result_queue, kwargs),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    errors = [result_queue.get() for _ in range(world_size)]
    errors = [e for e in errors if e]
    if errors:
        raise AssertionError("\n".join(errors))


def _distributed_worker(rank, world_size, backend, port, func, result_queue, kwargs):
    import traceback

    import torch.distributed as dist

    if backend == "nccl":
        torch.cuda.set_device(rank)
    dist.init_process_group(
        backend=backend,
        init_method=f"tcp://127.0.0.1:{port}",
        world_size=world_size,
        rank=rank,
    )
    try:
        func(rank, **kwargs)
        result_queue.put(None)
    except Exception as e:
        result_queue.put(f"Rank {rank}: {e}\n{traceback.format_exc()}")
    finally:
        dist.destroy_process_group()


class CustomTestCase(unittest.TestCase):
    def _callTestMethod(self, method):
        max_retry = envs.SGLANG_TEST_MAX_RETRY.get()
        if max_retry is None:
            max_retry = 1 if is_in_ci() else 0
        retry(
            lambda: super(CustomTestCase, self)._callTestMethod(method),
            max_retry=max_retry,
        )

    def setUp(self):
        print(
            f"[CI Test Method] {self.__class__.__name__}.{self._testMethodName}",
            flush=True,
        )


def dump_bench_raw_result(
    path: str,
    states,
    preds,
    labels,
):
    if not path:
        return

    rows = []
    for i in range(len(states)):
        state = states[i]
        output = state["answer"]
        prompt = _ensure_remove_suffix(state.text(), output)
        rows.append(
            dict(
                prompt_id=i,
                prompt=prompt,
                output=output,
                correct=bool(preds[i] == labels[i]),
            )
        )

    print(f"BenchRawResultDumper save results to {path}")
    Path(path).write_text("\n".join(json.dumps(row) for row in rows))


def _ensure_remove_suffix(text: str, suffix: str):
    assert text.endswith(suffix)
    return text.removesuffix(suffix)


class ModelLaunchSettings:
    def __init__(
        self,
        model_path: str,
        tp_size: int = 1,
        extra_args: Optional[List[str]] = None,
        env: Optional[dict] = None,
        variant: Optional[str] = None,
    ):
        self.model_path = model_path
        self.tp_size = tp_size
        self.extra_args = list(extra_args) if extra_args else []
        self.env = env
        self.variant = variant

        if self.tp_size > 1 and "--tp" not in self.extra_args:
            self.extra_args.extend(["--tp", str(self.tp_size)])

        fixed_args = ["--enable-multimodal", "--trust-remote-code"]
        for fixed_arg in fixed_args:
            if fixed_arg not in self.extra_args:
                self.extra_args.append(fixed_arg)


class ModelEvalMetrics:
    def __init__(self, accuracy: float, eval_time: float):
        self.accuracy = accuracy
        self.eval_time = eval_time


def extract_trace_link_from_bench_one_batch_server_output(output: str) -> str:
    match = re.search(r"\[Profile\]\((.*?)\)", output)
    if match:
        trace_link = match.group(1)
        return trace_link
    return None


def parse_models(model_string: str):
    return [model.strip() for model in model_string.split(",") if model.strip()]


def check_evaluation_test_results(
    results,
    test_name,
    model_accuracy_thresholds,
    model_latency_thresholds=None,
    model_count=None,
):
    """
    results: list of tuple of (model_path, accuracy, latency) or (model_path, accuracy, latency, error)
    """
    failed_models = []
    if model_latency_thresholds is not None:
        summary = " | model | status | score | score_threshold | latency | latency_threshold | error | \n"
        summary += "| ----- | ------ | ----- | --------------- | ------- | ----------------- | ----- | \n"
    else:
        summary = " | model | status | score | score_threshold | error | \n"
        summary += "| ----- | ------ | ----- | --------------- | ----- | \n"

    results_dict = {
        res[0]: (res[1], res[2], res[3] if len(res) == 4 else None) for res in results
    }

    for model, accuracy_threshold in sorted(model_accuracy_thresholds.items()):
        latency_threshold = (
            model_latency_thresholds.get(model)
            if model_latency_thresholds is not None
            else 1e9
        )

        # check for error here
        error = (
            results_dict.get(model, (None, None, None))[2]
            if model in results_dict
            else None
        )

        if model in results_dict and error is None:
            accuracy, latency, _ = results_dict[model]
            is_success = accuracy >= accuracy_threshold and latency <= latency_threshold
            status_emoji = "✅" if is_success else "❌"

            if not is_success:
                if accuracy < accuracy_threshold:
                    failed_models.append(
                        f"\nScore Check Failed: {model}\n"
                        f"Model {model} score ({accuracy:.4f}) is below threshold ({accuracy_threshold:.4f})"
                    )
                if latency > latency_threshold:
                    failed_models.append(
                        f"\nLatency Check Failed: {model}\n"
                        f"Model {model} latency ({latency:.4f}) is above threshold ({latency_threshold:.4f})"
                    )

            if model_latency_thresholds is not None:
                line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold} | {latency} | {latency_threshold} | - |\n"
            else:
                line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold} | - |\n"
        else:
            status_emoji = "❌"
            error_display = error if error else "Model not evaluated"
            failed_models.append(f"Model failed to launch or be evaluated: {model}")
            if model_latency_thresholds is not None:
                line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold} | N/A | {latency_threshold} | {error_display} |\n"
            else:
                line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold} | {error_display} |\n"

        summary += line

    print(summary)

    if is_in_ci():
        write_github_step_summary(f"## {test_name}\n{summary}")

    if failed_models:
        print("Some models failed the evaluation.")
        raise AssertionError("\n".join(failed_models))


# Bench knobs for bench_one_batch_server (override by env)
def _parse_int_list_env(name: str, default_val: str):
    val = os.environ.get(name, default_val)
    return [int(x) for x in val.split(",") if x]


# Return filenames
def find_traces_under_path(path: str) -> List[str]:
    results = []
    for _, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".trace.json.gz"):
                results.append(f"{file}")
    return results


def write_results_to_json(model, metrics, mode="a"):
    result = {
        "timestamp": datetime.now().isoformat(),
        "model": model,
        "metrics": metrics,
        "score": metrics["score"],
    }

    if "latency" in metrics:
        result["latency"] = (metrics.get("latency"),)

    existing_results = []
    if mode == "a" and os.path.exists("results.json"):
        try:
            with open("results.json", "r") as f:
                existing_results = json.load(f)
        except json.JSONDecodeError:
            existing_results = []

    if isinstance(existing_results, list):
        existing_results.append(result)
    else:
        existing_results = [result]

    with open("results.json", "w") as f:
        json.dump(existing_results, f, indent=2)


def intel_amx_benchmark(extra_args=None, min_throughput=None):
    def decorator(test_func):
        @wraps(test_func)
        def wrapper(self):
            common_args = [
                "--attention-backend",
                "intel_amx",
                "--disable-radix",
                "--trust-remote-code",
            ]
            full_args = common_args + (extra_args or [])

            model = test_func(self)
            prefill_latency, decode_throughput, decode_latency = run_bench_one_batch(
                model, full_args
            )

            print(f"{model=}")
            print(f"{prefill_latency=}")
            print(f"{decode_throughput=}")
            print(f"{decode_latency=}")

            if is_in_ci() and min_throughput is not None:
                self.assertGreater(decode_throughput, min_throughput)

        return wrapper

    return decorator


def get_gpu_count():
    if get_device() == "cpu":
        gpu_count = 0
    else:
        gpu_count = torch.accelerator.device_count()
    return gpu_count


def empty_gpu_cache():
    """
    Unified empty_cache for PyTorch 2.8 (no torch.accelerator)
    and PyTorch 2.9+ (where torch.accelerator.empty_cache() exists).
    """
    if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "empty_cache"):
        return torch.accelerator.empty_cache()

    # CUDA
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        torch.cuda.empty_cache()
        return

    # XPU (Intel)
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        torch.xpu.empty_cache()
        return

    return


def get_gpu_memory_gb():
    if is_cuda():
        return torch.cuda.device_memory_used() / 1024**3
    elif is_xpu():
        return torch.xpu.memory_allocated() / 1024**3
    else:
        return 0


def run_doctests(obj: Callable[..., Any] | ModuleType):
    mod = inspect.getmodule(obj)
    globals = dict(mod.__dict__)
    finder = doctest.DocTestFinder()
    runner = doctest.DocTestRunner(verbose=True)
    tests = finder.find(obj, obj.__name__, globs=globals)
    assert len(tests) >= 1, f"No tests found for {obj.__name__}"
    for test in tests:
        result = runner.run(test)
        assert result.failed == 0, f"Test {test.name} failed"


def dump_metric(metric_name: str, value: Any, labels: Optional[dict] = None):
    """
    Output test metric to JSONL and stdout for CI performance tracking.

    Schema (v1):
      - Required: filename, test_case, metric_name, value
      - Optional fields supported: ts, labels
        - ts is emitted by default for convenience
        - labels preferred as dict; if not JSON-serializable, stored as string

    Value types (v1 contract):
      - Supported: int, float, str
      - Input may be bool (will be coerced to int: True=1, False=0)
      - Others: best-effort conversion to float, fallback to str

    Output channels:
      - JSONL: ${SGLANG_TEST_METRICS_OUTPUT}.${pid}.jsonl (if env var set)
      - stdout: [METRIC] metric_name=value [labels=...]

    This function never fails tests - all errors are silently caught.

    Args:
        metric_name: Metric name (e.g., "gsm8k_accuracy", "cache_hit_rate")
        value: Metric value
        labels: Optional label dict (e.g., {"backend": "fa3"})
    """
    try:
        # 1. Capture test context
        filename, test_case = _get_test_context()

        # 2. Convert value to int/float/str
        # First unwrap numpy/torch scalars
        if hasattr(value, "item"):
            value = value.item()

        # Now convert, ensuring no bool in final result
        if isinstance(value, bool):
            converted_value = int(value)  # True->1, False->0
        elif isinstance(value, (int, float, str)):
            converted_value = value
        else:
            try:
                converted_value = float(value)
            except (ValueError, TypeError):
                converted_value = str(value)

        # 3. Build record
        record = {
            "filename": filename,
            "test_case": test_case,
            "metric_name": metric_name,
            "value": converted_value,
            "ts": time.time(),
        }

        # 4. Handle labels (best-effort JSON serialization)
        labels_for_output = None
        if labels:
            try:
                json.dumps(labels, ensure_ascii=False)  # Test serializability
                record["labels"] = labels
                labels_for_output = labels
            except (TypeError, ValueError):
                # If not serializable, stringify
                stringified = str(labels)
                record["labels"] = stringified
                labels_for_output = stringified

        # 5. Write JSONL
        base_path = os.getenv("SGLANG_TEST_METRICS_OUTPUT")
        if base_path:
            try:
                jsonl_path = f"{base_path}.{os.getpid()}.jsonl"
                with open(jsonl_path, "a", encoding="utf-8") as f:
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
            except Exception as e:
                logging.warning(
                    f"sglang.test.dump_metric: failed to write to {jsonl_path}: {e}"
                )

        # 6. Output to stdout (use same labels as record)
        if labels_for_output:
            if isinstance(labels_for_output, str):
                labels_str = f" labels='{labels_for_output}'"
            else:
                labels_str = (
                    f" labels={json.dumps(labels_for_output, ensure_ascii=False)}"
                )
        else:
            labels_str = ""
        print(f"[METRIC] {metric_name}={converted_value}{labels_str}")

    except Exception as e:
        # Silent failure - never break tests
        logging.warning(
            f"sglang.test.dump_metric: failed to dump metric '{metric_name}': {e}",
            exc_info=True,
        )


def _get_test_context() -> tuple[str, str]:
    """
    Get current test's filename and test_case.

    Tries PYTEST_CURRENT_TEST first, falls back to inspect.stack().
    """
    # 1. Try parsing PYTEST_CURRENT_TEST
    pytest_current = os.getenv("PYTEST_CURRENT_TEST")
    if pytest_current:
        # Format: "path/to/test.py::TestClass::test_method (call)"
        parts = pytest_current.split(" ")[0].split("::", 1)
        if len(parts) == 2:
            filename = _repo_relative_path(parts[0])
            test_case = parts[1].replace("::", ".")
            return filename, test_case

    # 2. Fallback to inspect
    import inspect

    frame = inspect.currentframe()
    # Assumes direct callsite: frame -> _get_test_context -> dump_metric -> caller
    # If dump_metric gets wrapped, may need to scan upward
    if frame and frame.f_back and frame.f_back.f_back:
        caller = frame.f_back.f_back
        filename = _repo_relative_path(caller.f_code.co_filename)

        # Try to get class name from self
        test_self = caller.f_locals.get("self")
        if test_self and hasattr(test_self, "__class__"):
            test_case = f"{test_self.__class__.__name__}.{caller.f_code.co_name}"
        else:
            test_case = caller.f_code.co_name

        return filename, test_case

    return "unknown.py", "unknown_test"


def _repo_relative_path(filepath: str) -> str:
    """Convert absolute path to repo-relative, preferring GITHUB_WORKSPACE"""
    # Path is imported at module top (line 20)
    try:
        abs_path = Path(filepath).resolve()

        # Try GITHUB_WORKSPACE first
        workspace = os.getenv("GITHUB_WORKSPACE")
        if workspace:
            try:
                return str(abs_path.relative_to(Path(workspace).resolve()))
            except ValueError:
                pass

        # Fallback to cwd
        try:
            return str(abs_path.relative_to(Path.cwd()))
        except ValueError:
            return abs_path.name

    except Exception:
        return Path(filepath).name
