from __future__ import annotations

import hashlib
import io
import json
import logging
import math
import re
import tarfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import duckdb
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq

from .final_export_config import FinalExportConfig


logger = logging.getLogger(__name__)

_TAG_RE = re.compile(r"\[[^\]]+\]")


@dataclass
class PackArtifacts:
    metadata_path: Path
    audio_tar_path: Path
    audio_index_path: Path
    manifest_path: Path
    row_count: int
    member_count: int
    sum_flac_bytes: int
    metadata_sha256: str
    audio_sha256: str
    audio_index_sha256: str
    manifest_sha256: str
    segment_id_set_sha256: str


def stable_json_dumps(payload: dict[str, Any]) -> str:
    return json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n"


def write_json(path: Path, payload: dict[str, Any]) -> None:
    path.write_text(stable_json_dumps(payload))


def sha256_bytes(data: bytes) -> str:
    return hashlib.sha256(data).hexdigest()


def sha256_file(path: Path) -> str:
    h = hashlib.sha256()
    with path.open("rb") as handle:
        while True:
            chunk = handle.read(1024 * 1024)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()


def sha256_string_set(values: list[str]) -> str:
    material = "\n".join(sorted(values)).encode("utf-8")
    return hashlib.sha256(material).hexdigest()


def segment_audio_member_name(segment_id: str) -> str:
    return segment_id if segment_id.endswith(".flac") else f"{segment_id}.flac"


def replay_segment_id(original_file: str, was_split: bool, split_index: int) -> str:
    if was_split:
        return f"{original_file}_split{split_index}"
    return original_file


def clean_transcript_text(text: str) -> str:
    return _TAG_RE.sub(" ", text or "").replace("[UNK]", " ").replace("[INAUDIBLE]", " ").strip()


