"""
RTF optimization v3: keep CFG for voice quality, but optimize compute.

Approach 1: torch.compile(mode=default) on LM + diffusion WITH full CFG
Approach 2: Cached negative condition - compute once, reuse for all tokens
            (saves the negative LM forward on every token)
Approach 3: Combined
"""

import time
import types
import threading
import torch
from collections import defaultdict

from vibevoice.modular.modeling_vibevoice_inference import (
    VibeVoiceForConditionalGenerationInference,
    VibeVoiceGenerationOutput,
    VibeVoiceTokenConstraintProcessor,
)
from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from transformers.generation import LogitsProcessorList


def patch_cached_negative(model):
    """
    Patch generate to cache the negative condition after first computation.
    Instead of running negative LM forward every diffusion token,
    run it once and reuse the hidden state for CFG.
    """

    @torch.no_grad()
    def fast_generate_cached_neg(
        self, inputs=None, generation_config=None, logits_processor=None,
        stopping_criteria=None, prefix_allowed_tokens_fn=None,
        synced_gpus=None, assistant_model=None, audio_streamer=None,
        negative_prompt_ids=None, negative_prompt_attention_mask=None,
        speech_tensors=None, speech_masks=None, speech_input_mask=None,
        is_prefill=True, return_speech=True, cfg_scale=1.3,
        stop_check_fn=None, tqdm_class=None, **kwargs,
    ):
        tokenizer = kwargs.pop("tokenizer", None)
        kwargs.pop("parsed_scripts", None)
        kwargs.pop("all_speakers_list", None)
        max_length_times = kwargs.pop("max_length_times", 2)

        if kwargs.get('max_new_tokens', None) is None:
            kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]

        generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
            generation_config, inputs, tokenizer, return_processors=True, **kwargs
        )

        # Build negative prompt once
        negative_kwargs = {
            'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
            'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
            'max_new_tokens': kwargs.get('max_new_tokens', 100)
        }
        negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
            None, None, tokenizer, return_processors=False, **negative_kwargs
        )

        # Run negative forward ONCE to get the baseline condition
        neg_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
        negative_outputs = self(
            **neg_model_inputs, logits_to_keep=0, return_dict=True,
            output_attentions=False, output_hidden_states=False,
        )
        cached_neg_condition = negative_outputs.last_hidden_state[:, -1, :].clone()

        acoustic_cache = VibeVoiceTokenizerStreamingCache()
        semantic_cache = VibeVoiceTokenizerStreamingCache()

        batch_size = input_ids.shape[0]
        device = input_ids.device
        finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
        inputs_embeds = None
        verbose = kwargs.get("verbose", False)
        audio_chunks = [[] for _ in range(batch_size)]

        initial_length = input_ids.shape[-1]
        initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)

        valid_tokens = [
            generation_config.speech_start_id, generation_config.speech_end_id,
            generation_config.speech_diffusion_id, generation_config.eos_token_id,
        ]
        if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
            valid_tokens.append(generation_config.bos_token_id)
        token_constraint = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(token_constraint)

        max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
        max_step_per_sample = torch.min(
            generation_config.max_length - initial_length_per_sample,
            (max_length_times * initial_length_per_sample).long()
        )
        reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for step in range(max_steps):
            if finished_tags.all():
                break
            if input_ids.shape[-1] >= generation_config.max_length:
                break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            if is_prefill:
                prefill_inputs = {}
                if speech_tensors is not None:
                    prefill_inputs["speech_tensors"] = speech_tensors.to(device=device)
                if speech_masks is not None:
                    prefill_inputs["speech_masks"] = speech_masks.to(device)
                if speech_input_mask is not None:
                    prefill_inputs["speech_input_mask"] = speech_input_mask.to(device)
                is_prefill = False
            else:
                _ = model_inputs.pop('inputs_embeds', None)
                prefill_inputs = {'inputs_embeds': inputs_embeds}

            # Single positive LM forward (no negative!)
            outputs = self(
                **model_inputs, **prefill_inputs, logits_to_keep=1,
                return_dict=True, output_attentions=False, output_hidden_states=False,
            )
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=False,
            )

            next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=device)
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_tokens = torch.argmax(next_token_scores, dim=-1)

            next_tokens[finished_tags] = generation_config.eos_token_id
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            if (next_tokens == generation_config.eos_token_id).any():
                eos_idx = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
                new_eos = eos_idx[~finished_tags[eos_idx]]
                if new_eos.numel() > 0:
                    finished_tags[new_eos] = True
                    if audio_streamer is not None:
                        audio_streamer.end(new_eos)

            diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
            if diffusion_end_indices.numel() > 0:
                acoustic_cache.set_to_zero(diffusion_end_indices)
                semantic_cache.set_to_zero(diffusion_end_indices)

            max_length_reached = step >= max_step_per_sample
            new_max = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
            if new_max.numel() > 0:
                finished_tags[new_max] = True
                reach_max_step_sample[new_max] = True
                if audio_streamer is not None:
                    audio_streamer.end(new_max)

            next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1)

            diffusion_indices = torch.arange(batch_size, device=device)[
                ~finished_tags & (next_tokens == generation_config.speech_diffusion_id)
            ]

            if diffusion_indices.numel() > 0:
                positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
                negative_condition = cached_neg_condition[diffusion_indices]

                # Full CFG diffusion using cached negative
                speech_latent = self.sample_speech_tokens(
                    positive_condition, negative_condition, cfg_scale=cfg_scale,
                ).unsqueeze(1)

                scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
                audio_chunk = self.model.acoustic_tokenizer.decode(
                    scaled_latent.to(self.model.acoustic_tokenizer.device),
                    cache=acoustic_cache, sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
                    use_cache=True, debug=False,
                )

                for i, sample_idx in enumerate(diffusion_indices):
                    idx = sample_idx.item()
                    if not finished_tags[idx]:
                        audio_chunks[idx].append(audio_chunk[i])

                if audio_streamer is not None:
                    audio_streamer.put(audio_chunk, diffusion_indices)

                semantic_features = self.model.semantic_tokenizer.encode(
                    audio_chunk, cache=semantic_cache, sample_indices=diffusion_indices,
                    use_cache=True, debug=False,
                ).mean

                acoustic_embed = self.model.acoustic_connector(speech_latent)
                semantic_embed = self.model.semantic_connector(semantic_features)
                next_inputs_embeds[diffusion_indices] = acoustic_embed + semantic_embed

            inputs_embeds = next_inputs_embeds

        if audio_streamer is not None:
            audio_streamer.end()

        final_audio = []
        for chunks in audio_chunks:
            final_audio.append(torch.cat(chunks, dim=-1) if chunks else None)

        return VibeVoiceGenerationOutput(
            sequences=input_ids, speech_outputs=final_audio if return_speech else None,
            reach_max_step_sample=reach_max_step_sample,
        )

    model._orig_generate = model.generate
    model.generate = types.MethodType(fast_generate_cached_neg, model)


def measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3):
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"
    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    first_chunk = [None]
    total_samp = [0]
    start = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk[0] is None:
                first_chunk[0] = t
            total_samp[0] += chunk.shape[-1]

    th = threading.Thread(target=consumer, daemon=True)
    th.start()

    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    start[0] = time.perf_counter()
    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )
    torch.cuda.synchronize()
    end = time.perf_counter()
    th.join(timeout=5)

    gen = end - start[0]
    ttfb = (first_chunk[0] - start[0]) * 1000 if first_chunk[0] else -1
    dur = total_samp[0] / 24000.0
    rtf = gen / dur if dur > 0 else float("inf")
    return {"ttfb_ms": ttfb, "gen_s": gen, "audio_s": dur, "rtf": rtf}, outputs


def main():
    import os
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है, जहाँ टेक्नोलॉजी और इनोवेशन हमारी ताकत बन रही है."

    configs = []

    # Warmup
    print("Warming up baseline...")
    _ = measure(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=20)

    # 1. Baseline
    print("[1] Baseline: original, cfg=1.3, 20 steps")
    r, _ = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Baseline (full CFG, 20 steps)", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 2. Cached negative + full CFG
    print("\n[2] Cached negative + cfg=1.3, 20 steps")
    patch_cached_negative(model)
    _ = measure(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=20, cfg_scale=1.3)
    r, _ = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Cached neg + CFG=1.3, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 3. Cached neg + torch.compile(LM, default)
    print("\n[3] Cached neg + compile(LM), cfg=1.3, 20 steps")
    model.model.language_model = torch.compile(model.model.language_model, mode="default")
    _ = measure(model, processor, voice_path, "Speaker 1: compile warmup1.", ddpm_steps=20, cfg_scale=1.3)
    _ = measure(model, processor, voice_path, "Speaker 1: compile warmup2.", ddpm_steps=20, cfg_scale=1.3)
    r, _ = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Cached neg + compile(LM), 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 4. Cached neg + compile(LM + diffusion)
    print("\n[4] Cached neg + compile(LM+diff), cfg=1.3, 20 steps")
    model.model.prediction_head = torch.compile(model.model.prediction_head, mode="default")
    _ = measure(model, processor, voice_path, "Speaker 1: diff warmup.", ddpm_steps=20, cfg_scale=1.3)
    r, out = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Cached neg + compile(LM+diff), 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # Save the best quality sample
    os.makedirs("samples_v3", exist_ok=True)
    if out.speech_outputs[0] is not None:
        processor.save_audio(out.speech_outputs[0], output_path="samples_v3/modi_cached_cfg_compiled.wav")
        print(f"    Saved: samples_v3/modi_cached_cfg_compiled.wav")

    # Summary
    print(f"\n{'='*90}")
    print(f"{'Config':<50} {'RTF':>7} {'TTFB':>8} {'Audio':>8} {'Stream?':>8}")
    print(f"{'-'*90}")
    for name, r in configs:
        can = "YES" if r["rtf"] < 1.0 else "NO"
        print(f"{name:<50} {r['rtf']:>6.3f}x {r['ttfb_ms']:>7.0f}ms {r['audio_s']:>7.2f}s {can:>7}")
    print(f"{'='*90}")

    # Generate more samples with best config
    print("\nGenerating samples with best config (cached neg + compile + CFG=1.3)...")
    texts = [
        ('short', 'Speaker 1: नमस्ते, मेरे प्यारे देशवासियों.'),
        ('medium', 'Speaker 1: आज हम डिजिटल इंडिया की बात करते हैं. गाँव गाँव में इंटरनेट पहुँच रहा है. किसान अपने फोन से मंडी के भाव देख रहा है. यह बदलाव छोटा नहीं है, यह एक क्रांति है.'),
        ('speech', 'Speaker 1: भारत आज दुनिया की पाँचवीं सबसे बड़ी अर्थव्यवस्था है. हमारे युवाओं की ऊर्जा, हमारे वैज्ञानिकों की प्रतिभा, और हमारे किसानों की मेहनत, यही हमारी असली ताकत है.'),
    ]
    for label, txt in texts:
        torch.manual_seed(42); torch.cuda.manual_seed_all(42)
        r, out = measure(model, processor, voice_path, txt, ddpm_steps=20, cfg_scale=1.3)
        if out.speech_outputs[0] is not None:
            path = f"samples_v3/modi_{label}.wav"
            processor.save_audio(out.speech_outputs[0], output_path=path)
            print(f"  {label}: {r['audio_s']:.2f}s | RTF={r['rtf']:.3f}x | {path}")


if __name__ == "__main__":
    main()
