"""
BiCodec Batch Decoder - Batches concurrent decode requests into single GPU forward passes.

OPTIMIZATION: Instead of N sequential ~30ms decodes (one per concurrent stream),
batches them into 1 GPU call of ~35-40ms. For 10 concurrent streams, this saves
~260ms of serialized GPU time per decode round.

Architecture:
    Each streaming coroutine submits a decode request to the batch queue.
    A background worker collects requests within a timeout window,
    pads semantic sequences to equal length, runs a single batched decode,
    then dispatches results back to individual futures.

Usage:
    # At startup:
    batch_decoder = BiCodecBatchDecoder(bicodec_decoder, max_batch=16, timeout_ms=10)
    batch_decoder.start()
    
    # In streaming hot loop:
    audio_bytes = await batch_decoder.decode_async(semantic_ids, global_ids)
    
    # At shutdown:
    batch_decoder.stop()
"""

import asyncio
import logging
import time
import threading
from typing import List, Optional, Tuple
from dataclasses import dataclass

import numpy as np
import torch

logger = logging.getLogger(__name__)


@dataclass
class DecodeRequest:
    """A single decode request from a streaming coroutine."""
    semantic_ids: List[int]
    global_ids: List[int]
    future: asyncio.Future
    loop: asyncio.AbstractEventLoop
    submitted_at: float


