# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
import argparse
import dataclasses
import hashlib
import json
import math
import os
import os.path
import re
import time
import unicodedata
import uuid
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any

from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import StoreBoolean

logger = init_logger(__name__)


def _json_safe(obj: Any):
    """
    Recursively convert objects to JSON-serializable forms.
    - Enums -> their name
    - Sets/Tuples -> lists
    - Dicts/Lists -> recursively processed
    """
    if isinstance(obj, Enum):
        return obj.name
    if isinstance(obj, dict):
        return {k: _json_safe(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple, set)):
        return [_json_safe(v) for v in obj]
    return obj


def generate_request_id() -> str:
    return str(uuid.uuid4())


def _sanitize_filename(name: str, replacement: str = "_", max_length: int = 150) -> str:
    """Create a filesystem- and ffmpeg-friendly filename.

    - Normalize to ASCII (drop accents and unsupported chars)
    - Replace spaces with underscores
    - Replace any char not in [A-Za-z0-9_.-] with replacement
    - Collapse multiple underscores
    - Trim leading/trailing dots/underscores and limit length
    """
    normalized = unicodedata.normalize("NFKD", name)
    ascii_name = normalized.encode("ascii", "ignore").decode("ascii")
    ascii_name = ascii_name.replace(" ", "_")
    ascii_name = re.sub(r"[^A-Za-z0-9._-]", replacement, ascii_name)
    ascii_name = re.sub(r"_+", "_", ascii_name).strip("._")
    if not ascii_name:
        ascii_name = "output"
    if max_length and len(ascii_name) > max_length:
        ascii_name = ascii_name[:max_length]
    return ascii_name


class DataType(Enum):
    IMAGE = auto()
    VIDEO = auto()

    def get_default_extension(self) -> str:
        if self == DataType.IMAGE:
            return "png"
        else:
            return "mp4"


