#!/usr/bin/env python3
"""
Download and prepare evaluation datasets for XCodec2 Indic training.

Languages: telugu, hindi, english, tamil, kannada, malayalam, assamese, odia, marathi, punjabi, gujarati, bengali

Sources:
- FLEURS (Google): High-quality read speech in 102 languages (primary)
- LibriSpeech: Studio audiobook recordings (English backup)
- CommonVoice: Additional crowd-sourced data if needed

Target: 500 samples per language for evaluation

Output structure:
    data/evaluation/
    ├── telugu/
    │   ├── audio/
    │   │   ├── 000000.wav
    │   │   ├── 000001.wav
    │   │   └── ...
    │   └── metadata.json
    ├── hindi/
    │   └── ...
    └── evaluation_manifest.json

Usage:
    python scripts/data_prep/prepare_evaluation_data.py
    python scripts/data_prep/prepare_evaluation_data.py --samples-per-lang 500
"""

import os
import sys
import json
import random
import argparse
from pathlib import Path
from typing import Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
import torchaudio
from tqdm import tqdm

# ==============================================================================
# Configuration
# ==============================================================================

LANGUAGES = [
    "telugu", "hindi", "english", "tamil", "kannada", "malayalam",
    "assamese", "odia", "marathi", "punjabi", "gujarati", "bengali"
]

# FLEURS language codes (Google's dataset)
FLEURS_LANG_CODES = {
    "telugu": "te_in",
    "hindi": "hi_in",
    "english": "en_us",
    "tamil": "ta_in",
    "bengali": "bn_in",
    "gujarati": "gu_in",
    "kannada": "kn_in",
    "malayalam": "ml_in",
    "marathi": "mr_in",
    "punjabi": "pa_in",
    "odia": "or_in",
    "assamese": "as_in",
}

# CommonVoice language codes (backup source)
COMMONVOICE_LANG_CODES = {
    "telugu": "te",
    "hindi": "hi",
    "english": "en",
    "tamil": "ta",
    "bengali": "bn",
    "gujarati": "gu",
    "kannada": "kn",
    "malayalam": "ml",
    "marathi": "mr",
    "punjabi": "pa-IN",
    "odia": "or",
    "assamese": "as",
}

# Minimum audio duration in seconds for evaluation
MIN_DURATION = 2.0
MAX_DURATION = 15.0
TARGET_SR = 16000


def resample_audio(waveform: torch.Tensor, orig_sr: int, target_sr: int = 16000) -> torch.Tensor:
    """Resample audio to target sample rate."""
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
        waveform = resampler(waveform)
    return waveform


def download_fleurs_lang(
    lang: str,
    output_dir: Path,
    max_samples: int = 500,
    splits: List[str] = ["test", "validation"]
) -> Dict:
    """Download FLEURS data for a single language."""
    from datasets import load_dataset
    
    lang_code = FLEURS_LANG_CODES.get(lang)
    if not lang_code:
        return {"error": f"Language {lang} not in FLEURS"}
    
    lang_dir = output_dir / lang
    audio_dir = lang_dir / "audio"
    audio_dir.mkdir(parents=True, exist_ok=True)
    
    samples = []
    total_duration = 0.0
    sample_idx = 0
    
    for split in splits:
        if sample_idx >= max_samples:
            break
            
        try:
            dataset = load_dataset(
                "google/fleurs",
                lang_code,
                split=split,
                trust_remote_code=True
            )
        except Exception as e:
            print(f"    ⚠️ Failed to load {split} split: {e}")
            continue
        
        for item in dataset:
            if sample_idx >= max_samples:
                break
            
            audio = item["audio"]
            waveform = torch.tensor(audio["array"]).float().unsqueeze(0)
            sr = audio["sampling_rate"]
            
            # Resample to 16kHz
            waveform = resample_audio(waveform, sr, TARGET_SR)
            
            # Calculate duration
            duration = waveform.shape[1] / TARGET_SR
            
            # Filter by duration
            if duration < MIN_DURATION or duration > MAX_DURATION:
                continue
            
            # Save audio
            audio_path = audio_dir / f"{sample_idx:06d}.wav"
            torchaudio.save(str(audio_path), waveform, TARGET_SR)
            
            samples.append({
                "id": f"{lang}_{sample_idx:06d}",
                "path": str(audio_path.relative_to(output_dir)),
                "duration": round(duration, 3),
                "language": lang,
                "source": "fleurs",
                "split": split,
                "transcription": item.get("transcription", ""),
            })
            
            total_duration += duration
            sample_idx += 1
    
    # Save metadata
    metadata = {
        "language": lang,
        "source": "fleurs",
        "lang_code": lang_code,
        "num_samples": len(samples),
        "total_duration_hours": round(total_duration / 3600, 4),
        "samples": samples,
    }
    
    metadata_path = lang_dir / "metadata.json"
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    return metadata