class BiCodecBatchDecoder:
    """
    Batches concurrent BiCodec decode requests into single GPU forward passes.
    
    Thread Safety:
        - submit() is called from async coroutines (event loop thread)
        - _worker_loop() runs in a background thread
        - Communication via thread-safe asyncio.Queue
        - Results dispatched via loop.call_soon_threadsafe()
    """
    
    def __init__(
        self,
        bicodec_decoder,
        max_batch: int = 16,
        timeout_ms: float = 10.0,
    ):
        """
        Args:
            bicodec_decoder: BiCodecDecoder instance (used for single-item decode fallback)
            max_batch: Max requests to batch together
            timeout_ms: Max time to wait for batch to fill before executing
        """
        self.decoder = bicodec_decoder
        self.max_batch = max_batch
        self.timeout_ms = timeout_ms
        
        self._queue: asyncio.Queue = asyncio.Queue()
        self._running = False
        self._worker_thread: Optional[threading.Thread] = None
        
        # Stats
        self.total_batches = 0
        self.total_requests = 0
        self.total_batch_time_ms = 0.0
    
    def start(self):
        """Start the background batch worker."""
        self._running = True
        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()
        logger.info(f"BiCodecBatchDecoder started (max_batch={self.max_batch}, timeout={self.timeout_ms}ms)")
    
    def stop(self):
        """Stop the background worker."""
        self._running = False
        if self._worker_thread:
            self._worker_thread.join(timeout=5)
    
    async def decode_async(
        self,
        semantic_ids: List[int],
        global_ids: List[int],
    ) -> Optional[bytes]:
        """
        Submit a decode request and wait for the batched result.
        
        This is the main API for streaming coroutines.
        Returns audio_bytes (int16 PCM) or None if decode failed.
        """
        loop = asyncio.get_event_loop()
        future = loop.create_future()
        
        request = DecodeRequest(
            semantic_ids=semantic_ids,
            global_ids=global_ids,
            future=future,
            loop=loop,
            submitted_at=time.time(),
        )
        
        await self._queue.put(request)
        return await future
    
    def _worker_loop(self):
        """Background thread that collects and executes batched decodes."""
        while self._running:
            batch: List[DecodeRequest] = []
            
            # Collect requests up to max_batch or timeout
            try:
                # Block for the first request
                first = asyncio.run_coroutine_threadsafe(
                    self._queue.get(), 
                    asyncio.get_event_loop()
                ).result(timeout=1.0)
                batch.append(first)
            except Exception:
                continue
            
            # Try to fill the batch within timeout
            deadline = time.time() + self.timeout_ms / 1000
            while len(batch) < self.max_batch and time.time() < deadline:
                try:
                    req = asyncio.run_coroutine_threadsafe(
                        asyncio.wait_for(self._queue.get(), timeout=0.001),
                        asyncio.get_event_loop()
                    ).result(timeout=self.timeout_ms / 1000)
                    batch.append(req)
                except Exception:
                    break
            
            if not batch:
                continue
            
            # Execute the batch
            t_start = time.time()
            try:
                if len(batch) == 1:
                    # Single request: no padding needed
                    result = self.decoder.decode_streaming(
                        semantic_ids=batch[0].semantic_ids,
                        global_ids=batch[0].global_ids,
                        use_sliding_window=False,
                        trim_warmup=False,
                    )
                    batch[0].loop.call_soon_threadsafe(batch[0].future.set_result, result)
                else:
                    # Batched decode: pad sequences and run together
                    results = self._batched_decode(batch)
                    for req, result in zip(batch, results):
                        req.loop.call_soon_threadsafe(req.future.set_result, result)
            except Exception as e:
                for req in batch:
                    if not req.future.done():
                        req.loop.call_soon_threadsafe(req.future.set_exception, e)
            
            elapsed = (time.time() - t_start) * 1000
            self.total_batches += 1
            self.total_requests += len(batch)
            self.total_batch_time_ms += elapsed
    
    def _batched_decode(self, batch: List[DecodeRequest]) -> List[Optional[bytes]]:
        """
        Execute a batched BiCodec decode.
        
        Pads semantic sequences to equal length, runs a single forward pass,
        then splits results back to individual outputs.
        """
        # All global_ids should be exactly 32 tokens
        # Semantic sequences may vary in length - pad to max
        max_semantic_len = max(len(req.semantic_ids) for req in batch)
        batch_size = len(batch)
        
        # Build batched tensors
        semantic_batch = torch.zeros(batch_size, max_semantic_len, dtype=torch.long, device=self.decoder.device)
        global_batch = torch.zeros(batch_size, 32, dtype=torch.long, device=self.decoder.device)
        seq_lengths = []
        
        for i, req in enumerate(batch):
            sem_len = len(req.semantic_ids)
            semantic_batch[i, :sem_len] = torch.tensor(req.semantic_ids, dtype=torch.long)
            global_batch[i, :len(req.global_ids)] = torch.tensor(req.global_ids[:32], dtype=torch.long)
            seq_lengths.append(sem_len)
        
        # Run batched decode through BiCodec
        try:
            with torch.inference_mode():
                wav_batch = self.decoder.audio_tokenizer.detokenize(
                    global_batch,
                    semantic_batch,
                )
        except Exception as e:
            logger.error(f"Batched decode failed: {e}")
            # Fallback to sequential decode
            results = []
            for req in batch:
                result = self.decoder.decode_streaming(
                    semantic_ids=req.semantic_ids,
                    global_ids=req.global_ids,
                    use_sliding_window=False,
                    trim_warmup=False,
                )
                results.append(result)
            return results
        
        # Split results: trim each output to its actual length
        # BiCodec produces ~320 samples per semantic token
        results = []
        for i, (req, sem_len) in enumerate(zip(batch, seq_lengths)):
            try:
                audio = wav_batch[i] if wav_batch.ndim > 1 else wav_batch
                if isinstance(audio, torch.Tensor):
                    audio = audio.cpu().numpy()
                
                # Trim to actual length (remove padding artifacts)
                expected_samples = sem_len * 320
                audio = audio[:expected_samples] if len(audio) > expected_samples else audio
                
                # Convert to int16 PCM bytes
                audio = np.clip(audio, -1.0, 1.0)
                audio_int16 = (audio * 32767).astype(np.int16)
                results.append(audio_int16.tobytes())
            except Exception as e:
                logger.error(f"Batch result extraction failed for item {i}: {e}")
                results.append(None)
        
        return results
    
    def get_stats(self):
        """Return batch decoder statistics."""
        avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
        avg_batch_time = self.total_batch_time_ms / self.total_batches if self.total_batches > 0 else 0
        return {
            "total_batches": self.total_batches,
            "total_requests": self.total_requests,
            "avg_batch_size": round(avg_batch_size, 1),
            "avg_batch_time_ms": round(avg_batch_time, 1),
        }
