import dataclasses
import random
from typing import List, Optional

import torch

from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l


@dataclasses.dataclass
class LoRAAdaptor:
    name: str
    prefill_tolerance: float = None
    decode_tolerance: float = None
    rouge_l_tolerance: float = None


@dataclasses.dataclass
class LoRAModelCase:
    base: str
    adaptors: List[LoRAAdaptor]
    tp_size: int = 1
    prefill_tolerance: float = 1e-1
    decode_tolerance: float = 1e-1
    rouge_l_tolerance: float = 1.0
    max_loras_per_batch: int = 1
    max_loaded_loras: Optional[int] = None
    skip_long_prompt: bool = False

    def __post_init__(self):
        if len(self.adaptors) > self.max_loras_per_batch:
            raise ValueError(
                f"For base '{self.base}', number of adaptors ({len(self.adaptors)}) "
                f"must be <= max_loras_per_batch ({self.max_loras_per_batch})"
            )


TORCH_DTYPES = [torch.float16]
BACKENDS = ["triton", "csgmv"]
DEFAULT_PROMPTS = [
    "AI is a field of computer science focused on",
    """
    ### Instruction:
    Tell me about llamas and alpacas
    ### Response:
    Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
    ### Question 2:
    What do you know about llamas?
    ### Answer:
    """,
]

CI_LORA_MODELS = [
    LoRAModelCase(
        base="meta-llama/Llama-3.1-8B-Instruct",
        adaptors=[
            LoRAAdaptor(
                name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
            ),
        ],
        max_loras_per_batch=1,
    ),
]

ALL_OTHER_LORA_MODELS = [
    LoRAModelCase(
        base="meta-llama/Llama-3.1-8B-Instruct",
        adaptors=[
            LoRAAdaptor(
                name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                prefill_tolerance=1e-1,
            ),
        ],
        max_loras_per_batch=1,
    ),
    LoRAModelCase(
        base="meta-llama/Llama-2-7b-hf",
        adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
        max_loras_per_batch=2,
    ),
]

CI_MULTI_LORA_MODELS = [
    # multi-rank case
    LoRAModelCase(
        base="meta-llama/Llama-2-7b-hf",
        adaptors=[
            LoRAAdaptor(
                name="winddude/wizardLM-LlaMA-LoRA-7B",
                prefill_tolerance=1e-1,
            ),
            LoRAAdaptor(
                name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa",
                prefill_tolerance=3e-1,
            ),
        ],
        max_loras_per_batch=2,
        max_loaded_loras=4,
    ),
]

ALL_OTHER_MULTI_LORA_MODELS = [
    LoRAModelCase(
        base="meta-llama/Llama-3.1-8B-Instruct",
        adaptors=[
            LoRAAdaptor(
                name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
                prefill_tolerance=1e-1,
            ),
            LoRAAdaptor(
                name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                prefill_tolerance=1e-1,
            ),
        ],
        max_loras_per_batch=2,
    ),
]

LORA_MODELS_QWEN3 = [
    LoRAModelCase(
        base="Qwen/Qwen3-4B",
        adaptors=[
            LoRAAdaptor(
                name="nissenj/Qwen3-4B-lora-v2",
                prefill_tolerance=3e-1,
            ),
            LoRAAdaptor(
                name="TanXS/Qwen3-4B-LoRA-ZH-WebNovelty-v0.0",
                prefill_tolerance=3e-1,
            ),
        ],
        max_loras_per_batch=2,
    ),
]


