"""Single worker: load model, generate one batch, write timing to file."""
import sys, time, torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

worker_id = int(sys.argv[1])
out_file = sys.argv[2]

processor = VibeVoiceProcessor.from_pretrained('microsoft/VibeVoice-1.5B')
try:
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda',
        attn_implementation='flash_attention_2')
except:
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda',
        attn_implementation='sdpa')
model.eval()
model.set_ddpm_inference_steps(num_steps=10)

voice = 'demo/voices/modi.wav'
text = 'Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.'

inp = processor(text=[text], voice_samples=[[voice]], padding=True, return_tensors='pt', return_attention_mask=True)
for k,v in inp.items():
    if torch.is_tensor(v): inp[k]=v.to('cuda')

# Warmup
_ = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
    generation_config={'do_sample':False}, verbose=False, is_prefill=True, show_progress_bar=False)

# Signal ready
with open(f'/tmp/worker_{worker_id}_ready', 'w') as f:
    f.write('ready')

# Wait for go
while True:
    try:
        open('/tmp/parallel_go', 'r')
        break
    except:
        time.sleep(0.01)

torch.cuda.synchronize()
t0 = time.perf_counter()
out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
    generation_config={'do_sample':False}, verbose=False, is_prefill=True, show_progress_bar=False)
torch.cuda.synchronize()
gen = time.perf_counter() - t0

audio_dur = sum(a.shape[-1]/24000.0 for a in out.speech_outputs if a is not None)
with open(out_file, 'w') as f:
    f.write(f'{gen},{audio_dur}')
