import torch
import multiprocessing as mp
import os
from datasets import load_dataset, concatenate_datasets
from typing import List
from ..config_manager import ConfigManager, DatasetConfig
from .dataset_processor import DatasetProcessor
from .audio_worker import worker_process, AudioWorker
from .reader_worker import reader_worker_process
from ..speaker_emb.processor import SpeakerEmbeddingProcessor
from ..speaker_emb.group_embeddings import SpeakerEmbeddingGrouper


class PipelineManager:
    """Manages the entire audio processing pipeline"""

    def __init__(self, config_path: str = "config.yaml"):
        self.config_manager = ConfigManager(config_path)
        self.base_settings = self.config_manager.get_base_settings()
        self.save_settings = self.config_manager.get_save_settings()
        self.speaker_emb_settings = self.config_manager.get_speaker_embedding_settings()
        self.num_gpus = torch.cuda.device_count()

        # Ensure output directory exists
        os.makedirs(self.base_settings.OUT_DIR, exist_ok=True)

        # Set spawn start method once — used by both speaker embedding
        # (WavLMEmbedder multi-GPU) and tokenization workers.
        try:
            mp.set_start_method("spawn", force=True)
        except RuntimeError:
            pass

    def _check_clustering_conflict(self):
        """
        Check whether the clustering speaker column name conflicts with any
        add_constant key in the dataset configs. If a conflict is found, prompt
        the user to supply an alternative column name.
        """
        if not self.speaker_emb_settings.add_speaker_emb:
            return
        if not self.speaker_emb_settings.do_clusters:
            return

        conflict_column = self.speaker_emb_settings.clustering_speaker_column
        datasets = self.config_manager.get_datasets()

        has_conflict = any(
            conflict_column in ds.get_constant_columns()
            for ds in datasets
        )

        if not has_conflict:
            return

        print("\n" + "=" * 60)
        print("⚠️  WARNING: Clustering column naming conflict detected!")
        print("=" * 60)
        print(
            f"  'do_clusters: true' will produce a '{conflict_column}' column "
            f"(integer cluster IDs),"
        )
        print(
            f"  but '{conflict_column}' is also defined in 'add_constant' "
            f"(string constant value)."
        )
        print("  Please provide a different name for the clustering output column.")
        print("-" * 60)

        new_name = input(
            f"  New clustering column name [default: cluster_id]: "
        ).strip()

        if not new_name:
            new_name = "cluster_id"

        self.speaker_emb_settings.clustering_speaker_column = new_name
        print(f"  ✅ Clustering column will be named: '{new_name}'")
        print("=" * 60 + "\n")

    def _check_group_emb_feasibility(self):
        """
        Warn the user early if 'group_sp_emb.do_this: true' but the
        group_by_column_name cannot be traced to any known source:
          - add_constant in any dataset config
          - clustering output (do_clusters: true + matching column name)
          - speaker_column_name of any dataset config

        If no source is found, prompt the user to disable grouping via
        interactive input.
        """
        group_emb = self.speaker_emb_settings.group_emb
        if not group_emb or not group_emb.do_this:
            return

        group_col = group_emb.group_by_column_name
        datasets  = self.config_manager.get_datasets()

        sources = []

        # Source 1: add_constant in any dataset
        for ds in datasets:
            if group_col in ds.get_constant_columns():
                sources.append(f"add_constant in '{ds.name}'")

        # Source 2: clustering output
        sp = self.speaker_emb_settings
        if (sp.add_speaker_emb and sp.do_clusters
                and sp.clustering_speaker_column == group_col):
            sources.append("clustering (do_clusters: true)")

        # Source 3: speaker_column_name of any dataset
        for ds in datasets:
            if ds.speaker_column_name == group_col:
                sources.append(f"speaker_column_name in '{ds.name}'")

        if sources:
            # Column will definitely be present — all good
            print(
                f"[Validation] group_sp_emb: column '{group_col}' "
                f"provided by: {', '.join(sources)}"
            )
            return

        # No source found — warn and offer to disable
        print("\n" + "=" * 60)
        print("⚠️  WARNING: group_sp_emb column source not found!")
        print("=" * 60)
        print(
            f"  'group_sp_emb.do_this: true' requires column '{group_col}'"
        )
        print(
            f"  but it will NOT be produced by add_constant, clustering, "
            f"or speaker_column_name."
        )
        print(
            "  It may already exist in the source dataset — proceed only "
            "if you are sure."
        )
        print("-" * 60)

        answer = input(
            "  Disable group_sp_emb for this run? [Y/n]: "
        ).strip().lower()

        if answer != "n":
            group_emb.do_this = False
            print("  ✅ group_sp_emb disabled.")
        else:
            print("  ⚠️  Proceeding — will fail at runtime if column is absent.")
        print("=" * 60 + "\n")

    def validate(self):
        """Validate configuration and environment"""
        print("🔍 Validating configuration...")
        self.config_manager.validate_datasets()
        self._check_clustering_conflict()
        self._check_group_emb_feasibility()

        if self.num_gpus == 0:
            raise RuntimeError("❌ ERROR: No CUDA devices found!")

        print(f"✅ Found {self.num_gpus} GPU(s)")

    def process_single_dataset(self, dataset_config: DatasetConfig):
        """Process a single dataset through the pipeline"""
        print(f"\n{'='*60}")
        print(f"🎯 Processing dataset: {dataset_config.name}")
        print(f"📝 Dataset prefix: {dataset_config.dataset_prefix}")
        print(f"{'='*60}")

        # Load dataset
        processor = DatasetProcessor(dataset_config)
        processor.load_dataset(num_proc=self.base_settings.load_dataset_num_proc)

        # ── Speaker Embedding Step ─────────────────────────────────────────────
        group_emb = self.speaker_emb_settings.group_emb

        if self.speaker_emb_settings.add_speaker_emb:
            print(f"\n{'─'*60}")
            print("🎙️  Speaker Embedding: starting...")
            print(f"{'─'*60}")

            sp_processor = SpeakerEmbeddingProcessor(self.speaker_emb_settings)
            updated_dataset = sp_processor.process(
                processor.get_dataset(),
                dataset_config.audio_column_name,
            )
            processor.dataset = updated_dataset
            processor.preserve_columns = sp_processor.get_preserved_columns()

            print(f"✅ Speaker Embedding done.")
            print(f"   Columns preserved through tokenization: {processor.preserve_columns}")

        elif group_emb and group_emb.do_this:
            # Embeddings already exist in the dataset; only run grouping.
            print(f"\n{'─'*60}")
            print("📊  Group Embeddings (standalone): starting...")
            print(f"{'─'*60}")

            grouper = SpeakerEmbeddingGrouper(
                settings=group_emb,
                embedding_column=self.speaker_emb_settings.embedding_column,
            )
            updated_dataset = grouper.group(processor.get_dataset())
            processor.dataset = updated_dataset
            processor.preserve_columns = grouper.get_preserved_columns()

            print(f"✅ Group Embeddings done.")
            print(f"   Columns preserved through tokenization: {processor.preserve_columns}")
        # ──────────────────────────────────────────────────────────────────────

        # Setup multiprocessing queue
        q = mp.Queue(maxsize=self.base_settings.qsize)

        print(f"\n🚀 Starting processing pipeline")
        print(f"💻 CUDA available: {torch.cuda.is_available()}")
        print(f"🔥 GPU workers: {self.num_gpus}")
        print(f"📖 Reader workers: {self.base_settings.num_readers}")
        print(f"⚙️  Dataset load processes: {self.base_settings.load_dataset_num_proc}")
        print(f"📁 Output directory: {self.base_settings.OUT_DIR}")
        print(f"🗂️  Lines per file: {self.base_settings.lines_per_file:,}")
        print(f"📦 Queue size: {self.base_settings.qsize}")
        print("-" * 60)

        # Start GPU worker processes
        workers = [
            mp.Process(
                target=worker_process,
                args=(
                    i,
                    q,
                    self.base_settings.OUT_DIR,
                    dataset_config.dataset_prefix,
                    self.base_settings.gzip_level,
                    self.base_settings.buffer_size,
                    self.base_settings.lines_per_file,
                    self.base_settings.num_readers,
                    self.base_settings.audio_codec  # Pass model_id from config
                )
            )
            for i in range(self.num_gpus)
        ]

        for p in workers:
            p.start()

        # Shard the dataset for readers
        dataset = processor.get_dataset()
        sharded_datasets = [
            dataset.shard(num_shards=self.base_settings.num_readers, index=i)
            for i in range(self.base_settings.num_readers)
        ]

        # Create dataset processors for each shard (they share the config but have different shards)
        shard_processors = []
        for i in range(self.base_settings.num_readers):
            shard_proc = DatasetProcessor(dataset_config)
            shard_proc.dataset = sharded_datasets[i]
            shard_proc.preserve_columns = processor.preserve_columns
            shard_processors.append(shard_proc)

        # Start reader processes
        readers = [
            mp.Process(
                target=reader_worker_process,
                args=(i, self.base_settings.num_readers, shard_processors[i], q)
            )
            for i in range(self.base_settings.num_readers)
        ]

        for pr in readers:
            pr.start()

        try:
            # Wait for all readers to complete
            for pr in readers:
                pr.join()

            # Send SENTINEL to all workers after all readers are done
            for i in range(self.num_gpus):
                q.put(AudioWorker.SENTINEL)

            # Wait for workers to complete
            for p in workers:
                p.join()

            print("\n" + "=" * 60)
            print(f"🎉 Dataset {dataset_config.name} processed successfully!")

            # Show file statistics
            if os.path.exists(self.base_settings.OUT_DIR):
                files = [f for f in os.listdir(self.base_settings.OUT_DIR)
                        if f.startswith(dataset_config.dataset_prefix)]
                if files:
                    total_size = sum(
                        os.path.getsize(os.path.join(self.base_settings.OUT_DIR, f))
                        for f in files
                    )
                    print(f"📊 Generated {len(files)} files for this dataset, "
                          f"size: {total_size / 1024**3:.2f} GB")

        except KeyboardInterrupt:
            print("\n⚠️  Interrupted! Terminating processes...")

            # Stop progress bars
            for pr in readers:
                pr.terminate()
            for p in workers:
                p.terminate()

            # Wait for completion with timeout
            for pr in readers:
                pr.join(timeout=10)
            for p in workers:
                p.join(timeout=10)

            print("🛑 All processes terminated")
            raise

    def assemble_and_save_final_dataset(self):
        """Assemble all processed shards into final dataset and save/upload"""
        print(f"\n{'='*60}")
        print("🔨 Assembling final dataset from all shards...")
        print(f"{'='*60}")

        # Load all JSONL.gz files from output directory
        shard_files = os.path.join(self.base_settings.OUT_DIR, "*.jsonl.gz")
        print(f"📂 Loading shards from: {shard_files}")

        final_dataset = load_dataset(
            "json",
            data_dir=self.base_settings.OUT_DIR,
            data_files="*.jsonl.gz",
            split='train',
            verification_mode='no_checks'
        )

        print(f"✅ Final dataset assembled: {len(final_dataset)} samples")

        # Save locally if specified
        if self.save_settings.local:
            print(f"\n💾 Saving dataset locally to: {self.save_settings.local}")
            final_dataset.save_to_disk(self.save_settings.local)
            print(f"✅ Dataset saved to disk")

        # Upload to HuggingFace if specified
        if self.save_settings.hf_upload:
            print(f"\n☁️  Uploading dataset to HuggingFace: {self.save_settings.hf_upload}")
            final_dataset.push_to_hub(self.save_settings.hf_upload, private=True)
            print(f"✅ Dataset uploaded to HuggingFace Hub")

        print(f"\n{'='*60}")
        print("🎊 Pipeline completed successfully!")
        print(f"{'='*60}")

    def run(self):
        """Run the complete pipeline for all datasets"""
        # Validate configuration
        self.validate()

        # Get all datasets to process
        datasets = self.config_manager.get_datasets()
        print(f"\n📋 Found {len(datasets)} dataset(s) to process")

        # Process each dataset sequentially
        for idx, dataset_config in enumerate(datasets, 1):
            print(f"\n🔄 Processing dataset {idx}/{len(datasets)}")
            self.process_single_dataset(dataset_config)

        # After all datasets are processed, assemble and save
        self.assemble_and_save_final_dataset()

        print("\n👋 Pipeline finished!")