def download_commonvoice_lang(
    lang: str,
    output_dir: Path,
    max_samples: int = 500,
    existing_count: int = 0
) -> Dict:
    """Download CommonVoice data for a single language (backup source)."""
    from datasets import load_dataset
    
    lang_code = COMMONVOICE_LANG_CODES.get(lang)
    if not lang_code:
        return {"error": f"Language {lang} not in CommonVoice"}
    
    lang_dir = output_dir / lang
    audio_dir = lang_dir / "audio"
    audio_dir.mkdir(parents=True, exist_ok=True)
    
    # Load existing metadata if any
    metadata_path = lang_dir / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path) as f:
            metadata = json.load(f)
        samples = metadata.get("samples", [])
        sample_idx = len(samples)
    else:
        samples = []
        sample_idx = existing_count
    
    needed = max_samples - sample_idx
    if needed <= 0:
        return {"skipped": True, "reason": "Already have enough samples"}
    
    total_duration = sum(s["duration"] for s in samples)
    
    try:
        # CommonVoice test split
        dataset = load_dataset(
            "mozilla-foundation/common_voice_16_1",
            lang_code,
            split="test",
            trust_remote_code=True
        )
    except Exception as e:
        return {"error": f"Failed to load CommonVoice: {e}"}
    
    for item in dataset:
        if len(samples) >= max_samples:
            break
        
        audio = item["audio"]
        waveform = torch.tensor(audio["array"]).float().unsqueeze(0)
        sr = audio["sampling_rate"]
        
        waveform = resample_audio(waveform, sr, TARGET_SR)
        duration = waveform.shape[1] / TARGET_SR
        
        if duration < MIN_DURATION or duration > MAX_DURATION:
            continue
        
        audio_path = audio_dir / f"{sample_idx:06d}.wav"
        torchaudio.save(str(audio_path), waveform, TARGET_SR)
        
        samples.append({
            "id": f"{lang}_{sample_idx:06d}",
            "path": str(audio_path.relative_to(output_dir)),
            "duration": round(duration, 3),
            "language": lang,
            "source": "commonvoice",
            "transcription": item.get("sentence", ""),
        })
        
        total_duration += duration
        sample_idx += 1
    
    # Update metadata
    metadata = {
        "language": lang,
        "source": "fleurs+commonvoice",
        "num_samples": len(samples),
        "total_duration_hours": round(total_duration / 3600, 4),
        "samples": samples,
    }
    
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    return metadata


def download_librispeech_english(output_dir: Path, max_samples: int = 500) -> Dict:
    """Download LibriSpeech test-clean for English (highest quality)."""
    from datasets import load_dataset
    
    lang = "english"
    lang_dir = output_dir / lang
    audio_dir = lang_dir / "audio"
    audio_dir.mkdir(parents=True, exist_ok=True)
    
    samples = []
    total_duration = 0.0
    sample_idx = 0
    
    try:
        dataset = load_dataset(
            "librispeech_asr",
            "clean",
            split="test",
            trust_remote_code=True
        )
    except Exception as e:
        return {"error": f"Failed to load LibriSpeech: {e}"}
    
    for item in dataset:
        if sample_idx >= max_samples:
            break
        
        audio = item["audio"]
        waveform = torch.tensor(audio["array"]).float().unsqueeze(0)
        sr = audio["sampling_rate"]
        
        waveform = resample_audio(waveform, sr, TARGET_SR)
        duration = waveform.shape[1] / TARGET_SR
        
        if duration < MIN_DURATION or duration > MAX_DURATION:
            continue
        
        audio_path = audio_dir / f"{sample_idx:06d}.wav"
        torchaudio.save(str(audio_path), waveform, TARGET_SR)
        
        samples.append({
            "id": f"{lang}_{sample_idx:06d}",
            "path": str(audio_path.relative_to(output_dir)),
            "duration": round(duration, 3),
            "language": lang,
            "source": "librispeech",
            "transcription": item.get("text", ""),
        })
        
        total_duration += duration
        sample_idx += 1
    
    metadata = {
        "language": lang,
        "source": "librispeech",
        "num_samples": len(samples),
        "total_duration_hours": round(total_duration / 3600, 4),
        "samples": samples,
    }
    
    metadata_path = lang_dir / "metadata.json"
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    return metadata


