# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project and the HuggingFace Team.
# All rights reserved.
#
# This module is adapted from HuggingFace diffusers library:
#   diffusers/src/diffusers/hooks/context_parallel.py
#
# NOTE: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers.
# We use the term "Sequence Parallelism" to align with vLLM-Omni's existing terminology.
#
# Key adaptations from diffusers:
#   - ModuleForwardMetadata: parameter lookup logic (adapted)
#   - SequenceParallelSplitHook/GatherHook: hook structure (adapted from ContextParallel*)
#   - apply_sequence_parallel: registration logic (adapted from apply_context_parallel)
#
# Key differences from diffusers:
#   - Uses vLLM-Omni's SequenceParallelGroupCoordinator instead of DeviceMesh
#   - Uses sp_shard/sp_gather from sp_sharding.py instead of funcol operations
#   - Supports Ulysses + Ring hybrid parallelism
#
"""Sequence Parallelism hooks for non-intrusive SP support.

This module implements the hook-based mechanism for applying sequence parallelism
to models without modifying their forward() methods.

Usage:
    1. Define _sp_plan on your model class (corresponds to diffusers' _cp_plan)
    2. Call apply_sequence_parallel(model, config, plan) to enable SP
    3. Call remove_sequence_parallel(model, plan) to disable SP

The hooks automatically shard inputs before forward and gather outputs after,
based on the plan specification.
"""

from __future__ import annotations

import inspect
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn
from vllm.logger import init_logger

from vllm_omni.diffusion.distributed.sp_plan import (
    AnySequenceParallelInput,
    SequenceParallelConfig,
    SequenceParallelInput,
    SequenceParallelModelPlan,
    SequenceParallelOutput,
    SequenceParallelPartialInput,
)
from vllm_omni.diffusion.distributed.sp_sharding import sp_gather, sp_shard
from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook

logger = init_logger(__name__)

# Hook name templates for identifying SP hooks
_SP_INPUT_HOOK_TEMPLATE = "sp_input---{}"
_SP_OUTPUT_HOOK_TEMPLATE = "sp_output---{}"


@dataclass
class ModuleForwardMetadata:
    """Metadata for mapping forward() parameter names to args/kwargs positions.

    This caches the inspection of a module's forward signature to efficiently
    locate parameters by name in subsequent calls.
    """

    cached_parameter_indices: dict[str, int] | None = None
    _cls: type | None = None

    def _get_parameter_from_args_kwargs(
        self,
        identifier: str,
        args: tuple = (),
        kwargs: dict | None = None,
    ) -> tuple[Any, bool, int | None]:
        """Get a parameter value from args or kwargs by name.

        Args:
            identifier: The parameter name to look up.
            args: Positional arguments passed to forward.
            kwargs: Keyword arguments passed to forward.

        Returns:
            Tuple of (value, is_kwarg, index).
            - value: The parameter value (or None if not found)
            - is_kwarg: True if found in kwargs
            - index: Position in args if found there

        Raises:
            ValueError: If parameter not found in signature.
        """
        kwargs = kwargs or {}

        # First check kwargs
        if identifier in kwargs:
            return kwargs[identifier], True, None

        # Check cached indices
        if self.cached_parameter_indices is not None:
            index = self.cached_parameter_indices.get(identifier, None)
            if index is None:
                raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
            if index < len(args):
                return args[index], False, index
            return None, False, index

        # Build cache from forward signature
        if self._cls is None:
            raise ValueError("Model class is not set for metadata.")

        parameters = list(inspect.signature(self._cls.forward).parameters.keys())
        parameters = parameters[1:]  # Skip `self`
        self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}

        if identifier not in self.cached_parameter_indices:
            raise ValueError(f"Parameter '{identifier}' not found in function signature.")

        index = self.cached_parameter_indices[identifier]

        if index >= len(args):
            return None, False, index

        return args[index], False, index


