#!/usr/bin/env python3
"""
Run inference with trained Hindi Soprano checkpoint and save sample audios.
Reference: official flow in soprano_infernce/ (tts.py, backends/transformers.py).
We use custom generation so only AUDIO token hidden states (ids 3..8002) are
sent to the decoder; official assumes base model only outputs audio after [START].
Decoder output trim matches official: last (L*TOKEN_SIZE - TOKEN_SIZE) samples.
Usage:
  python inference.py --checkpoint-dir /path/to/checkpoints --out-dir /path/to/wavs [--num-samples 5]
"""
import argparse
import json
import numpy as np
import os
import sys

# Audio token ID range (same as train.py: 3..8003 where 8003=stop, 3..8002=codec)
AUDIO_TOKEN_MIN = 3
AUDIO_TOKEN_MAX = 8002
EOS_AUDIO = 8003
TOKEN_SIZE = 2048
SAMPLE_RATE = 32000
BASE_MODEL_ID = "ekwek/Soprano-1.1-80M"


def load_model_and_decoder(checkpoint_dir, device):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch

    model = AutoModelForCausalLM.from_pretrained(
        checkpoint_dir,
        torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,  # noqa: deprecated
    ).to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token

    # Decoder from soprano (same as SopranoTTS)
    from soprano.vocos.decoder import SopranoDecoder
    decoder = SopranoDecoder().to(device)
    decoder_path = os.path.join(checkpoint_dir, "decoder.pth")
    if not os.path.isfile(decoder_path):
        from huggingface_hub import hf_hub_download
        decoder_path = hf_hub_download(repo_id="ekwek/Soprano-1.1-80M", filename="decoder.pth")
    decoder.load_state_dict(torch.load(decoder_path, map_location=device))
    decoder.eval()

    return model, tokenizer, decoder


def generate_audio_token_ids(model, tokenizer, text, device, max_new_tokens=512,
                             temperature=0.3, top_p=0.95, repetition_penalty=1.2):
    """Generate only the sequence of audio token IDs (3..8002); stop at EOS (8003)."""
    import torch
    from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor
    from transformers import TemperatureLogitsWarper, TopPLogitsWarper

    prompt = f"[STOP][TEXT]{text}[START]"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
    generated = inputs["input_ids"]
    eos_token_id = getattr(model.config, "eos_token_id", None) or EOS_AUDIO
    logits_processor = LogitsProcessorList()
    if repetition_penalty != 1.0:
        logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
    logits_warper = LogitsProcessorList()
    if temperature and temperature != 1.0:
        logits_warper.append(TemperatureLogitsWarper(temperature=max(1e-5, temperature)))
    if top_p < 1.0:
        logits_warper.append(TopPLogitsWarper(top_p=top_p))

    audio_token_ids = []
    with torch.no_grad():
        past_key_values = None
        for _ in range(max_new_tokens):
            if past_key_values is None:
                outputs = model(generated, use_cache=True, output_hidden_states=True)
            else:
                outputs = model(generated[:, -1:], past_key_values=past_key_values,
                                use_cache=True, output_hidden_states=True)
            past_key_values = outputs.past_key_values
            next_logits = outputs.logits[:, -1, :]
            next_logits = logits_processor(generated, next_logits)
            next_logits = logits_warper(generated, next_logits)
            probs = torch.nn.functional.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=-1)
            token_id = next_token.item()
            if token_id == eos_token_id or token_id == EOS_AUDIO:
                break
            if AUDIO_TOKEN_MIN <= token_id <= AUDIO_TOKEN_MAX:
                audio_token_ids.append(token_id)
    return audio_token_ids


