"""
Production inference engine for LeWM TTS.
Features:
  - Batched inference (many concurrent requests in parallel)
  - KV-cache (O(1) per step instead of O(n))
  - FP16
  - Audio prompt support (for multi-speaker / voice cloning)
  - Continuous batching (add/remove requests dynamically)
  - Griffin-Lim + Vocos vocoder pipeline
"""

import torch
import torch.nn.functional as F
import torchaudio
import numpy as np
import time
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path

from model import LeWMTTS


# ─── Vocoder ────────────────────────────────────────────────────────────────

class Vocoder:
    """Vocos vocoder for 100-mel → waveform (direct decode, no Griffin-Lim)."""

    def __init__(self):
        from vocos import Vocos
        self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
        self.vocos.eval()

    def __call__(self, mel):
        """mel: [B, 100, T] → waveforms list of [T_audio] numpy arrays."""
        with torch.no_grad():
            # Vocos decode handles batched input
            waveforms = []
            for i in range(mel.shape[0]):
                wav = self.vocos.decode(mel[i:i+1])  # [1, T_audio]
                waveforms.append(wav.squeeze(0).numpy())
        return waveforms


# ─── Request / Result ───────────────────────────────────────────────────────

@dataclass
class TTSRequest:
    text: str
    max_steps: int = 300
    temperature: float = 0.0
    prompt_mel: Optional[torch.Tensor] = None  # [1, 80, T] for voice cloning


@dataclass
class TTSResult:
    waveform: np.ndarray
    sample_rate: int = 24000
    steps_taken: int = 0
    generation_time: float = 0.0

    def save(self, path):
        wav = torch.from_numpy(self.waveform).unsqueeze(0)
        torchaudio.save(path, wav, self.sample_rate)

    @property
    def duration(self):
        return len(self.waveform) / self.sample_rate


# ─── Inference Engine ───────────────────────────────────────────────────────

