# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Tuple

from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_zimage_layer(n: str, m) -> bool:
    """Returns if the module should be sharded for Z-Image model."""
    if "layers" in n and str.isdigit(n.split(".")[-1]):
        return True
    if ("noise_refiner" in n or "context_refiner" in n) and str.isdigit(
        n.split(".")[-1]
    ):
        return True
    return False


@dataclass
class ZImageArchConfig(DiTArchConfig):
    all_patch_size: Tuple[int, ...] = (2,)
    all_f_patch_size: Tuple[int, ...] = (1,)
    in_channels: int = 16
    out_channels: int | None = None
    dim: int = 3840
    num_layers: int = 30
    n_refiner_layers: int = 2
    num_attention_heads: int = 30
    n_kv_heads: int = 30
    norm_eps: float = 1e-5
    qk_norm: bool = True
    cap_feat_dim: int = 2560
    rope_theta: float = 256.0
    t_scale: float = 1000.0
    axes_dims: Tuple[int, int, int] = (32, 48, 48)
    axes_lens: Tuple[int, int, int] = (1024, 512, 512)

    _fsdp_shard_conditions: list = field(default_factory=lambda: [is_zimage_layer])

    stacked_params_mapping: list[tuple[str, str, str]] = field(
        default_factory=lambda: [
            # (param_name, shard_name, shard_id)
            (".feed_forward.w13", ".feed_forward.w1", "gate"),
            (".feed_forward.w13", ".feed_forward.w3", "up"),
        ]
    )

    param_names_mapping: dict = field(
        default_factory=lambda: {
            r"(.*)\.feed_forward\.w1\.weight$": (r"\1.feed_forward.w13.weight", 0, 2),
            r"(.*)\.feed_forward\.w3\.weight$": (r"\1.feed_forward.w13.weight", 1, 2),
            r"(.*)\.feed_forward\.w1\.(lora_A|lora_B)$": (
                r"\1.feed_forward.w13.\2",
                0,
                2,
            ),
            r"(.*)\.feed_forward\.w3\.(lora_A|lora_B)$": (
                r"\1.feed_forward.w13.\2",
                1,
                2,
            ),
        }
    )

    def __post_init__(self):
        super().__post_init__()
        self.out_channels = self.out_channels or self.in_channels
        self.num_channels_latents = self.in_channels
        self.hidden_size = self.dim


@dataclass
class ZImageDitConfig(DiTConfig):
    arch_config: ZImageArchConfig = field(default_factory=ZImageArchConfig)

    prefix: str = "zimage"
