"""Encode ALL parquet shards to EnCodec tokens."""
import json, torch, torchaudio, time, numpy as np, io
import pyarrow.parquet as pq
from encodec import EncodecModel
from pathlib import Path

def decode_audio_bytes(audio_bytes, target_sr=24000):
    """Decode wav bytes from parquet → tensor."""
    buf = io.BytesIO(audio_bytes)
    wav, sr = torchaudio.load(buf)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav.squeeze(0)  # [T]

def main():
    data_dir = Path("/home/ubuntu/indic_dataset_raw/data")
    out_dir = Path("/home/ubuntu/lewm-tts/processed_data_v7_multi")
    tok_dir = out_dir / "tokens"
    tok_dir.mkdir(parents=True, exist_ok=True)

    shards = sorted(data_dir.glob("train-*.parquet"))
    print(f"{len(shards)} shards", flush=True)

    encodec = EncodecModel.encodec_model_24khz().cuda()
    encodec.set_target_bandwidth(6.0)
    encodec.eval()

    manifest = []
    idx = 0
    t0 = time.time()
    errors = 0

    for sid, shard_path in enumerate(shards):
        table = pq.read_table(shard_path)
        shard_count = 0

        for i in range(len(table)):
            try:
                row = table.slice(i, 1).to_pydict()
                text = row["text"][0]
                spk = row["speaker_id"][0]
                dur = row["duration_s"][0]
                if not text or dur < 1.0 or dur > 30.0:
                    continue

                audio_dict = row["audio"][0]
                wav = decode_audio_bytes(audio_dict["bytes"])

                with torch.no_grad():
                    codes = encodec.encode(wav.unsqueeze(0).unsqueeze(0).cuda())[0][0].squeeze(0).cpu()

                tok_path = tok_dir / f"{idx:06d}.pt"
                torch.save(codes, str(tok_path))

                manifest.append({
                    "token_path": str(tok_path), "text": text,
                    "speaker_id": spk, "duration": float(dur),
                    "n_tokens": int(codes.shape[1]), "n_rvq": int(codes.shape[0]),
                })
                idx += 1
                shard_count += 1
            except Exception as e:
                errors += 1
                if errors <= 5:
                    print(f"  Error: {e}", flush=True)
                continue

        elapsed = time.time() - t0
        total_h = sum(e["duration"] for e in manifest) / 3600
        rate = idx / max(elapsed, 1)
        remaining = (189327 - idx) / max(rate, 1)
        print(f"Shard {sid+1}/{len(shards)} | +{shard_count} | "
              f"Total: {idx} samples, {total_h:.1f}h | "
              f"{rate:.0f} samp/s | ETA: {remaining/60:.0f}min | err: {errors}", flush=True)

        with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
            json.dump(manifest, f, ensure_ascii=False)
        del table

    # Speaker map
    spks = {}
    for e in manifest:
        spks[e["speaker_id"]] = spks.get(e["speaker_id"], 0) + e["duration"]
    spk_sorted = sorted(spks.keys(), key=lambda s: -spks[s])
    spk_map = {s: i for i, s in enumerate(spk_sorted)}
    for e in manifest:
        e["speaker_id_num"] = spk_map[e["speaker_id"]]

    with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False)
    with open(out_dir / "speakers.json", "w") as f:
        json.dump({s: {"id": spk_map[s], "hours": round(spks[s]/3600, 2)} for s in spk_sorted}, f, indent=2)

    print(f"\nDONE: {len(manifest)} samples, {sum(spks.values())/3600:.1f}h, {len(spks)} speakers | errors: {errors}", flush=True)

if __name__ == "__main__":
    main()
