"""
Fast parallel preprocessing: HuggingFace dataset → EnCodec 8-level tokens.
Uses multiprocessing for audio loading + GPU for EnCodec encoding.
"""
import os, json, torch, torchaudio, numpy as np
from pathlib import Path
from encodec import EncodecModel
from datasets import load_dataset
from torch.utils.data import DataLoader
import argparse
from concurrent.futures import ThreadPoolExecutor
import time


def process_batch_gpu(batch_wavs, batch_meta, encodec_model, tok_dir, start_idx):
    """Encode a batch of waveforms on GPU, save tokens."""
    results = []
    for i, (wav, meta) in enumerate(zip(batch_wavs, batch_meta)):
        idx = start_idx + i
        try:
            with torch.no_grad():
                wav_tensor = wav.unsqueeze(0).unsqueeze(0).cuda()  # [1, 1, T]
                encoded = encodec_model.encode(wav_tensor)
                codes = encoded[0][0].squeeze(0).cpu()  # [n_rvq, T_codec]

            tok_path = tok_dir / f"{idx:06d}.pt"
            torch.save(codes, str(tok_path))

            results.append({
                "token_path": str(tok_path),
                "text": meta["text"],
                "speaker_id": meta["speaker_id"],
                "duration": meta["duration"],
                "n_tokens": codes.shape[1],
                "n_rvq": codes.shape[0],
                "language": meta.get("language", "hi"),
            })
        except Exception as e:
            pass  # skip corrupt samples
    return results


def main(args):
    print("Loading EnCodec...")
    encodec = EncodecModel.encodec_model_24khz().cuda()
    encodec.set_target_bandwidth(6.0)  # 8 RVQ levels
    encodec.eval()

    out_dir = Path(args.output_dir)
    tok_dir = out_dir / "tokens"
    tok_dir.mkdir(parents=True, exist_ok=True)

    print("Loading dataset (streaming)...")
    ds = load_dataset(
        args.dataset, token=args.token, streaming=True, split="train"
    )

    # Filter speakers with enough data
    min_hours = args.min_speaker_hours
    target_sr = 24000

    manifest = []
    speaker_dur = {}
    idx = 0
    batch_wavs = []
    batch_meta = []
    batch_size = args.batch_size
    t0 = time.time()
    skipped = 0

    for sample in ds:
        spk = sample.get("speaker_id", "unknown")
        dur = sample.get("duration_s", 0)
        text = sample.get("text", "")

        if not text or dur < 1.0 or dur > 30.0:
            skipped += 1
            continue

        # Load and resample audio
        audio = sample["audio"]
        wav = torch.tensor(audio["array"], dtype=torch.float32)
        sr = audio["sampling_rate"]

        if sr != target_sr:
            wav = torchaudio.functional.resample(wav, sr, target_sr)

        batch_wavs.append(wav)
        batch_meta.append({
            "text": text, "speaker_id": spk,
            "duration": dur, "language": sample.get("language", "hi"),
        })

        if len(batch_wavs) >= batch_size:
            results = process_batch_gpu(batch_wavs, batch_meta, encodec, tok_dir, idx)
            manifest.extend(results)
            idx += len(results)
            for r in results:
                s = r["speaker_id"]
                speaker_dur[s] = speaker_dur.get(s, 0) + r["duration"]
            batch_wavs = []
            batch_meta = []

            if idx % 1000 == 0:
                elapsed = time.time() - t0
                rate = idx / elapsed
                print(f"  {idx} samples | {sum(speaker_dur.values())/3600:.1f}h | "
                      f"{rate:.0f} samples/s | {elapsed:.0f}s")

        if args.max_samples and idx >= args.max_samples:
            break

    # Process remaining
    if batch_wavs:
        results = process_batch_gpu(batch_wavs, batch_meta, encodec, tok_dir, idx)
        manifest.extend(results)
        for r in results:
            s = r["speaker_id"]
            speaker_dur[s] = speaker_dur.get(s, 0) + r["duration"]

    # Assign numeric speaker IDs
    speakers_sorted = sorted(speaker_dur.keys(), key=lambda s: -speaker_dur[s])
    spk_to_id = {s: i for i, s in enumerate(speakers_sorted)}

    for entry in manifest:
        entry["speaker_id_num"] = spk_to_id[entry["speaker_id"]]

    # Save manifest
    with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False)

    # Save speaker map
    speaker_info = {s: {"id": spk_to_id[s], "hours": speaker_dur[s]/3600}
                    for s in speakers_sorted}
    with open(out_dir / "speakers.json", "w") as f:
        json.dump(speaker_info, f, indent=2)

    total_h = sum(speaker_dur.values()) / 3600
    print(f"\nDone: {len(manifest)} samples, {total_h:.1f}h, {len(speaker_dur)} speakers")
    print(f"Skipped: {skipped}")
    print("Speakers:")
    for s in speakers_sorted[:15]:
        print(f"  {s}: {speaker_dur[s]/3600:.1f}h (id={spk_to_id[s]})")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="saichranreddy/internal_indic_h")
    p.add_argument("--token", default="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")
    p.add_argument("--output_dir", default="/home/ubuntu/lewm-tts/processed_data_v7_multi")
    p.add_argument("--batch_size", type=int, default=1)
    p.add_argument("--min_speaker_hours", type=float, default=5.0)
    p.add_argument("--max_samples", type=int, default=0)
    main(p.parse_args())
