import soundfile as sf
import torchaudio
import torch
from voxcpm import VoxCPM

ref_audio_full = "/home/ubuntu/modi_mann_ki_baat_2271149762.mp3"
ref_audio_clip = "/home/ubuntu/modi_clip_10s.wav"

print("Step 1: Extracting 10s clip (skipping first 15s of intro)...")
waveform, sr = torchaudio.load(ref_audio_full)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)
start_sec, end_sec = 15, 25
waveform_clip = waveform[:, sr * start_sec : sr * end_sec]
if sr != 16000:
    resampler = torchaudio.transforms.Resample(sr, 16000)
    waveform_clip = resampler(waveform_clip)
torchaudio.save(ref_audio_clip, waveform_clip, 16000)
print(f"Saved clip ({start_sec}s-{end_sec}s) to {ref_audio_clip}")

print("\nStep 2: Transcribing with OpenAI Whisper...")
import whisper
whisper_model = whisper.load_model("base", device="cuda")
result = whisper_model.transcribe(ref_audio_clip, language="hi")
prompt_text = result["text"].strip()
print(f"Transcription: {prompt_text}")
del whisper_model
torch.cuda.empty_cache()

print("\nStep 3: Loading VoxCPM1.5...")
model = VoxCPM.from_pretrained("openbmb/VoxCPM1.5")
print("Model loaded!")

hindi_text = "नमस्कार, मैं आज आपसे बात करना चाहता हूँ। भारत एक महान देश है और हम सब मिलकर इसे और भी आगे ले जा सकते हैं। आइए, हम सब मिलकर एक नए भारत का निर्माण करें।"

print(f"\nGenerating Hindi speech with voice clone...")
print(f"Prompt text: {prompt_text}")
print(f"Target text: {hindi_text}")

wav = model.generate(
    text=hindi_text,
    prompt_wav_path=ref_audio_clip,
    prompt_text=prompt_text,
    cfg_value=2.0,
    inference_timesteps=10,
    normalize=False,
    denoise=False,
    retry_badcase=True,
    retry_badcase_max_times=3,
    retry_badcase_ratio_threshold=6.0,
)

output_path = "/home/ubuntu/voxcpm_hindi_clone_output.wav"
sf.write(output_path, wav, model.tts_model.sample_rate)
print(f"\nSaved to: {output_path}")
print(f"Sample rate: {model.tts_model.sample_rate}")
print(f"Duration: {len(wav) / model.tts_model.sample_rate:.2f}s")
