"""
Inference pipeline for LeWM TTS.
Text → JEPA predictor → mel reconstruction → Vocos vocoder → waveform
"""

import torch
import torchaudio
import numpy as np
import argparse
from pathlib import Path

from model import LeWMTTS


class LeWMTTSInference:
    """Full inference pipeline."""

    def __init__(self, checkpoint_path, device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        # Load checkpoint
        ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = ckpt["config"]

        # Build model
        self.model = LeWMTTS(config).to(self.device)
        self.model.load_state_dict(ckpt["model"])
        self.model.eval()

        # Vocos vocoder (directly decodes 100-mel)
        from vocos import Vocos
        self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
        self.vocos.eval()

        self.config = config
        print(f"Model loaded from {checkpoint_path}")

    def vocos_decode(self, mel):
        """Decode mel [B, 100, T] → waveform via Vocos."""
        with torch.no_grad():
            waveform = self.vocos.decode(mel)  # [B, T_audio]
        return waveform

    def text_to_tokens(self, text):
        tokens = list(text.encode("utf-8"))
        return torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)

    @torch.no_grad()
    def synthesize(self, text, max_steps=300, temperature=0.0, speaker_id=None):
        """Generate speech from text using AR prediction with KV-cache."""
        text_tokens = self.text_to_tokens(text)
        text_emb = self.model.text_encoder(text_tokens)

        # Speaker conditioning
        spk_cond = None
        if self.model.speaker_embed is not None and speaker_id is not None:
            spk_id = torch.tensor([speaker_id], device=self.device)
            spk_cond = self.model.speaker_embed(spk_id)  # [1, d]
            # FiLM on text embeddings (matches training)
            text_emb = self.model._apply_speaker_film(
                text_emb, spk_cond, self.model.latent_film
            )

        # Use learned start embedding
        start_emb = self.model.start_emb.to(self.device)
        all_embeddings = [start_emb]

        cache = self.model.init_ar_cache()
        next_emb, cache = self.model.predict_next_cached(start_emb, text_emb, 0, cache)
        if temperature > 0:
            next_emb = next_emb + torch.randn_like(next_emb) * temperature

        # FiLM on predicted latent (matches training where z gets FiLM'd)
        emb_for_decode = self.model._apply_speaker_film(
            next_emb, spk_cond, self.model.latent_film
        ) if spk_cond is not None else next_emb
        all_embeddings.append(emb_for_decode)

        prev_norms = [next_emb.norm().item()]

        for step in range(1, max_steps):
            # Feed FiLM-conditioned embedding as input (matches training input_emb)
            next_emb, cache = self.model.predict_next_cached(
                emb_for_decode, text_emb, step, cache
            )
            if temperature > 0:
                next_emb = next_emb + torch.randn_like(next_emb) * temperature

            emb_for_decode = self.model._apply_speaker_film(
                next_emb, spk_cond, self.model.latent_film
            ) if spk_cond is not None else next_emb
            all_embeddings.append(emb_for_decode)

            norm = next_emb.norm().item()
            prev_norms.append(norm)
            if len(prev_norms) > 20:
                recent = prev_norms[-20:]
                if np.std(recent) < 0.01 * np.mean(recent) and step > 30:
                    break

        audio_embs = torch.cat(all_embeddings, dim=1)
        # Pass speaker conditioning to mel decoder (FiLM-conditioned reconstruction)
        mel = self.model.mel_decoder(audio_embs, spk_cond)
        waveform = self.vocos_decode(mel.cpu())
        return waveform.squeeze().numpy(), 24000

    @torch.no_grad()
    def reconstruct(self, mel_path, speaker_id=None):
        """Teacher-forced reconstruction: encode real mel → decode back."""
        mel = torch.load(mel_path, weights_only=True).unsqueeze(0).to(self.device)

        spk_cond = None
        if self.model.speaker_embed is not None and speaker_id is not None:
            spk_id = torch.tensor([speaker_id], device=self.device)
            spk_cond = self.model.speaker_embed(spk_id)

        z, mu, logvar = self.model.audio_encoder(mel)
        # FiLM condition the latent before decoding
        if spk_cond is not None:
            mu = self.model._apply_speaker_film(mu, spk_cond, self.model.latent_film)
        mel_recon = self.model.mel_decoder(mu, spk_cond)

        waveform = self.vocos_decode(mel_recon.cpu())
        return waveform.squeeze().numpy(), 24000

    @torch.no_grad()
    def reconstruct_with_prediction(self, mel_path, speaker_id=None):
        """Hybrid: encode real mel → predictor → decode predicted embeddings."""
        mel = torch.load(mel_path, weights_only=True).unsqueeze(0).to(self.device)

        spk_cond = None
        if self.model.speaker_embed is not None and speaker_id is not None:
            spk_id = torch.tensor([speaker_id], device=self.device)
            spk_cond = self.model.speaker_embed(spk_id)

        z, mu, logvar = self.model.audio_encoder(mel)

        text_tokens = torch.zeros(1, 1, dtype=torch.long, device=self.device)
        text_emb = self.model.text_encoder(text_tokens)
        if spk_cond is not None:
            text_emb = self.model._apply_speaker_film(
                text_emb, spk_cond, self.model.latent_film
            )

        predicted = self.model.predictor(mu[:, :-1], text_emb)
        combined = torch.cat([mu[:, :1], predicted], dim=1)

        if spk_cond is not None:
            combined = self.model._apply_speaker_film(
                combined, spk_cond, self.model.latent_film
            )
        mel_recon = self.model.mel_decoder(combined, spk_cond)
        waveform = self.vocos_decode(mel_recon.cpu())
        return waveform.squeeze().numpy(), 24000

    def save_audio(self, waveform, sr, output_path):
        # Normalize
        mx = np.abs(waveform).max()
        if mx > 0:
            waveform = waveform / mx * 0.95
        wav_tensor = torch.from_numpy(waveform).unsqueeze(0)
        torchaudio.save(output_path, wav_tensor, sr)
        print(f"Saved: {output_path} ({len(waveform)/sr:.2f}s)")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--text", default=None, help="Text to synthesize")
    parser.add_argument("--mel", default=None, help="Mel path for reconstruction test")
    parser.add_argument("--mode", choices=["synthesize", "reconstruct", "predict"],
                        default="synthesize")
    parser.add_argument("--output", default="output.wav")
    parser.add_argument("--max_steps", type=int, default=300)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--speaker_id", type=int, default=None)
    args = parser.parse_args()

    tts = LeWMTTSInference(args.checkpoint)

    if args.mode == "synthesize":
        assert args.text, "Need --text for synthesis"
        waveform, sr = tts.synthesize(args.text, args.max_steps, args.temperature,
                                       speaker_id=args.speaker_id)
    elif args.mode == "reconstruct":
        assert args.mel, "Need --mel for reconstruction"
        waveform, sr = tts.reconstruct(args.mel, speaker_id=args.speaker_id)
    elif args.mode == "predict":
        assert args.mel, "Need --mel for prediction test"
        waveform, sr = tts.reconstruct_with_prediction(args.mel, speaker_id=args.speaker_id)

    tts.save_audio(waveform, sr, args.output)


if __name__ == "__main__":
    main()
