#!/usr/bin/env python3
from __future__ import annotations

import argparse
import io
import os
import sys
import tarfile
from pathlib import Path

import boto3
import numpy as np
import pyarrow.parquet as pq
import torch
import torchaudio
from dotenv import load_dotenv

PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from codecbench.pipeline.config import PipelineConfig
from codecbench.pipeline.encoder import HotEncoder
from codecbench.pipeline.vad import Segment


def load_remote_tokens(s3, bucket: str, token_key: str):
    body = s3.get_object(Bucket=bucket, Key=token_key)['Body'].read()
    table = pq.read_table(io.BytesIO(body))
    seg_ids = table.column('segment_id').to_pylist()
    raw = table.column('xcodec2_tokens').to_pylist()
    return {sid: np.frombuffer(b, dtype=np.uint16).copy() for sid, b in zip(seg_ids, raw)}


def load_local_segments(s3, bucket: str, shard_key: str, wanted: set[str]):
    body = s3.get_object(Bucket=bucket, Key=f'{shard_key}audio.tar')['Body'].read()
    out = {}
    with tarfile.open(fileobj=io.BytesIO(body), mode='r') as tar:
        members=[m for m in tar.getmembers() if m.isfile() and m.name.endswith('.flac')]
        for m in members:
            sid = Path(m.name).stem
            if sid not in wanted:
                continue
            f = tar.extractfile(m)
            wav, sr = torchaudio.load(io.BytesIO(f.read()))
            if sr != 16000:
                wav = torchaudio.functional.resample(wav, sr, 16000)
            if wav.shape[0] > 1:
                wav = wav.mean(dim=0, keepdim=True)
            dur = wav.shape[-1] / 16000
            out[sid] = Segment(start_s=0.0, end_s=dur, audio=wav)
            if len(out) == len(wanted):
                break
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--shard-key', required=True)
    ap.add_argument('--token-key', required=True)
    ap.add_argument('--limit', type=int, default=1000)
    args = ap.parse_args()

    load_dotenv(PROJECT_ROOT / '.env')
    bucket = os.environ.get('R2_BUCKET_DESTINATION', 'finalsftdata')
    s3 = boto3.client('s3', endpoint_url=os.environ['R2_ENDPOINT_URL'], aws_access_key_id=os.environ['R2_ACCESS_KEY_ID'], aws_secret_access_key=os.environ['R2_SECRET_ACCESS_KEY'], region_name='auto')

    remote = load_remote_tokens(s3, bucket, args.token_key)
    picked_ids = list(sorted(remote.keys()))[: args.limit]
    segs = load_local_segments(s3, bucket, args.shard_key, set(picked_ids))

    cfg = PipelineConfig.from_env()
    cfg.codec.xcodec2_custom_ckpt = '/tmp/pipeline/xcodec2_custom.ckpt'
    cfg.codec.xcodec_batch_size = 4
    enc = HotEncoder(cfg.codec, device='cuda')
    enc.load()

    bit_exact = 0
    total = 0
    same_len = 0
    token_match_sum = 0.0

    for sid in picked_ids:
        if sid not in segs:
            continue
        out = enc.encode_segments([segs[sid]], xcodec_batch_size_override=1)
        if not out:
            continue
        local = out[0].xcodec2_tokens.squeeze(0).cpu().numpy().astype(np.uint16)
        remote_arr = remote[sid]
        total += 1
        if local.shape == remote_arr.shape:
            same_len += 1
            eq = (local == remote_arr)
            token_match_sum += float(eq.mean())
            if eq.all():
                bit_exact += 1

    print({
        'samples_compared': total,
        'same_length': same_len,
        'bit_exact_samples': bit_exact,
        'bit_exact_pct': round(100 * bit_exact / max(total, 1), 2),
        'avg_token_match_pct': round(100 * token_match_sum / max(same_len, 1), 4),
    })


if __name__ == '__main__':
    main()
