"""CUDA event-based timing with proper synchronization."""

from __future__ import annotations

from dataclasses import dataclass, field

import torch


@dataclass
class TimingResult:
    """Result of a timed region."""
    gpu_ms: float
    label: str = ""


class CUDATimer:
    """Context-manager and manual API for GPU-accurate timing via CUDA events."""

    def __init__(self, device: str = "cuda"):
        self._device = device
        self._start = torch.cuda.Event(enable_timing=True)
        self._end = torch.cuda.Event(enable_timing=True)

    def record_start(self) -> None:
        self._start.record()

    def record_end(self) -> float:
        """Record end event, synchronize, return elapsed ms."""
        self._end.record()
        torch.cuda.synchronize()
        return self._start.elapsed_time(self._end)

    def __enter__(self) -> "CUDATimer":
        self.record_start()
        return self

    def __exit__(self, *exc) -> None:
        self._elapsed = self.record_end()

    @property
    def elapsed_ms(self) -> float:
        return self._elapsed


@dataclass
class BenchStats:
    """Aggregated statistics from repeated timing runs."""
    label: str
    times_ms: list[float] = field(default_factory=list)

    @property
    def n(self) -> int:
        return len(self.times_ms)

    @property
    def mean_ms(self) -> float:
        return sum(self.times_ms) / max(self.n, 1)

    @property
    def p50_ms(self) -> float:
        return self._percentile(50)

    @property
    def p95_ms(self) -> float:
        return self._percentile(95)

    @property
    def min_ms(self) -> float:
        return min(self.times_ms) if self.times_ms else 0.0

    @property
    def max_ms(self) -> float:
        return max(self.times_ms) if self.times_ms else 0.0

    def _percentile(self, p: float) -> float:
        if not self.times_ms:
            return 0.0
        s = sorted(self.times_ms)
        k = (len(s) - 1) * (p / 100.0)
        lo = int(k)
        hi = min(lo + 1, len(s) - 1)
        frac = k - lo
        return s[lo] * (1 - frac) + s[hi] * frac

    def as_dict(self) -> dict:
        return {
            "label": self.label,
            "n": self.n,
            "mean_ms": round(self.mean_ms, 3),
            "p50_ms": round(self.p50_ms, 3),
            "p95_ms": round(self.p95_ms, 3),
            "min_ms": round(self.min_ms, 3),
            "max_ms": round(self.max_ms, 3),
        }


def measure_peak_vram(device: str = "cuda") -> int:
    """Return peak VRAM in bytes since last reset."""
    return torch.cuda.max_memory_allocated(device)


def reset_vram_stats(device: str = "cuda") -> None:
    torch.cuda.reset_peak_memory_stats(device)
