# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import os.path as op

import torch
import torch.nn.functional as F
import numpy as np

from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from fairseq.speech_generator import (
    AutoRegressiveSpeechGenerator,
    NonAutoregressiveSpeechGenerator,
    TeacherForcingAutoRegressiveSpeechGenerator,
)

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)


try:
    from tensorboardX import SummaryWriter
except ImportError:
    logger.info("Please install tensorboardX: pip install tensorboardX")
    SummaryWriter = None


@register_task("text_to_speech")
class TextToSpeechTask(SpeechToTextTask):
    @staticmethod
    def add_args(parser):
        parser.add_argument("data", help="manifest root path")
        parser.add_argument(
            "--config-yaml",
            type=str,
            default="config.yaml",
            help="Configuration YAML filename (under manifest root)",
        )
        parser.add_argument(
            "--max-source-positions",
            default=1024,
            type=int,
            metavar="N",
            help="max number of tokens in the source sequence",
        )
        parser.add_argument(
            "--max-target-positions",
            default=1200,
            type=int,
            metavar="N",
            help="max number of tokens in the target sequence",
        )
        parser.add_argument("--n-frames-per-step", type=int, default=1)
        parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
        parser.add_argument("--eval-inference", action="store_true")
        parser.add_argument("--eval-tb-nsample", type=int, default=8)
        parser.add_argument("--vocoder", type=str, default="griffin_lim")
        parser.add_argument("--spec-bwd-max-iter", type=int, default=8)

    def __init__(self, args, src_dict):
        super().__init__(args, src_dict)
        self.src_dict = src_dict
        self.sr = self.data_cfg.config.get("features").get("sample_rate")

        self.tensorboard_writer = None
        self.tensorboard_dir = ""
        if args.tensorboard_logdir and SummaryWriter is not None:
            self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra")

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        is_train_split = split.startswith("train")
        pre_tokenizer = self.build_tokenizer(self.args)
        bpe_tokenizer = self.build_bpe(self.args)
        self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
            self.args.data,
            self.data_cfg,
            split,
            self.src_dict,
            pre_tokenizer,
            bpe_tokenizer,
            is_train_split=is_train_split,
            epoch=epoch,
            seed=self.args.seed,
            n_frames_per_step=self.args.n_frames_per_step,
            speaker_to_id=self.speaker_to_id,
        )

    @property
    def target_dictionary(self):
        return None

    @property
    def source_dictionary(self):
        return self.src_dict

    def get_speaker_embeddings_path(self):
        speaker_emb_path = None
        if self.data_cfg.config.get("speaker_emb_filename") is not None:
            speaker_emb_path = op.join(
                self.args.data, self.data_cfg.config.get("speaker_emb_filename")
            )
        return speaker_emb_path

    @classmethod
    def get_speaker_embeddings(cls, args):
        embed_speaker = None
        if args.speaker_to_id is not None:
            if args.speaker_emb_path is None:
                embed_speaker = torch.nn.Embedding(
                    len(args.speaker_to_id), args.speaker_embed_dim
                )
            else:
                speaker_emb_mat = np.load(args.speaker_emb_path)
                assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
                embed_speaker = torch.nn.Embedding.from_pretrained(
                    torch.from_numpy(speaker_emb_mat),
                    freeze=True,
                )
                logger.info(
                    f"load speaker embeddings from {args.speaker_emb_path}. "
                    f"train embedding? {embed_speaker.weight.requires_grad}\n"
                    f"embeddings:\n{speaker_emb_mat}"
                )
        return embed_speaker

    def build_model(self, cfg, from_checkpoint=False):
        cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
        cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
        cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
        cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
        cfg.speaker_emb_path = self.get_speaker_embeddings_path()
        model = super().build_model(cfg, from_checkpoint)
        self.generator = None
        if getattr(cfg, "eval_inference", False):
            self.generator = self.build_generator([model], cfg)
        return model

    def build_generator(self, models, cfg, vocoder=None, **unused):
        if vocoder is None:
            vocoder = self.build_default_vocoder()
        model = models[0]
        if getattr(model, "NON_AUTOREGRESSIVE", False):
            return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg)
        else:
            generator = AutoRegressiveSpeechGenerator
            if getattr(cfg, "teacher_forcing", False):
                generator = TeacherForcingAutoRegressiveSpeechGenerator
                logger.info("Teacher forcing mode for generation")
            return generator(
                model,
                vocoder,
                self.data_cfg,
                max_iter=self.args.max_target_positions,
                eos_prob_threshold=self.args.eos_prob_threshold,
            )

    def build_default_vocoder(self):
        from fairseq.models.text_to_speech.vocoder import get_vocoder

        vocoder = get_vocoder(self.args, self.data_cfg)
        if torch.cuda.is_available() and not self.args.cpu:
            vocoder = vocoder.cuda()
        else:
            vocoder = vocoder.cpu()
        return vocoder

    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(sample, model, criterion)

        if getattr(self.args, "eval_inference", False):
            hypos, inference_losses = self.valid_step_with_inference(
                sample, model, self.generator
            )
            for k, v in inference_losses.items():
                assert k not in logging_output
                logging_output[k] = v

            picked_id = 0
            if self.tensorboard_dir and (sample["id"] == picked_id).any():
                self.log_tensorboard(
                    sample,
                    hypos[: self.args.eval_tb_nsample],
                    model._num_updates,
                    is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False),
                )
        return loss, sample_size, logging_output

    def valid_step_with_inference(self, sample, model, generator):
        hypos = generator.generate(model, sample, has_targ=True)

        losses = {
            "mcd_loss": 0.0,
            "targ_frames": 0.0,
            "pred_frames": 0.0,
            "nins": 0.0,
            "ndel": 0.0,
        }
        rets = batch_mel_cepstral_distortion(
            [hypo["targ_waveform"] for hypo in hypos],
            [hypo["waveform"] for hypo in hypos],
            self.sr,
            normalize_type=None,
        )
        for d, extra in rets:
            pathmap = extra[-1]
            losses["mcd_loss"] += d.item()
            losses["targ_frames"] += pathmap.size(0)
            losses["pred_frames"] += pathmap.size(1)
            losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
            losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()

        return hypos, losses

    def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
        if self.tensorboard_writer is None:
            self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
        tb_writer = self.tensorboard_writer
        for b in range(len(hypos)):
            idx = sample["id"][b]
            text = sample["src_texts"][b]
            targ = hypos[b]["targ_feature"]
            pred = hypos[b]["feature"]
            attn = hypos[b]["attn"]

            if is_na_model:
                data = plot_tts_output(
                    [targ.transpose(0, 1), pred.transpose(0, 1)],
                    [f"target (idx={idx})", "output"],
                    attn,
                    "alignment",
                    ret_np=True,
                    suptitle=text,
                )
            else:
                eos_prob = hypos[b]["eos_prob"]
                data = plot_tts_output(
                    [targ.transpose(0, 1), pred.transpose(0, 1), attn],
                    [f"target (idx={idx})", "output", "alignment"],
                    eos_prob,
                    "eos prob",
                    ret_np=True,
                    suptitle=text,
                )

            tb_writer.add_image(
                f"inference_sample_{b}", data, num_updates, dataformats="HWC"
            )

            if hypos[b]["waveform"] is not None:
                targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
                pred_wave = hypos[b]["waveform"].detach().cpu().float()
                tb_writer.add_audio(
                    f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr
                )
                tb_writer.add_audio(
                    f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr
                )


