import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang):
    prompt_map = {
        "prompt_1": "Named entities refers to names of location, organisation and personal name. \n For example, "
        "'David is an employee of Amazon and he is visiting New York next week to see Esther' will be \n"
        "PERSON: David $ ORGANIZATION: Amazon $ LOCATION: New York $ PERSON: Esther \n\n"
        "Ensure the output strictly follows the format: label: entity $ label: entity, with each unique "
        "entity on a separate label line, avoiding grouped entities (e.g., avoid LOC: entity, entity) or "
        "irrelevant entries like none. \n\nText: {{text}} \n"
        "Return only the output",
        "prompt_2": "You are working as a named entity recognition expert and your task is to label a given text "
        "with named entity labels. Your task is to identify and label any named entities present in the "
        "text. The named entity labels that you will be using are PER (person), LOC (location), "
        "ORG (organization) and DATE (date). Label multi-word entities as a single named entity. "
        "For words which are not part of any named entity, do not return any value for it. \n"
        "Ensure the output strictly follows the format: label: entity $$ label: entity, with each unique "
        "entity on a separate label line, avoiding grouped entities (e.g., avoid LOC: entity, entity) or "
        "irrelevant entries like none. Return only the output \n\nText: {{text}}",
        "prompt_3": f"You are a Named Entity Recognition expert in {lang} language. \nExtract all named entities from "
        f"the following {lang} text and categorize them into PERSON, LOCATION, ORGANIZATION, or DATE. "
        f"Ensure the output strictly follows the format: label: entity $$ label: entity, with each unique "
        "entity on a separate label line, avoiding grouped entities (e.g., avoid LOC: entity, entity) or "
        "irrelevant entries like none. Return only the output \n\nText: {{text}}",
        "prompt_4": f"As a {lang} linguist, label all named entities in the {lang} text below with the categories: "
        "PERSON, LOCATION, ORGANIZATION, and DATE. Ensure the output strictly follows the format: label: "
        "entity $$ label: entity, with each unique entity on a separate label line, avoiding grouped "
        "entities (e.g., avoid LOC: entity, entity) or irrelevant entries like none. Return only the "
        "output. \n\nText: {{text}}",
        "prompt_5": "Provide a concise list of named entities in the text below. Use the following labels: "
        "PERSON, LOCATION, ORGANIZATION, and DATE. Ensure the output strictly follows the format: label: "
        "entity $$ label: entity, with each unique entity on a separate label line, avoiding grouped "
        "entities (e.g., avoid LOC: entity, entity) or irrelevant entries like none. Return only the "
        "output.  \n\nText: {{text}}",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "am": "Amharic",
        "bm": "Bambara",
        "bbj": "Ghomala",
        "ee": "Ewe",
        "ha": "Hausa",
        "ig": "Igbo",
        "rw": "Kinyarwanda",
        "lg": "Luganda",
        "luo": "Luo",
        "mos": "Mossi",
        "ny": "Chichewa",
        "pcm": "Nigerian Pidgin",
        "sn": "chiShona",
        "sw": "Kiswahili",
        "tn": "Setswana",
        "tw": "Twi",
        "wo": "Wolof",
        "xh": "isiXhosa",
        "yo": "Yoruba",
        "zu": "isiZulu",
    }

    for lang in languages.keys():
        try:
            file_name = f"masakhaner_{lang}.yaml"
            task_name = f"masakhaner_{lang}_{mode}"
            yaml_template = "masakhaner"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang,
                "doc_to_text": prompt_func(mode, languages[lang]),
            }
            os.makedirs(f"{output_dir}/{mode}", exist_ok=True)
            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_1",
        choices=["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"],
        help="Prompt number",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()
