"""Progress tracking for LTX training.
This module provides a unified progress display for training and validation sampling,
encapsulating all Rich progress bar logic in one place.
"""

from rich.progress import (
    BarColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)


class SamplingContext:
    """Context for validation sampling progress tracking.
    Provides a unified progress display showing current video and denoising step.
    Display format: "Sampling X/Y [████████████] step Z/W"
    The progress bar shows the denoising progress for the current video.
    """

    def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int):
        self._progress = progress
        self._task = task
        self._num_prompts = num_prompts
        self._num_steps = num_steps

    def start_video(self, video_idx: int) -> None:
        """Start tracking a new video (resets step progress)."""
        if self._progress is None or self._task is None:
            return
        # Reset task for new video: completed=0, total=num_steps
        self._progress.reset(self._task, total=self._num_steps)
        self._progress.update(
            self._task,
            completed=0,
            video=f"{video_idx + 1}/{self._num_prompts}",
            info=f"step 0/{self._num_steps}",
        )

    def advance_step(self) -> None:
        """Advance the denoising step by one."""
        if self._progress is None or self._task is None:
            return
        self._progress.advance(self._task)
        completed = int(self._progress.tasks[self._task].completed)
        self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")

    def cleanup(self) -> None:
        """Hide sampling task when done."""
        if self._progress is None or self._task is None:
            return
        self._progress.update(self._task, visible=False)


class StandaloneSamplingProgress:
    """Standalone progress display for inference scripts.
    Unlike SamplingContext (which integrates with TrainingProgress), this class
    manages its own Rich Progress instance for use in standalone inference scripts.
    Usage:
        with StandaloneSamplingProgress(num_steps=30) as ctx:
            for step in range(30):
                # ... denoising step ...
                ctx.advance_step()
    """

    def __init__(self, num_steps: int, description: str = "Generating"):
        """Initialize standalone sampling progress.
        Args:
            num_steps: Total number of denoising steps
            description: Description to show in progress bar
        """
        self._num_steps = num_steps
        self._description = description
        self._progress: Progress | None = None
        self._task: TaskID | None = None

    def __enter__(self) -> "StandaloneSamplingProgress":
        """Start the progress display."""
        self._progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(bar_width=40, style="blue"),
            TextColumn("{task.fields[info]}", style="cyan"),
            TimeElapsedColumn(),
            TextColumn("ETA:"),
            TimeRemainingColumn(compact=True),
        )
        self._progress.__enter__()
        self._task = self._progress.add_task(
            self._description,
            total=self._num_steps,
            info=f"step 0/{self._num_steps}",
        )
        return self

    def __exit__(self, *args) -> None:
        """Stop the progress display."""
        if self._progress is not None:
            self._progress.__exit__(*args)

    def advance_step(self) -> None:
        """Advance the denoising step by one."""
        if self._progress is None or self._task is None:
            return
        self._progress.advance(self._task)
        completed = int(self._progress.tasks[self._task].completed)
        self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")


class TrainingProgress:
    """Manages Rich progress display for training and validation.
    This class encapsulates all progress bar logic, providing a clean interface
    for the trainer to update progress without dealing with Rich internals.
    Usage:
        with TrainingProgress(enabled=True, total_steps=1000) as progress:
            for step in range(1000):
                # ... training step ...
                progress.update_training(loss=0.1, lr=1e-4, step_time=0.5)
                if should_validate:
                    sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30)
                    sampler = ValidationSampler(..., sampling_context=sampling_ctx)
                    for prompt_idx, prompt in enumerate(prompts):
                        sampling_ctx.start_video(prompt_idx)
                        sampler.generate(...)
                    sampling_ctx.cleanup()
    """

    def __init__(self, enabled: bool, total_steps: int):
        """Initialize progress tracking.
        Args:
            enabled: Whether to display progress bars (False for non-main processes)
            total_steps: Total number of training steps
        """
        self._enabled = enabled
        self._total_steps = total_steps
        self._train_task: TaskID | None = None

        if not enabled:
            self._progress = None
            return

        # Single Progress instance with flexible columns
        self._progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            TextColumn("{task.fields[video]}", style="magenta"),
            BarColumn(bar_width=40, style="blue"),
            TextColumn("{task.fields[info]}", style="cyan"),
            TimeElapsedColumn(),
            TextColumn("ETA:"),
            TimeRemainingColumn(compact=True),
        )

    def __enter__(self) -> "TrainingProgress":
        """Enter the progress context, starting the live display."""
        if self._progress is not None:
            self._progress.__enter__()
            self._train_task = self._progress.add_task(
                "Training",
                total=self._total_steps,
                video=f"0/{self._total_steps}",
                info="Starting...",
            )
        return self

    def __exit__(self, *args) -> None:
        """Exit the progress context, stopping the live display."""
        if self._progress is not None:
            self._progress.__exit__(*args)

    @property
    def enabled(self) -> bool:
        """Whether progress display is enabled."""
        return self._enabled

    def update_training(
        self,
        *,
        loss: float,
        lr: float,
        step_time: float,
        advance: bool = True,
    ) -> None:
        """Update the training progress display.
        Args:
            loss: Current training loss
            lr: Current learning rate
            step_time: Time taken for this step in seconds
            advance: Whether to advance the progress by one step
        """
        if self._progress is None or self._train_task is None:
            return

        info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step"
        self._progress.update(
            self._train_task,
            advance=1 if advance else 0,
            info=info,
        )
        # Update step count in video column
        completed = int(self._progress.tasks[self._train_task].completed)
        self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}")

    def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext:
        """Start validation sampling progress tracking.
        Creates a task that shows current video and denoising step progress.
        Format: "Sampling X/Y [████████████] step Z/W"
        Args:
            num_prompts: Number of validation prompts to sample
            num_steps: Number of denoising steps per sample
        Returns:
            SamplingContext for tracking progress (no-op if progress is disabled)
        """
        if self._progress is None:
            # Return a no-op context when progress is disabled
            return SamplingContext(
                progress=None,
                task=None,
                num_prompts=num_prompts,
                num_steps=num_steps,
            )

        task = self._progress.add_task(
            "Sampling",
            total=num_steps,
            completed=0,
            video=f"0/{num_prompts}",
            info=f"step 0/{num_steps}",
        )

        return SamplingContext(
            progress=self._progress,
            task=task,
            num_prompts=num_prompts,
            num_steps=num_steps,
        )
