"""
Pipeline funnel analysis: trace segment counts and durations across all stages.
Compares: DB segment_count vs raw tar metadata vs raw FLACs vs transcribed FLACs vs validation.
"""
import json
import os
import sys
import tarfile
import tempfile
import time
import io
import struct
from collections import defaultdict
from pathlib import Path

import boto3
import psycopg2

# ─── Config ───
R2_ENDPOINT = os.getenv("R2_ENDPOINT_URL", "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com")
R2_KEY_ID = os.getenv("R2_ACCESS_KEY_ID", "c3c9190ae7ff98b10271ea8db6940210")
R2_SECRET = os.getenv("R2_SECRET_ACCESS_KEY", "eab9394d02b48a865634105b92c74751ec9a311c56884f7aead5d76476c6b576")
DB_URL = os.getenv("DATABASE_URL", "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres")

RAW_BUCKET = "1-cleaned-data"
TRANSCRIBED_BUCKET = "transcribed"
VALIDATED_BUCKET = "validation-results"


def get_s3():
    return boto3.client("s3", endpoint_url=R2_ENDPOINT,
                        aws_access_key_id=R2_KEY_ID,
                        aws_secret_access_key=R2_SECRET,
                        region_name="auto")


def get_db():
    return psycopg2.connect(DB_URL)


def get_flac_duration_from_bytes(data: bytes) -> float:
    """Parse FLAC streaminfo to get duration without decoding audio."""
    if data[:4] != b'fLaC':
        return 0.0
    # STREAMINFO block starts at byte 4
    # Block header: 1 byte (type+last flag) + 3 bytes (length)
    if len(data) < 42:
        return 0.0
    # STREAMINFO is always first, 34 bytes
    # Bytes 18-20 (relative to streaminfo start at byte 8): sample rate (20 bits)
    # Bytes 20-24: total samples (36 bits, upper 4 bits from byte 20)
    offset = 8  # skip fLaC + block header (4+4)
    # Sample rate: bits 0-19 of bytes at offset+10
    sr_bytes = data[offset+10:offset+13]
    if len(sr_bytes) < 3:
        return 0.0
    sample_rate = (sr_bytes[0] << 12) | (sr_bytes[1] << 4) | (sr_bytes[2] >> 4)
    if sample_rate == 0:
        return 0.0
    # Total samples: bits 4-39 starting from offset+13 byte's lower 4 bits
    ts_bytes = data[offset+13:offset+18]
    if len(ts_bytes) < 5:
        return 0.0
    total_samples = ((ts_bytes[0] & 0x0F) << 32) | (ts_bytes[1] << 24) | (ts_bytes[2] << 16) | (ts_bytes[3] << 8) | ts_bytes[4]
    return total_samples / sample_rate


def flac_duration_from_stream(stream) -> float:
    """Read just enough of a FLAC stream to get duration."""
    header = stream.read(42)
    if not header or len(header) < 42:
        return 0.0
    return get_flac_duration_from_bytes(header)


def check_raw_tar_exists(s3, video_id):
    """Check if raw tar exists in 1-cleaned-data. Returns key or None."""
    for key in [f"{video_id}.tar", f"cleaned/trail/{video_id}.tar"]:
        try:
            s3.head_object(Bucket=RAW_BUCKET, Key=key)
            return key
        except:
            pass
    return None


def check_transcribed_tar_exists(s3, video_id):
    """Check if transcribed tar exists."""
    key = f"{video_id}_transcribed.tar"
    try:
        s3.head_object(Bucket=TRANSCRIBED_BUCKET, Key=key)
        return key
    except:
        return None


