#!/usr/bin/env python3
"""
Create WebDataset tar shards from evaluation data and upload to R2.

This script:
1. Reads evaluation data from data/evaluation/
2. Creates tar shards (one per language for easy management)
3. Uploads to R2 bucket 'xcodec' with prefix 'evaluation/'

Usage:
    python scripts/data_prep/create_and_upload_eval_shards.py
"""

import os
import sys
import json
import tarfile
import tempfile
from pathlib import Path
from dotenv import load_dotenv
import boto3
from botocore.config import Config
from tqdm import tqdm

# Load environment variables
load_dotenv()

# R2 Configuration
R2_ENDPOINT_URL = os.getenv("R2_ENDPOINT_URL") or os.getenv("S3_API")
R2_ACCESS_KEY_ID = os.getenv("R2_ACCESS_KEY_ID")
R2_SECRET_ACCESS_KEY = os.getenv("R2_SECRET_ACCESS_KEY")
R2_BUCKET = os.getenv("R2_PROCESSED_BUCKET", "xcodec")

EVALUATION_DIR = Path(__file__).parent.parent.parent / "data" / "evaluation"
R2_PREFIX = "evaluation"

LANGUAGES = [
    "telugu", "hindi", "english", "tamil", "kannada", "malayalam",
    "assamese", "odia", "marathi", "punjabi", "gujarati", "bengali"
]


def get_s3_client():
    """Create S3 client for R2."""
    return boto3.client(
        's3',
        endpoint_url=R2_ENDPOINT_URL,
        aws_access_key_id=R2_ACCESS_KEY_ID,
        aws_secret_access_key=R2_SECRET_ACCESS_KEY,
        config=Config(
            signature_version='s3v4',
            s3={'addressing_style': 'path'}
        )
    )


def create_tar_shard(lang: str, lang_dir: Path, output_path: Path) -> dict:
    """
    Create a WebDataset-compatible tar shard for a language.
    
    Each sample contains:
    - {sample_id}.wav - the audio file
    - {sample_id}.json - metadata (language, duration, transcription)
    """
    metadata_path = lang_dir / "metadata.json"
    audio_dir = lang_dir / "audio"
    
    if not metadata_path.exists():
        return {"error": f"No metadata found for {lang}"}
    
    with open(metadata_path) as f:
        metadata = json.load(f)
    
    samples = metadata.get("samples", [])
    if not samples:
        return {"error": f"No samples found for {lang}"}
    
    total_size = 0
    sample_count = 0
    
    with tarfile.open(output_path, 'w') as tar:
        for sample in tqdm(samples, desc=f"  Creating shard for {lang}", leave=False):
            # Get the audio filename
            audio_rel_path = sample["path"]
            audio_filename = Path(audio_rel_path).name
            audio_path = EVALUATION_DIR / audio_rel_path
            
            if not audio_path.exists():
                print(f"    ⚠️ Missing audio: {audio_path}")
                continue
            
            # Sample ID (without extension)
            sample_id = sample["id"]
            
            # Add audio file to tar
            tar.add(str(audio_path), arcname=f"{sample_id}.wav")
            
            # Create metadata JSON for this sample
            sample_meta = {
                "id": sample_id,
                "language": sample["language"],
                "duration": sample["duration"],
                "transcription": sample.get("transcription", ""),
                "source": sample.get("source", "unknown"),
            }
            
            # Add metadata JSON to tar
            meta_bytes = json.dumps(sample_meta, ensure_ascii=False).encode('utf-8')
            meta_info = tarfile.TarInfo(name=f"{sample_id}.json")
            meta_info.size = len(meta_bytes)
            
            import io
            tar.addfile(meta_info, io.BytesIO(meta_bytes))
            
            total_size += audio_path.stat().st_size + len(meta_bytes)
            sample_count += 1
    
    return {
        "language": lang,
        "samples": sample_count,
        "size_bytes": output_path.stat().st_size,
        "path": str(output_path),
    }


def upload_to_r2(local_path: Path, r2_key: str, s3_client) -> bool:
    """Upload file to R2."""
    try:
        file_size = local_path.stat().st_size
        
        with open(local_path, 'rb') as f:
            s3_client.upload_fileobj(
                f,
                R2_BUCKET,
                r2_key,
                ExtraArgs={'ContentType': 'application/x-tar'},
                Callback=lambda bytes_transferred: None
            )
        
        return True
    except Exception as e:
        print(f"  ❌ Upload failed: {e}")
        return False


