"""
Lab 2 — HPML Spring 2026
ResNet-18 on CIFAR10: Training & Profiling (Part A)
Exercises: C1–C6, Q3
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import argparse
import time


# ==== 2.1 ResNet-18 (arXiv:1512.03385) ====
# CIFAR10 variant — no maxpool, 3x3 first conv (32x32 inputs don't need 7x7)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_ch, out_ch, stride=1, use_bn=True):
        super().__init__()
        # bias=False since batchnorm absorbs the bias term
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch) if use_bn else nn.Identity()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch) if use_bn else nn.Identity()

        # shortcut projection when spatial dims or channels change
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch * self.expansion:
            sc = [nn.Conv2d(in_ch, out_ch * self.expansion, 1, stride=stride, bias=False)]
            if use_bn:
                sc.append(nn.BatchNorm2d(out_ch * self.expansion))
            self.shortcut = nn.Sequential(*sc)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, use_bn=True):
        super().__init__()
        self.in_planes = 64
        self.use_bn = use_bn

        # first conv: 3 input channels -> 64, kernel 3x3, stride 1, pad 1
        self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64) if use_bn else nn.Identity()

        # 4 subgroups x 2 blocks = 8 BasicBlocks total
        self.layer1 = self._build_group(block, 64,  layers[0], stride=1)   # 32x32
        self.layer2 = self._build_group(block, 128, layers[1], stride=2)   # 16x16
        self.layer3 = self._build_group(block, 256, layers[2], stride=2)   # 8x8
        self.layer4 = self._build_group(block, 512, layers[3], stride=2)   # 4x4

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _build_group(self, block, planes, n_blocks, stride):
        # first block may downsample, rest keep spatial dims
        strides = [stride] + [1] * (n_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s, self.use_bn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


def resnet18(use_bn=True):
    return ResNet(BasicBlock, [2, 2, 2, 2], use_bn=use_bn)


# ==== 2.2 CIFAR10 DataLoader ====
# Transforms: RandomCrop(32, pad=4) -> RandomHFlip(0.5) -> Normalize(RGB means/stds)

CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD  = (0.2023, 0.1994, 0.2010)

def get_cifar10_loader(data_path='./data', batch_size=128, num_workers=2, train=True):
    if train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])

    dataset = torchvision.datasets.CIFAR10(
        root=data_path, train=train, download=True, transform=transform
    )
    return torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=train, num_workers=num_workers
    )


# ==== C1: Training ====

def get_optimizer(name, model_params, lr=0.1, momentum=0.9, weight_decay=5e-4):
    name = name.lower()
    if name == 'sgd':
        return torch.optim.SGD(model_params, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif name == 'sgd_nesterov':
        return torch.optim.SGD(model_params, lr=lr, momentum=momentum,
                               weight_decay=weight_decay, nesterov=True)
    elif name == 'adam':
        return torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {name}")


def train_one_epoch(model, loader, criterion, optimizer, device):
    """Single epoch: fwd+bwd on all batches, returns (loss, acc, data_time, train_time)."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    data_time = 0.0
    train_time = 0.0

    # timer starts right before iterating — captures pure iterator wait
    batch_end = time.perf_counter()
    for inputs, targets in loader:
        # C2.1 — data loading: time spent waiting on the dataloader iterator
        data_time += time.perf_counter() - batch_end

        inputs, targets = inputs.to(device), targets.to(device)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        t0 = time.perf_counter()

        # C2.2 — mini-batch computation (fwd + loss + bwd + step)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        if device.type == 'cuda':
            torch.cuda.synchronize()
        train_time += time.perf_counter() - t0

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        batch_end = time.perf_counter()

    n_batches = len(loader)
    return running_loss / n_batches, 100.0 * correct / total, data_time, train_time


# ==== C1 + C2: Training with per-epoch timing ====

def run_training(args, device):
    """C1: full train loop, C2: timing breakdown per epoch."""
    model = resnet18(use_bn=not args.no_batchnorm).to(device)
    loader = get_cifar10_loader(args.data_path, args.batch_size, args.num_workers)
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer(args.optimizer, model.parameters())

    print(f"\n{'='*60}")
    print(f" C1/C2: Training — {args.epochs} epochs")
    print(f" optimizer={args.optimizer}  workers={args.num_workers}"
          f"  bn={'on' if not args.no_batchnorm else 'off'}  device={device}")
    print(f"{'='*60}")
    print(f"{'Ep':>3} {'Loss':>8} {'Acc%':>7} {'Data(s)':>8} {'Train(s)':>9} {'Total(s)':>9}")

    for epoch in range(1, args.epochs + 1):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        t_start = time.perf_counter()

        loss, acc, dt, tt = train_one_epoch(model, loader, criterion, optimizer, device)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        t_total = time.perf_counter() - t_start

        print(f"{epoch:>3} {loss:>8.4f} {acc:>6.2f}% {dt:>8.2f} {tt:>9.2f} {t_total:>9.2f}")

    print()
    return model, optimizer


