"""
Wandb audio evaluation callback.
Generates test audio samples at checkpoint saves and logs to wandb.
Uses reference audio tokens + IPA phonemes (same format as training data).
"""

import torch
import wandb
import numpy as np
import soundfile as sf_lib
from transformers import TrainerCallback
from nemo.collections.tts.models import AudioCodecModel
from .ipa import text_to_ipa


TL = 64400
END_OF_TEXT = 2
START_OF_SPEECH = TL + 1
END_OF_SPEECH = TL + 2
START_OF_HUMAN = TL + 3
END_OF_HUMAN = TL + 4
START_OF_AI = TL + 5
END_OF_AI = TL + 6
AUDIO_START = TL + 10

TEST_SENTENCES = [
    "नमस्ते, मैं आपकी कैसे मदद कर सकता हूँ?",
    "आज मौसम बहुत अच्छा है, चलो बाहर घूमने चलते हैं।",
    "समय दुनिया की सबसे मूल्यवान चीज़ है जिसे कभी वापस नहीं पाया जा सकता।",
    "शिक्षा हमारे जीवन का सबसे महत्वपूर्ण हिस्सा है।",
    "प्रकृति हमें बहुत कुछ सिखाती है और हमें उसका सम्मान करना चाहिए।",
    "हमें अपने सभी कार्य सही समय पर पूरे करने की आदत डालनी चाहिए।",
]

REFERENCE_AUDIO = {
    "iisc_female": "/home/ubuntu/soprano_data/IISc_SYSPIN_Data/IISc_SYSPINProject_Hindi_Female_Spk001_HC/wav/IISc_SYSPINProject_hi_f_AGRI_00036.wav",
    "iisc_male": "/home/ubuntu/soprano_data/IISc_SYSPIN_Data/IISc_SYSPINProject_Hindi_Male_Spk001_HC/wav/IISc_SYSPINProject_hi_m_AGRI_00041.wav",
    "sarvam_priya": "/home/ubuntu/soprano_data/sarvam_data/wavs/priya_000012.wav",
    "elevenlabs_nikita": "/home/ubuntu/soprano_data/speaker_embeddings/elevenlabs_nikita_ref_22k.wav",
}

EVAL_SPEAKERS = ["iisc_female", "iisc_male", "sarvam_priya", "elevenlabs_nikita", "iisc_female", "iisc_male"]

REF_FRAMES = 62


