# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py

import argparse
import dataclasses
import json
import os
from typing import cast

from sglang.multimodal_gen import DiffGenerator
from sglang.multimodal_gen.configs.sample.sampling_params import (
    SamplingParams,
    generate_request_id,
)
from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand
from sglang.multimodal_gen.runtime.entrypoints.cli.utils import (
    RaiseNotImplementedAction,
)
from sglang.multimodal_gen.runtime.entrypoints.utils import GenerationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.runtime.utils.perf_logger import (
    MemorySnapshot,
    PerformanceLogger,
    RequestMetrics,
)
from sglang.multimodal_gen.utils import FlexibleArgumentParser

logger = init_logger(__name__)


def add_multimodal_gen_generate_args(parser: argparse.ArgumentParser):
    """Add the arguments for the generate command."""
    parser.add_argument(
        "--config",
        type=str,
        default="",
        required=False,
        help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional.",
    )
    parser.add_argument(
        "--perf-dump-path",
        type=str,
        default=None,
        required=False,
        help="Path to dump the performance metrics (JSON) for the run.",
    )

    parser = ServerArgs.add_cli_args(parser)
    parser = SamplingParams.add_cli_args(parser)

    parser.add_argument(
        "--text-encoder-configs",
        action=RaiseNotImplementedAction,
        help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
    )

    return parser


def maybe_dump_performance(
    args: argparse.Namespace,
    server_args,
    prompt: str,
    results: GenerationResult | list[GenerationResult] | None,
):
    """dump performance if necessary"""
    if not (args.perf_dump_path and results):
        return

    if isinstance(results, list):
        result = results[0] if results else None
    else:
        result = results

    metrics_dict = result.metrics
    if not (args.perf_dump_path and metrics_dict):
        return

    metrics = RequestMetrics(request_id=metrics_dict.get("request_id"))
    metrics.stages = metrics_dict.get("stages", {})
    metrics.steps = metrics_dict.get("steps", [])
    metrics.total_duration_ms = metrics_dict.get("total_duration_ms", 0)

    # restore memory snapshots from serialized dict
    memory_snapshots_dict = metrics_dict.get("memory_snapshots", {})
    for checkpoint_name, snapshot_dict in memory_snapshots_dict.items():
        snapshot = MemorySnapshot(
            allocated_mb=snapshot_dict.get("allocated_mb", 0.0),
            reserved_mb=snapshot_dict.get("reserved_mb", 0.0),
            peak_allocated_mb=snapshot_dict.get("peak_allocated_mb", 0.0),
            peak_reserved_mb=snapshot_dict.get("peak_reserved_mb", 0.0),
        )
        metrics.memory_snapshots[checkpoint_name] = snapshot

    PerformanceLogger.dump_benchmark_report(
        file_path=args.perf_dump_path,
        metrics=metrics,
        meta={
            "prompt": prompt,
            "model": server_args.model_path,
        },
        tag="cli_generate",
    )


def generate_cmd(args: argparse.Namespace):
    """The entry point for the generate command."""
    args.request_id = "mocked_fake_id_for_offline_generate"

    server_args = ServerArgs.from_cli_args(args)

    sampling_params_kwargs = SamplingParams.get_cli_args(args)
    sampling_params_kwargs["request_id"] = generate_request_id()

    # Handle diffusers-specific kwargs passed via CLI
    if hasattr(args, "diffusers_kwargs") and args.diffusers_kwargs:
        try:
            sampling_params_kwargs["diffusers_kwargs"] = json.loads(
                args.diffusers_kwargs
            )
            logger.info(
                "Parsed diffusers_kwargs: %s",
                sampling_params_kwargs["diffusers_kwargs"],
            )
        except json.JSONDecodeError as e:
            logger.error("Failed to parse --diffusers-kwargs as JSON: %s", e)
            raise ValueError(
                f"--diffusers-kwargs must be valid JSON. Got: {args.diffusers_kwargs}"
            ) from e

    generator = DiffGenerator.from_pretrained(
        model_path=server_args.model_path, server_args=server_args, local_mode=True
    )

    results = generator.generate(sampling_params_kwargs=sampling_params_kwargs)

    prompt = sampling_params_kwargs.get("prompt")
    maybe_dump_performance(args, server_args, prompt, results)


class GenerateSubcommand(CLISubcommand):
    """The `generate` subcommand for the sglang-diffusion CLI"""

    def __init__(self) -> None:
        self.name = "generate"
        super().__init__()
        self.init_arg_names = self._get_init_arg_names()
        self.generation_arg_names = self._get_generation_arg_names()

    def _get_init_arg_names(self) -> list[str]:
        """Get names of arguments for DiffGenerator initialization"""
        return ["num_gpus", "tp_size", "sp_size", "model_path"]

    def _get_generation_arg_names(self) -> list[str]:
        """Get names of arguments for generate_video method"""
        return [field.name for field in dataclasses.fields(SamplingParams)]

    def cmd(self, args: argparse.Namespace) -> None:
        generate_cmd(args)

    def validate(self, args: argparse.Namespace) -> None:
        """Validate the arguments for this command"""
        if args.num_gpus is not None and args.num_gpus <= 0:
            raise ValueError("Number of gpus must be positive")

        if args.config and not os.path.exists(args.config):
            raise ValueError(f"Config file not found: {args.config}")

    def subparser_init(
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
        generate_parser = subparsers.add_parser(
            "generate",
            help="Run inference on a model",
            usage="sgl_diffusion generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]",
        )

        generate_parser = add_multimodal_gen_generate_args(generate_parser)

        return cast(FlexibleArgumentParser, generate_parser)
