"""
Spark TTS Generation Pipeline

End-to-end pipeline for TTS generation with BiCodec audio tokenizer.
Migrated from Veena3/Orpheus (SNAC) to Spark TTS (BiCodec).
"""

import logging
import os
import re
import time
import uuid
from dataclasses import asdict, is_dataclass
from typing import Optional, List, Dict, Any, Tuple
from vllm import SamplingParams
from vllm.sampling_params import RequestOutputKind
import torch

logger = logging.getLogger(__name__)
_PERF_STAGE_LOG_ENABLED = os.environ.get("VEENA3_PERF_STAGE_LOG", "").strip().lower() in {
    "1",
    "true",
    "yes",
    "on",
}
_PERF_GPU_TIMING_ENABLED = os.environ.get("VEENA3_PERF_GPU_TIMING", "").strip().lower() in {
    "1",
    "true",
    "yes",
    "on",
}
_ASYNC_BICODEC_DECODE_ENABLED = os.environ.get("VEENA3_ASYNC_BICODEC_DECODE", "").strip().lower() in {
    "1",
    "true",
    "yes",
    "on",
}
_NON_STREAM_FINAL_ONLY_ENABLED = os.environ.get("VEENA3_NON_STREAM_FINAL_ONLY", "").strip().lower() in {
    "1",
    "true",
    "yes",
    "on",
}


def _log_perf_stage(event: str, extra: Dict[str, Any]) -> None:
    """Emit hot-path perf logs only when explicitly enabled or DEBUG logging is active."""
    if _PERF_STAGE_LOG_ENABLED:
        logger.info(event, extra=extra)
    elif logger.isEnabledFor(logging.DEBUG):
        logger.debug(event, extra=extra)

from veena3modal.core.constants import (
    TRAINING_STOP_TOKEN_IDS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_SEED,
    AUDIO_SAMPLE_RATE,
)