def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Matrix multiplication with mixed precision handling for float16"""
    result = torch.matmul(a.float(), b.float())
    return result.to(a.dtype)


def reference_sgmv_shrink(
    x: torch.Tensor,
    weights: torch.Tensor,
    weight_indices: torch.Tensor,
    seq_lengths: torch.Tensor,
    lora_ranks: torch.Tensor,
    lora_scalings: torch.Tensor,
    num_slices: int = 1,
) -> torch.Tensor:
    """
    Simple sequence-level reference implementation of SGMV shrink operation.

    Args:
        x: (total_seq_len, input_dim) - Input activations
        weights: (num_loras, num_slices * max_rank, input_dim) - LoRA A weights
        weight_indices: LoRA idx for each sequence
        seq_lengths: Length of each sequence
        lora_ranks: LoRA rank for each LoRA adapters
        lora_scalings: LoRA scaling for each LoRA adapters
        num_slices: Number of slices (3 for QKV, 2 for gate_up, 1 for others)

    Returns:
        output: (total_seq_len, num_slices * max_rank) - Intermediate activations
    """
    if weights.numel() == 0:
        total_seq_len = x.shape[0]
        return torch.zeros(total_seq_len, 0, dtype=x.dtype, device=x.device)

    total_seq_len, _ = x.shape
    _, weight_out_dim, _ = weights.shape
    max_rank = weight_out_dim // num_slices

    output = torch.zeros(
        total_seq_len, num_slices * max_rank, dtype=x.dtype, device=x.device
    )

    token_offset = 0
    for lora_idx, seq_len, rank, scaling in zip(
        weight_indices,
        seq_lengths,
        lora_ranks[weight_indices],
        lora_scalings[weight_indices],
    ):
        if seq_len == 0:
            continue

        if rank > 0:
            x_seq = x[token_offset : token_offset + seq_len, :]
            w_seq = weights[lora_idx, : num_slices * rank, :]

            result = safe_matmul(x_seq, w_seq.t())
            output[token_offset : token_offset + seq_len, : num_slices * rank] = (
                scaling * result
            )

        token_offset += seq_len

    return output


def reference_sgmv_expand(
    x: torch.Tensor,
    weights: torch.Tensor,
    weight_indices: torch.Tensor,
    seq_lengths: torch.Tensor,
    lora_ranks: torch.Tensor,
    slice_offsets: torch.Tensor,
    base_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Simple sequence-level reference implementation of SGMV expand operation.

    Args:
        x: (total_seq_len, num_slices * max_rank) - Intermediate activations
        weights: (num_loras, output_dim, max_rank) - LoRA B weights
        weight_indices: LoRA idx for each sequence
        seq_lengths: Length of each sequence
        lora_ranks: LoRA rank for each LoRA adapters
        slice_offsets: Tensor defining slice boundaries
        base_output: Optional base output to accumulate into

    Returns:
        output: (total_seq_len, total_output_dim) - Final output
    """
    if weights.numel() == 0:
        total_seq_len = x.shape[0]
        total_output_dim = slice_offsets[-1].item() if len(slice_offsets) > 0 else 0
        return torch.zeros(
            total_seq_len, total_output_dim, dtype=x.dtype, device=x.device
        )

    total_seq_len, _ = x.shape

    num_slices = len(slice_offsets) - 1

    if base_output is not None:
        output = base_output.clone()
    else:
        total_output_dim = slice_offsets[-1].item()
        output = torch.zeros(
            total_seq_len, total_output_dim, dtype=x.dtype, device=x.device
        )

    token_offset = 0
    for lora_idx, seq_len, rank in zip(
        weight_indices, seq_lengths, lora_ranks[weight_indices]
    ):
        if seq_len == 0:
            continue

        if rank > 0:
            # Extract sequence intermediate activations
            x_seq = x[
                token_offset : token_offset + seq_len, : num_slices * rank
            ]  # (seq_len, num_slices * rank)

            for slice_idx in range(num_slices):
                slice_start_input = slice_idx * rank
                slice_end_input = (slice_idx + 1) * rank

                slice_start_output = slice_offsets[slice_idx].item()
                slice_end_output = slice_offsets[slice_idx + 1].item()

                x_slice = x_seq[:, slice_start_input:slice_end_input]  # (seq_len, rank)
                w_slice = weights[
                    lora_idx, slice_start_output:slice_end_output, :rank
                ]  # (slice_dim, rank)

                result = safe_matmul(x_slice, w_slice.t())  # (seq_len, slice_dim)
                output[
                    token_offset : token_offset + seq_len,
                    slice_start_output:slice_end_output,
                ] += result

        token_offset += seq_len

    return output


