# Copyright  2022-2024  Xiaomi Corp.       (authors: Daniel Povey
#                                                    Zengwei Yao
#                                                    Mingshuang Luo,
#                                                    Zengrui Jin,)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import Tensor, nn


class TensorDiagnosticOptions(object):
    """Options object for tensor diagnostics:

    Args:
      max_eig_dim:
        The maximum dimension for which we print out eigenvalues
        (limited for speed reasons).
    """

    def __init__(self, max_eig_dim: int = 512):
        self.max_eig_dim = max_eig_dim

    def dim_is_summarized(self, size: int):
        return size > 10 and size != 31


def get_tensor_stats(
    x: Tensor,
    dim: int,
    stats_type: str,
) -> Tuple[Tensor, int]:
    """
    Returns the specified transformation of the Tensor (either x or x.abs()
    or (x > 0), summed over all but the index `dim`.

    Args:
      x:
        Tensor, tensor to be analyzed
      dim:
        Dimension with 0 <= dim < x.ndim
      stats_type:
        The stats_type includes several types:
        "abs" -> take abs() before summing
        "positive" -> take (x > 0) before summing
        "rms" -> square before summing, we'll take sqrt later
        "value"  -> just sum x itself
        "max", "min" -> take the maximum or minimum [over all other dims but dim]
            instead of summing
        "rms-sort" -> this is a bit different than the others, it's based on computing
            the rms over the specified dim and returning percentiles of the result
            (11 of them).
    Returns:
      stats: a Tensor of shape (x.shape[dim],).
      count: an integer saying how many items were counted in each element
      of stats.
    """

    if stats_type == "rms-sort":
        rms = (x**2).mean(dim=dim).sqrt()
        rms = rms.flatten()
        rms = rms.sort()[0]
        rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
        count = 1.0
        return rms, count

    count = x.numel() // x.shape[dim]

    if stats_type == "eigs":
        x = x.transpose(dim, -1)
        x = x.reshape(-1, x.shape[-1])
        # shape of returned tensor: (s, s),
        # where s is size of dimension `dim` of original x.
        return torch.matmul(x.transpose(0, 1), x), count
    elif stats_type == "abs":
        x = x.abs()
    elif stats_type == "rms":
        x = x**2
    elif stats_type == "positive":
        x = (x > 0).to(dtype=torch.float)
    else:
        assert stats_type in ["value", "max", "min"]

    sum_dims = [d for d in range(x.ndim) if d != dim]
    if len(sum_dims) > 0:
        if stats_type == "max":
            for dim in reversed(sum_dims):
                x = torch.max(x, dim=dim)[0]
        elif stats_type == "min":
            for dim in reversed(sum_dims):
                x = torch.min(x, dim=dim)[0]
        else:
            x = torch.sum(x, dim=sum_dims)
    x = x.flatten().clone()
    return x, count


@dataclass
class TensorAndCount:
    tensor: Tensor
    count: int


