from __future__ import annotations

import logging
from collections.abc import Iterator

import torch

from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.components.schedulers import LTX2Scheduler
from ltx_core.conditioning.types.noise_mask_cond import TemporalRegionMask
from ltx_core.loader import LoraPathStrengthAndSDOps
from ltx_core.loader.registry import Registry
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.quantization import QuantizationPolicy
from ltx_core.types import (
    SpatioTemporalScaleFactors,
)
from ltx_pipelines.utils.args import video_editing_arg_parser
from ltx_pipelines.utils.blocks import (
    AudioConditioner,
    AudioDecoder,
    DiffusionStage,
    ImageConditioner,
    PromptEncoder,
    VideoDecoder,
)
from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, detect_params
from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
from ltx_pipelines.utils.helpers import (
    audio_latent_from_file,
    get_device,
    video_latent_from_file,
)
from ltx_pipelines.utils.media_io import (
    encode_video,
    get_videostream_metadata,
)
from ltx_pipelines.utils.types import ModalitySpec


class RetakePipeline:
    """Regenerate a time region (retake) of an existing video.
    Given a source video file and a time window ``[start_time, end_time]``
    (in seconds), this pipeline keeps the video/audio outside that window
    unchanged and *regenerates* the content inside the window from a text
    prompt using the LTX-2 diffusion model.
    Parameters
    ----------
    checkpoint_path : str
        Path to the LTX-2 model checkpoint.
    gemma_root : str
        Root directory containing Gemma text-encoder weights.
    loras : list[LoraPathStrengthAndSDOps]
        Optional LoRA configs applied to the transformer.
    device : torch.device
        Target device (default: CUDA if available).
    quantization : QuantizationPolicy | None
        Optional quantization policy for the transformer.
    distilled : bool
        Set to ``True`` if using distilled model or passing distillation
        lora with full model. If set to ``True``, distilled sigma schedule
        (``DISTILLED_SIGMA_VALUES``) and a simple (non-guided) denoising
        function will be used during ``__call__``.
    """

    def __init__(
        self,
        checkpoint_path: str,
        gemma_root: str,
        loras: list[LoraPathStrengthAndSDOps],
        device: torch.device | None = None,
        quantization: QuantizationPolicy | None = None,
        registry: Registry | None = None,
        distilled: bool = True,
        torch_compile: bool = False,
    ):
        self.device = device or get_device()
        self.dtype = torch.bfloat16
        self.distilled = distilled
        self.prompt_encoder = PromptEncoder(
            checkpoint_path=checkpoint_path,
            gemma_root=gemma_root,
            dtype=self.dtype,
            device=self.device,
            registry=registry,
        )
        self.image_conditioner = ImageConditioner(
            checkpoint_path=checkpoint_path,
            dtype=self.dtype,
            device=self.device,
            registry=registry,
        )
        self.audio_conditioner = AudioConditioner(
            checkpoint_path=checkpoint_path,
            dtype=self.dtype,
            device=self.device,
            registry=registry,
        )
        self.stage = DiffusionStage(
            checkpoint_path=checkpoint_path,
            dtype=self.dtype,
            device=self.device,
            loras=tuple(loras),
            quantization=quantization,
            registry=registry,
            torch_compile=torch_compile,
        )
        self.video_decoder = VideoDecoder(
            checkpoint_path=checkpoint_path,
            dtype=self.dtype,
            device=self.device,
            registry=registry,
        )
        self.audio_decoder = AudioDecoder(
            checkpoint_path=checkpoint_path,
            dtype=self.dtype,
            device=self.device,
            registry=registry,
        )

    # --------------------------------------------------------------------- #
    #  Public entry point                                                     #
    # --------------------------------------------------------------------- #

    def __call__(  # noqa: PLR0913
        self,
        video_path: str,
        prompt: str,
        start_time: float,
        end_time: float,
        seed: int,
        *,
        negative_prompt: str = "",
        num_inference_steps: int = 40,
        video_guider_params: MultiModalGuiderParams | None = None,
        audio_guider_params: MultiModalGuiderParams | None = None,
        regenerate_video: bool = True,
        regenerate_audio: bool = True,
        enhance_prompt: bool = False,
        tiling_config: TilingConfig | None = None,
        streaming_prefetch_count: int | None = None,
        max_batch_size: int = 1,
    ) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
        """Regenerate ``[start_time, end_time]`` of the source video (retake).
        Parameters
        ----------
        video_path : str
            Path to the source video file (must contain video; audio is optional).
        prompt : str
            Text prompt describing the *regenerated* section.
        start_time, end_time : float
            Time window (in seconds) of the section to regenerate.
        seed : int
            Random seed for reproducibility.
        negative_prompt : str
            Negative prompt for CFG guidance (ignored in distilled mode).
        num_inference_steps : int
            Number of Euler denoising steps (ignored in distilled mode which
            uses a fixed 8-step schedule).
        video_guider_params, audio_guider_params : MultiModalGuiderParams | None
            Guidance parameters for video and audio modalities.  Ignored in
            distilled mode.
        regenerate_video : bool
            If ``True`` (default), regenerate video inside ``[start_time, end_time]``.
            If ``False``, video is preserved as-is (no regeneration).
        regenerate_audio : bool
            If True, regenerate audio in the [start_time, end_time] window; if False,
            audio is preserved as-is (no regeneration).
        enhance_prompt : bool
            Whether to enhance the prompt via the text encoder.
        Returns
        -------
        tuple[Iterator[torch.Tensor], torch.Tensor]
            ``(video_frames_iterator, audio_waveform)``
        """
        if start_time >= end_time:
            raise ValueError(f"start_time ({start_time}) must be less than end_time ({end_time})")

        generator = torch.Generator(device=self.device).manual_seed(seed)
        noiser = GaussianNoiser(generator=generator)
        dtype = self.dtype

        output_shape = get_videostream_metadata(video_path)
        initial_video_latent = self.image_conditioner(
            lambda enc: video_latent_from_file(
                video_encoder=enc,
                file_path=video_path,
                output_shape=output_shape,
                dtype=dtype,
                device=self.device,
            )
        )

        initial_audio_latent = self.audio_conditioner(
            lambda enc: audio_latent_from_file(
                audio_encoder=enc,
                file_path=video_path,
                output_shape=output_shape,
                dtype=dtype,
                device=self.device,
            )
        )

        prompts_to_encode = [prompt] if self.distilled else [prompt, negative_prompt]
        contexts = self.prompt_encoder(
            prompts_to_encode,
            enhance_first_prompt=enhance_prompt,
            enhance_prompt_seed=seed,
            streaming_prefetch_count=streaming_prefetch_count,
        )

        v_context_p, a_context_p = contexts[0].video_encoding, contexts[0].audio_encoding
        video_modality_spec = ModalitySpec(
            context=v_context_p,
            conditionings=[TemporalRegionMask(start_time=start_time, end_time=end_time, fps=output_shape.fps)]
            if regenerate_video
            else [],
            initial_latent=initial_video_latent,
            frozen=not regenerate_video,
        )
        audio_modality_spec = ModalitySpec(
            context=a_context_p,
            conditionings=[TemporalRegionMask(start_time=start_time, end_time=end_time, fps=output_shape.fps)]
            if (initial_audio_latent is not None and regenerate_audio)
            else [],
            initial_latent=initial_audio_latent,
            frozen=initial_audio_latent is not None and not regenerate_audio,
        )
        # Build denoiser
        if self.distilled:
            sigmas = torch.tensor(DISTILLED_SIGMA_VALUES).to(dtype=torch.float32, device=self.device)
            denoiser = SimpleDenoiser(
                v_context=v_context_p,
                a_context=a_context_p,
            )
        else:
            sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
            v_context_n, a_context_n = contexts[1].video_encoding, contexts[1].audio_encoding
            video_guider = MultiModalGuider(
                params=video_guider_params,
                negative_context=v_context_n,
            )
            audio_guider = MultiModalGuider(
                params=audio_guider_params,
                negative_context=a_context_n,
            )
            denoiser = GuidedDenoiser(
                v_context=v_context_p,
                a_context=a_context_p,
                video_guider=video_guider,
                audio_guider=audio_guider,
            )

        # Run diffusion stage
        video_state, audio_state = self.stage(
            denoiser=denoiser,
            sigmas=sigmas,
            noiser=noiser,
            width=output_shape.width,
            height=output_shape.height,
            frames=output_shape.frames,
            fps=output_shape.fps,
            video=video_modality_spec,
            audio=audio_modality_spec,
            streaming_prefetch_count=streaming_prefetch_count,
            max_batch_size=max_batch_size,
        )

        # Decode
        decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
        decoded_audio = self.audio_decoder(audio_state.latent)

        return decoded_video, decoded_audio