def main():
    print("=" * 70)
    print("📦 CREATING EVALUATION SHARDS AND UPLOADING TO R2")
    print("=" * 70)
    print(f"Evaluation data: {EVALUATION_DIR}")
    print(f"R2 Bucket: {R2_BUCKET}")
    print(f"R2 Prefix: {R2_PREFIX}/")
    print()
    
    # Verify R2 credentials
    if not all([R2_ENDPOINT_URL, R2_ACCESS_KEY_ID, R2_SECRET_ACCESS_KEY]):
        print("❌ Missing R2 credentials in .env file")
        print("   Required: R2_ENDPOINT_URL, R2_ACCESS_KEY_ID, R2_SECRET_ACCESS_KEY")
        return 1
    
    # Create S3 client
    print("🔗 Connecting to R2...")
    s3_client = get_s3_client()
    
    # Test connection
    try:
        s3_client.head_bucket(Bucket=R2_BUCKET)
        print(f"  ✅ Connected to bucket: {R2_BUCKET}")
    except Exception as e:
        print(f"  ❌ Failed to connect to R2: {e}")
        return 1
    
    # Create temp directory for shards
    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)
        
        results = []
        uploaded_shards = []
        
        for lang in LANGUAGES:
            lang_dir = EVALUATION_DIR / lang
            
            if not lang_dir.exists():
                print(f"  ⚠️ {lang}: Directory not found, skipping")
                continue
            
            print(f"\n📁 Processing: {lang.upper()}")
            
            # Create tar shard
            shard_name = f"{lang}_eval.tar"
            shard_path = tmpdir / shard_name
            
            result = create_tar_shard(lang, lang_dir, shard_path)
            
            if "error" in result:
                print(f"  ❌ {result['error']}")
                continue
            
            print(f"  ✅ Created shard: {result['samples']} samples, {result['size_bytes'] / 1024 / 1024:.2f} MB")
            
            # Upload to R2
            r2_key = f"{R2_PREFIX}/{shard_name}"
            print(f"  ☁️  Uploading to R2: {r2_key}...")
            
            if upload_to_r2(shard_path, r2_key, s3_client):
                print(f"  ✅ Uploaded successfully")
                uploaded_shards.append({
                    "language": lang,
                    "r2_key": r2_key,
                    "samples": result["samples"],
                    "size_bytes": result["size_bytes"],
                })
            else:
                print(f"  ❌ Upload failed for {lang}")
            
            results.append(result)
        
        # Also upload the manifest
        print(f"\n📋 Uploading evaluation manifest...")
        manifest_path = EVALUATION_DIR / "evaluation_manifest.json"
        if manifest_path.exists():
            manifest_r2_key = f"{R2_PREFIX}/evaluation_manifest.json"
            if upload_to_r2(manifest_path, manifest_r2_key, s3_client):
                print(f"  ✅ Manifest uploaded: {manifest_r2_key}")
        
        # Create and upload index file
        index = {
            "name": "xcodec2_indic_evaluation",
            "description": "Evaluation dataset for XCodec2 Indic (500 samples per language)",
            "total_languages": len(uploaded_shards),
            "total_samples": sum(s["samples"] for s in uploaded_shards),
            "total_size_mb": sum(s["size_bytes"] for s in uploaded_shards) / 1024 / 1024,
            "shards": uploaded_shards,
        }
        
        index_path = tmpdir / "index.json"
        with open(index_path, 'w') as f:
            json.dump(index, f, indent=2)
        
        index_r2_key = f"{R2_PREFIX}/index.json"
        if upload_to_r2(index_path, index_r2_key, s3_client):
            print(f"  ✅ Index uploaded: {index_r2_key}")
    
    # Final summary
    print("\n" + "=" * 70)
    print("📊 UPLOAD SUMMARY")
    print("=" * 70)
    
    print(f"\n{'Language':<12} {'Samples':>8} {'Size (MB)':>10} {'R2 Key':<30}")
    print("-" * 70)
    
    total_samples = 0
    total_size = 0
    
    for shard in uploaded_shards:
        size_mb = shard["size_bytes"] / 1024 / 1024
        print(f"{shard['language']:<12} {shard['samples']:>8} {size_mb:>10.2f} {shard['r2_key']:<30}")
        total_samples += shard["samples"]
        total_size += shard["size_bytes"]
    
    print("-" * 70)
    print(f"{'TOTAL':<12} {total_samples:>8} {total_size / 1024 / 1024:>10.2f}")
    
    print(f"\n✅ Uploaded {len(uploaded_shards)} shards to R2")
    print(f"   Bucket: {R2_BUCKET}")
    print(f"   Prefix: {R2_PREFIX}/")
    print(f"\n🔗 Access shards at:")
    for shard in uploaded_shards:
        print(f"   - {R2_BUCKET}/{shard['r2_key']}")
    
    return 0


if __name__ == "__main__":
    sys.exit(main())
