"""
RTF optimization v2: properly skip CFG + torch.compile with safe mode.

Key insight: cfg_scale=1.0 still runs negative LM + doubles diffusion batch.
We patch sample_speech_tokens to run single-path when cfg disabled.
We also patch generate() to skip negative LM forward entirely.
"""

import time
import threading
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def patch_no_cfg(model):
    """Patch model to truly skip CFG: no negative LM, single-path diffusion."""

    @torch.no_grad()
    def sample_speech_tokens_no_cfg(condition, neg_condition=None, cfg_scale=1.0):
        model.model.noise_scheduler.set_timesteps(model.ddpm_inference_steps)
        device = model.model.prediction_head.device
        condition = condition.to(device)
        speech = torch.randn(condition.shape[0], model.config.acoustic_vae_dim, device=device, dtype=condition.dtype)
        for t in model.model.noise_scheduler.timesteps:
            eps = model.model.prediction_head(speech, t.repeat(speech.shape[0]).to(speech), condition=condition)
            speech = model.model.noise_scheduler.step(eps, t, speech).prev_sample
        return speech

    model._original_sample = model.sample_speech_tokens
    model.sample_speech_tokens = sample_speech_tokens_no_cfg


def unpatch_no_cfg(model):
    if hasattr(model, '_original_sample'):
        model.sample_speech_tokens = model._original_sample


def patch_generate_skip_negative(model):
    """Patch generate to skip all negative/CFG computation."""
    import types
    from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceGenerationOutput
    from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
    from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
    from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceTokenConstraintProcessor

    original_generate = model.generate.__wrapped__ if hasattr(model.generate, '__wrapped__') else None

    @torch.no_grad()
    def fast_generate(
        self,
        inputs=None, generation_config=None, logits_processor=None,
        stopping_criteria=None, prefix_allowed_tokens_fn=None,
        synced_gpus=None, assistant_model=None, audio_streamer=None,
        negative_prompt_ids=None, negative_prompt_attention_mask=None,
        speech_tensors=None, speech_masks=None, speech_input_mask=None,
        is_prefill=True, return_speech=True, cfg_scale=1.0,
        stop_check_fn=None, tqdm_class=None, **kwargs,
    ):
        tokenizer = kwargs.pop("tokenizer", None)
        parsed_scripts = kwargs.pop("parsed_scripts", None)
        all_speakers_list = kwargs.pop("all_speakers_list", None)
        max_length_times = kwargs.pop("max_length_times", 2)

        if kwargs.get('max_new_tokens', None) is None:
            kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]

        generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
            generation_config, inputs, tokenizer, return_processors=True, **kwargs
        )

        acoustic_cache = VibeVoiceTokenizerStreamingCache()
        semantic_cache = VibeVoiceTokenizerStreamingCache()

        batch_size = input_ids.shape[0]
        device = input_ids.device
        finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
        inputs_embeds = None
        verbose = kwargs.get("verbose", False)
        audio_chunks = [[] for _ in range(batch_size)]

        initial_length = input_ids.shape[-1]
        initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)

        valid_tokens = [
            generation_config.speech_start_id, generation_config.speech_end_id,
            generation_config.speech_diffusion_id, generation_config.eos_token_id,
        ]
        if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
            valid_tokens.append(generation_config.bos_token_id)
        token_constraint = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(token_constraint)

        max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
        max_step_per_sample = torch.min(
            generation_config.max_length - initial_length_per_sample,
            (max_length_times * initial_length_per_sample).long()
        )
        reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for step in range(max_steps):
            if finished_tags.all():
                break
            if input_ids.shape[-1] >= generation_config.max_length:
                break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            if is_prefill:
                prefill_inputs = {}
                if speech_tensors is not None:
                    prefill_inputs["speech_tensors"] = speech_tensors.to(device=device)
                if speech_masks is not None:
                    prefill_inputs["speech_masks"] = speech_masks.to(device)
                if speech_input_mask is not None:
                    prefill_inputs["speech_input_mask"] = speech_input_mask.to(device)
                is_prefill = False
            else:
                _ = model_inputs.pop('inputs_embeds', None)
                prefill_inputs = {'inputs_embeds': inputs_embeds}

            outputs = self(
                **model_inputs, **prefill_inputs, logits_to_keep=1,
                return_dict=True, output_attentions=False, output_hidden_states=False,
            )
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=False,
            )

            next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=device)
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_tokens = torch.argmax(next_token_scores, dim=-1)

            next_tokens[finished_tags] = generation_config.eos_token_id
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            if (next_tokens == generation_config.eos_token_id).any():
                eos_idx = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
                new_eos = eos_idx[~finished_tags[eos_idx]]
                if new_eos.numel() > 0:
                    finished_tags[new_eos] = True
                    if audio_streamer is not None:
                        audio_streamer.end(new_eos)

            diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
            if diffusion_end_indices.numel() > 0:
                acoustic_cache.set_to_zero(diffusion_end_indices)
                semantic_cache.set_to_zero(diffusion_end_indices)

            max_length_reached = step >= max_step_per_sample
            new_max = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
            if new_max.numel() > 0:
                finished_tags[new_max] = True
                reach_max_step_sample[new_max] = True
                if audio_streamer is not None:
                    audio_streamer.end(new_max)

            next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1)

            diffusion_indices = torch.arange(batch_size, device=device)[
                ~finished_tags & (next_tokens == generation_config.speech_diffusion_id)
            ]

            if diffusion_indices.numel() > 0:
                positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]

                # NO negative pass, NO CFG - single path diffusion
                speech_latent = self.sample_speech_tokens(positive_condition).unsqueeze(1)

                scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
                audio_chunk = self.model.acoustic_tokenizer.decode(
                    scaled_latent.to(self.model.acoustic_tokenizer.device),
                    cache=acoustic_cache, sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
                    use_cache=True, debug=False,
                )

                for i, sample_idx in enumerate(diffusion_indices):
                    idx = sample_idx.item()
                    if not finished_tags[idx]:
                        audio_chunks[idx].append(audio_chunk[i])

                if audio_streamer is not None:
                    audio_streamer.put(audio_chunk, diffusion_indices)

                semantic_features = self.model.semantic_tokenizer.encode(
                    audio_chunk, cache=semantic_cache, sample_indices=diffusion_indices,
                    use_cache=True, debug=False,
                ).mean

                acoustic_embed = self.model.acoustic_connector(speech_latent)
                semantic_embed = self.model.semantic_connector(semantic_features)
                next_inputs_embeds[diffusion_indices] = acoustic_embed + semantic_embed

            inputs_embeds = next_inputs_embeds

        if audio_streamer is not None:
            audio_streamer.end()

        final_audio = []
        for chunks in audio_chunks:
            if chunks:
                final_audio.append(torch.cat(chunks, dim=-1))
            else:
                final_audio.append(None)

        return VibeVoiceGenerationOutput(
            sequences=input_ids, speech_outputs=final_audio if return_speech else None,
            reach_max_step_sample=reach_max_step_sample,
        )

    model._original_generate = model.generate
    model.generate = types.MethodType(fast_generate, model)


def measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3):
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"
    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    first_chunk = [None]
    total_samp = [0]
    start = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk[0] is None:
                first_chunk[0] = t
            total_samp[0] += chunk.shape[-1]

    th = threading.Thread(target=consumer, daemon=True)
    th.start()

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    start[0] = time.perf_counter()
    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )
    torch.cuda.synchronize()
    end = time.perf_counter()
    th.join(timeout=5)

    gen = end - start[0]
    ttfb = (first_chunk[0] - start[0]) * 1000 if first_chunk[0] else -1
    dur = total_samp[0] / 24000.0
    rtf = gen / dur if dur > 0 else float("inf")
    return {"ttfb_ms": ttfb, "gen_s": gen, "audio_s": dur, "rtf": rtf}


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    configs = []

    # Warmup
    print("Warming up...")
    model.set_ddpm_inference_steps(20)
    _ = measure(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=20)

    # 1. Baseline
    print("[1] Baseline: original generate, cfg=1.3, 20 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Baseline (cfg=1.3, 20 steps)", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 2. Patched: skip negative LM + single-path diffusion, 20 steps
    print("\n[2] NO-CFG patched: skip negative LM + single-path diffusion, 20 steps")
    patch_no_cfg(model)
    patch_generate_skip_negative(model)
    _ = measure(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=20)
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("NO-CFG patched, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 3. NO-CFG + torch.compile (default mode) on LM
    print("\n[3] NO-CFG + torch.compile(LM, mode=default), 20 steps")
    model.model.language_model = torch.compile(model.model.language_model, mode="default")
    _ = measure(model, processor, voice_path, "Speaker 1: compile warmup.", ddpm_steps=20)
    _ = measure(model, processor, voice_path, "Speaker 1: compile warmup2.", ddpm_steps=20)
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("NO-CFG + compile(LM), 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 4. NO-CFG + compile LM + compile diffusion head
    print("\n[4] NO-CFG + compile(LM+diffusion), 20 steps")
    model.model.prediction_head = torch.compile(model.model.prediction_head, mode="default")
    _ = measure(model, processor, voice_path, "Speaker 1: diff compile warmup.", ddpm_steps=20)
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("NO-CFG + compile(LM+diff), 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # 5. NO-CFG + compile all + 10 steps
    print("\n[5] NO-CFG + compile(LM+diff), 10 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.0)
    configs.append(("NO-CFG + compile all, 10 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms")

    # Summary
    print(f"\n{'='*90}")
    print(f"{'Config':<45} {'RTF':>7} {'TTFB':>8} {'Audio':>8} {'Stream?':>8}")
    print(f"{'-'*90}")
    for name, r in configs:
        can = "YES" if r["rtf"] < 1.0 else "NO"
        print(f"{name:<45} {r['rtf']:>6.3f}x {r['ttfb_ms']:>7.0f}ms {r['audio_s']:>7.2f}s {can:>7}")
    print(f"{'='*90}")


if __name__ == "__main__":
    main()
