"""Orpheus TTS Client - Stream speech from the Orpheus model."""

from __future__ import annotations

import asyncio
import json
import os
import queue
import threading
import urllib.error
import urllib.request
import uuid
from pathlib import Path
from typing import AsyncIterator, Iterator, Optional
from urllib.parse import quote, urlparse, urlunparse

try:
    import websockets
    from websockets.exceptions import ConnectionClosed
except ImportError:
    raise ImportError(
        "websockets package is required. Install with: pip install orpheus-tts"
    )

from orpheus_tts.exceptions import (
    AuthenticationError,
    ConnectionError,
    OrpheusError,
    StreamingError,
)

_DEFAULT_POOL_SIZE = 16
_DEFAULT_MULTIPLEX_WS_ENV = "ORPHEUS_TTS_MULTIPLEX_WS_URL"
_DEFAULT_CLIENT_NAME = "demo"
_DEFAULT_CONFIG_SERVICE_HOST = "http://35.232.65.178"
_DEFAULT_CONFIG_SERVICE_PORT = 8000
_DEFAULT_CONFIG_REQUEST_TIMEOUT_SECONDS = 10.0


async def _connect_websocket(
    url: str,
    headers: Optional[dict[str, str]] = None,
):
    """Create websocket connection across websockets version differences."""
    kwargs = dict(
        max_size=16 * 1024 * 1024,
        ping_interval=30,
        ping_timeout=10,
        close_timeout=5,
    )
    try:
        if headers:
            try:
                return await websockets.connect(url, additional_headers=headers, **kwargs)
            except TypeError:
                return await websockets.connect(url, extra_headers=headers, **kwargs)
        return await websockets.connect(url, **kwargs)
    except Exception as exc:
        status_code = _extract_websocket_status_code(exc)
        if status_code in (401, 403):
            raise AuthenticationError(
                "Authentication failed while connecting to the Orpheus streaming endpoint "
                f"(HTTP {status_code}). Please verify your API key."
            ) from exc
        if status_code == 404:
            raise ConnectionError(
                "The Orpheus streaming endpoint was not found (HTTP 404). "
                "Please verify the configured websocket URL."
            ) from exc
        if status_code is not None:
            raise ConnectionError(
                "The Orpheus streaming endpoint rejected the websocket connection "
                f"(HTTP {status_code}). Please verify endpoint availability and service health."
            ) from exc
        raise


def _get_endpoint(voice: str, voice_endpoint_map: dict[str, str]) -> str:
    """Resolve a voice name to its WebSocket endpoint URL.

    Raises:
        ValueError: If the voice is not recognized.
    """
    normalized_voice = voice.strip().lower()
    try:
        return voice_endpoint_map[normalized_voice]
    except KeyError:
        valid = ", ".join(sorted(voice_endpoint_map.keys()))
        raise ValueError(
            f"Unknown voice '{voice}'. Valid voices: {valid}"
        )


def _write_local_config(data: dict, local_config_path: Path) -> None:
    local_config_path.parent.mkdir(parents=True, exist_ok=True)
    local_config_path.write_text(json.dumps(data, indent=2, sort_keys=True), encoding="utf-8")


def _load_voice_endpoint_map_from_local_config(local_config_path: Path) -> dict[str, str]:
    if not local_config_path.exists():
        raise ValueError(f"Local config file not found: {local_config_path}")

    data = json.loads(local_config_path.read_text(encoding="utf-8") or "{}")
    if isinstance(data, dict) and isinstance(data.get("voice_endpoint_map"), dict):
        voice_endpoint_map = data["voice_endpoint_map"]
    else:
        voice_endpoint_map = data
    if not isinstance(voice_endpoint_map, dict):
        raise ValueError(
            f"Invalid local config at {local_config_path}: "
            "expected voice mapping JSON object"
        )

    normalized: dict[str, str] = {}
    for voice_name, endpoint_url in voice_endpoint_map.items():
        voice_key = str(voice_name).strip().lower()
        endpoint_value = str(endpoint_url).strip()
        if not voice_key or not endpoint_value:
            continue
        normalized[voice_key] = endpoint_value
    if not normalized:
        raise ValueError(f"No voice endpoint mappings found in {local_config_path}")
    return normalized


def _normalize_config_host(config_host: str) -> str:
    value = config_host.strip()
    if not value:
        raise ValueError("Config host must not be empty")
    if "://" not in value:
        value = f"http://{value}"

    parsed = urlparse(value)
    if not parsed.netloc:
        raise ValueError(
            f"Invalid config host '{config_host}'. "
            "Expected something like http://35.232.65.178"
        )

    hostname = parsed.hostname
    if not hostname:
        raise ValueError(f"Invalid config host '{config_host}'")
    if ":" in hostname and not hostname.startswith("["):
        hostname = f"[{hostname}]"
    netloc = parsed.netloc if parsed.port is not None else f"{hostname}:{_DEFAULT_CONFIG_SERVICE_PORT}"
    base_path = parsed.path.rstrip("/")
    return urlunparse((parsed.scheme or "http", netloc, base_path, "", "", ""))


