import json
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from typing import Any, TypeVar


@dataclass
class ModelConfig:
    latent_dim: int = 128
    latent_patch_size: int = 1
    model_dim: int = 2048
    num_layers: int = 24
    num_heads: int = 16
    mlp_ratio: float = 2.875
    text_mlp_ratio: float | None = 2.6
    speaker_mlp_ratio: float | None = 2.6
    dropout: float = 0.0
    text_vocab_size: int = 102400
    text_tokenizer_repo: str = "sbintuitions/sarashina2.2-0.5b"
    text_add_bos: bool = True
    text_dim: int = 1280
    text_layers: int = 14
    text_heads: int = 10
    speaker_dim: int = 1280
    speaker_layers: int = 14
    speaker_heads: int = 10
    speaker_patch_size: int = 1
    timestep_embed_dim: int = 512
    adaln_rank: int = 256
    norm_eps: float = 1e-5

    @property
    def patched_latent_dim(self) -> int:
        return self.latent_dim * self.latent_patch_size

    @property
    def speaker_patched_latent_dim(self) -> int:
        return self.patched_latent_dim * self.speaker_patch_size

    @property
    def text_mlp_ratio_resolved(self) -> float:
        if self.text_mlp_ratio is None:
            return self.mlp_ratio
        return float(self.text_mlp_ratio)

    @property
    def speaker_mlp_ratio_resolved(self) -> float:
        if self.speaker_mlp_ratio is None:
            return self.mlp_ratio
        return float(self.speaker_mlp_ratio)


@dataclass
class TrainConfig:
    manifest_path: str = ""
    output_dir: str = "outputs"
    batch_size: int = 8
    num_workers: int = 2
    dataloader_persistent_workers: bool = False
    dataloader_prefetch_factor: int = 2
    allow_tf32: bool = False
    compile_model: bool = False
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    optimizer: str = "muon"
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_eps: float = 1e-8
    muon_momentum: float = 0.95
    muon_adjust_lr_fn: str = "match_rms_adamw"
    lr_scheduler: str = "none"
    warmup_steps: int = 0
    stable_steps: int = 0
    min_lr_scale: float = 0.1
    max_steps: int = 200000
    log_every: int = 100
    save_every: int = 1000
    checkpoint_best_n: int = 0
    valid_ratio: float = 0.0
    valid_every: int = 0
    progress: bool = True
    progress_all_ranks: bool = False
    precision: str = "bf16"
    grad_clip_norm: float = 1.0
    gradient_accumulation_steps: int = 1
    max_text_len: int = 256
    text_condition_dropout: float = 0.1
    speaker_condition_dropout: float = 0.1
    max_latent_steps: int = 750
    fixed_target_latent_steps: int | None = 750
    fixed_target_full_mask: bool = True
    timestep_logit_mean: float = 0.0
    timestep_logit_std: float = 1.0
    timestep_stratified: bool = True
    timestep_min: float = 0.001
    timestep_max: float = 0.999
    wandb_enabled: bool = False
    wandb_project: str = "Irodori-TTS"
    wandb_entity: str | None = None
    wandb_run_name: str | None = None
    wandb_mode: str = "online"
    ddp_find_unused_parameters: bool = False
    seed: int = 0


@dataclass
class SamplingConfig:
    num_steps: int = 40
    cfg_scale_text: float = 3.0
    cfg_scale_speaker: float = 5.0
    cfg_guidance_mode: str = "independent"
    cfg_scale: float | None = None
    cfg_min_t: float = 0.5
    cfg_max_t: float = 1.0
    truncation_factor: float | None = None
    rescale_k: float | None = None
    rescale_sigma: float | None = None
    context_kv_cache: bool = True
    speaker_kv_scale: float | None = None
    speaker_kv_min_t: float | None = 0.9
    speaker_kv_max_layers: int | None = None
    # Deprecated: inference length is derived from --seconds and codec hop_length.
    sequence_length: int | None = None
    seed: int = 0


def save_json(path: str | Path, payload: dict) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2), encoding="utf-8")


def dump_configs(path: str | Path, model_cfg: ModelConfig, train_cfg: TrainConfig) -> None:
    save_json(path, {"model": asdict(model_cfg), "train": asdict(train_cfg)})


T = TypeVar("T")


def load_experiment_yaml(path: str | Path) -> dict[str, Any]:
    """
    Load experiment config YAML. Returns {} for an empty document.
    """
    try:
        import yaml
    except ImportError as exc:
        raise RuntimeError(
            "PyYAML is required for --config support. Install with `pip install pyyaml`."
        ) from exc

    payload = yaml.safe_load(Path(path).read_text(encoding="utf-8"))
    if payload is None:
        return {}
    if not isinstance(payload, dict):
        raise ValueError(f"Config root must be a mapping: {path}")
    return payload


def merge_dataclass_overrides(base: T, overrides: dict[str, Any] | None, section: str) -> T:
    """
    Merge mapping overrides into a dataclass instance with key validation.
    """
    if overrides is None:
        return base
    if not isinstance(overrides, dict):
        raise ValueError(f"Config section '{section}' must be a mapping.")

    allowed = {f.name for f in fields(base)}
    unknown = sorted(set(overrides) - allowed)
    if unknown:
        raise ValueError(f"Unknown keys in '{section}' config: {unknown}")

    merged = asdict(base)
    merged.update(overrides)
    return type(base)(**merged)
