"""Custom WebSocket transport support for the Deepgram Python SDK.

Allows users to swap in custom transports (BiDi/SSE, test doubles, proxied
connections) by providing a ``transport_factory`` when constructing a
:class:`DeepgramClient` or :class:`AsyncDeepgramClient`.

A transport factory is a callable: ``factory(url, headers)`` that returns a
context manager (sync) or async context manager (async) yielding a transport
object with ``send()``, ``recv()``, and iteration support.

Usage::

    from deepgram import DeepgramClient

    client = DeepgramClient(api_key="...", transport_factory=MySyncTransport)
"""

import importlib
import sys
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Callable, Dict, Optional

from .transport_interface import AsyncTransport, SyncTransport

# Re-export so existing ``from deepgram.transport import SyncTransport`` still works.
__all__ = ["SyncTransport", "AsyncTransport", "install_transport", "restore_transport"]

# ---------------------------------------------------------------------------
# Module paths that contain the websocket references we need to patch.
# All 8 are auto-generated by Fern — we never modify their source.
# ---------------------------------------------------------------------------
_TARGET_MODULES = [
    "deepgram.listen.v1.raw_client",
    "deepgram.listen.v1.client",
    "deepgram.listen.v2.raw_client",
    "deepgram.listen.v2.client",
    "deepgram.speak.v1.raw_client",
    "deepgram.speak.v1.client",
    "deepgram.agent.v1.raw_client",
    "deepgram.agent.v1.client",
]

# Originals stashed by install_transport() so restore_transport() can undo.
_originals: Dict[str, Dict[str, Any]] = {}

# Active factories — used to detect conflicting installs.
_active_sync_factory: Optional[Callable] = None
_active_async_factory: Optional[Callable] = None


class _SyncTransportShim:
    """Drop-in replacement for ``websockets.sync.client``.

    The auto-generated code calls::

        websockets_sync_client.connect(ws_url, additional_headers=headers)

    This shim's ``.connect()`` delegates to the user-supplied factory.
    """

    def __init__(self, factory: Callable) -> None:
        self._factory = factory

    @contextmanager
    def connect(self, url: str, additional_headers: Optional[Dict[str, str]] = None):
        headers = additional_headers or {}
        transport = self._factory(url, headers)
        try:
            yield transport
        finally:
            transport.close()


class _AsyncTransportShim:
    """Drop-in replacement for ``websockets_client_connect``.

    The auto-generated code calls::

        websockets_client_connect(ws_url, extra_headers=headers)

    This shim is callable and returns an async context manager.
    """

    def __init__(self, factory: Callable) -> None:
        self._factory = factory

    @asynccontextmanager
    async def _connect(self, url: str, extra_headers: Optional[Dict[str, str]] = None):
        headers = extra_headers or {}
        transport = self._factory(url, headers)
        try:
            yield transport
        finally:
            await transport.close()

    def __call__(self, url: str, extra_headers: Optional[Dict[str, str]] = None):
        return self._connect(url, extra_headers=extra_headers)


def install_transport(
    *,
    sync_factory: Optional[Callable] = None,
    async_factory: Optional[Callable] = None,
) -> None:
    """Monkey-patch the 8 auto-generated modules to use custom transports.

    Parameters
    ----------
    sync_factory
        ``factory(url, headers) -> transport`` used for sync WebSocket clients.
    async_factory
        ``factory(url, headers) -> transport`` used for async WebSocket clients.

    Raises
    ------
    RuntimeError
        If a different transport factory is already installed. Re-installing
        the same factory is allowed (idempotent). Call :func:`restore_transport`
        first to switch factories.
    """
    global _active_sync_factory, _active_async_factory

    if sync_factory is not None and _active_sync_factory is not None and _active_sync_factory is not sync_factory:
        raise RuntimeError(
            "A different sync transport factory is already installed. "
            "Only one transport factory per process is supported because "
            "transport patching is global. Call restore_transport() before "
            "creating a new client with a different transport_factory."
        )
    if async_factory is not None and _active_async_factory is not None and _active_async_factory is not async_factory:
        raise RuntimeError(
            "A different async transport factory is already installed. "
            "Only one transport factory per process is supported because "
            "transport patching is global. Call restore_transport() before "
            "creating a new client with a different transport_factory."
        )

    if sync_factory is not None:
        _active_sync_factory = sync_factory
    if async_factory is not None:
        _active_async_factory = async_factory

    sync_shim = _SyncTransportShim(sync_factory) if sync_factory else None
    async_shim = _AsyncTransportShim(async_factory) if async_factory else None

    for mod_path in _TARGET_MODULES:
        mod = sys.modules.get(mod_path)
        if mod is None:
            try:
                mod = importlib.import_module(mod_path)
            except ImportError:
                continue

        # Stash originals on first install (don't overwrite if already stashed).
        if mod_path not in _originals:
            _originals[mod_path] = {
                "websockets_sync_client": getattr(mod, "websockets_sync_client", None),
                "websockets_client_connect": getattr(mod, "websockets_client_connect", None),
            }

        if sync_shim is not None and hasattr(mod, "websockets_sync_client"):
            mod.websockets_sync_client = sync_shim  # type: ignore[attr-defined]

        if async_shim is not None and hasattr(mod, "websockets_client_connect"):
            mod.websockets_client_connect = async_shim  # type: ignore[attr-defined]


def restore_transport() -> None:
    """Undo all patches applied by :func:`install_transport`."""
    global _active_sync_factory, _active_async_factory

    for mod_path, saved in _originals.items():
        mod = sys.modules.get(mod_path)
        if mod is None:
            continue
        for attr, original in saved.items():
            if original is not None:
                setattr(mod, attr, original)
    _originals.clear()
    _active_sync_factory = None
    _active_async_factory = None