@torch.inference_mode()
def main() -> None:
    """CLI entry point for retake (regenerate a time region)."""
    logging.getLogger().setLevel(logging.INFO)
    parser = video_editing_arg_parser(distilled=True)
    parser.description = "Retake: regenerate a time region of a video with LTX-2."
    args = parser.parse_args()

    if args.start_time >= args.end_time:
        raise ValueError("start_time must be less than end_time")

    # Validate frame count (8k+1) and resolution (multiples of 32) at CLI stage
    video_scale = SpatioTemporalScaleFactors.default()
    src = get_videostream_metadata(args.video_path)
    if (src.frames - 1) % video_scale.time != 0:
        snapped = ((src.frames - 1) // video_scale.time) * video_scale.time + 1
        raise ValueError(
            f"Video frame count must satisfy 8k+1 (e.g. 97, 193). Got {src.frames}; use a video with {snapped} frames."
        )
    if src.width % 32 != 0 or src.height % 32 != 0:
        raise ValueError(f"Video width and height must be multiples of 32. Got {src.width}x{src.height}.")

    pipeline = RetakePipeline(
        checkpoint_path=args.distilled_checkpoint_path,
        gemma_root=args.gemma_root,
        loras=tuple(args.lora) if args.lora else (),
        quantization=args.quantization,
        distilled=args.distilled,
        torch_compile=args.compile,
    )
    params = detect_params(args.distilled_checkpoint_path)
    tiling_config = TilingConfig.default()
    video_iter, audio = pipeline(
        video_path=args.video_path,
        prompt=args.prompt,
        start_time=args.start_time,
        end_time=args.end_time,
        seed=args.seed,
        video_guider_params=params.video_guider_params,
        audio_guider_params=params.audio_guider_params,
        tiling_config=tiling_config,
        streaming_prefetch_count=args.streaming_prefetch_count,
        max_batch_size=args.max_batch_size,
    )
    video_chunks_number = get_video_chunks_number(src.frames, tiling_config)
    encode_video(
        video=video_iter,
        fps=int(src.fps),
        audio=audio,
        output_path=args.output_path,
        video_chunks_number=video_chunks_number,
    )


if __name__ == "__main__":
    main()
