from __future__ import annotations

import asyncio
import json
import logging
import os
import time
from pathlib import Path
from typing import Optional

import duckdb
from boto3.s3.transfer import TransferConfig
from botocore.config import Config as BotoConfig

from .final_export_config import FinalExportConfig


logger = logging.getLogger(__name__)


class FinalExportReferenceStore:
    def __init__(self, config: FinalExportConfig, cache_root: Path):
        self.config = config
        self.cache_dir = cache_root / "final_export_reference"
        self.local_canonical_path = config.canonical_segments_path
        self.local_raw_transcripts_path = config.raw_transcripts_path
        self.local_validation_path = config.validation_path
        self.local_youtube_meta_path = config.youtube_meta_path
        self.local_variants_path = config.variants_path
        self.local_manifest_path: Optional[Path] = None

        if self.config.reference_mode == "r2":
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.local_canonical_path = self.cache_dir / _cache_name(config.canonical_segments_r2_key)
            self.local_raw_transcripts_path = self.cache_dir / _cache_name(config.raw_transcripts_r2_key)
            self.local_validation_path = (
                self.cache_dir / _cache_name(config.validation_r2_key)
                if config.validation_r2_key
                else None
            )
            self.local_youtube_meta_path = (
                self.cache_dir / _cache_name(config.youtube_meta_r2_key)
                if config.youtube_meta_r2_key
                else None
            )
            self.local_variants_path = (
                self.cache_dir / _cache_name(config.variants_r2_key)
                if config.variants_r2_key
                else None
            )
            self.local_manifest_path = (
                self.cache_dir / _cache_name(config.reference_manifest_r2_key)
                if config.reference_manifest_r2_key
                else None
            )

        self.duckdb_path = self.cache_dir / "final_export_reference.duckdb"
        self._con: Optional[duckdb.DuckDBPyConnection] = None
        self._lock = asyncio.Lock()

    async def start(self):
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._start_sync)

    async def close(self):
        if self._con is None:
            return
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._close_sync)

    def get_video_reference_rows(self, video_id: str) -> dict[str, dict]:
        if self._con is None:
            raise RuntimeError("Final export reference store not initialized")
        rel = self._con.execute(
            """
            SELECT
                c.video_id,
                c.recommended_action,
                c.queue_language,
                c.corrected_language,
                c.segment_file,
                c.speaker_id,
                c.parent_segment_file,
                c.is_split_segment,
                c.split_index_from_id,
                c.original_start_ms,
                c.original_end_ms,
                c.trimmed_start_ms,
                c.trimmed_end_ms,
                c.leading_pad_ms,
                c.trailing_pad_ms,
                c.expected_language_hint,
                c.tx_detected_language,
                c.gemini_lang,
                c.segment_language,
                c.transcription,
                c.tagged,
                c.num_unk,
                c.num_inaudible,
                c.num_event_tags,
                c.text_length_per_sec,
                c.tx_quality_score,
                c.asr_eligible,
                c.tts_clean_eligible,
                c.tts_expressive_eligible,
                c.lang_mismatch_flag,
                c.duration_s,
                c.final_validation_source,
                c.final_has_validation,
                c.final_bucket,
                c.input_script_profile,
                coalesce(nullif(c.native_script_text, ''), v.native_script_text) AS native_script_text,
                coalesce(nullif(c.romanized_text, ''), v.romanized_text) AS romanized_text,
                coalesce(nullif(c.variant_route, ''), v.variant_route) AS variant_route,
                coalesce(nullif(c.variant_validation_errors, ''), v.variant_validation_errors) AS variant_validation_errors,
                y.youtube_audio_language,
                y.youtube_default_language,
                y.channel_id,
                y.channel_title,
                y.title,
                y.description,
                y.tags,
                rt.transcription AS raw_transcription,
                rt.tagged AS raw_tagged,
                rt.detected_language AS raw_detected_language,
                rt.quality_score AS raw_quality_score,
                rt.speaker_emotion,
                rt.speaker_style,
                rt.speaker_pace,
                rt.speaker_accent,
                vs.lid_consensus,
                vs.lid_agree_count,
                vs.consensus_lang,
                vs.conformer_multi_ctc_normalized,
                vs.mms_confidence
            FROM canonical_segments c
            LEFT JOIN raw_transcripts rt
                ON rt.video_id = c.video_id
               AND rt.segment_file = c.segment_file
            LEFT JOIN variants_source v
                ON v.video_id = c.video_id
               AND v.segment_file = c.segment_file
            LEFT JOIN validation_source vs
                ON vs.video_id = c.video_id
               AND vs.segment_file = c.segment_file
            LEFT JOIN youtube_meta y
                ON y.video_id = c.video_id
            WHERE c.video_id = ?
            ORDER BY c.segment_file
            """,
            [video_id],
        )
        cols = [d[0] for d in rel.description]
        return {
            str(row["segment_file"]): row
            for row in (dict(zip(cols, item)) for item in rel.fetchall())
            if str(row.get("segment_file", "")).strip()
        }

    def _start_sync(self):
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        if self.config.reference_mode == "r2":
            self._download_reference_files()
        self._open_duckdb()

    def _close_sync(self):
        if self._con is not None:
            self._con.close()
            self._con = None

    def _download_reference_files(self):
        import boto3

        s3 = boto3.client(
            "s3",
            endpoint_url=self.config.base.r2_endpoint_url,
            aws_access_key_id=self.config.base.r2_access_key_id,
            aws_secret_access_key=self.config.base.r2_secret_access_key,
            region_name="auto",
            config=BotoConfig(max_pool_connections=_s3_pool_size(self.config.reference_download_concurrency)),
        )
        transfer = TransferConfig(
            multipart_threshold=64 * 1024 * 1024,
            multipart_chunksize=64 * 1024 * 1024,
            max_concurrency=max(self.config.reference_download_concurrency, 1),
            use_threads=True,
        )

        if self.local_manifest_path and self.config.reference_manifest_r2_key:
            self._try_download_optional(
                s3=s3,
                bucket=self.config.reference_bucket,
                key=self.config.reference_manifest_r2_key,
                dest=self.local_manifest_path,
                transfer=None,
            )
        self._download_required(
            s3=s3,
            bucket=self.config.reference_bucket,
            key=self.config.canonical_segments_r2_key,
            dest=self.local_canonical_path,
            transfer=transfer,
        )
        self._download_required(
            s3=s3,
            bucket=self.config.reference_bucket,
            key=self.config.raw_transcripts_r2_key,
            dest=self.local_raw_transcripts_path,
            transfer=transfer,
        )
        if self.local_validation_path and self.config.validation_r2_key:
            self._try_download_optional(
                s3=s3,
                bucket=self.config.reference_bucket,
                key=self.config.validation_r2_key,
                dest=self.local_validation_path,
                transfer=transfer,
            )
        if self.local_youtube_meta_path and self.config.youtube_meta_r2_key:
            self._try_download_optional(
                s3=s3,
                bucket=self.config.reference_bucket,
                key=self.config.youtube_meta_r2_key,
                dest=self.local_youtube_meta_path,
                transfer=transfer,
            )
        if self.local_variants_path and self.config.variants_r2_key:
            self._try_download_optional(
                s3=s3,
                bucket=self.config.reference_bucket,
                key=self.config.variants_r2_key,
                dest=self.local_variants_path,
                transfer=transfer,
            )

    def _download_required(self, *, s3, bucket: str, key: str, dest: Path, transfer):
        self._download_object(s3=s3, bucket=bucket, key=key, dest=dest, transfer=transfer, force=False)

    def _try_download_optional(self, *, s3, bucket: str, key: str, dest: Path, transfer):
        try:
            self._download_object(s3=s3, bucket=bucket, key=key, dest=dest, transfer=transfer, force=False)
        except Exception as exc:
            logger.info("Optional final export reference unavailable %s: %s", key, exc)

    def _download_object(self, *, s3, bucket: str, key: str, dest: Path, transfer, force: bool):
        head = s3.head_object(Bucket=bucket, Key=key)
        size_bytes = int(head.get("ContentLength", 0) or 0)
        if not force and dest.exists() and dest.stat().st_size == size_bytes:
            logger.info("Reusing cached final export reference %s (%s)", dest.name, _format_bytes(size_bytes))
            return size_bytes
        logger.info(
            "Downloading final export reference s3://%s/%s -> %s (%s)",
            bucket,
            key,
            dest,
            _format_bytes(size_bytes),
        )
        t0 = time.time()
        extra = {"Config": transfer} if transfer is not None else {}
        s3.download_file(bucket, key, str(dest), **extra)
        logger.info("Downloaded %s in %.1fs", dest.name, time.time() - t0)
        return size_bytes

    def _open_duckdb(self):
        self._con = duckdb.connect(str(self.duckdb_path))
        threads = min(os.cpu_count() or 4, max(self.config.duckdb_threads, 1))
        self._con.execute(f"SET threads = {threads}")
        self._con.execute("SET preserve_insertion_order = false")
        canonical_path = _sql_path(self.local_canonical_path)
        raw_tx_path = _sql_path(self.local_raw_transcripts_path)
        self._con.execute(
            f"""
            CREATE OR REPLACE VIEW canonical_segments AS
            SELECT
                *,
                coalesce(
                    nullif(tx_detected_language, ''),
                    nullif(gemini_lang, ''),
                    nullif(corrected_language, ''),
                    nullif(queue_language, '')
                ) AS segment_language
            FROM read_parquet('{canonical_path}')
            """
        )
        self._con.execute(
            f"""
            CREATE OR REPLACE VIEW raw_transcripts AS
            SELECT *
            FROM read_parquet('{raw_tx_path}')
            """
        )
        if self.local_variants_path and self.local_variants_path.exists():
            variants_path = _sql_path(self.local_variants_path)
            self._con.execute(
                f"""
                CREATE OR REPLACE VIEW variants_source AS
                SELECT
                    coalesce(video_id, regexp_extract(row_id, '^([^/]+)/', 1)) AS video_id,
                    coalesce(segment_id, regexp_extract(row_id, '^[^/]+/(.*)$', 1)) AS segment_file,
                    native_script_text,
                    romanized_text,
                    coalesce(variant_route, processing_route) AS variant_route,
                    coalesce(variant_validation_errors, validation_errors) AS variant_validation_errors
                FROM read_parquet('{variants_path}')
                """
            )
        else:
            self._con.execute(
                """
                CREATE OR REPLACE VIEW variants_source AS
                SELECT
                    ''::VARCHAR AS video_id,
                    ''::VARCHAR AS segment_file,
                    ''::VARCHAR AS native_script_text,
                    ''::VARCHAR AS romanized_text,
                    ''::VARCHAR AS variant_route,
                    ''::VARCHAR AS variant_validation_errors
                WHERE FALSE
                """
            )
        if self.local_validation_path and self.local_validation_path.exists():
            validation_path = _sql_path(self.local_validation_path)
            self._con.execute(
                f"""
                CREATE OR REPLACE VIEW validation_source AS
                SELECT
                    video_id,
                    segment_file,
                    lid_consensus,
                    lid_agree_count,
                    consensus_lang,
                    conformer_multi_ctc_normalized,
                    mms_confidence
                FROM read_parquet('{validation_path}')
                """
            )
        else:
            self._con.execute(
                """
                CREATE OR REPLACE VIEW validation_source AS
                SELECT
                    ''::VARCHAR AS video_id,
                    ''::VARCHAR AS segment_file,
                    false AS lid_consensus,
                    0::INTEGER AS lid_agree_count,
                    ''::VARCHAR AS consensus_lang,
                    NULL::DOUBLE AS conformer_multi_ctc_normalized,
                    NULL::DOUBLE AS mms_confidence
                WHERE FALSE
                """
            )
        if self.local_youtube_meta_path and self.local_youtube_meta_path.exists():
            youtube_path = _sql_path(self.local_youtube_meta_path)
            self._con.execute(
                f"""
                CREATE OR REPLACE VIEW youtube_meta AS
                SELECT
                    video_id,
                    regexp_extract(lower(coalesce(default_audio_language, '')), '^([a-z]+)', 1) AS youtube_audio_language,
                    regexp_extract(lower(coalesce(default_language, '')), '^([a-z]+)', 1) AS youtube_default_language,
                    channel_id,
                    channel_title,
                    title,
                    description,
                    tags
                FROM read_csv_auto('{youtube_path}', header=true)
                """
            )
        else:
            self._con.execute(
                """
                CREATE OR REPLACE VIEW youtube_meta AS
                SELECT
                    ''::VARCHAR AS video_id,
                    ''::VARCHAR AS youtube_audio_language,
                    ''::VARCHAR AS youtube_default_language,
                    ''::VARCHAR AS channel_id,
                    ''::VARCHAR AS channel_title,
                    ''::VARCHAR AS title,
                    ''::VARCHAR AS description,
                    ''::VARCHAR AS tags
                WHERE FALSE
                """
            )


def _format_bytes(size_bytes: int) -> str:
    if size_bytes <= 0:
        return "0B"
    value = float(size_bytes)
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if value < 1024.0 or unit == "TB":
            return f"{value:.1f}{unit}"
        value /= 1024.0
    return f"{size_bytes}B"


def _sql_path(path: Path) -> str:
    return path.as_posix().replace("'", "''")


def _cache_name(key: str) -> str:
    return key.strip("/").replace("/", "__")


def _s3_pool_size(download_concurrency: int) -> int:
    concurrency = max(download_concurrency, 1)
    return max(32, concurrency * 2)
