from __future__ import annotations

import asyncio
import math
import threading
from typing import TYPE_CHECKING

from typing_extensions import override

from wandb.proto import wandb_server_pb2 as spb

from .mailbox_handle import HandleAbandonedError, MailboxHandle

# Necessary to break an import loop.
if TYPE_CHECKING:
    from wandb.sdk.interface import interface


class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]):
    """A general handle for any ServerResponse."""

    def __init__(self, address: str) -> None:
        self._address = address
        self._lock = threading.Lock()
        self._event = threading.Event()

        self._abandoned = False
        self._response: spb.ServerResponse | None = None

        self._asyncio_events: dict[asyncio.Event, _AsyncioEvent] = dict()

    def deliver(self, response: spb.ServerResponse) -> None:
        """Deliver the response.

        This may only be called once. It is an error to respond to the same
        request more than once. It is a no-op if the handle has been abandoned.
        """
        with self._lock:
            if self._abandoned:
                return

            if self._response:
                raise ValueError(
                    f"A response has already been delivered to {self._address}."
                )

            self._response = response
            self._signal_done()

    @override
    def cancel(self, iface: interface.InterfaceBase) -> None:
        iface.publish_cancel(self._address)
        self.abandon()

    @override
    def abandon(self) -> None:
        with self._lock:
            self._abandoned = True
            self._signal_done()

    def _signal_done(self) -> None:
        """Indicate that the handle either got a response or became abandoned.

        The lock must be held.
        """
        # Unblock threads blocked on `wait_or`.
        self._event.set()

        # Unblock asyncio loops blocked on `wait_async`.
        for asyncio_event in self._asyncio_events.values():
            asyncio_event.set_threadsafe()
        self._asyncio_events.clear()

    @override
    def check(self) -> spb.ServerResponse | None:
        with self._lock:
            return self._response

    @override
    def wait_or(self, *, timeout: float | None) -> spb.ServerResponse:
        if timeout is not None and not math.isfinite(timeout):
            raise ValueError("Timeout must be finite or None.")

        if not self._event.wait(timeout=timeout):
            raise TimeoutError(
                f"Timed out waiting for response on {self._address}",
            )

        with self._lock:
            if self._response:
                return self._response

            assert self._abandoned
            raise HandleAbandonedError()

    @override
    async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse:
        if timeout is not None and not math.isfinite(timeout):
            raise ValueError("Timeout must be finite or None.")

        evt = asyncio.Event()
        self._add_asyncio_event(asyncio.get_event_loop(), evt)

        try:
            await asyncio.wait_for(evt.wait(), timeout=timeout)

        except (asyncio.TimeoutError, TimeoutError) as e:
            with self._lock:
                if self._response:
                    return self._response
                elif self._abandoned:
                    raise HandleAbandonedError()
                else:
                    raise TimeoutError(
                        f"Timed out waiting for response on {self._address}"
                    ) from e

        else:
            with self._lock:
                if self._response:
                    return self._response

                assert self._abandoned
                raise HandleAbandonedError()

        finally:
            self._forget_asyncio_event(evt)

    def _add_asyncio_event(
        self,
        loop: asyncio.AbstractEventLoop,
        event: asyncio.Event,
    ) -> None:
        """Add an event to signal when a response is received.

        If a response already exists, this notifies the event loop immediately.
        """
        asyncio_event = _AsyncioEvent(loop, event)

        with self._lock:
            if self._response or self._abandoned:
                asyncio_event.set_threadsafe()
            else:
                self._asyncio_events[event] = asyncio_event

    def _forget_asyncio_event(self, event: asyncio.Event) -> None:
        """Cancel signalling an event when a response is received."""
        with self._lock:
            self._asyncio_events.pop(event, None)


class _AsyncioEvent:
    def __init__(
        self,
        loop: asyncio.AbstractEventLoop,
        event: asyncio.Event,
    ):
        self._loop = loop
        self._event = event

    def set_threadsafe(self) -> None:
        """Set the asyncio event in its own loop."""
        self._loop.call_soon_threadsafe(self._event.set)
