from __future__ import annotations

import inspect
import pickle
from functools import wraps
from pathlib import Path

import wandb
from wandb.sdk.lib import telemetry as wb_telemetry

from . import errors

try:
    from metaflow import current
except ImportError as e:
    raise Exception(
        "Error: `metaflow` not installed >> This integration requires metaflow!"
        " To fix, please `pip install -Uqq metaflow`"
    ) from e


try:
    from . import data_pandas
except errors.MissingDependencyError as e:
    e.warn()
    data_pandas = None

try:
    from . import data_pytorch
except errors.MissingDependencyError as e:
    e.warn()
    data_pytorch = None

try:
    from . import data_sklearn
except errors.MissingDependencyError as e:
    e.warn()
    data_sklearn = None


class ArtifactProxy:
    def __init__(self, flow):
        # do this to avoid recursion problem with __setattr__
        self.__dict__.update(
            {
                "flow": flow,
                "inputs": {},
                "outputs": {},
                "base": set(dir(flow)),
                "params": {p: getattr(flow, p) for p in current.parameter_names},
            }
        )

    def __setattr__(self, key, val):
        self.outputs[key] = val
        return setattr(self.flow, key, val)

    def __getattr__(self, key):
        if key not in self.base and key not in self.outputs:
            self.inputs[key] = getattr(self.flow, key)
        return getattr(self.flow, key)


def _track_scalar(
    name: str,
    data: dict | list | set | str | int | float | bool,
    run,
    testing: bool = False,
) -> str | None:
    if testing:
        return "scalar"

    run.log({name: data})
    return None


def _track_path(
    name: str,
    data: Path,
    run,
    testing: bool = False,
) -> str | None:
    if testing:
        return "Path"

    artifact = wandb.Artifact(name, type="dataset")
    if data.is_dir():
        artifact.add_dir(data)
    elif data.is_file():
        artifact.add_file(data)
    run.log_artifact(artifact)
    wandb.termlog(f"Logging artifact: {name} ({type(data)})")
    return None


def _track_generic(
    name: str,
    data,
    run,
    testing: bool = False,
) -> str | None:
    if testing:
        return "generic"

    artifact = wandb.Artifact(name, type="other")
    with artifact.new_file(f"{name}.pkl", "wb") as f:
        pickle.dump(data, f)
    run.log_artifact(artifact)
    wandb.termlog(f"Logging artifact: {name} ({type(data)})")
    return None


def wandb_track(
    name: str,
    data,
    datasets: bool = False,
    models: bool = False,
    others: bool = False,
    run: wandb.Run | None = None,
    testing: bool = False,
) -> str | None:
    """Track data as wandb artifacts based on type and flags."""
    # Check for pandas DataFrame
    if data_pandas and data_pandas.is_dataframe(data) and datasets:
        return data_pandas.track_dataframe(name, data, run, testing)

    # Check for PyTorch Module
    if data_pytorch and data_pytorch.is_nn_module(data) and models:
        return data_pytorch.track_nn_module(name, data, run, testing)

    # Check for scikit-learn BaseEstimator
    if data_sklearn and data_sklearn.is_estimator(data) and models:
        return data_sklearn.track_estimator(name, data, run, testing)

    # Check for Path objects
    if isinstance(data, Path) and datasets:
        return _track_path(name, data, run, testing)

    # Check for scalar types
    if isinstance(data, (dict, list, set, str, int, float, bool)):
        return _track_scalar(name, data, run, testing)

    # Generic fallback
    if others:
        return _track_generic(name, data, run, testing)

    # No action taken
    return None


