import shutil
import tempfile
from pathlib import Path
from typing import List, Union

import imageio
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars
from rich.progress import Progress, SpinnerColumn, TextColumn

from ltx_trainer import logger
from ltx_trainer.config import LtxTrainerConfig


def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxTrainerConfig) -> None:
    """Push the trained LoRA weights to HuggingFace Hub."""
    if not config.hub.hub_model_id:
        logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub")
        return

    api = HfApi()

    # Save original progress bar state
    original_progress_state = are_progress_bars_disabled()
    disable_progress_bars()  # Disable during our custom progress tracking

    try:
        # Try to create repo if it doesn't exist
        try:
            repo = create_repo(
                repo_id=config.hub.hub_model_id,
                repo_type="model",
                exist_ok=True,  # Don't raise error if repo exists
            )
            repo_id = repo.repo_id
            logger.info(f"🤗 Successfully created HuggingFace model repository at: {repo.url}")
        except Exception as e:
            logger.error(f"❌ Failed to create HuggingFace model repository: {e}")
            return

        # Create a single temporary directory for all files
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = Path(temp_dir)

            with Progress(
                SpinnerColumn(),
                TextColumn("[progress.description]{task.description}"),
                transient=True,
            ) as progress:
                try:
                    # Copy weights
                    task_copy = progress.add_task("Copying weights...", total=None)
                    weights_dest = temp_path / weights_path.name
                    shutil.copy2(weights_path, weights_dest)
                    progress.update(task_copy, description="✓ Weights copied")

                    # Create model card and save samples
                    task_card = progress.add_task("Creating model card and samples...", total=None)
                    _create_model_card(
                        output_dir=temp_path,
                        videos=sampled_videos_paths,
                        config=config,
                    )
                    progress.update(task_card, description="✓ Model card and samples created")

                    # Upload everything at once
                    task_upload = progress.add_task("Pushing files to HuggingFace Hub...", total=None)
                    api.upload_folder(
                        folder_path=str(temp_path),
                        repo_id=repo_id,
                        repo_type="model",
                    )
                    progress.update(task_upload, description="✓ Files pushed to HuggingFace Hub")
                    logger.info("✅ Successfully pushed files to HuggingFace Hub")

                except Exception as e:
                    logger.error(f"❌ Failed to process and push files to HuggingFace Hub: {e}")
                    raise  # Re-raise to handle in outer try block

    finally:
        # Restore original progress bar state
        if not original_progress_state:
            enable_progress_bars()


def convert_video_to_gif(video_path: Path, output_path: Path) -> None:
    """Convert a video file to GIF format."""
    try:
        # Read the video file
        reader = imageio.get_reader(str(video_path))
        fps = reader.get_meta_data()["fps"]

        # Write GIF file with infinite loop
        writer = imageio.get_writer(
            str(output_path),
            fps=min(fps, 15),  # Cap FPS at 15 for reasonable file size
            loop=0,  # 0 means infinite loop
        )

        for frame in reader:
            writer.append_data(frame)

        writer.close()
        reader.close()
    except Exception as e:
        logger.error(f"Failed to convert video to GIF: {e}")


def _create_model_card(
    output_dir: Union[str, Path],
    videos: List[Path],
    config: LtxTrainerConfig,
) -> Path:
    """Generate and save a model card for the trained model."""

    repo_id = config.hub.hub_model_id
    pretrained_model_name_or_path = config.model.model_path
    validation_prompts = config.validation.prompts
    output_dir = Path(output_dir)
    template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md"

    # Read the template
    template = template_path.read_text()

    # Get model name from repo_id
    model_name = repo_id.split("/")[-1]

    # Get base model information
    base_model_link = str(pretrained_model_name_or_path)
    model_path_str = str(pretrained_model_name_or_path)
    is_url = model_path_str.startswith(("http://", "https://"))

    # For URLs, extract the filename from the URL. For local paths, use the filename stem
    base_model_name = model_path_str.split("/")[-1] if is_url else Path(pretrained_model_name_or_path).name

    # Format validation prompts and create grid layout
    prompts_text = ""
    sample_grid = []

    if validation_prompts and videos:
        prompts_text = "Example prompts used during validation:\n\n"

        # Create samples directory
        samples_dir = output_dir / "samples"
        samples_dir.mkdir(exist_ok=True, parents=True)

        # Process videos and create cells
        cells = []
        for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)):
            if video.exists():
                # Add prompt to text section
                prompts_text += f"- `{prompt}`\n"

                # Convert video to GIF
                gif_path = samples_dir / f"sample_{i}.gif"
                try:
                    convert_video_to_gif(video, gif_path)

                    # Create grid cell with collapsible description
                    cell = (
                        f"![example{i + 1}](./samples/sample_{i}.gif)"
                        "<br>"
                        '<details style="max-width: 300px; margin: auto;">'
                        f"<summary>Prompt</summary>"
                        f"{prompt}"
                        "</details>"
                    )
                    cells.append(cell)
                except Exception as e:
                    logger.error(f"Failed to process video {video}: {e}")

        # Calculate optimal grid dimensions
        num_cells = len(cells)
        if num_cells > 0:
            # Aim for a roughly square grid, with max 4 columns
            num_cols = min(4, num_cells)
            num_rows = (num_cells + num_cols - 1) // num_cols  # Ceiling division

            # Create grid rows
            for row in range(num_rows):
                start_idx = row * num_cols
                end_idx = min(start_idx + num_cols, num_cells)
                row_cells = cells[start_idx:end_idx]
                # Properly format the row with table markers and exact number of cells
                formatted_row = "| " + " | ".join(row_cells) + " |"
                sample_grid.append(formatted_row)

    # Join grid rows with just the content, no headers needed
    grid_text = "\n".join(sample_grid) if sample_grid else ""

    # Fill in the template
    model_card_content = template.format(
        base_model=base_model_name,
        base_model_link=base_model_link,
        model_name=model_name,
        training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning",
        training_steps=config.optimization.steps,
        learning_rate=config.optimization.learning_rate,
        batch_size=config.optimization.batch_size,
        validation_prompts=prompts_text,
        sample_grid=grid_text,
    )

    # Save the model card directly
    model_card_path = output_dir / "README.md"
    model_card_path.write_text(model_card_content)

    return model_card_path
