Qwen3-TTS - vLLM-Omni Skip to content # Qwen3-TTS¶ Source [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/qwen3_tts](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/qwen3_tts). This directory contains an offline demo for running Qwen3 TTS models with vLLM Omni. It builds task-specific inputs and generates WAV files locally. ## Model Overview¶ Qwen3 TTS provides multiple task variants for speech generation: - Base: Voice cloning using a reference audio + reference transcript, with optional mode selection. - VoiceDesign: Generate speech from text plus a descriptive instruction that designs a new voice. - CustomVoice: Generate speech with a known speaker identity (speaker ID) and optional instruction. - Example materials ## Setup¶ Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. ### ROCm Dependencies¶ You will need to install these two dependencies`onnxruntime-rocm` and`sox`. ``` pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm pip install onnxruntime-rocm sox ``` ## Quick Start¶ Run a single sample for a task: ``` python end2end.py --query-type CustomVoice ``` Generated audio files are saved to`output_audio/` by default. ## Task Usage¶ ### CustomVoice¶ Single sample: ``` python end2end.py --query-type CustomVoice ``` Batch sample (multiple prompts in one run): ``` python end2end.py --query-type CustomVoice --use-batch-sample ``` ### VoiceDesign¶ Single sample: ``` python end2end.py --query-type VoiceDesign ``` Batch sample: ``` python end2end.py --query-type VoiceDesign --use-batch-sample ``` ### Base (Voice Clone)¶ Single sample: ``` python end2end.py --query-type Base ``` Batch sample: ``` python end2end.py --query-type Base --use-batch-sample ``` Mode selection for Base: - `--mode-tag xvec_only`: enable`x_vector_only_mode` in the request - `--mode-tag icl`(default): standard mode Examples: ``` python end2end.py --query-type Base --mode-tag icl ``` ## Streaming Mode¶ Add`--streaming` to stream audio chunks progressively via`AsyncOmni`(requires`async_chunk: true` in the stage config): ``` python end2end.py --query-type CustomVoice --streaming --output-dir /tmp/out_stream ``` Each Code2Wav chunk is logged as it arrives (default 25 frames; configurable via`codec_chunk_frames` in the stage config). The initial chunk size is dynamically selected based on server load for reduced TTFA, and can be overridden per-request via the`initial_codec_chunk_frames` API field. The final WAV file is written once generation completes. This demonstrates that audio data is available progressively rather than only at the end. Note: Streaming uses`AsyncOmni` internally. The non-streaming path (`Omni`) is unchanged. ## Batched Decoding¶ The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with`max_batch_size > 1` and pass multiple prompts via`--txt-prompts` with a matching`--batch-size`. ``` python end2end.py --query-type CustomVoice \ --txt-prompts benchmark_prompts.txt \ --batch-size 4 \ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml ``` Important:`--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to`max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need`max_batch_size >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher`max_batch_size`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently. ## Notes¶ - Use`--output-dir` to change the output folder. - The script uses the model paths embedded in`end2end.py`. Update them if your local cache path differs. ## Example materials¶ benchmark_prompts.txt ``` Hello, welcome to the voice synthesis benchmark test. She said she would be here by noon, but nobody showed up. The quick brown fox jumps over the lazy dog near the riverbank. I can't believe how beautiful the sunset looks from up here on the mountain. Please remember to bring your identification documents to the appointment tomorrow morning. Have you ever wondered what it would be like to travel through time and visit ancient civilizations? The restaurant on the corner serves the best pasta I have ever tasted in my entire life. After the meeting, we should discuss the quarterly results and plan for the next phase. Learning a new language takes patience, practice, and a genuine curiosity about other cultures. The train leaves at half past seven, so we need to arrive at the station before then. Could you please turn down the music a little bit, I'm trying to concentrate on my work. It was a dark and stormy night when the old lighthouse keeper heard a knock at the door. ``` end2end.py ``` """Offline inference demo for Qwen3 TTS via vLLM Omni. Provides single and batch sample inputs for CustomVoice, VoiceDesign, and Base tasks, then runs Omni generation and saves output wav files. """ import asyncio import logging import os import time from typing import Any, NamedTuple import soundfile as sf import torch os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm_omni import AsyncOmni, Omni logger = logging.getLogger(__name__) class QueryResult(NamedTuple): """Container for a prepared Omni request.""" inputs: dict model_name: str def _estimate_prompt_len( additional_information: dict[str, Any], model_name: str, _cache: dict[str, Any] = {}, ) -> int: """Estimate prompt_token_ids placeholder length for the Talker stage. The AR Talker replaces all input embeddings via ``preprocess``, so the placeholder values are irrelevant but the **length** must match the embeddings that ``preprocess`` will produce. """ try: from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( Qwen3TTSTalkerForConditionalGeneration, ) if model_name not in _cache: from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) tok, tcfg = _cache[model_name] task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( additional_information=additional_information, task_type=task_type, tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], codec_language_id=getattr(tcfg, "codec_language_id", None), spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), ) except Exception as exc: logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) return 2048 def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: """Build CustomVoice sample inputs. Args: use_batch_sample: When True, return a batch of prompts; otherwise a single prompt. Returns: QueryResult with Omni inputs and the CustomVoice model path. """ task_type = "CustomVoice" model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" if use_batch_sample: texts = [ "其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon.", "I like you very much.", "Really, you do?", "Yes, absolutely.", ] instructs = ["", "Very happy.", "Very happy.", "Very happy.", "Very happy."] languages = ["Chinese", "English", "English", "English", "English"] speakers = ["Vivian", "Ryan", "Ryan", "Ryan", "Ryan"] inputs = [] for text, instruct, language, speaker in zip(texts, instructs, languages, speakers): additional_information = { "task_type": [task_type], "text": [text], "instruct": [instruct], "language": [language], "speaker": [speaker], "max_new_tokens": [2048], } inputs.append( { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } ) else: text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。" language = "Chinese" speaker = "Vivian" instruct = "用特别愤怒的语气说" additional_information = { "task_type": [task_type], "text": [text], "language": [language], "speaker": [speaker], "instruct": [instruct], "max_new_tokens": [2048], } inputs = { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } return QueryResult( inputs=inputs, model_name=model_name, ) def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: """Build VoiceDesign sample inputs. Args: use_batch_sample: When True, return a batch of prompts; otherwise a single prompt. Returns: QueryResult with Omni inputs and the VoiceDesign model path. """ task_type = "VoiceDesign" model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" if use_batch_sample: texts = [ "哥哥,你回来啦,人家等了你好久好久了,要抱抱!", "It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!", ] instructs = [ "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。", "Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice.", ] languages = ["Chinese", "English"] inputs = [] for text, instruct, language in zip(texts, instructs, languages): additional_information = { "task_type": [task_type], "text": [text], "language": [language], "instruct": [instruct], "max_new_tokens": [2048], "non_streaming_mode": [True], } inputs.append( { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } ) else: text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!" instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。" language = "Chinese" additional_information = { "task_type": [task_type], "text": [text], "language": [language], "instruct": [instruct], "max_new_tokens": [2048], "non_streaming_mode": [True], } inputs = { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } return QueryResult( inputs=inputs, model_name=model_name, ) def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> QueryResult: """Build Base (voice clone) sample inputs. Args: use_batch_sample: When True, return a batch of prompts (Case 2). mode_tag: "icl" or "xvec_only" to control x_vector_only_mode behavior. Returns: QueryResult with Omni inputs and the Base model path. """ task_type = "Base" model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" ref_audio_single = ref_audio_path_1 ref_text_single = ( "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you." ) syn_text_single = "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye." syn_lang_single = "Auto" x_vector_only_mode = mode_tag == "xvec_only" if use_batch_sample: syn_text_batch = [ "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye.", "其实我真的有发现,我是一个特别善于观察别人情绪的人。", ] syn_lang_batch = ["Chinese", "English"] inputs = [] for text, language in zip(syn_text_batch, syn_lang_batch): additional_information = { "task_type": [task_type], "ref_audio": [ref_audio_single], "ref_text": [ref_text_single], "text": [text], "language": [language], "x_vector_only_mode": [x_vector_only_mode], "max_new_tokens": [2048], } inputs.append( { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } ) else: additional_information = { "task_type": [task_type], "ref_audio": [ref_audio_single], "ref_text": [ref_text_single], "text": [syn_text_single], "language": [syn_lang_single], "x_vector_only_mode": [x_vector_only_mode], "max_new_tokens": [2048], } inputs = { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), "additional_information": additional_information, } return QueryResult( inputs=inputs, model_name=model_name, ) query_map = { "CustomVoice": get_custom_voice_query, "VoiceDesign": get_voice_design_query, "Base": get_base_query, } def _build_inputs(args) -> tuple[str, list]: """Resolve model name and inputs list from CLI args.""" if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0: raise ValueError( f"--batch-size must be a power of two (got {args.batch_size}); " "non-power-of-two values do not align with CUDA graph capture sizes " "of Code2Wav." ) query_func = query_map[args.query_type] if args.query_type in {"CustomVoice", "VoiceDesign"}: query_result = query_func(use_batch_sample=args.use_batch_sample) elif args.query_type == "Base": query_result = query_func(use_batch_sample=args.use_batch_sample, mode_tag=args.mode_tag) else: query_result = query_func() model_name = query_result.model_name if args.txt_prompts: with open(args.txt_prompts) as f: lines = [line.strip() for line in f if line.strip()] if not lines: raise ValueError(f"No valid prompts found in {args.txt_prompts}") template = query_result.inputs if not isinstance(query_result.inputs, list) else query_result.inputs[0] template_info = template["additional_information"] inputs = [ { "prompt_token_ids": [0] * _estimate_prompt_len({**template_info, "text": [t]}, model_name), "additional_information": {**template_info, "text": [t]}, } for t in lines ] else: inputs = query_result.inputs if isinstance(query_result.inputs, list) else [query_result.inputs] return model_name, inputs def _save_wav(output_dir: str, request_id: str, mm: dict) -> None: """Concatenate audio chunks and write to a wav file.""" audio_data = mm["audio"] sr_raw = mm["sr"] sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw sr = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) audio_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data out_wav = os.path.join(output_dir, f"output_{request_id}.wav") sf.write(out_wav, audio_tensor.float().cpu().numpy().flatten(), samplerate=sr, format="WAV") logger.info(f"Request ID: {request_id}, Saved audio to {out_wav}") def main(args): """Run offline inference with Omni.""" model_name, inputs = _build_inputs(args) output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) omni = Omni( model=model_name, stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, ) batch_size = args.batch_size for batch_start in range(0, len(inputs), batch_size): batch = inputs[batch_start : batch_start + batch_size] for stage_outputs in omni.generate(batch): for output in stage_outputs.request_output: _save_wav(output_dir, output.request_id, output.outputs[0].multimodal_output) async def main_streaming(args): """Run offline inference with AsyncOmni, logging each audio chunk as it arrives.""" model_name, inputs = _build_inputs(args) output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) omni = AsyncOmni( model=model_name, stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, ) for i, prompt in enumerate(inputs): request_id = str(i) t_start = time.perf_counter() t_prev = t_start chunk_idx = 0 async for stage_output in omni.generate(prompt, request_id=request_id): mm = stage_output.request_output.outputs[0].multimodal_output if not stage_output.finished: t_now = time.perf_counter() audio = mm.get("audio") n = len(audio) if isinstance(audio, list) else (0 if audio is None else 1) dt_ms = (t_now - t_prev) * 1000 ttfa_ms = (t_now - t_start) * 1000 if chunk_idx == 0: logger.info(f"Request {request_id}: chunk {chunk_idx} samples={n} TTFA={ttfa_ms:.1f}ms") else: logger.info(f"Request {request_id}: chunk {chunk_idx} samples={n} inter_chunk={dt_ms:.1f}ms") t_prev = t_now chunk_idx += 1 else: t_end = time.perf_counter() total_ms = (t_end - t_start) * 1000 logger.info(f"Request {request_id}: done total={total_ms:.1f}ms chunks={chunk_idx}") _save_wav(output_dir, request_id, mm) def parse_args(): parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") parser.add_argument( "--query-type", "-q", type=str, default="CustomVoice", choices=query_map.keys(), help="Query type.", ) parser.add_argument( "--log-stats", action="store_true", default=False, help="Enable writing detailed statistics (default: disabled)", ) parser.add_argument( "--stage-init-timeout", type=int, default=300, help="Timeout for initializing a single stage in seconds (default: 300)", ) parser.add_argument( "--batch-timeout", type=int, default=5, help="Timeout for batching in seconds (default: 5)", ) parser.add_argument( "--init-timeout", type=int, default=300, help="Timeout for initializing stages in seconds (default: 300)", ) parser.add_argument( "--shm-threshold-bytes", type=int, default=65536, help="Threshold for using shared memory in bytes (default: 65536)", ) parser.add_argument( "--output-dir", default="output_audio", help="Output directory for generated wav files (default: output_audio).", ) parser.add_argument( "--num-prompts", type=int, default=1, help="Number of prompts to generate.", ) parser.add_argument( "--txt-prompts", type=str, default=None, help="Path to a .txt file with one prompt per line (preferred).", ) parser.add_argument( "--stage-configs-path", type=str, default=None, help="Path to a stage configs file.", ) parser.add_argument( "--audio-path", "-a", type=str, default=None, help="Path to local audio file. If not provided, uses default audio asset.", ) parser.add_argument( "--sampling-rate", type=int, default=16000, help="Sampling rate for audio loading (default: 16000).", ) parser.add_argument( "--log-dir", type=str, default="logs", help="Log directory (default: logs).", ) parser.add_argument( "--py-generator", action="store_true", default=False, help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.", ) parser.add_argument( "--use-batch-sample", action="store_true", default=False, help="Use batch input sample for CustomVoice/VoiceDesign/Base query.", ) parser.add_argument( "--mode-tag", type=str, default="icl", choices=["icl", "xvec_only"], help="Mode tag for Base query x_vector_only_mode (default: icl).", ) parser.add_argument( "--streaming", action="store_true", default=False, help="Stream audio chunks as they arrive via AsyncOmni (async_chunk mode only).", ) parser.add_argument( "--batch-size", type=int, default=1, help="Number of prompts per batch (default: 1, sequential).", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.streaming: asyncio.run(main_streaming(args)) else: main(args) ```