def compute_audio_metrics(audio: np.ndarray, sr: int, transcript: str) -> dict[str, Any]:
    if len(audio) == 0:
        return {
            "duration_s": 0.0,
            "sample_rate_hz": sr,
            "rms_dbfs": None,
            "peak_dbfs": None,
            "zero_crossing_rate": 0.0,
            "silence_fraction": 1.0,
            "chars_per_sec": 0.0,
            "words_per_sec": 0.0,
        }

    audio = np.asarray(audio, dtype=np.float32)
    duration_s = float(len(audio)) / float(sr)
    rms = float(np.sqrt(np.mean(np.square(audio)))) if len(audio) else 0.0
    peak = float(np.max(np.abs(audio))) if len(audio) else 0.0
    rms_dbfs = None if rms <= 0 else float(20.0 * math.log10(rms))
    peak_dbfs = None if peak <= 0 else float(20.0 * math.log10(peak))
    zcr = 0.0
    if len(audio) > 1:
        zcr = float(np.mean(np.abs(np.diff(np.signbit(audio).astype(np.int8)))))

    frame_len = max(int(sr * 0.02), 1)
    frame_count = max(len(audio) // frame_len, 1)
    trimmed = audio[: frame_count * frame_len] if len(audio) >= frame_len else audio
    if len(trimmed) < frame_len:
        frame_rms = np.array([rms], dtype=np.float32)
    else:
        frames = trimmed.reshape(frame_count, frame_len)
        frame_rms = np.sqrt(np.mean(np.square(frames), axis=1))
    silence_threshold = max(rms * 0.1, 1e-4)
    silence_fraction = float(np.mean(frame_rms <= silence_threshold)) if len(frame_rms) else 1.0

    clean_text = clean_transcript_text(transcript)
    chars_per_sec = 0.0 if duration_s <= 0 else float(len(clean_text) / duration_s)
    words = [part for part in clean_text.split() if part]
    words_per_sec = 0.0 if duration_s <= 0 else float(len(words) / duration_s)

    return {
        "duration_s": round(duration_s, 6),
        "sample_rate_hz": int(sr),
        "rms_dbfs": None if rms_dbfs is None else round(rms_dbfs, 6),
        "peak_dbfs": None if peak_dbfs is None else round(peak_dbfs, 6),
        "zero_crossing_rate": round(zcr, 6),
        "silence_fraction": round(silence_fraction, 6),
        "chars_per_sec": round(chars_per_sec, 6),
        "words_per_sec": round(words_per_sec, 6),
    }


def build_meta_information(
    *,
    canonical_row: dict[str, Any],
    raw_tx_row: dict[str, Any] | None,
    variant_row: dict[str, Any] | None,
    validation_row: dict[str, Any] | None,
    video_metadata: dict[str, Any] | None,
    export_provenance: dict[str, Any],
) -> str:
    payload = {
        "replay_provenance": {
            "segment_file": canonical_row.get("segment_file", ""),
            "parent_segment_file": canonical_row.get("parent_segment_file", ""),
            "is_split_segment": bool(canonical_row.get("is_split_segment", False)),
            "split_index_from_id": canonical_row.get("split_index_from_id"),
            "original_start_ms": canonical_row.get("original_start_ms"),
            "original_end_ms": canonical_row.get("original_end_ms"),
            "trimmed_start_ms": canonical_row.get("trimmed_start_ms"),
            "trimmed_end_ms": canonical_row.get("trimmed_end_ms"),
            "leading_pad_ms": canonical_row.get("leading_pad_ms"),
            "trailing_pad_ms": canonical_row.get("trailing_pad_ms"),
        },
        "source_row_provenance": raw_tx_row or {},
        "transcript_provenance": {
            "canonical_transcription": canonical_row.get("transcription", ""),
            "canonical_tagged": canonical_row.get("tagged", ""),
            "raw_transcription": (raw_tx_row or {}).get("transcription", ""),
            "raw_tagged": (raw_tx_row or {}).get("tagged", ""),
        },
        "variant_provenance": variant_row or {},
        "validation_provenance": validation_row or {},
        "language_evidence": {
            "segment_language": canonical_row.get("segment_language", ""),
            "tx_detected_language": canonical_row.get("tx_detected_language", ""),
            "gemini_lang": canonical_row.get("gemini_lang", ""),
            "corrected_language": canonical_row.get("corrected_language", ""),
            "queue_language": canonical_row.get("queue_language", ""),
            "youtube_audio_language": canonical_row.get("youtube_audio_language", ""),
            "youtube_default_language": canonical_row.get("youtube_default_language", ""),
        },
        "video_metadata": video_metadata or {},
        "export_provenance": export_provenance,
    }
    return json.dumps(payload, ensure_ascii=False, sort_keys=True)


def build_export_segment_payload(
    *,
    video_id: str,
    canonical_row: dict[str, Any],
    polished_segment: Any,
    run_id: str,
    worker_id: str,
    exported_at: str,
) -> dict[str, dict[str, Any]]:
    transcription_native = str(canonical_row.get("native_script_text") or canonical_row.get("transcription") or "")
    transcription_romanized = str(
        canonical_row.get("romanized_text") or canonical_row.get("transcription") or ""
    )
    transcription_mixed = str(canonical_row.get("transcription") or "")
    transcription_tagged = str(canonical_row.get("tagged") or "")
    metrics = compute_audio_metrics(polished_segment.audio, polished_segment.sr, transcription_mixed)
    flac_sha256 = sha256_bytes(polished_segment.flac_bytes)
    segment_id = str(canonical_row.get("segment_file") or "")
    member_name = segment_audio_member_name(segment_id)
    has_audio_tag = bool(_TAG_RE.search(transcription_tagged))

    raw_tx_row = {
        "transcription": canonical_row.get("raw_transcription"),
        "tagged": canonical_row.get("raw_tagged"),
        "detected_language": canonical_row.get("raw_detected_language"),
        "quality_score": canonical_row.get("raw_quality_score"),
        "speaker_emotion": canonical_row.get("speaker_emotion"),
        "speaker_style": canonical_row.get("speaker_style"),
        "speaker_pace": canonical_row.get("speaker_pace"),
        "speaker_accent": canonical_row.get("speaker_accent"),
    }
    variant_row = {
        "input_script_profile": canonical_row.get("input_script_profile"),
        "native_script_text": canonical_row.get("native_script_text"),
        "romanized_text": canonical_row.get("romanized_text"),
        "processing_route": canonical_row.get("variant_route"),
        "validation_errors": canonical_row.get("variant_validation_errors"),
    }
    validation_row = {
        "final_validation_source": canonical_row.get("final_validation_source"),
        "final_has_validation": canonical_row.get("final_has_validation"),
        "final_bucket": canonical_row.get("final_bucket"),
        "lid_consensus": canonical_row.get("lid_consensus"),
        "lid_agree_count": canonical_row.get("lid_agree_count"),
        "consensus_lang": canonical_row.get("consensus_lang"),
        "conformer_multi_ctc_normalized": canonical_row.get("conformer_multi_ctc_normalized"),
        "mms_confidence": canonical_row.get("mms_confidence"),
    }
    video_metadata = {
        "youtube_audio_language": canonical_row.get("youtube_audio_language"),
        "youtube_default_language": canonical_row.get("youtube_default_language"),
        "channel_id": canonical_row.get("channel_id"),
        "channel_title": canonical_row.get("channel_title"),
        "title": canonical_row.get("title"),
        "description": canonical_row.get("description"),
        "tags": canonical_row.get("tags"),
    }
    meta_information = build_meta_information(
        canonical_row=canonical_row,
        raw_tx_row=raw_tx_row,
        variant_row=variant_row,
        validation_row=validation_row,
        video_metadata=video_metadata,
        export_provenance={
            "run_id": run_id,
            "worker_id": worker_id,
            "exported_at": exported_at,
            "audio_sha256_type": "flac_sha256",
        },
    )

    metadata_row = {
        "video_id": video_id,
        "segment_id": segment_id,
        "parent_segment_file": str(
            canonical_row.get("parent_segment_file") or polished_segment.trim_meta.original_file
        ),
        "is_split_part": bool(
            canonical_row.get("is_split_segment", polished_segment.trim_meta.was_split)
        ),
        "split_index": int(
            canonical_row.get("split_index_from_id") or polished_segment.trim_meta.split_index or 0
        ),
        "segment_language": str(canonical_row.get("segment_language") or ""),
        "video_language": str(canonical_row.get("corrected_language") or canonical_row.get("queue_language") or ""),
        "youtube_audio_language": str(canonical_row.get("youtube_audio_language") or ""),
        "youtube_default_language": str(canonical_row.get("youtube_default_language") or ""),
        "transcription_native": transcription_native,
        "transcription_romanized": transcription_romanized,
        "transcription_mixed": transcription_mixed,
        "transcription_tagged": transcription_tagged,
        "has_audio_tag": has_audio_tag,
        "duration_s": metrics["duration_s"],
        "sample_rate_hz": metrics["sample_rate_hz"],
        "rms_dbfs": metrics["rms_dbfs"],
        "peak_dbfs": metrics["peak_dbfs"],
        "zero_crossing_rate": metrics["zero_crossing_rate"],
        "silence_fraction": metrics["silence_fraction"],
        "chars_per_sec": metrics["chars_per_sec"],
        "words_per_sec": metrics["words_per_sec"],
        "tx_quality_score": canonical_row.get("tx_quality_score"),
        "final_bucket": str(canonical_row.get("final_bucket") or ""),
        "audio_sha256": flac_sha256,
        "flac_size_bytes": len(polished_segment.flac_bytes),
        "meta_information": meta_information,
    }
    audio_row = {
        "video_id": video_id,
        "segment_id": segment_id,
        "tar_member_name": member_name,
        "flac_bytes": polished_segment.flac_bytes,
        "flac_sha256": flac_sha256,
        "audio_duration_s": metrics["duration_s"],
    }
    return {"metadata_row": metadata_row, "audio_row": audio_row}


def build_pack_artifacts(
    *,
    pack_dir: Path,
    manifest_name: str,
    metadata_rows: list[dict[str, Any]],
    audio_rows: list[dict[str, Any]],
    manifest_payload: dict[str, Any],
) -> PackArtifacts:
    pack_dir.mkdir(parents=True, exist_ok=True)
    metadata_path = pack_dir / "metadata.parquet"
    audio_tar_path = pack_dir / "audio.tar"
    audio_index_path = pack_dir / "audio_index.parquet"
    manifest_path = pack_dir / manifest_name

    pq.write_table(pa.Table.from_pylist(metadata_rows), metadata_path, compression="zstd")

    audio_index_rows: list[dict[str, Any]] = []
    member_names: list[str] = []
    sum_flac_bytes = 0
    with tarfile.open(audio_tar_path, "w") as tf:
        for item in audio_rows:
            member_name = item["tar_member_name"]
            flac_bytes = item["flac_bytes"]
            info = tarfile.TarInfo(name=member_name)
            info.size = len(flac_bytes)
            tf.addfile(info, io.BytesIO(flac_bytes))
            member_names.append(member_name)
            sum_flac_bytes += len(flac_bytes)
            audio_index_rows.append(
                {
                    "segment_id": item["segment_id"],
                    "video_id": item["video_id"],
                    "tar_member_name": member_name,
                    "flac_size_bytes": len(flac_bytes),
                    "flac_sha256": item["flac_sha256"],
                    "audio_duration_s": item["audio_duration_s"],
                }
            )

    pq.write_table(pa.Table.from_pylist(audio_index_rows), audio_index_path, compression="zstd")

    segment_ids = [str(row["segment_id"]) for row in metadata_rows]
    metadata_sha256 = sha256_file(metadata_path)
    audio_sha256 = sha256_file(audio_tar_path)
    audio_index_sha256 = sha256_file(audio_index_path)
    pack_manifest = {
        **manifest_payload,
        "metadata_row_count": len(metadata_rows),
        "audio_index_row_count": len(audio_index_rows),
        "audio_tar_member_count": len(member_names),
        "sum_flac_bytes": sum_flac_bytes,
        "segment_id_set_sha256": sha256_string_set(segment_ids),
        "metadata_sha256": metadata_sha256,
        "audio_index_sha256": audio_index_sha256,
        "audio_tar_sha256": audio_sha256,
        "metadata_size_bytes": metadata_path.stat().st_size,
        "audio_index_size_bytes": audio_index_path.stat().st_size,
        "audio_tar_size_bytes": audio_tar_path.stat().st_size,
    }
    write_json(manifest_path, pack_manifest)
    manifest_sha256 = sha256_file(manifest_path)

    return PackArtifacts(
        metadata_path=metadata_path,
        audio_tar_path=audio_tar_path,
        audio_index_path=audio_index_path,
        manifest_path=manifest_path,
        row_count=len(metadata_rows),
        member_count=len(member_names),
        sum_flac_bytes=sum_flac_bytes,
        metadata_sha256=metadata_sha256,
        audio_sha256=audio_sha256,
        audio_index_sha256=audio_index_sha256,
        manifest_sha256=manifest_sha256,
        segment_id_set_sha256=sha256_string_set(segment_ids),
    )
