import argparse

import yaml


# Different languages that are part of xnli.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.

LANGUAGES = {
    "de": {  # German
        "QUESTION_WORD": "richtig",
        "YES": "Ja",
        "NO": "Nein",
    },
    "en": {  # English
        "QUESTION_WORD": "right",
        "YES": "Yes",
        "NO": "No",
    },
    "es": {  # Spanish
        "QUESTION_WORD": "verdad",
        "YES": "Sí",
        "NO": "No",
    },
    "fr": {  # French
        "QUESTION_WORD": "n'est-ce pas",
        "YES": "Oui",
        "NO": "No",
    },
    "ja": {  # Japanese
        "QUESTION_WORD": "ですね",
        "YES": "はい",
        "NO": "いいえ",
    },
    "ko": {  # Korean
        "QUESTION_WORD": "맞죠",
        "YES": "예",
        "NO": "아니요",
    },
    "zh": {  # Chinese
        "QUESTION_WORD": "对吧",
        "YES": "是",
        "NO": "不是",
    },
}


def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        file_name = f"paws_{lang}.yaml"
        try:
            QUESTION_WORD = LANGUAGES[lang]["QUESTION_WORD"]
            YES = LANGUAGES[lang]["YES"]
            NO = LANGUAGES[lang]["NO"]
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": "pawsx_template_yaml",
                        "dataset_name": lang,
                        "task": f"paws_{lang}",
                        "doc_to_text": "",
                        "doc_to_choice": f"{{{{["
                        f"""sentence1+\", {QUESTION_WORD}? {YES}, \"+sentence2,"""
                        f""" sentence1+\", {QUESTION_WORD}? {NO}, \"+sentence2"""
                        f"]}}}}",
                    },
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)


if __name__ == "__main__":
    main()