def run_lora_test_one_by_one(
    prompts: List[str],
    model_case: LoRAModelCase,
    torch_dtype: torch.dtype,
    max_new_tokens: int,
    backend: str = "csgmv",
    enable_lora_overlap_loading: Optional[bool] = None,
    disable_cuda_graph: bool = False,
    disable_radix_cache: bool = False,
    mem_fraction_static: float = 0.88,
    test_tag: str = "",
):
    """
    Input a batch of prompts, and run lora tests one by one with several generate requests
    (each request will have bs=1).
    For prompt0, prompt1, ..., promptN,
    we will use adaptor0, adaptor1, ..., adaptorN included in model case,
    We will then compare the outputs of HF and SRT with and without LoRA.
    If number of prompts is larger than number of adaptors,
    the prompt i will use adaptor i % (number of adaptors).

    Args:
        prompts (List[str]): The batch of prompts to test.
        model_case (LoRAModelCase): The model case to test.
        torch_dtype (torch.dtype): The torch dtype to use.
        max_new_tokens (int): The maximum number of new tokens to generate.
        backend (str): The lora backend to use.
        disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
        disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
        mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
        test_tag (str, optional): The tag to use for the test. Defaults to "".
    """
    base_path = model_case.base

    # Create used adaptors for each prompt in batch
    i, adaptors = 0, []
    for _ in range(len(prompts)):
        adaptors.append(model_case.adaptors[i])
        i = (i + 1) % len(model_case.adaptors)
    adaptor_names = [adaptor.name for adaptor in adaptors]

    print(
        f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
        f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
    )
    with SRTRunner(
        base_path,
        torch_dtype=torch_dtype,
        model_type="generation",
        tp_size=model_case.tp_size,
        lora_paths=[
            adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
        ],
        enable_lora_overlap_loading=enable_lora_overlap_loading,
        max_loras_per_batch=model_case.max_loras_per_batch,
        max_loaded_loras=model_case.max_loaded_loras,
        lora_backend=backend,
        disable_cuda_graph=disable_cuda_graph,
        disable_radix_cache=disable_radix_cache,
        mem_fraction_static=mem_fraction_static,
    ) as srt_runner:
        srt_outputs = srt_runner.forward(
            prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
        )

    with SRTRunner(
        base_path,
        torch_dtype=torch_dtype,
        model_type="generation",
        tp_size=model_case.tp_size,
        mem_fraction_static=mem_fraction_static,
    ) as srt_runner:
        srt_no_lora_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)

    with HFRunner(
        base_path, torch_dtype=torch_dtype, model_type="generation"
    ) as hf_runner:
        hf_outputs = hf_runner.forward(
            prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
        )
        hf_no_lora_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)

    # Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
    for i in range(len(prompts)):
        adaptor = adaptors[i]
        # Use individual adaptor tolerances if set, otherwise use model defaults
        prefill_tol = (
            adaptor.prefill_tolerance
            if adaptor.prefill_tolerance is not None
            else model_case.prefill_tolerance
        )
        decode_tol = (
            adaptor.decode_tolerance
            if adaptor.decode_tolerance is not None
            else model_case.decode_tolerance
        )
        rouge_tol = (
            adaptor.rouge_l_tolerance
            if adaptor.rouge_l_tolerance is not None
            else model_case.rouge_l_tolerance
        )
        # Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
        hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i])
        srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i])
        max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill))
        print("Max prefill diff (HF vs SRT):", max_prefill_diff)

        # Compare decode stage logprobs
        hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i])
        srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i])
        max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode))
        print("Max decode diff (HF vs SRT):", max_decode_diff)

        srt_output_str = srt_outputs.output_strs[i].strip()
        hf_output_str = hf_outputs.output_strs[i].strip()
        rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
        print("ROUGE-L score:", rouge_score)
        print("SRT output:", srt_output_str)
        print("HF output:", hf_output_str)

        # Additional: compare prefill outputs between base model (no LoRA) and LoRA model for reference
        hf_no_lora_prefill = torch.tensor(hf_no_lora_outputs.top_input_logprobs[i])
        srt_no_lora_prefill = torch.tensor(srt_no_lora_outputs.top_input_logprobs[i])
        print(
            "Max diff (SRT base vs SRT LoRA prefill):",
            torch.max(torch.abs(srt_no_lora_prefill - srt_prefill)),
        )
        print(
            "Max diff (HF base vs HF LoRA prefill):",
            torch.max(torch.abs(hf_no_lora_prefill - hf_prefill)),
        )

        if hf_prefill.shape[0] <= 100:
            assert torch.all(torch.abs(hf_prefill - srt_prefill) < prefill_tol), (
                f"Prefill logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
                f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
            )

        if hf_decode.shape[0] <= 100:
            assert torch.all(torch.abs(hf_decode - srt_decode) < decode_tol), (
                f"Decode logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
                f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
            )

        if rouge_score < rouge_tol:
            raise AssertionError(
                f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
                f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
            )