def _unwrap_module(module: nn.Module) -> nn.Module:
    """Unwrap a module from any wrappers to get the original class.

    Args:
        module: Potentially wrapped module.

    Returns:
        The unwrapped module.
    """
    # Handle common wrappers
    while hasattr(module, "_modules") and len(module._modules) == 1:
        inner = next(iter(module._modules.values()))
        if inner is not None:
            module = inner
        else:
            break
    return module


class SequenceParallelSplitHook(ModelHook):
    """Hook for splitting inputs before a module's forward pass.

    This hook is registered to modules that need their inputs sharded
    across sequence parallel ranks. It intercepts the forward call,
    shards specified inputs according to the plan, and passes the
    sharded inputs to the original forward.

    For split_output=True inputs, it shards the output instead.

    Supports both SequenceParallelInput (full split) and SequenceParallelPartialInput
    (partial split for text/image separation).

    Note: This corresponds to `ContextParallelSplitHook` in diffusers.
    """

    def __init__(
        self,
        metadata: dict[str | int, AnySequenceParallelInput | list[AnySequenceParallelInput]],
        config: SequenceParallelConfig,
    ) -> None:
        super().__init__()
        self.metadata = metadata
        self.config = config
        self.module_forward_metadata: ModuleForwardMetadata | None = None
        # Cache for text lengths resolved from kwargs
        self._text_len_cache: dict[str, int] = {}

    def initialize_hook(self, module: nn.Module) -> nn.Module:
        cls = _unwrap_module(module).__class__
        self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
        return module

    def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tuple, dict]:
        """Shard inputs before forward."""
        args_list = list(args)
        # Clear text length cache for this forward pass
        self._text_len_cache.clear()

        for name, spm in self.metadata.items():
            # Skip if this is a split_output entry (handled in post_forward)
            if isinstance(spm, (SequenceParallelInput, SequenceParallelPartialInput)) and spm.split_output:
                continue

            # Get the parameter value
            input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
                name, args_list, kwargs
            )

            if input_val is None:
                continue

            # Shard the input
            if isinstance(input_val, torch.Tensor):
                input_val = self._prepare_sp_input(input_val, spm, args_list, kwargs)
            elif isinstance(input_val, (list, tuple)):
                # Handle list/tuple of tensors with per-element config
                if not isinstance(spm, (list, tuple)):
                    raise ValueError(
                        f"Expected list/tuple of SequenceParallelInput for parameter '{name}' "
                        f"which is a list/tuple, but got {type(spm).__name__}"
                    )
                if len(input_val) != len(spm):
                    raise ValueError(f"Expected {len(spm)} elements for parameter '{name}', got {len(input_val)}")
                sharded_input_val = []
                for i, x in enumerate(input_val):
                    if torch.is_tensor(x) and not spm[i].split_output:
                        x = self._prepare_sp_input(x, spm[i], args_list, kwargs)
                    sharded_input_val.append(x)
                input_val = type(input_val)(sharded_input_val)
            else:
                raise ValueError(f"Unsupported input type for sharding: {type(input_val).__name__}")

            # Update args or kwargs
            if is_kwarg:
                kwargs[name] = input_val
            elif index is not None and index < len(args_list):
                args_list[index] = input_val
            else:
                raise ValueError(f"Failed to update parameter '{name}' after sharding.")

        # Store kwargs for post_forward to resolve text lengths
        self._last_kwargs = kwargs
        self._last_args = tuple(args_list)

        return tuple(args_list), kwargs

    def post_forward(self, module: nn.Module, output: Any) -> Any:
        """Shard outputs for split_output=True entries."""
        from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available

        is_tensor = isinstance(output, torch.Tensor)
        is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)

        if not is_tensor and not is_tensor_list:
            # No tensor outputs to shard
            return output

        output_list = [output] if is_tensor else list(output)
        actually_sharded = False

        for index, spm in self.metadata.items():
            if not isinstance(index, int):
                continue
            if not isinstance(spm, (SequenceParallelInput, SequenceParallelPartialInput)) or not spm.split_output:
                continue
            if index >= len(output_list):
                raise ValueError(f"Index {index} out of bounds for output of length {len(output_list)}.")

            original = output_list[index]
            output_list[index] = self._prepare_sp_input(original, spm, self._last_args, self._last_kwargs)
            if output_list[index] is not original:
                actually_sharded = True

        # Mark SP as active only if at least one tensor was actually sharded
        if actually_sharded and is_forward_context_available():
            get_forward_context()._sp_shard_depth += 1

        return output_list[0] if is_tensor else type(output)(output_list)

    def _resolve_text_len(
        self,
        sp_input: SequenceParallelPartialInput,
        args: tuple,
        kwargs: dict,
    ) -> int:
        """Resolve text length from the source specification."""
        source = sp_input.text_len_source

        if isinstance(source, int):
            return source

        # String source - look up from kwargs or args
        if source in self._text_len_cache:
            return self._text_len_cache[source]

        # Try to get from kwargs/args
        try:
            val, _, _ = self.module_forward_metadata._get_parameter_from_args_kwargs(source, args, kwargs)
            if val is None:
                raise ValueError(f"Parameter '{source}' is None, cannot determine text length.")
            if isinstance(val, torch.Tensor):
                # TODO: Currently assumes batch_size=1, where shape[0] is sequence length.
                # For batch inference support, this should be updated to handle
                # shape (batch_size, seq_len, ...) where text_len varies per sample.
                text_len = val.shape[0]
            elif isinstance(val, int):
                text_len = val
            else:
                raise ValueError(f"Cannot determine text length from '{source}' of type {type(val).__name__}")
            self._text_len_cache[source] = text_len
            return text_len
        except ValueError as e:
            raise ValueError(f"Failed to resolve text_len_source '{source}': {e}") from e

    def _prepare_sp_input(
        self,
        x: torch.Tensor,
        sp_input: AnySequenceParallelInput,
        args: tuple = (),
        kwargs: dict | None = None,
    ) -> torch.Tensor:
        """Shard a tensor according to the input specification."""
        kwargs = kwargs or {}

        if sp_input.expected_dims is not None and x.dim() != sp_input.expected_dims:
            logger.warning_once(f"Expected tensor with {sp_input.expected_dims} dims, got {x.dim()}. Skipping split.")
            return x

        if isinstance(sp_input, SequenceParallelInput):
            # Full split with optional auto-padding
            if sp_input.auto_pad:
                return self._shard_with_auto_pad(x, sp_input.split_dim)
            return sp_shard(x, sp_input.split_dim, validate=False)
        elif isinstance(sp_input, SequenceParallelPartialInput):
            # Partial split: keep text portion, split image portion
            text_len = self._resolve_text_len(sp_input, args, kwargs)
            dim = sp_input.split_dim

            # Split tensor into text and image portions
            text_part = x.narrow(dim, 0, text_len)
            image_part = x.narrow(dim, text_len, x.size(dim) - text_len)

            # Only shard the image portion
            image_part_sharded = sp_shard(image_part, dim, validate=False)

            # Concatenate back: [text_full, image_sharded]
            return torch.cat([text_part, image_part_sharded], dim=dim)
        else:
            raise ValueError(f"Unsupported input config type: {type(sp_input).__name__}")

    def _shard_with_auto_pad(self, x: torch.Tensor, dim: int) -> torch.Tensor:
        """Shard tensor with automatic padding and attention mask creation.

        When sequence length is not divisible by SP world size, this method:
        1. Pads the tensor to make it divisible
        2. Creates an attention mask indicating valid vs padding positions
        3. Stores the mask and padding info in ForwardContext
        """
        from vllm_omni.diffusion.attention.selector import get_attn_backend
        from vllm_omni.diffusion.distributed.parallel_state import (
            get_ring_parallel_world_size,
            get_sequence_parallel_rank,
            get_sequence_parallel_world_size,
        )
        from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available

        world_size = get_sequence_parallel_world_size()
        if world_size == 1:
            return x

        seq_len = x.size(dim)
        remainder = seq_len % world_size

        if remainder == 0:
            # No padding needed
            return sp_shard(x, dim, validate=False)

        # Check backend compatibility
        attn_backend = get_attn_backend(-1)
        if not attn_backend.supports_attention_mask:
            raise ValueError(
                f"Sequence length ({seq_len}) is not divisible by SP world size ({world_size}). "
                f"Cannot use {attn_backend.get_name()} which does not support attention_mask. "
                f"Please switch to SDPA or Ascend attention backend."
            )

        # Ring attention does not support attention_mask
        if get_ring_parallel_world_size() > 1:
            raise ValueError(
                f"Sequence length ({seq_len}) is not divisible by SP world size ({world_size}). "
                f"Cannot use Ring attention which does not support attention_mask. "
                f"Please switch to Ulysses SP only."
            )

        # Calculate padding
        pad_size = world_size - remainder
        padded_seq_len = seq_len + pad_size

        # Pad the tensor
        pad_shape = list(x.shape)
        pad_shape[dim] = pad_size
        padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
        x_padded = torch.cat([x, padding], dim=dim)

        # Store padding info in forward context (only once, for primary tensor)
        # Attention layers will create masks dynamically using this info
        if is_forward_context_available():
            ctx = get_forward_context()
            # Only set if not already set (first auto_pad tensor wins)
            if ctx.sp_original_seq_len is None:
                ctx.sp_padding_size = pad_size
                ctx.sp_original_seq_len = seq_len
                logger.debug(
                    f"Auto-padded sequence from {seq_len} to {padded_seq_len} "
                    f"(pad_size={pad_size}, world_size={world_size}, dim={dim})"
                )

        # Shard the padded tensor
        rank = get_sequence_parallel_rank()
        return x_padded.chunk(world_size, dim=dim)[rank]