def _build_config_service_url(config_host: str, provider: str) -> str:
    base_url = _normalize_config_host(config_host)
    client_segment = quote(_DEFAULT_CLIENT_NAME, safe="")
    provider_segment = quote(provider.strip(), safe="")
    return f"{base_url}/{client_segment}/{provider_segment}"


def _format_config_service_http_error(
    *,
    status_code: int,
    service_url: str,
    provider: str,
) -> str:
    if status_code in (401, 403):
        return (
            "Unable to load voice configuration from the config service because access was denied "
            f"(HTTP {status_code}) for provider {provider}. "
            "Please verify your SDK credentials and service authorization settings."
        )
    if status_code == 404:
        return (
            f"Voice configuration was not found for provider {provider} (HTTP 404). "
            "Please verify the provider value and confirm this configuration is published."
        )
    return (
        f"Unable to load voice configuration for provider {provider} "
        f"(HTTP {status_code}). "
        "Please try again or contact support if the issue persists."
    )


def _format_multiplex_connection_error(ws_url: str, websocket_count: int) -> str:
    return (
        "Unable to establish a connection to the Orpheus streaming endpoint "
        f"using {websocket_count} websocket(s). "
        "Please verify your API key, endpoint availability, and network connectivity."
    )


def _extract_websocket_status_code(exc: Exception) -> Optional[int]:
    """Best-effort extraction of an HTTP status from websocket handshake failures."""
    response = getattr(exc, "response", None)
    response_status = getattr(response, "status_code", None)
    if isinstance(response_status, int):
        return response_status

    status_code = getattr(exc, "status_code", None)
    if isinstance(status_code, int):
        return status_code

    status = getattr(exc, "status", None)
    if status is not None:
        try:
            return int(status)
        except (TypeError, ValueError):
            return None

    return None


def _load_voice_endpoint_map_from_config_service(
    *,
    provider: str,
    config_host: str,
    timeout_seconds: float = _DEFAULT_CONFIG_REQUEST_TIMEOUT_SECONDS,
    local_config_path: Optional[str] = None,
) -> dict[str, str]:
    service_url = _build_config_service_url(config_host, provider)
    request = urllib.request.Request(service_url, headers={"Accept": "application/json"})
    try:
        with urllib.request.urlopen(request, timeout=timeout_seconds) as response:
            payload = response.read().decode("utf-8")
    except urllib.error.HTTPError as exc:
        raise ValueError(
            _format_config_service_http_error(
                status_code=exc.code,
                service_url=service_url,
                provider=provider,
            )
        ) from exc
    except urllib.error.URLError as exc:
        raise ValueError(
            f"Failed to download voice configuration for provider {provider}"
        ) from exc

    try:
        data = json.loads(payload or "{}")
    except json.JSONDecodeError as exc:
        raise ValueError(
            f"Config service at {service_url} did not return valid JSON"
        ) from exc
    if not isinstance(data, dict):
        raise ValueError(f"Config service at {service_url} must return a JSON object")

    normalized: dict[str, str] = {}
    for voice_name, endpoint_url in data.items():
        voice_key = str(voice_name).strip().lower()
        endpoint_value = str(endpoint_url).strip()
        if not voice_key or not endpoint_value:
            continue
        normalized[voice_key] = endpoint_value
    if not normalized:
        raise ValueError(
            f"Config service at {service_url} returned no voice endpoint mappings"
        )

    if local_config_path:
        resolved_local_path = Path(local_config_path).expanduser().resolve()
        _write_local_config({"voice_endpoint_map": normalized}, resolved_local_path)
        return _load_voice_endpoint_map_from_local_config(resolved_local_path)

    return normalized


def _load_or_download_voice_endpoint_map(
    *,
    provider: str,
    config_host: str = _DEFAULT_CONFIG_SERVICE_HOST,
    timeout_seconds: float = _DEFAULT_CONFIG_REQUEST_TIMEOUT_SECONDS,
    local_config_path: Optional[str] = None,
    refresh: bool = False,
) -> dict[str, str]:
    if local_config_path:
        resolved_local_path = Path(local_config_path).expanduser().resolve()
        if resolved_local_path.exists() and not refresh:
            return _load_voice_endpoint_map_from_local_config(resolved_local_path)

        return _load_voice_endpoint_map_from_config_service(
            provider=provider,
            config_host=config_host,
            timeout_seconds=timeout_seconds,
            local_config_path=str(resolved_local_path),
        )

    return _load_voice_endpoint_map_from_config_service(
        provider=provider,
        config_host=config_host,
        timeout_seconds=timeout_seconds,
    )


def _get_health_url(endpoint: str) -> str:
    """Convert a WSS endpoint URL to its HTTPS /health URL."""
    parsed = urlparse(endpoint)
    scheme = "https" if parsed.scheme == "wss" else "http"
    return urlunparse((scheme, parsed.netloc, "/health", "", "", ""))


