"""
Data preprocessing for LeWM TTS.
- Loads Modi dataset from HuggingFace Arrow format
- Segments long clips into 3-15 second chunks using silence detection
- Extracts mel spectrograms
- Saves processed dataset
"""

import os
import json
import numpy as np
import torch
import torchaudio
from datasets import load_from_disk
from pathlib import Path
import argparse


# ─── Mel spectrogram config ───
MEL_CONFIG = {
    "sample_rate": 24000,
    "n_fft": 1024,
    "hop_length": 256,
    "n_mels": 100,
    "f_min": 0,
    "f_max": None,      # full bandwidth (matches Vocos mel-24khz)
    "power": 1.0,
    "clip_val": 1e-7,   # log clamp value (matches Vocos safe_log)
}

# ─── Segmentation config ───
SEG_CONFIG = {
    "min_duration": 3.0,   # seconds
    "max_duration": 12.0,  # seconds
    "target_duration": 8.0,  # preferred chunk size
    "silence_threshold_db": -40,
    "min_silence_len": 0.3,  # seconds
}


def detect_silence_points(waveform: np.ndarray, sr: int,
                          threshold_db: float = -40,
                          min_silence_len: float = 0.3) -> list:
    """Find silence points in audio for splitting."""
    # Convert to dB
    frame_length = int(0.025 * sr)  # 25ms frames
    hop = int(0.010 * sr)           # 10ms hop

    # RMS energy per frame
    n_frames = 1 + (len(waveform) - frame_length) // hop
    rms_db = []
    for i in range(n_frames):
        start = i * hop
        frame = waveform[start:start + frame_length]
        rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
        rms_db.append(20 * np.log10(rms + 1e-10))

    rms_db = np.array(rms_db)
    is_silence = rms_db < threshold_db

    # Find contiguous silence regions
    silence_points = []
    min_frames = int(min_silence_len / 0.010)

    in_silence = False
    silence_start = 0
    for i, s in enumerate(is_silence):
        if s and not in_silence:
            silence_start = i
            in_silence = True
        elif not s and in_silence:
            if i - silence_start >= min_frames:
                # Midpoint of silence region in samples
                mid_frame = (silence_start + i) // 2
                mid_sample = mid_frame * hop
                silence_points.append(mid_sample)
            in_silence = False

    return silence_points


def segment_audio(waveform: np.ndarray, sr: int, text: str, config: dict) -> list:
    """Segment a long audio clip into shorter chunks at silence boundaries."""
    total_duration = len(waveform) / sr

    # If already short enough, return as-is
    if total_duration <= config["max_duration"]:
        if total_duration >= config["min_duration"]:
            return [(waveform, text)]
        else:
            return []  # Too short

    # Find silence points for natural splitting
    silence_points = detect_silence_points(
        waveform, sr,
        threshold_db=config["silence_threshold_db"],
        min_silence_len=config["min_silence_len"]
    )

    # Add start and end
    split_points = [0] + silence_points + [len(waveform)]

    # Greedily merge segments to hit target duration
    segments = []
    current_start = 0

    for i, point in enumerate(split_points[1:], 1):
        current_duration = (point - current_start) / sr

        if current_duration >= config["target_duration"]:
            # Take this segment
            chunk = waveform[current_start:point]
            chunk_duration = len(chunk) / sr
            if config["min_duration"] <= chunk_duration <= config["max_duration"] * 1.2:
                segments.append(chunk)
            elif chunk_duration > config["max_duration"] * 1.2:
                # Force split at max_duration
                max_samples = int(config["max_duration"] * sr)
                for j in range(0, len(chunk), max_samples):
                    sub = chunk[j:j + max_samples]
                    if len(sub) / sr >= config["min_duration"]:
                        segments.append(sub)
            current_start = point

    # Handle remaining audio
    if current_start < len(waveform):
        remaining = waveform[current_start:]
        if len(remaining) / sr >= config["min_duration"]:
            segments.append(remaining)
        elif segments:
            # Append to last segment if it won't make it too long
            combined = np.concatenate([segments[-1], remaining])
            if len(combined) / sr <= config["max_duration"] * 1.2:
                segments[-1] = combined

    # For text: we can't perfectly align without forced alignment,
    # so we use the full text for all segments (will be refined later with forced alignment)
    # For now, distribute text roughly proportionally
    results = []
    total_samples = sum(len(s) for s in segments)
    text_chars = list(text)
    char_idx = 0

    for seg in segments:
        proportion = len(seg) / total_samples
        n_chars = max(1, int(len(text) * proportion))
        seg_text = text[char_idx:char_idx + n_chars]
        char_idx += n_chars
        results.append((seg, seg_text.strip()))

    return results


