import logging
import os
import random
from datetime import datetime
import copy
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler

try:
    import wandb
except ImportError:
    wandb = None

try:
    import torch.utils.tensorboard as tensorboard
except ImportError:
    tensorboard = None

try:
    import horovod.torch as hvd
except ImportError:
    hvd = None

from clap_module import create_model_and_transforms, trace_model, create_model
from training.data import get_data
from training.distributed import is_master, init_distributed_device, world_info_from_env
from training.logger import setup_logging
from training.params import parse_args
from training.scheduler import cosine_lr
from training.train import train_one_epoch, evaluate
from clap_module.utils import dataset_split, get_optimizer


def maintain_ckpts(args, startidx, all_idx_len):
    for i in reversed(range(startidx, all_idx_len)):
        if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
            os.rename(
                os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
                os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
            )
    if os.path.exists(
        os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
    ):
        os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
    return


def update_top_k_performance(
    new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
):
    """
    Record the top-k performance of the current epoch.
    current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
    """
    if isinstance(new_metrics_inputs, (list, tuple)):
        new_metrics_inputs = np.mean(new_metrics_inputs)
        return update_top_k_performance(
            new_metrics_inputs,
            current_top_k_ckpt_metrics,
            args=args,
            ckpt=ckpt,
            bignumbetter=bignumbetter,
        )
    elif isinstance(new_metrics_inputs, dict):
        new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
        return update_top_k_performance(
            new_metrics_inputs,
            current_top_k_ckpt_metrics,
            args=args,
            ckpt=ckpt,
            bignumbetter=bignumbetter,
        )
    elif isinstance(new_metrics_inputs, (float, int)):
        update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
        sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
        sorted_values = sorted(
            current_top_k_ckpt_metrics.values(), reverse=bignumbetter
        )
        sorted_values_ = copy.deepcopy(sorted_values)
        sorted_values.append(new_metrics_inputs)
        sorted_values = sorted(sorted_values, reverse=bignumbetter)
        sorted_values = sorted_values[:-1]

        if sorted_values == sorted_values_:
            return current_top_k_ckpt_metrics, new_metrics_inputs
        else:
            for i in range(len(sorted_keys)):
                if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
                    current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
                    update_flag[sorted_keys[i]] = True
            for i in range(len(update_flag)):
                if update_flag[i]:
                    maintain_ckpts(args, i, len(sorted_keys))
                    torch.save(
                        ckpt,
                        os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
                    )
                    break
            return current_top_k_ckpt_metrics, new_metrics_inputs


# def updateifNone(a, b):
#     a = b if None else a
#     return a


def is_pretrained_params(n):
    return (
        n.startswith("transformer")
        or n in ["positional_embedding", "text_projection"]
        or n.startswith("token_embedding")
        or n.startswith("ln_final")
        or n.startswith("logit_scale_t")
    )


def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)


