#!/usr/bin/env python3
"""
Smoke test: loads from config.yaml, constructs Trainer through setup_model/setup_data,
runs one forward/backward step, verifies decoder alignment via the real helper,
and runs one generate call.

Usage:
    python smoke_test.py --config config.yaml
"""

import argparse
import os
import sys

import torch
import numpy as np

os.environ["HF_HOME"] = "/workspace/.hf_home"

# ── Imports from the actual training code ───────────────────────────────
sys.path.insert(0, os.path.dirname(__file__))
from train import TrainConfig, Trainer, load_config
from tokenizer_utils import decode_tokens


def test_trainer_setup(config):
    """Construct Trainer through setup_model() — catches checkpointing, loading, etc."""
    print("=" * 60)
    print("TEST 1: Trainer.setup_model()")
    print("=" * 60)

    # Ensure no distributed env vars so Trainer runs in single-GPU mode
    for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]:
        os.environ.pop(var, None)

    trainer = Trainer(config)
    trainer.setup_model()
    print(f"  Model loaded: {type(trainer.raw_model).__name__}")
    print(f"  Vocab size: {trainer.raw_model.config.vocab_size}")
    print(f"  build_prompt available: {hasattr(trainer.raw_model, 'build_prompt')}")
    return trainer


def test_decoder_alignment(trainer):
    """Verify decoder_input_ids/labels alignment using the REAL _build_decoder_inputs."""
    print("\n" + "=" * 60)
    print("TEST 2: Decoder alignment (via real _build_decoder_inputs)")
    print("=" * 60)

    # Set up the prompt IDs (normally done in setup_data)
    tokenizer = trainer.processor.tokenizer
    trainer.tokenizer = tokenizer
    trainer.lang_prompt_ids = {}
    for lang in ["en", "hi", "te", "ta", "ml", "bn", "gu", "kn", "pa", "mr", "or", "as"]:
        prompt_str = trainer.raw_model.build_prompt(language=lang, punctuation=True)
        trainer.lang_prompt_ids[lang] = tokenizer.encode(prompt_str, add_special_tokens=False)

    eos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
    prompt_len = len(trainer.lang_prompt_ids["hi"])

    # Build a fake batch
    transcript = "नमस्ते दुनिया"
    transcript_ids = tokenizer.encode(transcript, add_special_tokens=False)
    n_tok = len(transcript_ids)

    max_tokens = max(n_tok, 20)
    labels_tensor = torch.full((2, max_tokens), -100, dtype=torch.long)
    labels_tensor[0, :n_tok] = torch.tensor(transcript_ids)
    labels_tensor[1, :n_tok] = torch.tensor(transcript_ids)

    batch = {
        'labels': labels_tensor,
        'token_lengths': torch.tensor([n_tok, n_tok]),
        'languages': ['hi', 'hi'],
    }

    # Call the REAL method
    decoder_input_ids, labels, decoder_attn_mask = trainer._build_decoder_inputs(batch)

    did = decoder_input_ids[0].tolist()
    lab = labels[0].tolist()

    # Verify the key invariant: for supervised positions, labels[i] == did[i+1]
    errors = []
    print(f"  prompt_len={prompt_len}, n_tok={n_tok}, eos={eos_id}, pad={pad_id}")
    print(f"  Checking labels[{prompt_len-1}..{prompt_len+n_tok-1}]:")

    for i in range(prompt_len - 1, prompt_len + n_tok):
        next_input = did[i + 1] if i + 1 < len(did) else -1
        if lab[i] == -100:
            continue
        if lab[i] == eos_id:
            # EOS prediction — next input is pad, that's expected
            print(f"    pos {i}: label=EOS({eos_id}) — OK (final prediction)")
            continue
        if lab[i] != next_input:
            errors.append(f"pos {i}: labels={lab[i]} != input[{i+1}]={next_input}")
            print(f"    pos {i}: label={lab[i]} != input[{i+1}]={next_input} — MISMATCH")
        else:
            print(f"    pos {i}: label={lab[i]} == input[{i+1}]={next_input} — OK")

    if errors:
        print(f"\n  FAILED — {len(errors)} mismatches")
        return False
    print("\n  PASSED — alignment correct")
    return True


