import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang, intent):
    prompt_map = {
        "prompt_1": "Given the text: '{{text}}', determine the correct intent from the following list: "
        f"[{', '.join(intent)}]. Only output one intent from the list.",
        "prompt_2": "Analyze the text: '{{text}}'. Choose the most appropriate intent from these options: "
        f"[{', '.join(intent)}]. Respond with only the selected intent.",
        "prompt_3": "You are a linguistic analyst trained to understand user intent. Based on the text: '{{text}}', "
        f"choose the intent that best matches from this list: [{', '.join(intent)}]. Return only the intent.",
        "prompt_4": f"You are a {lang} linguistic analyst trained to understand {lang} user intent. Based on the {lang}"
        "text: '{{text}}', choose the intent that best matches from this list: "
        f"[{', '.join(intent)}]. Return only the intent.",
        "prompt_5": f"The following text is in {lang}: '{{{{text}}}}'. Given the list of intents: [{', '.join(intent)}], "
        "identify the intent expressed in the text. Return only the identified intent.",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "amh": "Amharic",
        "ewe": "Ewe",
        "hau": "Hausa",
        "ibo": "Igbo",
        "kin": "Kinyarwanda",
        "lin": "Lingala",
        "lug": "Luganda",
        "orm": "Oromo",
        "sna": "Shona",
        "sot": "Sotho",
        "swa": "Swahili",
        "twi": "Twi",
        "wol": "Wolof",
        "xho": "Xhosa",
        "yor": "Yoruba",
        "zul": "Zulu",
        "eng": "English",
    }

    intents = [
        "alarm",
        "balance",
        "bill_balance",
        "book_flight",
        "book_hotel",
        "calendar_update",
        "cancel_reservation",
        "car_rental",
        "confirm_reservation",
        "cook_time",
        "exchange_rate",
        "food_last",
        "freeze_account",
        "ingredients_list",
        "interest_rate",
        "international_visa",
        "make_call",
        "meal_suggestion",
        "min_payment",
        "pay_bill",
        "pin_change",
        "play_music",
        "plug_type",
        "recipe",
        "restaurant_reservation",
        "restaurant_reviews",
        "restaurant_suggestion",
        "share_location",
        "shopping_list_update",
        "spending_history",
        "text",
        "time",
        "timezone",
        "transactions",
        "transfer",
        "translate",
        "travel_notification",
        "travel_suggestion",
        "update_playlist",
        "weather",
    ]

    for lang in languages.keys():
        try:
            file_name = f"injongointent_{lang}.yaml"
            task_name = f"injongointent_{lang}_{mode}"
            yaml_template = "injongointent"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang,
                "doc_to_text": prompt_func(mode, languages[lang], intents),
            }
            os.makedirs(f"{output_dir}/{mode}", exist_ok=True)
            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_3",
        choices=["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"],
        help="Prompt number",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()
