from __future__ import annotations

import asyncio
import logging
import shutil
import tarfile
import time
import uuid
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import pyarrow.parquet as pq

from .final_export_common import build_pack_artifacts
from .final_export_config import FinalExportConfig
from .final_export_db import (
    FinalExportMicroshardJob,
    FinalExportPostgresDB,
    FinalExportWorkerStats,
)
from .final_export_r2 import FinalExportR2Client


logger = logging.getLogger(__name__)


@dataclass
class BufferedSegment:
    microshard_id: str
    video_id: str
    metadata_row: dict
    audio_index_row: dict
    audio_tar_path: Path


class FinalExportCompactor:
    def __init__(self, config: FinalExportConfig):
        self.config = config
        self.db = FinalExportPostgresDB(config.database_url)
        self.r2 = FinalExportR2Client(config)
        self.stats = FinalExportWorkerStats()
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._work_root = config.local_work_root / "compact_stage" / config.run_id / config.worker_id
        self._current_language: Optional[str] = None
        self._download_pool = ThreadPoolExecutor(max_workers=12)

    async def start(self):
        try:
            self._work_root.mkdir(parents=True, exist_ok=True)
            await self.db.connect()
            try:
                await self.db.init_schema()
            except asyncio.TimeoutError:
                logger.warning(
                    "init_schema timed out for compactor startup; assuming schema already exists"
                )
            await self._register()
            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as exc:
            logger.error("final export compactor fatal error: %s", exc, exc_info=True)
            try:
                await self.db.set_worker_error(self.config.worker_id, str(exc))
            except Exception:
                pass
            raise
        finally:
            await self._cleanup()

    async def _register(self):
        config_json = {
            "stage": "language_compactor",
            "run_id": self.config.run_id,
            "output_bucket": self.config.output_bucket,
            "output_prefix": self.config.output_prefix,
                "reference_mode": self.config.reference_mode,
            "final_shard_target_rows": self.config.final_shard_target_rows,
            "allow_partial_shards": self.config.allow_partial_shards,
            "language_filters": self.config.language_filters,
        }
        await self.db.register_worker(
            worker_id=self.config.worker_id,
            stage="language_compactor",
            gpu_type=self.config.gpu_type,
            config_json=config_json,
        )

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self.db.update_heartbeat(self.config.worker_id, self.stats)
            except Exception as exc:
                logger.warning("final export compactor heartbeat failed: %s", str(exc)[:160])
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=30)
                break
            except asyncio.TimeoutError:
                pass

    async def _main_loop(self):
        shards_written = 0
        while not self._shutdown_event.is_set():
            if self.config.max_shards > 0 and shards_written >= self.config.max_shards:
                logger.info("reached FINAL_EXPORT_MAX_SHARDS=%s", self.config.max_shards)
                break

            jobs = await self.db.claim_microshards_batch(
                worker_id=self.config.worker_id,
                run_id=self.config.run_id,
                limit=self.config.compactor_claim_limit,
                languages=self.config.language_filters or None,
            )
            if not jobs:
                logger.info("no pending final export microshards to compact")
                break

            language = jobs[0].language
            self._current_language = language
            self.stats.current_item = f"lang={language} jobs={len(jobs)}"
            try:
                wrote = await self._process_claimed_jobs(language, jobs)
                shards_written += wrote
            finally:
                self._current_language = None
                self.stats.current_item = None

    async def _process_claimed_jobs(
        self, language: str, initial_jobs: list[FinalExportMicroshardJob]
    ) -> int:
        logger.info("[lang=%s] compaction start jobs=%s", language, len(initial_jobs))
        claimed_jobs: dict[str, FinalExportMicroshardJob] = {}
        buffer: list[BufferedSegment] = []
        written = 0
        exhausted = False
        pending_jobs = list(initial_jobs)

        while not self._shutdown_event.is_set():
            if pending_jobs:
                self.stats.jobs_claimed += len(pending_jobs)
                for job in pending_jobs:
                    claimed_jobs[job.microshard_id] = job
                buffer.extend(await self._load_microshard_rows_parallel(pending_jobs))
                pending_jobs = []
            self.stats.rows_buffered = len(buffer)

            if len(buffer) >= self.config.final_shard_target_rows:
                await self._flush_buffer(language, buffer, self.config.final_shard_target_rows)
                written += 1
                self.stats.jobs_completed += 1
                self.stats.rows_buffered = len(buffer)
                continue

            if not exhausted:
                pending_jobs = await self.db.claim_microshards_for_language(
                    worker_id=self.config.worker_id,
                    run_id=self.config.run_id,
                    language=language,
                    limit=self.config.compactor_claim_limit,
                )
                if pending_jobs:
                    continue
                exhausted = True

            if exhausted:
                if buffer and self.config.allow_partial_shards:
                    await self._flush_buffer(language, buffer, len(buffer))
                    written += 1
                    self.stats.jobs_completed += 1
                    self.stats.rows_buffered = len(buffer)
                break

        if claimed_jobs:
            await self.db.release_microshards(list(claimed_jobs.keys()), self.config.worker_id)
        return written

    async def _load_microshard_rows_parallel(
        self, jobs: list[FinalExportMicroshardJob]
    ) -> list[BufferedSegment]:
        loop = asyncio.get_running_loop()
        tasks = [
            loop.run_in_executor(self._download_pool, self._load_microshard_rows, job)
            for job in jobs
        ]
        buffered: list[BufferedSegment] = []
        for rows in await asyncio.gather(*tasks):
            buffered.extend(rows)
        return buffered

    def _load_microshard_rows(self, job: FinalExportMicroshardJob) -> list[BufferedSegment]:
        local_dir = self._work_root / "microshards" / job.microshard_id
        local_dir.mkdir(parents=True, exist_ok=True)
        metadata_path = local_dir / "metadata.parquet"
        audio_tar_path = local_dir / "audio.tar"
        audio_index_path = local_dir / "audio_index.parquet"
        if not metadata_path.exists():
            self.r2.download_file(job.output_bucket, job.metadata_key, metadata_path)
        if not audio_tar_path.exists():
            self.r2.download_file(job.output_bucket, job.audio_key, audio_tar_path)
        if not audio_index_path.exists():
            self.r2.download_file(job.output_bucket, job.audio_index_key, audio_index_path)

        metadata_rows = pq.read_table(metadata_path).to_pylist()
        audio_index_rows = pq.read_table(audio_index_path).to_pylist()
        if len(metadata_rows) != len(audio_index_rows):
            raise RuntimeError(
                f"Microshard row mismatch for {job.microshard_id}: {len(metadata_rows)} != {len(audio_index_rows)}"
            )

        buffered: list[BufferedSegment] = []
        start = min(job.consumed_rows, len(metadata_rows))
        for idx in range(start, len(metadata_rows)):
            meta = metadata_rows[idx]
            audio = audio_index_rows[idx]
            if str(meta.get("segment_id")) != str(audio.get("segment_id")):
                raise RuntimeError(f"Microshard order mismatch for {job.microshard_id} at row {idx}")
            buffered.append(
                BufferedSegment(
                    microshard_id=job.microshard_id,
                    video_id=str(meta.get("video_id") or job.video_id),
                    metadata_row=meta,
                    audio_index_row=audio,
                    audio_tar_path=audio_tar_path,
                )
            )
        return buffered

    async def _flush_buffer(self, language: str, buffer: list[BufferedSegment], take_rows: int):
        selected = list(buffer[:take_rows])
        del buffer[:take_rows]
        worker_suffix = self.config.worker_id.rsplit("-", 1)[-1]
        shard_id = (
            f"{language}_shard_{time.time_ns()}_{worker_suffix}_"
            f"{self.stats.packs_uploaded + 1:06d}_{uuid.uuid4().hex[:8]}"
        )
        pack_dir = self._work_root / "final_shards" / language / shard_id
        metadata_rows = [item.metadata_row for item in selected]
        audio_rows = self._load_audio_rows(selected)

        source_counts = Counter(item.microshard_id for item in selected)
        video_ids = sorted({item.video_id for item in selected})
        artifacts = build_pack_artifacts(
            pack_dir=pack_dir,
            manifest_name="manifest.json",
            metadata_rows=metadata_rows,
            audio_rows=audio_rows,
            manifest_payload={
                "shard_id": shard_id,
                "run_id": self.config.run_id,
                "language": language,
                "segment_count": len(selected),
                "video_count": len(video_ids),
                "source_microshard_count": len(source_counts),
                "source_video_ids_sample": video_ids[:32],
                "worker_id": self.config.worker_id,
                "created_at": datetime.now(timezone.utc).isoformat(),
            },
        )
        keys = self._upload_pack(
            bucket=self.config.output_bucket,
            base_prefix=f"{self.config.shard_prefix}/lang={language}/{shard_id}",
            artifacts=artifacts,
        )
        await self.db.insert_final_shard(
            {
                "shard_id": shard_id,
                "run_id": self.config.run_id,
                "language": language,
                "output_bucket": self.config.output_bucket,
                "metadata_key": keys["metadata_key"],
                "audio_key": keys["audio_key"],
                "audio_index_key": keys["audio_index_key"],
                "manifest_key": keys["manifest_key"],
                "segment_count": artifacts.row_count,
                "video_count": len(video_ids),
                "source_microshard_count": len(source_counts),
                "metadata_size_bytes": artifacts.metadata_path.stat().st_size,
                "audio_size_bytes": artifacts.audio_tar_path.stat().st_size,
                "audio_index_size_bytes": artifacts.audio_index_path.stat().st_size,
                "manifest_size_bytes": artifacts.manifest_path.stat().st_size,
                "metadata_sha256": artifacts.metadata_sha256,
                "audio_sha256": artifacts.audio_sha256,
                "audio_index_sha256": artifacts.audio_index_sha256,
                "segment_id_set_sha256": artifacts.segment_id_set_sha256,
                "metadata_json": {
                    "source_microshard_slices": dict(source_counts),
                    "manifest_sha256": artifacts.manifest_sha256,
                },
            }
        )
        await self.db.commit_microshard_consumption(
            worker_id=self.config.worker_id,
            consumption=dict(source_counts),
        )
        self.stats.packs_uploaded += 1
        self.stats.rows_uploaded += artifacts.row_count
        logger.info(
            "[lang=%s] wrote shard=%s rows=%s videos=%s microshards=%s",
            language,
            shard_id,
            artifacts.row_count,
            len(video_ids),
            len(source_counts),
        )

    def _load_audio_rows(self, selected: list[BufferedSegment]) -> list[dict]:
        audio_rows: list[dict] = []
        tar_cache: dict[Path, tarfile.TarFile] = {}
        try:
            for item in selected:
                tf = tar_cache.get(item.audio_tar_path)
                if tf is None:
                    tf = tarfile.open(item.audio_tar_path, "r")
                    tar_cache[item.audio_tar_path] = tf
                member_name = str(item.audio_index_row["tar_member_name"])
                handle = tf.extractfile(member_name)
                if handle is None:
                    raise RuntimeError(f"Missing tar member {member_name} in {item.audio_tar_path}")
                flac_bytes = handle.read()
                audio_rows.append(
                    {
                        "video_id": item.video_id,
                        "segment_id": item.metadata_row["segment_id"],
                        "tar_member_name": member_name,
                        "flac_bytes": flac_bytes,
                        "flac_sha256": item.audio_index_row["flac_sha256"],
                        "audio_duration_s": item.audio_index_row["audio_duration_s"],
                    }
                )
        finally:
            for tf in tar_cache.values():
                tf.close()
        return audio_rows

    def _upload_pack(self, *, bucket: str, base_prefix: str, artifacts) -> dict[str, str]:
        metadata_key = f"{base_prefix}/metadata.parquet"
        audio_key = f"{base_prefix}/audio.tar"
        audio_index_key = f"{base_prefix}/audio_index.parquet"
        manifest_key = f"{base_prefix}/manifest.json"
        self.r2.upload_file(artifacts.metadata_path, bucket, metadata_key)
        self.r2.upload_file(artifacts.audio_tar_path, bucket, audio_key)
        self.r2.upload_file(artifacts.audio_index_path, bucket, audio_index_key)
        self.r2.upload_file(artifacts.manifest_path, bucket, manifest_key)

        if not self.config.mock_mode:
            expected = {
                metadata_key: artifacts.metadata_path.stat().st_size,
                audio_key: artifacts.audio_tar_path.stat().st_size,
                audio_index_key: artifacts.audio_index_path.stat().st_size,
                manifest_key: artifacts.manifest_path.stat().st_size,
            }
            for key, size in expected.items():
                remote_size = self.r2.head_size(bucket, key)
                if remote_size != size:
                    raise RuntimeError(f"HEAD size mismatch for s3://{bucket}/{key}: {remote_size} != {size}")

        return {
            "metadata_key": metadata_key,
            "audio_key": audio_key,
            "audio_index_key": audio_index_key,
            "manifest_key": manifest_key,
        }

    async def _cleanup(self):
        self._shutdown_event.set()
        if self._heartbeat_task:
            try:
                await self._heartbeat_task
            except Exception:
                pass
        self._download_pool.shutdown(wait=False, cancel_futures=True)
        shutil.rmtree(self._work_root, ignore_errors=True)
        try:
            await self.db.set_worker_offline(self.config.worker_id)
        except Exception:
            pass
        await self.db.close()
