#!/usr/bin/env python3
"""
Fine-tune the Soprano decoder on hidden states from the Hindi fine-tuned LM.
The base decoder was trained on base-model hidden states; after LM fine-tuning,
hidden states are out-of-distribution. This script trains the decoder to map
our LM's hidden states -> waveform using (text, token_ids, wav) from the dataset.

Usage:
  python train_decoder.py --checkpoint-dir /path/to/checkpoints --input-dir /path/to/data --save-dir /path/to/decoder_out
  python train_decoder.py ... --max-samples 2000   # quick test (~10–15 min)
  python train_decoder.py ... --epochs 2            # full run (~3–6 hours for ~70k samples, 2 epochs)

Time estimate (single GPU, ~70k train samples):
  - 1 epoch, batch_size 4: ~2–3 hours
  - 2 epochs: ~4–6 hours
  --max-samples 2000, 1 epoch: ~10–15 min (sanity check)

Decoder training know-how (when output is still noisy):
  - Decoder needs enough steps: 1 full epoch is a minimum; 2+ full epochs recommended.
  - When resuming, use a lower LR (e.g. --lr 5e-5) for stable fine-tuning.
  - Run to completion: use --resume-decoder and --start-epoch 2 to finish the second epoch, or
    --resume-decoder with default epochs to do 2 more full epochs from the checkpoint.
  - Inference: use peak-normalized decoder output (already in inference.py) and Hindi model
    hidden states (not base) so train/inference match.
"""
import argparse
import json
import os
import pathlib
import time

import numpy as np
import torch
import torchaudio
from scipy.io import wavfile
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from inference import (
    AUDIO_TOKEN_MAX,
    AUDIO_TOKEN_MIN,
    EOS_AUDIO,
    SAMPLE_RATE,
    TOKEN_SIZE,
    load_model_and_decoder,
)

# Decoder training defaults (increase BATCH_SIZE for higher GPU utilization on large GPUs)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16  # 4 was underusing GPU; 16–32 recommended for A100/80GB
EPOCHS = 2
LR = 1e-4
MAX_TOKEN_LEN = 800  # cap sequence length per sample (memory)
VAL_EVERY = 500
SAVE_EVERY = 2000


