"""Fine-tune VibeVoice-Realtime-0.5B on Hindi. No semantic tokenizer needed."""
import os, logging, torch, torch.nn.functional as F, numpy as np
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser, Trainer, set_seed, TrainingArguments
from peft import LoraConfig, get_peft_model
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
try:
    import librosa
except:
    librosa = None

logger = logging.getLogger(__name__)

@dataclass
class Args:
    model_path: str = "microsoft/VibeVoice-Realtime-0.5B"
    train_jsonl: str = "data/hindi_soprano.jsonl"
    output_dir: str = "output/streaming_hindi"
    lora_r: int = 8
    lora_alpha: int = 32
    train_diffusion: bool = True
    num_epochs: int = 3
    batch_size: int = 2
    grad_accum: int = 8
    lr: float = 2.5e-5
    save_steps: int = 500

def load_audio_24k(path):
    if librosa:
        wav, sr = librosa.load(path, sr=None, mono=True)
    else:
        import soundfile as sf
        wav, sr = sf.read(path, dtype='float32')
        if wav.ndim > 1: wav = wav.mean(axis=1)
    if sr != 24000:
        import resampy
        wav = resampy.resample(wav.astype(np.float32), sr, 24000)
    return wav.astype(np.float32)

class DS:
    def __init__(self, hf_ds):
        self.ds = hf_ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, i):
        return {"text": self.ds[i]["text"], "audio": self.ds[i]["audio"]}

class Collator:
    def __init__(self, processor, model):
        self.processor = processor
        self.model = model
    def __call__(self, feats):
        texts, wavs = [], []
        for f in feats:
            texts.append(f["text"])
            wavs.append(torch.from_numpy(load_audio_24k(f["audio"])).float())
        proc = self.processor(text=texts, padding=True, truncation=True, return_tensors="pt")
        max_wav = max(w.shape[0] for w in wavs)
        wav_batch = torch.zeros(len(wavs), max_wav)
        wav_mask = torch.zeros(len(wavs), dtype=torch.long)
        for i, w in enumerate(wavs):
            wav_batch[i, :w.shape[0]] = w
            wav_mask[i] = w.shape[0]
        return {**proc, "wav_batch": wav_batch, "wav_lens": wav_mask}

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        wav_batch = inputs.pop("wav_batch").to(model.device)
        wav_lens = inputs.pop("wav_lens").to(model.device)
        input_ids = inputs["input_ids"].to(model.device)
        attn = inputs["attention_mask"].to(model.device)
        # Text encoding
        h = model.model.language_model(input_ids=input_ids, attention_mask=attn).last_hidden_state
        # CE loss
        logits = model.lm_head(h)
        ce = F.cross_entropy(logits[...,:-1,:].reshape(-1, logits.size(-1)), input_ids[...,1:].reshape(-1), ignore_index=0)
        # Audio latents
        with torch.no_grad():
            ac = model.model.acoustic_tokenizer
            lat = ac.encode(wav_batch.unsqueeze(1) if wav_batch.ndim == 2 else wav_batch)
            if hasattr(lat, 'mean'): lat = lat.mean
            sf_v = model.model.speech_scaling_factor
            bf_v = model.model.speech_bias_factor
            if not torch.isnan(sf_v): lat = (lat + bf_v) * sf_v
        # Diffusion loss
        B, T, D = lat.shape
        noise = torch.randn_like(lat)
        ts = torch.randint(0, model.model.noise_scheduler.config.num_train_timesteps, (B*T,), device=lat.device)
        flat_lat = lat.reshape(B*T, D)
        flat_noise = noise.reshape(B*T, D)
        noisy = model.model.noise_scheduler.add_noise(flat_lat, flat_noise, ts)
        cond = h[:, -1:, :].expand(-1, T, -1).reshape(B*T, -1)
        pred = model.model.prediction_head(noisy, ts, condition=cond)
        diff = F.mse_loss(pred, flat_noise)
        loss = 0.04 * ce + 1.4 * diff
        if self.state.global_step % 10 == 0:
            print(f"  step {self.state.global_step}: ce={ce.item():.2f} diff={diff.item():.4f} loss={loss.item():.4f}")
        return loss

def main():
    parser = HfArgumentParser((Args,))
    args = parser.parse_args_into_dataclasses()[0]
    set_seed(42)
    logging.basicConfig(level=logging.INFO)
    print(f"Loading 0.5B from {args.model_path}...")
    model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
        args.model_path, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa")
    proc = VibeVoiceStreamingProcessor.from_pretrained(args.model_path)
    for p in model.model.acoustic_tokenizer.parameters(): p.requires_grad = False
    lora_cfg = LoraConfig(r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.05,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], bias="none")
    model.model.tts_language_model = get_peft_model(model.model.tts_language_model, lora_cfg)
    model.model.tts_language_model.print_trainable_parameters()
    if args.train_diffusion:
        for p in model.model.prediction_head.parameters(): p.requires_grad = True
        print(f"Diff head trainable: {sum(p.numel() for p in model.model.prediction_head.parameters() if p.requires_grad)/1e6:.0f}M")
    raw = load_dataset("json", data_files={"train": args.train_jsonl}, split="train")
    ds = DS(raw)
    print(f"Dataset: {len(ds)} samples")
    ta = TrainingArguments(
        output_dir=args.output_dir, num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr, bf16=True, logging_steps=10, save_steps=args.save_steps,
        save_total_limit=3, remove_unused_columns=False, report_to="wandb",
        run_name="vibevoice-05b-hindi", lr_scheduler_type="cosine", warmup_ratio=0.03,
        gradient_checkpointing=True, do_train=True)
    trainer = MyTrainer(model=model, args=ta, train_dataset=ds, data_collator=Collator(proc, model))
    trainer.train()
    out = os.path.join(args.output_dir, "lora")
    os.makedirs(out, exist_ok=True)
    model.model.tts_language_model.save_pretrained(out)
    if args.train_diffusion:
        torch.save(model.model.prediction_head.state_dict(), os.path.join(out, "diffusion_head_full.bin"))
    print(f"Saved to {out}")

if __name__ == "__main__":
    main()
