"""
Preprocess Modi dataset → DAC latents.
Encodes all audio through frozen DAC encoder, saves continuous latents.
"""

import os
import json
import numpy as np
import torch
import torchaudio
import dac
from datasets import load_from_disk
from pathlib import Path
import argparse


SEG_CONFIG = {
    "min_duration": 3.0,
    "max_duration": 12.0,
    "target_duration": 8.0,
    "silence_threshold_db": -40,
    "min_silence_len": 0.3,
}


def detect_silence_points(waveform, sr, threshold_db=-40, min_silence_len=0.3):
    frame_length = int(0.025 * sr)
    hop = int(0.010 * sr)
    n_frames = 1 + (len(waveform) - frame_length) // hop
    rms_db = []
    for i in range(n_frames):
        start = i * hop
        frame = waveform[start:start + frame_length]
        rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
        rms_db.append(20 * np.log10(rms + 1e-10))
    rms_db = np.array(rms_db)
    is_silence = rms_db < threshold_db
    silence_points = []
    min_frames = int(min_silence_len / 0.010)
    in_silence = False
    silence_start = 0
    for i, s in enumerate(is_silence):
        if s and not in_silence:
            silence_start = i
            in_silence = True
        elif not s and in_silence:
            if i - silence_start >= min_frames:
                mid_frame = (silence_start + i) // 2
                silence_points.append(mid_frame * hop)
            in_silence = False
    return silence_points


def segment_audio(waveform, sr, text, config):
    total_duration = len(waveform) / sr
    if total_duration <= config["max_duration"]:
        return [(waveform, text)] if total_duration >= config["min_duration"] else []
    silence_points = detect_silence_points(waveform, sr,
        threshold_db=config["silence_threshold_db"],
        min_silence_len=config["min_silence_len"])
    split_points = [0] + silence_points + [len(waveform)]
    segments = []
    current_start = 0
    for point in split_points[1:]:
        current_duration = (point - current_start) / sr
        if current_duration >= config["target_duration"]:
            chunk = waveform[current_start:point]
            chunk_duration = len(chunk) / sr
            if config["min_duration"] <= chunk_duration <= config["max_duration"] * 1.2:
                segments.append(chunk)
            elif chunk_duration > config["max_duration"] * 1.2:
                max_samples = int(config["max_duration"] * sr)
                for j in range(0, len(chunk), max_samples):
                    sub = chunk[j:j + max_samples]
                    if len(sub) / sr >= config["min_duration"]:
                        segments.append(sub)
            current_start = point
    if current_start < len(waveform):
        remaining = waveform[current_start:]
        if len(remaining) / sr >= config["min_duration"]:
            segments.append(remaining)
        elif segments:
            combined = np.concatenate([segments[-1], remaining])
            if len(combined) / sr <= config["max_duration"] * 1.2:
                segments[-1] = combined
    results = []
    total_samples = sum(len(s) for s in segments)
    char_idx = 0
    for seg in segments:
        proportion = len(seg) / total_samples
        n_chars = max(1, int(len(text) * proportion))
        seg_text = text[char_idx:char_idx + n_chars].strip()
        char_idx += n_chars
        results.append((seg, seg_text))
    return results


def trim_silence(waveform, sr, threshold_db=-45):
    threshold = 10 ** (threshold_db / 20)
    abs_wav = np.abs(waveform)
    window = int(0.01 * sr)
    start = 0
    for start in range(0, len(abs_wav) - window, window):
        if np.max(abs_wav[start:start + window]) > threshold:
            break
    end = len(abs_wav)
    for end in range(len(abs_wav), window, -window):
        if np.max(abs_wav[end - window:end]) > threshold:
            break
    pad = int(0.05 * sr)
    return waveform[max(0, start - pad):min(len(waveform), end + pad)]


def preprocess_dataset(input_path, output_path):
    print(f"Loading dataset from {input_path}...")
    ds = load_from_disk(input_path)

    # Load DAC model
    print("Loading DAC 24kHz model...")
    model_path = dac.utils.download(model_type='24khz')
    dac_model = dac.DAC.load(model_path)
    dac_model.eval()
    dac_model = dac_model.to('cuda')
    print(f"DAC loaded. Hop={dac_model.hop_length}, latent_dim={dac_model.latent_dim}")

    output_dir = Path(output_path)
    latent_dir = output_dir / "latents"
    latent_dir.mkdir(parents=True, exist_ok=True)

    manifest = []
    total_duration = 0.0
    idx = 0

    for i in range(len(ds)):
        sample = ds[i]
        audio_data = sample["audio"]
        waveform = np.array(audio_data["array"], dtype=np.float32)
        sr = audio_data["sampling_rate"]
        text = sample["text"]

        waveform = trim_silence(waveform, sr)
        segments = segment_audio(waveform, sr, text, SEG_CONFIG)

        for seg_wav, seg_text in segments:
            if not seg_text:
                continue
            duration = len(seg_wav) / sr
            total_duration += duration

            wav_tensor = torch.from_numpy(seg_wav).unsqueeze(0)

            # Encode with DAC → continuous latents
            with torch.no_grad():
                wav_input = wav_tensor.unsqueeze(0).to('cuda')  # [1, 1, samples]
                wav_input = dac_model.preprocess(wav_input, sr)
                z = dac_model.encoder(wav_input)  # [1, 1024, T]
                z = z.squeeze(0).cpu()  # [1024, T]

            latent_path = latent_dir / f"{idx:06d}.pt"
            torch.save(z.half(), str(latent_path))  # float16 to save disk

            manifest.append({
                "id": idx,
                "latent_path": str(latent_path),
                "text": seg_text,
                "duration": round(duration, 3),
                "latent_frames": z.shape[1],
            })
            idx += 1

        if (i + 1) % 100 == 0:
            print(f"  Processed {i+1}/{len(ds)}, {idx} segments, {total_duration/3600:.2f}h")

    with open(output_dir / "manifest.json", "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False, indent=2)

    config = {
        "dac": {"hop_length": int(dac_model.hop_length), "latent_dim": int(dac_model.latent_dim),
                "sample_rate": 24000},
        "segmentation": SEG_CONFIG,
        "total_segments": idx,
        "total_duration_hours": round(total_duration / 3600, 2),
    }
    with open(output_dir / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    print(f"\nDone! {idx} segments, {total_duration/3600:.2f} hours")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="/home/ubuntu/modi_dataset")
    parser.add_argument("--output", default="/home/ubuntu/lewm-tts/v2/processed_data")
    args = parser.parse_args()
    preprocess_dataset(args.input, args.output)