class SparkTTSPipeline:
    """
    End-to-end TTS pipeline for Spark TTS with BiCodec.
    
    Replaces SNAC-based token extraction with BiCodec regex-based extraction.
    """
    
    def __init__(
        self,
        model,
        prompt_builder,
        bicodec_decoder,
    ):
        """
        Initialize pipeline.
        
        Args:
            model: SparkTTS Model instance
            prompt_builder: IndicPromptBuilder instance (with Spark TTS format)
            bicodec_decoder: BiCodecDecoder instance
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.bicodec_decoder = bicodec_decoder
        
        print(f"🚀 SparkTTSPipeline initialized")

    @staticmethod
    def _p50(values: List[float]) -> float:
        """Compute a stable p50 for small lists."""
        if not values:
            return 0.0
        ordered = sorted(values)
        return ordered[(len(ordered) - 1) // 2]

    @staticmethod
    def _to_number(value: Any) -> Optional[float]:
        """Best-effort conversion of scalar values to float."""
        if value is None:
            return None
        if isinstance(value, (int, float)):
            return float(value)
        try:
            return float(value)
        except (TypeError, ValueError):
            return None

    def _extract_request_metrics(self, raw_metrics: Any) -> Dict[str, float]:
        """
        Convert vLLM request_output.metrics into a numeric dictionary.

        vLLM may return a dataclass-like object or a plain dict depending on version.
        """
        if raw_metrics is None:
            return {}

        numeric: Dict[str, float] = {}
        items: List[Tuple[str, Any]] = []

        if isinstance(raw_metrics, dict):
            items.extend(raw_metrics.items())
        else:
            # Modern vLLM versions expose RequestStateStats dataclasses with
            # fields like queued_ts/scheduled_ts/first_token_ts/last_token_ts.
            if is_dataclass(raw_metrics):
                try:
                    items.extend(asdict(raw_metrics).items())
                except Exception:
                    pass
            if hasattr(raw_metrics, "__dict__"):
                try:
                    items.extend(dict(raw_metrics.__dict__).items())
                except Exception:
                    pass

            # Fallback: scan public scalar attributes.
            for name in dir(raw_metrics):
                if name.startswith("_"):
                    continue
                try:
                    value = getattr(raw_metrics, name)
                except Exception:
                    continue
                if callable(value):
                    continue
                items.append((name, value))

        seen: set[str] = set()
        for key, value in items:
            key_s = str(key)
            if key_s in seen:
                continue
            seen.add(key_s)
            parsed = self._to_number(value)
            if parsed is not None:
                numeric[key_s] = parsed
        return numeric

    def _apply_stats(self, perf: Dict[str, Any], prefix: str, values: List[float]) -> None:
        """Populate total/min/max/p50 stats for a list of ms samples."""
        if not values:
            return
        perf[f"{prefix}_total"] = float(sum(values))
        perf[f"{prefix}_min"] = float(min(values))
        perf[f"{prefix}_max"] = float(max(values))
        perf[f"{prefix}_p50"] = float(self._p50(values))

    async def generate_speech_profiled(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = 1.05,
        seed: Optional[int] = None,
    ) -> Tuple[Optional[bytes], Dict[str, Any]]:
        """
        Generate speech with detailed per-request timing metrics.

        Returns:
            (wav_bytes, perf_dict)
        """
        from veena3modal.audio.utils import add_wav_header

        perf: Dict[str, Any] = {
            "llm_prompt_token_total": 0,
            "llm_token_total": 0,
            "semantic_token_total": 0,
            "global_token_total": 0,
            "llm_batch_count": 0,
            "llm_parse_ms": 0.0,
            "llm_decode_calls": 0,
            "bicodec_decode_wall_ms": 0.0,
            "bicodec_decode_gpu_ms": 0.0,
            "bicodec_decode_cpu_ms": 0.0,
            "wav_pack_ms": 0.0,
        }

        t_request_start = time.perf_counter()
        timeline_marks: List[Dict[str, float]] = [{"stage": "request_start", "t_ms": 0.0, "dt_ms": 0.0}]
        last_mark_ms = 0.0
        perf["timeline_markers"] = timeline_marks
        perf["timeline_request_start_ms"] = 0.0

        def _mark(stage: str) -> float:
            nonlocal last_mark_ms
            now_ms = (time.perf_counter() - t_request_start) * 1000
            dt_ms = now_ms - last_mark_ms
            last_mark_ms = now_ms
            timeline_marks.append(
                {
                    "stage": stage,
                    "t_ms": float(now_ms),
                    "dt_ms": float(dt_ms),
                }
            )
            perf[f"timeline_{stage}_ms"] = float(now_ms)
            perf[f"timeline_{stage}_delta_ms"] = float(dt_ms)
            return float(now_ms)

        def _finalize_timeline() -> None:
            if not timeline_marks or timeline_marks[-1].get("stage") != "request_done":
                _mark("request_done")

            # Keep a compact, ordered timeline payload for downstream benchmarking.
            perf["timeline_markers"] = list(timeline_marks)
            perf["timeline_total_ms"] = float(perf.get("timeline_request_done_ms", 0.0))

            first_batch_ms = perf.get("timeline_llm_first_batch_ms")
            llm_done_ms = perf.get("timeline_llm_done_ms")
            parse_done_ms = perf.get("timeline_parse_done_ms")
            bicodec_done_ms = perf.get("timeline_bicodec_done_ms")
            request_done_ms = perf.get("timeline_request_done_ms")

            if isinstance(first_batch_ms, (int, float)):
                perf["timeline_to_first_batch_ms"] = float(first_batch_ms)
            if isinstance(first_batch_ms, (int, float)) and isinstance(llm_done_ms, (int, float)) and llm_done_ms >= first_batch_ms:
                perf["timeline_first_batch_to_llm_done_ms"] = float(llm_done_ms - first_batch_ms)
            if isinstance(llm_done_ms, (int, float)) and isinstance(request_done_ms, (int, float)) and request_done_ms >= llm_done_ms:
                perf["timeline_post_llm_ms"] = float(request_done_ms - llm_done_ms)
            if isinstance(parse_done_ms, (int, float)) and isinstance(bicodec_done_ms, (int, float)) and bicodec_done_ms >= parse_done_ms:
                perf["timeline_parse_to_bicodec_done_ms"] = float(bicodec_done_ms - parse_done_ms)
            if isinstance(bicodec_done_ms, (int, float)) and isinstance(request_done_ms, (int, float)) and request_done_ms >= bicodec_done_ms:
                perf["timeline_bicodec_to_request_done_ms"] = float(request_done_ms - bicodec_done_ms)

        # Stage 1: prompt build
        t0 = time.perf_counter()
        prompt = self.prompt_builder.build_prefix(speaker, text)
        perf["prompt_build_ms"] = (time.perf_counter() - t0) * 1000
        _mark("prompt_ready")

        # Prompt token count (for TPS normalization)
        try:
            perf["llm_prompt_token_total"] = len(
                self.model.tokenizer.encode(prompt, add_special_tokens=False)
            )
        except Exception:
            perf["llm_prompt_token_total"] = 0

        # Stage 2: sampling params
        t0 = time.perf_counter()
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,
            skip_special_tokens=False,
            seed=seed,
            output_kind=(RequestOutputKind.FINAL_ONLY if _NON_STREAM_FINAL_ONLY_ENABLED else RequestOutputKind.CUMULATIVE),
        )
        perf["sampling_params_ms"] = (time.perf_counter() - t0) * 1000
        _mark("sampling_ready")

        request_id = f"req-{uuid.uuid4().hex[:12]}"
        _log_perf_stage(
            "perf_non_stream_start",
            extra={
                "request_id": request_id,
                "speaker": speaker,
                "text_length": len(text),
                "max_tokens": max_tokens,
            },
        )
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )

        final_output = None
        final_request_metrics: Dict[str, float] = {}

        llm_batch_wall_ms: List[float] = []
        llm_batch_gpu_ms: List[float] = []
        llm_decode_wall_ms: List[float] = []
        llm_decode_gpu_ms: List[float] = []
        tokens_per_batch: List[float] = []

        prev_generated_count = 0
        prev_gpu_cumulative_s: Optional[float] = None

        t_gen_start = time.perf_counter()
        t_last_yield = t_gen_start
        async for request_output in results_generator:
            now = time.perf_counter()
            batch_wall = (now - t_last_yield) * 1000
            t_last_yield = now

            outputs = request_output.outputs or []
            if not outputs:
                final_output = request_output
                continue

            output = outputs[0]
            generated_count = len(output.token_ids)
            new_tokens = max(0, generated_count - prev_generated_count)
            prev_generated_count = generated_count

            raw_req_metrics = self._extract_request_metrics(getattr(request_output, "metrics", None))
            if raw_req_metrics:
                final_request_metrics = raw_req_metrics

            if new_tokens <= 0:
                final_output = request_output
                continue

            perf["llm_batch_count"] += 1
            if perf["llm_batch_count"] == 1:
                first_batch_timeline_ms = _mark("llm_first_batch")
                _log_perf_stage(
                    "perf_first_llm_batch",
                    extra={
                        "request_id": request_id,
                        "batch_wall_ms": round(batch_wall, 3),
                        "new_tokens": new_tokens,
                        "timeline_ms": round(first_batch_timeline_ms, 3),
                    },
                )
            tokens_per_batch.append(float(new_tokens))
            llm_batch_wall_ms.append(float(batch_wall))
            if perf["llm_batch_count"] > 1:
                llm_decode_wall_ms.append(float(batch_wall))

            cumulative_gpu_s: Optional[float] = None
            if raw_req_metrics:
                if "model_execute_time" in raw_req_metrics:
                    cumulative_gpu_s = raw_req_metrics["model_execute_time"]
                elif "model_forward_time" in raw_req_metrics:
                    cumulative_gpu_s = raw_req_metrics["model_forward_time"]

            if cumulative_gpu_s is not None:
                if prev_gpu_cumulative_s is None:
                    gpu_delta_ms = cumulative_gpu_s * 1000
                elif cumulative_gpu_s >= prev_gpu_cumulative_s:
                    gpu_delta_ms = (cumulative_gpu_s - prev_gpu_cumulative_s) * 1000
                else:
                    # Some versions report non-cumulative per-step values.
                    gpu_delta_ms = cumulative_gpu_s * 1000
                prev_gpu_cumulative_s = cumulative_gpu_s

                if gpu_delta_ms >= 0:
                    llm_batch_gpu_ms.append(float(gpu_delta_ms))
                    if perf["llm_batch_count"] > 1:
                        llm_decode_gpu_ms.append(float(gpu_delta_ms))

            final_output = request_output

        perf["llm_generation_wall_ms"] = (time.perf_counter() - t_gen_start) * 1000
        llm_done_timeline_ms = _mark("llm_done")
        _log_perf_stage(
            "perf_llm_generation_done",
            extra={
                "request_id": request_id,
                "llm_generation_wall_ms": round(perf["llm_generation_wall_ms"], 3),
                "llm_batch_count": perf["llm_batch_count"],
                "timeline_ms": round(llm_done_timeline_ms, 3),
            },
        )

        if final_output is None:
            perf["pipeline_total_ms"] = (time.perf_counter() - t_request_start) * 1000
            _finalize_timeline()
            return None, perf

        generated_text = final_output.outputs[0].text
        generated_ids = final_output.outputs[0].token_ids
        perf["llm_token_total"] = len(generated_ids)

        # Stage 3: parse BiCodec tokens from generated text
        t0 = time.perf_counter()
        semantic_ids, global_ids = self._extract_bicodec_tokens(generated_text)
        perf["llm_parse_ms"] = (time.perf_counter() - t0) * 1000
        perf["semantic_token_total"] = len(semantic_ids)
        perf["global_token_total"] = len(global_ids)
        parse_done_timeline_ms = _mark("parse_done")
        _log_perf_stage(
            "perf_token_parse_done",
            extra={
                "request_id": request_id,
                "parse_ms": round(perf["llm_parse_ms"], 3),
                "semantic_tokens": perf["semantic_token_total"],
                "global_tokens": perf["global_token_total"],
                "timeline_ms": round(parse_done_timeline_ms, 3),
            },
        )

        # Aggregate LLM batch stats
        self._apply_stats(perf, "llm_batch_wall_ms", llm_batch_wall_ms)
        self._apply_stats(perf, "llm_batch_gpu_ms", llm_batch_gpu_ms)
        self._apply_stats(perf, "llm_decode_wall_ms", llm_decode_wall_ms)
        self._apply_stats(perf, "llm_decode_gpu_ms", llm_decode_gpu_ms)
        self._apply_stats(perf, "tokens_per_batch", tokens_per_batch)

        if llm_batch_wall_ms:
            perf["llm_prefill_wall_ms"] = float(llm_batch_wall_ms[0])
            perf["llm_prefill_gpu_ms"] = float(llm_batch_gpu_ms[0]) if llm_batch_gpu_ms else 0.0
        else:
            perf["llm_prefill_wall_ms"] = float(perf["llm_generation_wall_ms"])
            perf["llm_prefill_gpu_ms"] = 0.0

        if not llm_batch_gpu_ms and final_request_metrics.get("model_execute_time") is not None:
            total_gpu_ms = float(final_request_metrics["model_execute_time"]) * 1000
            perf["llm_batch_gpu_ms_total"] = total_gpu_ms
            if perf["llm_batch_count"] > 1:
                perf["llm_decode_gpu_ms_total"] = total_gpu_ms

        if perf["llm_token_total"] > 0:
            perf["llm_time_per_token_ms"] = float(perf["llm_generation_wall_ms"]) / float(perf["llm_token_total"])
        if perf["llm_batch_count"] > 0:
            perf["llm_time_per_batch_wall_ms"] = float(perf.get("llm_batch_wall_ms_total", 0.0)) / float(perf["llm_batch_count"])
            if perf.get("llm_batch_gpu_ms_total", 0.0) > 0:
                perf["llm_time_per_batch_gpu_ms"] = float(perf["llm_batch_gpu_ms_total"]) / float(perf["llm_batch_count"])

        # Pull useful request-level vLLM metrics (seconds -> ms).
        # Supports both legacy field names and v0.15+ RequestStateStats.
        sec_fields = {
            "time_in_queue": "llm_time_in_queue_ms",
            "scheduler_time": "llm_scheduler_ms",
            "model_forward_time": "llm_model_forward_ms",
            "model_execute_time": "llm_model_execute_ms",
            "first_token_latency": "llm_first_token_ms",
        }
        for src, dst in sec_fields.items():
            if src in final_request_metrics:
                perf[dst] = float(final_request_metrics[src]) * 1000

        def _delta_ms(start_s: Optional[float], end_s: Optional[float]) -> Optional[float]:
            if start_s is None or end_s is None:
                return None
            if end_s < start_s:
                return None
            return (end_s - start_s) * 1000.0

        # New vLLM fields (monotonic seconds).
        queued_ts = final_request_metrics.get("queued_ts")
        scheduled_ts = final_request_metrics.get("scheduled_ts")
        first_token_ts = final_request_metrics.get("first_token_ts")
        last_token_ts = final_request_metrics.get("last_token_ts")
        q_to_sched = _delta_ms(queued_ts, scheduled_ts)
        sched_to_first = _delta_ms(scheduled_ts, first_token_ts)
        first_to_last = _delta_ms(first_token_ts, last_token_ts)
        queued_to_last = _delta_ms(queued_ts, last_token_ts)
        if q_to_sched is not None:
            perf["llm_queued_to_scheduled_ms"] = q_to_sched
            perf["llm_time_in_queue_ms"] = q_to_sched
        if sched_to_first is not None:
            perf["llm_scheduled_to_first_token_ms"] = sched_to_first
            perf["llm_scheduler_ms"] = sched_to_first
        if first_to_last is not None:
            perf["llm_first_to_last_token_ms"] = first_to_last
            # For new stats objects, this is the best available model-execute proxy.
            perf["llm_model_execute_ms"] = first_to_last
        if queued_to_last is not None:
            perf["llm_queued_to_last_token_ms"] = queued_to_last
            perf["llm_request_lifecycle_ms"] = queued_to_last

        if "num_generation_tokens" in final_request_metrics and perf.get("llm_token_total", 0) <= 0:
            perf["llm_token_total"] = int(final_request_metrics["num_generation_tokens"])

        arrival = final_request_metrics.get("arrival_time")
        first_token = final_request_metrics.get("first_token_time")
        finished = final_request_metrics.get("finished_time")
        if arrival is not None and first_token is not None and first_token >= arrival:
            perf["llm_first_token_ms"] = (first_token - arrival) * 1000
        if arrival is not None and finished is not None and finished >= arrival:
            perf["llm_request_lifecycle_ms"] = (finished - arrival) * 1000

        if not semantic_ids or not global_ids:
            perf["pipeline_total_ms"] = (time.perf_counter() - t_request_start) * 1000
            _finalize_timeline()
            return None, perf

        # Stage 4: validate BiCodec tokens
        t0 = time.perf_counter()
        if not self.bicodec_decoder.validate_tokens(semantic_ids, global_ids):
            perf["token_validation_ms"] = (time.perf_counter() - t0) * 1000
            _mark("validation_done")
            perf["pipeline_total_ms"] = (time.perf_counter() - t_request_start) * 1000
            _finalize_timeline()
            return None, perf
        perf["token_validation_ms"] = (time.perf_counter() - t0) * 1000
        _mark("validation_done")

        # Stage 5: BiCodec decode (wall + GPU/CPU split)
        decode_t0 = time.perf_counter()
        decode_gpu_ms = 0.0
        perf["llm_decode_calls"] = 1

        use_cuda_timing = (
            _PERF_GPU_TIMING_ENABLED
            and
            torch.cuda.is_available()
            and hasattr(self.bicodec_decoder, "device")
            and getattr(self.bicodec_decoder.device, "type", "") == "cuda"
        )
        audio_bytes: Optional[bytes]
        if use_cuda_timing:
            try:
                device = self.bicodec_decoder.device
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                torch.cuda.synchronize(device=device)
                start_event.record()
                audio_bytes = self.bicodec_decoder.decode_to_bytes(semantic_ids, global_ids)
                end_event.record()
                torch.cuda.synchronize(device=device)
                decode_gpu_ms = float(start_event.elapsed_time(end_event))
            except Exception:
                audio_bytes = self.bicodec_decoder.decode_to_bytes(semantic_ids, global_ids)
                decode_gpu_ms = 0.0
        else:
            if _ASYNC_BICODEC_DECODE_ENABLED:
                # Optional: keep event loop responsive by offloading decode.
                audio_bytes = await self.bicodec_decoder.decode_to_bytes_async(semantic_ids, global_ids)
            else:
                audio_bytes = self.bicodec_decoder.decode_to_bytes(semantic_ids, global_ids)

        decode_wall_ms = (time.perf_counter() - decode_t0) * 1000
        perf["bicodec_decode_wall_ms"] = float(decode_wall_ms)
        perf["bicodec_decode_gpu_ms"] = float(max(0.0, decode_gpu_ms))
        perf["bicodec_decode_cpu_ms"] = float(max(0.0, decode_wall_ms - decode_gpu_ms))
        bicodec_done_timeline_ms = _mark("bicodec_done")
        _log_perf_stage(
            "perf_bicodec_decode_done",
            extra={
                "request_id": request_id,
                "decode_wall_ms": round(perf["bicodec_decode_wall_ms"], 3),
                "decode_gpu_ms": round(perf["bicodec_decode_gpu_ms"], 3),
                "decode_cpu_ms": round(perf["bicodec_decode_cpu_ms"], 3),
                "timeline_ms": round(bicodec_done_timeline_ms, 3),
            },
        )

        if audio_bytes is None:
            perf["pipeline_total_ms"] = (time.perf_counter() - t_request_start) * 1000
            _finalize_timeline()
            return None, perf

        # Stage 6: WAV header
        t0 = time.perf_counter()
        wav_bytes = add_wav_header(audio_bytes, sample_rate=AUDIO_SAMPLE_RATE)
        perf["wav_pack_ms"] = (time.perf_counter() - t0) * 1000
        _mark("wav_done")

        perf["pipeline_total_ms"] = (time.perf_counter() - t_request_start) * 1000
        _finalize_timeline()
        _log_perf_stage(
            "perf_non_stream_done",
            extra={
                "request_id": request_id,
                "pipeline_total_ms": round(perf["pipeline_total_ms"], 3),
                "llm_token_total": perf.get("llm_token_total"),
                "timeline_total_ms": round(float(perf.get("timeline_total_ms", 0.0)), 3),
            },
        )
        return wav_bytes, perf
    
    async def generate_speech(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = 1.05,
        seed: Optional[int] = None,
    ) -> Optional[bytes]:
        """
        Generate speech audio (non-streaming) using Spark TTS.
        
        NOTE: This method signature changed from description-based to speaker-based
        to align with Spark TTS architecture.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize (with optional [emotion] tags)
            temperature: Sampling temperature (default: 0.8 for Spark TTS)
            top_k: Top-k sampling (default: 50)
            top_p: Nucleus sampling (default: 1.0)
            max_tokens: Max BiCodec tokens to generate (default: 2048)
            seed: Random seed for reproducibility
        
        Returns:
            Audio bytes (int16 PCM WAV, 16kHz mono) or None if failed
        """
        wav_bytes, _ = await self.generate_speech_profiled(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
        return wav_bytes
    
    async def generate_speech_indic(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = 1.05,
        seed: Optional[int] = None,
    ) -> Optional[bytes]:
        """
        Generate speech audio for Spark TTS (speaker-based).
        
        This method is kept for backward compatibility with existing API.
        It delegates to generate_speech() with the same implementation.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome."
                - "[laughs] Hello there!"
                - "Hello <laugh> this is fun!" (will be normalized)
                - "नमस्ते! [excited] आज का दिन बहुत अच्छा है।"
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max BiCodec tokens to generate
            seed: Random seed for reproducibility
        
        Returns:
            Audio bytes (int16 PCM WAV, 16kHz mono) or None if failed
        """
        # Delegate to generate_speech (same implementation for Spark TTS)
        return await self.generate_speech(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
    
    def _extract_bicodec_tokens(self, generated_text: str) -> Tuple[List[int], List[int]]:
        """
        Extract BiCodec semantic and global tokens from generated text using regex.
        
        Spark TTS generates tokens in the format:
        - <|bicodec_semantic_{id}|>
        - <|bicodec_global_{id}|>
        
        Args:
            generated_text: Generated text containing BiCodec token markers
        
        Returns:
            Tuple of (semantic_ids, global_ids)
            Returns ([], []) if no tokens found
        """
        # Extract semantic tokens
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", generated_text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        # Extract global tokens
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text)
        global_ids = [int(t) for t in global_matches]
        
        # Only log if there's an issue
        if not semantic_ids and not global_ids:
            logger.error(f"❌ No BiCodec tokens found! Generated text (first 1000 chars):\n{generated_text[:1000]}")
        
        return semantic_ids, global_ids
