"""
Preflight check: progressive test on real pF_BQpHaIdU Telugu podcast.
Phase 1: 1 segment (sanity check)
Phase 2: 5 segments (concurrency test)
Phase 3: All 426 segments (production simulation)

Results saved to preflight/results/ for dashboard consumption.
"""
import asyncio
import json
import logging
import os
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("MOCK_MODE", "false")

from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")

from src.config import EnvConfig, WORKER_BATCH_SIZE
from src.audio_polish import polish_segment, polish_all_segments, PolishedSegment
from src.providers.aistudio import AIStudioProvider
from src.providers.base import TranscriptionRequest, RequestStatus
from src.validator import validate_transcription
from src.batch_cycle import BatchCycleEngine
from src.prompt_builder import build_system_prompt

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("preflight")

RESULTS_DIR = Path(__file__).parent / "results"
SEGMENTS_DIR = Path(__file__).parent / "pF_BQpHaIdU" / "segments"
VIDEO_ID = "pF_BQpHaIdU"
LANGUAGE = "te"


def get_segment_paths() -> list[Path]:
    return sorted(SEGMENTS_DIR.glob("*.flac"))


def save_result(phase: str, data: dict):
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    path = RESULTS_DIR / f"{phase}.json"
    with open(path, "w") as f:
        json.dump(data, f, indent=2, ensure_ascii=False, default=str)
    logger.info(f"Saved results to {path}")


