"""
"""
import inspect
import warnings
from datetime import timedelta
from functools import partial
from typing import Callable
from typing import Literal
from typing import TypeVar
from typing import overload
from typing_extensions import ParamSpec
from typing_extensions import Unpack

from ..config import Config
from .api import GPUSize
from .types import DynamicDuration
from .types import EmptyKwargs


P = ParamSpec('P')
R = TypeVar('R')


decorated_cache: dict[Callable, Callable] = {}


@overload
def GPU(
    task: None = None, *,
    duration: DynamicDuration[P] = None,
    size: Literal['large', 'xlarge'] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
    ...
@overload
def GPU(
    task: Callable[P, R], *,
    duration: DynamicDuration[P] = None,
    size: Literal['large', 'xlarge'] | None = None,
) -> Callable[P, R]:
    ...
def GPU(
    task: Callable[P, R] | None = None, *,
    duration: DynamicDuration[P] = None,
    size: Literal['large', 'xlarge'] | None = None,
    **kwargs: Unpack[EmptyKwargs],
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
    """
    ZeroGPU decorator

    Args:
        task (`Callable | None`): Python function that requires CUDA
        duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
        size (`"large" | "xlarge" | None`): Defaults to **large** when running on ZeroGPU

    Returns:
        `Callable`: GPU-ready function

    Examples:
        Basic usage:
        ```
        @spaces.GPU
        def fn(...):
            # CUDA is available here
            pass
        ```

        With custom duration:
        ```
        @spaces.GPU(duration=45) # Expressed in seconds
        def fn(...):
            pass
        ```

        With custom duration and size:
        ```
        @spaces.GPU(duration=45, size='xlarge')
        def fn(...):
            pass
        ```
    """
    if "enable_queue" in kwargs:
        warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
    if not callable(task):
        if isinstance(task, str): # pragma: no cover (@spaces.GPU('xlarge'))
            size = task
        elif task is not None: # pragma: no cover (@spaces.GPU(45))
            duration = task
        return partial(_GPU, duration=duration, size=size)
    return _GPU(task, duration, size)


def _GPU(
    task: Callable[P, R],
    duration: DynamicDuration[P],
    size: GPUSize | None,
) -> Callable[P, R]:

    if not Config.zero_gpu:
        return task

    from . import client
    from .wrappers import regular_function_wrapper
    from .wrappers import generator_function_wrapper

    if task in decorated_cache:
        # TODO: Assert same duration ?
        return decorated_cache[task] # type: ignore

    if inspect.iscoroutinefunction(task):
        raise NotImplementedError

    if inspect.isgeneratorfunction(task):
        decorated = generator_function_wrapper(task, duration, size)
    else:
        decorated = regular_function_wrapper(task, duration, size)

    setattr(decorated, 'zerogpu', None)

    decorated_cache.update({
        task:      decorated,
        decorated: decorated,
    })

    return decorated # type: ignore
