"""Golden Set Calibration: validate reference segments to calibrate thresholds.
Dir: golden_set/<lang>/segments/*.flac + reference.json
Usage: python golden_set/calibrate.py --language te [--run-transcription]"""
import json, os, sys, argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.validators import validate_transcription, cleanup

def compute_cer(ref, hyp):
    r, h = list(ref), list(hyp)
    d = [[0]*(len(h)+1) for _ in range(len(r)+1)]
    for i in range(len(r)+1): d[i][0]=i
    for j in range(len(h)+1): d[0][j]=j
    for i in range(1,len(r)+1):
        for j in range(1,len(h)+1):
            cost = 0 if r[i-1]==h[j-1] else 1
            d[i][j] = min(d[i-1][j]+1, d[i][j-1]+1, d[i-1][j-1]+cost)
    return d[len(r)][len(h)] / max(len(r),1)

def calibrate(lc, run_tx=False):
    base = os.path.dirname(__file__)
    ref_path = os.path.join(base, lc, "reference.json")
    seg_dir = os.path.join(base, lc, "segments")
    if not os.path.exists(ref_path):
        print(f"Create {ref_path} first."); return
    if not os.path.isdir(seg_dir):
        os.makedirs(seg_dir, exist_ok=True); print(f"Add segments to {seg_dir}/"); return
    with open(ref_path, encoding="utf-8") as f: ref = json.load(f)
    lang, segs = ref.get("language","Telugu"), ref.get("segments",[])
    print(f"Calibrating {lang} ({lc}): {len(segs)} segments\n")
    results = []
    for seg in segs:
        ap = os.path.join(seg_dir, seg["file"])
        if not os.path.exists(ap): print(f"  SKIP: {seg['file']}"); continue
        nr, rr = seg.get("native",""), seg.get("romanized","")
        if run_tx:
            from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig
            tx = GeminiTranscriber()
            cfg = TranscriptionConfig(model="gemini-3-flash-preview",thinking_level="low",temperature=0.0,language=lang)
            raw = tx.transcribe_audio(ap, cfg)
            if raw.get("error"): print(f"  ERR: {raw['error']}"); continue
            gn, gr = raw.get("transcription",""), raw.get("romanized","")
            cer = compute_cer(nr, gn)
            val = validate_transcription(ap, gn, romanized_text=gr, language=lc)
            results.append({"file":seg["file"],"cer":round(cer,4),"ctc":val.native_ctc_score,"mms":val.roman_mms_score,"S":val.combined_score,"verdict":val.status})
            print(f"  {seg['file']}: CER={cer:.2%} S={val.combined_score:.2f} ({val.status})")
        else:
            val = validate_transcription(ap, nr, romanized_text=rr, language=lc)
            results.append({"file":seg["file"],"ctc":val.native_ctc_score,"mms":val.roman_mms_score,"S":val.combined_score,"verdict":val.status})
            print(f"  {seg['file']}: S={val.combined_score:.2f} ({val.status}) CTC={val.native_ctc_score:.2f} MMS={val.roman_mms_score:.2f}")
    cleanup()
    if results:
        sc = [r["S"] for r in results]
        vd = {}
        for r in results: vd[r["verdict"]] = vd.get(r["verdict"],0)+1
        print(f"\nSUMMARY: {len(results)} segs | avg={sum(sc)/len(sc):.3f} min={min(sc):.3f} max={max(sc):.3f}")
        print(f"Verdicts: {vd}")
        out = os.path.join(base, lc, "calibration_results.json")
        with open(out,"w",encoding="utf-8") as f: json.dump({"language":lang,"results":results},f,ensure_ascii=False,indent=2)
        print(f"Saved: {out}")

if __name__=="__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--language","-l",default="te")
    p.add_argument("--run-transcription",action="store_true")
    a = p.parse_args()
    calibrate(a.language, a.run_transcription)