def test_forward_backward(trainer):
    """One forward/backward step with synthetic data on the real Trainer."""
    print("\n" + "=" * 60)
    print("TEST 3: Forward/backward pass")
    print("=" * 60)

    if not torch.cuda.is_available():
        print("  SKIPPED — no CUDA")
        return True

    tokenizer = trainer.tokenizer
    eos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
    prompt_ids = trainer.lang_prompt_ids["hi"]

    # Synthetic batch
    B, T = 2, 1000
    transcript_ids = [100, 200, 300, 400, 500]
    n_tok = len(transcript_ids)
    max_tokens = 20
    labels_tensor = torch.full((B, max_tokens), -100, dtype=torch.long)
    for i in range(B):
        labels_tensor[i, :n_tok] = torch.tensor(transcript_ids)

    batch = {
        'mel': torch.randn(B, 128, T, dtype=torch.float16),
        'mel_lengths': torch.tensor([T, T // 2]),
        'labels': labels_tensor,
        'token_lengths': torch.tensor([n_tok, n_tok]),
        'languages': ['hi', 'hi'],
    }

    # Run the real train_step
    trainer.model.train()
    trainer.optimizer = torch.optim.AdamW(trainer.raw_model.parameters(), lr=1e-5)
    trainer.optimizer.zero_grad()
    trainer.spec_augment = None

    metrics = trainer.train_step(batch)

    print(f"  loss={metrics['loss']:.4f}, batch_size={metrics['batch_size']}")
    has_nan = np.isnan(metrics['loss']) or np.isinf(metrics['loss'])
    if has_nan:
        print("  FAILED — NaN/Inf loss")
        return False
    print("  PASSED")
    return True


def test_generate(trainer):
    """One generate call on synthetic mel."""
    print("\n" + "=" * 60)
    print("TEST 4: Generate call")
    print("=" * 60)

    if not torch.cuda.is_available():
        print("  SKIPPED — no CUDA")
        return True

    model = trainer.raw_model
    tokenizer = trainer.tokenizer
    model.eval()

    mel = torch.randn(1, 128, 500, dtype=torch.bfloat16).to(trainer.device)
    mel_length = torch.tensor([500], dtype=torch.long).to(trainer.device)

    prompt_ids = trainer.lang_prompt_ids["hi"]
    decoder_input_ids = torch.tensor([prompt_ids], dtype=torch.long).to(trainer.device)

    with torch.no_grad():
        outputs = model.generate(
            input_features=mel,
            length=mel_length,
            decoder_input_ids=decoder_input_ids,
            max_new_tokens=20,
        )

    text = decode_tokens(tokenizer, outputs[0])
    print(f"  Generated {outputs.shape[1]} tokens: {text[:100]}...")
    print("  PASSED")
    return True


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="config.yaml")
    args = parser.parse_args()

    config = load_config(args.config)
    print(f"Config loaded: model_name={config.model_name}")

    results = {}

    trainer = test_trainer_setup(config)
    results["trainer_setup"] = trainer is not None

    results["alignment"] = test_decoder_alignment(trainer)
    results["forward_backward"] = test_forward_backward(trainer)
    results["generate"] = test_generate(trainer)

    print("\n" + "=" * 60)
    print("SMOKE TEST SUMMARY")
    print("=" * 60)
    for name, passed in results.items():
        print(f"  {name}: {'PASS' if passed else 'FAIL'}")

    all_pass = all(results.values())
    print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILURES'}")
    return 0 if all_pass else 1


if __name__ == "__main__":
    sys.exit(main())
