from warnings import simplefilter

import numpy as np
from sklearn import model_selection

import wandb
from wandb.integration.sklearn import utils

# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)


def learning_curve(
    model,
    X,  # noqa: N803
    y,
    cv=None,
    shuffle=False,
    random_state=None,
    train_sizes=None,
    n_jobs=1,
    scoring=None,
):
    """Train model on datasets of varying size and generates plot of score vs size.

    Called by plot_learning_curve to visualize learning curve. Please use the function
    plot_learning_curve() if you wish to visualize your learning curves.
    """
    train_sizes, train_scores, test_scores = model_selection.learning_curve(
        model,
        X,
        y,
        cv=cv,
        n_jobs=n_jobs,
        train_sizes=train_sizes,
        scoring=scoring,
        shuffle=shuffle,
        random_state=random_state,
    )
    train_scores_mean = np.mean(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)

    table = make_table(train_scores_mean, test_scores_mean, train_sizes)
    chart = wandb.visualize("wandb/learning_curve/v1", table)

    return chart


def make_table(train, test, train_sizes):
    data = []
    for i in range(len(train)):
        if utils.check_against_limit(
            i,
            "learning_curve",
            utils.chart_limit / 2,
        ):
            break
        train_set = ["train", utils.round_2(train[i]), train_sizes[i]]
        test_set = ["test", utils.round_2(test[i]), train_sizes[i]]
        data.append(train_set)
        data.append(test_set)

    table = wandb.Table(columns=["dataset", "score", "train_size"], data=data)
    return table