async def run_phase(phase_name: str, segment_paths: list[Path], provider: AIStudioProvider) -> dict:
    logger.info(f"\n{'='*70}")
    logger.info(f"PHASE: {phase_name} ({len(segment_paths)} segments)")
    logger.info(f"{'='*70}")

    phase_start = time.monotonic()

    # Polish audio
    logger.info("Polishing audio segments...")
    polish_start = time.monotonic()
    polished = polish_all_segments(segment_paths)
    valid = [p for p in polished if not p.trim_meta.discarded]
    discarded = [p for p in polished if p.trim_meta.discarded]
    polish_time = time.monotonic() - polish_start
    logger.info(f"Polish done: {len(valid)} valid, {len(discarded)} discarded in {polish_time:.2f}s")

    if not valid:
        return {"phase": phase_name, "error": "All segments discarded", "polished": len(polished), "valid": 0}

    # Build requests
    requests = []
    audio_durations = {}
    trim_metas = {}
    segment_audio_map = {}

    for seg in valid:
        seg_id = seg.trim_meta.original_file
        if seg.trim_meta.was_split:
            seg_id = f"{seg.trim_meta.original_file}_split{seg.trim_meta.split_index}"

        requests.append(TranscriptionRequest(
            segment_id=seg_id,
            audio_base64=seg.base64_audio,
            language_code=LANGUAGE,
            original_file=seg.trim_meta.original_file,
        ))
        audio_durations[seg_id] = seg.trim_meta.final_duration_ms / 1000
        trim_metas[seg_id] = {
            "abrupt_start": seg.trim_meta.abrupt_start,
            "abrupt_end": seg.trim_meta.abrupt_end,
            "was_split": seg.trim_meta.was_split,
            "original_duration_ms": seg.trim_meta.original_duration_ms,
            "final_duration_ms": seg.trim_meta.final_duration_ms,
        }
        segment_audio_map[seg_id] = seg.trim_meta.original_file

    # Send to AI Studio in batches
    batch_size = min(len(requests), WORKER_BATCH_SIZE)
    all_responses = []
    all_validations = []
    total_tokens = 0
    total_cache_hits = 0
    batch_count = 0

    for i in range(0, len(requests), batch_size):
        batch = requests[i:i + batch_size]
        batch_num = i // batch_size + 1
        total_batches = (len(requests) + batch_size - 1) // batch_size
        logger.info(f"Batch {batch_num}/{total_batches}: sending {len(batch)} requests...")

        batch_start = time.monotonic()
        responses = await provider.send_batch(batch)
        batch_time = time.monotonic() - batch_start
        batch_count += 1

        successes = [r for r in responses if r.status == RequestStatus.SUCCESS]
        errors = [r for r in responses if r.status == RequestStatus.ERROR]
        rate_limited = [r for r in responses if r.status == RequestStatus.RATE_LIMITED]

        logger.info(f"Batch {batch_num}: {len(successes)} ok, {len(errors)} errors, "
                     f"{len(rate_limited)} 429s in {batch_time:.1f}s")

        for resp in responses:
            all_responses.append(resp)
            if resp.token_usage:
                total_tokens += resp.token_usage.total_tokens
                if resp.token_usage.cache_hit:
                    total_cache_hits += 1

            if resp.status == RequestStatus.SUCCESS and resp.transcription_data:
                vr = validate_transcription(
                    segment_id=resp.segment_id,
                    transcription_data=resp.transcription_data,
                    expected_language=LANGUAGE,
                    audio_duration_s=audio_durations.get(resp.segment_id, 5.0),
                    trim_meta=trim_metas.get(resp.segment_id),
                )
                all_validations.append(vr)

    phase_time = time.monotonic() - phase_start

    # Build results
    results_list = []
    for resp in all_responses:
        entry = {
            "segment_id": resp.segment_id,
            "original_file": segment_audio_map.get(resp.segment_id, resp.segment_id),
            "status": resp.status.value,
            "latency_ms": resp.latency_ms,
            "audio_duration_s": audio_durations.get(resp.segment_id, 0),
            "error_message": resp.error_message,
        }
        if resp.transcription_data:
            entry["transcription"] = resp.transcription_data.get("transcription", "")
            entry["tagged"] = resp.transcription_data.get("tagged", "")
            entry["detected_language"] = resp.transcription_data.get("detected_language", "")
            entry["speaker"] = resp.transcription_data.get("speaker", {})
        if resp.token_usage:
            entry["token_usage"] = {
                "input": resp.token_usage.input_tokens,
                "output": resp.token_usage.output_tokens,
                "cached": resp.token_usage.cached_tokens,
                "cache_hit": resp.token_usage.cache_hit,
            }
        # Attach validation
        for vr in all_validations:
            if vr.segment_id == resp.segment_id:
                entry["validation"] = {
                    "quality_score": vr.quality_score,
                    "is_empty": vr.is_empty,
                    "is_no_speech": vr.is_no_speech,
                    "chars_per_second": vr.chars_per_second,
                    "lang_mismatch": vr.lang_mismatch,
                    "boundary_score": vr.boundary_score,
                    "asr_eligible": vr.asr_eligible,
                    "tts_clean_eligible": vr.tts_clean_eligible,
                    "tts_expressive_eligible": vr.tts_expressive_eligible,
                    "num_unk": vr.num_unk,
                    "num_event_tags": vr.num_event_tags,
                    "flags": vr.flags,
                }
                break
        results_list.append(entry)

    # Compute summary
    successes = [r for r in all_responses if r.status == RequestStatus.SUCCESS]
    quality_scores = [v.quality_score for v in all_validations]
    latencies = [r.latency_ms for r in successes if r.latency_ms > 0]

    summary = {
        "phase": phase_name,
        "video_id": VIDEO_ID,
        "language": LANGUAGE,
        "total_segments_input": len(segment_paths),
        "polished_valid": len(valid),
        "polished_discarded": len(discarded),
        "requests_sent": len(requests),
        "responses_success": len(successes),
        "responses_error": len([r for r in all_responses if r.status == RequestStatus.ERROR]),
        "responses_429": len([r for r in all_responses if r.status == RequestStatus.RATE_LIMITED]),
        "total_tokens": total_tokens,
        "cache_hits": total_cache_hits,
        "cache_hit_rate": total_cache_hits / max(len(successes), 1),
        "avg_quality_score": sum(quality_scores) / len(quality_scores) if quality_scores else 0,
        "min_quality_score": min(quality_scores) if quality_scores else 0,
        "max_quality_score": max(quality_scores) if quality_scores else 0,
        "avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
        "p50_latency_ms": sorted(latencies)[len(latencies)//2] if latencies else 0,
        "p95_latency_ms": sorted(latencies)[int(len(latencies)*0.95)] if latencies else 0,
        "polish_time_s": polish_time,
        "total_time_s": phase_time,
        "batch_count": batch_count,
        "segments_per_second": len(requests) / phase_time if phase_time > 0 else 0,
    }

    # Discarded details
    discarded_details = []
    for d in discarded:
        discarded_details.append({
            "file": d.trim_meta.original_file,
            "reason": d.trim_meta.discard_reason,
            "original_duration_ms": d.trim_meta.original_duration_ms,
        })

    result = {
        "summary": summary,
        "results": results_list,
        "discarded": discarded_details,
    }

    save_result(phase_name, result)

    logger.info(f"\n--- {phase_name} SUMMARY ---")
    for k, v in summary.items():
        if isinstance(v, float):
            logger.info(f"  {k}: {v:.2f}")
        else:
            logger.info(f"  {k}: {v}")

    return result


async def main():
    config = EnvConfig()
    provider = AIStudioProvider(api_key=config.gemini_key, mock_mode=False)

    all_segments = get_segment_paths()
    logger.info(f"Total segments available: {len(all_segments)}")

    # Phase 1: Single segment
    phase1 = await run_phase("phase1_single", all_segments[:1], provider)
    if phase1.get("results") and phase1["results"][0].get("status") == "success":
        logger.info(f"\nPhase 1 transcription: {phase1['results'][0].get('transcription', 'N/A')[:200]}")
    else:
        logger.error("Phase 1 FAILED - stopping.")
        return

    # Phase 2: 5 segments (concurrency test)
    await asyncio.sleep(2)
    phase2 = await run_phase("phase2_concurrent", all_segments[:5], provider)

    # Phase 3: All segments (production simulation, capped at 426)
    await asyncio.sleep(2)
    phase3 = await run_phase("phase3_full", all_segments, provider)

    # Save combined summary
    combined = {
        "video_id": VIDEO_ID,
        "language": LANGUAGE,
        "total_segments_in_tar": len(all_segments),
        "phases": {
            "phase1": phase1.get("summary", {}),
            "phase2": phase2.get("summary", {}),
            "phase3": phase3.get("summary", {}),
        },
    }
    save_result("combined_summary", combined)
    logger.info("\n*** PREFLIGHT COMPLETE ***")


if __name__ == "__main__":
    asyncio.run(main())