class AudioEvalCallback(TrainerCallback):
    """Generate and log test audio to wandb at each checkpoint save."""

    def __init__(self, tokenizer, eval_steps=None, temperature=0.7, max_frames=200):
        self.tokenizer = tokenizer
        self.eval_steps = eval_steps
        self.temperature = temperature
        self.max_frames = max_frames
        self.codec = None
        self._refs_logged = False
        self._ref_tokens_cache = {}
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def _load_codec(self):
        if self.codec is None:
            self.codec = AudioCodecModel.from_pretrained(
                "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
            ).eval().to(self.device)

    def _encode_reference(self, wav_path):
        """Encode a wav file into cb1 reference tokens (offset by AUDIO_START)."""
        if wav_path in self._ref_tokens_cache:
            return self._ref_tokens_cache[wav_path]

        audio, sr = sf_lib.read(wav_path, dtype="float32")
        if audio.ndim > 1:
            audio = audio.mean(axis=1)
        if sr != 22050:
            ratio = 22050 / sr
            n_out = int(len(audio) * ratio)
            idx = np.clip((np.arange(n_out) / ratio).astype(np.int64), 0, len(audio) - 1)
            audio = audio[idx]

        audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
        audio_len = torch.tensor([audio_t.shape[1]], device=self.device)

        with torch.no_grad():
            codes = self.codec.encode(audio=audio_t, audio_len=audio_len)[0]

        cb1 = codes[0, 0].tolist()
        if len(cb1) > REF_FRAMES:
            cb1 = cb1[:REF_FRAMES]

        ref_tokens = [tok + AUDIO_START for tok in cb1]
        self._ref_tokens_cache[wav_path] = ref_tokens
        return ref_tokens

    def _prepare_input(self, text, ref_tokens):
        """Build input sequence matching training format:
        [ref_cb1] [START_HUMAN] ipa_text [END_TEXT] [END_HUMAN] [START_AI] [START_SPEECH]
        """
        ipa_text = text_to_ipa(text, language="hi")
        text_prompt = f"hi: {ipa_text}"
        text_ids = self.tokenizer.encode(text_prompt, add_special_tokens=True)
        text_ids.append(END_OF_TEXT)

        input_ids = (
            ref_tokens
            + [START_OF_HUMAN]
            + text_ids
            + [END_OF_HUMAN, START_OF_AI, START_OF_SPEECH]
        )
        return torch.tensor([input_ids], dtype=torch.long, device=self.device)

    @torch.no_grad()
    def _generate(self, model, text, speaker_key):
        wav_path = REFERENCE_AUDIO.get(speaker_key)
        if wav_path is None:
            return None

        ref_tokens = self._encode_reference(wav_path)
        input_ids = self._prepare_input(text, ref_tokens)
        cb1, cb2, cb3, cb4 = [], [], [], []

        for _ in range(self.max_frames):
            backbone_out = model.model(input_ids=input_ids)
            h = backbone_out.last_hidden_state[:, -1, :]

            lm_logits = model.lm_head(h) / self.temperature
            probs = torch.softmax(lm_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            next_id = next_token.item()

            if next_id == END_OF_SPEECH:
                break

            if next_id >= AUDIO_START:
                cb1.append(next_id - AUDIO_START)

                cb2_pred = torch.argmax(model.cb2_head(h), dim=-1)
                cb2_emb = model.cb2_embed(cb2_pred)

                cb3_in = model.cb3_mlp(torch.cat([h, cb2_emb], dim=-1))
                cb3_pred = torch.argmax(model.cb3_head(cb3_in), dim=-1)
                cb3_emb = model.cb3_embed(cb3_pred)

                cb4_in = model.cb4_mlp(torch.cat([h, cb2_emb + cb3_emb], dim=-1))
                cb4_pred = torch.argmax(model.cb4_head(cb4_in), dim=-1)

                cb2.append(cb2_pred.item())
                cb3.append(cb3_pred.item())
                cb4.append(cb4_pred.item())

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

        if len(cb1) == 0:
            return None

        codes = torch.tensor([cb1, cb2, cb3, cb4], dtype=torch.long).unsqueeze(0).to(self.device)
        codes_len = torch.tensor([codes.shape[-1]], device=self.device)
        audio, _ = self.codec.decode(tokens=codes, tokens_len=codes_len)
        return audio.detach().cpu().numpy().squeeze()

    def _log_references_once(self):
        """Log reference audio clips to wandb once (first checkpoint only)."""
        if self._refs_logged:
            return
        self._refs_logged = True
        for spk, path in REFERENCE_AUDIO.items():
            try:
                audio, sr = sf_lib.read(path)
                if sr != 22050:
                    ratio = 22050 / sr
                    n_out = int(len(audio) * ratio)
                    idx = np.clip((np.arange(n_out) / ratio).astype(np.int64), 0, len(audio) - 1)
                    audio = audio[idx]
                    sr = 22050
                wandb.log({f"audio/reference_{spk}": wandb.Audio(audio, sample_rate=sr, caption=f"REF: {spk}")})
            except Exception as e:
                print(f"Failed to log reference {spk}: {e}")

    def on_save(self, args, state, control, model=None, **kwargs):
        if not wandb.run:
            return

        try:
            self._load_codec()
            self._log_references_once()
            model.eval()

            for i, (text, speaker_key) in enumerate(zip(TEST_SENTENCES, EVAL_SPEAKERS)):
                try:
                    waveform = self._generate(model, text, speaker_key)
                    if waveform is not None:
                        audio_wandb = wandb.Audio(
                            waveform, sample_rate=22050,
                            caption=f"{speaker_key}: {text[:50]}..."
                        )
                        wandb.log({
                            f"audio/sample_{i}_{speaker_key}": audio_wandb,
                            "train/global_step": state.global_step,
                        })
                    else:
                        print(f"Audio eval: sample {i} produced no audio tokens")
                except Exception as e:
                    print(f"Audio eval failed for sample {i}: {e}")

            model.train()
            print(f"Audio samples logged to wandb (step {state.global_step})")

        except Exception as e:
            print(f"Audio eval callback error: {e}")
