# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from collections.abc import Sequence
from typing import Callable, Optional, Union

import torch
from torch import Tensor, tensor


def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
    """Count how many times each word appears in a given text with ngram.

    Args:
        ngram_input_list: A list of translated text or reference texts
        n_gram: gram value ranged 1 to 4

    Return:
        ngram_counter: a collections.Counter object of ngram

    """
    ngram_counter: Counter = Counter()

    for i in range(1, n_gram + 1):
        for j in range(len(ngram_input_list) - i + 1):
            ngram_key = tuple(ngram_input_list[j : (i + j)])
            ngram_counter[ngram_key] += 1

    return ngram_counter


def _tokenize_fn(sentence: str) -> Sequence[str]:
    """Tokenizes sentence into list of words.

    Args:
        sentence: A sentence separated by white space.

    Return:
        List of words

    """
    return sentence.split()


def _bleu_score_update(
    preds: Sequence[str],
    target: Sequence[Sequence[str]],
    numerator: Tensor,
    denominator: Tensor,
    preds_len: Tensor,
    target_len: Tensor,
    n_gram: int = 4,
    tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn,
) -> tuple[Tensor, Tensor]:
    """Update and returns variables required to compute the BLEU score.

    Args:
        preds: An iterable of machine translated corpus
        target: An iterable of iterables of reference corpus
        numerator: Numerator of precision score (true positives)
        denominator: Denominator of precision score (true positives + false positives)
        preds_len: count of words in a candidate prediction
        target_len: count of words in a reference translation
        target: count of words in a reference translation
        n_gram: gram value ranged 1 to 4
        tokenizer: A function that turns sentence into list of words

    """
    target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target]
    preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds]

    for pred, targets in zip(preds_, target_):
        preds_len += len(pred)
        target_len_list = [len(tgt) for tgt in targets]
        target_len_diff = [abs(len(pred) - x) for x in target_len_list]
        target_len += target_len_list[target_len_diff.index(min(target_len_diff))]
        preds_counter: Counter = _count_ngram(pred, n_gram)
        target_counter: Counter = Counter()

        for tgt in targets:
            target_counter |= _count_ngram(tgt, n_gram)

        ngram_counter_clip = preds_counter & target_counter

        for counter_clip in ngram_counter_clip:
            numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]

        for counter in preds_counter:
            denominator[len(counter) - 1] += preds_counter[counter]

    return preds_len, target_len


def _bleu_score_compute(
    preds_len: Tensor,
    target_len: Tensor,
    numerator: Tensor,
    denominator: Tensor,
    n_gram: int,
    weights: Sequence[float],
    smooth: bool,
) -> Tensor:
    """Compute the BLEU score.

    Args:
        preds_len: count of words in a candidate translation
        target_len: count of words in a reference translation
        numerator: Numerator of precision score (true positives)
        denominator: Denominator of precision score (true positives + false positives)
        n_gram: gram value ranged 1 to 4
        weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score.
        smooth: Whether to apply smoothing

    """
    device = numerator.device
    if min(numerator) == 0.0:
        return tensor(0.0, device=device)

    if smooth:
        precision_scores = torch.div(
            torch.add(numerator, torch.ones(n_gram, device=device)),
            torch.add(denominator, torch.ones(n_gram, device=device)),
        )
        precision_scores[0] = numerator[0] / denominator[0]
    else:
        precision_scores = numerator / denominator

    log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores)
    geometric_mean = torch.exp(torch.sum(log_precision_scores))
    brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len))
    return brevity_penalty * geometric_mean


def bleu_score(
    preds: Union[str, Sequence[str]],
    target: Sequence[Union[str, Sequence[str]]],
    n_gram: int = 4,
    smooth: bool = False,
    weights: Optional[Sequence[float]] = None,
) -> Tensor:
    """Calculate `BLEU score`_ of machine translated text with one or more references.

    Args:
        preds: An iterable of machine translated corpus
        target: An iterable of iterables of reference corpus
        n_gram: Gram value ranged from 1 to 4
        smooth: Whether to apply smoothing - see [2]
        weights:
            Weights used for unigrams, bigrams, etc. to calculate BLEU score.
            If not provided, uniform weights are used.

    Return:
        Tensor with BLEU Score

    Raises:
        ValueError: If ``preds`` and ``target`` corpus have different lengths.
        ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.

    Example:
        >>> from torchmetrics.functional.text import bleu_score
        >>> preds = ['the cat is on the mat']
        >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
        >>> bleu_score(preds, target)
        tensor(0.7598)

    References:
        [1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
        Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_

        [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
        and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_

    """
    preds_ = [preds] if isinstance(preds, str) else preds
    target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target]

    if len(preds_) != len(target_):
        raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}")

    if weights is not None and len(weights) != n_gram:
        raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
    if weights is None:
        weights = [1.0 / n_gram] * n_gram

    numerator = torch.zeros(n_gram)
    denominator = torch.zeros(n_gram)
    preds_len = tensor(0.0)
    target_len = tensor(0.0)

    preds_len, target_len = _bleu_score_update(
        preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn
    )

    return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)
