"""
Inference for LeWM TTS v5 (codec-based JEPA with duration predictor).
Text → duration predict → expand text → JEPA predict → proj_out → EnCodec decode → waveform
"""

import torch
import torchaudio
import numpy as np
import argparse
from pathlib import Path

from model_v5 import LeWMTTSv5, length_regulate


class LeWMTTSv5Inference:
    def __init__(self, checkpoint_path, device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = ckpt["config"]

        self.model = LeWMTTSv5(config).to(self.device)
        self.model.load_state_dict(ckpt["model"])
        self.model.eval()

        # EnCodec decoder (frozen)
        from encodec import EncodecModel
        self.codec = EncodecModel.encodec_model_24khz()
        self.codec.set_target_bandwidth(6.0)
        self.codec.eval().to(self.device)

        self.config = config
        print(f"v5 model loaded from {checkpoint_path}")

    def text_to_tokens(self, text):
        tokens = list(text.encode("utf-8"))
        return torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)

    @torch.no_grad()
    def synthesize(self, text, duration_scale=1.0, temperature=0.0):
        """AR synthesis with duration predictor for text alignment."""
        text_tokens = self.text_to_tokens(text)
        text_emb = self.model.text_encoder(text_tokens)

        # Predict durations and expand text
        durations = self.model.predict_durations(text_emb)
        durations = (durations.float() * duration_scale).round().long().clamp(min=1)
        total_frames = durations.sum().item()
        print(f"Text: {len(text)} chars → {total_frames} frames ({total_frames/75:.2f}s)")

        # Expand text to frame level
        text_emb_expanded = length_regulate(text_emb, durations)  # [1, T_total, d_model]

        # AR generation
        start_emb = self.model.start_emb.to(self.device)
        cache = self.model.init_ar_cache()
        all_embeddings = [start_emb]

        # First step
        text_frame = text_emb_expanded[:, 0:1]  # [1, 1, d_model]
        next_emb, cache = self.model.predict_next_cached(start_emb, text_frame, 0, cache)
        if temperature > 0:
            next_emb = next_emb + torch.randn_like(next_emb) * temperature
        all_embeddings.append(next_emb)

        for step in range(1, total_frames):
            # Get text embedding for this frame
            if step < text_emb_expanded.shape[1]:
                text_frame = text_emb_expanded[:, step:step+1]
            else:
                text_frame = torch.zeros(1, 1, text_emb.shape[-1], device=self.device)

            next_emb, cache = self.model.predict_next_cached(next_emb, text_frame, step, cache)
            if temperature > 0:
                next_emb = next_emb + torch.randn_like(next_emb) * temperature
            all_embeddings.append(next_emb)

        # Concat predicted embeddings and project to codec space
        pred_embs = torch.cat(all_embeddings, dim=1)  # [1, T, d_model]
        codec_emb = self.model.proj_out(pred_embs)     # [1, T, 128]
        codec_emb = codec_emb.transpose(1, 2)          # [1, 128, T]

        # EnCodec decode
        waveform = self.codec.decoder(codec_emb)
        waveform = waveform.squeeze().cpu().numpy()

        return waveform, 24000

    @torch.no_grad()
    def synthesize_prompted(self, text, prompt_audio_path, prompt_text=None,
                            duration_scale=1.0, temperature=0.0):
        """AR synthesis with audio prompt as prefix for voice/style reference."""
        from encodec.utils import convert_audio

        # Encode prompt audio
        wav, sr = torchaudio.load(prompt_audio_path)
        wav = convert_audio(wav, sr, 24000, 1).to(self.device)
        prompt_codec = self.codec.encoder(wav.unsqueeze(0))  # [1, 128, T_prompt]
        T_prompt = prompt_codec.shape[2]
        prompt_z = self.model.proj_in(prompt_codec.transpose(1, 2))  # [1, T_prompt, d_model]

        # Text encoding (full text = prompt_text + synthesis text)
        if prompt_text:
            full_text = prompt_text + text
        else:
            full_text = text
        text_tokens = self.text_to_tokens(full_text)
        text_emb = self.model.text_encoder(text_tokens)

        # Predict durations and expand text to frame level
        durations = self.model.predict_durations(text_emb)
        durations = (durations.float() * duration_scale).round().long().clamp(min=1)
        total_frames = durations.sum().item()
        text_emb_expanded = length_regulate(text_emb, durations)

        # How many new frames to generate beyond the prompt
        gen_frames = max(total_frames - T_prompt, total_frames // 2)
        print(f"Prompt: {T_prompt} frames ({T_prompt/75:.2f}s) | "
              f"Generate: {gen_frames} frames ({gen_frames/75:.2f}s) | "
              f"Total text frames: {total_frames}")

        # Feed prompt through predictor to build KV cache
        cache = self.model.init_ar_cache()

        # Process prompt: start_emb + prompt_z[:-1] as input
        start_emb = self.model.start_emb.to(self.device)
        prompt_input = torch.cat([start_emb, prompt_z[:, :-1]], dim=1)  # [1, T_prompt, d_model]

        # Feed prompt frames one by one to build cache
        for step in range(T_prompt):
            frame_emb = prompt_input[:, step:step+1]
            if step < text_emb_expanded.shape[1]:
                text_frame = text_emb_expanded[:, step:step+1]
            else:
                text_frame = torch.zeros(1, 1, text_emb.shape[-1], device=self.device)
            next_emb, cache = self.model.predict_next_cached(frame_emb, text_frame, step, cache)

        # AR generation continuing from prompt
        all_embeddings = [prompt_z]  # keep prompt embeddings
        cur_emb = prompt_z[:, -1:]   # last prompt frame as input

        for step in range(gen_frames):
            abs_step = T_prompt + step
            if abs_step < text_emb_expanded.shape[1]:
                text_frame = text_emb_expanded[:, abs_step:abs_step+1]
            else:
                text_frame = torch.zeros(1, 1, text_emb.shape[-1], device=self.device)

            next_emb, cache = self.model.predict_next_cached(cur_emb, text_frame, abs_step, cache)
            if temperature > 0:
                next_emb = next_emb + torch.randn_like(next_emb) * temperature
            all_embeddings.append(next_emb)
            cur_emb = next_emb

        # Decode
        pred_embs = torch.cat(all_embeddings, dim=1)
        codec_emb = self.model.proj_out(pred_embs).transpose(1, 2)
        waveform = self.codec.decoder(codec_emb).squeeze().cpu().numpy()

        return waveform, 24000

    @torch.no_grad()
    def reconstruct(self, audio_path):
        """Encode audio → proj_in → proj_out → decode. Tests roundtrip quality."""
        from encodec.utils import convert_audio
        wav, sr = torchaudio.load(audio_path)
        wav = convert_audio(wav, sr, 24000, 1).to(self.device)

        codec_emb = self.codec.encoder(wav.unsqueeze(0))
        z = self.model.proj_in(codec_emb.transpose(1, 2))
        codec_recon = self.model.proj_out(z).transpose(1, 2)
        waveform = self.codec.decoder(codec_recon).squeeze().cpu().numpy()
        return waveform, 24000

    @torch.no_grad()
    def reconstruct_codec_only(self, audio_path):
        """Pure EnCodec roundtrip (no JEPA). Baseline quality."""
        from encodec.utils import convert_audio
        wav, sr = torchaudio.load(audio_path)
        wav = convert_audio(wav, sr, 24000, 1).to(self.device)

        codec_emb = self.codec.encoder(wav.unsqueeze(0))
        waveform = self.codec.decoder(codec_emb).squeeze().cpu().numpy()
        return waveform, 24000

    @torch.no_grad()
    def reconstruct_with_prediction(self, audio_path, text):
        """Encode audio → proj_in → teacher-forced predict with text alignment → proj_out → decode."""
        from encodec.utils import convert_audio
        from model_v5 import compute_uniform_durations
        wav, sr = torchaudio.load(audio_path)
        wav = convert_audio(wav, sr, 24000, 1).to(self.device)

        codec_emb = self.codec.encoder(wav.unsqueeze(0))  # [1, 128, T]
        T = codec_emb.shape[2]
        z = self.model.proj_in(codec_emb.transpose(1, 2))  # [1, T, d_model]

        text_tokens = self.text_to_tokens(text)
        text_emb = self.model.text_encoder(text_tokens)

        # Compute uniform durations and expand text
        text_len = torch.tensor([text_tokens.shape[1]], device=self.device)
        audio_len = torch.tensor([T], device=self.device)
        durations = compute_uniform_durations(text_len, audio_len)
        text_emb_expanded = length_regulate(text_emb, durations)
        if text_emb_expanded.shape[1] > T:
            text_emb_expanded = text_emb_expanded[:, :T]
        elif text_emb_expanded.shape[1] < T:
            pad = torch.zeros(1, T - text_emb_expanded.shape[1], text_emb.shape[-1],
                            device=self.device)
            text_emb_expanded = torch.cat([text_emb_expanded, pad], dim=1)

        B = z.shape[0]
        start = self.model.start_emb.expand(B, -1, -1)
        input_emb = torch.cat([start, z[:, :-1]], dim=1)
        predicted = self.model.predictor(input_emb, text_emb_expanded)

        codec_recon = self.model.proj_out(predicted).transpose(1, 2)
        waveform = self.codec.decoder(codec_recon).squeeze().cpu().numpy()
        return waveform, 24000

    def save_audio(self, waveform, sr, output_path):
        mx = np.abs(waveform).max()
        if mx > 0:
            waveform = waveform / mx * 0.95
        wav_tensor = torch.from_numpy(waveform).unsqueeze(0)
        torchaudio.save(output_path, wav_tensor, sr)
        print(f"Saved: {output_path} ({len(waveform)/sr:.2f}s)")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--text", default=None)
    parser.add_argument("--audio", default=None, help="Audio path for reconstruction")
    parser.add_argument("--mode", choices=["synthesize", "reconstruct", "codec_only", "predict"],
                        default="synthesize")
    parser.add_argument("--output", default="output_v5.wav")
    parser.add_argument("--duration_scale", type=float, default=1.0)
    parser.add_argument("--temperature", type=float, default=0.0)
    args = parser.parse_args()

    tts = LeWMTTSv5Inference(args.checkpoint)

    if args.mode == "synthesize":
        assert args.text, "Need --text"
        waveform, sr = tts.synthesize(args.text, args.duration_scale, args.temperature)
    elif args.mode == "reconstruct":
        assert args.audio, "Need --audio"
        waveform, sr = tts.reconstruct(args.audio)
    elif args.mode == "codec_only":
        assert args.audio, "Need --audio"
        waveform, sr = tts.reconstruct_codec_only(args.audio)
    elif args.mode == "predict":
        assert args.audio and args.text, "Need --audio and --text"
        waveform, sr = tts.reconstruct_with_prediction(args.audio, args.text)

    tts.save_audio(waveform, sr, args.output)


if __name__ == "__main__":
    main()
