"""
Usage:
    python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]

CSV format (header required, "|" delimiter):
    audio_file|text
    /path/to/wavs/audio_0001.wav|Yo! Hello? Hello?
    /path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons.

Notes:
    - audio_file must be an absolute path.
"""

import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess
import sys
from contextlib import contextmanager


sys.path.append(os.getcwd())

import argparse
import csv
import json
from importlib.resources import files
from pathlib import Path

import soundfile as sf
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm

from f5_tts.model.utils import convert_char_to_pinyin


PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")

# Configuration constants
BATCH_SIZE = 100  # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1)  # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100  # Number of files to process per worker batch
executor = None  # Global executor for cleanup


def is_csv_wavs_format(input_path):
    fpath = Path(input_path).expanduser()
    return fpath.is_file() and fpath.suffix.lower() == ".csv"


@contextmanager
def graceful_exit():
    """Context manager for graceful shutdown on signals"""

    def signal_handler(signum, frame):
        print("\nReceived signal to terminate. Cleaning up...")
        if executor is not None:
            print("Shutting down executor...")
            executor.shutdown(wait=False, cancel_futures=True)
        sys.exit(1)

    # Set up signal handlers
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        yield
    finally:
        if executor is not None:
            executor.shutdown(wait=False)


def process_audio_file(audio_path, text, polyphone):
    """Process a single audio file by checking its existence and extracting duration."""
    if not Path(audio_path).exists():
        print(f"audio {audio_path} not found, skipping")
        return None
    try:
        audio_duration = get_audio_duration(audio_path)
        if audio_duration <= 0:
            raise ValueError(f"Duration {audio_duration} is non-positive.")
        return (audio_path, text, audio_duration)
    except Exception as e:
        print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
        return None


def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
    """Convert a list of texts to pinyin in batches."""
    converted_texts = []
    for i in tqdm(
        range(0, len(texts), batch_size),
        total=(len(texts) + batch_size - 1) // batch_size,
        desc="Converting texts to pinyin",
    ):
        batch = texts[i : i + batch_size]
        converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
        converted_texts.extend(converted_batch)
    return converted_texts


def prepare_csv_wavs_dir(input_path, num_workers=None):
    global executor
    if not is_csv_wavs_format(input_path):
        raise ValueError(f"input must be a .csv file: {input_path}")
    audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix())

    polyphone = True
    total_files = len(audio_path_text_pairs)
    if total_files == 0:
        raise RuntimeError("No valid rows found in CSV.")

    # Use provided worker count or calculate optimal number
    worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
    print(f"\nProcessing {total_files} audio files using {worker_count} workers...")

    with graceful_exit():
        # Initialize thread pool with optimized settings
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
        ) as exec:
            executor = exec
            results = []

            # Process files in chunks for better efficiency
            for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
                chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
                # Submit futures in order
                chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]

                # Iterate over futures in the original submission order to preserve ordering
                for future in tqdm(
                    chunk_futures,
                    total=len(chunk),
                    desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
                ):
                    try:
                        result = future.result()
                        if result is not None:
                            results.append(result)
                    except Exception as e:
                        print(f"Error processing file: {e}")

            executor = None

    # Filter out failed results
    processed = [res for res in results if res is not None]
    if not processed:
        raise RuntimeError("No valid audio files were processed!")

    # Batch process text conversion
    raw_texts = [item[1] for item in processed]
    converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)

    # Prepare final results
    sub_result = []
    durations = []
    vocab_set = set()

    for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
        sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
        durations.append(duration)
        vocab_set.update(list(conv_text))

    return sub_result, durations, vocab_set


def get_audio_duration(audio_path, timeout=5):
    """Get the duration of an audio file in seconds with fallbacks."""
    try:
        return sf.info(audio_path).duration
    except Exception as e:
        print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.")

    try:
        cmd = [
            "ffprobe",
            "-v",
            "error",
            "-show_entries",
            "format=duration",
            "-of",
            "default=noprint_wrappers=1:nokey=1",
            audio_path,
        ]
        result = subprocess.run(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
        )
        duration_str = result.stdout.strip()
        if duration_str:
            return float(duration_str)
        raise ValueError("Empty duration string from ffprobe.")
    except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
        print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.")

    try:
        info = torchaudio.info(audio_path)
        if info.sample_rate > 0:
            return info.num_frames / info.sample_rate
        raise ValueError("Invalid sample_rate from torchaudio.info.")
    except Exception as e:
        raise RuntimeError(f"failed to get duration for {audio_path}: {e}")


def read_audio_text_pairs(csv_file_path):
    audio_text_pairs = []

    csv_path = Path(csv_file_path).expanduser().absolute()
    with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile:
        reader = csv.reader(csvfile, delimiter="|")
        header = next(reader, None)
        if header is None:
            return audio_text_pairs
        if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text":
            raise ValueError("CSV header must be: audio_file|text")
        for row_idx, row in enumerate(reader, start=2):
            if len(row) < 2:
                continue
            audio_file = row[0].strip()
            text = row[1].strip()
            if not audio_file:
                continue
            audio_path = Path(audio_file).expanduser()
            if not audio_path.is_absolute():
                raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
            audio_text_pairs.append((audio_path.as_posix(), text))

    return audio_text_pairs


def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
    out_dir = Path(out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)
    print(f"\nSaving to {out_dir} ...")

    raw_arrow_path = out_dir / "raw.arrow"
    with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
        for line in tqdm(result, desc="Writing to raw.arrow ..."):
            writer.write(line)
        writer.finalize()

    # Save durations to JSON
    dur_json_path = out_dir / "duration.json"
    with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
        json.dump({"duration": duration_list}, f, ensure_ascii=False)

    # Handle vocab file - write only once based on finetune flag
    voca_out_path = out_dir / "vocab.txt"
    if is_finetune:
        file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
        shutil.copy2(file_vocab_finetune, voca_out_path)
    else:
        with open(voca_out_path.as_posix(), "w") as f:
            for vocab in sorted(text_vocab_set):
                f.write(vocab + "\n")

    dataset_name = out_dir.stem
    print(f"\nFor {dataset_name}, sample count: {len(result)}")
    print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
    print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")


def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
    if is_finetune:
        assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
    sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
    save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)


def get_args():
    parser = argparse.ArgumentParser(description="Prepare and save dataset.")
    parser.add_argument(
        "inp_dir",
        type=str,
        help="Input CSV with header 'audio_file|text' and absolute wav paths.",
    )
    parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
    parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
    parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
    return parser.parse_args()


def cli():
    try:
        args = get_args()
        prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
    except KeyboardInterrupt:
        print("\nOperation cancelled by user. Cleaning up...")
        if executor is not None:
            executor.shutdown(wait=False, cancel_futures=True)
        sys.exit(1)


if __name__ == "__main__":
    cli()
