#!/usr/bin/env python3
"""
Transcribe Modi audio segments using Gemini 3.1 Pro.
Hindi-only, clean text output for VibeVoice fine-tuning, resumable, parallel.
"""
from __future__ import annotations

import json
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

from pydantic import BaseModel, Field

# ── Config ──────────────────────────────────────────────────────────────────

SEGMENTS_DIR = Path("/home/ubuntu/modi_processed/segments")
OUTPUT_DIR = Path("/home/ubuntu/modi_processed/transcripts")
PROGRESS_FILE = Path("/home/ubuntu/modi_processed/transcribe_progress.json")
DATASET_FILE = Path("/home/ubuntu/modi_processed/dataset.jsonl")

GEMINI_MODEL = "gemini-3.1-pro-preview"
GEMINI_KEY = os.getenv("GEMINI_KEY", "")

WORKERS = 50
RETRY_MAX = 5
RETRY_BACKOFF = 3.0
SAVE_EVERY = 50

if not GEMINI_KEY:
    env_file = Path("/home/ubuntu/transcripts/.env")
    if env_file.exists():
        for line in env_file.read_text().splitlines():
            if line.startswith("GEMINI_KEY="):
                GEMINI_KEY = line.split("=", 1)[1].strip()
                break


# ── Schema (minimal — only what VibeVoice needs) ────────────────────────────

class TranscriptionSchema(BaseModel):
    transcription: str = Field(description="Verbatim Devanagari transcription. English words stay in Latin script.")


def _resolve_refs(schema: dict, defs: dict | None = None) -> dict:
    if defs is None:
        defs = schema.pop("$defs", {})
    if "$ref" in schema:
        ref_name = schema["$ref"].split("/")[-1]
        return _resolve_refs(dict(defs.get(ref_name, {})), defs)
    result = {}
    for k, v in schema.items():
        if k == "$defs":
            continue
        if isinstance(v, dict):
            result[k] = _resolve_refs(v, defs)
        elif isinstance(v, list):
            result[k] = [_resolve_refs(item, defs) if isinstance(item, dict) else item for item in v]
        else:
            result[k] = v
    return result


def get_json_schema() -> dict:
    schema = TranscriptionSchema.model_json_schema()
    return _resolve_refs(schema)


# ── Hindi-Only System Prompt (clean text for TTS fine-tuning) ────────────────

SYSTEM_PROMPT = """You are a verbatim speech-to-text transcription system for Hindi. Output ONLY the JSON.

This is Narendra Modi's speech audio — expect formal Hindi with occasional English terms.

RULES:
1. NEVER TRANSLATE. Write what you HEAR. English words stay in Latin script, Hindi in Devanagari.
2. VERBATIM FIDELITY. Every repetition, filler, stammer, false start — exactly as spoken.
3. NO CORRECTION. Do not fix grammar, pronunciation, or word choice.
4. NO HALLUCINATION. Never add words not in the audio. If audio cuts off, STOP where it stops.
5. If a word is unclear, write [UNK]. If no speech at all, write [NO_SPEECH].
6. Audio is VAD-cut and may start/end mid-speech. Transcribe only what you hear.

SCRIPT: Devanagari. Preserve Sandhi and combined forms as spoken.
PUNCTUATION: Only comma, period, ? and ! — from audible pauses/intonation only.
NUMBERS: Write as digits (1, 2, 100) not words.

Do NOT add any audio event tags like [breath], [applause], [music] etc. Output clean text only.

EXAMPLES:
transcription: "भाइयों और बहनों, आज हम एक नए India की बात करते हैं"
transcription: "तो basically हमने Digital India को आगे बढ़ाया"
transcription: "और इसलिए मैं कहता हूँ कि हमारे देश" (cutoff — stop where audio stops)"""

USER_PROMPT = "Transcribe this Hindi audio segment. Return JSON with the transcription field only."


# ── Gemini Client ────────────────────────────────────────────────────────────

def make_client():
    from google import genai
    return genai.Client(api_key=GEMINI_KEY)


def transcribe_segment(client, wav_path: Path) -> dict:
    from google.genai import types

    audio_bytes = wav_path.read_bytes()
    audio_part = types.Part(inline_data=types.Blob(data=audio_bytes, mime_type="audio/wav"))

    config = types.GenerateContentConfig(
        temperature=0,
        thinking_config=types.ThinkingConfig(thinking_budget=1024),
        response_mime_type="application/json",
        response_schema=get_json_schema(),
        system_instruction=SYSTEM_PROMPT,
    )

    response = client.models.generate_content(
        model=GEMINI_MODEL,
        contents=[audio_part, USER_PROMPT],
        config=config,
    )

    text = response.text.strip()
    return json.loads(text)