def run_lora_test_by_batch(
    prompts: List[str],
    model_case: LoRAModelCase,
    torch_dtype: torch.dtype,
    max_new_tokens: int,
    backend: str = "csgmv",
    disable_cuda_graph: bool = False,
    disable_radix_cache: bool = False,
    mem_fraction_static: float = 0.88,
    test_tag: str = "",
):
    """
    Run lora tests as a batch.
    For prompt0, prompt1, ..., promptN,
    we will use adaptor0, adaptor1, ..., adaptorN included in model case,
    We will then compare the outputs of HF and SRT with LoRA.
    If number of prompts is larger than number of adaptors,
    the prompt i will use adaptor i % (number of adaptors).

    Args:
        prompts (List[str]): The batch of prompts to test.
        model_case (LoRAModelCase): The model case to test.
        torch_dtype (torch.dtype): The torch dtype to use.
        max_new_tokens (int): The maximum number of new tokens to generate.
        backend (str): The lora backend to use.
        disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
        disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
        mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
        test_tag (str, optional): The tag to use for the test. Defaults to "".
    """
    base_path = model_case.base

    # Create used adaptors for each prompt in batch
    i, adaptors = 0, []
    for _ in range(len(prompts)):
        adaptors.append(model_case.adaptors[i])
        i = (i + 1) % len(model_case.adaptors)
    adaptor_names = [adaptor.name for adaptor in adaptors]

    print(
        f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
        f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
    )
    with SRTRunner(
        base_path,
        torch_dtype=torch_dtype,
        model_type="generation",
        tp_size=model_case.tp_size,
        lora_paths=[
            adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
        ],
        max_loras_per_batch=model_case.max_loras_per_batch,
        max_loaded_loras=model_case.max_loaded_loras,
        lora_backend=backend,
        disable_cuda_graph=disable_cuda_graph,
        disable_radix_cache=disable_radix_cache,
        mem_fraction_static=mem_fraction_static,
    ) as srt_runner:
        srt_outputs = srt_runner.batch_forward(
            prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
        )

    with SRTRunner(
        base_path,
        torch_dtype=torch_dtype,
        model_type="generation",
        tp_size=model_case.tp_size,
        mem_fraction_static=mem_fraction_static,
    ) as srt_runner:
        srt_no_lora_outputs = srt_runner.batch_forward(
            prompts, max_new_tokens=max_new_tokens
        )

    with HFRunner(
        base_path, torch_dtype=torch_dtype, model_type="generation"
    ) as hf_runner:
        hf_outputs = hf_runner.forward(
            prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
        )

    with HFRunner(
        base_path, torch_dtype=torch_dtype, model_type="generation"
    ) as hf_runner:
        hf_no_lora_outputs = hf_runner.forward(
            prompts,
            max_new_tokens=max_new_tokens,
        )

    for i in range(len(prompts)):

        srt_output_str = srt_outputs.output_strs[i].strip()
        hf_output_str = hf_outputs.output_strs[i].strip()
        rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
        print("ROUGE-L score:", rouge_score)
        print("SRT output:", srt_output_str)
        print("HF output:", hf_output_str)
        print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
        print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
        assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
            " "
        ), (
            srt_outputs.output_strs[i].strip(" "),
            hf_outputs.output_strs[i].strip(" "),
        )
        assert srt_no_lora_outputs.output_strs[i].strip(
            " "
        ) == hf_no_lora_outputs.output_strs[i].strip(" "), (
            srt_no_lora_outputs.output_strs[i].strip(" "),
            hf_no_lora_outputs.output_strs[i].strip(" "),
        )


def ensure_reproducibility():
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


