#!/usr/bin/env python3
"""Quick local generation from a checkpoint with reference audio tokens."""

import torch
import numpy as np
import soundfile as sf
from transformers import AutoTokenizer
from nemo.collections.tts.models import AudioCodecModel
from safetensors.torch import load_file
from utils.model import ChainedHeadLfm2ForCausalLM
from utils.config import ModelConfig
from utils.ipa import text_to_ipa

TL = 64400
END_OF_TEXT = 2
START_OF_SPEECH = TL + 1
END_OF_SPEECH = TL + 2
START_OF_HUMAN = TL + 3
END_OF_HUMAN = TL + 4
START_OF_AI = TL + 5
END_OF_AI = TL + 6
AUDIO_START = TL + 10

# Match training distribution: median=60, range 24-93
DEFAULT_REF_FRAMES = 62

DEVICE = "cuda"


def load_model(checkpoint_path):
    model_config = ModelConfig.from_yaml("configs/model_config.yaml")
    model = ChainedHeadLfm2ForCausalLM.from_pretrained_with_chained_heads(model_config)

    state = load_file(f"{checkpoint_path}/model.safetensors")
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing:
        print(f"  Warning: {len(missing)} missing keys (tied weights are normal)")
    model.tie_weights()
    model = model.to(torch.bfloat16).to(DEVICE).eval()

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    return model, tokenizer


def encode_reference(codec, wav_path, max_ref_frames=DEFAULT_REF_FRAMES):
    audio, sr = sf.read(wav_path, dtype="float32")
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    if sr != 22050:
        ratio = 22050 / sr
        n_out = int(len(audio) * ratio)
        idx = np.clip((np.arange(n_out) / ratio).astype(np.int64), 0, len(audio) - 1)
        audio = audio[idx]

    audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    audio_len = torch.tensor([audio_t.shape[1]], device=DEVICE)
    with torch.no_grad():
        codes = codec.encode(audio=audio_t, audio_len=audio_len)[0]
    cb1 = codes[0, 0].tolist()
    if len(cb1) > max_ref_frames:
        cb1 = cb1[:max_ref_frames]
    print(f"  Ref tokens: {len(cb1)} frames ({len(cb1)/12.5:.1f}s)")
    return [tok + AUDIO_START for tok in cb1]


