"""
Validation pipeline config: model IDs, language maps, scoring thresholds, env vars.
All tunables live here so Docker ENV overrides work cleanly.
"""
from __future__ import annotations

import os
import uuid
from dataclasses import dataclass, field
from pathlib import Path

from dotenv import load_dotenv

load_dotenv(Path(__file__).resolve().parent.parent / ".env")

# === MODEL IDs ===
MMS_LID_MODEL = "facebook/mms-lid-256"
VOXLINGUA_MODEL = "speechbrain/lang-id-voxlingua107-ecapa"
CONFORMER_MULTI_MODEL = "ai4bharat/indic-conformer-600m-multilingual"

# IndicWav2Vec per-language models (HuggingFace AutoModelForCTC)
INDICWAV2VEC_MODELS: dict[str, str] = {
    "hi": "ai4bharat/indicwav2vec-hindi",
    "bn": "ai4bharat/indicwav2vec_v1_bengali",
    "gu": "ai4bharat/indicwav2vec_v1_gujarati",
    "ta": "ai4bharat/indicwav2vec_v1_tamil",
    "te": "ai4bharat/indicwav2vec_v1_telugu",
    "mr": "ai4bharat/indicwav2vec_v1_marathi",
    "or": "ai4bharat/indicwav2vec_v1_odia",
}

# IndicConformer per-language pattern (NeMo — added later if NeMo is installed)
CONFORMER_LANG_PATTERN = "ai4bharat/indicconformer_stt_{lang}_hybrid_ctc_rnnt_large"

# === LANGUAGE MAPPINGS ===
# Our 12 target languages
LANGUAGE_MAP: dict[str, tuple[str, str]] = {
    "hi": ("Hindi", "Devanagari"),
    "mr": ("Marathi", "Devanagari"),
    "te": ("Telugu", "Telugu"),
    "ta": ("Tamil", "Tamil"),
    "kn": ("Kannada", "Kannada"),
    "ml": ("Malayalam", "Malayalam"),
    "gu": ("Gujarati", "Gujarati"),
    "pa": ("Punjabi", "Gurmukhi"),
    "bn": ("Bengali", "Bengali"),
    "as": ("Assamese", "Assamese"),
    "or": ("Odia", "Odia"),
    "en": ("English", "Latin"),
}
TARGET_LANGUAGES = set(LANGUAGE_MAP.keys())

# ISO 639-1 ↔ ISO 639-3 (MMS uses ISO-3)
LANG_TO_ISO3: dict[str, str] = {
    "hi": "hin", "mr": "mar", "te": "tel", "ta": "tam",
    "kn": "kan", "ml": "mal", "gu": "guj", "pa": "pan",
    "bn": "ben", "as": "asm", "or": "ory", "en": "eng",
}
ISO3_TO_LANG: dict[str, str] = {v: k for k, v in LANG_TO_ISO3.items()}

# VoxLingua107 label → ISO-1 mapping
# VoxLingua returns labels like "hi: Hindi" or just the 2-letter code
VOXLINGUA_LABEL_MAP: dict[str, str] = {
    "hi: Hindi": "hi", "mr: Marathi": "mr", "te: Telugu": "te",
    "ta: Tamil": "ta", "kn: Kannada": "kn", "ml: Malayalam": "ml",
    "gu: Gujarati": "gu", "pa: Punjabi": "pa", "bn: Bengali": "bn",
    "as: Assamese": "as", "or: Odia": "or", "en: English": "en",
    "ur: Urdu": "ur", "sd: Sindhi": "sd", "ne: Nepali": "ne",
    "sa: Sanskrit": "sa",
}

# IndicConformer multilingual accepted language codes
CONFORMER_LANG_CODES = {
    "hi", "mr", "te", "ta", "kn", "ml", "gu", "pa", "bn", "as", "or",
    "brx", "doi", "ks", "kok", "mai", "mni", "ne", "sa", "sat", "sd", "ur",
}

# === AUDIO SETTINGS ===
AUDIO_SAMPLE_RATE = 16000
MAX_AUDIO_DURATION_S = 30.0
MIN_AUDIO_DURATION_S = 0.5

# === BATCH SIZES (default tuned for RTX 3090 24GB with all 4 models loaded ~2.7GB) ===
MMS_BATCH_SIZE = int(os.getenv("MMS_BATCH_SIZE", "8"))
VOXLINGUA_BATCH_SIZE = int(os.getenv("VOX_BATCH_SIZE", "16"))
CONFORMER_BATCH_SIZE = int(os.getenv("CONFORMER_BATCH_SIZE", "8"))
WAV2VEC_BATCH_SIZE = int(os.getenv("WAV2VEC_BATCH_SIZE", "8"))

# === WORKER SETTINGS ===
PREFETCH_QUEUE_SIZE = int(os.getenv("PREFETCH_QUEUE_SIZE", "3"))
PARQUET_SHARD_SIZE = int(os.getenv("PARQUET_SHARD_SIZE", "50"))
HEARTBEAT_INTERVAL_S = 60
MAX_VIDEOS = int(os.getenv("MAX_VIDEOS", "0"))

