"""keras init."""

import logging
import operator
import os
import shutil
import sys
from itertools import chain

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K  # noqa: N812

import wandb
from wandb.proto.wandb_telemetry_pb2 import Deprecated
from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
from wandb.sdk.lib import telemetry
from wandb.sdk.lib.deprecation import warn_and_record_deprecation
from wandb.util import add_import_hook


def _check_keras_version():
    from keras import __version__ as keras_version
    from packaging.version import parse

    if parse(keras_version) < parse("2.4.0"):
        wandb.termwarn(
            f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
        )


def _can_compute_flops() -> bool:
    """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
    from packaging.version import parse

    if parse(tf.__version__) >= parse("2.0.0"):
        return True

    return False


if "keras" in sys.modules:
    _check_keras_version()
else:
    add_import_hook("keras", _check_keras_version)


logger = logging.getLogger(__name__)


def is_dataset(data):
    dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
    if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
        dataset_types = (dataset_ops.DatasetV2,)
        if hasattr(dataset_ops, "DatasetV1"):
            dataset_types = dataset_types + (dataset_ops.DatasetV1,)
        return isinstance(data, dataset_types)
    else:
        return False


def is_generator_like(data):
    # Checks if data is a generator, Sequence, or Iterator.

    types = (tf.keras.utils.Sequence,)
    iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
    if iterator_ops:
        types = types + (iterator_ops.Iterator,)
        # EagerIterator was in tensorflow < 2
        if hasattr(iterator_ops, "EagerIterator"):
            types = types + (iterator_ops.EagerIterator,)
        elif hasattr(iterator_ops, "IteratorV2"):
            types = types + (iterator_ops.IteratorV2,)
    return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)


def patch_tf_keras():  # noqa: C901
    from packaging.version import parse
    from tensorflow.python.eager import context

    if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"):
        keras_engine = "keras.engine"
        try:
            from keras.engine import training
            from keras.engine import training_arrays_v1 as training_arrays
            from keras.engine import training_generator_v1 as training_generator
        except (ImportError, AttributeError):
            wandb.termerror("Unable to patch Tensorflow/Keras")
            logger.exception("exception while trying to patch_tf_keras")
            return
    else:
        keras_engine = "tensorflow.python.keras.engine"

        from tensorflow.python.keras.engine import training

        try:
            from tensorflow.python.keras.engine import (
                training_arrays_v1 as training_arrays,
            )
            from tensorflow.python.keras.engine import (
                training_generator_v1 as training_generator,
            )
        except (ImportError, AttributeError):
            try:
                from tensorflow.python.keras.engine import (
                    training_arrays,
                    training_generator,
                )
            except (ImportError, AttributeError):
                wandb.termerror("Unable to patch Tensorflow/Keras")
                logger.exception("exception while trying to patch_tf_keras")
                return

    # Tensorflow 2.1
    training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
    # Tensorflow 2.2
    training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")

    if training_v2_1:
        old_v2 = training_v2_1.Loop.fit
    elif training_v2_2:
        old_v2 = training.Model.fit

    old_arrays = training_arrays.fit_loop
    old_generator = training_generator.fit_generator

    def set_wandb_attrs(cbk, val_data):
        if isinstance(cbk, WandbCallback):
            if is_generator_like(val_data):
                cbk.generator = val_data
            elif is_dataset(val_data):
                if context.executing_eagerly():
                    cbk.generator = iter(val_data)
                else:
                    wandb.termwarn(
                        "Found a validation dataset in graph mode, can't patch Keras."
                    )
            elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
                # Graph mode dataset generator
                def gen():
                    while True:
                        yield K.get_session().run(val_data)

                cbk.generator = gen()
            else:
                cbk.validation_data = val_data

    def new_arrays(*args, **kwargs):
        cbks = kwargs.get("callbacks", [])
        val_inputs = kwargs.get("val_inputs")
        val_targets = kwargs.get("val_targets")
        # TODO: these could be generators, why index 0?
        if val_inputs and val_targets:
            for cbk in cbks:
                set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
        return old_arrays(*args, **kwargs)

    def new_generator(*args, **kwargs):
        cbks = kwargs.get("callbacks", [])
        val_data = kwargs.get("validation_data")
        if val_data:
            for cbk in cbks:
                set_wandb_attrs(cbk, val_data)
        return old_generator(*args, **kwargs)

    def new_v2(*args, **kwargs):
        cbks = kwargs.get("callbacks", [])
        val_data = kwargs.get("validation_data")
        if val_data:
            for cbk in cbks:
                set_wandb_attrs(cbk, val_data)
        return old_v2(*args, **kwargs)

    training_arrays.orig_fit_loop = old_arrays
    training_arrays.fit_loop = new_arrays
    training_generator.orig_fit_generator = old_generator
    training_generator.fit_generator = new_generator
    wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
    wandb.patched["keras"].append(
        [f"{keras_engine}.training_generator", "fit_generator"]
    )

    if training_v2_1:
        training_v2_1.Loop.fit = new_v2
        wandb.patched["keras"].append(
            ["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
        )
    elif training_v2_2:
        training.Model.fit = new_v2
        wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])


def _array_has_dtype(array):
    return hasattr(array, "dtype")


def _update_if_numeric(metrics, key, values):
    if not _array_has_dtype(values):
        _warn_not_logging(key)
        return

    if not is_numeric_array(values):
        _warn_not_logging_non_numeric(key)
        return

    metrics[key] = wandb.Histogram(values)


def is_numeric_array(array):
    return np.issubdtype(array.dtype, np.number)


def _warn_not_logging_non_numeric(name):
    wandb.termwarn(
        f"Non-numeric values found in layer: {name}, not logging this layer",
        repeat=False,
    )


def _warn_not_logging(name):
    wandb.termwarn(
        f"Layer {name} has undetermined datatype not logging this layer",
        repeat=False,
    )


tf_logger = tf.get_logger()

patch_tf_keras()


### For gradient logging ###


def _get_custom_optimizer_parent_class():
    from packaging.version import parse

    if parse(tf.__version__) >= parse("2.9.0"):
        custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
    else:
        custom_optimizer_parent_class = tf.keras.optimizers.Optimizer

    return custom_optimizer_parent_class


_custom_optimizer_parent_class = _get_custom_optimizer_parent_class()


class _CustomOptimizer(_custom_optimizer_parent_class):
    def __init__(self):
        super().__init__(name="CustomOptimizer")
        self._resource_apply_dense = tf.function(self._resource_apply_dense)
        self._resource_apply_sparse = tf.function(self._resource_apply_sparse)

    def _resource_apply_dense(self, grad, var):
        var.assign(grad)

    # this needs to be implemented to prevent a NotImplementedError when
    # using Lookup layers.
    def _resource_apply_sparse(self, grad, var, indices):
        pass

    def get_config(self):
        return super().get_config()


class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
    """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""

    def set_model(self, model):
        super().set_model(model)
        self.og_weights = model.get_weights()
        self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]

    def on_batch_end(self, batch, logs=None):
        for g, w in zip(self.grads, self.model.trainable_weights):
            g += w.numpy()
        self.model.set_weights(self.og_weights)

    def get_grads(self):
        return [g.copy() for g in self.grads]