class InferenceEngine:
    """
    High-throughput batched TTS inference engine.

    Usage:
        engine = InferenceEngine("checkpoint.pt")

        # Single request
        result = engine.synthesize("नमस्ते भारत")

        # Batched (high throughput)
        results = engine.synthesize_batch(["text1", "text2", ...])

        # With voice prompt
        result = engine.synthesize("text", prompt_audio="ref.wav")

        # Benchmark
        engine.benchmark()
    """

    def __init__(self, checkpoint_path, device="cuda", dtype=torch.float16):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.dtype = dtype

        # Load model
        ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        self.config = ckpt["config"]
        self.model = LeWMTTS(self.config).to(self.device).to(dtype)
        self.model.load_state_dict(ckpt["model"])
        self.model.eval()

        self.d_model = self.config["d_model"]

        # Vocoder (runs on CPU)
        self.vocoder = Vocoder()

        print(f"Engine ready: {checkpoint_path}")
        print(f"  Device: {self.device}, Dtype: {dtype}")
        print(f"  Model: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M params")

    # ─── Text encoding ──────────────────────────────────────────────────

    def _encode_texts(self, texts):
        """Encode list of texts → padded embeddings + mask.
        Returns: text_emb [B, T_max, d], text_mask [B, T_max] (True=padding)
        """
        # Byte-level tokenization
        token_lists = [list(t.encode("utf-8")) for t in texts]
        max_len = max(len(t) for t in token_lists)

        # Pad to same length
        padded = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device)
        mask = torch.ones(len(texts), max_len, dtype=torch.bool, device=self.device)

        for i, tokens in enumerate(token_lists):
            padded[i, :len(tokens)] = torch.tensor(tokens, dtype=torch.long)
            mask[i, :len(tokens)] = False

        with torch.no_grad():
            text_emb = self.model.text_encoder(padded, text_mask=mask)

        return text_emb, mask

    # ─── Audio prompt encoding ──────────────────────────────────────────

    def _encode_prompt(self, prompt_mel):
        """Encode prompt mel → embeddings for AR seeding.
        prompt_mel: [1, 80, T]
        Returns: prompt_embs [1, T_down, d]
        """
        prompt_mel = prompt_mel.to(self.device).to(self.dtype)
        with torch.no_grad():
            embs = self.model.encode_audio(prompt_mel)  # [1, T_down, d]
        return embs

    def load_prompt_audio(self, audio_path, max_seconds=3.0):
        """Load audio file → mel for prompting.
        Returns: mel [1, 80, T]
        """
        wav, sr = torchaudio.load(audio_path)
        if sr != 24000:
            wav = torchaudio.functional.resample(wav, sr, 24000)
        # Trim to max_seconds
        max_samples = int(max_seconds * 24000)
        wav = wav[:1, :max_samples]  # mono, trimmed

        mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=24000, n_fft=1024, hop_length=256,
            n_mels=100, power=1.0,
        )
        mel = mel_transform(wav)
        mel = torch.log(mel.clamp(min=1e-7))
        return mel  # [1, 80, T]

    # ─── Core AR generation (batched) ───────────────────────────────────

    def _generate_batch(self, text_emb, text_mask, batch_size,
                        max_steps=300, temperature=0.0,
                        prompt_embs=None):
        """
        Core batched AR generation with KV-cache.

        Args:
            text_emb: [B, T_text, d]
            text_mask: [B, T_text]
            batch_size: int
            max_steps: int
            temperature: float
            prompt_embs: list of [1, T_prompt, d] or None per request

        Returns:
            all_embs: [B, T_total, d]
            steps_taken: list of int per request
        """
        d = self.d_model
        cache = self.model.init_ar_cache()

        # Track per-request state
        active = torch.ones(batch_size, dtype=torch.bool, device=self.device)
        steps_taken = [0] * batch_size
        all_emb_lists = [[] for _ in range(batch_size)]
        prev_norms = [[] for _ in range(batch_size)]

        # Handle prompts: seed cache with prompt embeddings
        if prompt_embs is not None:
            max_prompt_len = max(
                (e.shape[1] for e in prompt_embs if e is not None), default=0
            )
            if max_prompt_len > 0:
                # Process prompt embeddings through cache step by step
                # Pad prompts to same length
                prompt_padded = torch.zeros(
                    batch_size, max_prompt_len, d,
                    device=self.device, dtype=self.dtype,
                )
                prompt_lengths = []
                for i, pe in enumerate(prompt_embs):
                    if pe is not None:
                        L = pe.shape[1]
                        prompt_padded[i, :L] = pe[0]
                        prompt_lengths.append(L)
                    else:
                        prompt_lengths.append(0)

                # Feed prompt through cache one step at a time
                for t in range(max_prompt_len):
                    step_emb = prompt_padded[:, t:t+1, :]  # [B, 1, d]
                    _, cache = self.model.predict_next_cached(
                        step_emb, text_emb, t, cache, text_mask=text_mask,
                    )
                    # Store embeddings
                    for i in range(batch_size):
                        if t < prompt_lengths[i]:
                            all_emb_lists[i].append(step_emb[i:i+1])

                step_offset = max_prompt_len
                # Use last prompt prediction as first AR input
                last_pred, cache = self.model.predict_next_cached(
                    prompt_padded[:, -1:, :], text_emb, step_offset, cache,
                    text_mask=text_mask,
                )
                current_emb = last_pred
                step_offset += 1
            else:
                step_offset = 0
                current_emb = None
        else:
            step_offset = 0
            current_emb = None

        # If no prompt, start with learned start embedding
        if current_emb is None:
            start_emb = self.model.start_emb.expand(batch_size, -1, -1).to(self.dtype)
            current_emb, cache = self.model.predict_next_cached(
                start_emb, text_emb, 0, cache, text_mask=text_mask,
            )
            for i in range(batch_size):
                all_emb_lists[i].append(start_emb[i:i+1])
            step_offset = 1

        if temperature > 0:
            current_emb = current_emb + torch.randn_like(current_emb) * temperature

        for i in range(batch_size):
            all_emb_lists[i].append(current_emb[i:i+1])

        # AR loop
        for step in range(1, max_steps):
            if not active.any():
                break

            next_emb, cache = self.model.predict_next_cached(
                current_emb, text_emb, step + step_offset, cache,
                text_mask=text_mask,
            )

            if temperature > 0:
                next_emb = next_emb + torch.randn_like(next_emb) * temperature

            # Store and check stopping per request
            for i in range(batch_size):
                if not active[i]:
                    continue

                all_emb_lists[i].append(next_emb[i:i+1])
                steps_taken[i] = step

                norm = next_emb[i].norm().item()
                prev_norms[i].append(norm)
                if len(prev_norms[i]) > 20:
                    recent = prev_norms[i][-20:]
                    if np.std(recent) < 0.01 * np.mean(recent) and step > 30:
                        active[i] = False

            current_emb = next_emb

        # Concatenate embeddings per request
        results = []
        for i in range(batch_size):
            embs = torch.cat(all_emb_lists[i], dim=1)  # [1, T, d]
            results.append(embs)

        return results, steps_taken

    # ─── Public API ─────────────────────────────────────────────────────

    def synthesize(self, text, max_steps=300, temperature=0.0,
                   prompt_audio=None, prompt_seconds=3.0):
        """Synthesize single text.

        Args:
            text: Hindi text string
            max_steps: max AR steps
            temperature: sampling temperature (0 = greedy)
            prompt_audio: path to reference audio for voice cloning
            prompt_seconds: max seconds of prompt to use

        Returns: TTSResult
        """
        return self.synthesize_batch(
            [text], max_steps=max_steps, temperature=temperature,
            prompt_audios=[prompt_audio] if prompt_audio else None,
            prompt_seconds=prompt_seconds,
        )[0]

    def synthesize_batch(self, texts, max_steps=300, temperature=0.0,
                         prompt_audios=None, prompt_seconds=3.0):
        """Synthesize a batch of texts in parallel.

        Args:
            texts: list of text strings
            max_steps: max AR steps
            temperature: sampling temperature
            prompt_audios: list of audio paths (or None) per text
            prompt_seconds: max seconds of prompt

        Returns: list of TTSResult
        """
        B = len(texts)
        t0 = time.time()

        # Encode texts
        text_emb, text_mask = self._encode_texts(texts)

        # Encode prompts if given
        prompt_embs = None
        if prompt_audios is not None:
            prompt_embs = []
            for pa in prompt_audios:
                if pa is not None:
                    mel = self.load_prompt_audio(pa, max_seconds=prompt_seconds)
                    embs = self._encode_prompt(mel)
                    prompt_embs.append(embs)
                else:
                    prompt_embs.append(None)

        # Generate
        with torch.no_grad():
            emb_results, steps_taken = self._generate_batch(
                text_emb, text_mask, B,
                max_steps=max_steps, temperature=temperature,
                prompt_embs=prompt_embs,
            )

        # Decode to mel → waveform
        results = []
        for i in range(B):
            with torch.no_grad():
                mel = self.model.mel_decoder(emb_results[i].to(self.device))
                # mel: [1, 80, T]
            wav_list = self.vocoder(mel.float().cpu())
            waveform = wav_list[0]

            # Normalize
            mx = np.abs(waveform).max()
            if mx > 0:
                waveform = waveform / mx * 0.95

            results.append(TTSResult(
                waveform=waveform,
                steps_taken=steps_taken[i],
                generation_time=time.time() - t0,
            ))

        return results

    # ─── Benchmark ──────────────────────────────────────────────────────

    def benchmark(self, batch_sizes=None, steps=300):
        """Benchmark throughput at various batch sizes."""
        if batch_sizes is None:
            batch_sizes = [1, 4, 16, 32, 64, 128, 256]

        text = "भारत एक महान देश है और हम सब भारतवासी हैं।"
        audio_per_step = 4 * 256 / 24000  # seconds of audio per AR step

        print(f"\n{'='*70}")
        print(f"  LeWM TTS Inference Benchmark — {self.device}, {self.dtype}")
        print(f"  Model: {self.config['d_model']}d, {self.config['predictor_layers']}L predictor")
        print(f"  AR steps: {steps}")
        print(f"{'='*70}")
        print(f"  {'Batch':>6} {'Time(s)':>8} {'Audio(s)':>9} {'RTF':>8} "
              f"{'Speed':>8} {'Throughput':>12} {'Per-req':>10}")
        print(f"  {'-'*6} {'-'*8} {'-'*9} {'-'*8} {'-'*8} {'-'*12} {'-'*10}")

        for B in batch_sizes:
            try:
                texts = [text] * B
                token_lists = [list(t.encode("utf-8")) for t in texts]
                max_len = max(len(t) for t in token_lists)
                padded = torch.zeros(B, max_len, dtype=torch.long, device=self.device)
                mask = torch.ones(B, max_len, dtype=torch.bool, device=self.device)
                for i, tokens in enumerate(token_lists):
                    padded[i, :len(tokens)] = torch.tensor(tokens, dtype=torch.long)
                    mask[i, :len(tokens)] = False

                with torch.no_grad():
                    text_emb = self.model.text_encoder(padded, text_mask=mask)

                # Warmup
                if B == batch_sizes[0]:
                    cache = self.model.init_ar_cache()
                    start = self.model.start_emb.expand(B, -1, -1).to(self.dtype)
                    nxt = start
                    for s in range(5):
                        nxt, cache = self.model.predict_next_cached(
                            nxt, text_emb, s, cache, text_mask=mask)
                    torch.cuda.synchronize()

                # Benchmark
                cache = self.model.init_ar_cache()
                start = self.model.start_emb.expand(B, -1, -1).to(self.dtype)
                nxt = start

                torch.cuda.synchronize()
                t0 = time.time()
                for s in range(steps):
                    nxt, cache = self.model.predict_next_cached(
                        nxt, text_emb, s, cache, text_mask=mask)
                torch.cuda.synchronize()
                elapsed = time.time() - t0

                total_audio = B * steps * audio_per_step
                rtf = elapsed / (steps * audio_per_step)
                speed = (steps * audio_per_step) / elapsed
                throughput = total_audio / elapsed
                per_req = elapsed / B

                print(f"  {B:>6} {elapsed:>8.3f} {total_audio:>9.1f} {rtf:>8.4f} "
                      f"{speed:>7.0f}x {throughput:>10.0f}x {per_req:>9.3f}s")

            except torch.cuda.OutOfMemoryError:
                print(f"  {B:>6} {'OOM':>8}")
                torch.cuda.empty_cache()
                break

        print(f"{'='*70}")
        print(f"  Throughput = total audio-seconds generated per wall-second")
        print(f"  Speed = single-request real-time factor")
        print(f"  Per-req = wall-time per request in batch")
        print()