def save_figure_to_numpy(fig):
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return data


DEFAULT_V_MIN = np.log(1e-5)


def plot_tts_output(
    data_2d,
    title_2d,
    data_1d,
    title_1d,
    figsize=(24, 4),
    v_min=DEFAULT_V_MIN,
    v_max=3,
    ret_np=False,
    suptitle="",
):
    try:
        import matplotlib.pyplot as plt
        from mpl_toolkits.axes_grid1 import make_axes_locatable
    except ImportError:
        raise ImportError("Please install Matplotlib: pip install matplotlib")

    data_2d = [
        x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x
        for x in data_2d
    ]
    fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
    if suptitle:
        fig.suptitle(suptitle[:400])  # capped at 400 chars
    axes = [axes] if len(data_2d) == 0 else axes
    for ax, x, name in zip(axes, data_2d, title_2d):
        ax.set_title(name)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        im = ax.imshow(
            x,
            origin="lower",
            aspect="auto",
            vmin=max(x.min(), v_min),
            vmax=min(x.max(), v_max),
        )
        fig.colorbar(im, cax=cax, orientation="vertical")

    if isinstance(data_1d, torch.Tensor):
        data_1d = data_1d.detach().cpu().numpy()
    axes[-1].plot(data_1d)
    axes[-1].set_title(title_1d)
    plt.tight_layout()

    if ret_np:
        fig.canvas.draw()
        data = save_figure_to_numpy(fig)
        plt.close(fig)
        return data


def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
    """
    for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs

    offset=2 (1, 1),
    offset=3 (2, 1), (1, 2)
    offset=4 (2, 2), (1, 3)
    offset=5 (2, 3)

    constraints:
        i + j = offset
        min_j <= j < max_j
        min_i <= offset - j < max_i
    """
    if max_i is None:
        max_i = offset + 1
    if max_j is None:
        max_j = offset + 1
    min_j = max(min_j, offset - max_i + 1, 0)
    max_j = min(max_j, offset - min_i + 1, offset + 1)
    j = torch.arange(min_j, max_j)
    i = offset - j
    return torch.stack([i, j])


