"""
Generate roofline model plot for Q2.
Parameters: 200 GFLOPS peak compute, 30 GB/s peak bandwidth.
Dot product AI = 2N FLOPs / (2*N*4 bytes) = 0.25 FLOP/byte.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

# === Roofline parameters ===
peak_compute = 200   # GFLOP/s
peak_bw = 30         # GB/s
ridge_point = peak_compute / peak_bw  # 6.67 FLOP/byte

# === Arithmetic intensity for dot product ===
# 2N FLOPs (N muls + N adds) / (2*N*4 bytes read) = 0.25 FLOP/byte
ai_dot = 0.25

# === Data points: (label, AI, GFLOP/s) — all have same AI ===
data = {
    'N=1M': [
        ('C1: Simple',    ai_dot, 2.21),
        ('C2: Unrolled',  ai_dot, 6.71),
        ('C3: CBLAS',     ai_dot, 5.59),
        ('C4: Py loop',   ai_dot, 0.005),
        ('C5: np.dot',    ai_dot, 15.93),
    ],
    'N=300M': [
        ('C1: Simple',    ai_dot, 2.14),
        ('C2: Unrolled',  ai_dot, 3.23),
        ('C3: CBLAS',     ai_dot, 2.96),
        ('C4: Py loop',   ai_dot, 0.005),
        ('C5: np.dot',    ai_dot, 3.98),
    ]
}

# === Plot ===
fig, ax = plt.subplots(1, 1, figsize=(12, 8))

# Roofline: bandwidth-bound region + compute-bound region
ai_range = np.logspace(-2, 2, 500)
roofline = np.minimum(peak_compute, peak_bw * ai_range)
ax.plot(ai_range, roofline, 'k-', linewidth=2.5, label='Roofline')

# Bandwidth ceiling line (extends into compute region as dashed)
bw_line = peak_bw * ai_range
ax.plot(ai_range, bw_line, 'b--', linewidth=1, alpha=0.3, label=f'BW ceiling ({peak_bw} GB/s)')

# Compute ceiling line
ax.axhline(y=peak_compute, color='r', linestyle='--', linewidth=1, alpha=0.3, label=f'Compute ceiling ({peak_compute} GFLOP/s)')

# Vertical line at dot product AI
ax.axvline(x=ai_dot, color='green', linestyle='--', linewidth=1.5, alpha=0.7,
           label=f'Dot product AI = {ai_dot} FLOP/byte')

# Plot data points — offset slightly for readability
markers_1m = ['o', 's', '^', 'v', 'D']
markers_300m = ['o', 's', '^', 'v', 'D']
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']

# Slight x-offset for N=1M vs N=300M to avoid overlap
x_offset_1m = 0.22
x_offset_300m = 0.28

for i, (label, ai, gflops) in enumerate(data['N=1M']):
    ax.scatter(x_offset_1m, gflops, marker=markers_1m[i], s=120, c=colors[i],
               edgecolors='black', linewidths=0.5, zorder=5)
    ax.annotate(f'{label}\nN=1M\n{gflops:.2f}', (x_offset_1m, gflops),
                textcoords="offset points", xytext=(15, 5), fontsize=7,
                arrowprops=dict(arrowstyle='-', color='gray', lw=0.5))

for i, (label, ai, gflops) in enumerate(data['N=300M']):
    ax.scatter(x_offset_300m, gflops, marker=markers_300m[i], s=120, c=colors[i],
               edgecolors='black', linewidths=0.5, zorder=5)
    ax.annotate(f'{label}\nN=300M\n{gflops:.2f}', (x_offset_300m, gflops),
                textcoords="offset points", xytext=(-75, 5), fontsize=7,
                arrowprops=dict(arrowstyle='-', color='gray', lw=0.5))

# Ridge point annotation
ax.annotate(f'Ridge point\nAI={ridge_point:.2f}', (ridge_point, peak_compute),
            textcoords="offset points", xytext=(20, -30), fontsize=9,
            arrowprops=dict(arrowstyle='->', color='black'),
            bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))

# Bandwidth ceiling at AI=0.25
bw_at_ai = peak_bw * ai_dot
ax.axhline(y=bw_at_ai, color='green', linestyle=':', linewidth=1, alpha=0.4)
ax.annotate(f'BW limit @ AI=0.25\n= {bw_at_ai:.1f} GFLOP/s',
            (0.01, bw_at_ai), fontsize=8, color='green', va='bottom')

ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Arithmetic Intensity (FLOP/byte)', fontsize=13)
ax.set_ylabel('Attainable Performance (GFLOP/s)', fontsize=13)
ax.set_title('Roofline Model — Dot Product Microbenchmarks\n'
             f'Peak: {peak_compute} GFLOP/s, Memory BW: {peak_bw} GB/s', fontsize=14)
ax.set_xlim(0.01, 100)
ax.set_ylim(0.001, 500)
ax.legend(loc='upper left', fontsize=9)
ax.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.savefig('/home/ubuntu/nyu_hpml/roofline.png', dpi=150)
print("Saved roofline.png")