TEST_MULTIPLE_BATCH_PROMPTS = [
    """
    ### Instruction:
    Tell me about llamas and alpacas
    ### Response:
    Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
    ### Question 2:
    What do you know about llamas?
    ### Answer:
    """,
    """
    ### Instruction:
    Write a poem about the transformers Python library.
    Mention the word "large language models" in that poem.
    ### Response:
    The Transformers are large language models,
    They're used to make predictions on text.
    """,
    "AI is a field of computer science focused on",
    "Computer science is the study of",
    "Write a short story.",
    "What are the main components of a computer?",
]


def create_multiple_batch_test_samples(
    prompts: List[str], lora_adapter_paths: List[str]
):
    random.seed(42)

    return [
        (
            [
                random.choice(prompts),
                random.choice(prompts),
                random.choice(prompts),
            ],
            [
                None,
                lora_adapter_paths[0],
                lora_adapter_paths[1],
            ],
        ),
        # It can pass half the time on CI, so skip this flaky case for now
        # (
        #     [
        #         random.choice(prompts),
        #         random.choice(prompts),
        #         random.choice(prompts),
        #     ],
        #     [
        #         lora_adapter_paths[0],
        #         None,
        #         lora_adapter_paths[1],
        #     ],
        # ),
        (
            [
                random.choice(prompts),
                random.choice(prompts),
                random.choice(prompts),
            ],
            [lora_adapter_paths[0], lora_adapter_paths[1], None],
        ),
        # It can pass half the time on CI, so skip this flaky case for now
        # (
        #     [
        #         random.choice(prompts),
        #         random.choice(prompts),
        #         random.choice(prompts),
        #     ],
        #     [None, lora_adapter_paths[1], None],
        # ),
        (
            [
                random.choice(prompts),
                random.choice(prompts),
                random.choice(prompts),
            ],
            [None, None, None],
        ),
    ]


def run_lora_multiple_batch_on_model_cases(
    model_cases: List[LoRAModelCase],
    use_spec_decoding: bool = False,
    attention_backend: str = "torch_native",
    disable_cuda_graph: bool = True,
    enable_deterministic_inference: bool = False,
    disable_radix_cache: bool = True,
    enable_lora_overlap_loading: Optional[bool] = None,
):
    for model_case in model_cases:
        for torch_dtype in TORCH_DTYPES:
            max_new_tokens = 32
            base_path = model_case.base
            lora_adapter_paths = [a.name for a in model_case.adaptors]
            assert len(lora_adapter_paths) >= 2

            batches = create_multiple_batch_test_samples(
                TEST_MULTIPLE_BATCH_PROMPTS, lora_adapter_paths
            )

            print(
                f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---"
            )

            # Initialize runners
            ensure_reproducibility()
            spec_args = (
                {}
                if not use_spec_decoding
                else {
                    "speculative_algorithm": "NGRAM",
                    "speculative_num_draft_tokens": 5,
                    "speculative_ngram_min_match_window_size": 2,
                    "speculative_ngram_max_match_window_size": 15,
                }
            )
            srt_runner = SRTRunner(
                base_path,
                torch_dtype=torch_dtype,
                model_type="generation",
                lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
                enable_lora_overlap_loading=enable_lora_overlap_loading,
                max_loras_per_batch=len(lora_adapter_paths) + 1,
                max_loaded_loras=model_case.max_loaded_loras,
                sleep_on_idle=True,  # Eliminate non-determinism by forcing all requests to be processed in one batch.
                attention_backend=attention_backend,
                enable_deterministic_inference=enable_deterministic_inference,
                disable_cuda_graph=disable_cuda_graph,
                disable_radix_cache=disable_radix_cache,
                **spec_args,
            )

            ensure_reproducibility()
            hf_runner = HFRunner(
                base_path,
                torch_dtype=torch_dtype,
                model_type="generation",
                patch_model_do_sample_false=True,
            )

            with srt_runner, hf_runner:
                for i, (prompts, lora_paths) in enumerate(batches):
                    print(
                        f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
                    )

                    srt_outputs = srt_runner.batch_forward(
                        prompts,
                        max_new_tokens=max_new_tokens,
                        lora_paths=lora_paths,
                    )

                    hf_outputs = hf_runner.forward(
                        prompts,
                        max_new_tokens=max_new_tokens,
                        lora_paths=lora_paths,
                    )

                    print("SRT outputs:", [s for s in srt_outputs.output_strs])
                    print("HF outputs:", [s for s in hf_outputs.output_strs])

                    for srt_out, hf_out in zip(
                        srt_outputs.output_strs, hf_outputs.output_strs
                    ):
                        srt_str = srt_out.strip()
                        hf_str = hf_out.strip()
                        rouge_tol = model_case.rouge_l_tolerance
                        rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
                        if rouge_score < rouge_tol:
                            raise AssertionError(
                                f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
                                f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'"
                            )

                    print(f"--- Batch {i+1} Comparison Passed --- ")