def main():
    args = parse_args()
    # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
    args.amodel = args.amodel.replace("/", "-")
    # download sizes.json file

    # (yusong): the below two lines are for debug
    # print("setting up faulthandler")
    # faulthandler.register(10)

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart":
        assert (
            args.pretrained == "" or args.pretrained is None
        ), "bert/roberta/bart text encoder does not support pretrained models."

    # get the name of the experiments
    if args.name is None:
        args.name = "-".join(
            [
                datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
                f"model_{args.amodel}",
                f"lr_{args.lr}",
                f"b_{args.batch_size}",
                f"j_{args.workers}",
                f"p_{args.precision}",
            ]
        )

    # discover initial world args early so we can log properly
    args.distributed = False
    args.local_rank, args.rank, args.world_size = world_info_from_env()

    if args.remotedata and is_master(args):
        for dataset_name in args.datasetnames:
            for split in dataset_split[dataset_name]:
                if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
                    os.makedirs(f"./json_files/{dataset_name}/{split}")
                os.system(
                    f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
                )

    args.log_path = None
    if is_master(args, local=args.log_local):
        log_base_path = os.path.join(args.logs, args.name)
        os.makedirs(log_base_path, exist_ok=True)
        log_filename = f"out-{args.rank}" if args.log_local else "out.log"
        args.log_path = os.path.join(log_base_path, log_filename)
        if os.path.exists(args.log_path):
            print(
                "Error. Experiment already exists. Use --name {} to specify a new experiment."
            )
            return -1

    # Set logger
    args.log_level = logging.DEBUG if args.debug else logging.INFO
    setup_logging(args.log_path, args.log_level)

    # fully initialize distributed device environment
    device = init_distributed_device(args)

    args.wandb = "wandb" in args.report_to or "all" in args.report_to
    args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
    if is_master(args):
        args.tensorboard_path = (
            os.path.join(args.logs, args.name, "tensorboard")
            if args.tensorboard
            else ""
        )
        args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
        for dirname in [args.tensorboard_path, args.checkpoint_path]:
            if dirname:
                os.makedirs(dirname, exist_ok=True)
    else:
        args.tensorboard_path = ""
        args.checkpoint_path = ""

    if args.copy_codebase:
        copy_codebase(args)

    assert args.precision in ["amp", "fp16", "fp32"]
    if args.precision == "fp16":
        logging.warning(
            "It is recommended to use fp32 mixed-precision instead of FP16 and AMP in this model. "
            "They will cause NaN loss and NaN gradients. "
            "FP16 and AMP support needs further verification and tuning, especially for train."
        )

    if args.horovod:
        logging.info(
            f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
            f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
        )
    elif args.distributed:
        logging.info(
            f"Running in distributed mode with multiple processes. Device: {args.device}."
            f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
        )
    else:
        logging.info(f"Running with a single process. Device {args.device}.")

    logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")

    model, model_cfg = create_model(
        args.amodel,
        args.tmodel,
        args.pretrained,
        precision=args.precision,
        device=device,
        jit=args.torchscript,
        force_quick_gelu=args.force_quick_gelu,
        openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
        skip_params=True,
        pretrained_audio=args.pretrained_audio,
        pretrained_text=args.pretrained_text,
        enable_fusion=args.enable_fusion,
        fusion_type=args.fusion_type
    )

    if args.horovod:
        with torch.no_grad():
            for param in model.parameters():
                param.set_(param.contiguous())

    if args.trace:
        model = trace_model(model, batch_size=args.batch_size, device=device)

    if is_master(args):
        logging.info("Model:")
        logging.info(f"{str(model)}")
        logging.info("Params:")
        params_file = os.path.join(args.logs, args.name, "params.txt")
        with open(params_file, "w") as f:
            for name in sorted(vars(args)):
                val = getattr(args, name)
                logging.info(f"  {name}: {val}")
                f.write(f"{name}: {val}\n")

    if args.distributed and not args.horovod:
        if args.use_bn_sync:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        ddp_args = {}
        if args.ddp_static_graph:
            # this doesn't exist in older PyTorch, arg only added if enabled
            ddp_args["static_graph"] = True
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[device], find_unused_parameters=True, **ddp_args
        )

    data = get_data(args, model_cfg)
    assert len(data), "At least one train or eval dataset must be specified."
    if args.trace:
        assert "train" not in data, "Cannot train with traced model"

    exclude = (
        lambda n, p: p.ndim < 2
        or "bn" in n
        or "ln" in n
        or "bias" in n
        or "logit_scale" in n
    )
    include = lambda n, p: not exclude(n, p)

    named_parameters = list(model.named_parameters())

    # freeze text encoder
    text_freeze_parameters = [
        p
        for n, p in named_parameters
        if 'text_branch' in n
    ]

    if args.freeze_text:
        print("Freeze Text!!!!")
        for k in text_freeze_parameters:
            k.requires_grad = False

    gain_or_bias_params = [
        p for n, p in named_parameters if exclude(n, p) and p.requires_grad
    ]
    rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

    # set wd-related params to 0 if use adam optimizer
    if args.optimizer == "adam":
        args.wd = 0
        args.wd_pretrained = 0
        args.wd_new = 0

    if args.train_data is None:
        optimizer = None
        scheduler = None
    else:
        total_steps = data["train"].dataloader.num_batches * args.epochs

        if args.split_opt:
            for x in ["lr", "beta1", "beta2", "eps", "wd"]:
                for y in ["_new", "_pretrained"]:
                    if getattr(args, x + y) is None:
                        setattr(args, x + y, getattr(args, x))

            gain_or_bias_pretrained_params = [
                p
                for n, p in named_parameters
                if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
            ]
            rest_pretrained_params = [
                p
                for n, p in named_parameters
                if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
            ]
            gain_or_bias_new_params = [
                p
                for n, p in named_parameters
                if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n))
            ]
            rest_new_params = [
                p
                for n, p in named_parameters
                if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n))
            ]
            pretrained_params_optimizer = get_optimizer(
                [
                    {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
                    {
                        "params": rest_pretrained_params,
                        "weight_decay": args.wd_pretrained,
                    },
                ],
                lr=args.lr_pretrained,
                betas=(args.beta1_pretrained, args.beta2_pretrained),
                eps=args.eps_pretrained,
                momentum=args.momentum_pretrained,
                optimizer_name=args.optimizer,
                )
            pretrained_params_scheduler = cosine_lr(
                pretrained_params_optimizer,
                args.lr_pretrained,
                args.warmup,
                total_steps,
            )
            new_params_optimizer = get_optimizer(
                [
                    {"params": gain_or_bias_new_params, "weight_decay": 0.0},
                    {"params": rest_new_params, "weight_decay": args.wd_new},
                ],
                lr=args.lr_new,
                betas=(args.beta1_new, args.beta2_new),
                eps=args.eps_new,
                momentum=args.momentum_new,
                optimizer_name=args.optimizer,
                )

            new_params_scheduler = cosine_lr(
                new_params_optimizer, args.lr_new, args.warmup, total_steps
            )

            optimizer = {
                "pretrained": pretrained_params_optimizer,
                "new": new_params_optimizer,
            }
            scheduler = {
                "pretrained": pretrained_params_scheduler,
                "new": new_params_scheduler,
            }

            if args.horovod:
                pretrained_params_optimizer = hvd.DistributedOptimizer(
                    pretrained_params_optimizer,
                    named_parameters=model.named_parameters(),
                )
                new_params_optimizer = hvd.DistributedOptimizer(
                    new_params_optimizer, named_parameters=model.named_parameters()
                )
                hvd.broadcast_parameters(model.state_dict(), root_rank=0)
                hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0)
                hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
        else:
            optimizer = get_optimizer(
                [
                    {"params": gain_or_bias_params, "weight_decay": 0.0},
                    {"params": rest_params, "weight_decay": args.wd},
                ],
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=args.eps,
                momentum=args.momentum,
                optimizer_name=args.optimizer,
            )

            scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)

            if args.horovod:
                optimizer = hvd.DistributedOptimizer(
                    optimizer, named_parameters=model.named_parameters()
                )
                hvd.broadcast_parameters(model.state_dict(), root_rank=0)
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    scaler = GradScaler() if args.precision == "amp" else None

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location=device)
            if "epoch" in checkpoint:
                # resuming a train checkpoint w/ epoch and optimizer state
                start_epoch = checkpoint["epoch"]
                sd = checkpoint["state_dict"]
                if not args.distributed and next(iter(sd.items()))[0].startswith(
                    "module"
                ):
                    sd = {k[len("module.") :]: v for k, v in sd.items()}
                model.load_state_dict(sd)
                if args.split_opt:
                    if optimizer is not None:
                        for k, o_ in optimizer.items():
                            o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
                if optimizer is not None:
                    optimizer.load_state_dict(checkpoint["optimizer"])
                if scaler is not None and "scaler" in checkpoint:
                    scaler.load_state_dict(checkpoint["scaler"])
                logging.info(
                    f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
                )
            else:
                # loading a bare (model only) checkpoint for fine-tune or evaluation
                model.load_state_dict(checkpoint)
                logging.info(
                    f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
                )
            if args.freeze_text:
                print("Freeze Text!!!!")
                for k in text_freeze_parameters:
                    k.requires_grad = False
        else:
            logging.info("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    cudnn.deterministic = False

    # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
    args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
    writer = None
    if args.save_logs and args.tensorboard:
        assert tensorboard is not None, "Please install tensorboard."
        writer = tensorboard.SummaryWriter(args.tensorboard_path)

    if args.wandb and is_master(args):
        assert wandb is not None, "Please install wandb."
        logging.debug("Starting wandb.")
        args.train_sz = data["train"].dataloader.num_samples
        if args.val_data is not None:
            args.val_sz = data["val"].dataloader.num_samples
        # you will have to configure this for your project!
        wandb.init(
            entity="clap",
            project="clap",
            notes=args.wandb_notes,
            name=args.wandb_notes,
            tags=[],
            config=vars(args),
        )
        if args.debug:
            wandb.watch(model, log="all")
        wandb.save(params_file)
        logging.debug("Finished loading wandb.")

    if "train" not in data:
        evaluate(model, data, start_epoch, args, writer)
        return
    elif start_epoch == 0 and "val" in data and not args.no_eval:
        evaluate(model, data, 0, args, writer)
        #  print(f'rank {args.rank}, Start First Evaluation')#  (yusong): for debug
    if args.save_top_performance:
        current_top_k_ckpt_metrics = {
            i: 0 for i in range(args.save_top_performance)
        }  # initialize the top-k metric for ckpts to 0

    #  print(f'rank {args.rank}, Start Training') #  (yusong): for debug
    for epoch in range(start_epoch, args.epochs):
        # freeze the text param after (include) args.freeze_text_after, this is -1 by default
        if epoch == args.freeze_text_after:
            print("Text pretrained parameters are freezed since this epoch.")
            for k in text_freeze_parameters:
                k.requires_grad = False
        if is_master(args):
            logging.info(f"Start epoch {epoch}")

        train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
        completed_epoch = epoch + 1

        if (
            any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
            and not args.no_eval
        ):
            metrics = evaluate(model, data, completed_epoch, args, writer)
            if args.save_top_performance:
                top_k_dataset = args.top_k_checkpoint_select_dataset
                top_k_metric = args.top_k_checkpoint_select_metric
                filtered_metrics = [
                    v
                    for k, v in metrics.items()
                    if top_k_metric in k and top_k_dataset in k
                ]  # check all R@10 metrics (all dataset) and use it to update the ckpt
        # Saving checkpoints.
        if args.save_logs:
            if args.split_opt:
                opt_dict = {
                    k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
                }
            else:
                opt_dict = {"optimizer": optimizer.state_dict()}
            checkpoint_dict = {
                "epoch": completed_epoch,
                "name": args.name,
                "state_dict": model.state_dict(),
            }
            checkpoint_dict.update(opt_dict)
            if scaler is not None:
                checkpoint_dict["scaler"] = scaler.state_dict()

            if completed_epoch == args.epochs or (
                args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
            ):
                torch.save(
                    checkpoint_dict,
                    os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
                )
            if args.save_most_recent:
                torch.save(
                    checkpoint_dict,
                    os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
                )
            if args.save_top_performance and not args.no_eval:
                update_top_k_performance(
                    filtered_metrics,
                    current_top_k_ckpt_metrics,
                    args,
                    checkpoint_dict,
                    bignumbetter=True,
                )

    if args.wandb and is_master(args):
        wandb.finish()


def copy_codebase(args):
    from shutil import copytree, ignore_patterns

    new_code_path = os.path.join(args.logs, args.name, "code")
    if os.path.exists(new_code_path):
        print(
            f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
        )
        return -1
    print(f"Copying codebase to {new_code_path}")
    current_code_path = os.path.realpath(__file__)
    for _ in range(3):
        current_code_path = os.path.dirname(current_code_path)
    copytree(
        current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
    )
    print("Done copying code.")
    return 1


if __name__ == "__main__":
    main()