class SequenceParallelGatherHook(ModelHook):
    """Hook for gathering outputs after a module's forward pass.

    This hook is registered to modules that need their outputs gathered
    from all sequence parallel ranks. It intercepts the output and gathers
    it according to the plan specification.

    Note: This corresponds to `ContextParallelGatherHook` in diffusers.
    """

    def __init__(
        self,
        metadata: SequenceParallelOutput | list[SequenceParallelOutput],
        config: SequenceParallelConfig,
    ) -> None:
        super().__init__()
        if isinstance(metadata, SequenceParallelOutput):
            metadata = [metadata]
        self.metadata = metadata
        self.config = config

    def initialize_hook(self, module: nn.Module) -> nn.Module:
        return module

    def post_forward(self, module: nn.Module, output: Any) -> Any:
        """Gather outputs after forward and remove padding if applied."""
        from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available

        is_tensor = isinstance(output, torch.Tensor)

        if is_tensor:
            output = [output]
        elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
            raise ValueError(f"Expected tensor or list/tuple of tensors, got {type(output).__name__}")

        output = list(output)

        if len(output) != len(self.metadata):
            raise ValueError(f"Expected {len(self.metadata)} outputs, got {len(output)}.")

        # Check if padding was applied during split
        original_seq_len = None
        if is_forward_context_available():
            ctx = get_forward_context()
            original_seq_len = ctx.sp_original_seq_len

        actually_gathered = False

        for i, spm in enumerate(self.metadata):
            if spm is None:
                continue

            x = output[i]
            if spm.expected_dims is not None and x.dim() != spm.expected_dims:
                logger.warning_once(
                    f"Expected output tensor with {spm.expected_dims} dims, got {x.dim()}. Skipping gather."
                )
                continue

            # Gather from all ranks
            gathered = sp_gather(x, spm.gather_dim, validate=False)

            # Remove padding if it was applied
            if original_seq_len is not None and gathered.size(spm.gather_dim) > original_seq_len:
                gathered = gathered.narrow(spm.gather_dim, 0, original_seq_len)
                logger.debug(f"Removed padding: gathered shape {gathered.shape} (original_seq_len={original_seq_len})")

            output[i] = gathered
            actually_gathered = True

        # Mark SP as inactive only if at least one tensor was actually gathered
        if actually_gathered and is_forward_context_available():
            ctx = get_forward_context()
            ctx._sp_shard_depth = max(0, ctx._sp_shard_depth - 1)

        return output[0] if is_tensor else type(output)(output)


