"""OpenAI's non-english basic text normalization module"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_basic.ipynb.

# %% auto 0
__all__ = ['ADDITIONAL_DIACRITICS', 'remove_symbols_and_diacritics', 'remove_symbols', 'BasicTextNormalizer']

# %% ../nbs/00_basic.ipynb 4
# This code is from OpenAI Whisper Repository: https://github.com/openai/whisper/tree/main/whisper/normalizers
import re
import unicodedata

import regex

# from fastcore.foundation import add_docs


# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
    "œ": "oe",
    "Œ": "OE",
    "ø": "o",
    "Ø": "O",
    "æ": "ae",
    "Æ": "AE",
    "ß": "ss",
    "ẞ": "SS",
    "đ": "d",
    "Đ": "D",
    "ð": "d",
    "Ð": "D",
    "þ": "th",
    "Þ": "th",
    "ł": "l",
    "Ł": "L",
}


def remove_symbols_and_diacritics(s: str, keep=""):
    """
    Replace any other markers, symbols, and punctuations with a space,
    and drop any diacritics (category 'Mn' and some manual mappings)
    """
    return "".join(
        (
            c
            if c in keep
            else (
                ADDITIONAL_DIACRITICS[c]
                if c in ADDITIONAL_DIACRITICS
                else (
                    ""
                    if unicodedata.category(c) == "Mn"
                    else " " if unicodedata.category(c)[0] in "MSP" else c
                )
            )
        )
        for c in unicodedata.normalize("NFKD", s)
    )


def remove_symbols(s: str):
    """
    Replace any other markers, symbols, punctuations with a space, keeping diacritics
    """
    return "".join(
        " " if unicodedata.category(c)[0] in "MSP" else c
        for c in unicodedata.normalize("NFKC", s)
    )

# %% ../nbs/00_basic.ipynb 5
class BasicTextNormalizer:
    """As per the text normalization/standardization approach mentioned in  Appendix Section C pp.21 in  the paper [Robust Speech Recognition via Large-Scale  Weak Supervision](https://cdn.openai.com/papers/whisper.pdf). The `BasicTextNormalizer` does the following functionality:

        1. Remove any phrases between matching brackets ([, ]).
        2. Remove any phrases between matching parentheses ((, )).
        3. Replace any markers, symbols, and punctuation characters with a space, i.e. when the Unicode category of each
        character in the NFKC-normalized string starts with M, S, or P.
        4. make the text lowercase.
        5. replace any successive whitespace characters with a space

    Note: It's not recommended to use this function for non-english languages because it may removes vowels in languages as identified by [kavya in this tweet](https://twitter.com/kavya_manohar/status/1752574864618365059).
    """

    def __init__(
        self,
        remove_diacritics: bool = False,
        split_letters: bool = False,
    ):
        """
        remove_diaciritics - Replace any other markers, symbols, and punctuations with a space and drop any diacritics
        split_letters  - It uses a regular expression \X to find all Unicode graphemes (extended grapheme clusters) in the string s and join them together by space
        """
        self.clean = (
            remove_symbols_and_diacritics if remove_diacritics else remove_symbols
        )
        self.split_letters = split_letters

    def __call__(self, s: str):
        s = s.lower()
        s = re.sub(r"[<\[][^>\]]*[>\]]", "", s)  # remove words between brackets
        s = re.sub(r"\(([^)]+?)\)", "", s)  # remove words between parenthesis
        s = self.clean(s).lower()

        if self.split_letters:
            s = " ".join(regex.findall(r"\X", s, regex.U))

        s = re.sub(
            r"\s+", " ", s
        )  # replace any successive whitespace characters with a space

        return s