def analyze_raw_tar(s3, video_id, raw_key):
    """Download and analyze raw tar: metadata + FLAC count + durations."""
    result = {
        "raw_key": raw_key,
        "metadata_total_segments": 0,
        "metadata_usable": 0,
        "metadata_unusable": 0,
        "raw_flac_count": 0,
        "raw_total_duration_s": 0.0,
        "duration_buckets": defaultdict(int),
        "duration_hours_buckets": defaultdict(float),
        "unusable_reasons": defaultdict(int),
        "metadata_language": "",
    }

    with tempfile.NamedTemporaryFile(suffix=".tar") as tmp:
        s3.download_file(RAW_BUCKET, raw_key, tmp.name)
        with tarfile.open(tmp.name, "r:*") as tf:
            # Find metadata.json
            metadata = {}
            for member in tf.getmembers():
                if member.name.endswith("metadata.json"):
                    f = tf.extractfile(member)
                    if f:
                        metadata = json.loads(f.read())
                    break

            if metadata:
                result["metadata_total_segments"] = metadata.get("total_segments", 0)
                result["metadata_language"] = metadata.get("language", "")
                segments_list = metadata.get("segments", [])
                for seg in segments_list:
                    status = seg.get("status", "unknown")
                    if status == "usable":
                        result["metadata_usable"] += 1
                    else:
                        result["metadata_unusable"] += 1
                        reason = seg.get("label", seg.get("status", "unknown"))
                        result["unusable_reasons"][reason] += 1

            # Count FLAC files and get durations
            for member in tf.getmembers():
                if member.name.endswith(".flac") and "/segments/" in member.name:
                    result["raw_flac_count"] += 1
                    f = tf.extractfile(member)
                    if f:
                        dur = flac_duration_from_stream(f)
                        result["raw_total_duration_s"] += dur
                        bucket = duration_bucket(dur)
                        result["duration_buckets"][bucket] += 1
                        result["duration_hours_buckets"][bucket] += dur / 3600

    return result


def analyze_transcribed_tar(s3, video_id, tx_key):
    """Download and analyze transcribed tar: FLAC count + JSON count + durations."""
    result = {
        "tx_key": tx_key,
        "tx_flac_count": 0,
        "tx_json_count": 0,
        "tx_total_duration_s": 0.0,
        "tx_duration_buckets": defaultdict(int),
        "tx_duration_hours_buckets": defaultdict(float),
        "tx_metadata": {},
    }

    with tempfile.NamedTemporaryFile(suffix=".tar") as tmp:
        s3.download_file(TRANSCRIBED_BUCKET, tx_key, tmp.name)
        with tarfile.open(tmp.name, "r:*") as tf:
            # Metadata
            for member in tf.getmembers():
                if member.name.endswith("metadata.json"):
                    f = tf.extractfile(member)
                    if f:
                        result["tx_metadata"] = json.loads(f.read())
                    break

            for member in tf.getmembers():
                if member.name.endswith(".flac") and "/segments/" in member.name:
                    result["tx_flac_count"] += 1
                    f = tf.extractfile(member)
                    if f:
                        dur = flac_duration_from_stream(f)
                        result["tx_total_duration_s"] += dur
                        bucket = duration_bucket(dur)
                        result["tx_duration_buckets"][bucket] += 1
                        result["tx_duration_hours_buckets"][bucket] += dur / 3600
                elif member.name.endswith(".json") and "/transcriptions/" in member.name:
                    result["tx_json_count"] += 1

    return result


def duration_bucket(dur_s):
    if dur_s < 0.5:
        return "<0.5s"
    elif dur_s < 1.0:
        return "0.5-1s"
    elif dur_s < 2.0:
        return "1-2s"
    elif dur_s < 3.0:
        return "2-3s"
    elif dur_s < 5.0:
        return "3-5s"
    elif dur_s < 10.0:
        return "5-10s"
    elif dur_s < 15.0:
        return "10-15s"
    elif dur_s < 30.0:
        return "15-30s"
    else:
        return "30s+"