def _get_submodule_by_name(model: nn.Module, name: str) -> nn.Module | list[nn.Module]:
    """Get a submodule by dotted name, supporting wildcards.

    Args:
        model: The root module.
        name: Dotted path to submodule. Use "*" to match all children
            of a ModuleList.

    Returns:
        The submodule or list of submodules if wildcard used.

    Raises:
        ValueError: If the path is invalid or module not found.
    """
    if name.count("*") > 1:
        raise ValueError("Wildcard '*' can only be used once in the name")
    return _find_submodule_by_name(model, name)


def _find_submodule_by_name(model: nn.Module, name: str) -> nn.Module | list[nn.Module]:
    """Recursive helper for _get_submodule_by_name."""
    if name == "":
        return model

    first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")

    if first_atom == "*":
        if not isinstance(model, nn.ModuleList):
            raise ValueError("Wildcard '*' can only be used with ModuleList")
        submodules = []
        for submodule in model:
            subsubmodules = _find_submodule_by_name(submodule, remaining_name)
            if not isinstance(subsubmodules, list):
                subsubmodules = [subsubmodules]
            submodules.extend(subsubmodules)
        return submodules
    else:
        if hasattr(model, first_atom):
            submodule = getattr(model, first_atom)
            return _find_submodule_by_name(submodule, remaining_name)
        else:
            raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")


