from __future__ import annotations

import re
from typing import Literal

import wandb
import wandb.util

_gym_version_lt_0_26: bool | None = None
_gymnasium_version_lt_1_0_0: bool | None = None

_required_error_msg = (
    "Couldn't import the gymnasium python package, install with `pip install gymnasium`"
)
GymLib = Literal["gym", "gymnasium"]


def monitor():
    """Monitor a gym environment.

    Supports both gym and gymnasium.
    """
    gym_lib: GymLib | None = None

    # gym is not maintained anymore, gymnasium is the drop-in replacement - prefer it
    if wandb.util.get_module("gymnasium") is not None:
        gym_lib = "gymnasium"
    elif wandb.util.get_module("gym") is not None:
        gym_lib = "gym"

    if gym_lib is None:
        raise wandb.Error(_required_error_msg)

    global _gym_version_lt_0_26
    global _gymnasium_version_lt_1_0_0

    if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None:
        if gym_lib == "gym":
            import gym
        else:
            import gymnasium as gym  # type: ignore

        from packaging.version import parse

        gym_lib_version = parse(gym.__version__)
        _gym_version_lt_0_26 = gym_lib_version < parse("0.26.0")
        _gymnasium_version_lt_1_0_0 = gym_lib_version < parse("1.0.0a1")

    path = "path"  # Default path
    if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
        vcr_recorder_attribute = "RecordVideo"
        wrappers = wandb.util.get_module(
            f"{gym_lib}.wrappers",
            required=_required_error_msg,
        )
        recorder = getattr(wrappers, vcr_recorder_attribute)
    else:
        vcr = wandb.util.get_module(
            f"{gym_lib}.wrappers.monitoring.video_recorder",
            required=_required_error_msg,
        )
        # Breaking change in gym 0.26.0
        if _gym_version_lt_0_26:
            vcr_recorder_attribute = "ImageEncoder"
            recorder = getattr(vcr, vcr_recorder_attribute)
            path = "output_path"  # Override path for older gym versions
        else:
            vcr_recorder_attribute = "VideoRecorder"
            recorder = getattr(vcr, vcr_recorder_attribute)

    recorder.orig_close = recorder.close

    def close(self):
        recorder.orig_close(self)
        if not self.enabled:
            return
        if wandb.run:
            m = re.match(r".+(video\.\d+).+", getattr(self, path))
            key = m.group(1) if m else "videos"
            wandb.log({key: wandb.Video(getattr(self, path))})

    def del_(self):
        self.orig_close()

    if not _gym_version_lt_0_26:
        recorder.__del__ = del_
    recorder.close = close

    if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
        wrapper_name = vcr_recorder_attribute
    else:
        wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}"

    wandb.patched["gym"].append(
        [
            f"{gym_lib}.wrappers.{wrapper_name}",
            "close",
        ]
    )
