#!/usr/bin/env python3
"""
Full pipeline run: all segments >= 2s for a video, with splitting, polishing,
transcription, and validation. Outputs numbered segments for dashboard review.

Usage:
    cd /home/ubuntu/maya3_transcribe && source venv/bin/activate
    python bin/full_run.py /tmp/maya3_transcribe/pF_BQpHaIdU/extracted/pF_BQpHaIdU/segments
"""
import os, sys, json, time
from pathlib import Path
from datetime import datetime

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.backend.audio_processor import AudioProcessor, AudioChunk
from src.backend.audio_polisher import AudioPolisher
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig
from src.validators import validate_transcription, cleanup as cleanup_validators

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("seg_dir", help="Path to segments directory")
    parser.add_argument("--language", "-l", default="Telugu")
    parser.add_argument("--output", "-o", default="./consistency_test/full_run.json")
    parser.add_argument("--max", "-n", type=int, default=None, help="Max segments (for testing)")
    parser.add_argument("--no-polish", action="store_true")
    parser.add_argument("--no-validate", action="store_true")
    args = parser.parse_args()

    seg_dir = args.seg_dir
    lang_code_map = {
        "Telugu": "te", "Hindi": "hi", "Tamil": "ta", "Kannada": "kn",
        "Malayalam": "ml", "Bengali": "bn", "Marathi": "mr", "English": "en",
    }
    lang_code = lang_code_map.get(args.language, "te")

    print(f"[Full Run] Segments: {seg_dir}")
    print(f"[Full Run] Language: {args.language} ({lang_code})")

    # Step 1: Discover and process segments
    proc = AudioProcessor(max_duration_sec=10.0, min_duration_sec=2.0)
    chunks = proc.process_segments_directory(seg_dir, max_segments=args.max, skip_short=True)
    print(f"[Full Run] {len(chunks)} chunks ready (after splitting + min 2s filter)")

    # Step 2: Polish audio
    if not args.no_polish:
        print("[Full Run] Polishing audio boundaries...")
        polisher = AudioPolisher()
        polished_dir = os.path.join(seg_dir, "polished_full")
        polished = 0
        snr_skipped = 0
        surviving = []
        for chunk in chunks:
            result = polisher.polish(chunk.file_path, output_dir=polished_dir)
            if result.was_modified:
                chunk.file_path = result.output_path
                polished += 1
            if result.snr_db < 5.0:
                snr_skipped += 1
                continue
            if result.polished_duration_ms / 1000.0 < 2.0:
                snr_skipped += 1
                continue
            surviving.append(chunk)
        print(f"[Full Run] Polished {polished}/{len(chunks)}, SNR-skipped {snr_skipped}")
        chunks = surviving

    # Step 3: Number segments sequentially
    # Format: SEG_NNNN (0-padded 4 digits)
    print(f"[Full Run] {len(chunks)} chunks to transcribe")

    # Step 4: Transcribe + validate
    transcriber = GeminiTranscriber()
    config = TranscriptionConfig(
        model="gemini-3-flash-preview",
        thinking_level="low",
        temperature=0.0,
        language=args.language,
    )

    results = []
    start_time = time.time()

    for i, chunk in enumerate(chunks):
        seg_num = f"SEG_{i:04d}"
        short_name = Path(chunk.original_segment).stem
        dur = chunk.duration_sec

        print(f"  [{seg_num}] {i+1}/{len(chunks)}: {short_name} ({dur:.1f}s)"
              f" chunk {chunk.chunk_index+1}/{chunk.total_chunks}")

        # Transcribe
        raw = transcriber.transcribe_audio(chunk.file_path, config)

        entry = {
            "seg_num": seg_num,
            "seg_index": i,
            "original_file": chunk.original_segment,
            "audio_file": os.path.basename(chunk.file_path),
            "audio_path": chunk.file_path,
            "chunk_index": chunk.chunk_index,
            "total_chunks": chunk.total_chunks,
            "duration_sec": round(dur, 2),
            "start_ms": chunk.start_ms,
            "end_ms": chunk.end_ms,
        }

        if raw.get("error"):
            print(f"    ERROR: {raw['error']}")
            entry.update({
                "error": raw["error"],
                "transcription": "",
                "tagged": "",
                "romanized": "",
                "detected_language": "",
                "status": "error",
                "native_ctc": 0, "roman_mms": 0, "combined": 0,
            })
        else:
            native = raw.get("transcription", "")
            tagged = raw.get("tagged", "")
            romanized = raw.get("romanized", "")
            detected = raw.get("detected_language", "")
            speaker = raw.get("speaker", {})

            entry.update({
                "transcription": native,
                "tagged": tagged,
                "romanized": romanized,
                "detected_language": detected,
                "speaker": speaker,
            })

            # Validate
            if not args.no_validate and native:
                val = validate_transcription(
                    chunk.file_path, native, language=lang_code,
                    duration_sec=dur,
                )
                entry.update({
                    "native_ctc": val.native_ctc_score,
                    "roman_mms": val.roman_mms_score,
                    "combined": val.combined_score,
                    "status": val.status,
                    "uroman": val.uroman_romanized,
                })
                status_sym = {"accept": "+", "review": "~", "retry": "?", "reject": "X"}
                print(f"    [{status_sym.get(val.status, '?')}] S={val.combined_score:.3f} "
                      f"CTC={val.native_ctc_score:.3f} MMS={val.roman_mms_score:.3f} "
                      f"| {val.status}")
            else:
                entry.update({
                    "native_ctc": 0, "roman_mms": 0, "combined": 0, "status": "unvalidated",
                })

        results.append(entry)

        # Rate limit
        if i < len(chunks) - 1:
            time.sleep(0.3)

        # Save intermediate every 50 segments
        if (i + 1) % 50 == 0:
            _save(args.output, results, chunks, start_time, partial=True)
            print(f"  [checkpoint] Saved {len(results)} results")

    # Final save
    elapsed = time.time() - start_time
    _save(args.output, results, chunks, start_time, partial=False)

    # Summary
    cleanup_validators()
    statuses = {}
    total_s = 0
    count_s = 0
    for r in results:
        s = r.get("status", "unknown")
        statuses[s] = statuses.get(s, 0) + 1
        if r.get("combined", 0) > 0:
            total_s += r["combined"]
            count_s += 1

    print(f"\n{'='*60}")
    print(f"FULL RUN COMPLETE")
    print(f"{'='*60}")
    print(f"Segments: {len(results)}")
    print(f"Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
    print(f"Avg S: {total_s/max(count_s,1):.3f}")
    print(f"Verdicts: {statuses}")
    print(f"Output: {args.output}")


def _save(path, results, chunks, start_time, partial=False):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    data = {
        "meta": {
            "total_segments": len(chunks),
            "processed": len(results),
            "elapsed_sec": round(time.time() - start_time, 1),
            "timestamp": datetime.now().isoformat(),
            "partial": partial,
        },
        "segments": results,
    }
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
