#!/usr/bin/env python3
"""Train a SentencePiece BPE tokenizer from a NeMo JSONL manifest.

Extracts text from the manifest, trains a SentencePiece model, and writes
the model + vocab + metadata to the output directory.

Usage:
  # Smoke (small vocab for testing)
  python3 scripts/build_tokenizer.py \
    --manifest data/manifests/smoke_train.jsonl \
    --output-dir tokenizers/smoke_bpe \
    --vocab-size 512

  # Production
  python3 scripts/build_tokenizer.py \
    --manifest data/manifests/train.jsonl \
    --output-dir tokenizers/indic_8k \
    --vocab-size 8192 --character-coverage 0.9999
"""

import argparse
import json
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path

import sentencepiece as spm

# Add src to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import file_sha256


def main():
    parser = argparse.ArgumentParser(description="Train SentencePiece BPE tokenizer from manifest")
    parser.add_argument("--manifest", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--vocab-size", type=int, default=512)
    parser.add_argument(
        "--model-type",
        default="bpe",
        choices=["bpe", "unigram"],
    )
    parser.add_argument("--character-coverage", type=float, default=0.9995)
    parser.add_argument(
        "--max-sentences",
        type=int,
        default=0,
        help="Max sentences for training (0=all, useful for smoke)",
    )
    args = parser.parse_args()

    if not args.manifest.exists():
        print(f"ERROR: Manifest not found: {args.manifest}", file=sys.stderr)
        sys.exit(1)

    # Extract text from manifest
    print(f"Reading text from {args.manifest}...")
    texts = []
    empty_count = 0
    with open(args.manifest) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            text = row.get("text", "")
            if not isinstance(text, str) or not text.strip():
                empty_count += 1
                continue
            texts.append(text.strip())

    if not texts:
        print(
            "ERROR: No valid text found in manifest. Check that rows have non-empty 'text' field.",
            file=sys.stderr,
        )
        sys.exit(1)

    if empty_count > 0:
        print(f"  Skipped {empty_count} rows with empty/missing text")
    print(f"  Extracted {len(texts):,} text samples")

    if args.max_sentences > 0:
        texts = texts[: args.max_sentences]
        print(f"  Truncated to {len(texts):,} sentences (--max-sentences)")

    # Write text to temp file for SentencePiece
    with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tmp:
        for text in texts:
            tmp.write(text + "\n")
        tmp_path = tmp.name

    # Train SentencePiece model
    args.output_dir.mkdir(parents=True, exist_ok=True)
    model_prefix = str(args.output_dir / "tokenizer")

    print(f"Training {args.model_type} tokenizer (vocab_size={args.vocab_size})...")
    spm.SentencePieceTrainer.train(
        input=tmp_path,
        model_prefix=model_prefix,
        vocab_size=args.vocab_size,
        model_type=args.model_type,
        character_coverage=args.character_coverage,
        pad_id=0,
        unk_id=1,
        bos_id=2,
        eos_id=3,
        max_sentence_length=16384,
        num_threads=4,
    )

    # Clean up temp file
    Path(tmp_path).unlink()

    # Verify output files
    model_file = args.output_dir / "tokenizer.model"
    vocab_file = args.output_dir / "tokenizer.vocab"
    if not model_file.exists():
        print(f"ERROR: Model file not created: {model_file}", file=sys.stderr)
        sys.exit(1)

    # Load and verify
    sp = spm.SentencePieceProcessor()
    sp.load(str(model_file))
    actual_vocab = sp.get_piece_size()

    # Write NeMo-compatible vocab.txt (one token per line, required by NeMo BPE)
    nemo_vocab_file = args.output_dir / "vocab.txt"
    with open(nemo_vocab_file, "w") as f:
        for i in range(actual_vocab):
            f.write(sp.id_to_piece(i) + "\n")

    # Write metadata
    metadata = {
        "vocab_size_requested": args.vocab_size,
        "vocab_size": actual_vocab,
        "model_type": args.model_type,
        "character_coverage": args.character_coverage,
        "source_manifest": str(args.manifest.resolve()),
        "source_manifest_sha256": file_sha256(args.manifest),
        "num_training_sentences": len(texts),
        "created_at": datetime.now(timezone.utc).isoformat(),
    }
    metadata_file = args.output_dir / "metadata.json"
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)

    # Summary
    print("\nTokenizer trained successfully:")
    print(f"  Model:    {model_file}")
    print(f"  Vocab:    {vocab_file}")
    print(f"  Metadata: {metadata_file}")
    print(f"  Vocab size: {actual_vocab}")

    # Quick encode test
    sample = texts[0][:80]
    tokens = sp.encode(sample, out_type=str)
    print(f"\n  Sample encode: {sample!r}")
    print(f"  Tokens ({len(tokens)}): {tokens[:15]}{'...' if len(tokens) > 15 else ''}")


if __name__ == "__main__":
    main()
