"""Events that trigger W&B Automations."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, Optional, Union

from pydantic import AfterValidator, Field
from typing_extensions import Annotated, get_args

from wandb._pydantic import GQLBase, model_validator, pydantic_isinstance
from wandb._strutils import nameof

from ._filters import And, MongoLikeFilter
from ._filters.expressions import FilterableField
from ._filters.run_metrics import (
    MetricChangeFilter,
    MetricThresholdFilter,
    MetricVal,
    MetricZScoreFilter,
)
from ._filters.run_states import StateFilter, StateOperand
from ._generated import FilterEventFields
from ._validators import (
    JsonEncoded,
    LenientStrEnum,
    ensure_json,
    wrap_mutation_event_filter,
    wrap_run_event_run_filter,
)
from .actions import InputAction, InputActionTypes, SavedActionTypes
from .scopes import ArtifactCollectionScope, AutomationScope, ProjectScope

if TYPE_CHECKING:
    from .automations import NewAutomation


# NOTE: Re-defined publicly with a more readable name for easier access
class EventType(LenientStrEnum):
    """The type of event that triggers an automation."""

    # ---------------------------------------------------------------------------
    # Events triggered by GraphQL mutations
    UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS"  # NOTE: Avoid in new automations

    CREATE_ARTIFACT = "CREATE_ARTIFACT"
    ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS"
    LINK_ARTIFACT = "LINK_MODEL"
    # Note: "LINK_MODEL" is the (legacy) value expected by the backend, but we
    # name it "LINK_ARTIFACT" here in the public API for clarity and consistency.

    # ---------------------------------------------------------------------------
    # Events triggered by Run conditions
    RUN_METRIC_THRESHOLD = "RUN_METRIC"
    RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE"
    RUN_STATE = "RUN_STATE"
    RUN_METRIC_ZSCORE = "RUN_METRIC_ZSCORE"


# ------------------------------------------------------------------------------
# Saved types: for parsing response data from saved automations


# Note: In GQL responses containing saved automation data, the filter is wrapped
# in an extra `filter` key.
class _WrappedSavedEventFilter(GQLBase):  # from: TriggeringFilterEvent
    filter: JsonEncoded[MongoLikeFilter] = And()


class _WrappedMetricThresholdFilter(GQLBase):  # from: RunMetricFilter
    event_type: Annotated[
        Literal[EventType.RUN_METRIC_THRESHOLD],
        Field(exclude=True, repr=False),
    ] = EventType.RUN_METRIC_THRESHOLD

    threshold_filter: MetricThresholdFilter

    @model_validator(mode="before")
    @classmethod
    def _nest_inner_filter(cls, v: Any) -> Any:
        # Yeah, we've got a lot of nesting due to backend schema constraints.
        if pydantic_isinstance(v, MetricThresholdFilter):
            return cls(threshold_filter=v)
        return v


class _WrappedMetricChangeFilter(GQLBase):  # from: RunMetricFilter
    event_type: Annotated[
        Literal[EventType.RUN_METRIC_CHANGE],
        Field(exclude=True, repr=False),
    ] = EventType.RUN_METRIC_CHANGE

    change_filter: MetricChangeFilter

    @model_validator(mode="before")
    @classmethod
    def _nest_inner_filter(cls, v: Any) -> Any:
        # Yeah, we've got a lot of nesting due to backend schema constraints.
        if pydantic_isinstance(v, MetricChangeFilter):
            return cls(change_filter=v)
        return v


class _WrappedMetricZScoreFilter(GQLBase):  # from: RunMetricFilter
    event_type: Annotated[
        Literal[EventType.RUN_METRIC_ZSCORE],
        Field(exclude=True, repr=False),
    ] = EventType.RUN_METRIC_ZSCORE

    zscore_filter: MetricZScoreFilter

    @model_validator(mode="before")
    @classmethod
    def _nest_inner_filter(cls, v: Any) -> Any:
        if pydantic_isinstance(v, MetricZScoreFilter):
            return cls(zscore_filter=v)
        return v


class RunMetricFilter(GQLBase):  # from: TriggeringRunMetricEvent
    run: Annotated[
        JsonEncoded[MongoLikeFilter],
        AfterValidator(wrap_run_event_run_filter),
        Field(alias="run_filter"),
    ] = And()
    """Filters that must match any runs that will trigger this event."""

    metric: Annotated[
        Union[
            _WrappedMetricThresholdFilter,
            _WrappedMetricChangeFilter,
            _WrappedMetricZScoreFilter,
        ],
        Field(alias="run_metric_filter"),
    ]
    """Metric condition(s) that must be satisfied for this event to trigger."""

    # ------------------------------------------------------------------------------
    legacy_metric_filter: Annotated[
        Optional[JsonEncoded[MetricThresholdFilter]],
        Field(alias="metric_filter", deprecated=True),
    ] = None
    """Deprecated legacy field for defining run metric threshold events.

    For new automations, use the `metric` field (JSON alias `run_metric_filter`).
    """

    @model_validator(mode="before")
    @classmethod
    def _nest_metric_filter(cls, v: Any) -> Any:
        # If no run filter is given, automatically nest the metric filter and
        # let inner validators reshape further as needed.
        if pydantic_isinstance(
            v, (MetricThresholdFilter, MetricChangeFilter, MetricZScoreFilter)
        ):
            return cls(metric=v)
        return v


class RunStateFilter(GQLBase):  # from: TriggeringRunStateEvent
    """Represents a filter for triggering events based on changes in run states."""

    run: Annotated[
        JsonEncoded[MongoLikeFilter],
        AfterValidator(wrap_run_event_run_filter),
        Field(alias="run_filter"),
    ] = And()
    """Filters that must match any runs that will trigger this event."""

    state: Annotated[StateFilter, Field(alias="run_state_filter")]
    """Run state condition(s) that must be satisfied for this event to trigger."""

    @model_validator(mode="before")
    @classmethod
    def _nest_state_filter(cls, v: Any) -> Any:
        # If no run filter is given, automatically nest the state filter and
        # let inner validators reshape further as needed.
        if pydantic_isinstance(v, StateFilter):
            return cls(state=v)
        return v


class SavedEvent(FilterEventFields):  # from: FilterEventTriggeringCondition
    """A triggering event from a saved automation."""

    event_type: Annotated[EventType, Field(frozen=True)]  # type: ignore[assignment]

    # We override the type of the `filter` field in order to enforce the expected
    # structure for the JSON data when validating and serializing.
    filter: JsonEncoded[
        Union[_WrappedSavedEventFilter, RunMetricFilter, RunStateFilter]
    ]
    """The condition(s) under which this event triggers an automation."""


# ------------------------------------------------------------------------------
# Input types: for creating or updating automations


# Note: The GQL input for `eventFilter` does NOT wrap the filter in an extra
# `filter` key, unlike the `eventFilter` in GQL responses for saved automations.
class _BaseEventInput(GQLBase):
    event_type: EventType

    scope: AutomationScope
    """The scope of the event."""

    filter: JsonEncoded[Any]

    def then(self, action: InputAction) -> NewAutomation:
        """Define a new Automation in which this event triggers the given action."""
        from .automations import NewAutomation

        if isinstance(action, (InputActionTypes, SavedActionTypes)):
            return NewAutomation(event=self, action=action)

        raise TypeError(f"Expected a valid action, got: {nameof(type(action))!r}")

    def __rshift__(self, other: InputAction) -> NewAutomation:
        """Implement `event >> action` to define an automation."""
        return self.then(other)


# ------------------------------------------------------------------------------
# Events that trigger on specific mutations in the backend
class _BaseMutationEventInput(_BaseEventInput):
    filter: Annotated[
        JsonEncoded[MongoLikeFilter],
        AfterValidator(wrap_mutation_event_filter),
    ] = And()
    """Additional conditions(s), if any, that are required for this event to trigger."""


class OnLinkArtifact(_BaseMutationEventInput):
    """A new artifact is linked to a collection.

    Examples:
    Define an event that triggers when an artifact is linked to the
    collection "my-collection" with the alias "prod":

    ```python
    from wandb import Api
    from wandb.automations import OnLinkArtifact, ArtifactEvent

    api = Api()
    collection = api.artifact_collection(name="my-collection", type_name="model")

    event = OnLinkArtifact(
        scope=collection,
        filter=ArtifactEvent.alias.eq("prod"),
    )
    ```
    """

    event_type: Literal[EventType.LINK_ARTIFACT] = EventType.LINK_ARTIFACT


class OnAddArtifactAlias(_BaseMutationEventInput):
    """A new alias is assigned to an artifact.

    Examples:
    Define an event that triggers whenever the alias "prod" is assigned to
    any artifact in the collection "my-collection":

    ```python
    from wandb import Api
    from wandb.automations import OnAddArtifactAlias, ArtifactEvent

    api = Api()
    collection = api.artifact_collection(name="my-collection", type_name="model")

    event = OnAddArtifactAlias(
        scope=collection,
        filter=ArtifactEvent.alias.eq("prod"),
    )
    ```
    """

    event_type: Literal[EventType.ADD_ARTIFACT_ALIAS] = EventType.ADD_ARTIFACT_ALIAS


class OnCreateArtifact(_BaseMutationEventInput):
    """A new artifact is created.

    Examples:
    Define an event that triggers when a new artifact is created in the
    collection "my-collection":

    ```python
    from wandb import Api
    from wandb.automations import OnCreateArtifact

    api = Api()
    collection = api.artifact_collection(name="my-collection", type_name="model")

    event = OnCreateArtifact(scope=collection)
    ```
    """

    event_type: Literal[EventType.CREATE_ARTIFACT] = EventType.CREATE_ARTIFACT

    scope: ArtifactCollectionScope
    """The scope of the event: must be an artifact collection."""


# ------------------------------------------------------------------------------
# Events that trigger on run conditions
class _BaseRunEventInput(_BaseEventInput):
    scope: ProjectScope
    """The scope of the event: must be a project."""


class OnRunMetric(_BaseRunEventInput):
    """A run metric satisfies a user-defined condition.

    Examples:
    Define an event that triggers for any run in project "my-project" when
    the average of the last 5 values of metric "my-metric" exceeds 123.45:

    ```python
    from wandb import Api
    from wandb.automations import OnRunMetric, RunEvent

    api = Api()
    project = api.project(name="my-project")

    event = OnRunMetric(
        scope=project,
        filter=RunEvent.metric("my-metric").avg(5).gt(123.45),
    )
    ```
    """

    event_type: Literal[
        EventType.RUN_METRIC_THRESHOLD,
        EventType.RUN_METRIC_CHANGE,
        EventType.RUN_METRIC_ZSCORE,
    ]

    filter: JsonEncoded[RunMetricFilter]
    """Run and/or metric condition(s) that must be satisfied for this event to trigger."""

    @model_validator(mode="before")
    @classmethod
    def _infer_event_type(cls, data: Any) -> Any:
        """Infer the event type from the inner filter during validation.

        This supports both "threshold" and "change" metric filters, which can
        only be determined after parsing and validating the inner JSON data.
        """
        if isinstance(data, dict) and (raw_filter := data.get("filter")):
            # At this point, `raw_filter` may or may not be JSON-serialized
            parsed_filter = RunMetricFilter.model_validate_json(ensure_json(raw_filter))
            return {**data, "event_type": parsed_filter.metric.event_type}

        return data


class OnRunState(_BaseRunEventInput):
    """A run state changes.

    Examples:
    Define an event that triggers for any run in project "my-project" when
    its state changes to "finished" (i.e. succeeded) or "failed":

    ```python
    from wandb import Api
    from wandb.automations import OnRunState

    api = Api()
    project = api.project(name="my-project")

    event = OnRunState(
        scope=project,
        filter=RunEvent.state.in_(["finished", "failed"]),
    )
    ```
    """

    event_type: Literal[EventType.RUN_STATE] = EventType.RUN_STATE

    filter: JsonEncoded[RunStateFilter]
    """Run state condition(s) that must be satisfied for this event to trigger."""


# for type annotations
InputEvent = Annotated[
    Union[
        OnLinkArtifact,
        OnAddArtifactAlias,
        OnCreateArtifact,
        OnRunMetric,
        OnRunState,
    ],
    Field(discriminator="event_type"),
]
# for runtime type checks
InputEventTypes: tuple[type, ...] = get_args(InputEvent.__origin__)  # type: ignore[attr-defined]


# ----------------------------------------------------------------------------


class RunEvent:
    name = FilterableField(server_name="display_name")
    # `Run.name` is actually filtered on `Run.display_name` in the backend.
    # We can't reasonably expect users to know this a priori, so
    # automatically fix it here.

    state = StateOperand()

    @staticmethod
    def metric(name: str) -> MetricVal:
        """Define a metric filter condition."""
        return MetricVal(name=name)


class ArtifactEvent:
    alias = FilterableField()


MetricThresholdFilter.model_rebuild()
RunMetricFilter.model_rebuild()
_WrappedSavedEventFilter.model_rebuild()

OnLinkArtifact.model_rebuild()
OnAddArtifactAlias.model_rebuild()
OnCreateArtifact.model_rebuild()
OnRunMetric.model_rebuild()

__all__ = [
    "EventType",
    *(nameof(cls) for cls in InputEventTypes),
    "RunEvent",
    "ArtifactEvent",
    "MetricThresholdFilter",
    "MetricChangeFilter",
    "MetricZScoreFilter",
]
