"""
Spark TTS BiCodec Decoder

Decodes BiCodec audio tokens (semantic + global) to audio waveforms.
Replaces SNAC decoder for Spark TTS model.
"""

import asyncio
import logging
import os
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

# Ensure `sparktts` (vendored under external/) is importable in local/dev runs.
# In Modal, PYTHONPATH already includes `/root/external/sparktts`.
try:
    from sparktts.models.audio_tokenizer import BiCodecTokenizer
except ImportError:  # pragma: no cover - environment-dependent
    _this_file = Path(__file__).resolve()
    _sparktts_path = None
    for _p in [_this_file.parent] + list(_this_file.parents):
        _candidate = _p / "external" / "sparktts"
        if _candidate.is_dir():
            _sparktts_path = str(_candidate)
            break
    if _sparktts_path and _sparktts_path not in sys.path:
        sys.path.insert(0, _sparktts_path)
    from sparktts.models.audio_tokenizer import BiCodecTokenizer

from veena3modal.core.constants import (
    BICODEC_TOKENIZER_PATH,
    AUDIO_SAMPLE_RATE,
)

logger = logging.getLogger(__name__)


def _env_flag(name: str, default: bool) -> bool:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    return raw.strip().lower() in {"1", "true", "yes", "on"}


def _env_int(name: str, default: int, minimum: int = 1) -> int:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    try:
        value = int(raw)
    except ValueError:
        return default
    return max(minimum, value)


def _env_float(name: str, default: float, minimum: float = 0.0) -> float:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    try:
        value = float(raw)
    except ValueError:
        return default
    return max(minimum, value)


@dataclass
class _BatchDecodeRequest:
    semantic_ids: List[int]
    global_ids: List[int]
    use_sliding_window: bool
    trim_warmup: bool
    future: asyncio.Future
    submitted_at: float


