import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang, lang_dict):
    language_column_name = f"sentence_{lang}"
    prompt_map = {
        "prompt_1": f"{lang_dict[lang]}: {{{{{language_column_name}}}}} \nEnglish: ",
        "prompt_1_reverse": f"English: {{{{sentence_eng_Latn}}}} \n{lang_dict[lang]}: ",
        "prompt_2": f"You are a translation expert. Translate the following {lang_dict[lang]} sentences to English \n"
        f"{lang_dict[lang]}: {{{{{language_column_name}}}}}\nEnglish: ",
        "prompt_2_reverse": f"You are a translation expert. Translate the following English sentences to "
        f"{lang_dict[lang]} "
        "\nEnglish: {{sentence_eng_Latn}} "
        f"\n{lang_dict[lang]}: ",
        "prompt_3": f"As a {lang_dict[lang]} and English linguist, translate the following {lang_dict[lang]} sentences "
        f"to English \n{lang_dict[lang]}: {{{{{language_column_name}}}}}\nEnglish: ",
        "prompt_3_reverse": f"As a {lang_dict[lang]} and English linguist, translate the following English sentences to "
        f"{lang_dict[lang]} "
        "\nEnglish: {{sentence_eng_Latn}} "
        f"\n{lang_dict[lang]}: ",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str, reverse: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "ace_Latn": "Acehnese (Latin script)",
        "ace_Arab": "Acehnese (Arabic script)",
        "acq_Arab": "Ta’izzi-Adeni Arabic",
        "aeb_Arab": "Tunisian Arabic",
        "afr_Latn": "Afrikaans",
        "aka_Latn": "Akan",
        "amh_Ethi": "Amharic",
        "ary_Arab": "Moroccan Arabic",
        "arz_Arab": "Egyptian Arabic",
        "bam_Latn": "Bambara",
        "ban_Latn": "Balinese",
        "bem_Latn": "Bemba",
        "cjk_Latn": "Chokwe",
        "dik_Latn": "Southwestern Dinka",
        "dyu_Latn": "Dyula",
        "ewe_Latn": "Ewe",
        "fon_Latn": "Fon",
        "fra_Latn": "French",
        "fuv_Latn": "Nigerian Fulfulde",
        "hau_Latn": "Hausa",
        "ibo_Latn": "Igbo",
        "kab_Latn": "Kabyle",
        "kam_Latn": "Kamba",
        "knc_Arab": "Central Kanuri (Arabic script)",
        "knc_Latn": "Central Kanuri (Latin script)",
        "kbp_Latn": "Kabiyè",
        "kea_Latn": "Kabuverdianu",
        "kik_Latn": "Kikuyu",
        "kin_Latn": "Kinyarwanda",
        "kmb_Latn": "Kimbundu",
        "kon_Latn": "Kikongo",
        "lin_Latn": "Lingala",
        "lua_Latn": "Luba-Kasai",
        "lug_Latn": "Luganda",
        "luo_Latn": "Luo",
        "plt_Latn": "Plateau Malagasy",
        "mos_Latn": "Mossi",
        "nso_Latn": "Northern Sotho",
        "nus_Latn": "Nuer",
        "nya_Latn": "Nyanja",
        "gaz_Latn": "Oromo",
        "run_Latn": "Rundi",
        "sag_Latn": "Sango",
        "sna_Latn": "Shona",
        "som_Latn": "Somali",
        "sot_Latn": "Southern Sotho",
        "ssw_Latn": "Swati",
        "sun_Latn": "Sundanese",
        "swh_Latn": "Swahili",
        "tir_Ethi": "Tigrinya",
        "taq_Latn": "Tamasheq",
        "taq_Tfng": "Tamasheq (Tifinagh script)",
        "tsn_Latn": "Setswana",
        "tso_Latn": "Tsonga",
        "tum_Latn": "Tumbuka",
        "twi_Latn": "Twi",
        "tzm_Tfng": "Central Atlas Tamazight",
        "umb_Latn": "Umbundu",
        "wol_Latn": "Wolof",
        "xho_Latn": "Xhosa",
        "yor_Latn": "Yoruba",
        "zul_Latn": "Zulu",
    }

    for lang in languages.keys():
        try:
            if not reverse:
                file_name = f"flores_{lang}-eng_Latn.yaml"
                task_name = f"flores_{lang}-eng_Latn_{mode}"
                yaml_template = "flores"
                yaml_details = {
                    "include": yaml_template,
                    "task": task_name,
                    "dataset_name": f"{lang}-eng_Latn",
                    "doc_to_target": "sentence_eng_Latn",
                    "doc_to_text": prompt_func(mode, lang, languages),
                }
                os.makedirs(f"{output_dir}/{mode}/african-english", exist_ok=True)
                with open(
                    f"{output_dir}/{mode}/african-english/{file_name}",
                    "w" if overwrite else "x",
                    encoding="utf8",
                ) as f:
                    f.write("# Generated by utils.py\n")
                    yaml.dump(
                        yaml_details,
                        f,
                        allow_unicode=True,
                    )
            else:
                file_name = f"flores_eng_Latn-{lang}.yaml"
                task_name = f"flores_eng_Latn-{lang}_{mode}"
                yaml_template = "flores"
                # mode_reverse = f"{mode}_reverse"
                yaml_details = {
                    "include": yaml_template,
                    "task": task_name,
                    "dataset_name": f"eng_Latn-{lang}",
                    "doc_to_target": f"sentence_{lang}",
                    "doc_to_text": prompt_func(f"{mode}_reverse", lang, languages),
                }
                os.makedirs(f"{output_dir}/{mode}/english-african", exist_ok=True)
                with open(
                    f"{output_dir}/{mode}/english-african/{file_name}",
                    "w" if overwrite else "x",
                    encoding="utf8",
                ) as f:
                    f.write("# Generated by utils.py\n")
                    yaml.dump(
                        yaml_details,
                        f,
                        allow_unicode=True,
                    )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_1",
        choices=["prompt_1", "prompt_2", "prompt_3"],
        help="Prompt number",
    )
    parser.add_argument(
        "--reverse",
        default=True,
        choices=[True, False],
        help="Reverse the translation direction",
    )
    args = parser.parse_args()

    gen_lang_yamls(
        output_dir=args.output_dir,
        overwrite=args.overwrite,
        mode=args.mode,
        reverse=args.reverse,
    )


if __name__ == "__main__":
    main()
