"""
WebSocket handler for TTS streaming.

Provides bidirectional WebSocket communication for TTS:
- Client sends JSON request with text and parameters
- Server streams audio chunks as binary frames
- Client can send control messages (cancel, ping)

Protocol:
1. Client connects to /v1/tts/ws
2. Client sends JSON: {"text": "...", "speaker": "...", ...}
3. Server sends binary audio chunks
4. Server sends JSON: {"event": "complete", "metrics": {...}}
5. Connection closes

Control messages:
- {"event": "cancel"} - Stop generation
- {"event": "ping"} - Keep-alive (server responds with pong)
"""

import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any
from dataclasses import dataclass
from enum import Enum

from veena3modal.shared.logging import (
    get_logger,
    set_request_context,
    clear_request_context,
    log_request_received,
    log_first_audio_emitted,
    log_request_completed,
    log_request_failed,
)
from veena3modal.shared.metrics import (
    record_request_received as metrics_request_received,
    record_request_completed as metrics_request_completed,
    record_request_failed as metrics_request_failed,
    record_ttfb,
    record_rtf,
    record_audio_duration,
    record_chunks_sent,
)

logger = get_logger(__name__)


class WSMessageType(str, Enum):
    """WebSocket message types."""
    # Client -> Server
    REQUEST = "request"
    CANCEL = "cancel"
    PING = "ping"
    
    # Server -> Client
    AUDIO_CHUNK = "audio_chunk"  # Binary frame
    HEADER = "header"  # WAV header info
    PROGRESS = "progress"  # Progress update
    COMPLETE = "complete"  # Generation complete
    ERROR = "error"  # Error occurred
    PONG = "pong"  # Response to ping


@dataclass
class WSRequest:
    """Parsed WebSocket TTS request."""
    text: str
    speaker: str = "male1"
    temperature: float = 0.8
    top_k: int = 50
    top_p: float = 1.0
    max_tokens: int = 4096
    repetition_penalty: float = 1.05
    seed: Optional[int] = None
    chunking: bool = False
    normalize: bool = True
    format: str = "wav"
    sample_rate: int = 16000
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "WSRequest":
        """Parse request from JSON dict."""
        return cls(
            text=data.get("text", ""),
            speaker=data.get("speaker", "male1"),
            temperature=data.get("temperature", 0.8),
            top_k=data.get("top_k", 50),
            top_p=data.get("top_p", 1.0),
            max_tokens=data.get("max_tokens", 4096),
            repetition_penalty=data.get("repetition_penalty", 1.05),
            seed=data.get("seed"),
            chunking=data.get("chunking", False),
            normalize=data.get("normalize", True),
            format=data.get("format", "wav"),
            sample_rate=data.get("sample_rate", 16000),
        )
    
    def validate(self) -> Optional[str]:
        """Validate request, return error message if invalid."""
        if not self.text or not self.text.strip():
            return "Text is required and cannot be empty"
        if len(self.text) > 50000:
            return f"Text too long: {len(self.text)} chars (max 50000)"
        if self.format != "wav":
            return f"WebSocket streaming only supports WAV format (got {self.format})"
        return None


def create_error_message(code: str, message: str, request_id: str) -> str:
    """Create JSON error message."""
    return json.dumps({
        "event": WSMessageType.ERROR.value,
        "error": {
            "code": code,
            "message": message,
            "request_id": request_id,
        }
    })


def create_header_message(
    request_id: str,
    sample_rate: int,
    model_version: str,
) -> str:
    """Create JSON header message (sent before audio)."""
    return json.dumps({
        "event": WSMessageType.HEADER.value,
        "request_id": request_id,
        "sample_rate": sample_rate,
        "format": "wav",
        "model_version": model_version,
    })


def create_progress_message(
    chunks_sent: int,
    bytes_sent: int,
    elapsed_ms: int,
) -> str:
    """Create JSON progress message."""
    return json.dumps({
        "event": WSMessageType.PROGRESS.value,
        "chunks_sent": chunks_sent,
        "bytes_sent": bytes_sent,
        "elapsed_ms": elapsed_ms,
    })


def create_complete_message(
    request_id: str,
    metrics: Dict[str, Any],
) -> str:
    """Create JSON completion message."""
    return json.dumps({
        "event": WSMessageType.COMPLETE.value,
        "request_id": request_id,
        "metrics": metrics,
    })


