"""
Training utilities for pretrained models

Authors
* Artem Ploujnikov 2021
"""

import os
import shutil

from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


def save_for_pretrained(
    hparams,
    min_key=None,
    max_key=None,
    ckpt_predicate=None,
    pretrainer_key="pretrainer",
    checkpointer_key="checkpointer",
):
    """
    Saves the necessary files for the pretrained model
    from the best checkpoint found. The goal of this function
    is to export the model for a Pretrainer

    Arguments
    ---------
    hparams: dict
        the hyperparameter file
    min_key: str
        Key to use for finding best checkpoint (lower is better).
        By default, passed to ``self.checkpointer.recover_if_possible()``.
    max_key: str
        Key to use for finding best checkpoint (higher is better).
        By default, passed to ``self.checkpointer.recover_if_possible()``.
    ckpt_predicate: callable
        a filter predicate to locate checkpoints
    pretrainer_key: str
        the key under which the pretrainer is stored
    checkpointer_key: str
        the key under which the checkpointer is stored

    Returns
    -------
    saved: bool
        Whether the save was successful
    """
    if any(key not in hparams for key in [pretrainer_key, checkpointer_key]):
        raise ValueError(
            f"Incompatible hparams: a checkpointer with key {checkpointer_key}"
            f"and a pretrainer with key {pretrainer_key} are required"
        )
    pretrainer = hparams[pretrainer_key]
    checkpointer = hparams[checkpointer_key]
    checkpoint = checkpointer.find_checkpoint(
        min_key=min_key, max_key=max_key, ckpt_predicate=ckpt_predicate
    )
    if checkpoint:
        logger.info(
            "Saving checkpoint '%s' a pretrained model", checkpoint.path
        )
        pretrainer_keys = set(pretrainer.loadables.keys())
        checkpointer_keys = set(checkpoint.paramfiles.keys())
        keys_to_save = pretrainer_keys & checkpointer_keys
        for key in keys_to_save:
            source_path = checkpoint.paramfiles[key]
            if not os.path.exists(source_path):
                raise ValueError(
                    f"File {source_path} does not exist in the checkpoint"
                )
            target_path = pretrainer.paths[key]
            dirname = os.path.dirname(target_path)
            if not os.path.exists(dirname):
                os.makedirs(dirname)
            if os.path.exists(target_path):
                os.remove(target_path)
            shutil.copyfile(source_path, target_path)
        saved = True
    else:
        logger.info(
            "Unable to find a matching checkpoint for min_key = %s, max_key = %s",
            min_key,
            max_key,
        )
        checkpoints = checkpointer.list_checkpoints()
        checkpoints_str = "\n".join(
            f"{checkpoint.path}: {checkpoint.meta}"
            for checkpoint in checkpoints
        )
        logger.info("Available checkpoints: %s", checkpoints_str)
        saved = False

    return saved