# Max per-language models cached in VRAM simultaneously
LRU_MODEL_CACHE_SIZE = int(os.getenv("LRU_MODEL_CACHE_SIZE", "5"))


def _env(key: str, default: str = "") -> str:
    return os.getenv(key, default)


@dataclass
class ValidationConfig:
    """Resolved config for validation pipeline."""
    # R2
    r2_endpoint_url: str = ""
    r2_bucket_source: str = "transcribed"
    r2_model_bucket: str = "validation-results"
    r2_bucket_output: str = "validation-results"
    r2_reference_bucket: str = "validation-results"
    r2_access_key_id: str = ""
    r2_secret_access_key: str = ""
    r2_skip_upload: bool = False

    # Supabase (only for validation_status + validation_results table)
    database_url: str = ""
    supabase_url: str = ""
    supabase_admin_key: str = ""
    recover_queue_table: str = "validation_recover_queue"
    recover_reference_mode: str = "database"
    recover_tx_parquet_key: str = "reference-data/transcription_results_recover.parquet"
    recover_flags_parquet_key: str = "reference-data/transcription_flags_recover.parquet"
    recover_validated_parquet_key: str = "reference-data/validated_segment_ids.parquet"
    recover_reference_manifest_key: str = "reference-data/recover_reference_manifest.json"
    recover_reference_download_concurrency: int = 16
    recover_replay_ledger_prefix: str = "recover-replay-ledgers"
    r2_shard_prefix: str = "shards"

    # HuggingFace (gated model access)
    hf_token: str = ""

    # Worker identity
    worker_id: str = ""
    gpu_type: str = "unknown"
    mock_mode: bool = False
    max_videos: int = 0

    # Model toggles
    enable_mms_lid: bool = True
    enable_voxlingua: bool = True
    enable_conformer_multi: bool = True
    enable_wav2vec_lang: bool = True

    def __post_init__(self):
        self.r2_endpoint_url = self.r2_endpoint_url or _env("R2_ENDPOINT_URL")
        self.r2_bucket_source = _env("R2_VALIDATION_SOURCE", self.r2_bucket_source)
        self.r2_model_bucket = _env("R2_VALIDATION_MODEL_BUCKET", self.r2_model_bucket)
        self.r2_bucket_output = _env("R2_VALIDATION_OUTPUT", self.r2_bucket_output)
        self.r2_reference_bucket = _env("R2_VALIDATION_REFERENCE_BUCKET", self.r2_reference_bucket)
        self.r2_access_key_id = self.r2_access_key_id or _env("R2_ACCESS_KEY_ID")
        self.r2_secret_access_key = self.r2_secret_access_key or _env("R2_SECRET_ACCESS_KEY")
        self.r2_skip_upload = _env("R2_SKIP_UPLOAD", "false").lower() == "true"
        self.database_url = self.database_url or _env("DATABASE_URL")
        self.supabase_url = self.supabase_url or _env("URL")
        self.supabase_admin_key = self.supabase_admin_key or _env("SUPABASE_ADMIN")
        self.recover_queue_table = _env("VALIDATION_RECOVER_QUEUE_TABLE", self.recover_queue_table)
        self.recover_reference_mode = _env("RECOVER_REFERENCE_MODE", self.recover_reference_mode).lower()
        self.recover_tx_parquet_key = _env("RECOVER_TX_PARQUET_KEY", self.recover_tx_parquet_key)
        self.recover_flags_parquet_key = _env("RECOVER_FLAGS_PARQUET_KEY", self.recover_flags_parquet_key)
        self.recover_validated_parquet_key = _env("RECOVER_VALIDATED_PARQUET_KEY", self.recover_validated_parquet_key)
        self.recover_reference_manifest_key = _env(
            "RECOVER_REFERENCE_MANIFEST_KEY",
            self.recover_reference_manifest_key,
        )
        self.recover_reference_download_concurrency = int(
            _env(
                "RECOVER_REFERENCE_DOWNLOAD_CONCURRENCY",
                str(self.recover_reference_download_concurrency),
            )
        )
        self.recover_replay_ledger_prefix = _env(
            "RECOVER_REPLAY_LEDGER_PREFIX",
            self.recover_replay_ledger_prefix,
        )
        self.r2_shard_prefix = _env("R2_SHARD_PREFIX", self.r2_shard_prefix).strip("/")
        self.hf_token = self.hf_token or _env("HF_TOKEN")
        self.worker_id = self.worker_id or _env("WORKER_ID", str(uuid.uuid4())[:12])
        self.gpu_type = _env("GPU_TYPE", self.gpu_type)
        if not self.mock_mode:
            self.mock_mode = _env("MOCK_MODE", "false").lower() == "true"
        self.max_videos = int(_env("MAX_VIDEOS", str(self.max_videos)))

    def validate(self) -> list[str]:
        errors = []
        if not self.mock_mode:
            if not self.r2_endpoint_url:
                errors.append("R2_ENDPOINT_URL required")
            if self.recover_reference_mode not in {"database", "parquet"}:
                errors.append("RECOVER_REFERENCE_MODE must be 'database' or 'parquet'")
        return errors