def wandb_use(
    name: str,
    data,
    datasets: bool = False,
    models: bool = False,
    others: bool = False,
    run=None,
    testing: bool = False,
) -> str | None:
    """Use wandb artifacts based on data type and flags."""
    # Skip scalar types - nothing to use
    if isinstance(data, (dict, list, set, str, int, float, bool)):
        return None

    try:
        # Check for pandas DataFrame
        if data_pandas and data_pandas.is_dataframe(data) and datasets:
            return data_pandas.use_dataframe(name, run, testing)

        # Check for PyTorch Module
        elif data_pytorch and data_pytorch.is_nn_module(data) and models:
            return data_pytorch.use_nn_module(name, run, testing)

        # Check for scikit-learn BaseEstimator
        elif data_sklearn and data_sklearn.is_estimator(data) and models:
            return data_sklearn.use_estimator(name, run, testing)

        # Check for Path objects
        elif isinstance(data, Path) and datasets:
            return _use_path(name, data, run, testing)

        # Generic fallback
        elif others:
            return _use_generic(name, data, run, testing)

        else:
            return None

    except wandb.CommError:
        wandb.termwarn(
            f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
            " If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier),"
            " then you can safely ignore this. Otherwise you may want to check your internet connection!"
        )
        return None


def _use_path(
    name: str,
    data: Path,
    run,
    testing: bool = False,
) -> str | None:
    if testing:
        return "datasets"

    run.use_artifact(f"{name}:latest")
    wandb.termlog(f"Using artifact: {name} ({type(data)})")
    return None


def _use_generic(
    name: str,
    data,
    run,
    testing: bool = False,
) -> str | None:
    if testing:
        return "others"

    run.use_artifact(f"{name}:latest")
    wandb.termlog(f"Using artifact: {name} ({type(data)})")
    return None


def coalesce(*arg):
    return next((a for a in arg if a is not None), None)


def wandb_log(
    func=None,
    /,
    datasets: bool = False,
    models: bool = False,
    others: bool = False,
    settings: wandb.Settings | None = None,
):
    """Automatically log parameters and artifacts to W&B.

    This decorator can be applied to a flow, step, or both:

    - Decorating a step enables or disables logging within that step
    - Decorating a flow is equivalent to decorating all steps
    - Decorating a step after decorating its flow overwrites the flow decoration

    Args:
        func: The step method or flow class to decorate.
        datasets: Whether to log `pd.DataFrame` and `pathlib.Path`
            types. Defaults to False.
        models: Whether to log `nn.Module` and `sklearn.base.BaseEstimator`
            types. Defaults to False.
        others: If `True`, log anything pickle-able. Defaults to False.
        settings: Custom settings to pass to `wandb.init`.
            If `run_group` is `None`, it is set to `{flow_name}/{run_id}`.
            If `run_job_type` is `None`, it is set to `{run_job_type}/{step_name}`.
    """

    @wraps(func)
    def decorator(func):
        # If you decorate a class, apply the decoration to all methods in that class
        if inspect.isclass(func):
            cls = func
            for attr in cls.__dict__:
                if callable(getattr(cls, attr)) and not hasattr(attr, "_base_func"):
                    setattr(cls, attr, decorator(getattr(cls, attr)))
            return cls

        # prefer the earliest decoration (i.e. method decoration overrides class decoration)
        if hasattr(func, "_base_func"):
            return func

        @wraps(func)
        def wrapper(self, *args, settings=settings, **kwargs):
            if not isinstance(settings, wandb.sdk.wandb_settings.Settings):
                settings = wandb.Settings()

            settings.update_from_dict(
                {
                    "run_group": coalesce(
                        settings.run_group, f"{current.flow_name}/{current.run_id}"
                    ),
                    "run_job_type": coalesce(settings.run_job_type, current.step_name),
                }
            )

            with wandb.init(settings=settings) as run:
                with wb_telemetry.context(run=run) as tel:
                    tel.feature.metaflow = True
                proxy = ArtifactProxy(self)
                run.config.update(proxy.params)
                func(proxy, *args, **kwargs)

                for name, data in proxy.inputs.items():
                    wandb_use(
                        name,
                        data,
                        datasets=datasets,
                        models=models,
                        others=others,
                        run=run,
                    )

                for name, data in proxy.outputs.items():
                    wandb_track(
                        name,
                        data,
                        datasets=datasets,
                        models=models,
                        others=others,
                        run=run,
                    )

        wrapper._base_func = func

        # Add for testing visibility
        wrapper._kwargs = {
            "datasets": datasets,
            "models": models,
            "others": others,
            "settings": settings,
        }
        return wrapper

    if func is None:
        return decorator
    else:
        return decorator(func)
