# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from:
# Link: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
# Link: https://github.com/huggingface/datasets/blob/master/metrics/squad/squad.py
import re
import string
from collections import Counter
from typing import Any, Callable, Union

from torch import Tensor, tensor

from torchmetrics.utilities import rank_zero_warn

SINGLE_PRED_TYPE = dict[str, str]
PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]]
SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]]
TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]]
UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]]

SQuAD_FORMAT = {
    "answers": {"answer_start": [1], "text": ["This is a test text"]},
    "context": "This is a test context.",
    "id": "1",
    "question": "Is this a test?",
    "title": "train test",
}


def _normalize_text(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text: str) -> str:
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text: str) -> str:
        return " ".join(text.split())

    def remove_punc(text: str) -> str:
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text: str) -> str:
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def _get_tokens(s: str) -> list[str]:
    """Split a sentence into separate tokens."""
    return [] if not s else _normalize_text(s).split()


def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor:
    """Compute F1 Score for two sentences."""
    target_tokens = _get_tokens(target_answer)
    predicted_tokens = _get_tokens(predicted_answer)
    common = Counter(target_tokens) & Counter(predicted_tokens)
    num_same = tensor(sum(common.values()))
    if len(target_tokens) == 0 or len(predicted_tokens) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return tensor(int(target_tokens == predicted_tokens))
    if num_same == 0:
        return tensor(0.0)
    precision = 1.0 * num_same / tensor(len(predicted_tokens))
    recall = 1.0 * num_same / tensor(len(target_tokens))
    return (2 * precision * recall) / (precision + recall)


def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor:
    """Compute Exact Match for two sentences."""
    return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth)))


def _metric_max_over_ground_truths(
    metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str]
) -> Tensor:
    """Calculate maximum score for a predicted answer with all reference answers."""
    return max(metric_fn(prediction, truth) for truth in ground_truths)  # type: ignore[type-var]


def _squad_input_check(
    preds: PREDS_TYPE, targets: TARGETS_TYPE
) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]:
    """Check for types and convert the input to necessary format to compute the input."""
    if isinstance(preds, dict):
        preds = [preds]

    if isinstance(targets, dict):
        targets = [targets]

    for pred in preds:
        pred_keys = pred.keys()
        if "prediction_text" not in pred_keys or "id" not in pred_keys:
            raise KeyError(
                "Expected keys in a single prediction are 'prediction_text' and 'id'."
                "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string."
            )

    for target in targets:
        target_keys = target.keys()
        if "answers" not in target_keys or "id" not in target_keys:
            raise KeyError(
                "Expected keys in a single target are 'answers' and 'id'."
                "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n"
                "SQuAD Format: "
                f"{SQuAD_FORMAT}"
            )

        answers: dict[str, Union[list[str], list[int]]] = target["answers"]  # type: ignore[assignment]
        if "text" not in answers:
            raise KeyError(
                "Expected keys in a 'answers' are 'text'."
                "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n"
                "SQuAD Format: "
                f"{SQuAD_FORMAT}"
            )

    preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds}
    _fn_answer = lambda tgt: {"answers": [{"text": txt} for txt in tgt["answers"]["text"]], "id": tgt["id"]}
    targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}]
    return preds_dict, targets_dict


def _squad_update(
    preds: dict[str, str],
    target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]],
) -> tuple[Tensor, Tensor, Tensor]:
    """Compute F1 Score and Exact Match for a collection of predictions and references.

    Args:
        preds: A dictionary mapping an `id` to the predicted `answer`.
        target:
            A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary
            containing `id` and list of all possible `answers`.

    Return:
        Tuple containing F1 score, Exact match score and total number of examples.

    Example:
        >>> from torchmetrics.functional.text.squad import _squad_update
        >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
        >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
        >>> preds_dict = {pred["id"]: pred["prediction_text"] for pred in preds}
        >>> targets_dict = [
        ...     dict(paragraphs=[dict(qas=[dict(answers=[
        ...         {"text": txt} for txt in tgt["answers"]["text"]], id=tgt["id"]) for tgt in target
        ...     ])])
        ... ]
        >>> _squad_update(preds_dict, targets_dict)
        (tensor(1.), tensor(1.), tensor(1))

    """
    f1 = tensor(0.0)
    exact_match = tensor(0.0)
    total = tensor(0)
    for article in target:
        for paragraph in article["paragraphs"]:
            for qa in paragraph["qas"]:
                total += 1
                if qa["id"] not in preds:
                    rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.")
                    continue
                ground_truths = [x["text"] for x in qa["answers"]]
                pred = preds[qa["id"]]
                exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, pred, ground_truths)
                f1 += _metric_max_over_ground_truths(_compute_f1_score, pred, ground_truths)

    return f1, exact_match, total


def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]:
    """Aggregate the F1 Score and Exact match for the batch.

    Return:
        Dictionary containing the F1 score, Exact match score for the batch.

    """
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    return {"exact_match": exact_match, "f1": f1}


def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]:
    """Calculate `SQuAD Metric`_ .

    Args:
        preds: A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values.

            Example prediction:

            .. code-block:: python

                {"prediction_text": "TorchMetrics is awesome", "id": "123"}

        target: A Dictionary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format.

            Example target:

            .. code-block:: python

                {
                    'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
                    'id': '1',
                }

            Reference SQuAD Format:

            .. code-block:: python

                {
                    'answers': {'answer_start': [1], 'text': ['This is a test text']},
                    'context': 'This is a test context.',
                    'id': '1',
                    'question': 'Is this a test?',
                    'title': 'train test'
                }


    Return:
        Dictionary containing the F1 score, Exact match score for the batch.

    Example:
        >>> from torchmetrics.functional.text.squad import squad
        >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
        >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
        >>> squad(preds, target)
        {'exact_match': tensor(100.), 'f1': tensor(100.)}

    Raises:
        KeyError:
            If the required keys are missing in either predictions or targets.

    References:
        [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin
        Lopyrev, Percy Liang `SQuAD Metric`_ .

    """
    preds_dict, target_dict = _squad_input_check(preds, target)
    f1, exact_match, total = _squad_update(preds_dict, target_dict)
    return _squad_compute(f1, exact_match, total)