def create_evaluation_manifest(output_dir: Path) -> Path:
    """Create combined evaluation manifest from all language data."""
    all_samples = []
    lang_stats = {}
    
    for lang in LANGUAGES:
        lang_dir = output_dir / lang
        metadata_path = lang_dir / "metadata.json"
        
        if not metadata_path.exists():
            print(f"  ⚠️ Missing metadata for {lang}")
            continue
        
        with open(metadata_path, encoding='utf-8') as f:
            metadata = json.load(f)
        
        samples = metadata.get("samples", [])
        all_samples.extend(samples)
        
        lang_stats[lang] = {
            "num_samples": metadata["num_samples"],
            "total_duration_hours": metadata["total_duration_hours"],
            "source": metadata.get("source", "unknown"),
        }
    
    # Create manifest
    manifest = {
        "name": "xcodec2_indic_evaluation",
        "version": "1.0",
        "description": "Evaluation dataset for XCodec2 Indic model (500 samples per language)",
        "languages": LANGUAGES,
        "total_samples": len(all_samples),
        "language_stats": lang_stats,
        "samples": all_samples,
    }
    
    manifest_path = output_dir / "evaluation_manifest.json"
    with open(manifest_path, 'w', encoding='utf-8') as f:
        json.dump(manifest, f, indent=2, ensure_ascii=False)
    
    # Also create a simple TSV filelist
    tsv_path = output_dir / "evaluation.tsv"
    with open(tsv_path, 'w') as f:
        for sample in all_samples:
            full_path = output_dir / sample["path"]
            f.write(f"{full_path}\t{sample['language']}\t{sample['duration']}\n")
    
    return manifest_path


def main():
    parser = argparse.ArgumentParser(description="Prepare evaluation datasets for XCodec2 Indic")
    parser.add_argument("--output-dir", type=str, default="data/evaluation",
                        help="Output directory for evaluation data")
    parser.add_argument("--samples-per-lang", type=int, default=500,
                        help="Number of samples per language")
    parser.add_argument("--langs", type=str, nargs="+", default=None,
                        help="Specific languages to download (default: all 12)")
    parser.add_argument("--use-librispeech", action="store_true", default=True,
                        help="Use LibriSpeech for English instead of FLEURS")
    args = parser.parse_args()
    
    # Setup
    script_dir = Path(__file__).parent.parent.parent
    output_dir = script_dir / args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)
    
    languages = args.langs or LANGUAGES
    max_samples = args.samples_per_lang
    
    print("=" * 70)
    print("📥 DOWNLOADING EVALUATION DATASETS FOR XCODEC2 INDIC")
    print("=" * 70)
    print(f"Output directory: {output_dir}")
    print(f"Languages: {', '.join(languages)}")
    print(f"Samples per language: {max_samples}")
    print()
    
    results = {}
    
    for lang in languages:
        print(f"\n{'='*60}")
        print(f"📥 Processing: {lang.upper()}")
        print(f"{'='*60}")
        
        # For English, use LibriSpeech (higher quality)
        if lang == "english" and args.use_librispeech:
            print("  Using LibriSpeech test-clean (studio quality)...")
            result = download_librispeech_english(output_dir, max_samples)
        else:
            # For Indic languages, use FLEURS (Google's high-quality dataset)
            print("  Downloading from FLEURS...")
            result = download_fleurs_lang(lang, output_dir, max_samples)
        
        if "error" in result:
            print(f"  ❌ Error: {result['error']}")
            # Try CommonVoice as backup
            print("  Trying CommonVoice as backup...")
            result = download_commonvoice_lang(lang, output_dir, max_samples)
        
        results[lang] = result
        
        if "error" not in result:
            num = result.get("num_samples", 0)
            hours = result.get("total_duration_hours", 0)
            print(f"  ✅ {lang}: {num} samples ({hours:.2f} hours)")
        else:
            print(f"  ❌ {lang}: Failed - {result.get('error')}")
    
    # Create combined manifest
    print("\n" + "=" * 70)
    print("📋 CREATING COMBINED EVALUATION MANIFEST")
    print("=" * 70)
    
    manifest_path = create_evaluation_manifest(output_dir)
    
    # Print summary
    print("\n" + "=" * 70)
    print("📊 EVALUATION DATA SUMMARY")
    print("=" * 70)
    
    with open(manifest_path, encoding='utf-8') as f:
        manifest = json.load(f)
    
    print(f"\n{'Language':<12} {'Samples':>8} {'Hours':>8} {'Source':<15}")
    print("-" * 50)
    
    total_samples = 0
    total_hours = 0
    
    for lang in LANGUAGES:
        stats = manifest["language_stats"].get(lang, {})
        num = stats.get("num_samples", 0)
        hours = stats.get("total_duration_hours", 0)
        source = stats.get("source", "N/A")
        print(f"{lang:<12} {num:>8} {hours:>8.2f} {source:<15}")
        total_samples += num
        total_hours += hours
    
    print("-" * 50)
    print(f"{'TOTAL':<12} {total_samples:>8} {total_hours:>8.2f}")
    
    print(f"\n✅ Manifest: {manifest_path}")
    print(f"✅ TSV: {output_dir / 'evaluation.tsv'}")
    
    return 0


if __name__ == "__main__":
    sys.exit(main())