def extract_mel(waveform: torch.Tensor, config: dict) -> torch.Tensor:
    """Extract log-mel spectrogram (matches Vocos mel-24khz config)."""
    mel_kwargs = dict(
        sample_rate=config["sample_rate"],
        n_fft=config["n_fft"],
        hop_length=config["hop_length"],
        n_mels=config["n_mels"],
        power=config["power"],
    )
    if config.get("f_min") is not None:
        mel_kwargs["f_min"] = config["f_min"]
    if config.get("f_max") is not None:
        mel_kwargs["f_max"] = config["f_max"]

    mel_transform = torchaudio.transforms.MelSpectrogram(**mel_kwargs)
    mel = mel_transform(waveform)
    clip_val = config.get("clip_val", 1e-7)
    log_mel = torch.log(torch.clamp(mel, min=clip_val))
    return log_mel


def trim_silence(waveform: np.ndarray, sr: int, threshold_db: float = -45) -> np.ndarray:
    """Trim leading and trailing silence."""
    # Convert to amplitude threshold
    threshold = 10 ** (threshold_db / 20)

    # Find non-silent regions
    abs_wav = np.abs(waveform)

    # Use a sliding window
    window = int(0.01 * sr)  # 10ms
    for start in range(0, len(abs_wav) - window, window):
        if np.max(abs_wav[start:start + window]) > threshold:
            break

    for end in range(len(abs_wav), window, -window):
        if np.max(abs_wav[end - window:end]) > threshold:
            break

    # Add small padding
    pad = int(0.05 * sr)  # 50ms padding
    start = max(0, start - pad)
    end = min(len(waveform), end + pad)

    return waveform[start:end]


def preprocess_dataset(input_path: str, output_path: str):
    """Main preprocessing pipeline."""
    print(f"Loading dataset from {input_path}...")
    ds = load_from_disk(input_path)

    output_dir = Path(output_path)
    audio_dir = output_dir / "audio"
    mel_dir = output_dir / "mels"
    audio_dir.mkdir(parents=True, exist_ok=True)
    mel_dir.mkdir(parents=True, exist_ok=True)

    manifest = []
    total_duration = 0.0
    idx = 0

    for i in range(len(ds)):
        sample = ds[i]
        audio_data = sample["audio"]
        waveform = np.array(audio_data["array"], dtype=np.float32)
        sr = audio_data["sampling_rate"]
        text = sample["text"]

        # Trim silence
        waveform = trim_silence(waveform, sr)

        # Segment into chunks
        segments = segment_audio(waveform, sr, text, SEG_CONFIG)

        for seg_wav, seg_text in segments:
            if not seg_text:
                continue

            duration = len(seg_wav) / sr
            total_duration += duration

            # Save audio
            audio_path = audio_dir / f"{idx:06d}.wav"
            wav_tensor = torch.from_numpy(seg_wav).unsqueeze(0)
            torchaudio.save(str(audio_path), wav_tensor, sr)

            # Extract and save mel
            mel = extract_mel(wav_tensor, MEL_CONFIG)
            mel_path = mel_dir / f"{idx:06d}.pt"
            torch.save(mel.squeeze(0), str(mel_path))  # [n_mels, T]

            manifest.append({
                "id": idx,
                "audio_path": str(audio_path),
                "mel_path": str(mel_path),
                "text": seg_text,
                "duration": round(duration, 3),
                "mel_frames": mel.shape[-1],
            })

            idx += 1

        if (i + 1) % 100 == 0:
            print(f"  Processed {i+1}/{len(ds)} samples, {idx} segments, {total_duration/3600:.2f}h total")

    # Save manifest
    manifest_path = output_dir / "manifest.json"
    with open(manifest_path, "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False, indent=2)

    # Save config
    config = {
        "mel": MEL_CONFIG,
        "segmentation": SEG_CONFIG,
        "total_segments": idx,
        "total_duration_hours": round(total_duration / 3600, 2),
        "original_samples": len(ds),
    }
    config_path = output_dir / "config.json"
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)

    print(f"\nDone! {idx} segments, {total_duration/3600:.2f} hours")
    print(f"Manifest: {manifest_path}")
    print(f"Config: {config_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="/home/ubuntu/modi_dataset")
    parser.add_argument("--output", default="/home/ubuntu/lewm-tts/processed_data")
    args = parser.parse_args()
    preprocess_dataset(args.input, args.output)
