"""Report generation: CSV summary and Markdown comparison report."""

from __future__ import annotations

import csv
import json
from pathlib import Path
from typing import Any


def load_jsonl(path: str | Path) -> list[dict]:
    results = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                results.append(json.loads(line))
    return results


def generate_summary_csv(results: list[dict], output_path: str | Path) -> None:
    """Write aggregated summary CSV from JSONL results."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    fieldnames = [
        "codec", "sr", "batch_size", "dtype", "clip_seconds",
        "encode_ms", "encode_p50_ms", "encode_p95_ms",
        "decode_ms", "decode_p50_ms", "decode_p95_ms",
        "e2e_ms", "tokens_per_sec", "peak_vram_mb",
        "mel_l1", "mrstft", "sisdr", "hf_energy_delta_db",
        "gpu_name", "torch",
    ]
    existing_fields = set()
    for r in results:
        existing_fields.update(r.keys())
    fieldnames = [f for f in fieldnames if f in existing_fields]

    with open(output_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        writer.writeheader()
        for row in results:
            writer.writerow(row)


def _modeling_pain_score(codec_name: str, results: list[dict]) -> tuple[int, str]:
    """Heuristic 'modeling pain' score for LM integration.

    1 = trivial (single stream, simple tokens)
    5 = painful (multi-stream, extra token types, complex structure)
    """
    codec_name = codec_name.lower()
    if "xcodec2" in codec_name:
        return 1, "Single VQ stream @ 50 TPS - trivial for LM"
    if "wavtokenizer" in codec_name:
        return 2, "Single stream @ 40 TPS - easy, needs codes_to_features at inference"
    if "bicodec" in codec_name:
        return 3, "Dual stream (semantic 50 TPS + fixed global) - moderate, requires both for decode"
    if "snac" in codec_name:
        return 4, "Multi-scale hierarchical (3+ levels) - needs interleaving scheme for LM"
    return 5, "Unknown structure"


def generate_report_md(
    results: list[dict],
    output_path: str | Path,
    eval_results: list[dict] | None = None,
) -> None:
    """Generate markdown comparison report."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    lines: list[str] = []
    lines.append("# CodecBench Report\n")

    if results:
        lines.append(f"**GPU**: {results[0].get('gpu_name', 'N/A')}")
        lines.append(f"**PyTorch**: {results[0].get('torch', 'N/A')}")
        lines.append("")

    # --- Speed ranking ---
    lines.append("## Speed Ranking (encode + decode)\n")
    lines.append("| Codec | BS | dtype | Encode (ms) | Decode (ms) | E2E (ms) | TPS | VRAM (MB) |")
    lines.append("|-------|---:|-------|------------:|------------:|---------:|----:|----------:|")

    sorted_results = sorted(results, key=lambda r: r.get("e2e_ms", 9999))
    for r in sorted_results:
        lines.append(
            f"| {r.get('codec', '')} | {r.get('batch_size', '')} | {r.get('dtype', '')} "
            f"| {r.get('encode_ms', 0):.1f} | {r.get('decode_ms', 0):.1f} "
            f"| {r.get('e2e_ms', 0):.1f} | {r.get('tokens_per_sec', 0):.0f} "
            f"| {r.get('peak_vram_mb', 0):.0f} |"
        )
    lines.append("")

    # --- Quality ranking ---
    if eval_results:
        lines.append("## Quality Ranking\n")
        lines.append("| Codec | Mel L1 | MR-STFT | SI-SDR (dB) | HF Delta (dB) |")
        lines.append("|-------|-------:|--------:|------------:|--------------:|")

        sorted_eval = sorted(eval_results, key=lambda r: r.get("mel_l1", 9999))
        for r in sorted_eval:
            lines.append(
                f"| {r.get('codec', '')} "
                f"| {r.get('mel_l1', 0):.4f} | {r.get('mrstft', 0):.4f} "
                f"| {r.get('sisdr', 0):.1f} | {r.get('hf_energy_delta_db', 0):.1f} |"
            )
        lines.append("")

    # --- Modeling pain ---
    lines.append("## Modeling Pain Score\n")
    lines.append("| Codec | Pain (1-5) | Notes |")
    lines.append("|-------|----------:|-------|")

    seen_codecs = set()
    for r in results:
        codec = r.get("codec", "")
        if codec in seen_codecs:
            continue
        seen_codecs.add(codec)
        score, notes = _modeling_pain_score(codec, results)
        lines.append(f"| {codec} | {score} | {notes} |")
    lines.append("")

    # --- Recommendation ---
    lines.append("## Recommendation\n")
    if eval_results and results:
        speed_rank = {}
        quality_rank = {}
        for r in results:
            c = r.get("codec", "")
            e2e = r.get("e2e_ms", 9999)
            if c not in speed_rank or e2e < speed_rank[c]:
                speed_rank[c] = e2e
        for r in eval_results:
            c = r.get("codec", "")
            mel = r.get("mel_l1", 9999)
            if c not in quality_rank or mel < quality_rank[c]:
                quality_rank[c] = mel

        # Combined rank: lower is better for both
        all_codecs = set(speed_rank.keys()) | set(quality_rank.keys())
        combined = {}
        for c in all_codecs:
            s = sorted(speed_rank.values()).index(speed_rank.get(c, 9999)) if c in speed_rank else len(all_codecs)
            q = sorted(quality_rank.values()).index(quality_rank.get(c, 9999)) if c in quality_rank else len(all_codecs)
            combined[c] = s + q

        top2 = sorted(combined, key=combined.get)[:2]
        lines.append(f"Top 2 for production tokenization: **{', '.join(top2)}**\n")
        lines.append("Based on combined speed + quality ranking. Review the numbers above for specifics.")
    else:
        lines.append("Run both bench and eval to generate a combined recommendation.")
    lines.append("")

    output_path.write_text("\n".join(lines), encoding="utf-8")