def decode_with_base_model_hidden_states(
    audio_token_ids, text, base_model, base_tokenizer, decoder, device
):
    """
    Run base Soprano model in teacher-forced mode on prompt + our token IDs,
    take its hidden states at audio positions, decode to audio.
    The decoder was trained on base-model hidden states, so this should sound correct.
    """
    import torch
    prompt = f"[STOP][TEXT]{text}[START]"
    prompt_ids = base_tokenizer(
        prompt, return_tensors="pt", truncation=True, max_length=512
    ).input_ids.to(device)[0]
    if not audio_token_ids:
        return None
    audio_ids = torch.tensor(audio_token_ids, dtype=torch.long, device=device)
    full_ids = torch.cat([prompt_ids, audio_ids]).unsqueeze(0)
    with torch.no_grad():
        out = base_model(full_ids, output_hidden_states=True)
    # Hidden states at positions corresponding to audio tokens (after prompt)
    start = prompt_ids.size(0)
    end = start + len(audio_token_ids)
    hidden = out.hidden_states[-1][0, start:end, :].float()  # (L, 512)
    L = hidden.size(0)
    hidden = hidden.unsqueeze(0).transpose(1, 2).to(device)  # (1, 512, L)
    with torch.no_grad():
        audio = decoder(hidden)
    raw = audio[0].squeeze().cpu().float()
    n_keep = L * TOKEN_SIZE - TOKEN_SIZE
    if raw.numel() >= n_keep:
        raw = raw[-(n_keep):]
    raw = raw.numpy()
    # Decoder output scale is arbitrary; peak-normalize so it's audible (training scales to target).
    peak = np.abs(raw).max()
    if peak > 1e-8:
        raw = raw / peak
    return raw