def _fire_health_ping(url: str) -> None:
    """Send a fire-and-forget GET to the health endpoint (runs in thread pool)."""
    try:
        urllib.request.urlopen(url, timeout=5)
    except Exception:
        pass


def _log_request_json(payload: dict) -> str:
    """Serialize and print an outbound JSON request payload."""
    payload_json = json.dumps(payload)
    return payload_json


# Audio format constants
SAMPLE_RATE = 48000  # Hz
SAMPLE_WIDTH = 2  # bytes (int16)
CHANNELS = 1  # mono


class _MultiplexRequestState:
    """Tracks one in-flight multiplex request."""

    def __init__(self, request_id: str):
        self.request_id = request_id
        self.queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue()
        self.error: Optional[Exception] = None


class _MultiplexedWebSocketConnection:
    """Single websocket that can carry many concurrent TTS requests."""

    def __init__(self, ws_url: str, headers: Optional[dict[str, str]] = None):
        self.ws_url = ws_url
        self.headers = headers
        self.ws = None
        self._connected = False
        self._request_counter = 0
        self._send_lock = asyncio.Lock()
        self._state_lock = asyncio.Lock()
        self._receiver_task = None
        self._states: dict[str, _MultiplexRequestState] = {}
        self._server_to_client: dict[int, str] = {}

    async def connect(self) -> None:
        self.ws = await _connect_websocket(self.ws_url, self.headers)
        self._connected = True
        self._receiver_task = asyncio.create_task(self._receive_loop())

    async def close(self) -> None:
        self._connected = False
        if self._receiver_task is not None:
            self._receiver_task.cancel()
            try:
                await self._receiver_task
            except asyncio.CancelledError:
                pass
            self._receiver_task = None
        if self.ws is not None:
            try:
                await self.ws.close()
            except Exception:
                pass
            self.ws = None

    async def _receive_loop(self) -> None:
        try:
            async for message in self.ws:
                if isinstance(message, bytes):
                    await self._handle_binary_message(message)
                else:
                    await self._handle_json_message(message)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            await self._fail_all_requests(ConnectionError(f"Multiplex receive failed: {e}"))

    async def _handle_binary_message(self, message: bytes) -> None:
        if len(message) < 4:
            return
        server_request_id = int.from_bytes(message[:4], "big")
        audio = message[4:]
        if not audio:
            return
        async with self._state_lock:
            request_id = self._server_to_client.get(server_request_id)
            if request_id is None:
                # Backward-compatible fallback for gateways that stream binary audio
                # before sending an explicit "accepted" mapping.
                mapped_request_ids = set(self._server_to_client.values())
                for candidate_request_id in self._states.keys():
                    if candidate_request_id not in mapped_request_ids:
                        request_id = candidate_request_id
                        self._server_to_client[server_request_id] = candidate_request_id
                        break
            state = self._states.get(request_id) if request_id is not None else None
        if state is not None:
            await state.queue.put(audio)

    async def _handle_json_message(self, message: str) -> None:
        try:
            data = json.loads(message)
        except json.JSONDecodeError:
            return

        msg_type = data.get("type")
        server_request_id = data.get("server_request_id")
        client_request_id = data.get("client_request_id")
        if server_request_id is None and isinstance(data.get("request_id"), int):
            server_request_id = data.get("request_id")

        if msg_type == "accepted" and server_request_id is not None and client_request_id is not None:
            async with self._state_lock:
                self._server_to_client[int(server_request_id)] = str(client_request_id)
            return

        request_id = None
        if client_request_id is not None:
            request_id = str(client_request_id)
        elif isinstance(data.get("request_id"), str):
            request_id = data.get("request_id")
        elif server_request_id is not None:
            async with self._state_lock:
                request_id = self._server_to_client.get(int(server_request_id))
        if request_id is None:
            return

        async with self._state_lock:
            state = self._states.get(request_id)
        if state is None:
            return

        if data.get("error"):
            error_message = str(data.get("error"))
            if "auth" in error_message.lower() or "key" in error_message.lower():
                state.error = AuthenticationError(error_message)
            else:
                state.error = StreamingError(error_message)
            await state.queue.put(None)
            return

        if data.get("done"):
            await state.queue.put(None)

    async def _fail_all_requests(self, error: Exception) -> None:
        async with self._state_lock:
            states = list(self._states.values())
        for state in states:
            if state.error is None:
                state.error = error
            await state.queue.put(None)

    async def stream_request(
        self,
        text: str,
        voice: str,
        max_tokens: int,
        temperature: float,
        repetition_penalty: float,
    ) -> AsyncIterator[bytes]:
        if not self._connected or self.ws is None:
            raise ConnectionError("Multiplex websocket is not connected")

        async with self._state_lock:
            # Use globally unique IDs to avoid collisions on shared gateway state.
            request_id = f"req_{uuid.uuid4().hex}"
            state = _MultiplexRequestState(request_id)
            self._states[request_id] = state

        payload = {
            "request_id": request_id,
            "prompt": text,
            "voice": voice,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "repetition_penalty": repetition_penalty,
        }

        try:
            async with self._send_lock:
                await self.ws.send(_log_request_json(payload))

            while True:
                item = await state.queue.get()
                if item is None:
                    break
                yield item

            if state.error is not None:
                raise state.error
        finally:
            async with self._state_lock:
                self._states.pop(request_id, None)
                stale_server_ids = [
                    server_id
                    for server_id, client_id in self._server_to_client.items()
                    if client_id == request_id
                ]
                for server_id in stale_server_ids:
                    self._server_to_client.pop(server_id, None)


