#!/usr/bin/env python3
"""
Parallel pipeline: transcribe remaining segments using concurrent Gemini API calls.
Resumes from existing full_run.json and appends new results.

Parallelizes:
  - Gemini API calls: ThreadPoolExecutor (IO-bound, safe to parallelize)
  - Validation: runs sequentially per segment (GPU/CPU models not thread-safe)
"""
import os, sys, json, time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.backend.audio_processor import AudioProcessor, AudioChunk
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig
from src.validators import validate_transcription, cleanup as cleanup_validators


def transcribe_one(transcriber, config, chunk, seg_num, seg_index):
    """Transcribe a single chunk (thread-safe - only IO)."""
    try:
        raw = transcriber.transcribe_audio(chunk.file_path, config)
        return seg_num, seg_index, chunk, raw
    except Exception as e:
        return seg_num, seg_index, chunk, {"error": str(e)}


def main():
    data_path = "./consistency_test/full_run.json"
    seg_dir = "/tmp/maya3_transcribe/pF_BQpHaIdU/extracted/pF_BQpHaIdU/segments"
    workers = 4  # concurrent Gemini API calls

    # Load existing results
    with open(data_path, "r") as f:
        existing = json.load(f)
    done_count = len(existing["segments"])
    print(f"[Parallel] Resuming from {done_count} completed segments")

    # Rebuild all chunks
    proc = AudioProcessor(max_duration_sec=10.0, min_duration_sec=2.0)
    all_chunks = proc.process_segments_directory(seg_dir, skip_short=True)
    print(f"[Parallel] Total chunks: {len(all_chunks)}, remaining: {len(all_chunks) - done_count}")

    remaining = all_chunks[done_count:]
    if not remaining:
        print("[Parallel] Nothing to do")
        return

    transcriber = GeminiTranscriber()
    config = TranscriptionConfig(
        model="gemini-3-flash-preview", thinking_level="low",
        temperature=0.0, language="Telugu"
    )

    # Phase 1: Parallel transcription (IO-bound)
    print(f"\n[Phase 1] Transcribing {len(remaining)} segments with {workers} workers...")
    start = time.time()
    raw_results = {}

    with ThreadPoolExecutor(max_workers=workers) as pool:
        futures = {}
        for i, chunk in enumerate(remaining):
            idx = done_count + i
            seg_num = f"SEG_{idx:04d}"
            future = pool.submit(transcribe_one, transcriber, config, chunk, seg_num, idx)
            futures[future] = (seg_num, idx)

        completed = 0
        for future in as_completed(futures):
            seg_num, seg_index, chunk, raw = future.result()
            raw_results[seg_index] = (seg_num, chunk, raw)
            completed += 1
            if completed % 10 == 0 or completed == len(remaining):
                elapsed = time.time() - start
                rate = completed / elapsed if elapsed > 0 else 0
                eta = (len(remaining) - completed) / rate if rate > 0 else 0
                print(f"  Transcribed {completed}/{len(remaining)} "
                      f"({elapsed:.0f}s, {rate:.1f}/s, ETA {eta:.0f}s)")

    tx_time = time.time() - start
    print(f"[Phase 1] Done in {tx_time:.0f}s ({len(remaining)/tx_time:.1f} seg/s)")

    # Phase 2: Sequential validation (models not thread-safe)
    print(f"\n[Phase 2] Validating {len(raw_results)} segments...")
    val_start = time.time()
    new_segments = []

    for idx in sorted(raw_results.keys()):
        seg_num, chunk, raw = raw_results[idx]
        short = Path(chunk.original_segment).stem
        dur = chunk.duration_sec

        entry = {
            "seg_num": seg_num,
            "seg_index": idx,
            "original_file": chunk.original_segment,
            "audio_file": os.path.basename(chunk.file_path),
            "audio_path": chunk.file_path,
            "chunk_index": chunk.chunk_index,
            "total_chunks": chunk.total_chunks,
            "duration_sec": round(dur, 2),
            "start_ms": chunk.start_ms,
            "end_ms": chunk.end_ms,
        }

        if raw.get("error"):
            entry.update({
                "error": raw["error"],
                "transcription": "", "tagged": "", "romanized": "",
                "detected_language": "", "status": "error",
                "native_ctc": 0, "roman_mms": 0, "combined": 0,
            })
        else:
            native = raw.get("transcription", "")
            entry.update({
                "transcription": native,
                "tagged": raw.get("tagged", ""),
                "romanized": raw.get("romanized", ""),
                "detected_language": raw.get("detected_language", ""),
                "speaker": raw.get("speaker", {}),
            })

            if native:
                val = validate_transcription(
                    chunk.file_path, native, language="te", duration_sec=dur
                )
                entry.update({
                    "native_ctc": val.native_ctc_score,
                    "roman_mms": val.roman_mms_score,
                    "combined": val.combined_score,
                    "status": val.status,
                    "uroman": val.uroman_romanized,
                })
                sym = {"+": "accept", "~": "review", "?": "retry", "X": "reject"}
                rev_sym = {v: k for k, v in sym.items()}
                s = rev_sym.get(val.status, "?")
                print(f"  [{s}] {seg_num} S={val.combined_score:.3f} | {val.status} | {short[:30]}")
            else:
                entry.update({
                    "native_ctc": 0, "roman_mms": 0, "combined": 0,
                    "status": "empty",
                })

        new_segments.append(entry)

        if len(new_segments) % 50 == 0:
            _save(data_path, existing, new_segments, start)
            print(f"  [checkpoint] {done_count + len(new_segments)} total saved")

    # Final save
    _save(data_path, existing, new_segments, start)

    val_time = time.time() - val_start
    total_time = time.time() - start
    cleanup_validators()

    # Summary
    all_segs = existing["segments"] + new_segments
    st = {}
    ts, cs = 0, 0
    for s in all_segs:
        v = s.get("status", "?")
        st[v] = st.get(v, 0) + 1
        if s.get("combined", 0) > 0:
            ts += s["combined"]; cs += 1

    print(f"\n{'='*60}")
    print(f"PARALLEL RUN COMPLETE")
    print(f"{'='*60}")
    print(f"Total segments: {len(all_segs)}")
    print(f"Transcription: {tx_time:.0f}s | Validation: {val_time:.0f}s | Total: {total_time:.0f}s")
    print(f"Avg S: {ts/max(cs,1):.3f}")
    print(f"Verdicts: {st}")


def _save(path, existing, new_segments, start):
    all_segs = existing["segments"] + new_segments
    data = {
        "meta": {
            "total_segments": len(all_segs),
            "processed": len(all_segs),
            "elapsed_sec": round(time.time() - start + existing["meta"].get("elapsed_sec", 0), 1),
            "timestamp": datetime.now().isoformat(),
            "partial": True,
        },
        "segments": all_segs,
    }
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