class TensorDiagnostic(object):
    """This class is not directly used by the user, it is responsible for
    collecting diagnostics for a module or parameter tensor of a torch.nn.Module.

    Args:
      opts:
        Options object.
      name:
        The name associated with this diagnostics object, will probably be
            {module_name}.X where X is "output" or "grad", or {parameter_name}.
            Y where Y is param_value or param_grad.
    """

    def __init__(self, opts: TensorDiagnosticOptions, name: str):
        self.opts = opts
        self.name = name
        self.class_name = None  # will assign in accumulate()

        self.stats = None  # we'll later assign a list to self.stats.
        # It's a list of dicts, indexed by dim (i.e. by the
        # axis of the tensor).  The dicts, in turn, are
        # indexed by `stats-type` which are strings in
        # ["abs", "max", "min", "positive", "value", "rms"].

        # scalar_stats contains some analysis of the activations and gradients,
        self.scalar_stats = None

        # the keys into self.stats[dim] are strings, whose values can be
        # "abs", "max", "min" ,"value", "positive", "rms", "value".
        # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
        # containing a tensor and its associated count (which is the sum of the other
        # dims that we aggregated over, e.g. the number of frames and/or batch elements
        # and/or channels.
        # ... we actually accumulate the Tensors / counts any time we have the same-dim
        # tensor, only adding a new element to the list if there was a different dim.
        # if the string in the key is "eigs", if we detect a length mismatch we put None
        # as the value.

    def accumulate(self, x, class_name: Optional[str] = None):
        """
        Accumulate tensors.
        """
        if class_name is not None:
            self.class_name = class_name
        if isinstance(x, Tuple):
            x = x[0]
        if not isinstance(x, Tensor):
            return
        if x.numel() == 0:  # for empty tensor
            return
        x = x.detach().clone()
        if x.ndim == 0:
            x = x.unsqueeze(0)
        ndim = x.ndim
        if self.stats is None:
            self.stats = [dict() for _ in range(ndim)]

        for dim in range(ndim):
            this_dim_stats = self.stats[dim]
            if ndim > 1:
                # rms-sort is different from the others, it's based on summing over just
                # this dim, then sorting and returning the percentiles.
                stats_types = [
                    "abs",
                    "max",
                    "min",
                    "positive",
                    "value",
                    "rms",
                    "rms-sort",
                ]
                if x.shape[dim] <= self.opts.max_eig_dim:
                    stats_types.append("eigs")
            else:
                stats_types = ["value", "abs", "max", "min"]

            for stats_type in stats_types:
                stats, count = get_tensor_stats(x, dim, stats_type)
                if stats_type not in this_dim_stats:
                    this_dim_stats[stats_type] = []  # list of TensorAndCount

                done = False
                if this_dim_stats[stats_type] is None:
                    # we can reach here if we detected for stats_type "eigs" that
                    # where was more than one different size for this dim.  Then we
                    # disable accumulating this stats type, as it uses too much memory.
                    continue
                for s in this_dim_stats[stats_type]:
                    if s.tensor.shape == stats.shape:
                        if stats_type == "max":
                            s.tensor = torch.maximum(s.tensor, stats)

                        elif stats_type == "min":
                            s.tensor = torch.minimum(s.tensor, stats)
                        else:
                            assert stats_type != "max"
                            s.tensor += stats
                        s.count += count
                        done = True
                        break
                if not done:
                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                        # >1 size encountered on this dim, e.g. it's a batch or time
                        # dimension, don't accumulat "eigs" stats type, it uses too much
                        # memory
                        this_dim_stats[stats_type] = None
                    else:
                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))

    def print_diagnostics(self):
        """Print diagnostics for each dimension of the tensor."""
        if self.stats is None:
            print(f"Warning: the stats of {self.name} is None.")
            return
        for dim, this_dim_stats in enumerate(self.stats):
            if "rms" in this_dim_stats and "value" in this_dim_stats:
                # produce "stddev" stats, which is centered RMS.
                rms_stats_list = this_dim_stats["rms"]
                value_stats_list = this_dim_stats["value"]
                if len(rms_stats_list) == len(value_stats_list):
                    stddev_stats_list = []
                    for r, v in zip(rms_stats_list, value_stats_list):
                        stddev_stats_list.append(
                            # r.count and v.count should be the same, but we don't check
                            # this.
                            TensorAndCount(
                                r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
                                r.count,
                            )
                        )
                    this_dim_stats["stddev"] = stddev_stats_list

            for stats_type, stats_list in this_dim_stats.items():
                # stats_type could be "rms", "value", "abs", "eigs", "positive", "min"
                # or "max". "stats_list" could be a list of TensorAndCount (one list per
                # distinct tensor shape of the stats), or None
                if stats_list is None:
                    assert stats_type == "eigs"
                    continue

                def get_count(count):
                    return 1 if stats_type in ["max", "min"] else count

                if len(stats_list) == 1:
                    stats = stats_list[0].tensor / get_count(stats_list[0].count)
                else:
                    # a dimension that has variable size in different nnet
                    # forwards, e.g. a time dimension in an ASR model.
                    stats = torch.cat(
                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
                    )

                if stats_type == "eigs":
                    try:
                        if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
                            eigs, _ = torch.linalg.eigh(stats)
                        else:
                            eigs, _ = torch.symeig(stats)
                        stats = eigs.abs().sqrt()
                    except:  # noqa
                        print("Error getting eigenvalues, trying another method.")
                        if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
                            eigs, _ = torch.linalg.eig(stats)
                            eigs = eigs.abs()
                        else:
                            eigs, _ = torch.eig(stats)
                            eigs = eigs.norm(dim=1)
                        stats = eigs.sqrt()
                        # sqrt so it reflects data magnitude, like stddev- not variance

                if stats_type in ["rms", "stddev"]:
                    # we stored the square; after aggregation we need to take sqrt.
                    stats = stats.sqrt()

                # if `summarize` we print percentiles of the stats; else,
                # we print out individual elements.
                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
                    stats.numel()
                )
                if summarize:  # usually `summarize` will be true
                    # print out percentiles.
                    stats = stats.sort()[0]
                    num_percentiles = 10
                    size = stats.numel()
                    percentiles = []
                    for i in range(num_percentiles + 1):
                        index = (i * (size - 1)) // num_percentiles
                        percentiles.append(stats[index].item())
                    percentiles = ["%.2g" % x for x in percentiles]
                    percentiles = " ".join(percentiles)
                    ans = f"percentiles: [{percentiles}]"
                else:
                    ans = stats.tolist()
                    ans = ["%.2g" % x for x in ans]
                    ans = "[" + " ".join(ans) + "]"
                if stats_type in ["value", "rms", "stddev", "eigs"]:
                    # This norm is useful because it is strictly less than the largest
                    # sqrt(eigenvalue) of the variance, which we print out, and shows,
                    # speaking in an approximate way, how much of that largest
                    # eigenvalue can be attributed to the mean of the distribution.
                    norm = (stats**2).sum().sqrt().item()
                    ans += f", norm={norm:.2g}"
                mean = stats.mean().item()
                rms = (stats**2).mean().sqrt().item()
                ans += f", mean={mean:.3g}, rms={rms:.3g}"

                # OK, "ans" contains the actual stats, e.g.
                # ans = "percentiles: \
                # [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], \
                # mean=0.5, rms=0.5"

                sizes = [x.tensor.shape[0] for x in stats_list]
                size_str = (
                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
                )
                maybe_class_name = (
                    f" type={self.class_name}," if self.class_name is not None else ""
                )
                print(
                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, "
                    f"{stats_type} {ans}"
                )