class _MultiplexWebSocketPool:
    """Pool of multiplex-capable websockets with load-aware selection."""

    def __init__(
        self,
        ws_url: str,
        websocket_count: int,
        headers: Optional[dict[str, str]] = None,
        max_inflight_per_socket: int = 0,
    ):
        self.ws_url = ws_url
        self.websocket_count = max(1, websocket_count)
        self.headers = headers
        self.max_inflight_per_socket = max(0, max_inflight_per_socket)
        self._clients: list[_MultiplexedWebSocketConnection] = []
        self._inflight: dict[_MultiplexedWebSocketConnection, int] = {}
        self._cond = asyncio.Condition()

    async def initialize(self) -> None:
        async def _create_one():
            client = _MultiplexedWebSocketConnection(self.ws_url, self.headers)
            try:
                await client.connect()
                return client, None
            except Exception as exc:
                return None, exc

        created = await asyncio.gather(*[_create_one() for _ in range(self.websocket_count)])
        first_error: Optional[Exception] = None
        for client, error in created:
            if client is not None:
                self._clients.append(client)
                self._inflight[client] = 0
            elif error is not None and first_error is None:
                first_error = error
        if not self._clients:
            if first_error is not None:
                raise first_error
            raise ConnectionError(
                _format_multiplex_connection_error(
                    ws_url=self.ws_url,
                    websocket_count=self.websocket_count,
                )
            )

    async def close(self) -> None:
        async with self._cond:
            clients = list(self._clients)
            self._clients.clear()
            self._inflight.clear()
            self._cond.notify_all()
        for client in clients:
            await client.close()

    def has_connections(self) -> bool:
        """Return whether at least one websocket is connected."""
        return bool(self._clients)

    async def _acquire_client(self) -> _MultiplexedWebSocketConnection:
        async with self._cond:
            while True:
                if not self._clients:
                    raise ConnectionError("No multiplex websocket connections available")

                candidates = sorted(
                    self._clients,
                    key=lambda c: self._inflight.get(c, 0),
                )
                for client in candidates:
                    inflight = self._inflight.get(client, 0)
                    if self.max_inflight_per_socket > 0 and inflight >= self.max_inflight_per_socket:
                        continue
                    self._inflight[client] = inflight + 1
                    return client
                await self._cond.wait()

    async def _release_client(self, client: _MultiplexedWebSocketConnection) -> None:
        async with self._cond:
            if client in self._inflight:
                self._inflight[client] = max(0, self._inflight[client] - 1)
            self._cond.notify()

    async def stream_request(
        self,
        text: str,
        voice: str,
        max_tokens: int,
        temperature: float,
        repetition_penalty: float,
    ) -> AsyncIterator[bytes]:
        client = await self._acquire_client()
        try:
            async for chunk in client.stream_request(
                text=text,
                voice=voice,
                max_tokens=max_tokens,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
            ):
                yield chunk
        finally:
            await self._release_client(client)


