/*
 * Extra Credit — Load TorchScript ResNet-18 in C++ via LibTorch
 * Runs inference on a dummy CIFAR10-shaped input (1x3x32x32)
 * Usage: ./inference <path-to-resnet18_scripted.pt>
 */

#include <torch/script.h>
#include <iostream>
#include <chrono>

int main(int argc, const char* argv[]) {
    if (argc < 2) {
        std::cerr << "Usage: " << argv[0] << " <model.pt> [--gpu]" << std::endl;
        return 1;
    }

    std::string model_path = argv[1];
    bool use_gpu = (argc >= 3 && std::string(argv[2]) == "--gpu");

    torch::Device device = use_gpu ? torch::kCUDA : torch::kCPU;

    // load serialized TorchScript model
    torch::jit::script::Module model;
    try {
        model = torch::jit::load(model_path, device);
        model.eval();
    } catch (const c10::Error& e) {
        std::cerr << "Failed to load model: " << e.what() << std::endl;
        return 1;
    }
    std::cout << "Model loaded from " << model_path
              << " on " << (use_gpu ? "GPU" : "CPU") << std::endl;

    // dummy CIFAR10 input: batch=1, channels=3, height=32, width=32
    auto input = torch::randn({1, 3, 32, 32}).to(device);
    std::vector<torch::jit::IValue> inputs{input};

    // warmup
    for (int i = 0; i < 10; ++i)
        model.forward(inputs);

    // timed inference
    auto t0 = std::chrono::high_resolution_clock::now();
    int n_runs = 100;
    for (int i = 0; i < n_runs; ++i)
        model.forward(inputs);
    auto t1 = std::chrono::high_resolution_clock::now();

    auto output = model.forward(inputs).toTensor();
    double ms = std::chrono::duration<double, std::milli>(t1 - t0).count() / n_runs;

    auto pred = output.argmax(1).item<int>();
    auto probs = torch::softmax(output, 1);

    std::cout << "Output shape: " << output.sizes() << std::endl;
    std::cout << "Predicted class: " << pred << std::endl;
    std::cout << "Avg latency (" << n_runs << " runs): " << ms << " ms" << std::endl;
    std::cout << "Top-5 probs: " << probs.topk(5) << std::endl;

    return 0;
}
