import argparse
import os

import yaml


def prompt_func(mode, lang):
    prompt_map = {
        "prompt_1": "P: {{flores_passage}}\nQ: {{question.strip()}}\nA: {{mc_answer1}}\nB: {{mc_answer2}}\nC: {{mc_answer3}}\nD: {{mc_answer4}}\nPlease choose the correct answer from the options above:",
        "prompt_2": "Passage: {{flores_passage}}\nQuestion: {{question.strip()}}\n1: {{mc_answer1}}\n2: {{mc_answer2}}\n3: {{mc_answer3}}\n4: {{mc_answer4}}\nPlease select the correct answer from the given choices:",
        "prompt_3": "Context: {{flores_passage}}\nQuery: {{question.strip()}}\nOption A: {{mc_answer1}}\nOption B: {{mc_answer2}}\nOption C: {{mc_answer3}}\nOption D: {{mc_answer4}}\nPlease indicate the correct option from the list above:",
        "prompt_4": "{{flores_passage}}\nBased on the above passage, answer the following question:\n{{question.strip()}}\nChoices:\nA) {{mc_answer1}}\nB) {{mc_answer2}}\nC) {{mc_answer3}}\nD) {{mc_answer4}}\nPlease provide the correct answer from the choices given:",
        "prompt_5": "Read the passage: {{flores_passage}}\nThen answer the question: {{question.strip()}}\nOptions:\nA. {{mc_answer1}}\nB. {{mc_answer2}}\nC. {{mc_answer3}}\nD. {{mc_answer4}}\nPlease choose the correct option from the above list:",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "afr": "Afrikaans",
        "amh": "Amharic",
        "ary": "Moroccan Arabic",
        "arz": "Egyptian Arabic",
        "bam": "Bambara",
        "eng": "English",
        "fra": "French",
        "hau": "Hausa",
        "ibo": "Igbo",
        "lin": "Lingala",
        "por": "Portuguese",
        "sna": "Shona",
        "swa": "Swahili",
        "tir": "Tigrinya",
        "tso": "Tsonga",
        "tsn": "Tswana",
        "wol": "Wolof",
        "xho": "Xhosa",
        "yor": "Yoruba",
        "zul": "Zulu",
        "ssw": "Swati",
        "sot": "Southern Sotho",
        "som": "Somali",
        "plt": "Plateau Malagasy",
        "nya": "Nyanja",
        "luo": "Luo",
        "lug": "Luganda",
        "kin": "Kinyarwanda",
        "kea": "Kabuverdianu",
        "gaz": "Oromo",
        "fuv": "Nigerian Fulfulde",
    }

    lang_2_dataset_lang_code = {
        "afr": "afr_Latn",
        "amh": "amh_Ethi",
        "ary": "ary_Arab",
        "arz": "arz_Arab",
        "bam": "bam_Latn",
        "eng": "eng_Latn",
        "fra": "fra_Latn",
        "hau": "hau_Latn",
        "ibo": "ibo_Latn",
        "lin": "lin_Latn",
        "por": "por_Latn",
        "sna": "sna_Latn",
        "swa": "swh_Latn",
        "tir": "tir_Ethi",
        "tso": "tso_Latn",
        "tsn": "tsn_Latn",
        "wol": "wol_Latn",
        "xho": "xho_Latn",
        "yor": "yor_Latn",
        "zul": "zul_Latn",
        "ssw": "ssw_Latn",
        "sot": "sot_Latn",
        "som": "som_Latn",
        "plt": "plt_Latn",
        "nya": "nya_Latn",
        "luo": "luo_Latn",
        "lug": "lug_Latn",
        "kin": "kin_Latn",
        "kea": "kea_Latn",
        "gaz": "gaz_Latn",
        "fuv": "fuv_Latn",
    }

    for lang in languages.keys():
        try:
            file_name = f"belebele_{lang}.yaml"
            task_name = f"belebele_{lang}_{mode}"
            yaml_template = "belebele"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang_2_dataset_lang_code[lang],
                "doc_to_text": prompt_func(mode, languages[lang]),
            }
            file_path = os.path.join(output_dir, mode)
            os.makedirs(file_path, exist_ok=True)

            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_5",
        choices=["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"],
        help="Prompt number",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()
