from __future__ import annotations

import time
from collections.abc import Coroutine
from typing import Any, Callable, TypeVar, cast

from wandb.sdk.lib import asyncio_compat

from .mailbox_handle import MailboxHandle

_T = TypeVar("_T")


def wait_with_progress(
    handle: MailboxHandle[_T],
    *,
    timeout: float | None,
    display_progress: Callable[[], Coroutine[Any, Any, None]],
) -> _T:
    """Wait for a handle, possibly displaying progress to the user.

    Equivalent to passing a single handle to `wait_all_with_progress`.
    """
    return wait_all_with_progress(
        [handle],
        timeout=timeout,
        display_progress=display_progress,
    )[0]


def wait_all_with_progress(
    handle_list: list[MailboxHandle[_T]],
    *,
    timeout: float | None,
    display_progress: Callable[[], Coroutine[Any, Any, None]],
) -> list[_T]:
    """Wait for multiple handles, possibly displaying progress to the user.

    Args:
        handle_list: The handles to wait for.
        timeout: A number of seconds after which to raise a TimeoutError,
            or None if this should never timeout.
        display_progress: An asyncio function that displays progress to
            the user. This function runs using the handles' AsyncioManager.

    Returns:
        A list where the Nth item is the Nth handle's result.

    Raises:
        ValueError: If the handles live in different asyncio threads.
        TimeoutError: If the overall timeout expires.
        HandleAbandonedError: If any handle becomes abandoned.
        Exception: Any exception from the display function is propagated.
    """
    if not handle_list:
        return []

    asyncer = handle_list[0].asyncer
    for handle in handle_list:
        if handle.asyncer is not asyncer:
            raise ValueError("Handles have different AsyncioManagers.")

    start_time = time.monotonic()

    async def progress_loop_with_timeout() -> list[_T]:
        async with asyncio_compat.cancel_on_exit(display_progress()):
            if timeout is not None:
                elapsed_time = time.monotonic() - start_time
                remaining_timeout = timeout - elapsed_time
            else:
                remaining_timeout = None

            return await _wait_handles_async(
                handle_list,
                timeout=remaining_timeout,
            )

    return asyncer.run(progress_loop_with_timeout)


async def _wait_handles_async(
    handle_list: list[MailboxHandle[_T]],
    *,
    timeout: float | None,
) -> list[_T]:
    """Asynchronously wait for multiple mailbox handles.

    Just like _wait_handles.
    """
    results: list[_T | None] = [None for _ in handle_list]

    async def wait_single(index: int) -> None:
        handle = handle_list[index]
        results[index] = await handle.wait_async(timeout=timeout)

    async with asyncio_compat.open_task_group() as task_group:
        for index in range(len(handle_list)):
            task_group.start_soon(wait_single(index))

    # NOTE: `list` is not subscriptable until Python 3.10, so we use List.
    return cast(list[_T], results)
