"""Parallel beam search module for online simulation."""

import logging
from pathlib import Path
from typing import List

import torch
import yaml

from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect


class BatchBeamSearchOnlineSim(BatchBeamSearch):
    """Online beam search implementation.

    This simulates streaming decoding.
    It requires encoded features of entire utterance and
    extracts block by block from it as it shoud be done
    in streaming processing.
    This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
    WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
    (https://arxiv.org/abs/2006.14941).
    """

    def set_streaming_config(self, asr_config: str):
        """Set config file for streaming decoding.

        Args:
            asr_config (str): The config file for asr training

        """
        train_config_file = Path(asr_config)
        self.block_size = None
        self.hop_size = None
        self.look_ahead = None
        config = None
        with train_config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
            if "encoder_conf" in args.keys():
                if "block_size" in args["encoder_conf"].keys():
                    self.block_size = args["encoder_conf"]["block_size"]
                if "hop_size" in args["encoder_conf"].keys():
                    self.hop_size = args["encoder_conf"]["hop_size"]
                if "look_ahead" in args["encoder_conf"].keys():
                    self.look_ahead = args["encoder_conf"]["look_ahead"]
            elif "config" in args.keys():
                config = args["config"]
                if config is None:
                    logging.info(
                        "Cannot find config file for streaming decoding: "
                        + "apply batch beam search instead."
                    )
                    return
        if (
            self.block_size is None or self.hop_size is None or self.look_ahead is None
        ) and config is not None:
            config_file = Path(config)
            with config_file.open("r", encoding="utf-8") as f:
                args = yaml.safe_load(f)
            if "encoder_conf" in args.keys():
                enc_args = args["encoder_conf"]
            if enc_args and "block_size" in enc_args:
                self.block_size = enc_args["block_size"]
            if enc_args and "hop_size" in enc_args:
                self.hop_size = enc_args["hop_size"]
            if enc_args and "look_ahead" in enc_args:
                self.look_ahead = enc_args["look_ahead"]

    def set_block_size(self, block_size: int):
        """Set block size for streaming decoding.

        Args:
            block_size (int): The block size of encoder
        """
        self.block_size = block_size

    def set_hop_size(self, hop_size: int):
        """Set hop size for streaming decoding.

        Args:
            hop_size (int): The hop size of encoder
        """
        self.hop_size = hop_size

    def set_look_ahead(self, look_ahead: int):
        """Set look ahead size for streaming decoding.

        Args:
            look_ahead (int): The look ahead size of encoder
        """
        self.look_ahead = look_ahead

    def forward(
        self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
    ) -> List[Hypothesis]:
        """Perform beam search.

        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
            minlenratio (float): Input length ratio to obtain min output length.

        Returns:
            list[Hypothesis]: N-best decoding results

        """
        self.conservative = True  # always true

        if self.block_size and self.hop_size and self.look_ahead:
            cur_end_frame = int(self.block_size - self.look_ahead)
        else:
            cur_end_frame = x.shape[0]
        process_idx = 0
        if cur_end_frame < x.shape[0]:
            h = x.narrow(0, 0, cur_end_frame)
        else:
            h = x

        # set length bounds
        if maxlenratio == 0:
            maxlen = x.shape[0]
        else:
            maxlen = max(1, int(maxlenratio * x.size(0)))
        minlen = int(minlenratio * x.size(0))
        logging.info("decoder input length: " + str(x.shape[0]))
        logging.info("max output length: " + str(maxlen))
        logging.info("min output length: " + str(minlen))

        # main loop of prefix search
        running_hyps = self.init_hyp(h)
        prev_hyps = []
        ended_hyps = []
        prev_repeat = False

        continue_decode = True

        while continue_decode:
            move_to_next_block = False
            if cur_end_frame < x.shape[0]:
                h = x.narrow(0, 0, cur_end_frame)
            else:
                h = x

            # extend states for ctc
            self.extend(h, running_hyps)

            while process_idx < maxlen:
                logging.debug("position " + str(process_idx))
                best = self.search(running_hyps, h)

                if process_idx == maxlen - 1:
                    # end decoding
                    running_hyps = self.post_process(
                        process_idx, maxlen, maxlenratio, best, ended_hyps
                    )
                n_batch = best.yseq.shape[0]
                local_ended_hyps = []
                is_local_eos = (
                    best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
                )
                for i in range(is_local_eos.shape[0]):
                    if is_local_eos[i]:
                        hyp = self._select(best, i)
                        local_ended_hyps.append(hyp)
                    # NOTE(tsunoo): check repetitions here
                    # This is a implicit implementation of
                    # Eq (11) in https://arxiv.org/abs/2006.14941
                    # A flag prev_repeat is used instead of using set
                    elif (
                        not prev_repeat
                        and best.yseq[i, -1] in best.yseq[i, :-1]
                        and cur_end_frame < x.shape[0]
                    ):
                        move_to_next_block = True
                        prev_repeat = True
                if maxlenratio == 0.0 and end_detect(
                    [lh.asdict() for lh in local_ended_hyps], process_idx
                ):
                    logging.info(f"end detected at {process_idx}")
                    continue_decode = False
                    break
                if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
                    move_to_next_block = True

                if move_to_next_block:
                    if (
                        self.hop_size
                        and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
                        < x.shape[0]
                    ):
                        cur_end_frame += int(self.hop_size)
                    else:
                        cur_end_frame = x.shape[0]
                    logging.debug("Going to next block: %d", cur_end_frame)
                    if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
                        running_hyps = prev_hyps
                        process_idx -= 1
                        prev_hyps = []
                    break

                prev_repeat = False
                prev_hyps = running_hyps
                running_hyps = self.post_process(
                    process_idx, maxlen, maxlenratio, best, ended_hyps
                )

                if cur_end_frame >= x.shape[0]:
                    for hyp in local_ended_hyps:
                        ended_hyps.append(hyp)

                if len(running_hyps) == 0:
                    logging.info("no hypothesis. Finish decoding.")
                    continue_decode = False
                    break
                else:
                    logging.debug(f"remained hypotheses: {len(running_hyps)}")
                # increment number
                process_idx += 1

        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
            )
            return (
                []
                if minlenratio < 0.1
                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
            )

        # report the best result
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
                + "\n"
            )
        return nbest_hyps

    def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
        """Extend probabilities and states with more encoded chunks.

        Args:
            x (torch.Tensor): The extended encoder output feature
            hyps (Hypothesis): Current list of hypothesis

        Returns:
            Hypothesis: The extended hypothesis

        """
        for k, d in self.scorers.items():
            if hasattr(d, "extend_prob"):
                d.extend_prob(x)
            if hasattr(d, "extend_state"):
                hyps.states[k] = d.extend_state(hyps.states[k])