# ==== C3: I/O Worker Sweep ====

def run_worker_sweep(args, device):
    """C3: vary num_workers in steps of 4, measure data-loading time, plot results."""
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    worker_list = list(range(0, 21, 4))  # [0, 4, 8, 12, 16, 20]
    avg_data_times = []
    avg_total_times = []

    print(f"\n{'='*60}")
    print(f" C3: I/O Optimization — Worker Sweep ({args.epochs} epochs each)")
    print(f"{'='*60}")
    print(f"{'Workers':>8} {'AvgData(s)':>11} {'AvgTrain(s)':>12} {'AvgTotal(s)':>12}")

    for nw in worker_list:
        # fresh model each run so results are comparable
        model = resnet18().to(device)
        loader = get_cifar10_loader(args.data_path, args.batch_size, nw)
        criterion = nn.CrossEntropyLoss()
        optimizer = get_optimizer('sgd', model.parameters())

        sum_dt, sum_tt, sum_total = 0.0, 0.0, 0.0
        for ep in range(1, args.epochs + 1):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            t0 = time.perf_counter()

            _, _, dt, tt = train_one_epoch(model, loader, criterion, optimizer, device)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            sum_dt += dt
            sum_tt += tt
            sum_total += time.perf_counter() - t0

        avg_dt = sum_dt / args.epochs
        avg_tt = sum_tt / args.epochs
        avg_tot = sum_total / args.epochs
        avg_data_times.append(avg_dt)
        avg_total_times.append(avg_tot)

        print(f"{nw:>8} {avg_dt:>11.3f} {avg_tt:>12.3f} {avg_tot:>12.3f}")

    # C3.2 — best worker count by total runtime (what the assignment means)
    best_idx = avg_total_times.index(min(avg_total_times))
    best_nw = worker_list[best_idx]
    print(f"\nC3.2 => Best num_workers = {best_nw} "
          f"(avg total: {avg_total_times[best_idx]:.3f}s, "
          f"avg data: {avg_data_times[best_idx]:.3f}s)")

    # plot
    fig, ax1 = plt.subplots(figsize=(8, 5))
    ax1.plot(worker_list, avg_data_times, 'bo-', linewidth=2, label='Data Loading')
    ax1.plot(worker_list, avg_total_times, 'rs--', linewidth=2, label='Total Epoch')
    ax1.set_xlabel('num_workers')
    ax1.set_ylabel('Avg Time per Epoch (s)')
    ax1.set_title('C3: DataLoader I/O Time vs. num_workers')
    ax1.set_xticks(worker_list)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig('c3_workers.png', dpi=150)
    print(f"Plot saved -> c3_workers.png")

    return best_nw


# ==== C4: GPU vs CPU comparison ====

def run_cpu_vs_gpu(args):
    """C4: train 5 epochs on each device, report avg epoch time."""
    results = {}
    for dev_name in ['cpu', 'cuda']:
        device = torch.device(dev_name)
        model = resnet18().to(device)
        loader = get_cifar10_loader(args.data_path, args.batch_size, args.num_workers)
        criterion = nn.CrossEntropyLoss()
        optimizer = get_optimizer('sgd', model.parameters())

        print(f"\n  [{dev_name.upper()}] Training {args.epochs} epochs...")
        epoch_times = []
        for ep in range(1, args.epochs + 1):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            t0 = time.perf_counter()

            loss, acc, dt, tt = train_one_epoch(model, loader, criterion, optimizer, device)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            elapsed = time.perf_counter() - t0
            epoch_times.append(elapsed)
            print(f"    Ep {ep}: loss={loss:.4f}  acc={acc:.2f}%  time={elapsed:.2f}s")

        avg = sum(epoch_times) / len(epoch_times)
        results[dev_name] = avg
        print(f"  [{dev_name.upper()}] Avg epoch time: {avg:.2f}s")

    print(f"\n{'='*60}")
    print(f" C4 Summary: GPU vs CPU (workers={args.num_workers})")
    print(f"{'='*60}")
    print(f"  CPU avg: {results['cpu']:.2f}s/epoch")
    print(f"  GPU avg: {results['cuda']:.2f}s/epoch")
    speedup = results['cpu'] / results['cuda']
    print(f"  GPU speedup: {speedup:.1f}x")
    return results


# ==== C5: Optimizer comparison ====