# ─── CLI ────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="LeWM TTS Inference Engine")
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--text", default=None, help="Text to synthesize")
    parser.add_argument("--texts", nargs="+", help="Multiple texts for batch synthesis")
    parser.add_argument("--prompt", default=None, help="Reference audio for voice cloning")
    parser.add_argument("--output", default="output.wav")
    parser.add_argument("--output_dir", default=None, help="Output dir for batch mode")
    parser.add_argument("--max_steps", type=int, default=300)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--benchmark", action="store_true", help="Run throughput benchmark")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--fp32", action="store_true", help="Use FP32 instead of FP16")
    args = parser.parse_args()

    dtype = torch.float32 if args.fp32 else torch.float16
    engine = InferenceEngine(args.checkpoint, device=args.device, dtype=dtype)

    if args.benchmark:
        engine.benchmark()

    elif args.texts:
        results = engine.synthesize_batch(
            args.texts, max_steps=args.max_steps, temperature=args.temperature,
            prompt_audios=[args.prompt] * len(args.texts) if args.prompt else None,
        )
        out_dir = Path(args.output_dir or "output")
        out_dir.mkdir(exist_ok=True)
        for i, r in enumerate(results):
            path = out_dir / f"batch_{i:03d}.wav"
            r.save(str(path))
            print(f"  [{i}] {path} — {r.duration:.2f}s, {r.steps_taken} steps")

    elif args.text:
        result = engine.synthesize(
            args.text, max_steps=args.max_steps, temperature=args.temperature,
            prompt_audio=args.prompt,
        )
        result.save(args.output)
        print(f"Saved: {args.output} — {result.duration:.2f}s, "
              f"{result.steps_taken} steps, {result.generation_time:.2f}s")
    else:
        parser.print_help()
