"""Centralized configuration for the encoding pipeline.

All tunables in one place. CLI overrides via argparse flags.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()


@dataclass
class R2Config:
    endpoint_url: str = os.getenv("R2_ENDPOINT_URL", "")
    access_key_id: str = os.getenv("R2_ACCESS_KEY_ID", "")
    secret_access_key: str = os.getenv("R2_SECRET_ACCESS_KEY", "")

    source_buckets: dict[str, str] = field(default_factory=lambda: {
        "english": "pt-english",
        "indic": "pt-indic",
    })
    output_bucket: str = "pretrain-encoded"
    xcodec_bucket: str = "xcodec"
    xcodec_ckpt_key: str = "nikil_new/indic_step_00198000.ckpt"
    metafiles_bucket: str = "metafiles"

    max_download_workers: int = 4
    download_chunk_size: int = 8 * 1024 * 1024  # 8MB multipart chunks


@dataclass
class SupabaseConfig:
    url: str = os.getenv("URL", "")
    admin_key: str = os.getenv("SUPABASE_ADMIN", "")
    anon_key: str = os.getenv("SUPABASE_ANON", "")

    videos_table: str = "encoding_videos"
    workers_table: str = "worker_encoders"
    shards_table: str = "encoding_shards"

    heartbeat_interval_s: int = 30
    claim_timeout_s: int = 600  # 10 min — if no heartbeat, release video


@dataclass
class VADConfig:
    min_segment_s: float = 3.0
    max_segment_s: float = 30.0
    min_speech_duration_s: float = 0.5
    # Silero-VAD thresholds
    threshold: float = 0.5
    min_silence_duration_ms: int = 300
    speech_pad_ms: int = 100
    sample_rate: int = 16_000
    # Chunked parallel VAD for long audio
    chunk_threshold_s: float = 300.0  # enable chunking above this duration
    chunk_size_s: float = 300.0       # seconds per chunk
    chunk_overlap_s: float = 2.0      # overlap at boundaries to avoid cutting speech


@dataclass
class CodecConfig:
    xcodec2_model_id: str = "HKUSTAudio/xcodec2"
    xcodec2_custom_ckpt: str | None = None  # path to custom .ckpt if set
    bicodec_model_dir: str = "repos/Spark-TTS-0.5B"
    spark_tts_repo: str = "repos/Spark-TTS"
    target_sr: int = 16_000
    chunk_seconds: float = 6.0
    xcodec_batch_size: int = 2
    bicodec_batch_size: int = 1


@dataclass
class WorkerConfig:
    worker_id: str = ""  # set at runtime: <offer_id>_<gpu_name>_<pid>
    offer_id: str = os.getenv("VAST_OFFER_ID", "local")
    gpu_name: str = ""  # filled at runtime
    prefetch_videos: int = 4  # how many videos to keep pre-downloaded
    extract_workers: int = 4  # parallel ffmpeg+VAD workers (set to CPU_cores//2)
    ffmpeg_threads: int = 2  # threads per ffmpeg process
    ready_queue_depth: int = 8  # max PreparedVideos buffered for GPU (keep >> 1 to avoid GPU starvation)
    shard_pack_count: int = 50  # pack N videos' tokens per shard upload
    local_tmp_dir: str = "/tmp/pipeline"
    parallel_encode: bool = True  # dual-codec parallel threads+streams
    oom_segment_threshold: int = 500  # only preemptive OOM-safe for extreme videos; normal OOM retry handles the rest
    max_retries: int = 3
    use_async_pipeline: bool = True  # use 3-stage async pipeline vs serial


@dataclass
class PipelineConfig:
    r2: R2Config = field(default_factory=R2Config)
    supabase: SupabaseConfig = field(default_factory=SupabaseConfig)
    vad: VADConfig = field(default_factory=VADConfig)
    codec: CodecConfig = field(default_factory=CodecConfig)
    worker: WorkerConfig = field(default_factory=WorkerConfig)

    @classmethod
    def from_env(cls) -> "PipelineConfig":
        load_dotenv()
        return cls()