async def handle_websocket_tts(websocket, api_key: Optional[str] = None):
    """
    Handle WebSocket TTS streaming connection.
    
    This is the main handler called from FastAPI WebSocket route.
    
    Args:
        websocket: FastAPI WebSocket instance
        api_key: Optional API key for authentication
    """
    from veena3modal.api.auth import get_api_validator, hash_api_key
    from veena3modal.api.rate_limiter import get_rate_limiter
    from veena3modal.api.error_handlers import FeatureFlags
    from veena3modal.api.schemas import resolve_speaker_name
    from veena3modal.services import tts_runtime
    
    request_id = str(uuid.uuid4())
    set_request_context(request_id=request_id)
    
    # Accept the WebSocket connection
    await websocket.accept()
    logger.info(f"WebSocket connection accepted: {request_id}")
    
    # Track cancellation
    cancelled = False
    admission_lease = None
    
    try:
        # === AUTH CHECK (if enabled) ===
        if FeatureFlags.is_auth_enabled() and api_key:
            validator = get_api_validator()
            key_hash = hash_api_key(api_key)
            auth_result = validator.validate(key_hash)
            
            if not auth_result.is_valid:
                await websocket.send_text(create_error_message(
                    code=auth_result.error_code or "AUTH_FAILED",
                    message=auth_result.error_message or "Authentication failed",
                    request_id=request_id,
                ))
                await websocket.close(code=1008)  # Policy violation
                return
            
            # Rate limit check
            if FeatureFlags.is_rate_limiting_enabled():
                rate_limiter = get_rate_limiter()
                allowed, remaining, reset_after = rate_limiter.check(key_hash)
                if not allowed:
                    await websocket.send_text(create_error_message(
                        code="RATE_LIMIT_EXCEEDED",
                        message=f"Rate limit exceeded. Retry after {int(reset_after) + 1} seconds.",
                        request_id=request_id,
                    ))
                    await websocket.close(code=1008)
                    return
        
        # === WAIT FOR REQUEST ===
        try:
            # Wait for client to send request (with timeout)
            raw_message = await asyncio.wait_for(
                websocket.receive_text(),
                timeout=30.0
            )
            message = json.loads(raw_message)
        except asyncio.TimeoutError:
            await websocket.send_text(create_error_message(
                code="TIMEOUT",
                message="No request received within 30 seconds",
                request_id=request_id,
            ))
            await websocket.close(code=1000)
            return
        except json.JSONDecodeError as e:
            await websocket.send_text(create_error_message(
                code="INVALID_JSON",
                message=f"Invalid JSON: {e}",
                request_id=request_id,
            ))
            await websocket.close(code=1003)  # Unsupported data
            return
        
        # Handle control messages
        event_type = message.get("event", "request")
        if event_type == "ping":
            await websocket.send_text(json.dumps({"event": "pong"}))
            # Wait for actual request
            raw_message = await websocket.receive_text()
            message = json.loads(raw_message)
        
        # === PARSE REQUEST ===
        req = WSRequest.from_dict(message)
        validation_error = req.validate()
        if validation_error:
            await websocket.send_text(create_error_message(
                code="VALIDATION_ERROR",
                message=validation_error,
                request_id=request_id,
            ))
            await websocket.close(code=1003)
            return
        
        # Resolve speaker
        speaker = resolve_speaker_name(req.speaker)
        
        # Log request
        log_request_received(
            request_id=request_id,
            text_length=len(req.text),
            speaker=speaker,
            stream=True,
            format="wav",
        )
        metrics_request_received(speaker=speaker, stream=True, format="wav")
        
        # === CHECK MODEL ===
        if not tts_runtime.is_initialized():
            await websocket.send_text(create_error_message(
                code="MODEL_NOT_LOADED",
                message="TTS model not initialized",
                request_id=request_id,
            ))
            await websocket.close(code=1011)  # Internal error
            return
        
        runtime = tts_runtime.get_runtime()
        model_version = runtime.model_version if runtime else "unknown"

        try:
            admission_lease = await tts_runtime.acquire_streaming_slot()
        except tts_runtime.StreamingAdmissionError as exc:
            log_request_failed(
                request_id=request_id,
                status_code=429,
                error_code="STREAMING_OVERLOADED",
                error_message=str(exc),
            )
            metrics_request_failed(
                status_code=429,
                error_code="STREAMING_OVERLOADED",
                speaker=speaker,
            )
            await websocket.send_text(create_error_message(
                code="STREAMING_OVERLOADED",
                message="Streaming capacity exhausted. Retry shortly.",
                request_id=request_id,
            ))
            await websocket.close(code=1013)  # Try again later
            return
        
        # === SEND HEADER INFO ===
        await websocket.send_text(create_header_message(
            request_id=request_id,
            sample_rate=req.sample_rate,
            model_version=model_version,
        ))
        
        # === STREAM AUDIO ===
        start_time = time.time()
        chunks_sent = 0
        total_bytes = 0
        first_chunk_time = None
        final_metrics = {}
        
        # Start a task to listen for cancel messages
        async def listen_for_cancel():
            nonlocal cancelled
            try:
                while not cancelled:
                    try:
                        msg = await asyncio.wait_for(
                            websocket.receive_text(),
                            timeout=0.1
                        )
                        data = json.loads(msg)
                        if data.get("event") == "cancel":
                            cancelled = True
                            logger.info(f"Client requested cancel: {request_id}")
                            break
                        elif data.get("event") == "ping":
                            await websocket.send_text(json.dumps({"event": "pong"}))
                    except asyncio.TimeoutError:
                        continue
                    except Exception:
                        break
            except Exception:
                pass
        
        cancel_task = asyncio.create_task(listen_for_cancel())
        
        try:
            # Apply text normalization if enabled
            text = req.text
            if req.normalize:
                try:
                    from veena3modal.processing.text_normalizer import normalize_text
                    text = normalize_text(text)
                except ImportError:
                    pass  # Normalizer not available
            
            # Stream audio chunks
            async for audio_chunk, metrics in tts_runtime.generate_speech_streaming(
                text=text,
                speaker=speaker,
                temperature=req.temperature,
                top_k=req.top_k,
                top_p=req.top_p,
                max_tokens=req.max_tokens,
                repetition_penalty=req.repetition_penalty,
                seed=req.seed,
                enable_chunking=req.chunking,
                admission_lease=admission_lease,
                release_admission_lease=False,
            ):
                if cancelled:
                    logger.info(f"Generation cancelled: {request_id}")
                    break
                
                final_metrics = metrics
                chunks_sent += 1
                total_bytes += len(audio_chunk)
                
                # Record TTFB on first chunk
                if first_chunk_time is None:
                    first_chunk_time = time.time()
                    ttfb_ms = int((first_chunk_time - start_time) * 1000)
                    log_first_audio_emitted(
                        request_id=request_id,
                        ttfb_ms=ttfb_ms,
                        chunk_size_bytes=len(audio_chunk),
                    )
                    record_ttfb(ttfb_seconds=ttfb_ms / 1000, speaker=speaker, stream=True)
                
                # Send audio as binary frame
                await websocket.send_bytes(audio_chunk)
                
                # Send progress every 10 chunks
                if chunks_sent % 10 == 0:
                    await websocket.send_text(create_progress_message(
                        chunks_sent=chunks_sent,
                        bytes_sent=total_bytes,
                        elapsed_ms=int((time.time() - start_time) * 1000),
                    ))
        
        finally:
            cancel_task.cancel()
            try:
                await cancel_task
            except asyncio.CancelledError:
                pass
            tts_runtime.release_streaming_slot(admission_lease)
        
        # === SEND COMPLETION ===
        total_time = time.time() - start_time
        audio_duration = final_metrics.get("audio_duration_seconds", 0)
        rtf = total_time / audio_duration if audio_duration > 0 else 0
        
        completion_metrics = {
            "request_id": request_id,
            "chunks_sent": chunks_sent,
            "total_bytes": total_bytes,
            "audio_duration_seconds": round(audio_duration, 2),
            "total_time_seconds": round(total_time, 2),
            "ttfb_ms": int((first_chunk_time - start_time) * 1000) if first_chunk_time else 0,
            "rtf": round(rtf, 3),
            "cancelled": cancelled,
        }
        
        await websocket.send_text(create_complete_message(
            request_id=request_id,
            metrics=completion_metrics,
        ))
        
        # Log completion
        if not cancelled:
            log_request_completed(
                request_id=request_id,
                status_code=200,
                total_duration_ms=int(total_time * 1000),
                audio_duration_seconds=audio_duration,
                rtf=rtf,
                chunks_sent=chunks_sent,
            )
            metrics_request_completed(
                status_code=200,
                duration_seconds=total_time,
                speaker=speaker,
                stream=True,
            )
            record_rtf(rtf=rtf, speaker=speaker)
            record_audio_duration(duration_seconds=audio_duration, speaker=speaker)
            record_chunks_sent(chunks=chunks_sent, speaker=speaker)
        
        # Close connection gracefully
        await websocket.close(code=1000)
        
    except Exception as e:
        logger.exception(f"WebSocket error: {e}")
        log_request_failed(
            request_id=request_id,
            status_code=500,
            error_code="WEBSOCKET_ERROR",
            error_message=str(e),
        )
        metrics_request_failed(
            status_code=500,
            error_code="WEBSOCKET_ERROR",
        )
        
        try:
            await websocket.send_text(create_error_message(
                code="INTERNAL_ERROR",
                message=str(e),
                request_id=request_id,
            ))
            await websocket.close(code=1011)
        except Exception:
            pass
    
    finally:
        tts_runtime.release_streaming_slot(admission_lease)
        clear_request_context()


def add_websocket_routes(app):
    """
    Add WebSocket routes to FastAPI app.
    
    Call this from create_app() to enable WebSocket support.
    
    Args:
        app: FastAPI application instance
    """
    from fastapi import WebSocket, Query
    
    @app.websocket("/v1/tts/ws")
    async def websocket_tts(
        websocket: WebSocket,
        api_key: Optional[str] = Query(None, alias="api_key"),
    ):
        """
        WebSocket endpoint for TTS streaming.
        
        Connect with optional api_key query parameter:
        ws://host/v1/tts/ws?api_key=your-key
        
        Then send JSON request:
        {"text": "Hello world", "speaker": "male1"}
        
        Receive:
        - JSON header message
        - Binary audio chunks
        - JSON progress messages (every 10 chunks)
        - JSON completion message
        """
        # Also check Authorization header
        auth_header = websocket.headers.get("authorization", "")
        if auth_header.startswith("Bearer "):
            api_key = auth_header[7:]
        
        await handle_websocket_tts(websocket, api_key)
    
    logger.info("WebSocket route /v1/tts/ws registered")
