import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang, lang_dict):
    language_column_name = f"{lang}_text"
    prompt_map = {
        "prompt_1": f"{lang_dict[lang]} sentence: {{{{{language_column_name}}}}} \nEnglish sentence: ",
        "prompt_1_reverse": "English sentence: {{eng_source_text}} "
        f"\n{lang_dict[lang]} sentence: ",
        "prompt_2": f"You are a translation expert. Translate the following {lang_dict[lang]} sentences to English \n"
        f"{lang_dict[lang]} sentence: {{{{{language_column_name}}}}}\nEnglish sentence: ",
        "prompt_2_reverse": f"You are a translation expert. Translate the following English sentences to "
        f"{lang_dict[lang]} "
        "\nEnglish sentence: {{eng_source_text}} "
        f"\n{lang_dict[lang]} sentence: ",
        "prompt_3": f"As a {lang_dict[lang]} and English linguist, translate the following {lang_dict[lang]} sentences "
        f"to English. \n{lang_dict[lang]} sentence: {{{{{language_column_name}}}}}\nEnglish sentence: ",
        "prompt_3_reverse": f"As a {lang_dict[lang]} and English linguist, translate the following English sentences to "
        f"{lang_dict[lang]}. "
        "\nEnglish sentence: {{eng_source_text}} "
        f"\n{lang_dict[lang]} sentence: ",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str, reverse: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "eng": "English",
        "lug": "Luganda",
        "ach": "Acholi",
        "lgg": "Lugbara",
        "teo": "Ateso",
        "nyn": "Runyankole",
        "swa": "Swahili",
        "ibo": "Igbo",
    }

    for lang in languages.keys():
        try:
            if lang != "eng":
                if not reverse:
                    file_name = f"salt_{lang}-eng.yaml"
                    task_name = f"salt_{lang}-eng_{mode}"
                    yaml_template = "salt"
                    yaml_details = {
                        "include": yaml_template,
                        "task": task_name,
                        "dataset_name": "text-all",
                        "doc_to_target": "eng_target_text",
                        "doc_to_text": prompt_func(mode, lang, languages),
                    }
                    os.makedirs(f"{output_dir}/{mode}", exist_ok=True)
                    with open(
                        f"{output_dir}/{mode}/{file_name}",
                        "w" if overwrite else "x",
                        encoding="utf8",
                    ) as f:
                        f.write("# Generated by utils.py\n")
                        yaml.dump(
                            yaml_details,
                            f,
                            allow_unicode=True,
                        )
                else:
                    file_name = f"salt_eng-{lang}.yaml"
                    task_name = f"salt_eng-{lang}_{mode}"
                    yaml_template = "salt"
                    yaml_details = {
                        "include": yaml_template,
                        "task": task_name,
                        "dataset_name": "text-all",
                        "doc_to_target": f"{lang}_text",
                        "doc_to_text": prompt_func(f"{mode}_reverse", lang, languages),
                    }
                    os.makedirs(f"{output_dir}/{mode}", exist_ok=True)
                    with open(
                        f"{output_dir}/{mode}/{file_name}",
                        "w" if overwrite else "x",
                        encoding="utf8",
                    ) as f:
                        f.write("# Generated by utils.py\n")
                        yaml.dump(
                            yaml_details,
                            f,
                            allow_unicode=True,
                        )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_1",
        choices=["prompt_1", "prompt_2", "prompt_3"],
        help="Prompt number",
    )
    parser.add_argument(
        "--reverse",
        default=True,
        choices=[True, False],
        help="Reverse the translation direction",
    )
    args = parser.parse_args()

    gen_lang_yamls(
        output_dir=args.output_dir,
        overwrite=args.overwrite,
        mode=args.mode,
        reverse=args.reverse,
    )


if __name__ == "__main__":
    main()
