"""
Transcript variant shard worker.

Claims shard jobs from PostgreSQL, downloads input parquet/csv from R2, skips
fully Roman rows locally, batches the remaining rows through Gemini, uploads
packed parquet outputs back to R2, and records heartbeats/manifests in DB.
"""
from __future__ import annotations

import asyncio
import json
import logging
import os
import tempfile
import time
import uuid
from pathlib import Path
from typing import Any, Optional

import pandas as pd

from .config import EnvConfig, HEARTBEAT_INTERVAL_S
from .transcript_variant_prompt import (
    InputScriptProfile,
    classify_input_script,
    detect_script_counts,
    extract_protected_spans,
    get_target_script_block,
    romanized_text_is_ascii,
)
from .variant_db import VariantJob, VariantPostgresDB, VariantWorkerStats
from .variant_provider import TranscriptVariantCacheManager, TranscriptVariantClient
from .variant_r2 import VariantR2Client


logger = logging.getLogger(__name__)


def _env_int(name: str, default: int) -> int:
    try:
        return int(os.getenv(name, str(default)))
    except ValueError:
        return default


def _env_str(name: str, default: str) -> str:
    return os.getenv(name, default)


class TranscriptVariantWorker:
    def __init__(self, config: EnvConfig):
        self.config = config
        self.worker_id = config.worker_id
        self.db = VariantPostgresDB(config.database_url)
        self.r2 = VariantR2Client(config)
        self.stats = VariantWorkerStats()
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self.client: Optional[TranscriptVariantClient] = None

        self.batch_size = _env_int("VARIANT_BATCH_SIZE", 10)
        self.concurrent_requests = _env_int("VARIANT_CONCURRENT_REQUESTS", 20)
        self.max_jobs = _env_int("VARIANT_MAX_JOBS", 0)
        self.max_rows_per_job = _env_int("VARIANT_MAX_ROWS_PER_JOB", 0)
        self.pack_target_videos = _env_int("VARIANT_PACK_TARGET_VIDEOS", 50)
        self.pack_target_rows = _env_int("VARIANT_PACK_TARGET_ROWS", 5000)
        self.cache_ttl_s = _env_int("VARIANT_CACHE_TTL_S", 518400)

    async def start(self):
        try:
            await self.db.connect()
            await self.db.init_schema()
            await self._register()
            await self._setup_cache()
            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as e:
            logger.error("variant worker fatal error: %s", e, exc_info=True)
            try:
                await self.db.set_worker_error(self.worker_id, str(e))
            except Exception:
                pass
            raise
        finally:
            await self._cleanup()

    async def _register(self):
        config_json = {
            "worker_type": "transcript_variant",
            "model": "gemini-3-flash-preview",
            "thinking_level": os.getenv("THINKING_LEVEL", "low"),
            "temperature": os.getenv("TEMPERATURE", "0"),
            "batch_size": self.batch_size,
            "concurrent_requests": self.concurrent_requests,
            "pack_target_videos": self.pack_target_videos,
            "pack_target_rows": self.pack_target_rows,
            "gemini_key_index": self.config.gemini_key_index,
        }
        await self.db.register_worker(
            worker_id=self.worker_id,
            provider="aistudio",
            gpu_type=self.config.gpu_type,
            config_json=config_json,
        )

    async def _setup_cache(self):
        cache_manager = TranscriptVariantCacheManager(self.config.primary_gemini_key)
        cache_info = await cache_manager.ensure_cache(self.cache_ttl_s)
        cache_name = cache_info["name"]
        tokens = cache_info.get("usageMetadata", {}).get("totalTokenCount", "?")
        logger.info("variant cache ready: %s (%s tokens)", cache_name, tokens)
        self.client = TranscriptVariantClient(
            api_key=self.config.primary_gemini_key,
            cache_name=cache_name,
        )

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self.db.update_heartbeat(self.worker_id, self.stats)
            except Exception as e:
                logger.warning("variant heartbeat failed: %s", str(e)[:120])
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=HEARTBEAT_INTERVAL_S)
                break
            except asyncio.TimeoutError:
                pass

    async def _main_loop(self):
        jobs_done = 0
        while not self._shutdown_event.is_set():
            if self.max_jobs > 0 and jobs_done >= self.max_jobs:
                logger.info("reached VARIANT_MAX_JOBS=%s", self.max_jobs)
                break

            job = await self.db.claim_job(self.worker_id)
            if job is None:
                logger.info("no pending variant shard jobs")
                break

            self.stats.jobs_claimed += 1
            self.stats.current_shard_id = job.shard_id
            self.stats.rows_remaining = job.total_rows
            logger.info("claimed variant shard %s (%s rows)", job.shard_id, job.total_rows)

            try:
                await self._process_job(job)
                self.stats.jobs_completed += 1
            except Exception as e:
                self.stats.jobs_failed += 1
                await self.db.fail_job(job.shard_id, str(e))
                logger.error("variant shard %s failed: %s", job.shard_id, e, exc_info=True)
            finally:
                self.stats.current_shard_id = None
                self.stats.rows_remaining = 0
                jobs_done += 1

    async def _process_job(self, job: VariantJob):
        assert self.client is not None
        job_started = time.monotonic()
        work_dir = Path(tempfile.mkdtemp(prefix=f"variant_{job.shard_id}_"))
        input_path = work_dir / f"{job.shard_id}.{job.input_format}"
        self.r2.download_file(job.input_bucket, job.input_r2_key, input_path)

        df = self._load_input_frame(input_path, job.input_format)
        if self.max_rows_per_job > 0:
            df = df.head(self.max_rows_per_job).copy()

        metadata = job.metadata_json or {}
        column_map = metadata.get("column_map", {})
        rows = self._prepare_source_rows(df, column_map)
        total_rows = len(rows)
        self.stats.rows_remaining = total_rows

        skipped_rows: list[dict[str, Any]] = []
        gemini_rows: list[dict[str, Any]] = []

        for row in rows:
            profile = classify_input_script(row["text"], row["language_code"])
            row["input_script_profile"] = profile.value
            if profile == InputScriptProfile.fully_roman:
                skipped_rows.append(self._build_skipped_row(row))
            else:
                gemini_rows.append(row)

        logger.info(
            "[%s] total=%s gemini=%s skipped=%s",
            job.shard_id,
            total_rows,
            len(gemini_rows),
            len(skipped_rows),
        )

        output_buffer: list[dict[str, Any]] = []
        output_buffer.extend(skipped_rows)
        rows_processed = len(skipped_rows)
        rows_skipped = len(skipped_rows)
        rows_gemini = 0
        self.stats.rows_processed += len(skipped_rows)
        self.stats.rows_skipped += len(skipped_rows)
        packs_uploaded = 0
        last_pack_key = ""

        if self._should_flush(output_buffer):
            packs_uploaded, last_pack_key = await self._flush_pack(
                job=job,
                output_rows=output_buffer,
                pack_index=packs_uploaded,
            )
            output_buffer = []

        batches = [
            gemini_rows[idx: idx + self.batch_size]
            for idx in range(0, len(gemini_rows), self.batch_size)
        ]
        semaphore = asyncio.Semaphore(self.concurrent_requests)

        async def run_batch(batch_index: int, batch_rows: list[dict[str, Any]]):
            async with semaphore:
                return await self._run_single_batch(batch_index, batch_rows)

        tasks = [
            asyncio.create_task(run_batch(batch_index, batch_rows))
            for batch_index, batch_rows in enumerate(batches)
        ]

        for task in asyncio.as_completed(tasks):
            batch_output_rows, batch_row_count = await task
            output_buffer.extend(batch_output_rows)
            rows_processed += batch_row_count
            rows_gemini += batch_row_count
            self.stats.rows_processed += batch_row_count
            self.stats.rows_gemini += batch_row_count
            self.stats.rows_skipped += 0
            self.stats.rows_remaining = max(total_rows - rows_processed, 0)
            self.stats.active_rpm = (
                self.stats.requests_succeeded / max((time.monotonic() - job_started) / 60.0, 0.001)
            )

            if self._should_flush(output_buffer):
                packs_uploaded, last_pack_key = await self._flush_pack(
                    job=job,
                    output_rows=output_buffer,
                    pack_index=packs_uploaded,
                )
                output_buffer = []
                await self.db.update_job_progress(
                    job.shard_id,
                    rows_processed=rows_processed,
                    rows_skipped=rows_skipped,
                    rows_gemini=rows_gemini,
                    packs_uploaded=packs_uploaded,
                    last_pack_key=last_pack_key,
                )

        if output_buffer:
            packs_uploaded, last_pack_key = await self._flush_pack(
                job=job,
                output_rows=output_buffer,
                pack_index=packs_uploaded,
            )

        await self.db.complete_job(
            job.shard_id,
            rows_processed=rows_processed,
            rows_skipped=rows_skipped,
            rows_gemini=rows_gemini,
            packs_uploaded=packs_uploaded,
            last_pack_key=last_pack_key,
        )
        logger.info(
            "[%s] complete rows=%s gemini=%s skipped=%s packs=%s",
            job.shard_id,
            rows_processed,
            rows_gemini,
            rows_skipped,
            packs_uploaded,
        )

    async def _run_single_batch(self, batch_index: int, batch_rows: list[dict[str, Any]]):
        assert self.client is not None
        request_id = f"{self.worker_id}_{batch_index}_{uuid.uuid4().hex[:8]}"
        items = [
            {
                "id": row["row_id"],
                "language_code": row["language_code"],
                "input_script_profile": row["input_script_profile"],
                "text": row["text"],
            }
            for row in batch_rows
        ]

        try:
            self.stats.requests_sent += 1
            result = await self.client.generate_batch(items)
            self.stats.requests_succeeded += 1
            self.stats.cache_hits += int(result.token_usage.cache_hit)
            self.stats.total_input_tokens += result.token_usage.input_tokens
            self.stats.total_output_tokens += result.token_usage.output_tokens
            self.stats.total_cached_tokens += result.token_usage.cached_tokens

            result_map = {item["id"]: item for item in result.items}
            output_rows: list[dict[str, Any]] = []
            for row in batch_rows:
                actual = result_map.get(row["row_id"])
                if actual is None:
                    output_rows.append(self._build_error_row(row, request_id, "missing_from_response"))
                    continue
                validation_errors = self._validate_item(row, actual)
                output_rows.append(
                    {
                        **row,
                        "processing_route": "gemini",
                        "native_script_text": actual["native_script_text"],
                        "romanized_text": actual["romanized_text"],
                        "request_id": request_id,
                        "request_input_tokens": result.token_usage.input_tokens,
                        "request_output_tokens": result.token_usage.output_tokens,
                        "request_cached_tokens": result.token_usage.cached_tokens,
                        "request_cache_hit": result.token_usage.cache_hit,
                        "validation_errors": json.dumps(validation_errors, ensure_ascii=False),
                    }
                )
            return output_rows, len(batch_rows)
        except Exception as exc:
            self.stats.requests_failed += 1
            logger.error("batch %s failed: %s", batch_index, str(exc)[:300])
            return [self._build_error_row(row, request_id, str(exc)[:200]) for row in batch_rows], len(batch_rows)

    def _build_error_row(self, row: dict[str, Any], request_id: str, error: str) -> dict[str, Any]:
        return {
            **row,
            "processing_route": "gemini_error",
            "native_script_text": "",
            "romanized_text": "",
            "request_id": request_id,
            "request_input_tokens": 0,
            "request_output_tokens": 0,
            "request_cached_tokens": 0,
            "request_cache_hit": False,
            "validation_errors": json.dumps([f"request_error:{error}"], ensure_ascii=False),
        }

    def _prepare_source_rows(self, df: pd.DataFrame, column_map: dict[str, str]) -> list[dict[str, Any]]:
        id_col = column_map.get("id", "row_id")
        video_id_col = column_map.get("video_id", "video_id")
        segment_col = column_map.get("segment_id", "segment_id")
        language_col = column_map.get("language_code", "language_code")
        text_col = column_map.get("text", "text")

        rows: list[dict[str, Any]] = []
        for idx, item in enumerate(df.to_dict(orient="records")):
            row_id = str(item.get(id_col) or f"row_{idx:08d}")
            rows.append(
                {
                    "row_id": row_id,
                    "video_id": str(item.get(video_id_col) or ""),
                    "segment_id": str(item.get(segment_col) or ""),
                    "language_code": str(item.get(language_col) or "").strip() or "en",
                    "text": str(item.get(text_col) or "").strip(),
                    "source_row_index": idx,
                }
            )
        return rows

    def _load_input_frame(self, path: Path, input_format: str) -> pd.DataFrame:
        if input_format.lower() == "csv":
            return pd.read_csv(path)
        return pd.read_parquet(path)

    def _build_skipped_row(self, row: dict[str, Any]) -> dict[str, Any]:
        return {
            **row,
            "processing_route": "local_skip_fully_roman",
            "native_script_text": "",
            "romanized_text": row["text"],
            "request_id": "",
            "request_input_tokens": 0,
            "request_output_tokens": 0,
            "request_cached_tokens": 0,
            "request_cache_hit": False,
            "validation_errors": "[]",
        }

    def _should_flush(self, rows: list[dict[str, Any]]) -> bool:
        if not rows:
            return False
        if len(rows) >= self.pack_target_rows:
            return True
        video_ids = {row.get("video_id", "") for row in rows if row.get("video_id", "")}
        return bool(video_ids) and len(video_ids) >= self.pack_target_videos

    async def _flush_pack(self, job: VariantJob, output_rows: list[dict[str, Any]], pack_index: int):
        pack_id = f"{job.shard_id}_pack_{pack_index:04d}"
        local_path = Path(tempfile.mkdtemp(prefix=f"{pack_id}_")) / f"{pack_id}.parquet"
        pd.DataFrame.from_records(output_rows).to_parquet(local_path, index=False)
        output_key = f"{job.output_prefix.rstrip('/')}/{job.shard_id}/{pack_id}.parquet"
        byte_size = self.r2.upload_file(local_path, job.output_bucket, output_key)
        manifest = {
            "pack_id": pack_id,
            "shard_id": job.shard_id,
            "worker_id": self.worker_id,
            "output_bucket": job.output_bucket,
            "output_key": output_key,
            "row_count": len(output_rows),
            "gemini_row_count": sum(1 for row in output_rows if row["processing_route"] == "gemini"),
            "skipped_row_count": sum(1 for row in output_rows if row["processing_route"] != "gemini"),
            "distinct_video_count": len(
                {row.get("video_id", "") for row in output_rows if row.get("video_id", "")}
            ),
            "byte_size": byte_size,
            "metadata_json": {
                "source_rows": len(output_rows),
                "pack_index": pack_index,
            },
        }
        await self.db.insert_pack_manifest(manifest)
        self.stats.packs_uploaded += 1
        return pack_index + 1, output_key

    def _validate_item(self, expected: dict[str, Any], actual: dict[str, Any]) -> list[str]:
        errors: list[str] = []
        if actual["id"] != expected["row_id"]:
            errors.append("id_mismatch")
        if not actual["native_script_text"].strip():
            errors.append("empty_native_script_text")
        if not actual["romanized_text"].strip():
            errors.append("empty_romanized_text")

        for span in extract_protected_spans(expected["text"]):
            if span not in actual["native_script_text"]:
                errors.append(f"missing_protected_native:{span}")
            if span not in actual["romanized_text"]:
                errors.append(f"missing_protected_roman:{span}")

        if not romanized_text_is_ascii(actual["romanized_text"]):
            errors.append("roman_not_ascii")

        target_script = get_target_script_block(expected["language_code"])
        native_counts = detect_script_counts(actual["native_script_text"])
        roman_counts = detect_script_counts(actual["romanized_text"])
        if target_script != "Latin" and native_counts.get(target_script, 0) == 0:
            errors.append("native_missing_target_script")
        if target_script != "Latin" and roman_counts.get(target_script, 0) > 0:
            errors.append("roman_contains_target_script")
        return errors

    async def _cleanup(self):
        self._shutdown_event.set()
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass
        try:
            await self.db.update_heartbeat(self.worker_id, self.stats)
        except Exception:
            pass
        try:
            await self.db.set_worker_offline(self.worker_id)
        except Exception:
            pass
        if self.client:
            await self.client.close()
        await self.db.close()