def build_decoder_index(input_dir, train_path, val_path):
    """
    Match train.json / val.json (text, tokens) to wav paths using metadata.txt.
    metadata format: filename|transcript. First match per transcript is used
    (duplicate transcripts get same wav; ideally transcript is unique).
    Returns (train_list, val_list) of (text, token_ids, wav_path).
    """
    meta_path = pathlib.Path(input_dir) / "metadata.txt"
    if not meta_path.is_file():
        raise FileNotFoundError(f"Need {meta_path} (filename|transcript per line)")
    transcript_to_file = {}
    with open(meta_path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or "|" not in line:
                continue
            filename, transcript = line.split("|", maxsplit=1)
            t = transcript.strip()
            if t not in transcript_to_file:
                transcript_to_file[t] = filename.strip()

    input_dir = pathlib.Path(input_dir)
    wavs_dir = input_dir / "wavs"

    def add_paths(json_path):
        with open(json_path, encoding="utf-8") as f:
            data = json.load(f)
        out = []
        for item in data:
            text = item[0].strip()
            tokens = item[1]
            fn = transcript_to_file.get(text)
            if fn is None:
                continue
            wav_path = wavs_dir / f"{fn}.wav"
            if not wav_path.is_file():
                continue
            out.append((text, tokens, str(wav_path)))
        return out

    train_list = add_paths(train_path)
    val_list = add_paths(val_path)
    return train_list, val_list


class DecoderDataset(Dataset):
    """Returns (text, token_ids, audio) for decoder training."""

    def __init__(self, items, max_token_len=MAX_TOKEN_LEN, sample_rate=SAMPLE_RATE):
        self.items = items
        self.max_token_len = max_token_len
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        text, token_ids, wav_path = self.items[idx]
        # Cap token length
        if len(token_ids) > self.max_token_len:
            token_ids = token_ids[: self.max_token_len]
        # Load audio: mono, float32, resample to SAMPLE_RATE
        sr, audio = wavfile.read(wav_path)
        audio = torch.from_numpy(audio).float()
        if audio.dim() == 2:
            audio = audio.mean(dim=-1)
        if audio.dtype == torch.int16:
            audio = audio / 32768.0
        elif audio.dtype == torch.int32:
            audio = audio / 2147483648.0
        if sr != self.sample_rate:
            audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
        return text, token_ids, audio.numpy()


def collate_decoder_batch(samples):
    """Collate to padded (texts, token_ids_list, audio_list, lengths)."""
    texts = [s[0] for s in samples]
    token_lists = [s[1] for s in samples]
    audios = [s[2] for s in samples]
    L_list = [len(t) for t in token_lists]
    return texts, token_lists, audios, L_list


def get_hidden_states_for_tokens(model, tokenizer, text, token_ids, device):
    """Teacher-forced forward; return hidden states at audio token positions (L, 512)."""
    prompt = f"[STOP][TEXT]{text}[START]"
    prompt_ids = tokenizer(
        prompt, return_tensors="pt", truncation=True, max_length=512
    ).input_ids.to(device)[0]
    audio_ids = torch.tensor(token_ids, dtype=torch.long, device=device)
    full_ids = torch.cat([prompt_ids, audio_ids]).unsqueeze(0)
    with torch.no_grad():
        out = model(full_ids, output_hidden_states=True)
    start = prompt_ids.size(0)
    end = start + len(token_ids)
    hidden = out.hidden_states[-1][0, start:end, :].float()
    return hidden


def trim_decoder_target(L, audio_np):
    """Target length for L tokens: (L*TOKEN_SIZE - TOKEN_SIZE) samples. Trim/crop audio."""
    n_keep = L * TOKEN_SIZE - TOKEN_SIZE
    if n_keep <= 0:
        n_keep = 1
    audio = torch.from_numpy(audio_np).float()
    if audio.numel() >= n_keep:
        return audio[-(n_keep):]
    return torch.nn.functional.pad(audio, (0, n_keep - audio.numel()))


def train_step(model, tokenizer, decoder, batch, device, optimizer, criterion):
    texts, token_lists, audios, L_list = batch
    optimizer.zero_grad()
    total_loss = 0.0
    n = 0
    for i in range(len(texts)):
        text, tokens, audio_np = texts[i], token_lists[i], audios[i]
        L = len(tokens)
        if L < 2:
            continue
        hidden = get_hidden_states_for_tokens(model, tokenizer, text, tokens, device)
        # (1, 512, L)
        hidden = hidden.unsqueeze(0).transpose(1, 2).to(device)
        pred = decoder(hidden)
        raw = pred[0].squeeze().float()
        n_keep = L * TOKEN_SIZE - TOKEN_SIZE
        if raw.numel() >= n_keep:
            raw = raw[-(n_keep):]
        target = trim_decoder_target(L, audio_np).to(device)
        if raw.numel() != target.numel():
            min_len = min(raw.numel(), target.numel())
            raw = raw[:min_len]
            target = target[:min_len]
        # Scale pred to target range (decoder output can be unbounded)
        t_max = target.abs().max().clamp(min=1e-6)
        raw = raw * (t_max / raw.abs().max().clamp(min=1e-6))
        loss = criterion(raw, target)
        loss.backward()
        total_loss += loss.item()
        n += 1
    if n > 0:
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
        optimizer.step()
    return total_loss / max(n, 1)


@torch.no_grad()
def val_pass(model, tokenizer, decoder, val_loader, device, criterion):
    decoder.eval()
    total_loss = 0.0
    n = 0
    for batch in val_loader:
        texts, token_lists, audios, _ = batch
        for i in range(len(texts)):
            text, tokens, audio_np = texts[i], token_lists[i], audios[i]
            L = len(tokens)
            if L < 2:
                continue
            hidden = get_hidden_states_for_tokens(model, tokenizer, text, tokens, device)
            hidden = hidden.unsqueeze(0).transpose(1, 2).to(device)
            pred = decoder(hidden)
            raw = pred[0].squeeze().float()
            n_keep = L * TOKEN_SIZE - TOKEN_SIZE
            if raw.numel() >= n_keep:
                raw = raw[-(n_keep):]
            target = trim_decoder_target(L, audio_np).to(device)
            min_len = min(raw.numel(), target.numel())
            if min_len == 0:
                continue
            raw, target = raw[:min_len], target[:min_len]
            total_loss += criterion(raw, target).item()
            n += 1
    decoder.train()
    return total_loss / max(n, 1)


def main():
    parser = argparse.ArgumentParser(description="Fine-tune decoder on Hindi LM hidden states")
    parser.add_argument("--checkpoint-dir", type=str, required=True,
                        help="Path to Hindi checkpoint (model, tokenizer, decoder.pth)")
    parser.add_argument("--input-dir", type=str, required=True,
                        help="Dataset dir with metadata.txt, wavs/, train.json, val.json")
    parser.add_argument("--save-dir", type=str, required=True,
                        help="Where to save fine-tuned decoder.pth")
    parser.add_argument("--epochs", type=int, default=EPOCHS)
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LR)
    parser.add_argument("--max-samples", type=int, default=None,
                        help="Cap train/val size (for quick test)")
    parser.add_argument("--val-every", type=int, default=VAL_EVERY)
    parser.add_argument("--save-every", type=int, default=SAVE_EVERY)
    parser.add_argument("--max-token-len", type=int, default=MAX_TOKEN_LEN)
    parser.add_argument("--resume-decoder", type=str, default=None,
                        help="Resume from this decoder checkpoint (e.g. save-dir/decoder_latest.pth)")
    parser.add_argument("--start-epoch", type=int, default=1,
                        help="Start from this epoch (1-based). Use with --resume-decoder to skip completed epochs (e.g. --start-epoch 2)")
    parser.add_argument("--resume-lr", type=float, default=None,
                        help="When resuming, use this LR instead of --lr (e.g. 5e-5 for gentler fine-tuning)")
    args = parser.parse_args()

    device = torch.device(DEVICE)
    train_path = pathlib.Path(args.input_dir) / "train.json"
    val_path = pathlib.Path(args.input_dir) / "val.json"
    if not train_path.is_file() or not val_path.is_file():
        raise FileNotFoundError(f"Need {train_path} and {val_path}")

    print("Building decoder dataset index (matching transcripts to wavs)...")
    train_list, val_list = build_decoder_index(args.input_dir, train_path, val_path)
    print(f"  Train samples with wav: {len(train_list)}, val: {len(val_list)}")
    if not train_list or not val_list:
        raise RuntimeError("No (text, tokens, wav) matches. Check metadata.txt and wavs/.")

    if args.max_samples is not None:
        train_list = train_list[: args.max_samples]
        val_list = val_list[: min(len(val_list), max(100, args.max_samples // 10))]
        print(f"  Capped to train={len(train_list)}, val={len(val_list)}")

    train_ds = DecoderDataset(train_list, max_token_len=args.max_token_len)
    val_ds = DecoderDataset(val_list, max_token_len=args.max_token_len)
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_decoder_batch,
        pin_memory=(device.type == "cuda"),
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=collate_decoder_batch,
        pin_memory=(device.type == "cuda"),
    )

    print("Loading Hindi model and decoder...")
    model, tokenizer, decoder = load_model_and_decoder(args.checkpoint_dir, device)
    if args.resume_decoder and os.path.isfile(args.resume_decoder):
        print(f"Resuming decoder from {args.resume_decoder}")
        decoder.load_state_dict(torch.load(args.resume_decoder, map_location=device))
        if args.resume_lr is not None:
            args.lr = args.resume_lr
            print(f"Using resume LR: {args.lr}")
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    decoder.train()
    decoder.to(device)

    optimizer = torch.optim.AdamW(decoder.parameters(), lr=args.lr)
    criterion = torch.nn.L1Loss()

    os.makedirs(args.save_dir, exist_ok=True)
    global_step = 0
    best_val = float("inf")
    if args.start_epoch > 1:
        global_step = (args.start_epoch - 1) * len(train_loader)
        print(f"Starting from epoch {args.start_epoch} (global_step ~{global_step})")

    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * args.epochs
    print(f"Decoder fine-tuning: {args.epochs} epochs, ~{steps_per_epoch} steps/epoch, ~{total_steps} steps total")
    print("Rough time: 1 epoch ~2–3 h (70k samples, batch 4); --max-samples 2000 ~10–15 min")
    print("Starting training...")

    for epoch in range(args.start_epoch - 1, args.epochs):
        epoch_loss = 0.0
        epoch_count = 0
        t0 = time.perf_counter()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
        for batch in pbar:
            loss = train_step(model, tokenizer, decoder, batch, device, optimizer, criterion)
            epoch_loss += loss
            epoch_count += 1
            global_step += 1
            pbar.set_postfix(loss=f"{loss:.4f}")
            if global_step % args.val_every == 0:
                v = val_pass(model, tokenizer, decoder, val_loader, device, criterion)
                tqdm.write(f"  step {global_step} val_loss={v:.4f}")
                if v < best_val:
                    best_val = v
                    torch.save(decoder.state_dict(), os.path.join(args.save_dir, "decoder_best.pth"))
            if global_step % args.save_every == 0:
                torch.save(decoder.state_dict(), os.path.join(args.save_dir, "decoder_latest.pth"))
        elapsed = time.perf_counter() - t0
        avg = epoch_loss / max(epoch_count, 1)
        print(f"Epoch {epoch+1} done in {elapsed/60:.1f} min, avg train loss={avg:.4f}")

    torch.save(decoder.state_dict(), os.path.join(args.save_dir, "decoder.pth"))
    print(f"Saved decoder to {args.save_dir}/decoder.pth. Use this in inference.py (copy to checkpoint-dir or set decoder path).")


if __name__ == "__main__":
    main()