###


class WandbCallback(tf.keras.callbacks.Callback):
    """`WandbCallback` automatically integrates keras with wandb.

    Example:
        ```python
        model.fit(
            X_train,
            y_train,
            validation_data=(X_test, y_test),
            callbacks=[WandbCallback()],
        )
        ```

    `WandbCallback` will automatically log history data from any
    metrics collected by keras: loss and anything passed into `keras_model.compile()`.

    `WandbCallback` will set summary metrics for the run associated with the "best" training
    step, where "best" is defined by the `monitor` and `mode` attributes.  This defaults
    to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
    associated with the best `epoch`.

    `WandbCallback` can optionally log gradient and parameter histograms.

    `WandbCallback` can optionally save training and validation data for wandb to visualize.

    Args:
        monitor: (str) name of metric to monitor.  Defaults to `val_loss`.
        mode: (str) one of {`auto`, `min`, `max`}.
            `min` - save model when monitor is minimized
            `max` - save model when monitor is maximized
            `auto` - try to guess when to save the model (default).
        save_model:
            True - save a model when monitor beats all previous epochs
            False - don't save models
        save_graph: (boolean) if True save model graph to wandb (default to True).
        save_weights_only: (boolean) if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        log_weights: (boolean) if True save histograms of the model's layer's weights.
        log_gradients: (boolean) if True log histograms of the training gradients
        training_data: (tuple) Same format `(X,y)` as passed to `model.fit`.  This is needed
            for calculating gradients - this is mandatory if `log_gradients` is `True`.
        validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`.  A set of data
            for wandb to visualize.  If this is set, every epoch, wandb will
            make a small number of predictions and save the results for later visualization. In case
            you are working with image data, please also set `input_type` and `output_type` in order
            to log correctly.
        generator: (generator) a generator that returns validation data for wandb to visualize.  This
            generator should return tuples `(X,y)`.  Either `validate_data` or generator should
            be set for wandb to visualize specific data examples. In case you are working with image data,
            please also set `input_type` and `output_type` in order to log correctly.
        validation_steps: (int) if `validation_data` is a generator, how many
            steps to run the generator for the full validation set.
        labels: (list) If you are visualizing your data with wandb this list of labels
            will convert numeric output to understandable string if you are building a
            multiclass classifier.  If you are making a binary classifier you can pass in
            a list of two labels ["label for false", "label for true"].  If `validate_data`
            and generator are both false, this won't do anything.
        predictions: (int) the number of predictions to make for visualization each epoch, max
            is 100.
        input_type: (string) type of the model input to help visualization. can be one of:
            (`image`, `images`, `segmentation_mask`, `auto`).
        output_type: (string) type of the model output to help visualization. can be one of:
            (`image`, `images`, `segmentation_mask`, `label`).
        log_evaluation: (boolean) if True, save a Table containing validation data and the
            model's predictions at each epoch. See `validation_indexes`,
            `validation_row_processor`, and `output_row_processor` for additional details.
        class_colors: ([float, float, float]) if the input or output is a segmentation mask,
            an array containing an rgb tuple (range 0-1) for each class.
        log_batch_frequency: (integer) if None, callback will log every epoch.
            If set to integer, callback will log training metrics every `log_batch_frequency`
            batches.
        log_best_prefix: (string) if None, no extra summary metrics will be saved.
            If set to a string, the monitored metric and epoch will be prepended with this value
            and stored as summary metrics.
        validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
            with each validation example.  If log_evaluation is True and `validation_indexes` is provided,
            then a Table of validation data will not be created and instead each prediction will
            be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
            such keys are is use `Table.get_index()` which will return a list of row keys.
        validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
            The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
            then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
            input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
            it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
            but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
            as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
        output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
            the results of the model output.
        infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
            should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
            processors where appropriate.
        log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
            Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
        compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
    """

    def __init__(
        self,
        monitor="val_loss",
        verbose=0,
        mode="auto",
        save_weights_only=False,
        log_weights=False,
        log_gradients=False,
        save_model=True,
        training_data=None,
        validation_data=None,
        labels=None,
        predictions=36,
        generator=None,
        input_type=None,
        output_type=None,
        log_evaluation=False,
        validation_steps=None,
        class_colors=None,
        log_batch_frequency=None,
        log_best_prefix="best_",
        save_graph=True,
        validation_indexes=None,
        validation_row_processor=None,
        prediction_row_processor=None,
        infer_missing_processors=True,
        log_evaluation_frequency=0,
        compute_flops=False,
        **kwargs,
    ):
        if wandb.run is None:
            raise wandb.Error("You must call wandb.init() before WandbCallback()")

        warn_and_record_deprecation(
            feature=Deprecated(keras_callback=True),
            message=(
                "WandbCallback is deprecated and will be removed in a future release. "
                "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
                "callbacks instead. "
                "See https://docs.wandb.ai/guides/integrations/keras for more information."
            ),
        )

        with telemetry.context(run=wandb.run) as tel:
            tel.feature.keras = True
        self.validation_data = None
        # This is kept around for legacy reasons
        if validation_data is not None:
            if is_generator_like(validation_data):
                generator = validation_data
            else:
                self.validation_data = validation_data
        if labels is None:
            labels = []
        self.labels = labels
        self.predictions = min(predictions, 100)

        self.monitor = monitor
        self.verbose = verbose
        self.save_weights_only = save_weights_only
        self.save_graph = save_graph

        wandb.save("model-best.h5")
        self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
        self.save_model = save_model
        if save_model:
            warn_and_record_deprecation(
                feature=Deprecated(keras_callback__save_model=True),
                message=(
                    "The save_model argument by default saves the model in the HDF5 format that cannot save "
                    "custom objects like subclassed models and custom layers. This behavior will be deprecated "
                    "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
                    "as W&B files and the SavedModel as W&B Artifacts."
                ),
            )

        self.save_model_as_artifact = True
        self.log_weights = log_weights
        self.log_gradients = log_gradients
        self.training_data = training_data
        self.generator = generator
        self._graph_rendered = False

        data_type = kwargs.get("data_type", None)
        if data_type is not None:
            warn_and_record_deprecation(
                feature=Deprecated(keras_callback__data_type=True),
                message=(
                    "The data_type argument of wandb.keras.WandbCallback is deprecated "
                    "and will be removed in a future release. Please use input_type instead.\n"
                    "Setting input_type = data_type."
                ),
            )
            input_type = data_type
        self.input_type = input_type
        self.output_type = output_type
        self.log_evaluation = log_evaluation
        self.validation_steps = validation_steps
        self.class_colors = np.array(class_colors) if class_colors is not None else None
        self.log_batch_frequency = log_batch_frequency
        self.log_best_prefix = log_best_prefix
        self.compute_flops = compute_flops

        self._prediction_batch_size = None

        if self.log_gradients:
            if int(tf.__version__.split(".")[0]) < 2:
                raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
            if self.training_data is None:
                raise ValueError(
                    "training_data argument is required for gradient logging."
                )
            if isinstance(self.training_data, (list, tuple)):
                if len(self.training_data) != 2:
                    raise ValueError("training data must be a tuple of length two")
                self._training_data_x, self._training_data_y = self.training_data
            else:
                self._training_data_x = (
                    self.training_data
                )  # generator, tf.data.Dataset etc
                self._training_data_y = None

        # From Keras
        if mode not in ["auto", "min", "max"]:
            wandb.termwarn(
                f"WandbCallback mode {mode} is unknown, fallback to auto mode."
            )
            mode = "auto"

        if mode == "min":
            self.monitor_op = operator.lt
            self.best = float("inf")
        elif mode == "max":
            self.monitor_op = operator.gt
            self.best = float("-inf")
        else:
            if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
                self.monitor_op = operator.gt
                self.best = float("-inf")
            else:
                self.monitor_op = operator.lt
                self.best = float("inf")
        # Get the previous best metric for resumed runs
        previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
        if previous_best is not None:
            self.best = previous_best

        self._validation_data_logger = None
        self._validation_indexes = validation_indexes
        self._validation_row_processor = validation_row_processor
        self._prediction_row_processor = prediction_row_processor
        self._infer_missing_processors = infer_missing_processors
        self._log_evaluation_frequency = log_evaluation_frequency
        self._model_trained_since_last_eval = False

    def _build_grad_accumulator_model(self):
        inputs = self.model.inputs
        outputs = self.model(inputs)
        grad_acc_model = tf.keras.models.Model(inputs, outputs)
        grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())

        # make sure magic doesn't think this is a user model
        grad_acc_model._wandb_internal_model = True

        self._grad_accumulator_model = grad_acc_model
        self._grad_accumulator_callback = _GradAccumulatorCallback()

    def _implements_train_batch_hooks(self):
        return self.log_batch_frequency is not None

    def _implements_test_batch_hooks(self):
        return self.log_batch_frequency is not None

    def _implements_predict_batch_hooks(self):
        return self.log_batch_frequency is not None

    def set_params(self, params):
        self.params = params

    def set_model(self, model):
        super().set_model(model)
        if self.input_type == "auto" and len(model.inputs) == 1:
            self.input_type = wandb.util.guess_data_type(
                model.inputs[0].shape, risky=True
            )
        if self.input_type and self.output_type is None and len(model.outputs) == 1:
            self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
        if self.log_gradients:
            self._build_grad_accumulator_model()

    def _attempt_evaluation_log(self, commit=True):
        if self.log_evaluation and self._validation_data_logger:
            try:
                if not self.model:
                    wandb.termwarn("WandbCallback unable to read model from trainer")
                else:
                    self._validation_data_logger.log_predictions(
                        predictions=self._validation_data_logger.make_predictions(
                            self.model.predict
                        ),
                        commit=commit,
                    )
                    self._model_trained_since_last_eval = False
            except Exception as e:
                wandb.termwarn("Error during prediction logging for epoch: " + str(e))

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            logs = {}
        if self.log_weights:
            wandb.log(self._log_weights(), commit=False)

        if self.log_gradients:
            wandb.log(self._log_gradients(), commit=False)

        if self.input_type in (
            "image",
            "images",
            "segmentation_mask",
        ) or self.output_type in ("image", "images", "segmentation_mask"):
            if self.generator:
                self.validation_data = next(self.generator)
            if self.validation_data is None:
                wandb.termwarn(
                    "No validation_data set, pass a generator to the callback."
                )
            elif self.validation_data and len(self.validation_data) > 0:
                wandb.log(
                    {"examples": self._log_images(num_images=self.predictions)},
                    commit=False,
                )

        if (
            self._log_evaluation_frequency > 0
            and epoch % self._log_evaluation_frequency == 0
        ):
            self._attempt_evaluation_log(commit=False)

        wandb.log({"epoch": epoch}, commit=False)
        wandb.log(logs, commit=True)

        self.current = logs.get(self.monitor)
        if self.current and self.monitor_op(self.current, self.best):
            if self.log_best_prefix:
                wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
                    self.current
                )
                wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
                if self.verbose and not self.save_model:
                    wandb.termlog(
                        f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
                    )
            if self.save_model:
                self._save_model(epoch)

            if self.save_model and self.save_model_as_artifact:
                self._save_model_as_artifact(epoch)

            self.best = self.current

    # This is what keras used pre tensorflow.keras
    def on_batch_begin(self, batch, logs=None):
        pass

    # This is what keras used pre tensorflow.keras
    def on_batch_end(self, batch, logs=None):
        if self.save_graph and not self._graph_rendered:
            # Couldn't do this in train_begin because keras may still not be built
            wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
            self._graph_rendered = True

        if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
            wandb.log(logs, commit=True)

    def on_train_batch_begin(self, batch, logs=None):
        self._model_trained_since_last_eval = True

    def on_train_batch_end(self, batch, logs=None):
        if self.save_graph and not self._graph_rendered:
            # Couldn't do this in train_begin because keras may still not be built
            wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
            self._graph_rendered = True

        if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
            wandb.log(logs, commit=True)

    def on_test_begin(self, logs=None):
        pass

    def on_test_end(self, logs=None):
        pass

    def on_test_batch_begin(self, batch, logs=None):
        pass

    def on_test_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        if self.log_evaluation:
            try:
                validation_data = None
                if self.validation_data:
                    validation_data = self.validation_data
                elif self.generator:
                    if not self.validation_steps:
                        wandb.termwarn(
                            "WandbCallback is unable to log validation data. "
                            "When using a generator for validation_data, you must pass validation_steps"
                        )
                    else:
                        x = None
                        y_true = None
                        for _ in range(self.validation_steps):
                            bx, by_true = next(self.generator)
                            if x is None:
                                x, y_true = bx, by_true
                            else:
                                x, y_true = (
                                    np.append(x, bx, axis=0),
                                    np.append(y_true, by_true, axis=0),
                                )
                        validation_data = (x, y_true)
                else:
                    wandb.termwarn(
                        "WandbCallback is unable to read validation_data from trainer "
                        "and therefore cannot log validation data. Ensure Keras is properly "
                        "patched by calling `from wandb.keras import WandbCallback` at the top of your script."
                    )
                if validation_data:
                    self._validation_data_logger = ValidationDataLogger(
                        inputs=validation_data[0],
                        targets=validation_data[1],
                        indexes=self._validation_indexes,
                        validation_row_processor=self._validation_row_processor,
                        prediction_row_processor=self._prediction_row_processor,
                        class_labels=self.labels,
                        infer_missing_processors=self._infer_missing_processors,
                    )
            except Exception as e:
                wandb.termwarn(
                    "Error initializing ValidationDataLogger in WandbCallback. "
                    f"Skipping logging validation data. Error: {str(e)}"
                )

        if self.compute_flops and _can_compute_flops():
            try:
                wandb.summary["GFLOPs"] = self.get_flops()
            except Exception:
                logger.exception("Error computing FLOPs")
                wandb.termwarn("Unable to compute FLOPs for this model.")

    def on_train_end(self, logs=None):
        if self._model_trained_since_last_eval:
            self._attempt_evaluation_log()

    def on_predict_begin(self, logs=None):
        pass

    def on_predict_end(self, logs=None):
        pass

    def on_predict_batch_begin(self, batch, logs=None):
        pass

    def on_predict_batch_end(self, batch, logs=None):
        pass

    def _logits_to_captions(self, logits):
        if logits[0].shape[-1] == 1:
            # Scalar output from the model
            # TODO: handle validation_y
            if len(self.labels) == 2:
                # User has named true and false
                captions = [
                    self.labels[1] if logits[0] > 0.5 else self.labels[0]
                    for logit in logits
                ]
            else:
                if len(self.labels) != 0:
                    wandb.termwarn(
                        "keras model is producing a single output, "
                        'so labels should be a length two array: ["False label", "True label"].'
                    )
                captions = [logit[0] for logit in logits]
        else:
            # Vector output from the model
            # TODO: handle validation_y
            labels = np.argmax(np.stack(logits), axis=1)

            if len(self.labels) > 0:
                # User has named the categories in self.labels
                captions = []
                for label in labels:
                    try:
                        captions.append(self.labels[label])
                    except IndexError:
                        captions.append(label)
            else:
                captions = labels
        return captions

    def _masks_to_pixels(self, masks):
        # if its a binary mask, just return it as grayscale instead of picking the argmax
        if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
            return masks
        class_colors = (
            self.class_colors
            if self.class_colors is not None
            else np.array(wandb.util.class_colors(masks[0].shape[2]))
        )
        imgs = class_colors[np.argmax(masks, axis=-1)]
        return imgs

    def _log_images(self, num_images=36):
        validation_X = self.validation_data[0]  # noqa: N806
        validation_y = self.validation_data[1]

        validation_length = len(validation_X)

        if validation_length > num_images:
            # pick some data at random
            indices = np.random.choice(validation_length, num_images, replace=False)
        else:
            indices = range(validation_length)

        test_data = []
        test_output = []
        for i in indices:
            test_example = validation_X[i]
            test_data.append(test_example)
            test_output.append(validation_y[i])

        if self.model.stateful:
            predictions = self.model.predict(np.stack(test_data), batch_size=1)
            self.model.reset_states()
        else:
            predictions = self.model.predict(
                np.stack(test_data), batch_size=self._prediction_batch_size
            )
            if len(predictions) != len(test_data):
                self._prediction_batch_size = 1
                predictions = self.model.predict(
                    np.stack(test_data), batch_size=self._prediction_batch_size
                )

        if self.input_type == "label":
            if self.output_type in ("image", "images", "segmentation_mask"):
                captions = self._logits_to_captions(test_data)
                output_image_data = (
                    self._masks_to_pixels(predictions)
                    if self.output_type == "segmentation_mask"
                    else predictions
                )
                reference_image_data = (
                    self._masks_to_pixels(test_output)
                    if self.output_type == "segmentation_mask"
                    else test_output
                )
                output_images = [
                    wandb.Image(data, caption=captions[i], grouping=2)
                    for i, data in enumerate(output_image_data)
                ]
                reference_images = [
                    wandb.Image(data, caption=captions[i])
                    for i, data in enumerate(reference_image_data)
                ]
                return list(chain.from_iterable(zip(output_images, reference_images)))
        elif self.input_type in ("image", "images", "segmentation_mask"):
            input_image_data = (
                self._masks_to_pixels(test_data)
                if self.input_type == "segmentation_mask"
                else test_data
            )
            if self.output_type == "label":
                # we just use the predicted label as the caption for now
                captions = self._logits_to_captions(predictions)
                return [
                    wandb.Image(data, caption=captions[i])
                    for i, data in enumerate(test_data)
                ]
            elif self.output_type in ("image", "images", "segmentation_mask"):
                output_image_data = (
                    self._masks_to_pixels(predictions)
                    if self.output_type == "segmentation_mask"
                    else predictions
                )
                reference_image_data = (
                    self._masks_to_pixels(test_output)
                    if self.output_type == "segmentation_mask"
                    else test_output
                )
                input_images = [
                    wandb.Image(data, grouping=3)
                    for i, data in enumerate(input_image_data)
                ]
                output_images = [
                    wandb.Image(data) for i, data in enumerate(output_image_data)
                ]
                reference_images = [
                    wandb.Image(data) for i, data in enumerate(reference_image_data)
                ]
                return list(
                    chain.from_iterable(
                        zip(input_images, output_images, reference_images)
                    )
                )
            else:
                # unknown output, just log the input images
                return [wandb.Image(img) for img in test_data]
        elif self.output_type in ("image", "images", "segmentation_mask"):
            # unknown input, just log the predicted and reference outputs without captions
            output_image_data = (
                self._masks_to_pixels(predictions)
                if self.output_type == "segmentation_mask"
                else predictions
            )
            reference_image_data = (
                self._masks_to_pixels(test_output)
                if self.output_type == "segmentation_mask"
                else test_output
            )
            output_images = [
                wandb.Image(data, grouping=2)
                for i, data in enumerate(output_image_data)
            ]
            reference_images = [
                wandb.Image(data) for i, data in enumerate(reference_image_data)
            ]
            return list(chain.from_iterable(zip(output_images, reference_images)))

    def _log_weights(self):
        metrics = {}
        for layer in self.model.layers:
            weights = layer.get_weights()
            if len(weights) == 1:
                _update_if_numeric(
                    metrics, "parameters/" + layer.name + ".weights", weights[0]
                )
            elif len(weights) == 2:
                _update_if_numeric(
                    metrics, "parameters/" + layer.name + ".weights", weights[0]
                )
                _update_if_numeric(
                    metrics, "parameters/" + layer.name + ".bias", weights[1]
                )
        return metrics

    def _log_gradients(self):
        # Suppress callback warnings grad accumulator
        og_level = tf_logger.level
        tf_logger.setLevel("ERROR")

        self._grad_accumulator_model.fit(
            self._training_data_x,
            self._training_data_y,
            verbose=0,
            callbacks=[self._grad_accumulator_callback],
        )
        tf_logger.setLevel(og_level)
        weights = self.model.trainable_weights
        grads = self._grad_accumulator_callback.grads
        metrics = {}
        for weight, grad in zip(weights, grads):
            metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
                wandb.Histogram(grad)
            )
        return metrics

    def _log_dataframe(self):
        x, y_true, y_pred = None, None, None

        if self.validation_data:
            x, y_true = self.validation_data[0], self.validation_data[1]
            y_pred = self.model.predict(x)
        elif self.generator:
            if not self.validation_steps:
                wandb.termwarn(
                    "when using a generator for validation data with dataframes, "
                    "you must pass validation_steps. skipping"
                )
                return None

            for _ in range(self.validation_steps):
                bx, by_true = next(self.generator)
                by_pred = self.model.predict(bx)
                if x is None:
                    x, y_true, y_pred = bx, by_true, by_pred
                else:
                    x, y_true, y_pred = (
                        np.append(x, bx, axis=0),
                        np.append(y_true, by_true, axis=0),
                        np.append(y_pred, by_pred, axis=0),
                    )

        if self.input_type in ("image", "images") and self.output_type == "label":
            return wandb.image_categorizer_dataframe(
                x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
            )
        elif (
            self.input_type in ("image", "images")
            and self.output_type == "segmentation_mask"
        ):
            return wandb.image_segmentation_dataframe(
                x=x,
                y_true=y_true,
                y_pred=y_pred,
                labels=self.labels,
                class_colors=self.class_colors,
            )
        else:
            wandb.termwarn(
                f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
            )
            return None

    def _save_model(self, epoch):
        if wandb.run.disabled:
            return
        if self.verbose > 0:
            wandb.termlog(
                f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
                f"saving model to {self.filepath}"
            )

        try:
            if self.save_weights_only:
                self.model.save_weights(self.filepath, overwrite=True)
            else:
                self.model.save(self.filepath, overwrite=True)
        # Was getting `RuntimeError: Unable to create link` in TF 1.13.1
        # also saw `TypeError: can't pickle _thread.RLock objects`
        except (ImportError, RuntimeError, TypeError, AttributeError):
            logger.exception("Error saving model in the h5py format")
            wandb.termerror(
                "Can't save model in the h5py format. The model will be saved as "
                "as an W&B Artifact in the 'tf' format."
            )

    def _save_model_as_artifact(self, epoch):
        if wandb.run.disabled:
            return

        # Save the model in the SavedModel format.
        # TODO: Replace this manual artifact creation with the `log_model` method
        # after `log_model` is released from beta.
        self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")

        # Log the model as artifact.
        name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
        model_artifact = wandb.Artifact(name, type="model")
        model_artifact.add_dir(self.filepath[:-3])
        wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])

        # Remove the SavedModel from wandb dir as we don't want to log it to save memory.
        shutil.rmtree(self.filepath[:-3])

    def get_flops(self) -> float:
        """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.

        It uses tf.compat.v1.profiler under the hood.
        """
        if not hasattr(self, "model"):
            raise wandb.Error("self.model must be set before using this method.")

        if not isinstance(
            self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
        ):
            raise TypeError(
                "Calculating FLOPS is only supported for "
                "`tf.keras.Model` and `tf.keras.Sequential` instances."
            )

        from tensorflow.python.framework.convert_to_constants import (
            convert_variables_to_constants_v2_as_graph,
        )

        # Compute FLOPs for one sample
        batch_size = 1
        inputs = [
            tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
            for inp in self.model.inputs
        ]

        # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
        real_model = tf.function(self.model).get_concrete_function(inputs)
        frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)

        # Calculate FLOPs with tf.profiler
        run_meta = tf.compat.v1.RunMetadata()
        opts = (
            tf.compat.v1.profiler.ProfileOptionBuilder(
                tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
            )
            .with_empty_output()
            .build()
        )

        flops = tf.compat.v1.profiler.profile(
            graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
        )

        # convert to GFLOPs
        return (flops.total_float_ops / 1e9) / 2