@dataclass
class SamplingParams:
    """
    Sampling parameters for generation.
    """

    data_type: DataType = DataType.VIDEO

    request_id: str | None = None

    # All fields below are copied from ForwardBatch

    # Image inputs
    image_path: str | list[str] | None = None

    # Text inputs
    prompt: str | list[str] | None = None
    negative_prompt: str = (
        "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
    )
    prompt_path: str | None = None
    output_path: str | None = None
    output_file_name: str | None = None
    output_quality: str | None = "default"
    output_compression: int | None = None

    # Batch info
    num_outputs_per_prompt: int = 1
    seed: int = 42
    generator_device: str = "cuda"  # Device for random generator: "cuda" or "cpu"

    # Original dimensions (before VAE scaling)
    num_frames: int = 1  # Default for image models
    num_frames_round_down: bool = (
        False  # Whether to round down num_frames if it's not divisible by num_gpus
    )
    height: int | None = None
    width: int | None = None
    # NOTE: this is temporary, we need a way to know if width or height is not provided, or do the image resize earlier
    height_not_provided: bool = False
    width_not_provided: bool = False
    fps: int = 24

    # Resolution validation
    supported_resolutions: list[tuple[int, int]] | None = (
        None  # None means all resolutions allowed
    )

    # Denoising parameters
    num_inference_steps: int = None
    guidance_scale: float = 1.0
    guidance_scale_2: float = None
    true_cfg_scale: float = None  # for CFG vs guidance distillation (e.g., QwenImage)
    guidance_rescale: float = 0.0
    cfg_normalization: float | bool = 0.0
    boundary_ratio: float | None = None

    # TeaCache parameters
    enable_teacache: bool = False

    # Profiling
    profile: bool = False
    num_profiled_timesteps: int = 5
    profile_all_stages: bool = False

    # Debugging
    debug: bool = False
    perf_dump_path: str | None = None

    # Misc
    save_output: bool = True
    return_frames: bool = False
    return_trajectory_latents: bool = False  # returns all latents for each timestep
    return_trajectory_decoded: bool = False  # returns decoded latents for each timestep
    # if True, disallow user params to override subclass-defined protected fields
    no_override_protected_fields: bool = False
    # whether to adjust num_frames for multi-GPU friendly splitting (default: True)
    adjust_frames: bool = True
    # if True, suppress verbose logging for this request
    suppress_logs: bool = False

    return_file_paths_only: bool = True
    enable_sequence_shard: bool = False

    def _set_output_file_ext(self):
        # add extension if needed
        if not any(
            self.output_file_name.endswith(ext)
            for ext in [".mp4", ".jpg", ".png", ".webp"]
        ):
            self.output_file_name = (
                f"{self.output_file_name}.{self.data_type.get_default_extension()}"
            )

    def _set_output_file_name(self):
        # settle output_file_name
        if (
            self.output_file_name is None
            and self.prompt
            and isinstance(self.prompt, str)
        ):
            # generate a random filename
            # get a hash of current params
            params_dict = dataclasses.asdict(self)
            # Avoid recursion
            params_dict["output_file_name"] = ""

            # Convert to a stable JSON string
            params_str = json.dumps(_json_safe(params_dict), sort_keys=True)
            # Create a hash
            hasher = hashlib.sha256()
            hasher.update(params_str.encode("utf-8"))
            param_hash = hasher.hexdigest()[:8]

            timestamp = time.strftime("%Y%m%d-%H%M%S")
            base = f"{self.prompt[:100]}_{timestamp}_{param_hash}"
            self.output_file_name = base

        if self.output_file_name is None:
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            self.output_file_name = f"output_{timestamp}"

        self.output_file_name = _sanitize_filename(self.output_file_name)

        # Ensure a proper extension is present
        self._set_output_file_ext()

    def __post_init__(self) -> None:
        assert self.num_frames >= 1

        if self.width is None:
            self.width_not_provided = True
        if self.height is None:
            self.height_not_provided = True

        # Handle output_quality to output_compression conversion
        if self.output_compression is None and self.output_quality is not None:
            self.output_compression = self._adjust_output_quality(
                self.output_quality, self.data_type
            )

        self._validate()

        # Allow env var to override num_inference_steps (for faster CI testing on AMD)
        env_steps = os.environ.get("SGLANG_TEST_NUM_INFERENCE_STEPS")
        if env_steps is not None and self.num_inference_steps is not None:
            self.num_inference_steps = int(env_steps)

    def _adjust_output_quality(self, output_quality: str, data_type: DataType) -> int:
        """Convert output_quality string to compression level."""
        output_quality_mapper = {"maximum": 100, "high": 90, "medium": 55, "low": 35}
        if output_quality == "default":
            return 50 if data_type == DataType.VIDEO else 75
        return output_quality_mapper.get(output_quality)

    def _validate(self):
        """
        check if the sampling params is correct by itself
        """
        if self.prompt_path and not self.prompt_path.endswith(".txt"):
            raise ValueError(
                f"prompt_path must be a txt file, got {self.prompt_path!r}"
            )

        # These are always required to be sane regardless of pipeline.
        if (
            not isinstance(self.num_outputs_per_prompt, int)
            or self.num_outputs_per_prompt <= 0
        ):
            raise ValueError(
                f"num_outputs_per_prompt must be a positive int, got {self.num_outputs_per_prompt!r}"
            )

        # Used by seconds() and video writer; fps <= 0 is always invalid.
        if not isinstance(self.fps, int) or self.fps <= 0:
            raise ValueError(f"fps must be a positive int, got {self.fps!r}")

        # num_frames is already asserted in __post_init__, but keep a friendly error here too
        # (e.g., when validation is triggered from other code paths).
        if not isinstance(self.num_frames, int) or self.num_frames <= 0:
            raise ValueError(
                f"num_frames must be a positive int, got {self.num_frames!r}"
            )

        if self.num_inference_steps is not None:
            if (
                not isinstance(self.num_inference_steps, int)
                or self.num_inference_steps <= 0
            ):
                raise ValueError(
                    f"num_inference_steps must be a positive int, got {self.num_inference_steps!r}"
                )

        # Numeric hyperparams should not be NaN/Inf and should be within basic ranges.
        # Note: bool is a subclass of int; reject it explicitly to avoid silent surprises.
        def _finite_non_negative_float(
            name: str, value: Any, allow_none: bool = True
        ) -> None:
            if value is None and allow_none:
                return
            if isinstance(value, bool) or not isinstance(value, (int, float)):
                raise ValueError(f"{name} must be a number, got {value!r}")
            if not math.isfinite(float(value)):
                raise ValueError(f"{name} must be finite, got {value!r}")
            if float(value) < 0.0:
                raise ValueError(f"{name} must be non-negative, got {value!r}")

        _finite_non_negative_float(
            "guidance_scale", self.guidance_scale, allow_none=True
        )
        _finite_non_negative_float(
            "guidance_scale_2", self.guidance_scale_2, allow_none=True
        )
        _finite_non_negative_float(
            "true_cfg_scale", self.true_cfg_scale, allow_none=True
        )
        _finite_non_negative_float(
            "guidance_rescale", self.guidance_rescale, allow_none=False
        )

        if self.cfg_normalization is None:
            self.cfg_normalization = 0.0
        elif isinstance(self.cfg_normalization, bool):
            self.cfg_normalization = 1.0 if self.cfg_normalization else 0.0

        if self.boundary_ratio is not None:
            if isinstance(self.boundary_ratio, bool) or not isinstance(
                self.boundary_ratio, (int, float)
            ):
                raise ValueError(
                    f"boundary_ratio must be a number, got {self.boundary_ratio!r}"
                )
            if not math.isfinite(float(self.boundary_ratio)):
                raise ValueError(
                    f"boundary_ratio must be finite, got {self.boundary_ratio!r}"
                )
            if not (0.0 <= float(self.boundary_ratio) <= 1.0):
                raise ValueError(
                    f"boundary_ratio must be within [0, 1], got {self.boundary_ratio!r}"
                )

    def check_sampling_param(self):
        # Keep backward-compatibility for old call sites.
        self._validate()

    def _validate_with_pipeline_config(self, pipeline_config):
        """
        check if the sampling params is compatible and valid with server_args
        """
        if pipeline_config.task_type.requires_image_input():
            # requires image input
            if self.image_path is None:
                raise ValueError(
                    f"Served model with task type '{pipeline_config.task_type.name}' requires an 'image_path' input, but none was provided"
                )

    def _adjust(
        self,
        server_args,
    ):
        """
        final adjustment, called after merged with user params
        """
        # TODO: SamplingParams should not rely on ServerArgs
        pipeline_config = server_args.pipeline_config
        if not isinstance(self.prompt, str):
            raise TypeError(f"`prompt` must be a string, but got {type(self.prompt)}")

        self.data_type = server_args.pipeline_config.task_type.data_type()

        if self.output_path is None and server_args.output_path is not None:
            self.output_path = server_args.output_path
            logger.debug(
                f"Overriding output_path with server configuration: {self.output_path}"
            )

        # Process negative prompt
        if self.negative_prompt is not None and not self.negative_prompt.isspace():
            # avoid stripping default negative prompt: ' ' for qwen-image
            self.negative_prompt = self.negative_prompt.strip()

        # Validate dimensions
        if self.num_frames <= 0:
            raise ValueError(
                f"height, width, and num_frames must be positive integers, got "
                f"height={self.height}, width={self.width}, "
                f"num_frames={self.num_frames}"
            )

        # Validate resolution against pipeline-specific supported resolutions
        if self.height is None and self.width is None:
            if self.supported_resolutions is not None:
                self.width, self.height = self.supported_resolutions[0]
                logger.info(
                    f"Resolution unspecified, using default: {self.supported_resolutions[0]}"
                )

        if self.height is not None and self.width is not None:
            if self.supported_resolutions is not None:
                if (self.width, self.height) not in self.supported_resolutions:
                    supported_str = ", ".join(
                        [f"{w}x{h}" for w, h in self.supported_resolutions]
                    )
                    error_msg = (
                        f"Unsupported resolution: {self.width}x{self.height}, output quality may suffer. "
                        f"Supported resolutions: {supported_str}"
                    )
                    logger.warning(error_msg)

        if self.enable_sequence_shard:
            self.adjust_frames = False
            logger.info(
                f"Sequence dimension shard is enabled, disabling frame adjustment"
            )

        if pipeline_config.task_type.is_image_gen():
            # settle num_frames
            if not server_args.pipeline_config.allow_set_num_frames():
                logger.debug(f"Setting `num_frames` to 1 for image generation model")
                self.num_frames = 1

        elif self.adjust_frames:
            # NOTE: We must apply adjust_num_frames BEFORE the SP alignment logic below.
            # If we apply it after, adjust_num_frames might modify the frame count
            # and break the divisibility constraint (alignment) required by num_gpus.
            self.num_frames = server_args.pipeline_config.adjust_num_frames(
                self.num_frames
            )

            # Adjust number of frames based on number of GPUs for video task
            use_temporal_scaling_frames = (
                pipeline_config.vae_config.use_temporal_scaling_frames
            )
            num_frames = self.num_frames
            num_gpus = server_args.num_gpus
            temporal_scale_factor = (
                pipeline_config.vae_config.arch_config.temporal_compression_ratio
            )

            if use_temporal_scaling_frames:
                orig_latent_num_frames = (num_frames - 1) // temporal_scale_factor + 1

            if orig_latent_num_frames % server_args.num_gpus != 0:
                # Adjust latent frames to be divisible by number of GPUs
                if self.num_frames_round_down:
                    # Ensure we have at least 1 batch per GPU
                    new_latent_num_frames = (
                        max(1, (orig_latent_num_frames // num_gpus)) * num_gpus
                    )
                else:
                    new_latent_num_frames = (
                        math.ceil(orig_latent_num_frames / num_gpus) * num_gpus
                    )

                if use_temporal_scaling_frames:
                    # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
                    new_num_frames = (
                        new_latent_num_frames - 1
                    ) * temporal_scale_factor + 1

                logger.info(
                    "Adjusting number of frames from %s to %s based on number of GPUs (%s)",
                    self.num_frames,
                    new_num_frames,
                    server_args.num_gpus,
                )
                self.num_frames = new_num_frames

        if not server_args.comfyui_mode:
            self._set_output_file_name()

    @classmethod
    def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
        from sglang.multimodal_gen.registry import get_model_info

        backend = kwargs.pop("backend", None)
        model_info = get_model_info(model_path, backend=backend)
        sampling_params: SamplingParams = model_info.sampling_param_cls(**kwargs)
        return sampling_params

    @staticmethod
    def from_user_sampling_params_args(model_path: str, server_args, *args, **kwargs):
        try:
            sampling_params = SamplingParams.from_pretrained(
                model_path, backend=server_args.backend
            )
        except (AttributeError, ValueError) as e:
            # Handle safetensors files or other cases where model_index.json is not available
            # Use appropriate SamplingParams based on pipeline_class_name from registry
            if os.path.isfile(model_path) and model_path.endswith(".safetensors"):
                # Determine which sampling params to use based on pipeline_class_name
                pipeline_class_name = getattr(server_args, "pipeline_class_name", None)

                # Try to get SamplingParams from registry
                from sglang.multimodal_gen.registry import get_pipeline_config_classes

                config_classes = (
                    get_pipeline_config_classes(pipeline_class_name)
                    if pipeline_class_name
                    else None
                )

                if config_classes is not None:
                    _, sampling_params_cls = config_classes
                    try:
                        sampling_params = sampling_params_cls()
                        logger.info(
                            f"Using {sampling_params_cls.__name__} for {pipeline_class_name} safetensors file (no model_index.json): %s",
                            model_path,
                        )
                    except Exception as import_error:
                        logger.warning(
                            f"Failed to instantiate {sampling_params_cls.__name__}: {import_error}. "
                            "Using default SamplingParams"
                        )
                        sampling_params = SamplingParams()
                else:
                    raise ValueError(
                        f"Could not get pipeline config classes for {pipeline_class_name}"
                    )
            else:
                # Re-raise if it's not a safetensors file issue
                raise

        user_kwargs = dict(kwargs)
        user_kwargs.pop("diffusers_kwargs", None)
        user_sampling_params = SamplingParams(*args, **user_kwargs)
        # TODO: refactor
        sampling_params._merge_with_user_params(user_sampling_params)
        sampling_params._adjust(server_args)

        sampling_params._validate_with_pipeline_config(server_args.pipeline_config)

        return sampling_params

    def output_size_str(self) -> str:
        return f"{self.width}x{self.height}"

    def seconds(self) -> float:
        return self.num_frames / self.fps

    @staticmethod
    def add_cli_args(parser: Any) -> Any:
        """Add CLI arguments for SamplingParam fields"""
        parser.add_argument("--data-type", type=str, nargs="+", default=DataType.VIDEO)
        parser.add_argument(
            "--num-frames-round-down",
            action="store_true",
            default=SamplingParams.num_frames_round_down,
        )
        parser.add_argument(
            "--enable-teacache",
            action="store_true",
            default=SamplingParams.enable_teacache,
        )

        # profiling
        parser.add_argument(
            "--profile",
            action="store_true",
            default=SamplingParams.profile,
            help="Enable torch profiler for denoising stage",
        )
        parser.add_argument(
            "--num-profiled-timesteps",
            type=int,
            default=SamplingParams.num_profiled_timesteps,
            help="Number of timesteps to profile after warmup",
        )
        parser.add_argument(
            "--profile-all-stages",
            action="store_true",
            dest="profile_all_stages",
            default=SamplingParams.profile_all_stages,
            help="Used with --profile, profile all pipeline stages",
        )

        parser.add_argument(
            "--debug",
            action="store_true",
            default=SamplingParams.debug,
            help="",
        )

        parser.add_argument(
            "--prompt",
            type=str,
            default=SamplingParams.prompt,
            help="Text prompt for generation",
        )
        parser.add_argument(
            "--negative-prompt",
            type=str,
            default=SamplingParams.negative_prompt,
            help="Negative text prompt for generation",
        )
        parser.add_argument(
            "--prompt-path",
            type=str,
            default=SamplingParams.prompt_path,
            help="Path to a text file containing the prompt",
        )
        parser.add_argument(
            "--output-file-name",
            type=str,
            default=SamplingParams.output_file_name,
            help="Name of the output file",
        )
        parser.add_argument(
            "--output-quality",
            type=str,
            default=SamplingParams.output_quality,
            help="Output quality setting (default, low, medium, high, maximum)",
        )
        parser.add_argument(
            "--output-compression",
            type=int,
            default=SamplingParams.output_compression,
            help="Output compression level (0-100, higher means better quality but larger file size)",
        )
        parser.add_argument(
            "--num-outputs-per-prompt",
            type=int,
            default=SamplingParams.num_outputs_per_prompt,
            help="Number of outputs to generate per prompt",
        )
        parser.add_argument(
            "--seed",
            type=int,
            default=SamplingParams.seed,
            help="Random seed for generation",
        )
        parser.add_argument(
            "--generator-device",
            type=str,
            default=SamplingParams.generator_device,
            choices=["cuda", "musa", "cpu"],
            help="Device for random generator (cuda, musa or cpu). Default: cuda",
        )
        parser.add_argument(
            "--num-frames",
            type=int,
            default=SamplingParams.num_frames,
            help="Number of frames to generate",
        )
        parser.add_argument(
            "--height",
            type=int,
            default=SamplingParams.height,
            help="Height of generated output",
        )
        parser.add_argument(
            "--width",
            type=int,
            default=SamplingParams.width,
            help="Width of generated output",
        )
        # resolution shortcuts
        parser.add_argument(
            "--4k",
            action="store_true",
            dest="resolution_4k",
            help="Set resolution to 4K (3840x2160)",
        )
        parser.add_argument(
            "--2k",
            action="store_true",
            dest="resolution_2k",
            help="Set resolution to 2K (2560x1440)",
        )
        parser.add_argument(
            "--1080p",
            action="store_true",
            dest="resolution_1080p",
            help="Set resolution to 1080p (1920x1080)",
        )
        parser.add_argument(
            "--720p",
            action="store_true",
            dest="resolution_720p",
            help="Set resolution to 720p (1280x720)",
        )

        parser.add_argument(
            "--fps",
            type=int,
            default=SamplingParams.fps,
            help="Frames per second for saved output",
        )
        parser.add_argument(
            "--num-inference-steps",
            type=int,
            default=SamplingParams.num_inference_steps,
            help="Number of denoising steps",
        )
        parser.add_argument(
            "--guidance-scale",
            type=float,
            default=SamplingParams.guidance_scale,
            help="Classifier-free guidance scale",
        )
        parser.add_argument(
            "--guidance-scale-2",
            type=float,
            default=SamplingParams.guidance_scale_2,
            dest="guidance_scale_2",
            help="Secondary guidance scale for dual-guidance models (e.g., Wan low-noise expert)",
        )
        parser.add_argument(
            "--guidance-rescale",
            type=float,
            default=SamplingParams.guidance_rescale,
            help="Guidance rescale factor",
        )
        parser.add_argument(
            "--cfg-normalization",
            type=float,
            default=SamplingParams.cfg_normalization,  # type: ignore[arg-type]
            dest="cfg_normalization",
            help=("CFG renormalization factor (for Z-Image). "),
        )
        parser.add_argument(
            "--boundary-ratio",
            type=float,
            default=SamplingParams.boundary_ratio,
            help="Boundary timestep ratio",
        )
        parser.add_argument(
            "--save-output",
            action="store_true",
            default=SamplingParams.save_output,
            help="Whether to save the output to disk",
        )
        parser.add_argument(
            "--no-save-output",
            action="store_false",
            dest="save_output",
            help="Don't save the output to disk",
        )
        parser.add_argument(
            "--return-frames",
            action="store_true",
            default=SamplingParams.return_frames,
            help="Whether to return the raw frames",
        )
        parser.add_argument(
            "--image-path",
            type=str,
            nargs="+",
            default=SamplingParams.image_path,
            help=(
                "Path(s) to input image(s) for image-to-image / image-to-video "
                "generation. For multiple images, pass them as space-separated "
                "values, e.g.: "
                '--image-path "img1.png" "img2.png"'
            ),
        )
        parser.add_argument(
            "--moba-config-path",
            type=str,
            default=None,
            help="Path to a JSON file containing V-MoBA specific configurations.",
        )
        parser.add_argument(
            "--return-trajectory-latents",
            action="store_true",
            default=SamplingParams.return_trajectory_latents,
            help="Whether to return the trajectory",
        )
        parser.add_argument(
            "--return-trajectory-decoded",
            action="store_true",
            default=SamplingParams.return_trajectory_decoded,
            help="Whether to return the decoded trajectory",
        )
        parser.add_argument(
            "--diffusers-kwargs",
            type=str,
            default=None,
            help="JSON string of extra kwargs to pass to diffusers pipeline. "
            'Example: \'{"output_type": "latent", "clip_skip": 2}\'',
        )
        parser.add_argument(
            "--no-override-protected-fields",
            action="store_true",
            default=SamplingParams.no_override_protected_fields,
            help=(
                "If set, disallow user params to override fields defined in subclasses."
            ),
        )
        parser.add_argument(
            "--adjust-frames",
            action=StoreBoolean,
            default=SamplingParams.adjust_frames,
            help=(
                "Enable/disable adjusting num_frames to evenly split latent frames across GPUs "
                "and satisfy model temporal constraints. If disabled, tokens might be padded for SP."
                "Default: true. Examples: --adjust-frames, --adjust-frames true, --adjust-frames false."
            ),
        )
        parser.add_argument(
            "--return-file-paths-only",
            action=StoreBoolean,
            default=SamplingParams.return_file_paths_only,
            help="If set, output file will be saved early to get a performance boost, while output tensors will not be returned.",
        )
        parser.add_argument(
            "--enable-sequence-shard",
            action=StoreBoolean,
            default=SamplingParams.enable_sequence_shard,
            help="Enable sequence dimension shard with sequence parallelism.",
        )
        return parser

    @classmethod
    def get_cli_args(cls, args: argparse.Namespace):
        # handle resolution shortcuts
        if hasattr(args, "resolution_4k") and args.resolution_4k:
            args.width = 3840
            args.height = 2160
        elif hasattr(args, "resolution_2k") and args.resolution_2k:
            args.width = 2560
            args.height = 1440
        elif hasattr(args, "resolution_1080p") and args.resolution_1080p:
            args.width = 1920
            args.height = 1080
        elif hasattr(args, "resolution_720p") and args.resolution_720p:
            args.width = 1280
            args.height = 720

        sampling_params_fields = {attr.name for attr in dataclasses.fields(cls)}
        args_attrs = set(vars(args).keys())
        attrs = sampling_params_fields & args_attrs
        args.height_not_provided = False
        args.width_not_provided = False
        return {attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}

    def output_file_path(self):
        return os.path.join(self.output_path, self.output_file_name)

    def _merge_with_user_params(self, user_params: "SamplingParams"):
        """
        Merges parameters from a user-provided SamplingParams object.
        """
        if user_params is None:
            return

        predefined_fields = set(type(self).__annotations__.keys())

        # global switch: if True, allow overriding protected fields
        allow_override_protected = not user_params.no_override_protected_fields
        for field in dataclasses.fields(user_params):
            field_name = field.name
            user_value = getattr(user_params, field_name)
            default_class_value = getattr(SamplingParams, field_name)

            # A field is considered user-modified if its value is different from the default
            is_user_modified = user_value != default_class_value
            is_protected_field = field_name in predefined_fields
            if is_user_modified and (
                allow_override_protected or not is_protected_field
            ):
                setattr(self, field_name, user_value)
        self.height_not_provided = user_params.height_not_provided
        self.width_not_provided = user_params.width_not_provided
        self.__post_init__()

    @property
    def n_tokens(self) -> int:
        # Calculate latent sizes
        if self.height and self.width:
            latents_size = [
                (self.num_frames - 1) // 4 + 1,
                self.height // 8,
                self.width // 8,
            ]
            n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
        else:
            n_tokens = -1
        return n_tokens


@dataclass
class CacheParams:
    cache_type: str = "none"