class ScalarDiagnostic(object):
    """This class is not directly used by the user, it is responsible for
    collecting diagnostics for a single module (subclass of torch.nn.Module) that
    represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
    """

    def __init__(self, opts: TensorDiagnosticOptions, name: str):
        self.opts = opts
        self.name = name
        self.class_name = None  # will assign in accumulate()
        self.is_forward_pass = True

        self.tick_scale = None

        self.saved_inputs = []
        self.is_ok = True

        self.counts = None
        self.sum_grad = None
        self.sum_gradsq = None
        self.sum_abs_grad = None

    def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
        """
        Called in forward pass.
        """
        if not self.is_forward_pass:
            # in case we did a forward pass without a backward pass, for some reason.
            self.saved_inputs = []
            self.is_forward_pass = True

        if class_name is not None:
            self.class_name = class_name
        if not self.is_ok:
            return

        limit = 10
        if len(self.saved_inputs) > limit:
            print(
                f"ERROR: forward pass called for this module over {limit} times "
                f"with no backward pass. Will not accumulate scalar stats."
            )
            self.is_ok = False
            return
        self.saved_inputs.append(x)

    def accumulate_output_grad(self, grad: Tensor):
        if not self.is_ok:
            return
        if self.is_forward_pass:
            self.is_forward_pass = False

        last_shape = (
            "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
        )
        if len(self.saved_inputs) == 0 or grad.shape != last_shape:
            print(
                f"ERROR: shape mismatch or no forward activation present when backward "
                f"pass called: grad shape ={tuple(grad.shape)}"
                f", num-saved-inputs={len(self.saved_inputs)}"
                f", shape-of-last-saved-input={last_shape}"
            )
            self.is_ok = False
            return

        x = self.saved_inputs.pop()
        self.process_input_and_grad(x, grad)

    def process_input_and_grad(self, x: Tensor, grad: Tensor):
        assert x.shape == grad.shape
        x = x.flatten()
        grad = grad.flatten()

        num_ticks_per_side = 256

        if self.tick_scale is None:
            x_abs_sorted = x.abs().sort()[0]
            # take the 98th percentile as the largest value we count separately.
            index = int(x.numel() * 0.98)
            self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)

            # integerize from tick * (-num ticks_per_side ..  num_ticks_per_side - 1]
            self.counts = torch.zeros(
                2 * num_ticks_per_side, dtype=torch.long, device=x.device
            )
            self.sum_grad = torch.zeros(
                2 * num_ticks_per_side, dtype=torch.double, device=x.device
            )
            # sum_gradsq is for getting error bars.
            self.sum_gradsq = torch.zeros(
                2 * num_ticks_per_side, dtype=torch.double, device=x.device
            )
            self.sum_abs_grad = torch.zeros(
                2 * num_ticks_per_side, dtype=torch.double, device=x.device
            )

        # this will round down.
        x = (x / self.tick_scale).to(torch.long)
        x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
        x = x + num_ticks_per_side

        self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
        self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
        self.sum_gradsq.index_add_(
            dim=0, index=x, source=(grad * grad).to(torch.double)
        )
        self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))

    def print_diagnostics(self):
        """Print diagnostics."""
        if self.is_ok is False or self.counts is None:
            print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
            return

        counts = self.counts.to("cpu")
        sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
        sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
        sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)

        counts_cumsum = counts.cumsum(dim=0)
        counts_tot = counts_cumsum[-1]

        # subdivide the distribution up into `num_bins` intervals for analysis, for
        # greater statistical significance.  each bin corresponds to multiple of the
        # original 'tick' intervals.
        num_bins = 20

        # integer division
        counts_per_bin = (counts_tot // num_bins) + 1
        bin_indexes = counts_cumsum // counts_per_bin
        bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)

        bin_counts = torch.zeros(num_bins, dtype=torch.long)
        bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
        bin_grad = torch.zeros(num_bins)
        bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
        bin_gradsq = torch.zeros(num_bins)
        bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
        bin_abs_grad = torch.zeros(num_bins)
        bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)

        bin_boundary_counts = (
            torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
        )
        bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
        # boundaries are the "x" values between the bins, e.g. corresponding to the
        # locations of percentiles of the distribution.
        num_ticks_per_side = counts.numel() // 2
        bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale

        bin_grad = bin_grad / (bin_counts + 1)
        bin_conf_interval = bin_gradsq.sqrt() / (
            bin_counts + 1
        )  # consider this a standard deviation.
        # bin_grad / bin_abs_grad will give us a sense for how important in a practical
        # sense, the gradients are.
        bin_abs_grad = bin_abs_grad / (bin_counts + 1)

        bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
        bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)

        def tensor_to_str(x: Tensor):
            x = ["%.2g" % f for f in x]
            x = "[" + " ".join(x) + "]"
            return x

        maybe_class_name = (
            f" type={self.class_name}," if self.class_name is not None else ""
        )

        print(
            f"module={self.name},{maybe_class_name} "
            f"bin-boundaries={tensor_to_str(bin_boundaries)}, "
            f"rel_grad={tensor_to_str(bin_rel_grad)}, "
            f"grad_conf={tensor_to_str(bin_conf)}"
        )


