"""
Batch cycle engine: 1-minute cadence loop.
Fire batch -> collect -> validate -> report.
Handles provider switching on 429 flood.
"""
from __future__ import annotations

import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Optional

from .config import (
    WORKER_BATCH_SIZE, BATCH_INTERVAL_SECONDS, FLOOD_THRESHOLD_PCT,
    PROMPT_VERSION, SCHEMA_VERSION, TRIMMER_VERSION, VALIDATOR_VERSION,
    TEMPERATURE, THINKING_LEVEL,
)
from .providers.base import (
    BaseProvider, TranscriptionRequest, TranscriptionResponse, RequestStatus,
)
from .providers.aistudio import log_api_stats
from .validator import validate_transcription, ValidationResult
from .db import WorkerStats

logger = logging.getLogger(__name__)


@dataclass
class BatchResult:
    batch_id: str
    segments_sent: int = 0
    segments_returned: int = 0
    segments_429: int = 0
    segments_error: int = 0
    provider_used: str = ""
    cache_hits: int = 0
    avg_latency_ms: float = 0.0
    batch_duration_ms: float = 0.0
    responses: list[TranscriptionResponse] = field(default_factory=list)
    validations: list[ValidationResult] = field(default_factory=list)
    transcription_records: list[dict] = field(default_factory=list)
    flag_records: list[dict] = field(default_factory=list)


class BatchCycleEngine:

    def __init__(
        self,
        primary_provider: BaseProvider,
        fallback_provider: Optional[BaseProvider],
        worker_id: str,
        video_id: str,
    ):
        self.primary = primary_provider
        self.fallback = fallback_provider
        self.worker_id = worker_id
        self.video_id = video_id

    async def run_batch(
        self,
        requests: list[TranscriptionRequest],
        expected_language: str,
        audio_durations: dict[str, float],
        trim_metas: dict[str, dict],
    ) -> BatchResult:
        batch_id = f"batch_{uuid.uuid4().hex[:8]}"
        batch_start = time.monotonic()
        result = BatchResult(
            batch_id=batch_id,
            segments_sent=len(requests),
            provider_used=self.primary.get_provider_name(),
        )

        # PHASE 1: Fire batch to primary provider
        logger.info(f"[{batch_id}] Firing {len(requests)} requests to {self.primary.get_provider_name()}")
        responses = await self.primary.send_batch(requests)

        # Count 429s
        rate_limited = [r for r in responses if r.status == RequestStatus.RATE_LIMITED]
        result.segments_429 = len(rate_limited)

        # PHASE 1b: If 429 flood, retry with fallback
        if self.fallback and len(rate_limited) / max(len(requests), 1) > FLOOD_THRESHOLD_PCT:
            logger.warning(f"[{batch_id}] 429 flood ({len(rate_limited)}/{len(requests)}), switching to fallback")
            retry_ids = {r.segment_id for r in rate_limited}
            retry_reqs = [req for req in requests if req.segment_id in retry_ids]

            if retry_reqs:
                fallback_responses = await self.fallback.send_batch(retry_reqs)
                resp_map = {r.segment_id: r for r in responses if r.status != RequestStatus.RATE_LIMITED}
                for fr in fallback_responses:
                    resp_map[fr.segment_id] = fr
                responses = list(resp_map.values())
                result.provider_used = "mixed"

        # Tally results
        successes = [r for r in responses if r.status == RequestStatus.SUCCESS]
        errors = [r for r in responses if r.status == RequestStatus.ERROR]
        result.segments_returned = len(successes)
        result.segments_error = len(errors)
        result.cache_hits = sum(1 for r in successes if r.token_usage.cache_hit)
        result.responses = responses

        latencies = [r.latency_ms for r in successes if r.latency_ms > 0]
        result.avg_latency_ms = sum(latencies) / len(latencies) if latencies else 0

        # PHASE 2: Validate
        for resp in successes:
            if resp.transcription_data:
                vr = validate_transcription(
                    segment_id=resp.segment_id,
                    transcription_data=resp.transcription_data,
                    expected_language=expected_language,
                    audio_duration_s=audio_durations.get(resp.segment_id, 5.0),
                    trim_meta=trim_metas.get(resp.segment_id),
                )
                result.validations.append(vr)

                # Build DB record
                record = self._build_result_record(resp, vr, expected_language)
                result.transcription_records.append(record)

                # Build flag records
                for flag in vr.flags:
                    result.flag_records.append({
                        "segment_id": resp.segment_id,
                        "flag_type": flag.split(":")[0],
                        "details": flag,
                        "resolved": False,
                    })

        # For errors/429s that never succeeded, create flag records
        for resp in responses:
            if resp.status != RequestStatus.SUCCESS:
                result.flag_records.append({
                    "segment_id": resp.segment_id,
                    "flag_type": resp.status.value,
                    "details": resp.error_message,
                    "resolved": False,
                })

        result.batch_duration_ms = (time.monotonic() - batch_start) * 1000
        logger.info(
            f"[{batch_id}] Done: {result.segments_returned}/{result.segments_sent} ok, "
            f"{result.segments_429} 429s, {result.segments_error} errors, "
            f"{result.cache_hits} cache hits, {result.avg_latency_ms:.0f}ms avg latency, "
            f"{result.batch_duration_ms:.0f}ms total"
        )
        log_api_stats()

        return result

    def _build_result_record(self, resp: TranscriptionResponse,
                             vr: ValidationResult, expected_lang: str) -> dict:
        data = resp.transcription_data or {}
        speaker = data.get("speaker", {})

        return {
            "id": str(uuid.uuid4()),
            "video_id": self.video_id,
            "segment_file": resp.segment_id,
            "speaker_id": resp.segment_id.split("_")[0] if "_" in resp.segment_id else "speaker0",
            "expected_language_hint": expected_lang,
            "detected_language": data.get("detected_language", ""),
            "lang_mismatch_flag": vr.lang_mismatch,
            "transcription": data.get("transcription", ""),
            "tagged": data.get("tagged", ""),
            "speaker_emotion": speaker.get("emotion", "neutral"),
            "speaker_style": speaker.get("speaking_style", "conversational"),
            "speaker_pace": speaker.get("pace", "normal"),
            "speaker_accent": speaker.get("accent", ""),
            "num_unk": vr.num_unk,
            "num_inaudible": vr.num_inaudible,
            "num_event_tags": vr.num_event_tags,
            "boundary_score": vr.boundary_score,
            "text_length_per_sec": vr.chars_per_second,
            "overlap_suspected": vr.overlap_suspected,
            "quality_score": vr.quality_score,
            "asr_eligible": vr.asr_eligible,
            "tts_clean_eligible": vr.tts_clean_eligible,
            "tts_expressive_eligible": vr.tts_expressive_eligible,
            "prompt_version": PROMPT_VERSION,
            "schema_version": SCHEMA_VERSION,
            "trimmer_version": TRIMMER_VERSION,
            "validator_version": VALIDATOR_VERSION,
            "model_id": "gemini-3-flash-preview",
            "temperature": TEMPERATURE,
            "thinking_level": THINKING_LEVEL,
            "provider": resp.token_usage.cache_hit and "cached" or self.primary.get_provider_name(),
            "worker_id": self.worker_id,
            "cache_hit": resp.token_usage.cache_hit,
            "token_usage_json": {
                "input_tokens": resp.token_usage.input_tokens,
                "output_tokens": resp.token_usage.output_tokens,
                "cached_tokens": resp.token_usage.cached_tokens,
            },
        }
