#!/usr/bin/env python3
"""
Transcribe indic-asr-validation-300k with Qwen3-ASR-mixed via vLLM.

Processes parquet shards one at a time to control memory.
Saves transcriptions as JSONL incrementally with --resume support.

Usage:
    python transcribe_300k.py \
        --checkpoint /home/ubuntu/training/checkpoints/qwen3-asr-mixed-ckpt-120000 \
        --output /home/ubuntu/training/transcriptions_300k/ckpt-120000-mixed.jsonl \
        --batch-size 32
"""

import argparse
import glob
import json
import os
import re
import sys
import time

os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

import numpy as np
import pandas as pd
import soundfile as sf
from io import BytesIO

sys.path.insert(0, "/home/ubuntu/training/qwen3-asr-1.7b-phase2-sft")

LANG_TAG_RE = re.compile(r"^<\|[a-z]{2}\|>\s*")
LANG_CODE_TO_NAME = {
    "as": "Assamese", "bn": "Bengali", "en": "English", "gu": "Gujarati",
    "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi",
    "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu",
}


def strip_lang_tag(text: str) -> str:
    return LANG_TAG_RE.sub("", text).strip()


def decode_audio(audio_dict):
    """Decode audio from parquet row. Returns (wav_array, sample_rate)."""
    if isinstance(audio_dict, dict):
        if "array" in audio_dict:
            return np.array(audio_dict["array"], dtype=np.float32), audio_dict.get("sampling_rate", 16000)
        if "bytes" in audio_dict and audio_dict["bytes"]:
            wav, sr = sf.read(BytesIO(audio_dict["bytes"]))
            return wav.astype(np.float32), sr
        if "path" in audio_dict and audio_dict["path"]:
            wav, sr = sf.read(audio_dict["path"])
            return wav.astype(np.float32), sr
    raise ValueError(f"Cannot decode audio: {type(audio_dict)}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--dataset-dir", default="/home/ubuntu/training/datasets/indic-asr-validation-300k/data")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.85)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--resume", action="store_true")
    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.output), exist_ok=True)

    # Get parquet files
    parquet_files = sorted(glob.glob(os.path.join(args.dataset_dir, "*.parquet")))
    print(f"Found {len(parquet_files)} parquet shards")

    # Resume support
    done_ids = set()
    if args.resume and os.path.exists(args.output):
        with open(args.output) as f:
            for line in f:
                done_ids.add(json.loads(line)["segment_id"])
        print(f"Resuming: {len(done_ids)} already done")

    # Load model
    from qwen_asr import Qwen3ASRModel
    print(f"Loading model: {args.checkpoint}")
    t0 = time.time()
    model = Qwen3ASRModel.LLM(
        model=args.checkpoint,
        gpu_memory_utilization=args.gpu_memory_utilization,
        max_inference_batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
    )
    print(f"Model loaded in {time.time() - t0:.1f}s")

    total_done = len(done_ids)
    total_audio_sec = 0.0
    t_start = time.time()
    outfile = open(args.output, "a" if args.resume else "w")

    for shard_idx, pf in enumerate(parquet_files):
        shard_name = os.path.basename(pf)
        df = pd.read_parquet(pf)
        print(f"\n[Shard {shard_idx+1}/{len(parquet_files)}] {shard_name}: {len(df)} rows")

        # Filter already done
        if done_ids:
            mask = ~df["segment_id"].isin(done_ids)
            df = df[mask].reset_index(drop=True)
            if len(df) == 0:
                print(f"  All done, skipping")
                continue

        # Process in batches
        for batch_start in range(0, len(df), args.batch_size):
            batch_df = df.iloc[batch_start:batch_start + args.batch_size]

            audio_inputs = []
            metas = []
            skip_count = 0
            for _, row in batch_df.iterrows():
                try:
                    wav, sr = decode_audio(row["audio"])
                    audio_inputs.append((wav, sr))
                    metas.append({
                        "segment_id": row["segment_id"],
                        "reference": strip_lang_tag(str(row["text"])),
                        "duration": float(row["duration"]),
                        "lang": row["lang"],
                        "source": row.get("source", ""),
                    })
                except Exception as e:
                    skip_count += 1
                    continue

            if not audio_inputs:
                continue

            forced_langs = [LANG_CODE_TO_NAME.get(m["lang"]) for m in metas]

            try:
                transcriptions = model.transcribe(
                    audio=audio_inputs,
                    language=forced_langs,
                    return_time_stamps=False,
                )
            except Exception as e:
                print(f"  Batch failed ({e}), trying without forced lang...")
                try:
                    transcriptions = model.transcribe(
                        audio=audio_inputs, language=None, return_time_stamps=False,
                    )
                except Exception as e2:
                    print(f"  FAILED: {e2}, writing empty hyps")
                    transcriptions = [type('T', (), {'text': '', 'language': ''})() for _ in metas]

            for j, t in enumerate(transcriptions):
                result = {**metas[j], "hypothesis": t.text, "detected_language": t.language}
                outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
                total_audio_sec += metas[j]["duration"]

            total_done += len(metas)
            elapsed = time.time() - t_start
            rate = total_done / elapsed if elapsed > 0 else 0
            eta_min = (300000 - total_done) / rate / 60 if rate > 0 else 0

            print(f"  [{total_done:>7}/~300k] {rate:.0f} samp/s | ETA={eta_min:.0f}min", end="\r", flush=True)

        outfile.flush()
        elapsed = time.time() - t_start
        rate = total_done / elapsed if elapsed > 0 else 0
        eta_min = (300000 - total_done) / rate / 60 if rate > 0 else 0
        rtf = elapsed / total_audio_sec if total_audio_sec > 0 else 0
        print(f"\n  Shard done. Total={total_done} | {rate:.0f} samp/s | RTF={rtf:.4f} | ETA={eta_min:.0f}min")

    outfile.close()
    elapsed = time.time() - t_start
    rtf = elapsed / total_audio_sec if total_audio_sec > 0 else 0
    print(f"\nDONE: {total_done} samples in {elapsed:.0f}s ({total_audio_sec:.0f}s audio, RTF={rtf:.4f})")


if __name__ == "__main__":
    main()