class OrpheusClient:
    """
    Client for streaming speech from the Orpheus TTS model.

    Example:
        >>> from orpheus_tts import OrpheusClient
        >>>
        >>> # Option 1: Auto-connect (includes connection time in first request)
        >>> client = OrpheusClient(provider="PROVIDER_NAME")
        >>> for chunk in client.stream("Hello!", voice="josh"):
        ...     audio_player.write(chunk)
        >>>
        >>> # Option 2: Pre-connect for lowest latency (recommended)
        >>> client = OrpheusClient(provider="PROVIDER_NAME")
        >>> client.connect()  # Pool of 16 connections per endpoint
        >>> for chunk in client.stream("Hello!", voice="josh"):
        ...     audio_player.write(chunk)  # TTFA excludes handshake!
        >>> client.close()
        >>>
        >>> # Option 3: Async usage (same pool, same connect)
        >>> client = OrpheusClient(provider="PROVIDER_NAME")
        >>> client.connect()
        >>> async for chunk in client.stream_async("Hello!", voice="josh"):
        ...     audio_player.write(chunk)
        >>> client.close()
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        max_tokens: int = 3000,
        temperature: float = 1.0,
        repetition_penalty: float = 1.1,
        *,
        provider: Optional[str] = None,
        voice_endpoint_map: Optional[dict[str, str]] = None,
    ):
        """
        Initialize the Orpheus TTS client.

        Args:
            api_key: API key for authentication (optional for now).
            max_tokens: Maximum tokens to generate (default: 3000).
            temperature: Sampling temperature (default: 1.0).
            repetition_penalty: Repetition penalty (default: 1.1).
            provider: Provider used to load the voice configuration
                (for example "PROVIDER_NAME").
            voice_endpoint_map: Explicit voice->url mapping override.
        """
        self.api_key = api_key
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.repetition_penalty = repetition_penalty

        if voice_endpoint_map is not None:
            self._voice_endpoint_map = {
                str(k).strip().lower(): str(v) for k, v in voice_endpoint_map.items()
            }
        elif provider:
            self._voice_endpoint_map = _load_or_download_voice_endpoint_map(
                provider=provider,
            )
        else:
            raise ValueError(
                "A provider is required unless voice_endpoint_map is provided."
            )

        # Connection pool state (shared by sync and async)
        self._pools: dict[str, asyncio.Queue] = {}  # endpoint -> queue of ws
        self._all_connections: list = []  # flat list for cleanup
        self._multiplex_pool: Optional[_MultiplexWebSocketPool] = None
        self._multiplex_ws_url: Optional[str] = None
        self._websocket_count: int = 0
        self._max_inflight_per_socket: int = 0
        self._loop = None
        self._thread = None
        self._connected = False

    def _get_auth_headers(self) -> Optional[dict[str, str]]:
        """Build websocket auth headers when an API key is configured."""
        if not self.api_key:
            return None
        return {"Authorization": f"Api-Key {self.api_key}"}

    def connect(
        self,
        pool_size: int = _DEFAULT_POOL_SIZE,
        *,
        ws_url: Optional[str] = None,
        voice: Optional[str] = None,
        websocket_count: Optional[int] = None,
        max_inflight_per_socket: int = 0,
    ) -> None:
        """
        Pre-establish WebSocket connections for lowest latency.

        Multiplex mode (recommended):
            - Pass ws_url, or pass voice, or set ORPHEUS_TTS_MULTIPLEX_WS_URL.
            - Creates websocket_count multiplex sockets to a single endpoint.
            - A single socket can serve many concurrent requests.

        Legacy mode (compat):
            - If ws_url/voice/env are not provided, uses per-voice endpoint pools.
            - Creates pool_size connections per unique endpoint.

        Args:
            pool_size: Legacy per-endpoint pool size.
            ws_url: Multiplex gateway websocket URL.
            voice: Voice name to resolve to a multiplex endpoint URL.
            websocket_count: Number of multiplex sockets (default: pool_size).
            max_inflight_per_socket: Optional cap per socket (0 = unlimited).

        Example:
            >>> client = OrpheusClient(provider="PROVIDER_NAME")
            >>> client.connect()  # 16 connections per endpoint
            >>> for chunk in client.stream("Hello!", voice="josh"):
            ...     process(chunk)
        """
        if self._connected:
            return

        if ws_url is not None and voice is not None:
            raise ValueError("Pass either ws_url or voice to connect(), not both.")

        resolved_ws_url = ws_url
        if resolved_ws_url is None and voice is not None:
            resolved_ws_url = _get_endpoint(voice, self._voice_endpoint_map)
        if resolved_ws_url is None:
            resolved_ws_url = os.getenv(_DEFAULT_MULTIPLEX_WS_ENV)

        self._multiplex_ws_url = resolved_ws_url
        self._websocket_count = max(1, websocket_count if websocket_count is not None else pool_size)
        self._max_inflight_per_socket = max(0, max_inflight_per_socket)

        # Create background event loop
        self._loop = asyncio.new_event_loop()

        def run_loop():
            asyncio.set_event_loop(self._loop)
            self._loop.run_forever()

        self._thread = threading.Thread(target=run_loop, daemon=True)
        self._thread.start()

        # Create connection pools on background loop
        future = asyncio.run_coroutine_threadsafe(
            self._create_pools(pool_size=pool_size), self._loop
        )
        try:
            future.result(timeout=30.0)
            has_legacy_connections = bool(self._all_connections)
            has_multiplex_connections = (
                self._multiplex_pool is not None and self._multiplex_pool.has_connections()
            )
            if not has_legacy_connections and not has_multiplex_connections:
                raise ConnectionError("Failed to connect to any endpoint")
            self._connected = True
        except (ConnectionError, AuthenticationError):
            self.close()
            raise
        except Exception as e:
            self.close()
            raise ConnectionError(f"Failed to connect: {e}") from e

    async def _create_pools(self, pool_size: int) -> None:
        """Internal: create multiplex pool or legacy pools."""
        if self._multiplex_ws_url:
            headers = self._get_auth_headers()
            self._multiplex_pool = _MultiplexWebSocketPool(
                ws_url=self._multiplex_ws_url,
                websocket_count=self._websocket_count,
                headers=headers,
                max_inflight_per_socket=self._max_inflight_per_socket,
            )
            await self._multiplex_pool.initialize()
            return

        endpoints = set(self._voice_endpoint_map.values())

        async def create_connection(endpoint: str, idx: int):
            try:
                ws = await _connect_websocket(endpoint, self._get_auth_headers())
                return (endpoint, ws, None)
            except Exception as exc:
                return (endpoint, None, exc)

        # Initialize queues
        for endpoint in endpoints:
            self._pools[endpoint] = asyncio.Queue()

        # Create all connections in parallel
        tasks = []
        for endpoint in endpoints:
            for i in range(pool_size):
                tasks.append(create_connection(endpoint, i))

        results = await asyncio.gather(*tasks)

        first_error: Optional[Exception] = None
        for endpoint, ws, error in results:
            if ws is not None:
                self._all_connections.append(ws)
                await self._pools[endpoint].put(ws)
            elif error is not None and first_error is None:
                first_error = error

        if not self._all_connections and first_error is not None:
            raise first_error

    async def _acquire(self, voice: str):
        """Acquire a connection from the pool for the given voice."""
        endpoint = _get_endpoint(voice, self._voice_endpoint_map)
        pool = self._pools.get(endpoint)
        if pool is None:
            raise ConnectionError(
                f"No connection pool for voice '{voice}'. "
                f"Call connect() first or check voice name."
            )
        ws = await pool.get()
        return endpoint, ws

    async def _release(self, ws, endpoint: str):
        """Release a connection back to the pool, replacing it if dead."""
        try:
            is_open = not ws.closed if hasattr(ws, 'closed') else ws.open
        except Exception:
            is_open = False

        pool = self._pools.get(endpoint)
        if pool is None:
            return

        if is_open:
            await pool.put(ws)
        else:
            # Replace dead connection
            try:
                new_ws = await _connect_websocket(endpoint, self._get_auth_headers())
                await pool.put(new_ws)
            except Exception:
                pass  # Pool will have one less connection

    def close(self) -> None:
        """Close all connections and cleanup resources."""
        if self._loop is not None and (self._pools or self._multiplex_pool is not None):
            async def _close_all():
                if self._multiplex_pool is not None:
                    await self._multiplex_pool.close()
                    self._multiplex_pool = None

                # Drain pools
                for pool in self._pools.values():
                    while not pool.empty():
                        try:
                            ws = pool.get_nowait()
                            if ws not in self._all_connections:
                                self._all_connections.append(ws)
                        except asyncio.QueueEmpty:
                            break

                # Close all connections
                for ws in self._all_connections:
                    try:
                        await asyncio.wait_for(ws.close(), timeout=1.0)
                    except Exception:
                        pass

            try:
                future = asyncio.run_coroutine_threadsafe(_close_all(), self._loop)
                future.result(timeout=10.0)
            except Exception:
                pass

            self._all_connections.clear()
            self._pools.clear()
            self._multiplex_ws_url = None
            self._websocket_count = 0
            self._max_inflight_per_socket = 0

        if self._loop is not None:
            self._loop.call_soon_threadsafe(self._loop.stop)
            if self._thread is not None:
                self._thread.join(timeout=2.0)
            self._loop = None
            self._thread = None

        self._connected = False

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
        return False

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self.close()
        return False

    async def _stream_on_connection(
        self,
        ws,
        text: str,
        voice: str,
        max_tokens: Optional[int],
        temperature: Optional[float],
        repetition_penalty: Optional[float],
    ) -> AsyncIterator[bytes]:
        """Internal: stream audio on an existing WebSocket connection."""
        request_body = {
            "prompt": text,
            "voice": voice,
            "max_tokens": max_tokens if max_tokens is not None else self.max_tokens,
            "temperature": temperature if temperature is not None else self.temperature,
            "repetition_penalty": repetition_penalty
            if repetition_penalty is not None
            else self.repetition_penalty,
        }

        await ws.send(_log_request_json(request_body))

        # Fire-and-forget health ping for autoscaler tracking
        try:
            health_url = _get_health_url(_get_endpoint(voice, self._voice_endpoint_map))
            asyncio.get_running_loop().run_in_executor(None, _fire_health_ping, health_url)
        except Exception:
            pass

        async for message in ws:
            if isinstance(message, bytes):
                yield message
            else:
                try:
                    data = json.loads(message)
                    if "error" in data:
                        error_msg = data["error"]
                        if "auth" in error_msg.lower() or "key" in error_msg.lower():
                            raise AuthenticationError(error_msg)
                        raise StreamingError(error_msg)
                    if data.get("done"):
                        break
                except json.JSONDecodeError:
                    pass

    def stream(
        self,
        text: str,
        voice: str = "josh",
        *,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
    ) -> Iterator[bytes]:
        """
        Stream audio synchronously from the Orpheus TTS model.

        If connect() was called first, uses a pooled connection for lowest
        latency. Otherwise, creates a new connection per request.

        Args:
            text: The text to convert to speech.
            voice: Voice to use. See orpheus_tts.VOICES for options.
            max_tokens: Override default max_tokens.
            temperature: Override default temperature.
            repetition_penalty: Override default repetition_penalty.

        Yields:
            bytes: Raw PCM audio chunks (int16, 48kHz, mono).

        Raises:
            ValueError: If voice is not recognized.
            AuthenticationError: If API key is invalid.
            ConnectionError: If connection to service fails.
            StreamingError: If an error occurs during streaming.
        """
        voice = voice.strip().lower()
        endpoint = _get_endpoint(voice, self._voice_endpoint_map)

        if self._connected and self._loop is not None:
            if self._multiplex_pool is not None:
                yield from self._stream_from_multiplex_pool(
                    text, voice, max_tokens, temperature, repetition_penalty
                )
            else:
                yield from self._stream_from_pool(
                    endpoint, text, voice, max_tokens, temperature, repetition_penalty
                )
        else:
            yield from self._stream_with_new_connection(
                text, voice, endpoint, max_tokens, temperature, repetition_penalty
            )

    def _stream_from_multiplex_pool(
        self,
        text: str,
        voice: str,
        max_tokens: Optional[int],
        temperature: Optional[float],
        repetition_penalty: Optional[float],
    ) -> Iterator[bytes]:
        """Stream using the multiplex websocket pool."""
        chunk_queue: queue.Queue = queue.Queue()
        error_holder: list = []

        async def _run():
            if self._multiplex_pool is None:
                raise ConnectionError("Multiplex pool is not initialized")
            async for chunk in self._multiplex_pool.stream_request(
                text=text,
                voice=voice,
                max_tokens=max_tokens if max_tokens is not None else self.max_tokens,
                temperature=temperature if temperature is not None else self.temperature,
                repetition_penalty=(
                    repetition_penalty
                    if repetition_penalty is not None
                    else self.repetition_penalty
                ),
            ):
                chunk_queue.put(chunk)

        async def _run_wrapper():
            try:
                await _run()
            except Exception as e:
                error_holder.append(e)
            finally:
                chunk_queue.put(None)

        future = asyncio.run_coroutine_threadsafe(_run_wrapper(), self._loop)

        while True:
            chunk = chunk_queue.get()
            if chunk is None:
                break
            yield chunk

        try:
            future.result(timeout=300.0)
        except Exception as e:
            if not error_holder:
                error_holder.append(e)

        if error_holder:
            raise error_holder[0]

    def _stream_from_pool(
        self,
        endpoint: str,
        text: str,
        voice: str,
        max_tokens: Optional[int],
        temperature: Optional[float],
        repetition_penalty: Optional[float],
    ) -> Iterator[bytes]:
        """Stream using a pooled connection (lowest latency)."""
        chunk_queue: queue.Queue = queue.Queue()
        error_holder: list = []

        async def _run():
            max_retries = 3
            for attempt in range(max_retries):
                ws = None
                try:
                    _, ws = await self._acquire(voice)
                    async for chunk in self._stream_on_connection(
                        ws, text, voice, max_tokens, temperature, repetition_penalty
                    ):
                        chunk_queue.put(chunk)
                    # Success
                    await self._release(ws, endpoint)
                    return
                except ConnectionClosed:
                    if ws is not None:
                        await self._release(ws, endpoint)
                    if attempt < max_retries - 1:
                        continue
                    raise ConnectionError(
                        "Connection closed unexpectedly after "
                        f"{max_retries} retries"
                    )
                except Exception:
                    if ws is not None:
                        await self._release(ws, endpoint)
                    raise

        async def _run_wrapper():
            try:
                await _run()
            except Exception as e:
                error_holder.append(e)
            finally:
                chunk_queue.put(None)

        future = asyncio.run_coroutine_threadsafe(_run_wrapper(), self._loop)

        while True:
            chunk = chunk_queue.get()
            if chunk is None:
                break
            yield chunk

        try:
            future.result(timeout=300.0)
        except Exception as e:
            if not error_holder:
                error_holder.append(e)

        if error_holder:
            raise error_holder[0]

    async def stream_async(
        self,
        text: str,
        voice: str = "josh",
        *,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
    ) -> AsyncIterator[bytes]:
        """
        Stream audio asynchronously from the Orpheus TTS model.

        If connect() was called first, uses a pooled connection for lowest
        latency. Otherwise, creates a new connection per request.

        Args:
            text: The text to convert to speech.
            voice: Voice to use. See orpheus_tts.VOICES for options.
            max_tokens: Override default max_tokens.
            temperature: Override default temperature.
            repetition_penalty: Override default repetition_penalty.

        Yields:
            bytes: Raw PCM audio chunks (int16, 48kHz, mono).

        Raises:
            ValueError: If voice is not recognized.
            AuthenticationError: If API key is invalid.
            ConnectionError: If connection to service fails.
            StreamingError: If an error occurs during streaming.
        """
        voice = voice.strip().lower()
        endpoint = _get_endpoint(voice, self._voice_endpoint_map)

        if self._connected and self._loop is not None:
            # Use pool — stream on background loop, bridge chunks back
            chunk_queue: queue.Queue = queue.Queue()
            error_holder: list = []

            async def _run():
                if self._multiplex_pool is not None:
                    async for chunk in self._multiplex_pool.stream_request(
                        text=text,
                        voice=voice,
                        max_tokens=max_tokens if max_tokens is not None else self.max_tokens,
                        temperature=temperature if temperature is not None else self.temperature,
                        repetition_penalty=(
                            repetition_penalty
                            if repetition_penalty is not None
                            else self.repetition_penalty
                        ),
                    ):
                        chunk_queue.put(chunk)
                    return

                max_retries = 3
                for attempt in range(max_retries):
                    ws = None
                    try:
                        _, ws = await self._acquire(voice)
                        async for chunk in self._stream_on_connection(
                            ws, text, voice, max_tokens, temperature,
                            repetition_penalty
                        ):
                            chunk_queue.put(chunk)
                        # Success
                        await self._release(ws, endpoint)
                        return
                    except ConnectionClosed:
                        if ws is not None:
                            await self._release(ws, endpoint)
                        if attempt < max_retries - 1:
                            continue
                        raise ConnectionError(
                            "Connection closed unexpectedly after "
                            f"{max_retries} retries"
                        )
                    except Exception:
                        if ws is not None:
                            await self._release(ws, endpoint)
                        raise

            async def _run_wrapper():
                try:
                    await _run()
                except Exception as e:
                    error_holder.append(e)
                finally:
                    chunk_queue.put(None)

            future = asyncio.run_coroutine_threadsafe(_run_wrapper(), self._loop)

            loop = asyncio.get_running_loop()
            while True:
                chunk = await loop.run_in_executor(None, chunk_queue.get)
                if chunk is None:
                    break
                yield chunk

            try:
                future.result(timeout=5.0)
            except Exception as e:
                if not error_holder:
                    error_holder.append(e)

            if error_holder:
                raise error_holder[0]
        else:
            # No pool - create new connection (higher latency)
            ws = None
            try:
                ws = await _connect_websocket(endpoint, self._get_auth_headers())
                async for chunk in self._stream_on_connection(
                    ws, text, voice, max_tokens, temperature,
                    repetition_penalty
                ):
                    yield chunk

            except ConnectionClosed as e:
                raise ConnectionError(
                    f"Connection closed unexpectedly: {e}"
                ) from e
            except OSError as e:
                raise ConnectionError(
                    f"Failed to connect to TTS service: {e}"
                ) from e
            except (AuthenticationError, StreamingError):
                raise
            except Exception as e:
                raise OrpheusError(f"Unexpected error: {e}") from e
            finally:
                if ws is not None:
                    try:
                        await ws.close()
                    except Exception:
                        pass

    def _stream_with_new_connection(
        self,
        text: str,
        voice: str,
        endpoint: str,
        max_tokens: Optional[int],
        temperature: Optional[float],
        repetition_penalty: Optional[float],
    ) -> Iterator[bytes]:
        """Stream with a fresh connection (includes handshake in latency)."""
        chunk_queue: queue.Queue = queue.Queue()
        error_holder: list = []

        async def collect():
            ws = None
            try:
                ws = await _connect_websocket(endpoint, self._get_auth_headers())
                async for chunk in self._stream_on_connection(
                    ws, text, voice, max_tokens, temperature,
                    repetition_penalty
                ):
                    chunk_queue.put(chunk)
            except Exception as e:
                error_holder.append(e)
            finally:
                if ws is not None:
                    try:
                        await ws.close()
                    except Exception:
                        pass
                chunk_queue.put(None)

        def run_async():
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                loop.run_until_complete(collect())
            finally:
                loop.run_until_complete(loop.shutdown_asyncgens())
                loop.close()

        thread = threading.Thread(target=run_async, daemon=True)
        thread.start()

        while True:
            chunk = chunk_queue.get()
            if chunk is None:
                break
            yield chunk

        thread.join(timeout=5.0)

        if error_holder:
            raise error_holder[0]

    def stream_to_bytes(
        self,
        text: str,
        voice: str = "josh",
        **kwargs,
    ) -> bytes:
        """
        Generate complete audio and return as bytes.

        Convenience method that collects all streamed chunks.

        Args:
            text: The text to convert to speech.
            voice: Voice to use.
            **kwargs: Additional arguments passed to stream().

        Returns:
            bytes: Complete PCM audio (int16, 48kHz, mono).
        """
        chunks = list(self.stream(text, voice, **kwargs))
        return b"".join(chunks)

    async def stream_to_bytes_async(
        self,
        text: str,
        voice: str = "josh",
        **kwargs,
    ) -> bytes:
        """
        Generate complete audio asynchronously and return as bytes.

        Args:
            text: The text to convert to speech.
            voice: Voice to use.
            **kwargs: Additional arguments passed to stream_async().

        Returns:
            bytes: Complete PCM audio (int16, 48kHz, mono).
        """
        chunks = []
        async for chunk in self.stream_async(text, voice, **kwargs):
            chunks.append(chunk)
        return b"".join(chunks)
