from collections.abc import Callable
from typing import Any

import cloudpickle
from pydantic import ValidationError
from tqdm import tqdm

# External library imports (vLLM)
from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field
from vllm.entrypoints.llm import LLM
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.v1.engine.llm_engine import LLMEngine

from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors

# Internal imports (our code)
from vllm_omni.engine.arg_utils import OmniEngineArgs
from vllm_omni.engine.input_processor import OmniInputProcessor
from vllm_omni.engine.output_processor import MultimodalOutputProcessor
from vllm_omni.entrypoints.utils import (
    filter_dataclass_kwargs,
    load_stage_configs_from_model,
    load_stage_configs_from_yaml,
    resolve_model_config_path,
)

logger = init_logger(__name__)


class OmniLLM(LLM):
    """Main entry point for vLLM-Omni inference.

    This class extends the base vLLM LLM class with omni-specific
    processors for handling multimodal inputs and outputs. It provides
    configuration loading for multi-stage pipelines, while stage management
    is handled by the Omni class.

    Args:
        model: Model name or path to load
        stage_configs_path: Optional path to YAML file containing stage
            configurations. If None, configurations are loaded from the model.
        log_stats: Whether to enable statistics logging
        compilation_config: Optional compilation configuration. Can be an
            integer (compilation level), dict, or CompilationConfig instance.
        hf_overrides: Optional HuggingFace model configuration overrides
        structured_outputs_config: Optional structured outputs configuration.
            Can be a dict or StructuredOutputsConfig instance.
        init_sleep_seconds: Number of seconds to sleep between starting
            each stage process during initialization (used by Omni class)
        shm_threshold_bytes: Threshold in bytes for using shared memory
            for IPC. Objects larger than this threshold will use shared memory.
        batch_timeout: Timeout in seconds for batching requests within a stage
        init_timeout: Timeout in seconds for waiting for all stages to initialize
        **kwargs: Additional keyword arguments passed to the base LLM class
            and engine

    Example:
        >>> llm = OmniLLM(model="Qwen/Qwen2.5-Omni-7B")
        >>> # Stage management is handled by Omni class
    """

    def __init__(
        self,
        model: str,
        stage_configs_path: str | None = None,
        log_stats: bool = False,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        hf_overrides: dict[str, Any] | None = None,
        structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None,
        init_sleep_seconds: int = 20,
        shm_threshold_bytes: int = 65536,
        batch_timeout: int = 10,
        init_timeout: int = 300,
        **kwargs: Any,
    ):
        """LLM constructor with omni-specific configuration loading."""
        # Store stage management parameters (used by Omni class)
        self.worker_backend = kwargs.get("worker_backend", "multi_process")
        self.ray_address = kwargs.get("ray_address", None)
        self.batch_timeout = batch_timeout
        self.log_stats: bool = bool(log_stats)

        # Load stage configurations
        if stage_configs_path is None:
            self.config_path = resolve_model_config_path(model)
            self.stage_configs = load_stage_configs_from_model(model)
        else:
            self.config_path = stage_configs_path
            self.stage_configs = load_stage_configs_from_yaml(stage_configs_path)

        # Initialize connectors
        self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors(
            self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes
        )

        # Initialize LLM engine
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True

        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

        if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict):
            from vllm.config.kv_transfer import KVTransferConfig

            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict,
                    e,
                )
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e

        # Extract omni_kv_config from kwargs if present (injected by Omni)
        omni_kv_config = kwargs.pop("omni_kv_config", None)

        if compilation_config is not None:
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
                    **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)}
                )
            else:
                compilation_config_instance = compilation_config
        else:
            compilation_config_instance = CompilationConfig()

        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)}
                )
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

        engine_args = OmniEngineArgs(
            model=model,
            compilation_config=compilation_config_instance,
            structured_outputs_config=structured_outputs_instance,
            omni_kv_config=omni_kv_config,
            **filter_dataclass_kwargs(OmniEngineArgs, kwargs),
        )

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.llm_engine.output_processor = MultimodalOutputProcessor(
            tokenizer=self.llm_engine.tokenizer,
            log_stats=self.llm_engine.log_stats,
            engine_core_output_type=engine_args.engine_output_type,
        )
        self.llm_engine.input_processor = OmniInputProcessor(vllm_config=self.llm_engine.vllm_config)
        self.engine_class = type(self.llm_engine)

        self.request_counter = Counter()
        self.default_sampling_params: dict[str, Any] | None = None

        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin)
        self.model_config = self.llm_engine.model_config
        self.input_processor = self.llm_engine.input_processor

    def close(self) -> None:
        """Close resources.

        Note: Stage management is now handled by Omni class.
        This method closes the LLM engine but not stages.
        """
        # Close the LLM engine if it exists
        if hasattr(self, "llm_engine") and self.llm_engine is not None:
            if hasattr(self.llm_engine, "shutdown"):
                self.llm_engine.shutdown()

    def __del__(self) -> None:  # best-effort
        try:
            self.close()
        except Exception as e:
            logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True)

    def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_engine.get_num_unfinished_requests()
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
            )

        # Run the engine.
        outputs: list[RequestOutput | PoolingRequestOutput] = []
        total_in_toks = 0
        total_out_toks = 0
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
                            n = len(output.outputs)
                            assert output.prompt_token_ids is not None
                            total_in_toks += len(output.prompt_token_ids) * n
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(len(stp.token_ids) for stp in output.outputs)
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
                            pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s"
                            pbar.update(n)
                        else:
                            pbar.update(1)
                        if pbar.n == num_requests:
                            pbar.refresh()

        if use_tqdm:
            pbar.close()
        # Sort the outputs by the int part of request ID which is in format of 'int-uuid'.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0]))
