import yaml
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field


@dataclass
class DatasetConfig:
    """Configuration for a single dataset"""
    name: str
    text_column_name: str
    audio_column_name: str
    speaker_column_name: Optional[str]
    add_constant: Optional[List[Dict[str, str]]] = None
    split: str = "train"  # Default to 'train' split
    sub_name: Optional[str] = None  # Dataset subset/configuration name

    @property
    def dataset_prefix(self) -> str:
        """Extract dataset prefix from name (part after /)"""
        return self.name.split('/')[-1]

    def get_constant_columns(self) -> Dict[str, str]:
        """Get constant columns as a dictionary"""
        if not self.add_constant:
            return {}
        return {item['key']: item['value'] for item in self.add_constant}


@dataclass
class BaseSettings:
    """Base pipeline settings"""
    audio_codec: str
    num_readers: int
    qsize: int
    OUT_DIR: str
    gzip_level: int
    buffer_size: int
    lines_per_file: int
    load_dataset_num_proc: int = 5  # Default to 5 if not specified


@dataclass
class SaveSettings:
    """Settings for saving/uploading datasets"""
    local: Optional[str]
    hf_upload: Optional[str]


@dataclass
class GroupEmbeddingSettings:
    """Settings for the per-speaker averaged embedding step."""
    do_this: bool
    group_by_column_name: str
    grouped_embedding_column: str


@dataclass
class SpeakerEmbeddingSettings:
    """Settings for the speaker embedding pipeline step."""
    add_speaker_emb: bool
    model_name: str
    embedding_column: str
    target_sr: int
    max_audio_sec: float
    batch_size: int
    use_multiprocessing: bool
    do_clusters: bool
    umap_params: dict
    hdbscan_params: dict
    # Resolved at runtime if a naming conflict is detected (see PipelineManager)
    clustering_speaker_column: str = "speaker_id"
    group_emb: Optional[GroupEmbeddingSettings] = None


class ConfigManager:
    """Manages configuration loading and validation"""

    def __init__(self, config_path: str = "config.yaml"):
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)

        self.base_settings = BaseSettings(**self.config['base_settings'])
        self.save_settings = SaveSettings(**self.config['save_settings'])
        self.datasets = [DatasetConfig(**ds) for ds in self.config['hf_datasets']]

        sp = self.config.get('speaker_embedding_settings', {})
        grp = sp.get('group_sp_emb', {})
        group_emb = GroupEmbeddingSettings(
            do_this=grp.get('do_this', False),
            group_by_column_name=grp.get('group_by_column_name', 'speaker_id'),
            grouped_embedding_column=grp.get('grouped_embedding_column', 'grouped_sp_emb'),
        )
        self.speaker_embedding_settings = SpeakerEmbeddingSettings(
            add_speaker_emb=sp.get('add_speaker_emb', False),
            model_name=sp.get('model', {}).get('name', ''),
            embedding_column=sp.get('model', {}).get('embedding_column', 'wavlm_embedding'),
            target_sr=sp.get('audio', {}).get('target_sample_rate', 16000),
            max_audio_sec=sp.get('audio', {}).get('max_duration_sec', 30.0),
            batch_size=sp.get('processing', {}).get('batch_size', 2),
            use_multiprocessing=sp.get('processing', {}).get('use_multiprocessing', True),
            do_clusters=sp.get('clustering', {}).get('do_clusters', False),
            umap_params=sp.get('clustering', {}).get('UMAP', {}),
            hdbscan_params=sp.get('clustering', {}).get('HDBSCAN', {}),
            clustering_speaker_column=sp.get('clustering', {}).get('speaker_column', 'speaker_id'),
            group_emb=group_emb,
        )

    def validate_datasets(self) -> None:
        """
        Validate that all datasets have matching additional columns.
        This prevents conflicts when merging the final dataset.
        """
        if not self.datasets:
            raise ValueError("No datasets specified in configuration")

        # Get all constant column keys from all datasets
        all_constant_keys = set()
        dataset_constant_keys = []

        for ds in self.datasets:
            constant_keys = set(ds.get_constant_columns().keys())
            dataset_constant_keys.append((ds.name, constant_keys))
            all_constant_keys.update(constant_keys)

        # Check if all datasets have the same constant columns
        if all_constant_keys:
            for ds_name, keys in dataset_constant_keys:
                missing_keys = all_constant_keys - keys
                if missing_keys:
                    raise ValueError(
                        f"Dataset '{ds_name}' is missing constant columns: {missing_keys}. "
                        f"All datasets must have the same constant columns for proper merging."
                    )

        print("✅ Dataset validation passed: All datasets have matching additional columns")

    def get_datasets(self) -> List[DatasetConfig]:
        """Get list of dataset configurations"""
        return self.datasets

    def get_base_settings(self) -> BaseSettings:
        """Get base pipeline settings"""
        return self.base_settings

    def get_save_settings(self) -> SaveSettings:
        """Get save/upload settings"""
        return self.save_settings

    def get_speaker_embedding_settings(self) -> SpeakerEmbeddingSettings:
        """Get speaker embedding settings"""
        return self.speaker_embedding_settings