def apply_sequence_parallel(
    module: nn.Module,
    config: SequenceParallelConfig,
    plan: SequenceParallelModelPlan,
) -> None:
    """Apply sequence parallel hooks to a model according to the plan.

    This function registers hooks on the specified submodules to automatically
    shard inputs and gather outputs for sequence parallelism.

    Note: This corresponds to `apply_context_parallel` in diffusers.

    The complete SP flow is:
    1. Input sharding (SequenceParallelSplitHook): Split sequence across SP ranks
    2. Attention parallelism (handled by vLLM-Omni's Attention layer):
       - Ulysses: All-to-All over Q/K/V heads
       - Ring: K/V circulation in ring topology
       - Hybrid: Both (Ulysses handles head redistribution, Ring handles K/V)
    3. Output gathering (SequenceParallelGatherHook): Gather sequence from SP ranks

    Args:
        module: The model to apply SP to.
        config: The sequence parallel configuration.
        plan: Dictionary mapping module names to input/output specifications.

    Example:
        config = SequenceParallelConfig(ulysses_degree=2)
        plan = {
            "": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)},
            "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
        }
        apply_sequence_parallel(model, config, plan)

    Note:
        vLLM-Omni's Attention layer automatically handles the internal
        parallelism (Ulysses All-to-All or Ring attention) based on the
        forward_context configuration. This function only handles input/output
        sharding for the model as a whole.
    """
    logger.debug(
        f"Applying sequence parallel with config: ulysses={config.ulysses_degree}, "
        f"ring={config.ring_degree}, plan keys: {list(plan.keys())}"
    )

    for module_id, sp_model_plan in plan.items():
        submodule = _get_submodule_by_name(module, module_id)
        if not isinstance(submodule, list):
            submodule = [submodule]

        logger.debug(f"Applying SP hooks to '{module_id}' ({len(submodule)} module(s))")

        for m in submodule:
            if isinstance(sp_model_plan, dict):
                # Input specification
                hook = SequenceParallelSplitHook(sp_model_plan, config)
                hook_name = _SP_INPUT_HOOK_TEMPLATE.format(module_id)
            elif isinstance(sp_model_plan, (SequenceParallelOutput, list, tuple)):
                # Output specification
                if isinstance(sp_model_plan, SequenceParallelOutput):
                    sp_model_plan = [sp_model_plan]
                if not all(isinstance(x, SequenceParallelOutput) or x is None for x in sp_model_plan):
                    raise ValueError(f"Expected SequenceParallelOutput elements, got {sp_model_plan}")
                hook = SequenceParallelGatherHook(sp_model_plan, config)
                hook_name = _SP_OUTPUT_HOOK_TEMPLATE.format(module_id)
            else:
                raise ValueError(f"Unsupported plan type: {type(sp_model_plan).__name__}")

            registry = HookRegistry.get_or_create(m)
            registry.register_hook(hook_name, hook)