def run_lora_batch_splitting_equivalence_test(
    model_cases: List[LoRAModelCase],
    attention_backend: str = "torch_native",
    disable_cuda_graph: bool = True,
    disable_radix_cache: bool = True,
    enable_lora_overlap_loading: Optional[bool] = None,
):
    """
    Test that SRT correctly handles batch splitting with multiple LoRA adapters.

    When the number of distinct adapters (including None for base model) exceeds
    max_loras_per_batch, SRT internally splits requests into microbatches.

    This test validates:
    1. SRT can process batches that trigger internal splitting without errors
    2. Different adapters don't produce all identical outputs (i.e., at least one
       output differs, indicating adapters are being applied correctly)

    Args:
        model_cases: List of LoRAModelCase configurations to test
        attention_backend: Attention backend to use
        disable_cuda_graph: Whether to disable CUDA graph
        disable_radix_cache: Whether to disable radix cache
    """
    max_loras_per_batch = 2

    def _run_test(model_case: LoRAModelCase, torch_dtype: torch.dtype):
        lora_adapter_paths = [a.name for a in model_case.adaptors]
        assert (
            len(lora_adapter_paths) >= max_loras_per_batch
        ), f"Need at least {max_loras_per_batch} adapters for this test"

        max_new_tokens = 64
        base_path = model_case.base

        print(
            f"\n========== Testing batch splitting on base '{base_path}', "
            f"dtype={torch_dtype} =========="
        )

        prompts = [TEST_MULTIPLE_BATCH_PROMPTS[0]] * 3
        test_cases = [
            (
                prompts,
                [None, lora_adapter_paths[0], lora_adapter_paths[1]],
            ),
            (
                prompts,
                [lora_adapter_paths[0], None, lora_adapter_paths[1]],
            ),
            (
                prompts,
                [lora_adapter_paths[0], lora_adapter_paths[1], None],
            ),
            (
                prompts,
                [None, lora_adapter_paths[1], None],
            ),
            (
                prompts,
                [lora_adapter_paths[0], lora_adapter_paths[1], lora_adapter_paths[0]],
            ),
            (
                prompts,
                [None, None, None],
            ),
        ]

        ensure_reproducibility()
        with SRTRunner(
            base_path,
            torch_dtype=torch_dtype,
            model_type="generation",
            lora_paths=lora_adapter_paths,
            enable_lora_overlap_loading=enable_lora_overlap_loading,
            max_loras_per_batch=max_loras_per_batch,
            max_loaded_loras=model_case.max_loaded_loras,
            sleep_on_idle=True,
            attention_backend=attention_backend,
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
        ) as srt_runner:
            for batch_idx, (batch_prompts, lora_paths) in enumerate(test_cases):
                print(f"\n--- Batch {batch_idx + 1} ---")
                print(f"  Adapters: {lora_paths}")

                srt_outputs = srt_runner.batch_forward(
                    batch_prompts,
                    max_new_tokens=max_new_tokens,
                    lora_paths=lora_paths,
                )

                # If different adapters are used in this batch, verify that not every
                # output is identical (at least one should differ)
                unique_adapters = set(lora_paths)
                if len(unique_adapters) >= 2:
                    all_outputs = [s.strip() for s in srt_outputs.output_strs]
                    all_identical = all(out == all_outputs[0] for out in all_outputs)
                    assert not all_identical, (
                        f"Every output was identical despite using different adapters for "
                        f"base '{base_path}', batch {batch_idx + 1}: "
                        f"adapters={lora_paths}. Expected at least one output to differ."
                    )

                print(f"--- Batch {batch_idx + 1} passed ---")

    for model_case in model_cases:
        for torch_dtype in TORCH_DTYPES:
            _run_test(model_case, torch_dtype)
