"""Fast parallel EnCodec encoding from downloaded parquet files."""
import os, json, torch, torchaudio, time, sys
import pyarrow.parquet as pq
import numpy as np
from encodec import EncodecModel
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp


def encode_shard(shard_path, out_dir, start_idx, shard_id):
    """Process one parquet shard on GPU."""
    torch.cuda.set_device(0)
    encodec = EncodecModel.encodec_model_24khz().cuda()
    encodec.set_target_bandwidth(6.0)
    encodec.eval()

    tok_dir = Path(out_dir) / "tokens"
    table = pq.read_table(shard_path)

    results = []
    idx = start_idx

    for i in range(len(table)):
        try:
            row = table.slice(i, 1).to_pydict()
            text = row["text"][0]
            spk = row["speaker_id"][0]
            dur = row["duration_s"][0]
            sr = row["sample_rate"][0]

            if not text or dur < 1.0 or dur > 30.0:
                continue

            # Decode audio from parquet bytes
            audio_dict = row["audio"][0]
            wav = torch.tensor(np.array(audio_dict["array"]), dtype=torch.float32)
            audio_sr = audio_dict["sampling_rate"]

            if audio_sr != 24000:
                wav = torchaudio.functional.resample(wav, audio_sr, 24000)

            with torch.no_grad():
                codes = encodec.encode(wav.unsqueeze(0).unsqueeze(0).cuda())[0][0].squeeze(0).cpu()

            tok_path = tok_dir / f"{idx:06d}.pt"
            torch.save(codes, str(tok_path))

            results.append({
                "token_path": str(tok_path),
                "text": text,
                "speaker_id": spk,
                "duration": float(dur),
                "n_tokens": int(codes.shape[1]),
                "n_rvq": int(codes.shape[0]),
                "language": row.get("language", ["hi"])[0],
            })
            idx += 1

        except Exception as e:
            continue

    print(f"  Shard {shard_id}: {len(results)} samples encoded", flush=True)
    return results


def main():
    data_dir = Path("/home/ubuntu/indic_dataset_raw/data")
    out_dir = Path("/home/ubuntu/lewm-tts/processed_data_v7_multi")
    tok_dir = out_dir / "tokens"
    tok_dir.mkdir(parents=True, exist_ok=True)

    shards = sorted(data_dir.glob("train-*.parquet"))
    print(f"Found {len(shards)} shards")

    # Check which shards we already processed
    existing = len(list(tok_dir.glob("*.pt")))
    print(f"Existing tokens: {existing}")

    # Process sequentially (single GPU) but shard by shard
    all_results = []

    # Load existing manifest if any
    manifest_path = out_dir / "manifest.json"
    if manifest_path.exists():
        with open(manifest_path) as f:
            all_results = json.load(f)
        print(f"Loaded existing manifest: {len(all_results)} entries")

    idx = existing
    t0 = time.time()

    # Figure out which shard to start from based on existing count
    # ~4000 samples per shard (189K / 48)
    start_shard = existing // 4000
    print(f"Starting from shard {start_shard}")

    encodec = EncodecModel.encodec_model_24khz().cuda()
    encodec.set_target_bandwidth(6.0)
    encodec.eval()

    for shard_id, shard_path in enumerate(shards):
        if shard_id < start_shard:
            continue

        print(f"Processing shard {shard_id}/{len(shards)}: {shard_path.name}", flush=True)
        table = pq.read_table(shard_path)
        shard_results = []

        for i in range(len(table)):
            try:
                row = table.slice(i, 1).to_pydict()
                text = row["text"][0]
                spk = row["speaker_id"][0]
                dur = row["duration_s"][0]

                if not text or dur < 1.0 or dur > 30.0:
                    continue

                audio_dict = row["audio"][0]
                wav = torch.tensor(np.array(audio_dict["array"]), dtype=torch.float32)
                audio_sr = audio_dict["sampling_rate"]

                if audio_sr != 24000:
                    wav = torchaudio.functional.resample(wav, audio_sr, 24000)

                with torch.no_grad():
                    codes = encodec.encode(wav.unsqueeze(0).unsqueeze(0).cuda())[0][0].squeeze(0).cpu()

                tok_path = tok_dir / f"{idx:06d}.pt"
                torch.save(codes, str(tok_path))

                shard_results.append({
                    "token_path": str(tok_path),
                    "text": text,
                    "speaker_id": spk,
                    "duration": float(dur),
                    "n_tokens": int(codes.shape[1]),
                    "n_rvq": int(codes.shape[0]),
                })
                idx += 1

            except Exception as e:
                continue

        all_results.extend(shard_results)
        elapsed = time.time() - t0
        total_h = sum(e["duration"] for e in all_results) / 3600
        print(f"  Shard {shard_id} done: +{len(shard_results)} samples | "
              f"Total: {len(all_results)} samples, {total_h:.1f}h | "
              f"{elapsed:.0f}s elapsed", flush=True)

        # Save after each shard
        with open(manifest_path, "w", encoding="utf-8") as f:
            json.dump(all_results, f, ensure_ascii=False)

        # Free memory
        del table

    # Final: add speaker IDs
    spks = {}
    for e in all_results:
        spks[e["speaker_id"]] = spks.get(e["speaker_id"], 0) + e["duration"]
    spk_sorted = sorted(spks.keys(), key=lambda s: -spks[s])
    spk_map = {s: i for i, s in enumerate(spk_sorted)}

    for e in all_results:
        e["speaker_id_num"] = spk_map[e["speaker_id"]]

    with open(manifest_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False)
    with open(out_dir / "speakers.json", "w") as f:
        json.dump({s: {"id": spk_map[s], "hours": round(spks[s]/3600, 2)} for s in spk_sorted}, f, indent=2)

    print(f"\nDONE: {len(all_results)} samples, {sum(spks.values())/3600:.1f}h, {len(spks)} speakers", flush=True)
    for s in spk_sorted[:15]:
        print(f"  {s}: {spks[s]/3600:.1f}h (id={spk_map[s]})")


if __name__ == "__main__":
    main()