def run_optimizer_comparison(args, device):
    """C5: same setup, three optimizers — per-epoch stats side by side."""
    opt_names = ['sgd', 'sgd_nesterov', 'adam']
    all_results = {}

    for opt in opt_names:
        model = resnet18().to(device)
        loader = get_cifar10_loader(args.data_path, args.batch_size, args.num_workers)
        criterion = nn.CrossEntropyLoss()
        optimizer = get_optimizer(opt, model.parameters())

        print(f"\n  [{opt.upper()}]")
        print(f"  {'Ep':>3} {'Loss':>8} {'Acc%':>7} {'Train(s)':>9} {'Total(s)':>9}")

        epoch_data = []
        for ep in range(1, args.epochs + 1):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            t0 = time.perf_counter()

            loss, acc, dt, tt = train_one_epoch(model, loader, criterion, optimizer, device)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            t_total = time.perf_counter() - t0

            epoch_data.append((loss, acc, tt, t_total))
            print(f"  {ep:>3} {loss:>8.4f} {acc:>6.2f}% {tt:>9.2f} {t_total:>9.2f}")

        avg_loss = sum(r[0] for r in epoch_data) / args.epochs
        avg_acc  = sum(r[1] for r in epoch_data) / args.epochs
        avg_tt   = sum(r[2] for r in epoch_data) / args.epochs
        all_results[opt] = (avg_loss, avg_acc, avg_tt)

    print(f"\n{'='*60}")
    print(f" C5 Summary: Optimizer Comparison (workers={args.num_workers})")
    print(f"{'='*60}")
    print(f"  {'Optimizer':<15} {'AvgLoss':>8} {'AvgAcc%':>8} {'AvgTrain(s)':>12}")
    for opt in opt_names:
        l, a, t = all_results[opt]
        print(f"  {opt:<15} {l:>8.4f} {a:>7.2f}% {t:>12.2f}")
    return all_results


# ==== C6: Without batch normalization ====

def run_no_batchnorm(args, device):
    """C6: train with SGD but all batchnorm layers replaced by Identity."""
    model = resnet18(use_bn=False).to(device)
    loader = get_cifar10_loader(args.data_path, args.batch_size, args.num_workers)
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer('sgd', model.parameters())

    print(f"\n{'='*60}")
    print(f" C6: Without Batch Norm — {args.epochs} epochs (SGD, workers={args.num_workers})")
    print(f"{'='*60}")
    print(f"{'Ep':>3} {'Loss':>8} {'Acc%':>7} {'Train(s)':>9} {'Total(s)':>9}")

    epoch_data = []
    for ep in range(1, args.epochs + 1):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        t0 = time.perf_counter()

        loss, acc, dt, tt = train_one_epoch(model, loader, criterion, optimizer, device)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        t_total = time.perf_counter() - t0

        epoch_data.append((loss, acc, tt, t_total))
        print(f"{ep:>3} {loss:>8.4f} {acc:>6.2f}% {tt:>9.2f} {t_total:>9.2f}")

    avg_loss = sum(r[0] for r in epoch_data) / args.epochs
    avg_acc  = sum(r[1] for r in epoch_data) / args.epochs
    print(f"\nC6 Summary => avg loss: {avg_loss:.4f}, avg acc: {avg_acc:.2f}%")
    return avg_loss, avg_acc


# ==== Q3: param & gradient counting ====

def count_parameters(model, optimizer):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    grad_count = sum(p.numel() for p in model.parameters() if p.grad is not None)
    print(f"[Q3] Trainable params : {trainable:,}")
    print(f"[Q3] Params w/ grads  : {grad_count:,}")
    n_states = sum(len(v) for v in optimizer.state.values())
    print(f"[Q3] Optimizer states : {n_states}")
    return trainable, grad_count


# ==== Main ====

def main():
    parser = argparse.ArgumentParser(description='Lab2 — ResNet-18 CIFAR10')
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--data_path', type=str, default='./data')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='default=4 (optimal from C3)')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--optimizer', type=str, default='sgd',
                        choices=['sgd', 'sgd_nesterov', 'adam'])
    parser.add_argument('--no_batchnorm', action='store_true',
                        help='C6: disable batch norm layers')
    parser.add_argument('--task', type=str, default='train',
                        choices=['train', 'c3', 'c4', 'c5', 'c6', 'all'],
                        help='which experiment to run')
    args = parser.parse_args()

    device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    print(f"ResNet-18 params: {sum(p.numel() for p in resnet18().parameters()):,}")

    if args.task in ('train', 'all'):
        model, optimizer = run_training(args, device)
        count_parameters(model, optimizer)

    if args.task in ('c3', 'all'):
        best_nw = run_worker_sweep(args, device)
        print(f"\n=> Use --num_workers {best_nw} for subsequent experiments")

    if args.task in ('c4', 'all'):
        print(f"\n{'='*60}")
        print(f" C4: GPU vs CPU")
        print(f"{'='*60}")
        run_cpu_vs_gpu(args)

    if args.task in ('c5', 'all'):
        print(f"\n{'='*60}")
        print(f" C5: Optimizer Comparison")
        print(f"{'='*60}")
        run_optimizer_comparison(args, device)

    if args.task in ('c6', 'all'):
        run_no_batchnorm(args, device)


if __name__ == '__main__':
    main()