class BiCodecDecoder:
    """
    BiCodec Decoder for Spark TTS.
    
    Decodes semantic and global tokens to audio waveforms using BiCodec.
    This replaces the SNAC decoder used in the old Indic Orpheus model.
    
    Supports streaming with sliding window approach (like SNAC).
    """
    
    def __init__(
        self,
        device: str = "cuda",
        model_path: str = BICODEC_TOKENIZER_PATH,
        enable_batching: bool = False,  # For compatibility with streaming pipeline
    ):
        """
        Initialize BiCodec decoder.
        
        Args:
            device: Device for BiCodec model (cuda/cpu)
            model_path: Path to BiCodec model checkpoint
            enable_batching: Enable async batching for streaming (not used yet)
        """
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model_path = model_path
        requested_batching = bool(enable_batching)
        if _env_flag("VEENA3_BICODEC_BATCHING_FORCE_ON", False):
            requested_batching = True
        if _env_flag("VEENA3_BICODEC_BATCHING_FORCE_OFF", False):
            requested_batching = False
        self.enable_batching = requested_batching and self.device.type == "cuda"
        self.batch_max_size = _env_int("VEENA3_BICODEC_BATCH_MAX", 16, minimum=1)
        self.batch_timeout_s = _env_float("VEENA3_BICODEC_BATCH_TIMEOUT_MS", 6.0, minimum=0.1) / 1000.0
        self.batch_worker_count = _env_int("VEENA3_BICODEC_BATCH_WORKERS", 1, minimum=1)
        self.batch_worker_scale_pending = _env_int("VEENA3_BICODEC_BATCH_SCALE_PENDING", 0, minimum=0)
        self.batch_worker_scale_mode = os.environ.get("VEENA3_BICODEC_BATCH_SCALE_MODE", "sticky").strip().lower()
        if self.batch_worker_scale_mode not in {"sticky", "dynamic"}:
            self.batch_worker_scale_mode = "sticky"

        # Async micro-batcher state (per event loop).
        self._batch_map_lock = threading.Lock()
        self._batch_queues: Dict[int, asyncio.Queue] = {}
        self._batch_workers: Dict[int, List[asyncio.Task]] = {}
        self._batch_scaled_loops: Dict[int, bool] = {}
        self._batch_total_batches = 0
        self._batch_total_requests = 0
        self._batch_max_seen = 0
        self._batch_queue_wait_ms_total = 0.0
        self._batch_queue_wait_ms_min: Optional[float] = None
        self._batch_queue_wait_ms_max = 0.0
        self._batch_compute_ms_total = 0.0
        self._batch_compute_ms_min: Optional[float] = None
        self._batch_compute_ms_max = 0.0
        self._batch_queue_depth_total = 0
        self._batch_queue_depth_max = 0
        self._batch_inflight_requests = 0
        self._batch_workers_target_last = 1
        
        print(f"🎵 Loading BiCodec audio tokenizer from {model_path}...")
        print(f"   Device: {self.device}")
        
        # Initialize BiCodec tokenizer
        self.audio_tokenizer = BiCodecTokenizer(model_path, device=str(self.device))
        
        # OPTIMIZATION: Ensure model is on device once at init, not per-decode call
        self.audio_tokenizer.device = self.device
        self.audio_tokenizer.model.to(self.device)
        
        # NOTE: torch.compile disabled for BiCodec -- dynamic semantic sequence lengths
        # cause repeated recompilation (~120s each), worse than eager mode.
        # torch.compile works best with static shapes. If needed, pre-warm with
        # representative lengths or use mode="default" instead of "reduce-overhead".
        # The multi-engine approach (Tier 3) provides more impactful gains.
        
        print(f"✅ BiCodec decoder initialized (sample rate: {AUDIO_SAMPLE_RATE}Hz)")
        if self.enable_batching:
            print(
                "   ℹ️  Micro-batching enabled "
                f"(workers={self.batch_worker_count}, scale_mode={self.batch_worker_scale_mode}, "
                f"scale_pending={self.batch_worker_scale_pending}, "
                f"max_batch={self.batch_max_size}, timeout_ms={self.batch_timeout_s * 1000:.1f})"
            )
    
    def decode(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[np.ndarray]:
        """
        Decode BiCodec tokens to audio waveform.
        
        CRITICAL: BiCodec expects EXACTLY 32 global tokens!
        The decoder pools them via speaker_encoder → d_vector, then broadcasts
        via d_vector.unsqueeze(-1) across the time dimension.
        
        Args:
            semantic_ids: List of semantic token IDs (variable length)
            global_ids: List of global token IDs (MUST be exactly 32!)
        
        Returns:
            Audio waveform as numpy array (float32, 16kHz mono)
            Shape: (samples,)
            Returns None if decode fails
        """
        if not semantic_ids or not global_ids:
            print(f"⚠️  Empty token lists: semantic={len(semantic_ids)}, global={len(global_ids)}")
            return None
        
        # Validate global token count
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"❌ BiCodec requires EXACTLY {EXPECTED_GLOBAL_TOKENS} global tokens, got {len(global_ids)}")
            # Try to fix it
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS}")
            else:
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS}")
        
        try:
            # Convert to tensors following official Spark-TTS implementation
            # global_token_ids shape: (1, 32) - ALWAYS 32!
            # pred_semantic_ids shape: (1, N) - variable length
            pred_semantic = torch.tensor(semantic_ids).long().unsqueeze(0).to(self.device)
            pred_global = torch.tensor(global_ids).long().unsqueeze(0).to(self.device)
            
            # Device migration done once in __init__, not per-call
            # Decode through BiCodec
            # BiCodecTokenizer.detokenize() expects:
            #   - global_tokens: (batch, global_dim) -> will unsqueeze(1) internally
            #   - semantic_tokens: (batch, latent_dim)
            with torch.inference_mode():
                wav_np = self.audio_tokenizer.detokenize(
                    pred_global,  # (1, seq) - correct shape!
                    pred_semantic  # (1, seq)
                )
            
            return wav_np
            
        except Exception as e:
            print(f"❌ BiCodec decode error: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def decode_to_bytes(
        self, 
        semantic_ids: List[int], 
        global_ids: List[int],
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens to audio bytes (int16 PCM).
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            Audio as bytes (int16 PCM, 16kHz mono)
            Returns None if decode fails
        """
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # Clip and convert float32 to int16 PCM
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()

    async def decode_to_bytes_async(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
    ) -> Optional[bytes]:
        """
        Async wrapper for decode_to_bytes.

        Runs decode in the default executor so the event loop stays responsive
        under high request concurrency.
        """
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(
            None,
            self.decode_to_bytes,
            semantic_ids,
            global_ids,
        )

    @staticmethod
    def _normalize_global_ids(global_ids: List[int]) -> List[int]:
        expected = 32
        if len(global_ids) >= expected:
            return list(global_ids[:expected])
        return list(global_ids) + [0] * (expected - len(global_ids))

    def _decode_streaming_batch_sync(
        self,
        batch_items: List[Tuple[List[int], List[int], bool, bool]],
    ) -> List[Optional[bytes]]:
        """
        Run a single batched BiCodec forward pass for multiple stream decode requests.

        Falls back to sequential decode_streaming if shapes or settings are not batch-safe.
        """
        if not batch_items:
            return []
        if len(batch_items) == 1:
            sem, glob, use_sw, trim = batch_items[0]
            return [self.decode_streaming(sem, glob, use_sliding_window=use_sw, trim_warmup=trim)]

        # Sliding window path uses special trimming behavior; keep it sequential.
        if any(use_sw for _, _, use_sw, _ in batch_items):
            return [
                self.decode_streaming(sem, glob, use_sliding_window=use_sw, trim_warmup=trim)
                for sem, glob, use_sw, trim in batch_items
            ]

        # Decoder needs at least 8 semantic tokens. Use fallback if any request is below.
        if any(len(sem) < 8 for sem, _, _, _ in batch_items):
            return [
                self.decode_streaming(sem, glob, use_sliding_window=use_sw, trim_warmup=trim)
                for sem, glob, use_sw, trim in batch_items
            ]

        try:
            batch_size = len(batch_items)
            max_sem_len = max(len(sem) for sem, _, _, _ in batch_items)

            semantic_batch = torch.zeros(
                (batch_size, max_sem_len),
                dtype=torch.long,
                device=self.device,
            )
            global_batch = torch.zeros(
                (batch_size, 32),
                dtype=torch.long,
                device=self.device,
            )
            semantic_lengths: List[int] = []

            for idx, (semantic_ids, global_ids, _, _) in enumerate(batch_items):
                sem_len = len(semantic_ids)
                semantic_lengths.append(sem_len)
                semantic_batch[idx, :sem_len] = torch.tensor(
                    semantic_ids,
                    dtype=torch.long,
                    device=self.device,
                )
                global_batch[idx] = torch.tensor(
                    self._normalize_global_ids(global_ids),
                    dtype=torch.long,
                    device=self.device,
                )

            with torch.inference_mode():
                wav_batch = self.audio_tokenizer.detokenize(global_batch, semantic_batch)

            results: List[Optional[bytes]] = []
            for idx, sem_len in enumerate(semantic_lengths):
                audio = wav_batch[idx] if getattr(wav_batch, "ndim", 1) > 1 else wav_batch
                if isinstance(audio, torch.Tensor):
                    audio = audio.detach().cpu().numpy()
                audio_np = np.asarray(audio).reshape(-1)
                expected_samples = sem_len * 320
                if expected_samples > 0 and len(audio_np) > expected_samples:
                    audio_np = audio_np[:expected_samples]
                audio_np = np.clip(audio_np, -1.0, 1.0)
                audio_int16 = (audio_np * 32767).astype(np.int16)
                results.append(audio_int16.tobytes())
            return results
        except Exception:
            logger.exception("BiCodec batched decode failed; falling back to sequential decode")
            return [
                self.decode_streaming(sem, glob, use_sliding_window=use_sw, trim_warmup=trim)
                for sem, glob, use_sw, trim in batch_items
            ]

    def _ensure_batch_worker(self, loop: asyncio.AbstractEventLoop) -> asyncio.Queue:
        key = id(loop)
        with self._batch_map_lock:
            queue = self._batch_queues.get(key)
            workers = self._batch_workers.get(key, [])
            live_workers = [worker for worker in workers if not worker.done()]

            if queue is None:
                queue = asyncio.Queue(maxsize=max(self.batch_max_size * 8, 256))
                self._batch_queues[key] = queue

            target_workers = self.batch_worker_count
            if self.batch_worker_count > 1 and self.batch_worker_scale_pending > 0:
                pending_now = queue.qsize() + self._batch_inflight_requests
                if self.batch_worker_scale_mode == "sticky":
                    already_scaled = self._batch_scaled_loops.get(key, False)
                    should_scale = already_scaled or pending_now >= self.batch_worker_scale_pending
                    self._batch_scaled_loops[key] = should_scale
                    if not should_scale:
                        target_workers = 1
                elif pending_now < self.batch_worker_scale_pending:
                    target_workers = 1
            while len(live_workers) < target_workers:
                worker_idx = len(live_workers)
                worker = loop.create_task(
                    self._batch_worker(loop, queue, worker_idx),
                    name=f"bicodec-batch-worker-{key}-{worker_idx}",
                )
                live_workers.append(worker)

            self._batch_workers_target_last = len(live_workers)
            self._batch_workers[key] = live_workers
            return queue

    def _desired_worker_count(self, queue: Optional[asyncio.Queue] = None) -> int:
        if self.batch_worker_count <= 1:
            return 1
        if self.batch_worker_scale_pending <= 0:
            return self.batch_worker_count
        with self._batch_map_lock:
            pending_now = self._batch_inflight_requests
        if queue is not None:
            pending_now += queue.qsize()
        if pending_now >= self.batch_worker_scale_pending:
            return self.batch_worker_count
        return 1

    async def _batch_worker(
        self,
        loop: asyncio.AbstractEventLoop,
        queue: asyncio.Queue,
        worker_idx: int,
    ) -> None:
        while True:
            if self.batch_worker_scale_mode == "dynamic":
                target_workers = self._desired_worker_count(queue)
                self._batch_workers_target_last = target_workers
                if worker_idx >= target_workers:
                    await asyncio.sleep(min(self.batch_timeout_s, 0.002))
                    continue

            first: _BatchDecodeRequest = await queue.get()
            batch: List[_BatchDecodeRequest] = [first]
            deadline = loop.time() + self.batch_timeout_s

            while len(batch) < self.batch_max_size:
                remaining = deadline - loop.time()
                if remaining <= 0:
                    break
                try:
                    req: _BatchDecodeRequest = await asyncio.wait_for(queue.get(), timeout=remaining)
                except asyncio.TimeoutError:
                    break
                batch.append(req)

            batch_inputs = [
                (req.semantic_ids, req.global_ids, req.use_sliding_window, req.trim_warmup)
                for req in batch
            ]
            batch_start = time.perf_counter()
            queue_wait_ms = [
                max(0.0, (batch_start - req.submitted_at) * 1000.0)
                for req in batch
            ]
            queue_depth_now = queue.qsize()
            self._batch_queue_depth_total += queue_depth_now
            if queue_depth_now > self._batch_queue_depth_max:
                self._batch_queue_depth_max = queue_depth_now
            for wait_ms in queue_wait_ms:
                self._batch_queue_wait_ms_total += wait_ms
                if self._batch_queue_wait_ms_min is None or wait_ms < self._batch_queue_wait_ms_min:
                    self._batch_queue_wait_ms_min = wait_ms
                if wait_ms > self._batch_queue_wait_ms_max:
                    self._batch_queue_wait_ms_max = wait_ms
            with self._batch_map_lock:
                self._batch_inflight_requests += len(batch)

            t0 = time.perf_counter()
            try:
                results = await loop.run_in_executor(None, self._decode_streaming_batch_sync, batch_inputs)
                if len(results) != len(batch):
                    raise RuntimeError("batch decode result size mismatch")
                for req, result in zip(batch, results):
                    if not req.future.done():
                        req.future.set_result(result)
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                logger.exception("BiCodec batch worker %d failed", worker_idx)
                for req in batch:
                    if not req.future.done():
                        req.future.set_exception(exc)
            finally:
                compute_ms = (time.perf_counter() - t0) * 1000.0
                self._batch_compute_ms_total += compute_ms
                if self._batch_compute_ms_min is None or compute_ms < self._batch_compute_ms_min:
                    self._batch_compute_ms_min = compute_ms
                if compute_ms > self._batch_compute_ms_max:
                    self._batch_compute_ms_max = compute_ms
                self._batch_total_batches += 1
                self._batch_total_requests += len(batch)
                if len(batch) > self._batch_max_seen:
                    self._batch_max_seen = len(batch)
                with self._batch_map_lock:
                    self._batch_inflight_requests = max(0, self._batch_inflight_requests - len(batch))

    def get_pending_requests(self) -> int:
        if not self.enable_batching:
            return 0
        with self._batch_map_lock:
            pending = sum(queue.qsize() for queue in self._batch_queues.values())
            return int(max(0, pending + self._batch_inflight_requests))

    def get_batching_stats(self) -> Dict[str, float]:
        batches = max(1, self._batch_total_batches)
        reqs = max(1, self._batch_total_requests)
        with self._batch_map_lock:
            workers_live = sum(
                1
                for workers in self._batch_workers.values()
                for worker in workers
                if not worker.done()
            )
        return {
            "enabled": 1.0 if self.enable_batching else 0.0,
            "batch_total_batches": float(self._batch_total_batches),
            "batch_total_requests": float(self._batch_total_requests),
            "batch_avg_size": float(self._batch_total_requests / batches),
            "batch_max_seen": float(self._batch_max_seen),
            "batch_workers_configured": float(self.batch_worker_count),
            "batch_workers_live": float(workers_live),
            "batch_workers_target": float(self._batch_workers_target_last),
            "batch_worker_scale_pending": float(self.batch_worker_scale_pending),
            "batch_worker_scale_mode": self.batch_worker_scale_mode,
            "batch_queue_wait_ms_avg": float(self._batch_queue_wait_ms_total / reqs),
            "batch_queue_wait_ms_min": float(self._batch_queue_wait_ms_min or 0.0),
            "batch_queue_wait_ms_max": float(self._batch_queue_wait_ms_max),
            "batch_compute_ms_avg": float(self._batch_compute_ms_total / batches),
            "batch_compute_ms_min": float(self._batch_compute_ms_min or 0.0),
            "batch_compute_ms_max": float(self._batch_compute_ms_max),
            "batch_queue_depth_avg": float(self._batch_queue_depth_total / batches),
            "batch_queue_depth_max": float(self._batch_queue_depth_max),
            "batch_pending_now": float(self.get_pending_requests()),
        }
    
    def validate_tokens(self, semantic_ids: List[int], global_ids: List[int]) -> bool:
        """
        Validate BiCodec tokens before decoding.
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs
        
        Returns:
            True if valid, False otherwise
        """
        # Check minimum length
        if not semantic_ids:
            print(f"❌ No semantic tokens")
            return False
        
        if not global_ids:
            print(f"❌ No global tokens")
            return False
        
        # Basic sanity checks
        if len(semantic_ids) < 1:
            print(f"❌ Too few semantic tokens: {len(semantic_ids)}")
            return False
        
        if len(global_ids) < 1:
            print(f"❌ Too few global tokens: {len(global_ids)}")
            return False
        
        return True
    
    def get_audio_duration(self, semantic_ids: List[int]) -> float:
        """
        Estimate audio duration from semantic tokens.
        
        Args:
            semantic_ids: List of semantic token IDs
        
        Returns:
            Estimated duration in seconds
        """
        # BiCodec generates audio based on semantic tokens
        # Approximate duration calculation (may need tuning)
        # This is an estimate and actual duration depends on model output
        estimated_samples = len(semantic_ids) * 320  # Rough estimate
        return estimated_samples / AUDIO_SAMPLE_RATE
    
    def decode_streaming(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        use_sliding_window: bool = False,
        trim_warmup: bool = False,
    ) -> Optional[bytes]:
        """
        Decode BiCodec tokens with streaming support.
        
        CRITICAL: BiCodec decoder expects EXACTLY 32 global tokens always!
        The decoder internally broadcasts d_vector via d_vector.unsqueeze(-1).
        
        From sparktts/models/bicodec.py line 184-186:
            d_vector = self.speaker_encoder.detokenize(global_tokens)  # Expects 32 tokens
            x = self.prenet(z_q, d_vector)
            x = x + d_vector.unsqueeze(-1)  # Broadcasts across time automatically!
        
        Args:
            semantic_ids: List of semantic token IDs (variable length, 50 TPS)
            global_ids: List of global token IDs (MUST be exactly 32 tokens!)
            use_sliding_window: If True, return only middle samples
            trim_warmup: Legacy parameter (not used for BiCodec)
        
        Returns:
            Audio bytes (int16 PCM) or None if decode fails
        """
        # Minimum semantic tokens needed for stable decode
        MIN_SEMANTIC_TOKENS = 8
        
        if len(semantic_ids) < MIN_SEMANTIC_TOKENS:
            print(f"⚠️  Too few semantic tokens ({len(semantic_ids)}) for decode, need >= {MIN_SEMANTIC_TOKENS}")
            return None
        
        # CRITICAL: Ensure we have exactly 32 global tokens
        EXPECTED_GLOBAL_TOKENS = 32
        if len(global_ids) != EXPECTED_GLOBAL_TOKENS:
            print(f"⚠️  WARNING: Got {len(global_ids)} global tokens, expected {EXPECTED_GLOBAL_TOKENS}")
            if len(global_ids) > EXPECTED_GLOBAL_TOKENS:
                # Truncate to first 32
                global_ids = global_ids[:EXPECTED_GLOBAL_TOKENS]
                print(f"   └─ Truncated to {EXPECTED_GLOBAL_TOKENS} tokens")
            else:
                # Pad with zeros if we somehow have fewer
                global_ids = global_ids + [0] * (EXPECTED_GLOBAL_TOKENS - len(global_ids))
                print(f"   └─ Padded to {EXPECTED_GLOBAL_TOKENS} tokens")
        
        # Decode with EXACTLY 32 global tokens + variable semantic tokens
        # The decoder handles broadcasting internally!
        audio = self.decode(semantic_ids, global_ids)
        
        if audio is None:
            return None
        
        # If using sliding window, return only the middle portion
        if use_sliding_window and len(audio) > 4096:
            # Keep middle samples for overlap-add streaming
            # BiCodec produces ~320 samples per semantic token (16kHz / 50 TPS)
            total_samples = len(audio)
            keep_samples = min(4096, total_samples // 2)
            start = (total_samples - keep_samples) // 2
            end = start + keep_samples
            audio = audio[start:end]
        
        # Convert to bytes
        audio = np.clip(audio, -1.0, 1.0)
        audio_int16 = (audio * 32767).astype(np.int16)
        
        return audio_int16.tobytes()
    
    async def decode_single_async(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
        trim_warmup: bool = False,
        use_sliding_window: bool = False,
    ) -> Optional[bytes]:
        """
        Async wrapper for streaming decode - runs in executor to avoid blocking event loop.
        
        OPTIMIZATION: Previously ran synchronously, blocking all concurrent coroutines
        during GPU decode (~20-50ms). Now uses run_in_executor to unblock the event loop
        so other streams can progress while this decode runs on GPU.
        
        Args:
            semantic_ids: List of semantic token IDs
            global_ids: List of global token IDs  
            trim_warmup: Legacy parameter (not used)
            use_sliding_window: Use sliding window mode
        
        Returns:
            Audio bytes (int16 PCM) or None
        """
        loop = asyncio.get_running_loop()
        if self.enable_batching and not use_sliding_window:
            queue = self._ensure_batch_worker(loop)
            future: asyncio.Future = loop.create_future()
            req = _BatchDecodeRequest(
                semantic_ids=list(semantic_ids),
                global_ids=list(global_ids),
                use_sliding_window=use_sliding_window,
                trim_warmup=trim_warmup,
                future=future,
                submitted_at=time.perf_counter(),
            )
            await queue.put(req)
            return await future

        return await loop.run_in_executor(
            None,
            self.decode_streaming,
            semantic_ids,
            global_ids,
            use_sliding_window,
            trim_warmup,
        )
