from typing import TYPE_CHECKING, Any, cast

import numpy as np
import torch
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.interfaces import supports_mrope
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.sampling_params import SamplingType
from vllm.tracing import instrument
from vllm.utils.import_utils import LazyLoader
from vllm.utils.math_utils import cdiv
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices

from vllm_omni.model_executor.models.output_templates import OmniOutput

if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
    xgr_torch_compile = LazyLoader(
        "xgr_torch_compile",
        globals(),
        "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile",
    )

logger = init_logger(__name__)


class OmniGPUModelRunner(GPUModelRunner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._omni_per_req_additional_information: dict[str, dict] | None = None
        self._omni_num_scheduled_tokens_np: np.ndarray | None = None
        self._omni_last_model_output: object | None = None

    def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes):
        """Override to fix scheduler_metadata buffer size for FA3 + CUDA graph.

        The upstream FlashAttentionMetadataBuilder pre-allocates
        scheduler_metadata with (max_num_seqs + 1) entries, but FA3's
        get_scheduler_metadata() can return up to
        (max_num_seqs * max_num_splits + 1) entries, causing a RuntimeError
        during CUDA graph capture.  After calling the parent implementation
        we resize any too-small buffers.
        """
        super().initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

        for kv_cache_group in self.attn_groups:
            for attn_group in kv_cache_group:
                for builder in attn_group.metadata_builders:
                    sm = getattr(builder, "scheduler_metadata", None)
                    max_num_splits = getattr(builder, "max_num_splits", 0)
                    if sm is not None and max_num_splits > 1:
                        required = self.scheduler_config.max_num_seqs * max_num_splits + 1
                        if sm.shape[0] < required:
                            builder.scheduler_metadata = torch.zeros(
                                required,
                                dtype=sm.dtype,
                                device=sm.device,
                            )

    @instrument(span_name="Loading (GPU)")
    def load_model(self, *args, **kwargs) -> None:
        super().load_model(*args, **kwargs)

        # TODO move this model specific logic to a separate class
        # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS.
        talker_mtp = getattr(self.model, "talker_mtp", None)
        if talker_mtp is not None:
            self.talker_mtp = talker_mtp  # type: ignore[assignment]
            cudagraph_mode = self.compilation_config.cudagraph_mode
            assert cudagraph_mode is not None
            # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
            # have a separate .talker sub-module.  TTS models' code predictor
            # has internal AR loops / torch.multinomial — not graph-safe.
            has_separate_talker = getattr(self.model, "talker", None) is not None
            if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
                self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
            # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
            hidden_size = int(
                getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size")
            )
            max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size)
            self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32)
            self.talker_mtp_inputs_embeds = self._make_buffer(
                max_batch_size, hidden_size, dtype=self.dtype, numpy=False
            )
            self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
            self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)

    def _init_mrope_positions(self, req_state: CachedRequestState):
        """Initialize M-RoPE positions for multimodal inputs.

        Extracts multimodal feature metadata (image grids, video grids,
        audio features) and computes M-RoPE positions for proper positional
        encoding of multimodal tokens.

        Args:
            req_state: Cached request state containing multimodal features

        Raises:
            AssertionError: If the model does not support M-RoPE
        """
        image_grid_thw = []
        video_grid_thw = []
        second_per_grid_ts = []
        audio_feature_lengths = []
        use_audio_in_video = False
        for mm_feature in req_state.mm_features:
            mm_item = mm_feature.data
            if mm_item is None:
                continue
            mm_input = mm_item.get_data()
            if (t := mm_input.get("image_grid_thw")) is not None:
                image_grid_thw.append(t.tolist())
            if (t := mm_input.get("video_grid_thw")) is not None:
                video_grid_thw.append(t.tolist())
            if (t := mm_input.get("second_per_grid_ts")) is not None:
                second_per_grid_ts.append(t)
            if (t := mm_input.get("audio_feature_lengths")) is not None:
                audio_feature_lengths.append(t)
            # Check for use_audio_in_video
            use_audio_in_video_value = mm_input.get("use_audio_in_video")
            if use_audio_in_video_value is not None:
                use_audio_in_video = bool(use_audio_in_video_value.item())

        if supports_mrope(self.get_model()):
            # Model implements SupportsMRoPE interface
            # Pass all extracted metadata; models use what they need via **kwargs
            req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
                req_state.prompt_token_ids,
                mm_features=req_state.mm_features,
                hf_config=self.model_config.hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )
        else:
            req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor(
                req_state.prompt_token_ids,
                hf_config=self.model_config.hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )

    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        """Calculate M-RoPE positions for scheduled tokens.

        Delegates to the upstream implementation first, then applies a fixup
        pass for models that pre-compute 2D spatial decode positions (e.g.
        GLM-Image).  This avoids duplicating the full upstream method while
        still supporting non-linear decode position patterns.

        Models opt-in by declaring ``precomputed_mrope_decode = True`` as a
        class attribute.  When set, ``get_mrope_input_positions`` is expected
        to return positions covering **both** prefill and decode tokens.
        """
        # Run upstream logic (handles prompt positions + linear decode fallback)
        super()._calc_mrope_positions(scheduler_output)

        # Only run the fixup if the model pre-computes decode M-RoPE positions
        if not getattr(self.get_model(), "precomputed_mrope_decode", False):
            return

        self._fixup_precomputed_mrope_decode_positions(scheduler_output)

    def _fixup_precomputed_mrope_decode_positions(self, scheduler_output: "SchedulerOutput") -> None:
        """Overwrite linear decode M-RoPE positions with pre-computed ones.

        For image-generation models (like GLM-Image) that output tokens in 2D
        grid order, ``get_mrope_input_positions`` returns positions for the
        full sequence (prefill + decode).  The upstream runner only uses the
        prefill portion and falls back to linear increments for decode.  This
        method patches the decode slice with the correct pre-computed values.
        """
        from vllm.utils import length_from_prompt_token_ids_or_embeds

        mrope_pos_ptr = 0
        for index, req_id in enumerate(self.input_batch.req_ids):
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = length_from_prompt_token_ids_or_embeds(req.prompt_token_ids, req.prompt_embeds)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                dst_start = mrope_pos_ptr
                decode_start = num_computed_tokens + prompt_part_len
                decode_end = decode_start + completion_part_len
                total_precomputed = req.mrope_positions.shape[1]

                if decode_end <= total_precomputed:
                    # Overwrite the linear positions written by upstream with
                    # the correct pre-computed 2D spatial positions.
                    self.mrope_positions.cpu[:, dst_start : dst_start + completion_part_len] = req.mrope_positions[
                        :, decode_start:decode_end
                    ]

                mrope_pos_ptr += completion_part_len

    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
            self.num_prompt_logprobs.pop(req_id, None)
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
            self.input_batch.remove_request(req_id)

        # Free the cached encoder outputs.
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)

        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
        # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint,
        # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds
        # apart from the forced-preemption case in reset_prefix_cache. And in
        # that case we include the resumed_req_ids in the unscheduled set so
        # that they get cleared from the persistent batch before being re-scheduled
        # in the normal resumed request path.
        unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            self.input_batch.remove_request(req_id)

        reqs_to_add: list[CachedRequestState] = []
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
            pooling_params = new_req_data.pooling_params

            if sampling_params and sampling_params.sampling_type == SamplingType.RANDOM_SEED:
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

            if self.is_pooling_model:
                assert pooling_params is not None
                task = pooling_params.task
                assert task is not None, "You did not set `task` in the API"

                model = cast(VllmModelForPooling, self.get_model())
                to_update = model.pooler.get_pooling_updates(task)
                to_update.apply(pooling_params)

            req_state = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt_embeds=new_req_data.prompt_embeds,
                mm_features=new_req_data.mm_features,
                sampling_params=sampling_params,
                pooling_params=pooling_params,
                generator=generator,
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )
            self.requests[req_id] = req_state

            # If prompt embeddings are provided, decode and attach to inter_data
            try:
                if getattr(new_req_data, "prompt_embeds", None) is not None:
                    payload = new_req_data.prompt_embeds
                    dtype = getattr(np, payload.dtype)
                    arr = np.frombuffer(payload.data, dtype=dtype)
                    arr = arr.reshape(payload.shape)
                    pe_cpu = torch.from_numpy(arr)
                    # Store temporarily on CPU; later moved to device in builder
                    setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu)
                    # Also replace payload with Tensor for user visibility in
                    # scheduler_output
                    try:
                        new_req_data.prompt_embeds = pe_cpu  # type: ignore[assignment]
                    except Exception:
                        pass
            except Exception as e:
                logger.error(f"Error decoding prompt embeds: {e}")
            # Decode additional_information payloads (dictionary)
            try:
                if getattr(new_req_data, "additional_information", None) is not None:
                    payload_info = new_req_data.additional_information
                    info_dict = {}
                    if isinstance(payload_info, dict):
                        info_dict = payload_info
                    else:
                        from vllm_omni.engine import AdditionalInformationPayload

                        if isinstance(payload_info, AdditionalInformationPayload):
                            for k, entry in payload_info.entries.items():
                                if entry.tensor_data is not None:
                                    dt = np.dtype(getattr(entry, "tensor_dtype", "float32"))
                                    arr = np.frombuffer(entry.tensor_data, dtype=dt)
                                    arr = arr.reshape(entry.tensor_shape)
                                    info_dict[k] = torch.from_numpy(arr.copy())
                                else:
                                    info_dict[k] = entry.list_data
                    if info_dict:
                        setattr(
                            self.requests[req_id],
                            "additional_information_cpu",
                            info_dict,
                        )
            except Exception as e:
                logger.error(f"Error decoding additional information: {e}")
                pass

            if sampling_params and sampling_params.prompt_logprobs is not None:
                self.num_prompt_logprobs[req_id] = (
                    self.input_batch.vocab_size
                    if sampling_params.prompt_logprobs == -1
                    else sampling_params.prompt_logprobs
                )
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            if self.uses_mrope:
                self._init_mrope_positions(req_state)

            # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
            if self.uses_xdrope_dim > 0:
                self._init_xdrope_positions(req_state)

            reqs_to_add.append(self.requests[req_id])

        # Update the states of the running/resumed requests.
        is_last_rank = get_pp_group().is_last_rank
        req_data = scheduler_output.scheduled_cached_reqs
        scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens

        # Wait until valid_sampled_tokens_count is copied to cpu,
        # then use it to update actual num_computed_tokens of each request.
        valid_sampled_token_count = self._get_valid_sampled_token_count()

        for i, req_id in enumerate(req_data.req_ids):
            req_state = self.requests[req_id]
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
            resumed_from_preemption = req_id in req_data.resumed_req_ids
            num_output_tokens = req_data.num_output_tokens[i]
            req_index = self.input_batch.req_id_to_index.get(req_id)

            if req_state.prev_num_draft_len and self.use_async_scheduling:
                # prev_num_draft_len is used in async scheduling mode with
                # spec decode. it indicates if need to update num_computed_tokens
                # of the request. for example:
                # fist step: num_computed_tokens = 0, spec_tokens = [],
                # prev_num_draft_len = 0.
                # second step: num_computed_tokens = 100(prompt length),
                # spec_tokens = [a,b], prev_num_draft_len = 0.
                # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
                # prev_num_draft_len = 2.
                # num_computed_tokens in first step and second step does't contain
                # the spec tokens length, but in third step it contains the
                # spec tokens length. we only need to update num_computed_tokens
                # when prev_num_draft_len > 0.
                if req_index is None:
                    req_state.prev_num_draft_len = 0
                else:
                    assert self.input_batch.prev_req_id_to_index is not None
                    prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
                    num_accepted = valid_sampled_token_count[prev_req_index] - 1
                    num_rejected = req_state.prev_num_draft_len - num_accepted
                    num_computed_tokens -= num_rejected
                    req_state.output_token_ids.extend([-1] * num_accepted)

            # Update the cached states.
            req_state.num_computed_tokens = num_computed_tokens

            if not is_last_rank:
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
                num_new_tokens = num_computed_tokens + len(new_token_ids) - req_state.num_tokens
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
                    req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
            elif num_output_tokens < len(req_state.output_token_ids):
                # Some output tokens were discarded due to a sync-KV-load
                # failure. Align the cached state.
                del req_state.output_token_ids[num_output_tokens:]
                if req_index is not None:
                    end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens
                    self.input_batch.num_tokens_no_spec[req_index] = end_idx

            # Update the block IDs.
            if not resumed_from_preemption:
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
                    for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
                        block_ids.extend(new_ids)
            else:
                assert req_index is None
                assert new_block_ids is not None
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.

                if self.use_async_scheduling and num_output_tokens > 0:
                    # We must recover the output token ids for resumed requests in the
                    # async scheduling case, so that correct input_ids are obtained.
                    resumed_token_ids = req_data.all_token_ids[req_id]
                    req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]

                reqs_to_add.append(req_state)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
            if new_block_ids is not None:
                self.input_batch.block_table.append_row(new_block_ids, req_index)

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
                end_token_index = num_computed_tokens + len(new_token_ids)
                self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids
                self.input_batch.num_tokens_no_spec[req_index] = end_token_index

            # Add spec_token_ids to token_ids_cpu.
            self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        for request in reqs_to_add:
            self.input_batch.add_request(request)
            self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens)

        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
        self.input_batch.refresh_metadata()

    @torch.inference_mode()
    def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict:
        if (
            hasattr(self.model, "have_multimodal_outputs")
            and self.model.have_multimodal_outputs
            and isinstance(hidden_states, OmniOutput)
        ):
            text_hidden_states = hidden_states.text_hidden_states
            multimodal_outputs = hidden_states.multimodal_outputs

        elif isinstance(hidden_states, torch.Tensor):
            text_hidden_states = hidden_states
            multimodal_outputs = {}
        elif isinstance(hidden_states, list) or isinstance(hidden_states, tuple):
            text_hidden_states = hidden_states[0]
            multimodal_outputs = {}
        else:
            raise ValueError(f"Invalid hidden states type: {type(hidden_states)}")
        return text_hidden_states, multimodal_outputs

    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
        cudagraph_runtime_mode: CUDAGraphMode | None = None,
        force_attention: bool = False,
        uniform_decode: bool = False,
        allow_microbatching: bool = True,
        skip_eplb: bool = False,
        is_profile: bool = False,
        create_mixed_batch: bool = False,
        remove_lora: bool = True,
        is_graph_capturing: bool = False,
        num_active_loras: int = 0,
        activate_lora: bool | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Run a dummy forward pass to warm up/profile run or capture the
        CUDA graph for the model.

        Args:
            num_tokens: Number of tokens to run the dummy forward pass.
            cudagraph_runtime_mode: used to control the behavior.
                - if not set will determine the cudagraph mode based on using
                    the self.cudagraph_dispatcher.
                - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
                - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
                - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
                    needed.
            force_attention: If True, always create attention metadata. Used to
                warm up attention backend when mode is NONE.
            uniform_decode: If True, the batch is a uniform decode batch.
            skip_eplb: If True, skip EPLB state update.
            is_profile: If True, this is a profile run.
            create_mixed_batch: If True, create a mixed batch with both decode
                (1 token) and prefill (multiple tokens) requests.
            remove_lora: If False, dummy LoRAs are not destroyed after the run
            num_active_loras: Number of active LoRAs to capture for.
            activate_lora: Backward-compatible override for LoRA activation.
        """
        if activate_lora is None:
            activate_lora = num_active_loras > 0

        mm_config = self.vllm_config.model_config.multimodal_config
        if mm_config and mm_config.mm_encoder_only:
            # The current dummy run only covers LM execution, so we can skip it.
            # mm encoder dummy run may need to add in the future.
            return torch.tensor([]), torch.tensor([])

        assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes()

        # If cudagraph_mode.decode_mode() == FULL and
        # cudagraph_mode.separate_routine(). This means that we are using
        # different graphs and/or modes for mixed prefill-decode batches vs.
        # uniform decode batches. A uniform decode batch means that all
        # requests have identical query length, except a potential virtual
        # request (shorter) in the batch account for padding.
        # Uniform decode batch could either be common pure decode, where
        # max_query_len == 1, or speculative decode, where
        # max_query_len == 1 + num_spec_decode_tokens.

        # When setting max_query_len = 1, we switch to and capture the optimized
        # routine of FA2 for pure decode, i.e., Flashdecode + an optimization
        # for GQA/MQA.
        max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens

        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
        if create_mixed_batch:
            assert not uniform_decode
            # Create mixed batch:
            # first half decode tokens, second half one prefill
            num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
            num_prefill_tokens = num_tokens - num_decode_tokens
            num_reqs = num_decode_tokens + 1

            # Create decode requests (1 token each) followed by prefill request
            num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens]
            # Note: Overriding max_query_len to be the prefill tokens
            max_query_len = num_prefill_tokens
        elif uniform_decode:
            assert not create_mixed_batch
            num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
            num_scheduled_tokens_list = [max_query_len] * num_reqs
            if num_tokens % max_query_len != 0:
                num_scheduled_tokens_list[-1] = num_tokens % max_query_len
        else:
            num_reqs = min(num_tokens, max_num_reqs)
            min_tokens_per_req = num_tokens // num_reqs
            num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
            num_scheduled_tokens_list[-1] += num_tokens % num_reqs

        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
        num_tokens_unpadded = int(num_scheduled_tokens.sum())

        num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)

        _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
            self._determine_batch_execution_and_padding(
                num_tokens=num_tokens_unpadded,
                num_reqs=num_reqs,
                num_scheduled_tokens_np=num_scheduled_tokens,
                max_num_scheduled_tokens=max_query_len,
                use_cascade_attn=False,
                allow_microbatching=allow_microbatching,
                force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
                # `force_uniform_decode` is used for cudagraph capture; because for
                # capturing mixed prefill-decode batches, we sometimes use
                # num_tokens == num_reqs which looks like a uniform decode batch to the
                # dispatcher; but we actually want to capture a piecewise cudagraph
                force_uniform_decode=uniform_decode,
                # `force_has_lora` is used for cudagraph capture; because LoRA is
                # activated later in the context manager, but we need to know the
                # LoRA state when determining the batch descriptor for capture
                force_has_lora=activate_lora,
                # Capture shape specialization for specific active LoRA counts.
                force_num_active_loras=num_active_loras,
            )
        )

        if cudagraph_runtime_mode is None:
            cudagraph_runtime_mode = _cudagraph_mode
        else:
            assert cudagraph_runtime_mode == _cudagraph_mode, (
                f"Cudagraph runtime mode mismatch in dummy_run. "
                f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
            )

        num_tokens_padded = batch_desc.num_tokens
        num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
        ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
            should_ubatch,
            num_scheduled_tokens,
            num_tokens_padded,
            num_reqs_padded,
            self.vllm_config.parallel_config.num_ubatches,
        )
        logger.debug(
            "ubatch_slices: %s, ubatch_slices_padded: %s",
            ubatch_slices,
            ubatch_slices_padded,
        )

        attn_metadata: PerLayerAttnMetadata | None = None

        slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
            num_tokens_padded=num_tokens,
            num_reqs_padded=num_reqs_padded,
            num_tokens_unpadded=num_tokens_unpadded,
            ubatch_slices=ubatch_slices_padded,
        )

        # If force_attention is True, we always capture attention. Otherwise,
        # it only happens for cudagraph_runtime_mode=FULL.
        if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
            if create_mixed_batch:
                # In the mixed batch mode (used for FI warmup), we use
                # shorter sequence lengths to run faster.
                # TODO(luka) better system for describing dummy batches
                seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
            else:
                seq_lens = max_query_len  # type: ignore[assignment]
            self.seq_lens.np[:num_reqs] = seq_lens
            self.seq_lens.np[num_reqs:] = 0
            self.seq_lens.copy_to_gpu()

            cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
            self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
            self.query_start_loc.copy_to_gpu()

            pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
            attn_metadata, _ = self._build_attention_metadata(
                num_tokens=num_tokens_unpadded,
                num_reqs=num_reqs_padded,
                max_query_len=max_query_len,
                ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
                for_cudagraph_capture=is_graph_capturing,
                slot_mappings=slot_mappings_by_group,
            )

        with self.maybe_dummy_run_with_lora(
            self.lora_config,
            num_scheduled_tokens,
            num_sampled_tokens,
            activate_lora,
            remove_lora,
        ):
            # Make sure padding doesn't exceed max_num_tokens
            assert num_tokens_padded <= self.max_num_tokens
            model_kwargs = self._init_model_kwargs()
            if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
                input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)

                model_kwargs = {
                    **model_kwargs,
                    **self._dummy_mm_kwargs(num_reqs),
                }
            elif self.enable_prompt_embeds:
                input_ids = None
                inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
                model_kwargs = self._init_model_kwargs()
            elif getattr(getattr(self, "model", None), "has_preprocess", False):
                # Capture CUDA graph with inputs_embeds path so replay reads
                # from the same buffer that _preprocess writes into.
                input_ids = self.input_ids.gpu[:num_tokens_padded]
                inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
            else:
                input_ids = self.input_ids.gpu[:num_tokens_padded]
                inputs_embeds = None

            if self.uses_mrope:
                positions = self.mrope_positions.gpu[:, :num_tokens_padded]
            elif self.uses_xdrope_dim > 0:
                positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
            else:
                positions = self.positions.gpu[:num_tokens_padded]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
                        batch_size=self.max_num_tokens,
                        dtype=self.model_config.dtype,
                        device=self.device,
                    )

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False)

            if ubatch_slices_padded is not None:
                # Adjust values to reflect a single ubatch.
                # TODO(sage,lucas): this is cruft that should be addressed in
                #  the padding refactor.
                num_tokens_padded = ubatch_slices_padded[0].num_tokens
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[:] = num_tokens_padded

            with (
                self.maybe_randomize_inputs(input_ids, inputs_embeds),
                set_forward_context(
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_tokens_padded,
                    num_tokens_across_dp=num_tokens_across_dp,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
                    batch_descriptor=batch_desc,
                    ubatch_slices=ubatch_slices_padded,
                    slot_mapping=slot_mappings,
                ),
            ):
                if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
                    num_tokens_padded_talker_mtp = num_tokens_padded
                    if num_tokens_padded_talker_mtp == self.max_num_tokens:
                        num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0]
                    outputs = self.talker_mtp(
                        self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp],
                        self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp],
                        self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp],
                        self.text_step.gpu[:num_tokens_padded_talker_mtp],
                    )
                    self.compilation_config.cache_dir = None
                outputs = self.model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                    **model_kwargs,
                )

            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
            hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
            if self.speculative_config and self.speculative_config.use_eagle():
                assert isinstance(self.drafter, EagleProposer)
                # Eagle currently only supports PIECEWISE cudagraphs.
                # Therefore only use cudagraphs if the main model uses PIECEWISE
                # NOTE(lucas): this is a hack, need to clean up.
                use_cudagraphs = (
                    (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE)
                    or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE)
                ) and not self.speculative_config.enforce_eager

                # Note(gnovack) - We need to disable cudagraphs for one of the two
                # lora cases when cudagraph_specialize_lora is enabled. This is a
                # short term mitigation for issue mentioned in
                # https://github.com/vllm-project/vllm/issues/28334
                if self.compilation_config.cudagraph_specialize_lora and activate_lora:
                    use_cudagraphs = False

                self.drafter.dummy_run(
                    num_tokens,
                    use_cudagraphs=use_cudagraphs,
                    is_graph_capturing=is_graph_capturing,
                    slot_mappings=slot_mappings,
                )

        # We register layerwise NVTX hooks here after the first dynamo tracing is
        # done to avoid nvtx operations in hook functions being traced by
        # torch dynamo and causing graph breaks.
        # Note that for DYNAMO_ONCE and VLLM_COMPILE mode,
        # compiled model's dynamo tracing is only done once and the compiled model's
        # __call__ function is replaced by calling the compiled function.
        # So it's safe to register hooks here. Hooks will be registered to
        # both compiled and uncompiled models but they will never
        # be called on the compiled model execution path.
        self._register_layerwise_nvtx_hooks()

        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

        logit_indices = np.cumsum(num_scheduled_tokens) - 1
        logit_indices_device = torch.from_numpy(logit_indices).to(self.device, non_blocking=True)
        return hidden_states, hidden_states[logit_indices_device]

    def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None:
        """Decode per-request prompt_embeds and additional_information for newly
        scheduled requests and store them to CPU in the request state.
        This version avoids hard dependency on payload classes by duck-typing."""
        try:
            new_reqs = getattr(scheduler_output, "scheduled_new_reqs", [])
            if not new_reqs:
                return
            for nr in new_reqs:
                req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None)
                if req_id is None:
                    continue
                # prompt_embeds
                payload_pe = getattr(nr, "prompt_embeds", None)
                pe_cpu = None
                if payload_pe is not None:
                    if isinstance(payload_pe, torch.Tensor):
                        pe_cpu = payload_pe.detach().to("cpu").contiguous()
                    else:
                        # Try duck-typing a payload with data/shape/dtype
                        data = getattr(payload_pe, "data", None)
                        shape = getattr(payload_pe, "shape", None)
                        if data is not None and shape is not None:
                            dt = np.dtype(getattr(payload_pe, "dtype", "float32"))
                            arr = np.frombuffer(data, dtype=dt)
                            arr = arr.reshape(shape)
                            pe_cpu = torch.from_numpy(arr.copy())
                if pe_cpu is not None and req_id in self.requests:
                    setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu)
                # additional_information
                payload_info = getattr(nr, "additional_information", None)
                if payload_info is not None:
                    info_dict = {}
                    if isinstance(payload_info, dict):
                        info_dict = payload_info
                    else:
                        # Try duck-typing a payload with entries, each entry may have
                        # tensor_data/tensor_dtype/tensor_shape or list_data
                        entries = getattr(payload_info, "entries", None)
                        if isinstance(entries, dict):
                            for k, entry in entries.items():
                                tensor_data = getattr(entry, "tensor_data", None)
                                if tensor_data is not None:
                                    dt = np.dtype(getattr(entry, "tensor_dtype", "float32"))
                                    arr = np.frombuffer(tensor_data, dtype=dt)
                                    arr = arr.reshape(getattr(entry, "tensor_shape", ()))
                                    info_dict[k] = torch.from_numpy(arr.copy())
                                else:
                                    info_dict[k] = getattr(entry, "list_data", None)
                    if info_dict and req_id in self.requests:
                        setattr(self.requests[req_id], "additional_information_cpu", info_dict)
        except Exception as e:
            logger.error(f"Error decoding prompt_embeds / additional_information: {e}")

    def _gather_runtime_additional_information(self) -> list[dict]:
        """Gather per-request additional_information stored in request state in batch order."""
        per_req_runtime_info = []
        for req_id in self.input_batch.req_ids:
            req_state = self.requests.get(req_id)
            info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
            if info and isinstance(info, dict):
                per_req_runtime_info.append(info)
                if "thinker_reply_part_per_request" in info:
                    q = info["thinker_reply_part_per_request"]
                    if hasattr(q, "shape"):
                        logger.debug(f"[OMNI] req={req_id} has thinker_reply_part_per_request queue shape: {q.shape}")
            else:
                per_req_runtime_info.append({})
        return per_req_runtime_info

    def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[int, int]]:
        """Compute (start, end) token spans for each request within the flattened step sequence."""
        req_token_spans: list[tuple[int, int]] = []
        for req_index in range(len(self.input_batch.req_ids)):
            start_offset = int(self.query_start_loc.cpu[req_index])
            sched_tokens = int(num_scheduled_tokens_np[req_index])
            req_token_spans.append((start_offset, start_offset + sched_tokens))
        return req_token_spans

    def _build_model_kwargs_extra(self) -> dict:
        """Build extra keyword arguments passed to the model for this step, including:
        - runtime_additional_information: per-request additional information stored in request state
        """
        model_kwargs_extra: dict[str, object] = {}
        try:
            model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information()
        except Exception as e:
            logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}")
            import traceback

            traceback.print_exc()
        return model_kwargs_extra

    def _process_additional_information_updates(
        self,
        hidden_states: torch.Tensor,
        multimodal_outputs: object,
        num_scheduled_tokens_np: np.ndarray,
        scheduler_output: "SchedulerOutput",
    ) -> None:
        """Process model-provided per-request additional_information updates and merge into request state."""
        try:
            # execute the custom postprocess function
            # TODO(Peiqi): do we have a more elegant way to do this?
            if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
                for req_index, req_id in enumerate(self.input_batch.req_ids):
                    req_state = self.requests.get(req_id)
                    req_infos = (
                        getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
                    )
                    start_offset = int(self.query_start_loc.cpu[req_index])
                    sched_tokens = int(num_scheduled_tokens_np[req_index])
                    s, e = start_offset, start_offset + sched_tokens
                    # only consider to store data into update dict.
                    hidden_states_slice = hidden_states[s:e]
                    update_dict = self.model.postprocess(hidden_states_slice, **req_infos)
                    self._merge_additional_information_update(req_id, update_dict)
        except Exception as e:
            logger.error(
                f"Error merging for requests:{self.input_batch.req_ids} "
                f"additional information update: {e}, with the multimodal_outputs "
                f"as {multimodal_outputs}"
            )
            import traceback

            traceback.print_exc()

    def _collect_additional_information_for_prefill(
        self,
        num_scheduled_tokens_np: np.ndarray,
    ) -> dict[str, dict]:
        """Overlay per-request prompt_embeds for the prefill portion and collect
        additional_information slices for this step. Returns a map req_id -> dict."""
        for req_index, req_id in enumerate(self.input_batch.req_ids):
            req_state = self.requests[req_id]
            pe_cpu = getattr(req_state, "prompt_embeds_cpu", None)
            num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index])
            prompt_len = len(req_state.prompt_token_ids)
            prompt_remaining = max(0, prompt_len - num_computed_tokens)
            sched_tokens = int(num_scheduled_tokens_np[req_index])
            overlay_len = min(sched_tokens, prompt_remaining)
            if overlay_len <= 0:
                continue
            if overlay_len > 0 and pe_cpu is not None:
                src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to(
                    dtype=self.dtype, device=self.device, non_blocking=True
                )
                start_offset = int(self.query_start_loc.cpu[req_index])
                self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src)

    def _update_request_information(self, request_id: str, payload_info: dict) -> None:
        """Update per-request additional_information stored in request state."""
        req_state = self.requests.get(request_id)
        if req_state is None:
            return

        info_dict = getattr(req_state, "additional_information_cpu", None)
        if isinstance(payload_info, dict) and info_dict is not None:
            info_dict.update(payload_info)

    def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None:
        for new_req in scheduler_output.scheduled_new_reqs:
            payload_info = getattr(new_req, "additional_information", None)
            if isinstance(payload_info, dict):
                self._update_request_information(new_req.req_id, payload_info)

        if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"):
            cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {})
            if isinstance(cached_infos, dict):
                for req_id, req_infos in cached_infos.items():
                    self._update_request_information(req_id, req_infos)

    def _maybe_attach_mimo_audio_req_infos(
        self,
        req_state: CachedRequestState | None,
        req_infos: dict | None,
        req_id: str,
    ) -> dict | None:
        """Attach MiMoAudio-specific fields into req_infos if applicable.

        This helper is intentionally small and self-contained so that it can be
        unit-tested to prevent regressions when updating MiMoAudio handling.
        """
        if req_state is None or self.model.__class__.__name__ != "MiMoAudioForConditionalGeneration":
            return req_infos

        # Always operate on a dict copy to avoid mutating shared instances.
        req_infos = dict(req_infos) if isinstance(req_infos, dict) else {}
        mm_features = getattr(req_state, "mm_features", None)
        if mm_features and (not req_infos.get("mm_features")):
            req_infos["mm_features"] = mm_features
        req_infos["req_id"] = req_id

        return req_infos

    def _preprocess(
        self,
        scheduler_output: "SchedulerOutput",
        num_input_tokens: int,
        intermediate_tensors: IntermediateTensors | None = None,
    ):
        """Align with v0.14.0 preprocess and omni's additional information handling."""
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        is_first_rank = get_pp_group().is_first_rank
        is_encoder_decoder = self.model_config.is_encoder_decoder

        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        ec_connector_output = None

        if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder:
            # Run the multimodal encoder if any.
            with self.maybe_get_ec_connector_output(
                scheduler_output,
                encoder_cache=self.encoder_cache,
            ) as ec_connector_output:
                self._execute_mm_encoder(scheduler_output)
                mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)

            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            inputs_embeds_scheduled = self.model.embed_input_ids(
                self.input_ids.gpu[:num_scheduled_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
            )

            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)

            input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
            model_kwargs = {
                **self._init_model_kwargs(),
                **self._extract_mm_kwargs(scheduler_output),
            }
        elif self.enable_prompt_embeds and is_first_rank:
            # Get the input embeddings for the tokens that are not input embeds,
            # then put them into the appropriate positions.
            # TODO(qthequartermasterman): Since even when prompt embeds are
            # enabled, (a) not all requests will use prompt embeds, and (b)
            # after the initial prompt is processed, the rest of the generated
            # tokens will be token ids, it is not desirable to have the
            # embedding layer outside of the CUDA graph all the time. The v0
            # engine avoids this by "double compiling" the CUDA graph, once
            # with input_ids and again with inputs_embeds, for all num_tokens.
            # If a batch only has token ids, then including the embedding layer
            # in the CUDA graph will be more performant (like in the else case
            # below).
            token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens].nonzero(as_tuple=False).squeeze(1)
            # Some tokens ids may need to become embeds
            if token_ids_idx.numel() > 0:
                token_ids = self.input_ids.gpu[token_ids_idx]
                tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids)
                self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds

            inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
            model_kwargs = self._init_model_kwargs()
            input_ids = self.input_ids.gpu[:num_input_tokens]
        elif getattr(self.model, "has_preprocess", False):
            # Use pre-allocated buffer for CUDA graph compatibility.
            input_ids = self.input_ids.gpu[:num_input_tokens]
            inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
            model_kwargs = self._init_model_kwargs()
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids.gpu[:num_input_tokens]
            inputs_embeds = None
            model_kwargs = self._init_model_kwargs()

        if self.uses_mrope:
            positions = self.mrope_positions.gpu[:, :num_input_tokens]
        elif self.uses_xdrope_dim > 0:
            positions = self.xdrope_positions.gpu[:, :num_input_tokens]
        else:
            positions = self.positions.gpu[:num_input_tokens]

        if is_first_rank:
            intermediate_tensors = None
        else:
            assert intermediate_tensors is not None
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True
            )

        if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
            # Run the encoder, just like we do with other multimodal inputs.
            # For an encoder-decoder model, our processing here is a bit
            # simpler, because the outputs are just passed to the decoder.
            # We are not doing any prompt replacement. We also will only
            # ever have a single encoder input.
            encoder_outputs = self._execute_mm_encoder(scheduler_output)
            model_kwargs.update({"encoder_outputs": encoder_outputs})

        req_ids = self.input_batch.req_ids
        num_scheduled_tokens_np = np.array(
            [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids],
            dtype=np.int32,
        )
        self._omni_num_scheduled_tokens_np = num_scheduled_tokens_np

        # Note: only prefill need collect additional_information for now.
        # Decode don't need per_req_additional_information anymore.
        if inputs_embeds is not None:
            # Prefill: overlay prompt_embeds and collect additional_information
            self._collect_additional_information_for_prefill(num_scheduled_tokens_np)

        if hasattr(self.model, "has_preprocess") and self.model.has_preprocess:
            # Overlay custom prompt_embeds per request for the prompt portion;
            # collect additional_information (tensor/list) for prefill portion only
            decode_req_ids = []
            if self.vllm_config.model_config.async_chunk:
                self._update_additional_information(scheduler_output)
            for req_index, req_id in enumerate(self.input_batch.req_ids):
                req_state = self.requests.get(req_id)
                req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None

                # mimo-audio check
                req_infos = self._maybe_attach_mimo_audio_req_infos(req_state, req_infos, req_id)

                start_offset = int(self.query_start_loc.cpu[req_index])
                sched_tokens = int(num_scheduled_tokens_np[req_index])
                s, e = start_offset, start_offset + sched_tokens
                span_len = int(e) - int(s)

                # call the custom process function
                embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None
                req_input_ids, req_embeds, update_dict = self.model.preprocess(
                    input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos
                )
                if inputs_embeds is None:
                    inputs_embeds = torch.empty(
                        (input_ids.shape[0], req_embeds.shape[-1]),
                        device=req_embeds.device,
                        dtype=req_embeds.dtype,
                    )

                if hasattr(self.model, "talker_mtp") and span_len == 1:
                    last_talker_hidden, text_step = update_dict.pop("mtp_inputs")
                    decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1)
                    self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids)
                    self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds)
                    self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden)
                    self.text_step.gpu[decode_slice].copy_(text_step)
                    decode_req_ids.append(req_id)

                # TODO(Peiqi): the merge stage could move out from the critical path
                self._merge_additional_information_update(req_id, update_dict)

                # update the inputs_embeds and input_ids
                seg_len = min(span_len, req_embeds.shape[0])
                inputs_embeds[s : s + seg_len] = req_embeds[:seg_len]
                if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len:
                    input_ids[s : s + seg_len] = req_input_ids

            # run talker mtp decode
            if hasattr(self.model, "talker_mtp"):
                self._talker_mtp_forward(decode_req_ids, inputs_embeds)

        return (
            input_ids,
            inputs_embeds,
            positions,
            intermediate_tensors,
            model_kwargs,
            ec_connector_output,
        )

    def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None:
        decode_batch_size = len(decode_req_ids)
        if decode_batch_size == 0:
            return
        _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
            num_tokens=decode_batch_size,
            num_reqs=decode_batch_size,
            num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32),
            max_num_scheduled_tokens=1,
            use_cascade_attn=False,
        )
        # Force eager for unwrapped code predictors (AR loops / multinomial).
        if not isinstance(self.talker_mtp, CUDAGraphWrapper):
            _cudagraph_mode = CUDAGraphMode.NONE
        num_tokens_padded = batch_desc.num_tokens
        req_input_ids = self.talker_mtp_input_ids.gpu[:num_tokens_padded]
        req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded]
        last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded]
        text_step = self.text_step.gpu[:num_tokens_padded]
        with set_forward_context(
            None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
        ):
            req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
        # update the inputs_embeds and code_predictor_codes
        code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous()
        out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes")
        for idx, req_id in enumerate(decode_req_ids):
            req_index = self.input_batch.req_ids.index(req_id)
            start_offset = int(self.query_start_loc.cpu[req_index])
            inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
            update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]}
            self._merge_additional_information_update(req_id, update_dict)

    def _model_forward(
        self,
        input_ids: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **model_kwargs: dict[str, Any],
    ):
        """Inject omni-specific kwargs into forward and cache model output"""
        model_kwargs_extra = self._build_model_kwargs_extra()

        model_output = super()._model_forward(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **model_kwargs,
            **model_kwargs_extra,
        )
        if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"):
            model_output = self.model.make_omni_output(model_output, **model_kwargs_extra)
        # Cache model output so later sample_tokens can consume multimodal results.
        self._omni_last_model_output = model_output
        return model_output

    def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None:
        if not isinstance(upd, dict):
            return
        req_state = self.requests.get(req_id)
        if req_state is None:
            return
        existing = getattr(req_state, "additional_information_cpu", {})
        if not isinstance(existing, dict):
            existing = {}
        merged = dict(existing)
        for k, v in upd.items():
            if isinstance(v, torch.Tensor):
                merged[k] = v.detach().to("cpu").contiguous()
            elif isinstance(v, list):
                merged[k] = [
                    (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v
                ]
            else:
                merged[k] = v
        setattr(req_state, "additional_information_cpu", merged)