class ModelDiagnostic(object):
    """This class stores diagnostics for all tensors in the torch.nn.Module.

    Args:
      opts:
        Options object.
    """

    def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
        # In this dictionary, the keys are tensors names and the values
        # are corresponding TensorDiagnostic objects.
        if opts is None:
            self.opts = TensorDiagnosticOptions()
        else:
            self.opts = opts
        self.diagnostics = dict()

    def __getitem__(self, name: str):
        T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
        if name not in self.diagnostics:
            self.diagnostics[name] = T(self.opts, name)
        return self.diagnostics[name]

    def print_diagnostics(self):
        """Print diagnostics for each tensor."""
        for k in sorted(self.diagnostics.keys()):
            self.diagnostics[k].print_diagnostics()


def get_class_name(module: nn.Module):
    ans = type(module).__name__
    # we put the below in try blocks in case anyone is using a different version of
    # these modules that might have different member names.
    if ans == "Balancer" or ans == "ActivationBalancer":
        try:
            ans += f"[{float(module.min_positive)},{float(module.max_positive)},"
            f"{float(module.min_abs)},{float(module.max_abs)}]"
        except:
            pass
    elif ans == "AbsValuePenalizer":
        try:
            ans += f"[{module.limit}]"
        except:
            pass
    return ans


