"""Inference for LeWM TTS v2 (DAC-based)."""

import torch
import torchaudio
import dac
import numpy as np
import argparse
from model import LeWMTTS


class LeWMTTSInference:
    def __init__(self, checkpoint_path, device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = ckpt["config"]
        self.model = LeWMTTS(config).to(self.device)
        self.model.load_state_dict(ckpt["model"])
        self.model.eval()
        self.config = config

        print("Loading DAC 24kHz...")
        model_path = dac.utils.download(model_type='24khz')
        self.dac_model = dac.DAC.load(model_path)
        self.dac_model.eval()
        self.dac_model = self.dac_model.to(self.device)

        with torch.no_grad():
            dummy = torch.randn(1, 1, 24000).to(self.device)
            dummy = self.dac_model.preprocess(dummy, 24000)
            z = self.dac_model.encoder(dummy)
            self._start_scale = z.std().item()
        print(f"Model loaded. Start scale: {self._start_scale:.3f}")

    def text_to_tokens(self, text):
        tokens = list(text.encode("utf-8"))
        return torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)

    @torch.no_grad()
    def synthesize(self, text, max_steps=500, temperature=0.0):
        text_tokens = self.text_to_tokens(text)
        text_emb = self.model.text_encoder(text_tokens)

        start_dac = torch.randn(1, 1, 1024, device=self.device) * self._start_scale
        start_emb = self.model.dac_in_proj(start_dac)

        all_embs = [start_emb]
        prev_norms = []

        for step in range(max_steps):
            context = torch.cat(all_embs, dim=1)
            next_emb = self.model.predict_next(context, text_emb)
            if temperature > 0:
                next_emb = next_emb + torch.randn_like(next_emb) * temperature
            all_embs.append(next_emb)

            norm = next_emb.norm().item()
            prev_norms.append(norm)
            if len(prev_norms) > 20:
                recent = prev_norms[-20:]
                if np.std(recent) < 0.01 * np.mean(recent) and step > 30:
                    break

        audio_embs = torch.cat(all_embs, dim=1)
        dac_latents = self.model.latents_to_dac(audio_embs)
        dac_latents = dac_latents.transpose(1, 2)  # [1, 1024, T]

        waveform = self.dac_model.decode(dac_latents)
        return waveform.squeeze().cpu().numpy(), 24000

    @torch.no_grad()
    def reconstruct(self, audio_path):
        """DAC encode → project through model → DAC decode."""
        wav, sr = torchaudio.load(audio_path)
        wav = wav.unsqueeze(0).to(self.device)
        wav = self.dac_model.preprocess(wav, sr)

        z = self.dac_model.encoder(wav)  # [1, 1024, T]
        z_in = z.transpose(1, 2)  # [1, T, 1024]
        h = self.model.dac_in_proj(z_in)
        z_out = self.model.latents_to_dac(h)
        z_out = z_out.transpose(1, 2)  # [1, 1024, T]

        waveform = self.dac_model.decode(z_out)
        return waveform.squeeze().cpu().numpy(), 24000

    @torch.no_grad()
    def reconstruct_direct(self, audio_path):
        """Pure DAC encode → decode. Baseline quality."""
        wav, sr = torchaudio.load(audio_path)
        wav = wav.unsqueeze(0).to(self.device)
        wav = self.dac_model.preprocess(wav, sr)
        z = self.dac_model.encoder(wav)
        waveform = self.dac_model.decode(z)
        return waveform.squeeze().cpu().numpy(), 24000

    def save_audio(self, waveform, sr, output_path):
        wav_tensor = torch.from_numpy(waveform).unsqueeze(0)
        torchaudio.save(output_path, wav_tensor, sr)
        print(f"Saved: {output_path} ({len(waveform)/sr:.2f}s)")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--text", default=None)
    parser.add_argument("--audio", default=None)
    parser.add_argument("--mode", choices=["synthesize", "reconstruct", "dac_baseline"],
                        default="synthesize")
    parser.add_argument("--output", default="output.wav")
    parser.add_argument("--max_steps", type=int, default=500)
    parser.add_argument("--temperature", type=float, default=0.0)
    args = parser.parse_args()

    tts = LeWMTTSInference(args.checkpoint)
    if args.mode == "synthesize":
        wav, sr = tts.synthesize(args.text, args.max_steps, args.temperature)
    elif args.mode == "reconstruct":
        wav, sr = tts.reconstruct(args.audio)
    elif args.mode == "dac_baseline":
        wav, sr = tts.reconstruct_direct(args.audio)
    tts.save_audio(wav, sr, args.output)
