from __future__ import annotations

import asyncio
import functools
import inspect
import logging
from collections.abc import Sequence
from typing import Any, Optional, Protocol, TypeVar

import wandb.sdk
import wandb.util
from wandb.sdk.lib import telemetry as wb_telemetry
from wandb.sdk.lib.timer import Timer

logger = logging.getLogger(__name__)


AutologInitArgs = Optional[dict[str, Any]]


K = TypeVar("K", bound=str)
V = TypeVar("V")


class Response(Protocol[K, V]):
    def __getitem__(self, key: K) -> V: ...  # pragma: no cover

    def get(self, key: K, default: V | None = None) -> V | None: ...  # pragma: no cover


class ArgumentResponseResolver(Protocol):
    def __call__(
        self,
        args: Sequence[Any],
        kwargs: dict[str, Any],
        response: Response,
        start_time: float,
        time_elapsed: float,
    ) -> dict[str, Any] | None: ...  # pragma: no cover


class PatchAPI:
    def __init__(
        self,
        name: str,
        symbols: Sequence[str],
        resolver: ArgumentResponseResolver,
    ) -> None:
        """Patches the API to log wandb Media or metrics."""
        # name of the LLM provider, e.g. "Cohere" or "OpenAI" or package name like "Transformers"
        self.name = name
        # api library name, e.g. "cohere" or "openai" or "transformers"
        self._api = None
        # dictionary of original methods
        self.original_methods: dict[str, Any] = {}
        # list of symbols to patch, e.g. ["Client.generate", "Edit.create"] or ["Pipeline.__call__"]
        self.symbols = symbols
        # resolver callable to convert args/response into a dictionary of wandb media objects or metrics
        self.resolver = resolver

    @property
    def set_api(self) -> Any:
        """Returns the API module."""
        lib_name = self.name.lower()
        if self._api is None:
            self._api = wandb.util.get_module(
                name=lib_name,
                required=f"To use the W&B {self.name} Autolog, "
                f"you need to have the `{lib_name}` python "
                f"package installed. Please install it with `pip install {lib_name}`.",
                lazy=False,
            )
        return self._api

    def patch(self, run: wandb.Run) -> None:
        """Patches the API to log media or metrics to W&B."""
        for symbol in self.symbols:
            # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
            symbol_parts = symbol.split(".")
            # and get the attribute from the module
            original = functools.reduce(getattr, symbol_parts, self.set_api)

            def method_factory(original_method: Any):
                async def async_method(*args, **kwargs):
                    future = asyncio.Future()

                    async def callback(coro):
                        try:
                            result = await coro
                            loggable_dict = self.resolver(
                                args, kwargs, result, timer.start_time, timer.elapsed
                            )
                            if loggable_dict is not None:
                                run.log(loggable_dict)
                            future.set_result(result)
                        except Exception as e:
                            logger.warning(e)

                    with Timer() as timer:
                        coro = original_method(*args, **kwargs)
                        asyncio.ensure_future(callback(coro))

                    return await future

                def sync_method(*args, **kwargs):
                    with Timer() as timer:
                        result = original_method(*args, **kwargs)
                        try:
                            loggable_dict = self.resolver(
                                args, kwargs, result, timer.start_time, timer.elapsed
                            )
                            if loggable_dict is not None:
                                run.log(loggable_dict)
                        except Exception as e:
                            logger.warning(e)
                        return result

                if inspect.iscoroutinefunction(original_method):
                    return functools.wraps(original_method)(async_method)
                else:
                    return functools.wraps(original_method)(sync_method)

            # save original method
            self.original_methods[symbol] = original
            # monkey patch the method
            if len(symbol_parts) == 1:
                setattr(self.set_api, symbol_parts[0], method_factory(original))
            else:
                setattr(
                    functools.reduce(getattr, symbol_parts[:-1], self.set_api),
                    symbol_parts[-1],
                    method_factory(original),
                )

    def unpatch(self) -> None:
        """Unpatches the API."""
        for symbol, original in self.original_methods.items():
            # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
            symbol_parts = symbol.split(".")
            # unpatch the method
            if len(symbol_parts) == 1:
                setattr(self.set_api, symbol_parts[0], original)
            else:
                setattr(
                    functools.reduce(getattr, symbol_parts[:-1], self.set_api),
                    symbol_parts[-1],
                    original,
                )


class AutologAPI:
    def __init__(
        self,
        name: str,
        symbols: Sequence[str],
        resolver: ArgumentResponseResolver,
        telemetry_feature: str | None = None,
    ) -> None:
        """Autolog API calls to W&B."""
        self._telemetry_feature = telemetry_feature
        self._patch_api = PatchAPI(
            name=name,
            symbols=symbols,
            resolver=resolver,
        )
        self._name = self._patch_api.name
        self._run: wandb.Run | None = None
        self.__run_created_by_autolog: bool = False

    @property
    def _is_enabled(self) -> bool:
        """Returns whether autologging is enabled."""
        return self._run is not None

    def __call__(self, init: AutologInitArgs = None) -> None:
        """Enable autologging."""
        self.enable(init=init)

    def _run_init(self, init: AutologInitArgs = None) -> None:
        """Handle wandb run initialization."""
        # - autolog(init: dict = {...}) calls wandb.init(**{...})
        #   regardless of whether there is a wandb.run or not,
        #   we only track if the run was created by autolog
        #    - todo: autolog(init: dict | run = run) would use the user-provided run
        # - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
        if init:
            _wandb_run = wandb.run
            # we delegate dealing with the init dict to wandb.init()
            self._run = wandb.init(**init)
            if _wandb_run != self._run:
                self.__run_created_by_autolog = True
        elif wandb.run is None:
            self._run = wandb.init()
            self.__run_created_by_autolog = True
        else:
            self._run = wandb.run

    def enable(self, init: AutologInitArgs = None) -> None:
        """Enable autologging.

        Args:
            init: Optional dictionary of arguments to pass to wandb.init().

        """
        if self._is_enabled:
            logger.info(
                f"{self._name} autologging is already enabled, disabling and re-enabling."
            )
            self.disable()

        logger.info(f"Enabling {self._name} autologging.")
        self._run_init(init=init)

        self._patch_api.patch(self._run)

        if self._telemetry_feature:
            with wb_telemetry.context(self._run) as tel:
                setattr(tel.feature, self._telemetry_feature, True)

    def disable(self) -> None:
        """Disable autologging."""
        if self._run is None:
            return

        logger.info(f"Disabling {self._name} autologging.")

        if self.__run_created_by_autolog:
            self._run.finish()
            self.__run_created_by_autolog = False

        self._run = None

        self._patch_api.unpatch()