def attach_diagnostics(
    model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
) -> ModelDiagnostic:
    """Attach a ModelDiagnostic object to the model by
    1) registering forward hook and backward hook on each module, to accumulate
    its output tensors and gradient tensors, respectively;
    2) registering backward hook on each module parameter, to accumulate its
    values and gradients.

    Args:
      model:
        the model to be analyzed.
      opts:
        Options object.

    Returns:
      The ModelDiagnostic object attached to the model.
    """

    ans = ModelDiagnostic(opts)
    for name, module in model.named_modules():
        if name == "":
            name = "<top-level>"

        # Setting model_diagnostic=ans and n=name below, instead of trying to
        # capture the variables, ensures that we use the current values.
        # (this matters for `name`, since the variable gets overwritten).
        # These closures don't really capture by value, only by
        # "the final value the variable got in the function" :-(
        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
            if isinstance(_output, tuple) and len(_output) == 1:
                _output = _output[0]

            if isinstance(_output, Tensor) and _output.dtype in (
                torch.float32,
                torch.float16,
                torch.float64,
            ):
                _model_diagnostic[f"{_name}.output"].accumulate(
                    _output, class_name=get_class_name(_module)
                )
            elif isinstance(_output, tuple):
                for i, o in enumerate(_output):
                    if isinstance(o, Tensor) and o.dtype in (
                        torch.float32,
                        torch.float16,
                        torch.float64,
                    ):
                        _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
                            o, class_name=get_class_name(_module)
                        )

        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
            if isinstance(_output, tuple) and len(_output) == 1:
                _output = _output[0]
            if isinstance(_output, Tensor) and _output.dtype in (
                torch.float32,
                torch.float16,
                torch.float64,
            ):
                _model_diagnostic[f"{_name}.grad"].accumulate(
                    _output, class_name=get_class_name(_module)
                )
            elif isinstance(_output, tuple):
                for i, o in enumerate(_output):
                    if isinstance(o, Tensor) and o.dtype in (
                        torch.float32,
                        torch.float16,
                        torch.float64,
                    ):
                        _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
                            o, class_name=get_class_name(_module)
                        )

        module.register_forward_hook(forward_hook)
        module.register_backward_hook(backward_hook)

        if type(module).__name__ in [
            "Sigmoid",
            "Tanh",
            "ReLU",
            "TanSwish",
            "Swish",
            "DoubleSwish",
            "Swoosh",
        ]:
            # For these specific module types, accumulate some additional diagnostics
            # that can help us improve the activation function.  These require a lot of
            # memory, to save the forward activations, so limit this to some select
            # classes. Note: this will not work correctly for all model types.
            def scalar_forward_hook(
                _module, _input, _output, _model_diagnostic=ans, _name=name
            ):
                if isinstance(_input, tuple):
                    (_input,) = _input
                assert isinstance(_input, Tensor)
                _model_diagnostic[f"{_name}.scalar"].accumulate_input(
                    _input, class_name=get_class_name(_module)
                )

            def scalar_backward_hook(
                _module, _input, _output, _model_diagnostic=ans, _name=name
            ):
                if isinstance(_output, tuple):
                    (_output,) = _output
                assert isinstance(_output, Tensor)
                _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)

            module.register_forward_hook(scalar_forward_hook)
            module.register_backward_hook(scalar_backward_hook)

    for name, parameter in model.named_parameters():

        def param_backward_hook(
            grad, _parameter=parameter, _model_diagnostic=ans, _name=name
        ):
            _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
            _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)

        try:
            parameter.register_hook(param_backward_hook)
        except:
            logging.warning(
                f"Warning: could not register backward hook for parameter {name}, "
                f"it might not be differentiable."
            )

    return ans


def _test_tensor_diagnostic():
    opts = TensorDiagnosticOptions(512)

    diagnostic = TensorDiagnostic(opts, "foo")

    for _ in range(10):
        diagnostic.accumulate(torch.randn(50, 100) * 10.0)

    diagnostic.print_diagnostics()

    model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))

    diagnostic = attach_diagnostics(model, opts)
    for _ in range(10):
        T = random.randint(200, 300)
        x = torch.randn(T, 100)
        y = model(x)
        y.sum().backward()

    diagnostic.print_diagnostics()


if __name__ == "__main__":
    _test_tensor_diagnostic()
