"""
Preprocess audio files → EnCodec continuous embeddings.
Saves 128-dim embeddings at 75Hz for each audio file.
"""

import os
import json
import torch
import torchaudio
import argparse
from pathlib import Path
from encodec import EncodecModel
from encodec.utils import convert_audio


def preprocess(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load EnCodec
    codec = EncodecModel.encodec_model_24khz()
    codec.set_target_bandwidth(6.0)
    codec.eval().to(device)

    # Load original manifest
    with open(args.manifest, "r") as f:
        manifest = json.load(f)

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    emb_dir = out_dir / "embeddings"
    emb_dir.mkdir(exist_ok=True)

    new_manifest = []
    skipped = 0

    for i, entry in enumerate(manifest):
        audio_path = entry["audio_path"]
        if not os.path.exists(audio_path):
            skipped += 1
            continue

        try:
            wav, sr = torchaudio.load(audio_path)
            wav = convert_audio(wav, sr, 24000, 1).to(device)

            with torch.no_grad():
                # Get continuous embeddings (before VQ)
                emb = codec.encoder(wav.unsqueeze(0))  # [1, 128, T]
                emb = emb.squeeze(0).cpu()  # [128, T]

            # Save embedding
            emb_path = str(emb_dir / f"{i:06d}.pt")
            torch.save(emb, emb_path)

            new_manifest.append({
                "emb_path": emb_path,
                "text": entry["text"],
                "audio_path": audio_path,
                "emb_frames": emb.shape[1],
                "duration": wav.shape[-1] / 24000,
            })

            if (i + 1) % 500 == 0:
                print(f"  Processed {i+1}/{len(manifest)} files")

        except Exception as e:
            print(f"  Error on {audio_path}: {e}")
            skipped += 1

    # Save new manifest
    manifest_path = str(out_dir / "manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(new_manifest, f, indent=2, ensure_ascii=False)

    print(f"\nDone: {len(new_manifest)} files processed, {skipped} skipped")
    print(f"Manifest: {manifest_path}")
    print(f"Embeddings: {emb_dir}")

    # Stats
    total_frames = sum(e["emb_frames"] for e in new_manifest)
    total_dur = sum(e["duration"] for e in new_manifest)
    print(f"Total: {total_dur/3600:.2f}h, {total_frames} frames, avg {total_frames/len(new_manifest):.0f} frames/sample")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--manifest", default="processed_data_100mel/manifest.json")
    parser.add_argument("--output_dir", default="processed_data_codec")
    args = parser.parse_args()
    preprocess(args)
