import argparse
import os

import pycountry
import yaml


def get_language_from_code(code: str) -> str:
    language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
    return language_tuple.name


def prompt_func(mode):
    prompt_map = {
        "prompt_1": "You are a virtual assistant that answers multiple-choice questions with the correct option only.\n\n"
        "Question: {{question}}\n\n"
        "Choices:\n\n"
        "{% for i in range(choices['text']|length) %}"
        "\t{{ 'ABCD'[i] }}: {{ choices['text'][i] }}\n"
        "{% endfor %}\n"
        "Answer: ",
        "prompt_2": "Choose the correct option that answers the question below:\n\n"
        "Question: {{question}}\n\n"
        "Choices:\n\n"
        "{% for i in range(choices['text']|length) %}"
        "\t{{ 'ABCD'[i] }}: {{ choices['text'][i] }}\n"
        "{% endfor %}\n"
        "Answer: ",
        "prompt_3": "Answer the following multiple-choice question by picking 'A', 'B', 'C', or 'D'.\n\n"
        "Question: {{question}}\n\n"
        "Options:\n\n"
        "{% for i in range(choices['text']|length) %}"
        "\t{{ 'ABCD'[i] }}: {{ choices['text'][i] }}\n"
        "{% endfor %}\n"
        "Answer: ",
        "prompt_4": "Question: {{question}}\n\n"
        "Options:\n\n"
        "{% for i in range(choices['text']|length) %}"
        "\t{{ 'ABCD'[i] }}: {{ choices['text'][i] }}\n"
        "{% endfor %}\n"
        "Answer: ",
        "prompt_5": "Which of the following options answers this question: {{question}}\n\n"
        "{% for i in range(choices['text']|length) %}"
        "\t{{ 'ABCD'[i] }}: {{ choices['text'][i] }}\n"
        "{% endfor %}\n"
        "Answer: ",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {"am", "en", "ha", "nso", "sw", "yo", "zu"}

    for lang in languages:
        try:
            file_name = f"uhura-arc-easy_{lang}.yaml"
            task_name = f"uhura-arc-easy_{lang}_{mode}"
            yaml_template = "uhura-arc-easy"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": f"{lang}_multiple_choice{'_unmatched' if lang == 'nso' else ''}",
                "doc_to_text": prompt_func(mode),
            }
            if lang in ("nso", "zu"):
                yaml_details["fewshot_split"] = "train"

            file_path = os.path.join(output_dir, mode)
            os.makedirs(file_path, exist_ok=True)

            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )

    PROMPT_CHOICES = ["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"]
    parser.add_argument(
        "--mode",
        nargs="*",
        default=PROMPT_CHOICES,
        choices=PROMPT_CHOICES,
        help="Prompt number(s)",
    )
    args = parser.parse_args()

    for mode in args.mode:
        gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=mode)


if __name__ == "__main__":
    main()
