#!/usr/bin/env python3
"""
Evaluation script for finetuned Cohere Transcribe.
Computes WER/CER per language on a dev/test set.

Usage:
    python evaluate.py \
        --checkpoint ./checkpoints/final-ema/model \
        --test-manifest /data/mel_shards/test_manifest.parquet \
        --mel-shards-dir /data/mel_shards/ \
        --output results.json
"""

import argparse
import json
import time
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from tokenizer_utils import decode_tokens, load_extended_tokenizer

# Use jiwer for WER/CER computation
try:
    from jiwer import wer as compute_wer, cer as compute_cer
except ImportError:
    print("Install jiwer: pip install jiwer")
    raise


def evaluate(
    model,
    processor,
    test_manifest: pd.DataFrame,
    mel_shards_dir: str,
    device: torch.device,
    batch_size: int = 16,
):
    """Evaluate model on test set, return per-language WER/CER."""
    import tarfile
    import io

    model.eval()
    results_by_lang = defaultdict(lambda: {'refs': [], 'hyps': [], 'durations': []})

    # Group by shard for efficient reading
    shard_groups = test_manifest.groupby('shard_path')
    total_utts = len(test_manifest)
    processed = 0

    tokenizer = processor.tokenizer

    for shard_path, group in shard_groups:
        # Read mel features from shard
        try:
            with tarfile.open(shard_path, 'r') as tar:
                for _, row in group.iterrows():
                    segment_id = row['segment_id']
                    language = row['language']

                    # Read mel
                    try:
                        mel_member = tar.getmember(f"{segment_id}.mel.npy")
                        mel_data = np.load(io.BytesIO(tar.extractfile(mel_member).read()))
                    except (KeyError, Exception):
                        continue

                    mel_tensor = torch.from_numpy(mel_data).unsqueeze(0).to(device, dtype=torch.bfloat16)
                    mel_length = torch.tensor([mel_data.shape[-1]], device=device)

                    # Build language-conditioned decoder prompt
                    prompt_str = model.build_prompt(language=language, punctuation=True)
                    prompt_ids = tokenizer.encode(prompt_str, add_special_tokens=False)
                    decoder_input_ids = torch.tensor([prompt_ids], device=device)

                    # Generate transcription with language prompt
                    with torch.no_grad():
                        outputs = model.generate(
                            input_features=mel_tensor,
                            length=mel_length,
                            decoder_input_ids=decoder_input_ids,
                            max_new_tokens=256,
                        )

                    hyp = decode_tokens(tokenizer, outputs[0]).strip()
                    ref = row.get('transcript', '').strip()

                    results_by_lang[language]['refs'].append(ref)
                    results_by_lang[language]['hyps'].append(hyp)
                    results_by_lang[language]['durations'].append(row.get('duration_s', 0))

                    processed += 1
                    if processed % 500 == 0:
                        print(f"  Evaluated {processed}/{total_utts} utterances")

        except Exception as e:
            print(f"Error reading shard {shard_path}: {e}")

    # Compute metrics per language
    metrics = {}
    for lang, data in results_by_lang.items():
        refs = data['refs']
        hyps = data['hyps']

        # Filter out empty refs
        valid = [(r, h) for r, h in zip(refs, hyps) if r.strip()]
        if not valid:
            continue
        valid_refs, valid_hyps = zip(*valid)

        lang_wer = compute_wer(list(valid_refs), list(valid_hyps))
        lang_cer = compute_cer(list(valid_refs), list(valid_hyps))

        metrics[lang] = {
            'wer': round(lang_wer * 100, 2),
            'cer': round(lang_cer * 100, 2),
            'num_utterances': len(valid_refs),
            'total_hours': round(sum(data['durations']) / 3600, 2),
        }

    # Overall metrics — pair refs and hyps BEFORE filtering to maintain alignment
    all_pairs = [(r, h) for d in results_by_lang.values()
                 for r, h in zip(d['refs'], d['hyps'])]
    valid_pairs = [(r, h) for r, h in all_pairs if r.strip()]
    if valid_pairs:
        all_valid_refs, all_valid_hyps = zip(*valid_pairs)
        metrics['overall'] = {
            'wer': round(compute_wer(list(all_valid_refs), list(all_valid_hyps)) * 100, 2),
            'cer': round(compute_cer(list(all_valid_refs), list(all_valid_hyps)) * 100, 2),
            'num_utterances': len(all_valid_refs),
        }

    return metrics


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
    parser.add_argument('--test-manifest', required=True, help='Path to test manifest parquet')
    parser.add_argument('--mel-shards-dir', required=True)
    parser.add_argument('--output', default='eval_results.json')
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--device', default='cuda:0')
    args = parser.parse_args()

    device = torch.device(args.device)

    print(f"Loading model from {args.checkpoint}...")
    processor = AutoProcessor.from_pretrained(args.checkpoint, trust_remote_code=True)
    processor.tokenizer = load_extended_tokenizer(args.checkpoint)
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        args.checkpoint,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    ).to(device)
    model.eval()

    print(f"Loading test manifest from {args.test_manifest}...")
    test_df = pd.read_parquet(args.test_manifest)
    print(f"Test set: {len(test_df)} utterances")

    print("Evaluating...")
    start = time.time()
    metrics = evaluate(model, processor, test_df, args.mel_shards_dir, device, args.batch_size)
    elapsed = time.time() - start

    print(f"\nEvaluation complete in {elapsed:.1f}s")
    print("=" * 60)
    print(f"{'Language':<12} {'WER':>8} {'CER':>8} {'Utts':>8} {'Hours':>8}")
    print("-" * 60)
    for lang in sorted(metrics.keys()):
        if lang == 'overall':
            continue
        m = metrics[lang]
        print(f"{lang:<12} {m['wer']:>7.2f}% {m['cer']:>7.2f}% {m['num_utterances']:>8} {m.get('total_hours', 0):>7.1f}h")
    if 'overall' in metrics:
        print("-" * 60)
        m = metrics['overall']
        print(f"{'OVERALL':<12} {m['wer']:>7.2f}% {m['cer']:>7.2f}% {m['num_utterances']:>8}")
    print("=" * 60)

    with open(args.output, 'w') as f:
        json.dump(metrics, f, indent=2)
    print(f"Results saved to {args.output}")


if __name__ == '__main__':
    main()
