from __future__ import annotations

import asyncio
import logging
import shutil
import tarfile
import time
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import pyarrow.parquet as pq

from .final_export_common import build_pack_artifacts
from .final_export_config import FinalExportConfig
from .final_export_db import (
    FinalExportMicroshardJob,
    FinalExportPostgresDB,
    FinalExportWorkerStats,
)
from .final_export_r2 import FinalExportR2Client


logger = logging.getLogger(__name__)


@dataclass
class BufferedSegment:
    microshard_id: str
    video_id: str
    metadata_row: dict
    audio_index_row: dict
    audio_tar_path: Path


class FinalExportCompactor:
    def __init__(self, config: FinalExportConfig):
        self.config = config
        self.db = FinalExportPostgresDB(config.database_url)
        self.r2 = FinalExportR2Client(config)
        self.stats = FinalExportWorkerStats()
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._work_root = config.local_work_root / "compact_stage" / config.run_id / config.worker_id
        self._current_language: Optional[str] = None

    async def start(self):
        try:
            self._work_root.mkdir(parents=True, exist_ok=True)
            await self.db.connect()
            await self.db.init_schema()
            await self.db.reset_stale_claims(self.config.claim_stale_after_s)
            await self._register()
            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as exc:
            logger.error("final export compactor fatal error: %s", exc, exc_info=True)
            try:
                await self.db.set_worker_error(self.config.worker_id, str(exc))
            except Exception:
                pass
            raise
        finally:
            await self._cleanup()

    async def _register(self):
        config_json = {
            "stage": "language_compactor",
            "run_id": self.config.run_id,
            "output_bucket": self.config.output_bucket,
            "output_prefix": self.config.output_prefix,
                "reference_mode": self.config.reference_mode,
            "final_shard_target_rows": self.config.final_shard_target_rows,
            "allow_partial_shards": self.config.allow_partial_shards,
            "language_filters": self.config.language_filters,
        }
        await self.db.register_worker(
            worker_id=self.config.worker_id,
            stage="language_compactor",
            gpu_type=self.config.gpu_type,
            config_json=config_json,
        )

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self.db.update_heartbeat(self.config.worker_id, self.stats)
                if self._current_language:
                    await self.db.heartbeat_language_lease(
                        self._current_language,
                        self.config.worker_id,
                        self.config.language_lease_seconds,
                    )
            except Exception as exc:
                logger.warning("final export compactor heartbeat failed: %s", str(exc)[:160])
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=30)
                break
            except asyncio.TimeoutError:
                pass

    async def _main_loop(self):
        shards_written = 0
        while not self._shutdown_event.is_set():
            if self.config.max_shards > 0 and shards_written >= self.config.max_shards:
                logger.info("reached FINAL_EXPORT_MAX_SHARDS=%s", self.config.max_shards)
                break

            language = await self.db.acquire_language_lease(
                worker_id=self.config.worker_id,
                run_id=self.config.run_id,
                lease_seconds=self.config.language_lease_seconds,
                languages=self.config.language_filters or None,
            )
            if language is None:
                logger.info("no pending final export microshards to compact")
                break

            self._current_language = language
            self.stats.current_item = f"lang={language}"
            try:
                wrote = await self._process_language(language)
                shards_written += wrote
            finally:
                await self.db.release_language_lease(language, self.config.worker_id)
                self._current_language = None
                self.stats.current_item = None

    async def _process_language(self, language: str) -> int:
        logger.info("[lang=%s] compaction start", language)
        claimed_jobs: dict[str, FinalExportMicroshardJob] = {}
        buffer: list[BufferedSegment] = []
        written = 0
        exhausted = False

        while not self._shutdown_event.is_set():
            if len(buffer) < self.config.final_shard_target_rows and not exhausted:
                new_jobs = await self.db.claim_microshards_for_language(
                    worker_id=self.config.worker_id,
                    run_id=self.config.run_id,
                    language=language,
                    limit=self.config.compactor_claim_limit,
                )
                if not new_jobs:
                    exhausted = True
                else:
                    self.stats.jobs_claimed += len(new_jobs)
                    for job in new_jobs:
                        claimed_jobs[job.microshard_id] = job
                        buffer.extend(self._load_microshard_rows(job))

            self.stats.rows_buffered = len(buffer)

            if len(buffer) >= self.config.final_shard_target_rows:
                await self._flush_buffer(language, buffer, self.config.final_shard_target_rows)
                written += 1
                self.stats.jobs_completed += 1
                self.stats.rows_buffered = len(buffer)
                continue

            if exhausted:
                if buffer and self.config.allow_partial_shards:
                    await self._flush_buffer(language, buffer, len(buffer))
                    written += 1
                    self.stats.jobs_completed += 1
                    self.stats.rows_buffered = len(buffer)
                break

        if claimed_jobs:
            await self.db.release_microshards(list(claimed_jobs.keys()), self.config.worker_id)
        return written

    def _load_microshard_rows(self, job: FinalExportMicroshardJob) -> list[BufferedSegment]:
        local_dir = self._work_root / "microshards" / job.microshard_id
        local_dir.mkdir(parents=True, exist_ok=True)
        metadata_path = local_dir / "metadata.parquet"
        audio_tar_path = local_dir / "audio.tar"
        audio_index_path = local_dir / "audio_index.parquet"
        if not metadata_path.exists():
            self.r2.download_file(job.output_bucket, job.metadata_key, metadata_path)
        if not audio_tar_path.exists():
            self.r2.download_file(job.output_bucket, job.audio_key, audio_tar_path)
        if not audio_index_path.exists():
            self.r2.download_file(job.output_bucket, job.audio_index_key, audio_index_path)

        metadata_rows = pq.read_table(metadata_path).to_pylist()
        audio_index_rows = pq.read_table(audio_index_path).to_pylist()
        if len(metadata_rows) != len(audio_index_rows):
            raise RuntimeError(
                f"Microshard row mismatch for {job.microshard_id}: {len(metadata_rows)} != {len(audio_index_rows)}"
            )

        buffered: list[BufferedSegment] = []
        start = min(job.consumed_rows, len(metadata_rows))
        for idx in range(start, len(metadata_rows)):
            meta = metadata_rows[idx]
            audio = audio_index_rows[idx]
            if str(meta.get("segment_id")) != str(audio.get("segment_id")):
                raise RuntimeError(f"Microshard order mismatch for {job.microshard_id} at row {idx}")
            buffered.append(
                BufferedSegment(
                    microshard_id=job.microshard_id,
                    video_id=str(meta.get("video_id") or job.video_id),
                    metadata_row=meta,
                    audio_index_row=audio,
                    audio_tar_path=audio_tar_path,
                )
            )
        return buffered

    async def _flush_buffer(self, language: str, buffer: list[BufferedSegment], take_rows: int):
        selected = list(buffer[:take_rows])
        del buffer[:take_rows]
        shard_id = f"{language}_shard_{int(time.time())}_{self.stats.packs_uploaded + 1:06d}"
        pack_dir = self._work_root / "final_shards" / language / shard_id
        metadata_rows = [item.metadata_row for item in selected]
        audio_rows = self._load_audio_rows(selected)

        source_counts = Counter(item.microshard_id for item in selected)
        video_ids = sorted({item.video_id for item in selected})
        artifacts = build_pack_artifacts(
            pack_dir=pack_dir,
            manifest_name="manifest.json",
            metadata_rows=metadata_rows,
            audio_rows=audio_rows,
            manifest_payload={
                "shard_id": shard_id,
                "run_id": self.config.run_id,
                "language": language,
                "segment_count": len(selected),
                "video_count": len(video_ids),
                "source_microshard_count": len(source_counts),
                "source_video_ids_sample": video_ids[:32],
                "worker_id": self.config.worker_id,
                "created_at": datetime.now(timezone.utc).isoformat(),
            },
        )
        keys = self._upload_pack(
            bucket=self.config.output_bucket,
            base_prefix=f"{self.config.shard_prefix}/lang={language}/{shard_id}",
            artifacts=artifacts,
        )
        await self.db.insert_final_shard(
            {
                "shard_id": shard_id,
                "run_id": self.config.run_id,
                "language": language,
                "output_bucket": self.config.output_bucket,
                "metadata_key": keys["metadata_key"],
                "audio_key": keys["audio_key"],
                "audio_index_key": keys["audio_index_key"],
                "manifest_key": keys["manifest_key"],
                "segment_count": artifacts.row_count,
                "video_count": len(video_ids),
                "source_microshard_count": len(source_counts),
                "metadata_size_bytes": artifacts.metadata_path.stat().st_size,
                "audio_size_bytes": artifacts.audio_tar_path.stat().st_size,
                "audio_index_size_bytes": artifacts.audio_index_path.stat().st_size,
                "manifest_size_bytes": artifacts.manifest_path.stat().st_size,
                "metadata_sha256": artifacts.metadata_sha256,
                "audio_sha256": artifacts.audio_sha256,
                "audio_index_sha256": artifacts.audio_index_sha256,
                "segment_id_set_sha256": artifacts.segment_id_set_sha256,
                "metadata_json": {
                    "source_microshard_slices": dict(source_counts),
                    "manifest_sha256": artifacts.manifest_sha256,
                },
            }
        )
        await self.db.commit_microshard_consumption(
            worker_id=self.config.worker_id,
            consumption=dict(source_counts),
        )
        self.stats.packs_uploaded += 1
        self.stats.rows_uploaded += artifacts.row_count
        logger.info(
            "[lang=%s] wrote shard=%s rows=%s videos=%s microshards=%s",
            language,
            shard_id,
            artifacts.row_count,
            len(video_ids),
            len(source_counts),
        )

    def _load_audio_rows(self, selected: list[BufferedSegment]) -> list[dict]:
        audio_rows: list[dict] = []
        tar_cache: dict[Path, tarfile.TarFile] = {}
        try:
            for item in selected:
                tf = tar_cache.get(item.audio_tar_path)
                if tf is None:
                    tf = tarfile.open(item.audio_tar_path, "r")
                    tar_cache[item.audio_tar_path] = tf
                member_name = str(item.audio_index_row["tar_member_name"])
                handle = tf.extractfile(member_name)
                if handle is None:
                    raise RuntimeError(f"Missing tar member {member_name} in {item.audio_tar_path}")
                flac_bytes = handle.read()
                audio_rows.append(
                    {
                        "video_id": item.video_id,
                        "segment_id": item.metadata_row["segment_id"],
                        "tar_member_name": member_name,
                        "flac_bytes": flac_bytes,
                        "flac_sha256": item.audio_index_row["flac_sha256"],
                        "audio_duration_s": item.audio_index_row["audio_duration_s"],
                    }
                )
        finally:
            for tf in tar_cache.values():
                tf.close()
        return audio_rows

    def _upload_pack(self, *, bucket: str, base_prefix: str, artifacts) -> dict[str, str]:
        metadata_key = f"{base_prefix}/metadata.parquet"
        audio_key = f"{base_prefix}/audio.tar"
        audio_index_key = f"{base_prefix}/audio_index.parquet"
        manifest_key = f"{base_prefix}/manifest.json"
        self.r2.upload_file(artifacts.metadata_path, bucket, metadata_key)
        self.r2.upload_file(artifacts.audio_tar_path, bucket, audio_key)
        self.r2.upload_file(artifacts.audio_index_path, bucket, audio_index_key)
        self.r2.upload_file(artifacts.manifest_path, bucket, manifest_key)

        if not self.config.mock_mode:
            expected = {
                metadata_key: artifacts.metadata_path.stat().st_size,
                audio_key: artifacts.audio_tar_path.stat().st_size,
                audio_index_key: artifacts.audio_index_path.stat().st_size,
                manifest_key: artifacts.manifest_path.stat().st_size,
            }
            for key, size in expected.items():
                remote_size = self.r2.head_size(bucket, key)
                if remote_size != size:
                    raise RuntimeError(f"HEAD size mismatch for s3://{bucket}/{key}: {remote_size} != {size}")

        return {
            "metadata_key": metadata_key,
            "audio_key": audio_key,
            "audio_index_key": audio_index_key,
            "manifest_key": manifest_key,
        }

    async def _cleanup(self):
        self._shutdown_event.set()
        if self._heartbeat_task:
            try:
                await self._heartbeat_task
            except Exception:
                pass
        shutil.rmtree(self._work_root, ignore_errors=True)
        try:
            await self.db.set_worker_offline(self.config.worker_id)
        except Exception:
            pass
        await self.db.close()