@torch.no_grad()
def generate(model, tokenizer, codec, text, ref_wav,
             temperature=0.8, top_p=0.92, repetition_penalty=1.1,
             max_frames=300, ref_frames=DEFAULT_REF_FRAMES):
    ref_tokens = encode_reference(codec, ref_wav, max_ref_frames=ref_frames)

    ipa_text = text_to_ipa(text, language="hi")
    print(f"  IPA: {ipa_text}")

    text_prompt = f"hi: {ipa_text}"
    text_ids = tokenizer.encode(text_prompt, add_special_tokens=True)
    text_ids.append(END_OF_TEXT)

    input_ids_list = (
        ref_tokens
        + [START_OF_HUMAN] + text_ids + [END_OF_HUMAN, START_OF_AI, START_OF_SPEECH]
    )
    input_ids = torch.tensor([input_ids_list], dtype=torch.long, device=DEVICE)

    cb1, cb2, cb3, cb4 = [], [], [], []
    print(f"  Input length: {input_ids.shape[1]} tokens ({len(ref_tokens)} ref + {len(text_ids)+4} text/special)")

    for step in range(max_frames):
        backbone_out = model.model(input_ids=input_ids)
        hidden = backbone_out.last_hidden_state[:, -1:, :]
        h = hidden.squeeze(1)

        lm_logits = model.lm_head(h) / temperature

        # Repetition penalty on recently generated tokens
        if repetition_penalty != 1.0 and len(cb1) > 0:
            recent = input_ids[0, -50:].tolist()
            for prev_tok in set(recent):
                if prev_tok < lm_logits.shape[-1]:
                    if lm_logits[0, prev_tok] > 0:
                        lm_logits[0, prev_tok] /= repetition_penalty
                    else:
                        lm_logits[0, prev_tok] *= repetition_penalty

        # Top-p (nucleus) sampling
        sorted_logits, sorted_indices = torch.sort(lm_logits, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_mask = cumulative_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
        sorted_logits[sorted_mask] = float('-inf')
        lm_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)

        probs = torch.softmax(lm_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_id = next_token.item()

        if next_id == END_OF_SPEECH:
            print(f"  EOS at step {step}")
            break

        if next_id >= AUDIO_START:
            cb1.append(next_id - AUDIO_START)

            cb2_pred = torch.argmax(model.cb2_head(h), dim=-1)
            cb2_emb = model.cb2_embed(cb2_pred)

            cb3_in = model.cb3_mlp(torch.cat([h, cb2_emb], dim=-1))
            cb3_pred = torch.argmax(model.cb3_head(cb3_in), dim=-1)
            cb3_emb = model.cb3_embed(cb3_pred)

            cb4_in = model.cb4_mlp(torch.cat([h, cb2_emb + cb3_emb], dim=-1))
            cb4_pred = torch.argmax(model.cb4_head(cb4_in), dim=-1)

            cb2.append(cb2_pred.item())
            cb3.append(cb3_pred.item())
            cb4.append(cb4_pred.item())

        input_ids = torch.cat([input_ids, next_token], dim=1)

        if step % 50 == 0 and step > 0:
            print(f"  Step {step}, audio frames: {len(cb1)}")

    if len(cb1) == 0:
        print("  No audio tokens generated!")
        return None

    print(f"  Generated {len(cb1)} frames ({len(cb1)/12.5:.1f}s)")

    codes = torch.tensor([cb1, cb2, cb3, cb4], dtype=torch.long).unsqueeze(0).to(DEVICE)
    codes_len = torch.tensor([codes.shape[-1]], device=DEVICE)
    audio, _ = codec.decode(tokens=codes, tokens_len=codes_len)
    return audio.detach().cpu().numpy().squeeze()


if __name__ == "__main__":
    import argparse, time

    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", default="checkpoints_nextgen/checkpoint-48438")
    parser.add_argument("--ref-wav", default="/home/ubuntu/soprano_data/IISc_SYSPIN_Data/IISc_SYSPINProject_Hindi_Female_Spk001_HC/wav/IISc_SYSPINProject_hi_f_AGRI_00036.wav")
    parser.add_argument("--text", default="नमस्ते, मैं आपकी कैसे मदद कर सकता हूँ?")
    parser.add_argument("--output", default="output_test.wav")
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--top-p", type=float, default=0.92)
    parser.add_argument("--rep-penalty", type=float, default=1.1)
    parser.add_argument("--max-frames", type=int, default=300)
    parser.add_argument("--ref-frames", type=int, default=DEFAULT_REF_FRAMES)
    args = parser.parse_args()

    print("Loading model...")
    model, tokenizer = load_model(args.checkpoint)

    print("Loading NanoCodec...")
    codec = AudioCodecModel.from_pretrained(
        "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
    ).eval().to(DEVICE)

    print(f"\nGenerating: \"{args.text}\"")
    print(f"Reference: {args.ref_wav}")
    print(f"Params: temp={args.temperature}, top_p={args.top_p}, rep={args.rep_penalty}, ref_frames={args.ref_frames}")
    t0 = time.time()
    waveform = generate(model, tokenizer, codec, args.text, args.ref_wav,
                        temperature=args.temperature, top_p=args.top_p,
                        repetition_penalty=args.rep_penalty,
                        max_frames=args.max_frames, ref_frames=args.ref_frames)
    elapsed = time.time() - t0

    if waveform is not None:
        sf.write(args.output, waveform, 22050)
        dur = len(waveform) / 22050
        print(f"\nSaved: {args.output} ({dur:.2f}s audio, generated in {elapsed:.1f}s)")
    else:
        print("\nGeneration failed - no audio tokens produced")