# ── Progress Management ──────────────────────────────────────────────────────

def load_progress() -> dict:
    if PROGRESS_FILE.exists():
        return json.loads(PROGRESS_FILE.read_text())
    return {}


def save_progress(progress: dict):
    PROGRESS_FILE.write_text(json.dumps(progress, indent=2))


# ── Main Pipeline ────────────────────────────────────────────────────────────

def collect_segments() -> list[Path]:
    return sorted(SEGMENTS_DIR.rglob("*.wav"))


def segment_key(wav_path: Path) -> str:
    return f"{wav_path.parent.name}/{wav_path.name}"


def process_one(client, wav_path: Path) -> tuple[str, dict | None, str | None]:
    key = segment_key(wav_path)
    for attempt in range(RETRY_MAX):
        try:
            result = transcribe_segment(client, wav_path)
            return key, result, None
        except json.JSONDecodeError as e:
            return key, None, f"JSON parse error: {e}"
        except Exception as e:
            err_str = str(e)
            if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str:
                wait = RETRY_BACKOFF * (2 ** attempt)
                time.sleep(wait)
                continue
            if "500" in err_str or "503" in err_str:
                time.sleep(RETRY_BACKOFF * (attempt + 1))
                continue
            return key, None, f"Error: {err_str}"
    return key, None, "Max retries exceeded (rate limited)"


def main():
    if not GEMINI_KEY:
        print("ERROR: No GEMINI_KEY found. Set env var or check /home/ubuntu/transcripts/.env")
        sys.exit(1)

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    client = make_client()
    progress = load_progress()
    all_segments = collect_segments()

    done_keys = {k for k, v in progress.items() if v.get("status") == "done"}
    pending = [s for s in all_segments if segment_key(s) not in done_keys]

    total = len(all_segments)
    already_done = len(done_keys)
    print(f"Total segments: {total}")
    print(f"Already transcribed: {already_done}")
    print(f"Remaining: {len(pending)}")
    print(f"Workers: {WORKERS}")
    print(f"Model: {GEMINI_MODEL}")
    print()

    if not pending:
        print("All segments already transcribed!")
        build_dataset(progress)
        return

    completed_count = 0
    errors_count = 0
    t0 = time.time()

    def _do_one(seg_path):
        return process_one(client, seg_path)

    with ThreadPoolExecutor(max_workers=WORKERS) as pool:
        futures = {pool.submit(_do_one, seg): seg for seg in pending}

        for future in as_completed(futures):
            key, result, error = future.result()
            completed_count += 1

            if result:
                progress[key] = {"status": "done", "result": result}

                parent_dir = OUTPUT_DIR / Path(key).parent
                parent_dir.mkdir(parents=True, exist_ok=True)
                out_file = OUTPUT_DIR / key.replace(".wav", ".json")
                out_file.write_text(json.dumps(result, ensure_ascii=False, indent=2))
            else:
                errors_count += 1
                progress[key] = {"status": "error", "error": error}

            if completed_count % SAVE_EVERY == 0:
                save_progress(progress)

            done_total = already_done + completed_count
            elapsed = time.time() - t0
            rate = completed_count / elapsed if elapsed > 0 else 0
            eta_s = (len(pending) - completed_count) / rate if rate > 0 else 0
            eta_m = eta_s / 60

            status = "OK" if result else f"ERR: {error[:60]}"
            print(
                f"[{done_total}/{total}] {key} — {status} "
                f"({rate:.1f}/s, ETA {eta_m:.0f}m, errs={errors_count})"
            )

    save_progress(progress)
    elapsed_total = time.time() - t0
    print(f"\nDone. {completed_count} processed in {elapsed_total/60:.1f}m. Errors: {errors_count}")

    build_dataset(progress)


def build_dataset(progress: dict):
    """Build the final JSONL dataset for VibeVoice fine-tuning."""
    print(f"\nBuilding dataset JSONL → {DATASET_FILE}")
    count = 0
    with open(DATASET_FILE, "w", encoding="utf-8") as f:
        for key in sorted(progress.keys()):
            entry = progress[key]
            if entry.get("status") != "done":
                continue
            result = entry["result"]
            text = result.get("transcription", "").strip()
            if not text or text == "[NO_SPEECH]":
                continue

            audio_path = str(SEGMENTS_DIR / key)
            line = json.dumps(
                {"text": f"Speaker 0: {text}", "audio": audio_path},
                ensure_ascii=False,
            )
            f.write(line + "\n")
            count += 1

    print(f"Dataset: {count} entries written to {DATASET_FILE}")


if __name__ == "__main__":
    main()
