import argparse


def get_default_params(model_name):
    # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
    model_name = model_name.lower()
    if "vit" in model_name:
        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
    else:
        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train-data",
        type=str,
        default=None,
        help="Path to h5 filewith training data",
    )
    parser.add_argument(
        "--val-data",
        type=str,
        default=None,
        help="Path to h5 file with validation data",
    )
    parser.add_argument(
        "--freeze-text",
        default=False,
        action="store_true",
        help="if you need to freeze the text encoder, make this True",
    )
    parser.add_argument(
        "--freeze-text-after",
        type=int,
        default=-1,
        help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
    )
    parser.add_argument(
        "--train-ipc",
        type=str,
        default=None,
        help="Path to npy file of the number of instance per class in training data",
    )
    parser.add_argument(
        "--val-ipc",
        type=str,
        default=None,
        help="Path to npy file of the number of instance per class in validation data",
    )
    parser.add_argument(
        "--train-num-samples",
        type=int,
        default=None,
        help="Number of samples in dataset. Required for webdataset if not available in info file.",
    )
    parser.add_argument(
        "--val-num-samples",
        type=int,
        default=None,
        help="Number of samples in dataset. Useful for webdataset if not available in info file.",
    )
    parser.add_argument(
        "--dataset-type",
        choices=["webdataset", "csv", "auto", "toy"],
        default="auto",
        help="Which type of dataset to process.",
    )
    parser.add_argument(
        "--csv-separator",
        type=str,
        default="\t",
        help="For csv-like datasets, which separator to use.",
    )
    parser.add_argument(
        "--csv-img-key",
        type=str,
        default="filepath",
        help="For csv-like datasets, the name of the key for the image paths.",
    )
    parser.add_argument(
        "--csv-caption-key",
        type=str,
        default="title",
        help="For csv-like datasets, the name of the key for the captions.",
    )
    parser.add_argument(
        "--imagenet-val",
        type=str,
        default=None,
        help="Path to imagenet val set for conducting zero shot evaluation.",
    )
    parser.add_argument(
        "--imagenet-v2",
        type=str,
        default=None,
        help="Path to imagenet v2 for conducting zero shot evaluation.",
    )
    parser.add_argument(
        "--datasetnames",
        nargs="+",
        default=None,
        help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
    )
    parser.add_argument(
        "--full-train-dataset",
        nargs="+",
        default=None,
        help="Which dataset will be trained with all the subsets. (train+test)",
    )
    parser.add_argument(
        "--exclude-eval-dataset",
        nargs="+",
        default=None,
        help="Which dataset will be excluded with evaluation",
    )
    parser.add_argument(
        "--datasetinfos",
        nargs="+",
        default=None,
        help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
    )
    parser.add_argument(
        "--dataset-proportion",
        type=float,
        default=1.0,
        help="How much proportion of dataset we want to train.",
    )
    parser.add_argument(
        "--remotedata",
        default=False,
        action="store_true",
        help="if the dataset is remote, set this flag",
    )
    parser.add_argument(
        "--class-label-path",
        type=str,
        default=None,
        help="The path of the class label pickle or csv.",
    )
    parser.add_argument(
        "--datasetpath",
        type=str,
        default="/mnt/audio_clip/webdataset_tar",
        help="The path to the dataset",
    )
    parser.add_argument(
        "--logs",
        type=str,
        default="./logs/",
        help="Where to store tensorboard logs. Use None to avoid storing logs.",
    )
    parser.add_argument(
        "--log-local",
        action="store_true",
        default=False,
        help="log files on local master, otherwise global master only.",
    )
    parser.add_argument(
        "--name",
        type=str,
        default=None,
        help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
    )
    parser.add_argument(
        "--workers", type=int, default=1, help="Number of workers per GPU."
    )
    parser.add_argument(
        "--batch-size", type=int, default=64, help="Batch size per GPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=32, help="Number of epochs to train for."
    )
    parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
    parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
    parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
    parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
    parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
    parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")

    parser.add_argument(
        "--split-opt",
        action="store_true",
        default=False,
        help="Use this flag to skip the learning rate decay.",
    )
    parser.add_argument(
        "--lr-pretrained", type=float, default=None, help="Learning rate for text."
    )
    parser.add_argument(
        "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
    )
    parser.add_argument(
        "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
    )
    parser.add_argument(
        "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
    )
    parser.add_argument(
        "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
    )
    parser.add_argument(
        "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
    )
    parser.add_argument(
        "--lr-new", type=float, default=None, help="Learning rate for audio."
    )
    parser.add_argument(
        "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
    )
    parser.add_argument(
        "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
    )
    parser.add_argument(
        "--eps-new", type=float, default=None, help="Adam epsilon for audio."
    )
    parser.add_argument(
        "--wd-new", type=float, default=0.2, help="Weight decay for audio."
    )
    parser.add_argument(
        "--momentum-new", type=float, default=0.9, help="Momentum for audio."
    )
    parser.add_argument(
        "--warmup", type=int, default=10000, help="Number of steps to warmup for."
    )
    parser.add_argument(
        "--use-bn-sync",
        default=False,
        action="store_true",
        help="Whether to use batch norm sync.",
    )
    parser.add_argument(
        "--skip-scheduler",
        action="store_true",
        default=False,
        help="Use this flag to skip the learning rate decay.",
    )
    parser.add_argument(
        "--save-frequency", type=int, default=1, help="How often to save checkpoints."
    )
    parser.add_argument(
        "--save-top-performance",
        type=int,
        default=0,
        help="Save the top x performance weights if the value >0",
    )
    parser.add_argument(
        "--save-most-recent",
        action="store_true",
        default=False,
        help="Always save the most recent model trained to epoch_latest.pt.",
    )
    parser.add_argument(
        "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
    )
    parser.add_argument(
        "--val-frequency",
        type=int,
        default=1,
        help="How often to run evaluation with val data.",
    )
    parser.add_argument(
        "--resume",
        default=None,
        type=str,
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "--precision",
        choices=["amp", "fp16", "fp32"],
        default="amp",
        help="Floating point precision.",
    )
    parser.add_argument(
        "--amodel",
        type=str,
        default="RN50",
        help="Name of the audio backbone to use.",
    )
    parser.add_argument(
        "--tmodel",
        type=str,
        default="transformer",
        help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
    )
    parser.add_argument(
        "--pretrained-audio",
        default="",
        type=str,
        help="Use a pretrained audio model weights for the audio encoder of CLAP",
    )
    parser.add_argument(
        "--pretrained-text",
        default="",
        type=str,
        help="Use a pretrained text model weights for the text encoder of CLAP",
    )
    parser.add_argument(
        "--pretrained",
        default="",
        type=str,
        help="Use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--pretrained-image",
        default=False,
        action="store_true",
        help="Load imagenet pretrained weights for image tower backbone if available.",
    )
    parser.add_argument(
        "--lock-image",
        default=False,
        action="store_true",
        help="Lock full image tower by disabling gradients.",
    )
    parser.add_argument(
        "--lock-image-unlocked-groups",
        type=int,
        default=0,
        help="Leave last n image tower layer groups unlocked.",
    )
    parser.add_argument(
        "--lock-image-freeze-bn-stats",
        default=False,
        action="store_true",
        help="Freeze BatchNorm running stats in image tower for any locked layers.",
    )
    parser.add_argument(
        "--local-loss",
        default=False,
        action="store_true",
        help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
    )
    parser.add_argument(
        "--gather-with-grad",
        default=False,
        action="store_true",
        help="enable full distributed gradient for feature gather",
    )
    parser.add_argument(
        "--force-quick-gelu",
        default=False,
        action="store_true",
        help="Force use of QuickGELU activation for non-OpenAI transformer models.",
    )
    parser.add_argument(
        "--torchscript",
        default=False,
        action="store_true",
        help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
    )
    parser.add_argument(
        "--trace",
        default=False,
        action="store_true",
        help="torch.jit.trace the model for inference / eval only",
    )
    # arguments for distributed training
    parser.add_argument(
        "--dist-url",
        default="env://",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--dist-backend", default="nccl", type=str, help="distributed backend"
    )
    parser.add_argument(
        "--report-to",
        default="",
        type=str,
        help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
    )
    parser.add_argument(
        "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
    )
    parser.add_argument(
        "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
    )
    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="If true, more information is logged.",
    )
    parser.add_argument(
        "--copy-codebase",
        default=False,
        action="store_true",
        help="If true, we copy the entire base on the log diretory, and execute from there.",
    )
    parser.add_argument(
        "--horovod",
        default=False,
        action="store_true",
        help="Use horovod for distributed training.",
    )
    parser.add_argument(
        "--ddp-static-graph",
        default=False,
        action="store_true",
        help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
    )
    parser.add_argument(
        "--no-set-device-rank",
        default=False,
        action="store_true",
        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
    )
    parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")

    parser.add_argument(
        "--top-k-checkpoint-select-dataset",
        type=str,
        default="all",
        help="The dataset of selecting top-k checkpoint.",
    )

    # @R10, @R@5, @R1, mAP@10
    parser.add_argument(
        "--top-k-checkpoint-select-metric",
        type=str,
        default="_R@10",
        help="The metric for selecting top-k checkpoint.",
    )
    parser.add_argument(
        "--openai-model-cache-dir",
        type=str,
        default="~/.cache/clip",
        help="Directory to download OpenAI models.",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adamw",
        help="can be AdamW or SGD",
    )
    parser.add_argument(
        "--parallel-eval",
        default=False,
        action="store_true",
        help="Eval in parallel (multi-GPU, multi-node).",
    )

    parser.add_argument(
        "--no-eval",
        default=False,
        action="store_true",
        help="Training without evaluation.",
    )

    parser.add_argument(
        "--lp-mlp",
        default=False,
        action="store_true",
        help="Linear Probe using MLP layer or not.",
    )

    parser.add_argument(
        "--lp-freeze",
        default=False,
        action="store_true",
        help="Linear Probe using Freeze CLAP or not",
    )

    parser.add_argument(
        "--lp-act",
        default="None",
        type=str,
        help="Options are ['relu','elu','prelu','softmax','sigmoid']",
    )

    parser.add_argument(
        "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
    )

    parser.add_argument(
        "--lp-metrics",
        type=str,
        default="map,mauc,acc",
        help="Metrics of Linear Probe.",
    )

    parser.add_argument(
        "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
    )
    parser.add_argument(
        "--kappa", type=float, default=0,
        help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss"
    )

    parser.add_argument(
        "--data-filling",
        type=str,
        default="pad",
        help="type of data filling when the audio length is shorter than the max length."
             "Can be one of the following: repeat, repeatpad, pad",
    )
    parser.add_argument(
        "--data-truncating",
        type=str,
        default="rand_trunc",
        help="type of data truncation when the audio length is longer than the max length."
             "Can be one of the following: rand_trunc, fusion",
    )

    parser.add_argument(
        "--clap-mlploss",
        default=False,
        action="store_true",
        help="Using MLP loss for CLAP model or not",
    )

    parser.add_argument(
        "--wandb-id",
        type=str,
        default=None,
        help="the id of wandb experiment to restore.",
    )

    parser.add_argument(
        "--sleep", type=float, default=0, help="sleep n seconds before start training"
    )

    # variable length processing
    parser.add_argument(
        "--enable-fusion",
        default=False,
        action="store_true",
        help="Enable feature funsion for variable-length data",
    )

    parser.add_argument(
        "--fusion-type",
        type=str,
        default='None',
        help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
    )

    parser.add_argument(
        "--mixup",
        default=False,
        action="store_true",
        help="Enable mixup in finetuning training.",
    )
    parser.add_argument(
        "--text-augment-selection",
        type=str,
        default=None,
        help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
    )
    parser.add_argument(
        "--prefetch-factor",
        type=int,
        default=None,
        help="The prefetch factor for dataloader. Larger value will use more memory and CPU but faster.",
    )

    args = parser.parse_args()

    # If some params are not passed, we use the default values based on model name.
    default_params = get_default_params(args.amodel)
    for name, val in default_params.items():
        if getattr(args, name) is None:
            setattr(args, name, val)

    return args