def run_analysis(sample_size):
    print(f"\n{'='*80}")
    print(f"  PIPELINE FUNNEL ANALYSIS — {sample_size} VIDEO SAMPLE")
    print(f"{'='*80}\n")
    
    s3 = get_s3()
    conn = get_db()
    cur = conn.cursor()

    # Get random sample from DB
    cur.execute(f"""
        SELECT video_id, language, segment_count, segments_transcribed, validation_segments, validation_status
        FROM video_queue 
        WHERE segment_count > 0
        ORDER BY random() 
        LIMIT {sample_size}
    """)
    samples = cur.fetchall()
    print(f"Sampled {len(samples)} videos from DB\n")

    # Aggregators
    totals = {
        "db_segment_count": 0,
        "db_segments_transcribed": 0,
        "raw_found": 0, "raw_not_found": 0,
        "tx_found": 0, "tx_not_found": 0,
        "metadata_total_segments": 0,
        "metadata_usable": 0, "metadata_unusable": 0,
        "raw_flac_total": 0,
        "raw_duration_total_s": 0.0,
        "tx_flac_total": 0,
        "tx_json_total": 0,
        "tx_duration_total_s": 0.0,
    }
    all_raw_duration_buckets = defaultdict(int)
    all_raw_duration_hours = defaultdict(float)
    all_tx_duration_buckets = defaultdict(int)
    all_tx_duration_hours = defaultdict(float)
    all_unusable_reasons = defaultdict(int)

    # Per-video detailed tracking
    per_video = []
    
    t0 = time.time()
    for i, (vid, lang, seg_count, seg_tx, val_segs, val_status) in enumerate(samples):
        if (i + 1) % 10 == 0 or i == 0:
            elapsed = time.time() - t0
            rate = (i + 1) / elapsed if elapsed > 0 else 0
            eta = (len(samples) - i - 1) / rate if rate > 0 else 0
            print(f"  [{i+1}/{len(samples)}] Processing {vid} ({rate:.1f} vids/s, ETA {eta:.0f}s)")

        totals["db_segment_count"] += seg_count
        totals["db_segments_transcribed"] += (seg_tx or 0)

        entry = {"video_id": vid, "language": lang, "db_segment_count": seg_count}

        # Check raw tar
        raw_key = check_raw_tar_exists(s3, vid)
        if raw_key:
            totals["raw_found"] += 1
            try:
                raw = analyze_raw_tar(s3, vid, raw_key)
                entry.update({
                    "raw_found": True,
                    "metadata_total": raw["metadata_total_segments"],
                    "metadata_usable": raw["metadata_usable"],
                    "metadata_unusable": raw["metadata_unusable"],
                    "raw_flacs": raw["raw_flac_count"],
                    "raw_duration_s": raw["raw_total_duration_s"],
                })
                totals["metadata_total_segments"] += raw["metadata_total_segments"]
                totals["metadata_usable"] += raw["metadata_usable"]
                totals["metadata_unusable"] += raw["metadata_unusable"]
                totals["raw_flac_total"] += raw["raw_flac_count"]
                totals["raw_duration_total_s"] += raw["raw_total_duration_s"]
                for k, v in raw["duration_buckets"].items():
                    all_raw_duration_buckets[k] += v
                for k, v in raw["duration_hours_buckets"].items():
                    all_raw_duration_hours[k] += v
                for k, v in raw["unusable_reasons"].items():
                    all_unusable_reasons[k] += v
            except Exception as e:
                entry["raw_found"] = True
                entry["raw_error"] = str(e)
        else:
            totals["raw_not_found"] += 1
            entry["raw_found"] = False

        # Check transcribed tar
        tx_key = check_transcribed_tar_exists(s3, vid)
        if tx_key:
            totals["tx_found"] += 1
            try:
                tx = analyze_transcribed_tar(s3, vid, tx_key)
                entry.update({
                    "tx_found": True,
                    "tx_flacs": tx["tx_flac_count"],
                    "tx_jsons": tx["tx_json_count"],
                    "tx_duration_s": tx["tx_total_duration_s"],
                    "tx_summary": tx["tx_metadata"].get("transcription_summary", {}),
                })
                totals["tx_flac_total"] += tx["tx_flac_count"]
                totals["tx_json_total"] += tx["tx_json_count"]
                totals["tx_duration_total_s"] += tx["tx_total_duration_s"]
                for k, v in tx["tx_duration_buckets"].items():
                    all_tx_duration_buckets[k] += v
                for k, v in tx["tx_duration_hours_buckets"].items():
                    all_tx_duration_hours[k] += v
            except Exception as e:
                entry["tx_found"] = True
                entry["tx_error"] = str(e)
        else:
            totals["tx_not_found"] += 1
            entry["tx_found"] = False

        per_video.append(entry)

    elapsed = time.time() - t0
    conn.close()

    # ─── Results ───
    n = len(samples)
    raw_n = totals["raw_found"]
    tx_n = totals["tx_found"]
    total_db_videos = 507387  # from DB query

    print(f"\n{'─'*80}")
    print(f"  RESULTS ({n} videos, {elapsed:.0f}s)")
    print(f"{'─'*80}")

    print(f"\n  1. BUCKET AVAILABILITY")
    print(f"     Raw tars (1-cleaned-data):  {raw_n}/{n} found ({100*raw_n/n:.1f}%)")
    print(f"     Transcribed tars:           {tx_n}/{n} found ({100*tx_n/n:.1f}%)")

    print(f"\n  2. DB segment_count vs METADATA")
    print(f"     DB sum(segment_count) for sample:     {totals['db_segment_count']:>12,}")
    if raw_n > 0:
        print(f"     Metadata total_segments (from {raw_n} raw tars): {totals['metadata_total_segments']:>12,}")
        if totals['metadata_total_segments'] > 0:
            ratio = totals['db_segment_count'] / totals['metadata_total_segments'] if totals['metadata_total_segments'] else 0
            # Only compare for videos that have raw tars
            raw_videos_db_seg = sum(e['db_segment_count'] for e in per_video if e.get('raw_found'))
            print(f"     DB segment_count for those {raw_n} videos:    {raw_videos_db_seg:>12,}")
            if totals['metadata_total_segments'] > 0:
                ratio2 = raw_videos_db_seg / totals['metadata_total_segments']
                print(f"     Ratio (DB / metadata):                     {ratio2:.4f}")

    print(f"\n  3. RAW SEGMENT BREAKDOWN (from {raw_n} raw tars)")
    if raw_n > 0:
        print(f"     Metadata usable:   {totals['metadata_usable']:>10,}  ({100*totals['metadata_usable']/max(1,totals['metadata_total_segments']):.1f}%)")
        print(f"     Metadata unusable: {totals['metadata_unusable']:>10,}  ({100*totals['metadata_unusable']/max(1,totals['metadata_total_segments']):.1f}%)")
        print(f"     Actual FLAC files: {totals['raw_flac_total']:>10,}")
        if totals['metadata_usable'] > 0:
            print(f"     FLACs vs usable:   {100*totals['raw_flac_total']/totals['metadata_usable']:.1f}%")
        print(f"     Raw duration:      {totals['raw_duration_total_s']/3600:.2f} hours")
        
        print(f"\n     Unusable reasons:")
        for reason, cnt in sorted(all_unusable_reasons.items(), key=lambda x: -x[1]):
            print(f"       {reason:30s} {cnt:>8,}  ({100*cnt/max(1,totals['metadata_unusable']):.1f}%)")

        print(f"\n     Raw FLAC duration distribution:")
        for bucket in ["<0.5s", "0.5-1s", "1-2s", "2-3s", "3-5s", "5-10s", "10-15s", "15-30s", "30s+"]:
            cnt = all_raw_duration_buckets.get(bucket, 0)
            hrs = all_raw_duration_hours.get(bucket, 0)
            print(f"       {bucket:10s} {cnt:>8,} segments  {hrs:>8.2f} hours  ({100*cnt/max(1,totals['raw_flac_total']):.1f}%)")

    print(f"\n  4. TRANSCRIBED SEGMENT BREAKDOWN (from {tx_n} tars)")
    if tx_n > 0:
        print(f"     Transcribed FLACs:  {totals['tx_flac_total']:>10,}")
        print(f"     Transcription JSONs:{totals['tx_json_total']:>10,}")
        print(f"     Transcribed duration: {totals['tx_duration_total_s']/3600:.2f} hours")

        print(f"\n     Transcribed FLAC duration distribution:")
        for bucket in ["<0.5s", "0.5-1s", "1-2s", "2-3s", "3-5s", "5-10s", "10-15s", "15-30s", "30s+"]:
            cnt = all_tx_duration_buckets.get(bucket, 0)
            hrs = all_tx_duration_hours.get(bucket, 0)
            print(f"       {bucket:10s} {cnt:>8,} segments  {hrs:>8.2f} hours  ({100*cnt/max(1,totals['tx_flac_total']):.1f}%)")

    print(f"\n  5. FUNNEL RATIOS")
    if raw_n > 0 and totals['metadata_total_segments'] > 0:
        print(f"     Metadata -> Usable:     {100*totals['metadata_usable']/totals['metadata_total_segments']:.1f}% of segments")
        print(f"     Usable -> Raw FLACs:    {100*totals['raw_flac_total']/max(1,totals['metadata_usable']):.1f}%")
    if tx_n > 0 and raw_n > 0 and totals['raw_flac_total'] > 0:
        # Compare per-video where we have both raw and transcribed data
        both_raw_flac = 0
        both_tx_flac = 0
        both_raw_dur = 0.0
        both_tx_dur = 0.0
        both_count = 0
        for e in per_video:
            if e.get('raw_found') and e.get('tx_found') and e.get('raw_flacs', 0) > 0:
                both_count += 1
                both_raw_flac += e.get('raw_flacs', 0)
                both_tx_flac += e.get('tx_flacs', 0)
                both_raw_dur += e.get('raw_duration_s', 0)
                both_tx_dur += e.get('tx_duration_s', 0)
        if both_count > 0:
            print(f"     [For {both_count} videos with both raw+tx data:]")
            print(f"       Raw FLACs -> Transcribed FLACs: {100*both_tx_flac/both_raw_flac:.1f}%  ({both_raw_flac:,} -> {both_tx_flac:,})")
            print(f"       Raw hours -> Transcribed hours: {100*both_tx_dur/max(1,both_raw_dur):.1f}%  ({both_raw_dur/3600:.2f}h -> {both_tx_dur/3600:.2f}h)")

    print(f"\n  6. EXTRAPOLATION TO FULL DATASET ({total_db_videos:,} videos)")
    if tx_n > 0:
        avg_tx_flacs = totals['tx_flac_total'] / tx_n
        avg_tx_dur_h = (totals['tx_duration_total_s'] / tx_n) / 3600
        print(f"     Avg transcribed FLACs/video: {avg_tx_flacs:.1f}")
        print(f"     Avg transcribed hours/video: {avg_tx_dur_h:.4f}")
        print(f"     Est total transcribed FLACs: {int(avg_tx_flacs * total_db_videos):,}")
        print(f"     Est total transcribed hours: {avg_tx_dur_h * total_db_videos:,.0f}")
    if raw_n > 0:
        avg_raw_flacs = totals['raw_flac_total'] / raw_n
        avg_raw_dur_h = (totals['raw_duration_total_s'] / raw_n) / 3600
        print(f"     Avg raw FLACs/video:         {avg_raw_flacs:.1f}")
        print(f"     Avg raw hours/video:         {avg_raw_dur_h:.4f}")
        print(f"     Est total raw FLACs:         {int(avg_raw_flacs * total_db_videos):,}")
        print(f"     Est total raw hours:         {avg_raw_dur_h * total_db_videos:,.0f}")

    # Per-video sanity check: compare db segment_count to flac counts
    print(f"\n  7. PER-VIDEO CROSS-CHECK (DB segment_count vs actual FLACs)")
    mismatches = []
    for e in per_video:
        if e.get('raw_found') and 'raw_flacs' in e:
            ratio = e['raw_flacs'] / e['db_segment_count'] if e['db_segment_count'] > 0 else 0
            if ratio > 0:
                mismatches.append({
                    'vid': e['video_id'], 
                    'db': e['db_segment_count'],
                    'raw_flacs': e['raw_flacs'],
                    'ratio': ratio,
                    'metadata_total': e.get('metadata_total', 0),
                })
    if mismatches:
        ratios = [m['ratio'] for m in mismatches]
        print(f"     Videos checked: {len(mismatches)}")
        print(f"     FLACs/DB ratio: min={min(ratios):.3f}, max={max(ratios):.3f}, avg={sum(ratios)/len(ratios):.3f}, median={sorted(ratios)[len(ratios)//2]:.3f}")
        
        # Show a few examples
        print(f"\n     Sample entries (first 10):")
        for m in mismatches[:10]:
            print(f"       {m['vid']:15s}  DB={m['db']:>6,}  meta={m['metadata_total']:>6,}  FLACs={m['raw_flacs']:>6,}  ratio={m['ratio']:.3f}")

    # Save detailed results
    output_path = f"/home/ubuntu/transcripts/funnel_analysis_{n}.json"
    save_data = {
        "sample_size": n,
        "totals": {k: v for k, v in totals.items()},
        "raw_duration_buckets": dict(all_raw_duration_buckets),
        "tx_duration_buckets": dict(all_tx_duration_buckets),
        "raw_duration_hours": dict(all_raw_duration_hours),
        "tx_duration_hours": dict(all_tx_duration_hours),
        "unusable_reasons": dict(all_unusable_reasons),
        "per_video_sample": per_video[:20],
    }
    with open(output_path, "w") as f:
        json.dump(save_data, f, indent=2, default=str)
    print(f"\n  Detailed results saved to {output_path}")

    return totals, per_video


if __name__ == "__main__":
    size = int(sys.argv[1]) if len(sys.argv) > 1 else 100
    run_analysis(size)
