from sarvamai import SarvamAI
import requests, time

client = SarvamAI(api_subscription_key="sk_0td9brgk_yh8yJFGUzrJRVA7OuBdLexjC")

# Create fresh job
job = client.speech_to_text_job.create_job(
    model="saaras:v3",
    language_code="te-IN",
    with_diarization=True,
    with_timestamps=True,
)
print(f"Job: {job.job_id}")

# Get upload links
links = client.speech_to_text_job.get_upload_links(job_id=job.job_id, files=["lalal_voice_16k.wav"])
print(f"upload_urls type: {type(links.upload_urls)}")

# Handle dict or list
urls = links.upload_urls
if isinstance(urls, dict):
    for fname, obj in urls.items():
        url = obj.file_url if hasattr(obj, 'file_url') else str(obj)
        print(f"Uploading {fname} -> {url[:80]}...")
        with open("test_mIp/lalal_voice_16k.wav", "rb") as f:
            resp = requests.put(url, data=f, headers={"x-ms-blob-type": "BlockBlob"})
            print(f"Upload: {resp.status_code}")
elif isinstance(urls, list):
    for obj in urls:
        url = obj.file_url if hasattr(obj, 'file_url') else str(obj)
        print(f"Uploading -> {url[:80]}...")
        with open("test_mIp/lalal_voice_16k.wav", "rb") as f:
            resp = requests.put(url, data=f, headers={"x-ms-blob-type": "BlockBlob"})
            print(f"Upload: {resp.status_code}")

# Start
result = client.speech_to_text_job.start(job.job_id)
print(f"Started: {result.job_state}")

# Poll
for i in range(120):
    time.sleep(10)
    status = client.speech_to_text_job.get_status(job_id=job.job_id)
    print(f"[{i+1}] {status.job_state}")
    if status.job_state in ("Completed", "Failed"):
        dl = client.speech_to_text_job.get_download_links(job_id=job.job_id, files=["0.json"])
        dl_urls = dl.download_urls
        if isinstance(dl_urls, dict):
            dl_urls = list(dl_urls.values())
        for idx, d in enumerate(dl_urls):
            d_url = d.file_url if hasattr(d, 'file_url') else str(d)
            r = requests.get(d_url)
            fname = "test_mIp/sarvam_diarization.json"
            with open(fname, "w") as f:
                f.write(r.text)
            print(f"Saved: {fname} ({len(r.text)} chars)")
            print(r.text[:5000])
        break
