from __future__ import annotations

import asyncio
import logging
import re
import shutil
import time
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional

from .audio_polish import polish_all_segments
from .final_export_common import (
    build_export_segment_payload,
    build_pack_artifacts,
    replay_segment_id,
)
from .final_export_config import FinalExportConfig
from .final_export_db import FinalExportPostgresDB, FinalExportVideoJob, FinalExportWorkerStats
from .final_export_reference_store import FinalExportReferenceStore
from .final_export_r2 import FinalExportR2Client
from .r2_client import R2Client


logger = logging.getLogger(__name__)


class FinalExportVideoWorker:
    def __init__(self, config: FinalExportConfig):
        self.config = config
        self.db = FinalExportPostgresDB(config.database_url)
        self.raw_r2 = R2Client(config.base)
        self.out_r2 = FinalExportR2Client(config)
        self._work_root = config.local_work_root / "video_stage" / config.run_id / config.worker_id
        self.reference_store = FinalExportReferenceStore(config, self._work_root)
        self._jobs_root = self._work_root / "jobs"
        self.stats = FinalExportWorkerStats()
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None

    async def start(self):
        try:
            self._work_root.mkdir(parents=True, exist_ok=True)
            self._jobs_root.mkdir(parents=True, exist_ok=True)
            await self.db.connect()
            await self.db.init_schema()
            await self.db.reset_stale_claims(self.config.claim_stale_after_s)
            await self.reference_store.start()
            await self._register()
            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as exc:
            logger.error("final export video worker fatal error: %s", exc, exc_info=True)
            try:
                await self.db.set_worker_error(self.config.worker_id, str(exc))
            except Exception:
                pass
            raise
        finally:
            await self._cleanup()

    async def _register(self):
        config_json = {
            "stage": "video_export",
            "run_id": self.config.run_id,
            "output_bucket": self.config.output_bucket,
            "output_prefix": self.config.output_prefix,
            "reference_mode": self.config.reference_mode,
            "reference_bucket": self.config.reference_bucket,
            "microshard_target_rows": self.config.microshard_target_rows,
            "final_shard_target_rows": self.config.final_shard_target_rows,
            "polish_threads": self.config.polish_threads,
        }
        await self.db.register_worker(
            worker_id=self.config.worker_id,
            stage="video_export",
            gpu_type=self.config.gpu_type,
            config_json=config_json,
        )

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self.db.update_heartbeat(self.config.worker_id, self.stats)
            except Exception as exc:
                logger.warning("final export heartbeat failed: %s", str(exc)[:160])
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=30)
                break
            except asyncio.TimeoutError:
                pass

    # ------------------------------------------------------------------
    # Prefetch pipeline: overlap download of video N+1 with processing N
    # ------------------------------------------------------------------

    async def _download_video_tar(self, video_id: str) -> tuple[Path, Any]:
        """Download and extract a video tar (I/O phase)."""
        video_work = self._jobs_root / video_id
        video_work.mkdir(parents=True, exist_ok=True)
        loop = asyncio.get_running_loop()
        t0 = time.monotonic()
        tar_path = await loop.run_in_executor(
            None, self.raw_r2.download_tar, video_id, video_work,
        )
        extracted = await loop.run_in_executor(
            None, self.raw_r2.extract_tar, tar_path, video_id,
        )
        logger.info("[%s] download+extract %.1fs", video_id, time.monotonic() - t0)
        return video_work, extracted

    async def _main_loop(self):
        processed = 0
        # 1-deep prefetch: claim next job and start its download while processing current
        prefetch_job: Optional[FinalExportVideoJob] = None
        prefetch_task: Optional[asyncio.Task] = None

        while not self._shutdown_event.is_set():
            if self.config.max_videos > 0 and processed >= self.config.max_videos:
                logger.info("reached FINAL_EXPORT_MAX_VIDEOS=%s", self.config.max_videos)
                break

            # Pick up prefetched job, or claim a new one
            if prefetch_job is not None:
                job = prefetch_job
                dl_task = prefetch_task
                prefetch_job = None
                prefetch_task = None
            else:
                job = await self.db.claim_video_job(self.config.worker_id)
                if job is None:
                    logger.info("no pending final export videos")
                    break
                await self.db.mark_video_processing(job.video_id, self.config.worker_id)
                dl_task = asyncio.create_task(self._download_video_tar(job.video_id))

            self.stats.jobs_claimed += 1
            self.stats.current_item = job.video_id

            try:
                video_work, extracted = await dl_task

                # Kick off prefetch for next video while we do CPU-heavy processing
                want_prefetch = (
                    self.config.max_videos == 0
                    or processed + 1 < self.config.max_videos
                )
                if want_prefetch and not self._shutdown_event.is_set():
                    next_job = await self.db.claim_video_job(self.config.worker_id)
                    if next_job:
                        await self.db.mark_video_processing(
                            next_job.video_id, self.config.worker_id,
                        )
                        prefetch_job = next_job
                        prefetch_task = asyncio.create_task(
                            self._download_video_tar(next_job.video_id),
                        )

                await self._process_video(job.video_id, video_work, extracted)
                self.stats.jobs_completed += 1
                processed += 1
            except Exception as exc:
                self.stats.jobs_failed += 1
                await self.db.fail_video(job.video_id, str(exc))
                logger.error(
                    "final export video %s failed: %s", job.video_id, exc, exc_info=True,
                )
            finally:
                self.stats.current_item = None

        # Release any dangling prefetch back to pending
        if prefetch_job is not None and prefetch_task is not None:
            prefetch_task.cancel()
            try:
                await prefetch_task
            except (asyncio.CancelledError, Exception):
                pass
            try:
                await self.db.release_video_job(prefetch_job.video_id)
            except Exception:
                pass

    async def _process_video(self, video_id: str, video_work: Path, extracted: Any):
        logger.info("[%s] final export start", video_id)
        try:
            loop = asyncio.get_running_loop()

            reference_rows = self.reference_store.get_video_reference_rows(video_id)
            if not reference_rows:
                raise RuntimeError("No canonical final export rows found for video")

            raw_paths = sorted(extracted.segment_paths)
            polished_segments = await loop.run_in_executor(
                None,
                lambda: polish_all_segments(raw_paths, max_workers=self.config.polish_threads),
            )

            grouped_metadata: dict[str, list[dict[str, Any]]] = defaultdict(list)
            grouped_audio: dict[str, list[dict[str, Any]]] = defaultdict(list)
            drop_counts: Counter[str] = Counter()
            replay_discarded = 0
            total_flac_bytes = 0
            kept_count = 0
            replay_valid_count = 0

            for polished in polished_segments:
                if polished.trim_meta.discarded:
                    replay_discarded += 1
                    continue

                replay_valid_count += 1
                segment_id = replay_segment_id(
                    polished.trim_meta.original_file,
                    polished.trim_meta.was_split,
                    polished.trim_meta.split_index,
                )
                row = reference_rows.get(segment_id)
                if row is None:
                    drop_counts["no_transcription_row"] += 1
                    continue

                segment_language = str(row.get("segment_language") or "").strip().lower()
                if not segment_language:
                    drop_counts["transcript_row_but_missing_required_fields"] += 1
                    continue
                if segment_language not in self.config.supported_languages:
                    drop_counts["language_unsupported"] += 1
                    continue

                transcription_native = str(row.get("native_script_text") or row.get("transcription") or "")
                transcription_romanized = str(row.get("romanized_text") or row.get("transcription") or "")
                if self.config.require_variants and not transcription_native.strip() and not transcription_romanized.strip():
                    drop_counts["variant_missing"] += 1
                    continue
                if self.config.require_validation and not bool(row.get("final_has_validation")):
                    drop_counts["validation_missing"] += 1
                    continue

                payload = build_export_segment_payload(
                    video_id=video_id,
                    canonical_row=row,
                    polished_segment=polished,
                    run_id=self.config.run_id,
                    worker_id=self.config.worker_id,
                    exported_at=datetime.now(timezone.utc).isoformat(),
                )
                grouped_metadata[segment_language].append(payload["metadata_row"])
                grouped_audio[segment_language].append(payload["audio_row"])
                kept_count += 1
                total_flac_bytes += len(polished.flac_bytes)

            microshard_count = 0
            for language, metadata_rows in grouped_metadata.items():
                audio_rows = grouped_audio[language]
                pairs = list(zip(metadata_rows, audio_rows))
                pairs.sort(key=lambda item: item[0]["segment_id"])
                for chunk_index, chunk_start in enumerate(
                    range(0, len(pairs), self.config.microshard_target_rows)
                ):
                    chunk = pairs[chunk_start : chunk_start + self.config.microshard_target_rows]
                    chunk_metadata = [item[0] for item in chunk]
                    chunk_audio = [item[1] for item in chunk]
                    microshard_id = f"{video_id}_{language}_{chunk_index:04d}"
                    pack_dir = video_work / "microshards" / language / microshard_id
                    artifacts = build_pack_artifacts(
                        pack_dir=pack_dir,
                        manifest_name="manifest.json",
                        metadata_rows=chunk_metadata,
                        audio_rows=chunk_audio,
                        manifest_payload={
                            "microshard_id": microshard_id,
                            "run_id": self.config.run_id,
                            "video_id": video_id,
                            "language": language,
                            "segment_count": len(chunk_metadata),
                            "source_video_ids_sample": [video_id],
                            "worker_id": self.config.worker_id,
                            "created_at": datetime.now(timezone.utc).isoformat(),
                        },
                    )
                    keys = self._upload_pack(
                        bucket=self.config.output_bucket,
                        base_prefix=(
                            f"{self.config.microshard_prefix}/lang={language}/video={video_id}/{microshard_id}"
                        ),
                        artifacts=artifacts,
                    )
                    await self.db.insert_microshard(
                        {
                            "microshard_id": microshard_id,
                            "run_id": self.config.run_id,
                            "video_id": video_id,
                            "language": language,
                            "chunk_index": chunk_index,
                            "status": "pending",
                            "row_count": artifacts.row_count,
                            "consumed_rows": 0,
                            "output_bucket": self.config.output_bucket,
                            "metadata_key": keys["metadata_key"],
                            "audio_key": keys["audio_key"],
                            "audio_index_key": keys["audio_index_key"],
                            "manifest_key": keys["manifest_key"],
                            "metadata_size_bytes": artifacts.metadata_path.stat().st_size,
                            "audio_size_bytes": artifacts.audio_tar_path.stat().st_size,
                            "audio_index_size_bytes": artifacts.audio_index_path.stat().st_size,
                            "manifest_size_bytes": artifacts.manifest_path.stat().st_size,
                            "metadata_sha256": artifacts.metadata_sha256,
                            "audio_sha256": artifacts.audio_sha256,
                            "audio_index_sha256": artifacts.audio_index_sha256,
                            "segment_id_set_sha256": artifacts.segment_id_set_sha256,
                            "claimed_by": None,
                            "claimed_at": None,
                            "compacted_at": None,
                            "error_message": None,
                            "attempt_count": 0,
                            "metadata_json": {
                                "source_video_ids": [video_id],
                                "drop_counts": dict(drop_counts),
                                "replay_discarded": replay_discarded,
                                "manifest_sha256": artifacts.manifest_sha256,
                            },
                        }
                    )
                    microshard_count += 1
                    self.stats.packs_uploaded += 1
                    self.stats.rows_uploaded += artifacts.row_count

            await self.db.insert_video_output(
                {
                    "video_id": video_id,
                    "run_id": self.config.run_id,
                    "status": "spooled",
                    "raw_parent_count": len(raw_paths),
                    "replay_valid_count": replay_valid_count,
                    "kept_count": kept_count,
                    "dropped_count": sum(drop_counts.values()) + replay_discarded,
                    "microshard_count": microshard_count,
                    "total_flac_bytes": total_flac_bytes,
                    "drop_counts_json": dict(drop_counts),
                    "metadata_json": {
                        "replay_discarded": replay_discarded,
                        "languages": sorted(grouped_metadata.keys()),
                    },
                }
            )
            await self.db.complete_video_spooled(video_id)
            logger.info(
                "[%s] spooled raw_parents=%s replay_valid=%s kept=%s dropped=%s microshards=%s",
                video_id,
                len(raw_paths),
                replay_valid_count,
                kept_count,
                sum(drop_counts.values()) + replay_discarded,
                microshard_count,
            )
        finally:
            shutil.rmtree(video_work, ignore_errors=True)

    def _upload_pack(self, *, bucket: str, base_prefix: str, artifacts) -> dict[str, str]:
        metadata_key = f"{base_prefix}/metadata.parquet"
        audio_key = f"{base_prefix}/audio.tar"
        audio_index_key = f"{base_prefix}/audio_index.parquet"
        manifest_key = f"{base_prefix}/manifest.json"
        self.out_r2.upload_file(artifacts.metadata_path, bucket, metadata_key)
        self.out_r2.upload_file(artifacts.audio_tar_path, bucket, audio_key)
        self.out_r2.upload_file(artifacts.audio_index_path, bucket, audio_index_key)
        self.out_r2.upload_file(artifacts.manifest_path, bucket, manifest_key)

        if not self.config.mock_mode:
            expected = {
                metadata_key: artifacts.metadata_path.stat().st_size,
                audio_key: artifacts.audio_tar_path.stat().st_size,
                audio_index_key: artifacts.audio_index_path.stat().st_size,
                manifest_key: artifacts.manifest_path.stat().st_size,
            }
            for key, size in expected.items():
                remote_size = self.out_r2.head_size(bucket, key)
                if remote_size != size:
                    raise RuntimeError(f"HEAD size mismatch for s3://{bucket}/{key}: {remote_size} != {size}")

        return {
            "metadata_key": metadata_key,
            "audio_key": audio_key,
            "audio_index_key": audio_index_key,
            "manifest_key": manifest_key,
        }

    async def _cleanup(self):
        self._shutdown_event.set()
        if self._heartbeat_task:
            try:
                await self._heartbeat_task
            except Exception:
                pass
        try:
            await self.reference_store.close()
        except Exception:
            pass
        if self._jobs_root.exists():
            shutil.rmtree(self._jobs_root, ignore_errors=True)
        try:
            await self.db.set_worker_offline(self.config.worker_id)
        except Exception:
            pass
        await self.db.close()
