"""
Lab 2 — HPML Spring 2026
Part B: TorchScript Optimization (C7–C10)
Converts trained ResNet-18 to TorchScript, evaluates, and benchmarks.
"""

import torch
import torch.nn as nn
import argparse
import time

from lab2 import (resnet18, get_cifar10_loader, get_optimizer,
                  train_one_epoch)


# ==== C7: Convert to TorchScript ====

def train_and_script(args, device):
    """Train ResNet-18, then convert to TorchScript via torch.jit.script."""
    model = resnet18(use_bn=True).to(device)
    loader = get_cifar10_loader(args.data_path, args.batch_size, args.num_workers)
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer('sgd', model.parameters())

    print(f"Training {args.epochs} epochs before scripting...")
    for ep in range(1, args.epochs + 1):
        loss, acc, _, _ = train_one_epoch(model, loader, criterion, optimizer, device)
        print(f"  Ep {ep}: loss={loss:.4f}  acc={acc:.2f}%")

    # script the trained model
    model.eval()
    scripted = torch.jit.script(model)

    # save and reload to verify round-trip works
    save_path = "resnet18_scripted.pt"
    torch.jit.save(scripted, save_path)
    loaded = torch.jit.load(save_path, map_location=device)
    print(f"\nC7: Scripted model saved -> {save_path}")

    # sanity: both produce same output
    dummy = torch.randn(1, 3, 32, 32, device=device)
    with torch.no_grad():
        diff = (scripted(dummy) - loaded(dummy)).abs().max().item()
    print(f"C7: Save/load verification — max diff: {diff:.1e}")

    return model, loaded


# ==== C8: Print Model Graph ====

def print_graph(scripted_model):
    """Print the TorchScript IR graph."""
    print(f"\n{'='*60}")
    print(" C8: TorchScript Model Graph")
    print(f"{'='*60}")
    print(scripted_model.graph)


# ==== C9: Evaluate on CIFAR10 test set ====

def evaluate(model, loader, device):
    """Run inference on test set, return top-1 accuracy."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100.0 * correct / total


def run_evaluation(pytorch_model, scripted_model, args, device):
    """C9: evaluate both models on test set."""
    test_loader = get_cifar10_loader(args.data_path, args.batch_size,
                                     args.num_workers, train=False)

    pytorch_model.eval()
    acc_pt = evaluate(pytorch_model, test_loader, device)
    acc_ts = evaluate(scripted_model, test_loader, device)

    print(f"\n{'='*60}")
    print(" C9: Test Set Accuracy")
    print(f"{'='*60}")
    print(f"  PyTorch model:     {acc_pt:.2f}%")
    print(f"  TorchScript model: {acc_ts:.2f}%")
    return acc_pt, acc_ts


# ==== C10: Latency Comparison ====

def measure_latency(model, device, n_warmup=50, n_runs=200):
    """Average inference latency for a single batch (128 images)."""
    dummy = torch.randn(1, 3, 32, 32, device=device)
    model.eval()

    # warmup — let CUDA kernels JIT-compile, caches settle
    with torch.no_grad():
        for _ in range(n_warmup):
            model(dummy)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    times = []
    with torch.no_grad():
        for _ in range(n_runs):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            t0 = time.perf_counter()
            model(dummy)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            times.append((time.perf_counter() - t0) * 1000)  # ms

    return sum(times) / len(times)


def run_latency_comparison(pytorch_model, scripted_model, args):
    """C10: benchmark PyTorch vs TorchScript on CPU and GPU."""
    results = {}

    for dev_name in ['cpu', 'cuda']:
        device = torch.device(dev_name)

        # fresh model copies per device — avoids stale CUDA state on CPU runs
        pt_model = resnet18(use_bn=True)
        pt_model.load_state_dict(pytorch_model.state_dict())
        pt_model = pt_model.to(device).eval()

        ts_model = torch.jit.load("resnet18_scripted.pt", map_location=device)

        lat_pt = measure_latency(pt_model, device)
        lat_ts = measure_latency(ts_model, device)
        results[dev_name] = (lat_pt, lat_ts)

    print(f"\n{'='*60}")
    print(" C10: Latency Comparison (single image, ms)")
    print(f"{'='*60}")
    print(f"  {'':15} {'CPU (ms)':>10} {'GPU (ms)':>10}")
    print(f"  {'PyTorch':15} {results['cpu'][0]:>10.2f} {results['cuda'][0]:>10.2f}")
    print(f"  {'TorchScript':15} {results['cpu'][1]:>10.2f} {results['cuda'][1]:>10.2f}")

    for dev_name in ['cpu', 'cuda']:
        pt, ts = results[dev_name]
        speedup = pt / ts if ts > 0 else float('inf')
        print(f"  {dev_name.upper()} speedup: {speedup:.2f}x")

    return results


# ==== Main ====

def main():
    parser = argparse.ArgumentParser(description='Lab2 Part B — TorchScript')
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--data_path', type=str, default='./data')
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=5)
    args = parser.parse_args()

    device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}\n")

    # C7 — train then script
    pytorch_model, scripted_model = train_and_script(args, device)

    # C8 — print graph
    print_graph(scripted_model)

    # C9 — test accuracy
    run_evaluation(pytorch_model, scripted_model, args, device)

    # C10 — latency benchmark
    run_latency_comparison(pytorch_model, scripted_model, args)


if __name__ == '__main__':
    main()