def batch_dynamic_time_warping(distance, shapes=None):
    """full batched DTW without any constraints

    distance:  (batchsize, max_M, max_N) matrix
    shapes: (batchsize,) vector specifying (M, N) for each entry
    """
    # ptr: 0=left, 1=up-left, 2=up
    ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}

    bsz, m, n = distance.size()
    cumdist = torch.zeros_like(distance)
    backptr = torch.zeros_like(distance).type(torch.int32) - 1

    # initialize
    cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
    cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
    backptr[:, 0, :] = 0
    backptr[:, :, 0] = 2

    # DP with optimized anti-diagonal parallelization, O(M+N) steps
    for offset in range(2, m + n - 1):
        ind = antidiag_indices(offset, 1, m, 1, n)
        c = torch.stack(
            [
                cumdist[:, ind[0], ind[1] - 1],
                cumdist[:, ind[0] - 1, ind[1] - 1],
                cumdist[:, ind[0] - 1, ind[1]],
            ],
            dim=2,
        )
        v, b = c.min(axis=-1)
        backptr[:, ind[0], ind[1]] = b.int()
        cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]

    # backtrace
    pathmap = torch.zeros_like(backptr)
    for b in range(bsz):
        i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
        j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
        dtwpath = [(i, j)]
        while (i != 0 or j != 0) and len(dtwpath) < 10000:
            assert i >= 0 and j >= 0
            di, dj = ptr2dij[backptr[b, i, j].item()]
            i, j = i + di, j + dj
            dtwpath.append((i, j))
        dtwpath = dtwpath[::-1]
        indices = torch.from_numpy(np.array(dtwpath))
        pathmap[b, indices[:, 0], indices[:, 1]] = 1

    return cumdist, backptr, pathmap


def compute_l2_dist(x1, x2):
    """compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
    return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)


def compute_rms_dist(x1, x2):
    l2_dist = compute_l2_dist(x1, x2)
    return (l2_dist / x1.size(1)).pow(0.5)


def get_divisor(pathmap, normalize_type):
    if normalize_type is None:
        return 1
    elif normalize_type == "len1":
        return pathmap.size(0)
    elif normalize_type == "len2":
        return pathmap.size(1)
    elif normalize_type == "path":
        return pathmap.sum().item()
    else:
        raise ValueError(f"normalize_type {normalize_type} not supported")


def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
    d, s, x1, x2 = [], [], [], []
    for cur_y1, cur_y2 in zip(y1, y2):
        assert cur_y1.ndim == 1 and cur_y2.ndim == 1
        cur_x1 = feat_fn(cur_y1)
        cur_x2 = feat_fn(cur_y2)
        x1.append(cur_x1)
        x2.append(cur_x2)

        cur_d = dist_fn(cur_x1, cur_x2)
        d.append(cur_d)
        s.append(d[-1].size())
    max_m = max(ss[0] for ss in s)
    max_n = max(ss[1] for ss in s)
    d = torch.stack(
        [F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
    )
    s = torch.LongTensor(s).to(d.device)
    cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)

    rets = []
    itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
    for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
        cumdist = cumdist[:m, :n]
        backptr = backptr[:m, :n]
        pathmap = pathmap[:m, :n]
        divisor = get_divisor(pathmap, normalize_type)

        distortion = cumdist[-1, -1] / divisor
        ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
        rets.append(ret)
    return rets


def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None):
    """
    https://arxiv.org/pdf/2011.03568.pdf

    The root mean squared error computed on 13-dimensional MFCC using DTW for
    alignment. MFCC features are computed from an 80-channel log-mel
    spectrogram using a 50ms Hann window and hop of 12.5ms.

    y1: list of waveforms
    y2: list of waveforms
    sr: sampling rate
    """

    try:
        import torchaudio
    except ImportError:
        raise ImportError("Please install torchaudio: pip install torchaudio")

    if mfcc_fn is None or mfcc_fn.sample_rate != sr:
        melkwargs = {
            "n_fft": int(0.05 * sr),
            "win_length": int(0.05 * sr),
            "hop_length": int(0.0125 * sr),
            "f_min": 20,
            "n_mels": 80,
            "window_fn": torch.hann_window,
        }
        mfcc_fn = torchaudio.transforms.MFCC(
            sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
        ).to(y1[0].device)
    return batch_compute_distortion(
        y1,
        y2,
        sr,
        lambda y: mfcc_fn(y).transpose(-1, -2),
        compute_rms_dist,
        normalize_type,
    )
