# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Union

import numpy as np
import torch

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging

__all__ = ['AggregateTokenizer', 'TokenizerWrapper']


class DummyTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.vocab_size = len(vocab)

    # minimum compatibility
    # since all the monolingual tokenizers have a vocab
    # additional methods could be added here
    def get_vocab(self):
        return self.vocab


class AggregateTokenizer(TokenizerSpec):
    '''
    AggregateTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer.
    The intuition is that we can use existing tokenizers "as is", without retraining, and associate each tokenizer with a language id
    during text processing (language id will be used to route the incoming text sample to the right tokenizer)
    as well as a token id range for detokenization (e.g. [0..127] for tokenizer A, [128..255] for tokenizer B) so
    that the orignal text could be reconstructed. Note that we assume that the incoming dict of langs / tokenizers
    is ordered, e.g. the first tokenizer will be assigned a lower interval of token ids
        Args:
        tokenizers: dict of tokenizers, keys are lang ids, values are actual tokenizers
    '''

    def __init__(self, tokenizers: Dict):

        self.tokenizers_dict = tokenizers
        self.vocabulary = []

        # the tokenizers should produce non-overlapping, ordered token ids
        # keys are language ids
        self.token_id_offset = {}

        # keys are tokenizer numbers
        self.token_id_offset_by_tokenizer_num = {}
        offset = 0
        i = 0
        for lang, tokenizer in self.tokenizers_dict.items():
            self.token_id_offset[lang] = offset
            self.token_id_offset_by_tokenizer_num[i] = offset
            offset += len(tokenizer.vocab)
            i += 1

        for tokenizer in self.tokenizers_dict.values():
            self.vocabulary.extend(tokenizer.vocab)

        self.vocab_size = len(self.vocabulary)
        logging.info(f'Aggregate vocab size: {self.vocab_size}')

        # for compatibility purposes only -- right now only the get_vocab method
        # is supported, returning the joint vocab across all tokenizers
        self.tokenizer = DummyTokenizer(self.vocabulary)

        # lookup tables to speed up token to text operations
        # if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer
        # token range will be [0,255]. The below method provides three look up tables:
        # one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73)
        # second, to compute the tokenizer id that should process that token (1)
        # third, the compute the lang id for that token ('es')
        offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets()

        self.offset_token_ids_by_token_id = offset_token_ids_by_token_id
        self.tokenizers_by_token_id = tokenizers_by_token_id
        self.langs_by_token_id = langs_by_token_id

    def _calculate_offsets(self):
        offsets = {}
        tokenizers = {}
        langs = {}
        cur_num = 0
        tot = len(self.tokenizers_dict)
        for id in range(len(self.vocabulary)):
            off_id = id - list(self.token_id_offset.values())[cur_num]
            if cur_num + 1 < tot:
                if id >= list(self.token_id_offset.values())[cur_num + 1]:
                    cur_num += 1
                    off_id = id - list(self.token_id_offset.values())[cur_num]
            offsets[id] = off_id
            tokenizers[id] = list(self.tokenizers_dict.values())[cur_num]
            langs[id] = list(self.tokenizers_dict.keys())[cur_num]

        return offsets, tokenizers, langs

    def text_to_tokens(self, text, lang_id):
        tokenizer = self.tokenizers_dict[lang_id]
        return tokenizer.text_to_tokens(text)

    def text_to_ids(self, text, lang_id):
        tokenizer = self.tokenizers_dict[lang_id]
        token_ids = tokenizer.text_to_ids(text)
        token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids]

        return token_ids

    def tokens_to_text(self, tokens, lang_id):
        if isinstance(tokens, np.ndarray):
            tokens = tokens.tolist()

        tokenizer = self.tokenizers_dict[lang_id]
        return tokenizer.decode_pieces(tokens)

    def ids_to_text(self, ids):
        if isinstance(ids, (np.ndarray, torch.Tensor)):
            ids = ids.tolist()

        tokens = []
        for id in ids:
            offset_id = self.offset_token_ids_by_token_id[id]
            tokenizer = self.tokenizers_by_token_id[id]
            tokens.extend(tokenizer.ids_to_tokens([offset_id]))
        text = ''.join(tokens).replace('▁', ' ')

        return text

    def token_to_id(self, token, lang_id):
        tokenizer = self.tokenizers_dict[lang_id]
        return tokenizer.token_to_id(token) + self.token_id_offset[lang_id]

    def ids_to_tokens(self, ids):
        tokens = []

        for id in ids:
            offset_id = self.offset_token_ids_by_token_id[id]
            tokenizer = self.tokenizers_by_token_id[id]
            token = tokenizer.ids_to_tokens([offset_id])[0]
            tokens.append(token)

        return tokens

    def ids_to_text_and_langs(self, ids):
        text_and_langs = []

        for id in ids:
            offset_id = self.offset_token_ids_by_token_id[id]
            tokenizer = self.tokenizers_by_token_id[id]
            token = tokenizer.ids_to_tokens([offset_id])[0]
            text = token.replace('▁', ' ')
            text = text.strip()  # strip for display purposes
            lang = self.langs_by_token_id[id]
            text_and_langs.append({'char': text, 'lang': lang})

        return text_and_langs

    def ids_to_words_and_langs(self, ids):
        words_and_langs = []

        word_ids = []  # tokens belonging to the current word
        for id in ids:
            offset_id = self.offset_token_ids_by_token_id[id]
            tokenizer = self.tokenizers_by_token_id[id]
            token = tokenizer.ids_to_tokens([offset_id])[0]
            if token.startswith('▁'):
                if len(word_ids) > 0:  # if this isn't the first word
                    word = self.ids_to_text(word_ids)
                    word = word.strip()  # strip for display purposes
                    lang = self.ids_to_lang(word_ids)
                    wl = {'word': word, 'lang': lang}
                    words_and_langs.append(wl)
                word_ids = []
            word_ids.append(id)

        if len(word_ids) > 0:  # the last tokens
            word = self.ids_to_text(word_ids)
            word = word.strip()  # strip for display purposes
            lang = self.ids_to_lang(word_ids)
            wl = {'word': word, 'lang': lang}
            words_and_langs.append(wl)

        return words_and_langs

    def ids_to_lang(self, ids):
        lang_cnts = {}

        for id in ids:
            lang = self.langs_by_token_id[id]
            lang_cnt = lang_cnts.get(lang)
            if lang_cnt is not None:
                lang_cnts[lang] = lang_cnt + 1
            else:
                lang_cnts[lang] = 1

        max_lang = ''
        max_lang_cnt = -1
        for lang, lang_cnt in lang_cnts.items():
            if lang_cnt > max_lang_cnt:
                max_lang = lang
                max_lang_cnt = lang_cnt

        return max_lang

    def tokens_to_ids(self, tokens: Union[str, List[str]], langs: Union[str, List[str]]) -> Union[int, List[int]]:
        if isinstance(tokens, str):
            tokens = [tokens]
        if isinstance(langs, str):
            langs = [langs]

        ids = []
        for i, token in enumerate(tokens):
            lang_id = langs[i]
            ids.append(self.token_to_id(token, lang_id))
        return ids

    def get_bos(self, lang_id: str) -> int:
        return self.tokenizers_dict[lang_id].bos + self.token_id_offset[lang_id]

    def get_eos(self, lang_id: str) -> int:
        return self.tokenizers_dict[lang_id].eos + self.token_id_offset[lang_id]

    @property
    def vocab(self):
        return self.vocabulary

    @property
    def langs(self):
        return list(self.tokenizers_dict.keys())


class TokenizerWrapper:
    """
    Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser.
    """

    def __init__(self, tokenizer):
        self._tokenizer = tokenizer
        if isinstance(tokenizer, AggregateTokenizer):
            self._impl = self._call_agg_tokenizer
        elif isinstance(tokenizer, TokenizerSpec):
            self._impl = self._call_tokenizer
        else:
            self._impl = self._call_parser

    def __call__(self, text: str, lang: str | None = None):
        return self._impl(text, lang)

    def _call_agg_tokenizer(self, text: str, lang: str | None = None):
        assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer."
        return self._tokenizer.text_to_ids(text, lang)

    def _call_tokenizer(self, text: str, lang: str | None = None):
        return self._tokenizer.text_to_ids(text)

    def _call_parser(self, text: str, lang: str | None = None):
        return self._tokenizer(text)