def generate_audio_tokens_only(model, tokenizer, decoder, text, device, max_new_tokens=512,
                               temperature=0.3, top_p=0.95, repetition_penalty=1.2):
    """
    Build prompt [STOP][TEXT]{text}[START], generate, and pass ONLY hidden states
    for tokens in AUDIO_TOKEN_MIN..AUDIO_TOKEN_MAX to the decoder. Stop at EOS (8003).
    """
    import torch
    from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor
    from transformers import TemperatureLogitsWarper, TopPLogitsWarper

    prompt = f"[STOP][TEXT]{text}[START]"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
    input_ids = inputs["input_ids"]
    batch_size = input_ids.size(0)

    eos_token_id = getattr(model.config, "eos_token_id", None) or EOS_AUDIO
    logits_processor = LogitsProcessorList()
    if repetition_penalty != 1.0:
        logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
    logits_warper = LogitsProcessorList()
    if temperature and temperature != 1.0:
        logits_warper.append(TemperatureLogitsWarper(temperature=max(1e-5, temperature)))
    if top_p < 1.0:
        logits_warper.append(TopPLogitsWarper(top_p=top_p))

    audio_hidden_states = []
    with torch.no_grad():
        past_key_values = None
        generated = input_ids
        for _ in range(max_new_tokens):
            if past_key_values is None:
                outputs = model(
                    input_ids=generated,
                    use_cache=True,
                    output_hidden_states=True,
                )
            else:
                outputs = model(
                    input_ids=generated[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                    output_hidden_states=True,
                )
            past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            next_token_logits = logits_processor(generated, next_token_logits)
            next_token_logits = logits_warper(generated, next_token_logits)
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=-1)

            token_id = next_token.item()
            if token_id == eos_token_id or token_id == EOS_AUDIO:
                break
            if AUDIO_TOKEN_MIN <= token_id <= AUDIO_TOKEN_MAX:
                # Last layer, last position, batch 0
                h = outputs.hidden_states[-1][0, -1, :].float()
                audio_hidden_states.append(h)

    if not audio_hidden_states:
        return None
    L = len(audio_hidden_states)
    hidden = torch.stack(audio_hidden_states, dim=0).unsqueeze(0).to(device)
    # Decoder expects (B, hidden_dim, T) — same as soprano_infernce/soprano/tts.py
    hidden = hidden.transpose(1, 2)
    with torch.no_grad():
        audio = decoder(hidden)
    raw = audio[0].squeeze().cpu().float()
    # Official trim: last (L*TOKEN_SIZE - TOKEN_SIZE) samples (tts.py line 194)
    n_keep = L * TOKEN_SIZE - TOKEN_SIZE
    if raw.numel() >= n_keep:
        raw = raw[-(n_keep):]
    raw = raw.numpy()
    # Decoder output scale is arbitrary; peak-normalize so it's audible.
    peak = np.abs(raw).max()
    if peak > 1e-8:
        raw = raw / peak
    return raw


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint-dir", type=str,
                        default="/home/ubuntu/soprano_data/hindi_unified/checkpoints",
                        help="Path to trained checkpoint (model + tokenizer + decoder.pth)")
    parser.add_argument("--out-dir", type=str,
                        default="/home/ubuntu/soprano_data/hindi_unified/inference_output",
                        help="Directory to save generated WAVs")
    parser.add_argument("--val-json", type=str,
                        default="/home/ubuntu/soprano_data/hindi_unified/val.json",
                        help="Val JSON to sample prompts from")
    parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to generate")
    parser.add_argument("--max-new-tokens", type=int, default=512,
                        help="Max audio tokens to generate (default 512)")
    parser.add_argument("--temperature", type=float, default=0.3)
    parser.add_argument("--top-p", type=float, default=0.95)
    parser.add_argument("--official-preprocess", action="store_true",
                        help="Use clean_text from soprano_infernce/ if available")
    parser.add_argument("--use-base-hidden", action="store_true",
                        help="Use base Soprano model to get hidden states for decoding (fixes noise: our LM predicts tokens, base LM provides decoder-friendly hidden states)")
    parser.add_argument("--sanity-check", action="store_true",
                        help="Use ground-truth token IDs from val.json + base model hidden states (no LM). If this is still noise, base tokenizer/context is the issue.")
    parser.add_argument("--english-test", action="store_true",
                        help="Base model only: generate and decode for English text. If this is clear, the issue is Hindi->UNK context.")
    parser.add_argument("--english-with-hindi-model", action="store_true",
                        help="Hindi-trained model generates token IDs for English text; base model provides hidden states for decode. Checks that Hindi model still has usable English token predictions.")
    args = parser.parse_args()

    with open(args.val_json, encoding="utf-8") as f:
        val_data = json.load(f)

    if args.english_with_hindi_model:
        # Hindi model generates token IDs; use same Hindi model for hidden states -> decoder.
        # Decoder was trained on Hindi LM hidden states, so we must use Hindi model here (not base).
        english_text = "Hello, this is a test."
        print("English with Hindi model: Hindi LM generates tokens and provides hidden states for decode (matches decoder training)")
        os.makedirs(args.out_dir, exist_ok=True)
        import torch
        import soundfile as sf
        import numpy as np
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading Hindi model from {args.checkpoint_dir} ...")
        model, tokenizer, decoder = load_model_and_decoder(args.checkpoint_dir, device)
        audio_token_ids = generate_audio_token_ids(
            model, tokenizer, english_text, device,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
        )
        print(f"  Hindi model generated {len(audio_token_ids)} audio tokens for English prompt")
        audio = decode_with_base_model_hidden_states(
            audio_token_ids, english_text, model, tokenizer, decoder, device
        )
        out_path = os.path.join(args.out_dir, "english_hindi_model.wav")
        if audio is not None:
            audio = np.clip(audio, -1.0, 1.0)
            sf.write(out_path, (audio * 32767).astype(np.int16), SAMPLE_RATE, subtype="PCM_16")
            print(f"Saved {out_path} ({len(audio)/SAMPLE_RATE:.2f}s)")
        else:
            print("Decode returned None")
        return

    if args.english_test:
        # Base model generates for English; then we decode with base hidden states. Proves the stack works.
        english_text = "Hello, this is a test."
        print("English test: base model generate + decode (no Hindi, no fine-tuned LM)")
        os.makedirs(args.out_dir, exist_ok=True)
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import soundfile as sf
        import numpy as np
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _, _, decoder = load_model_and_decoder(args.checkpoint_dir, device)
        print(f"Loading base model {BASE_MODEL_ID} ...")
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
        ).to(device)
        base_model.eval()
        base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
        if base_tokenizer.pad_token is None:
            base_tokenizer.pad_token = base_tokenizer.unk_token
        audio_token_ids = generate_audio_token_ids(
            base_model, base_tokenizer, english_text, device,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
        )
        print(f"  Base model generated {len(audio_token_ids)} audio tokens")
        audio = decode_with_base_model_hidden_states(
            audio_token_ids, english_text, base_model, base_tokenizer, decoder, device
        )
        out_path = os.path.join(args.out_dir, "english_test.wav")
        if audio is not None:
            audio = np.clip(audio, -1.0, 1.0)
            sf.write(out_path, (audio * 32767).astype(np.int16), SAMPLE_RATE, subtype="PCM_16")
            print(f"Saved {out_path} ({len(audio)/SAMPLE_RATE:.2f}s)")
        else:
            print("Decode returned None")
        return

    if args.sanity_check:
        # Ground-truth tokens + Hindi model hidden states -> decoder (matches decoder training).
        if not val_data:
            print("val.json is empty, cannot run sanity check")
            sys.exit(1)
        text = val_data[0][0].strip()
        gt_token_ids = val_data[0][1]
        if len(gt_token_ids) > 2000:
            gt_token_ids = gt_token_ids[:2000]  # cap length
        print("Sanity check: Hindi model hidden states + decoder on ground-truth tokens (no LM generation)")
        os.makedirs(args.out_dir, exist_ok=True)
        import torch
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading Hindi model from {args.checkpoint_dir} ...")
        model, tokenizer, decoder = load_model_and_decoder(args.checkpoint_dir, device)
        audio = decode_with_base_model_hidden_states(
            gt_token_ids, text, model, tokenizer, decoder, device
        )
        out_path = os.path.join(args.out_dir, "sanity_check.wav")
        if audio is not None:
            import soundfile as sf
            import numpy as np
            audio = np.clip(audio, -1.0, 1.0)
            sf.write(out_path, (audio * 32767).astype(np.int16), SAMPLE_RATE, subtype="PCM_16")
            print(f"Saved {out_path} ({len(audio)/SAMPLE_RATE:.2f}s)")
        else:
            print("Decode returned None")
        return

    indices = [0, len(val_data)//4, len(val_data)//2, 3*len(val_data)//4, len(val_data)-1][:args.num_samples]
    raw_prompts = [val_data[i][0] for i in indices]
    if args.official_preprocess:
        official_path = os.path.join(os.path.dirname(__file__) or ".", "..", "soprano_infernce")
        if os.path.isdir(official_path) and official_path not in sys.path:
            sys.path.insert(0, official_path)
        try:
            from soprano.utils.text_normalizer import clean_text
            prompts = [clean_text(t.strip()) for t in raw_prompts]
        except Exception:
            prompts = [t.strip() for t in raw_prompts]
    else:
        prompts = [t.strip() for t in raw_prompts]

    os.makedirs(args.out_dir, exist_ok=True)
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading model from {args.checkpoint_dir} ...")
    model, tokenizer, decoder = load_model_and_decoder(args.checkpoint_dir, device)
    base_model = base_tokenizer = None
    if args.use_base_hidden:
        print(f"Loading base model {BASE_MODEL_ID} for decoder-friendly hidden states ...")
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
        ).to(device)
        base_model.eval()
        base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
        if base_tokenizer.pad_token is None:
            base_tokenizer.pad_token = base_tokenizer.unk_token
    mode = "base-hidden decode" if args.use_base_hidden else "audio-token-only decoding"
    print(f"Generating {len(prompts)} samples ({mode}) ...")

    import soundfile as sf
    import numpy as np
    for i, text in enumerate(prompts):
        if len(text) > 200:
            text = text[:200].rsplit(" ", 1)[0] if " " in text[:200] else text[:200]
        out_path = os.path.join(args.out_dir, f"sample_{i}.wav")
        print(f"  [{i+1}/{len(prompts)}] {text[:50]}... -> {out_path}")
        if args.use_base_hidden:
            audio_token_ids = generate_audio_token_ids(
                model, tokenizer, text, device,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
            )
            audio = decode_with_base_model_hidden_states(
                audio_token_ids, text, base_model, base_tokenizer, decoder, device
            )
        else:
            audio = generate_audio_tokens_only(
                model, tokenizer, decoder, text, device,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
            )
        if audio is None:
            print(f"    Warning: no audio tokens generated, skipping.")
            continue
        audio = np.clip(audio, -1.0, 1.0)
        sf.write(out_path, (audio * 32767).astype(np.int16), SAMPLE_RATE, subtype="PCM_16")
        dur = len(audio) / SAMPLE_RATE
        print(f"    -> {dur:.2f}s")

    print(f"\nDone. Audios saved to {args.out_dir}")
    for i in range(len(prompts)):
        p = os.path.join(args.out_dir, f"sample_{i}.wav")
        if os.path.isfile(p):
            print(f"  {p} ({os.path.getsize(p)//1024} KB)")


if __name__ == "__main__":
    main()