def remove_sequence_parallel(
    module: nn.Module,
    plan: SequenceParallelModelPlan,
) -> None:
    """Remove sequence parallel hooks from a model.

    Note: This corresponds to `remove_context_parallel` in diffusers.

    Args:
        module: The model to remove SP from.
        plan: The same plan used when applying SP.
    """
    for module_id, sp_model_plan in plan.items():
        submodule = _get_submodule_by_name(module, module_id)
        if not isinstance(submodule, list):
            submodule = [submodule]

        for m in submodule:
            registry = getattr(m, "_hook_registry", None)
            if registry is None:
                continue

            if isinstance(sp_model_plan, dict):
                hook_name = _SP_INPUT_HOOK_TEMPLATE.format(module_id)
            elif isinstance(sp_model_plan, (SequenceParallelOutput, list, tuple)):
                hook_name = _SP_OUTPUT_HOOK_TEMPLATE.format(module_id)
            else:
                continue

            registry.remove_hook(hook_name)


def enable_sequence_parallel_for_model(
    model: nn.Module,
    config: SequenceParallelConfig | None = None,
) -> None:
    """Enable sequence parallelism for a model using its _sp_plan.

    This is a convenience function that reads the model's _sp_plan attribute
    and applies sequence parallelism automatically.

    Note: This corresponds to `enable_context_parallel_for_model` in diffusers,
    but uses vLLM-Omni's _sp_plan instead of diffusers' _cp_plan.

    The function performs two main tasks:
    1. Applies _sp_plan hooks to shard inputs and gather outputs
    2. Ensures Attention layers are configured for the correct parallel mode
       (handled automatically by vLLM-Omni's forward_context mechanism)

    Args:
        model: The model to enable SP for. Must have a _sp_plan attribute.
        config: Optional config. If None, uses default based on current
            parallel state.

    Raises:
        ValueError: If model has no _sp_plan defined.

    Note:
        vLLM-Omni supports Ulysses + Ring hybrid parallelism:
        - ulysses_degree > 1: Uses All-to-All communication over Q/K/V heads
        - ring_degree > 1: Uses Ring attention with K/V passing
        - Both > 1: Hybrid mode (Ulysses handles head redistribution,
          Ring handles K/V circulation)
    """
    from vllm_omni.diffusion.distributed.parallel_state import (
        get_ring_parallel_world_size,
        get_ulysses_parallel_world_size,
    )
    from vllm_omni.diffusion.distributed.sp_plan import get_sp_plan_from_model

    plan = get_sp_plan_from_model(model)
    if plan is None:
        raise ValueError(
            f"Model {model.__class__.__name__} has no _sp_plan defined. "
            f"Define _sp_plan as a class attribute or pass a plan explicitly."
        )

    if config is None:
        # Create config from current parallel state
        ulysses_degree = get_ulysses_parallel_world_size()
        ring_degree = get_ring_parallel_world_size()
        config = SequenceParallelConfig(
            ulysses_degree=ulysses_degree,
            ring_degree=ring_degree,
        )
        if ulysses_degree > 1 and ring_degree > 1:
            mode = "hybrid"
        elif ulysses_degree > 1:
            mode = "ulysses"
        else:
            mode = "ring"
        logger.info(
            f"Created SP config from parallel state: "
            f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, "
            f"mode={mode}"
        )

    apply_sequence_parallel(model, config, plan)
    logger.info(f"Enabled sequence parallelism for {model.__class__.__name__}")


def disable_sequence_parallel_for_model(model: nn.Module) -> None:
    """Disable sequence parallelism for a model.

    Note: This corresponds to `disable_context_parallel_for_model` in diffusers.

    Args:
        model: The model to disable SP for.
    """
    from vllm_omni.diffusion.distributed.sp_plan import get_sp_plan_from_model

    plan = get_sp_plan_from_model(model)
    if plan is not None:
        remove_sequence_parallel(model, plan)
