o
     iqO                     @   s  d dl Z d dlZd dlmZ d dlmZ d dlmZmZ d dl	m
Z
mZmZmZmZmZmZ d dlZd dlmZmZmZ d dlmZ d dlmZ d	Zzd d
lmZmZmZ dZW n eyc   d	ZY nw zd dl m!Z! d dl"m#Z#m$Z$ W n ey   ej%j&Z!G dd dZ'e'Z#e'Z$Y nw z
d dl"m(Z) da*W n ey   d dl"m)Z) d	a*Y nw ej+j,j-j.ej+j/j0j1ej+j,j2j.hZ3e)e3B Z4eG dd dZ5d=ddZ6G dd deZ7dd Z8G dd de#Z9G dd deZ:d=ddZ;ddd d!e
fd"d#Z<G d$d% d%eZ=d!ee5 fd&d'Z>d(e?d!efd)d*Z@d+ejAd,ejAd-e?d.eeB d/eeeBd0f  d1eeB d2eCd!ejAfd3d4ZDG d5d6 d6ZEG d7d8 d8e!ZF		d>d9ej%j&d(ee? d:ee fd;d<ZGdS )?    N)defaultdict)deepcopy)astuple	dataclass)AnyCallableContextManagerDictListOptionalTuple)
is_inplaceis_inplace_view_fn
is_view_fn)TorchDispatchMode)tree_mapF)BoundsLinearConstraintmilpT)ActivationWrapper)_CachedTorchDispatchMode_CachingTorchDispatchModec                   @   s   e Zd Zdd ZdS )_NotAvailablec                 O   s   t d)NzNeed PyTorch >= 2.2)RuntimeErrorselfargskwargs r   G/home/ubuntu/.local/lib/python3.10/site-packages/xformers/checkpoint.py__init__-   s   z_NotAvailable.__init__N)__name__
__module____qualname__r    r   r   r   r   r   ,   s    r   )SAC_IGNORED_OPS)_ignored_opsc                   @   sV   e Zd ZU eed< eed< eed< eed< eed< eeef ed< e	ed< e	ed< d	S )
ProfileMetadataname
time_takenmemory_usedcurr_idx
output_idsinplace_infois_view_like
is_rand_opN)
r!   r"   r#   str__annotations__floatintr   r   boolr   r   r   r   r&   F   s   
 r&   c                    s$   g d} d u r
|  fdd}|S )N)z4xformers.efficient_attention_forward_cutlass.defaultz xformers_flash.flash_fwd.defaultzaten.addmm.defaultzaten.mm.defaultc                    s   t | v S N)r/   )ctxfuncr   r   
allow_listr   r   _default_policy\   s   z,_get_default_policy.<locals>._default_policyr   )r8   _default_allow_listr9   r   r7   r   _get_default_policyR   s
   r;   c                   @   s   e Zd Zdd ZdddZdS )VerboseTorchDispatchModec                 C   s
   g | _ d S r4   )	operators)r   r   r   r   r    c   s   
z!VerboseTorchDispatchMode.__init__r   Nc                 C   s&   |d u ri }| j | ||i |S r4   )r=   appendr   r6   typesr   r   r   r   r   __torch_dispatch__f   s   z+VerboseTorchDispatchMode.__torch_dispatch__r   N)r!   r"   r#   r    rA   r   r   r   r   r<   b   s    r<   c                 O   sB   t  }| | |i | W d   |jS 1 sw   Y  |jS )zZ
    Returns the list of operators used inside `function` with
    *args and **kwargs
    N)r<   r=   )functionr   r   verbose_moder   r   r   list_operatorsm   s   
rE   c                       s$   e Zd Z fddZdd Z  ZS )CachedTorchDispatchModec                    s*   t rt ||| d S t || d S r4   )_PT_HAS_NEW_IMPLsuperr    )r   	policy_fnstorageallow_cache_entry_mutation	__class__r   r   r    y   s   z CachedTorchDispatchMode.__init__c                 C   s(   | j | r| j | dS ||i |S Nr   )rJ   pop)r   r6   r   r   r   r   r   pop_from_storage   s   
z(CachedTorchDispatchMode.pop_from_storage)r!   r"   r#   r    rP   __classcell__r   r   rL   r   rF   x   s    rF   c                   @   s   e Zd ZdddZdS )NullTorchDispatchModer   Nc                 C   s   |d u ri }||i |S r4   r   r?   r   r   r   rA      s   z(NullTorchDispatchMode.__torch_dispatch__rB   )r!   r"   r#   rA   r   r   r   r   rR      s    rR   c                 C   sr   | du rt  } nt| trt | } nt| sJ dtt}t r*tt| |}nt	 }t
t| |d}||fS )a  An activation checkpoint context_fn for selectively deciding what to
    store and what to recompute. Accepts a custom policy.
    Args:
        policy_fn(Union[List[Op], callable]): policy for deciding what to
            store (instead of recompute). If it's a function, it should
            be of form (func, *args, **kwargs) -> bool which indicates
            if func outputs with *args and **kwargs should be stored or not.
            Additionally, a list[Op] is also supported for easier cases.
            The op should be in the format `torch.ops.***`, where the `***`
            names of operators can be obtained with `list_operators`.
    Nz,policy_fn should be None, list or a callableT)r;   
isinstancelistcallabler   torchis_grad_enabledr   r   rR   rF   )rI   temp_storagecaching_modecached_moder   r   r   selective_checkpoint_context_fn   s   

r[   )preserve_rng_staterI   returnc                O   s,   t jjj| g|R d|tt|d|S )a8  Wrapper around torch.utils.checkpoint that accepts a custom policy
    function for selectively deciding what to store and what to recompute
    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional):  Omit stashing and restoring
            the RNG state during each checkpoint.
            Default: ``True``
        policy_fn(Union[List[Op], callable]): policy for deciding what to
            store (instead of recompute). If it's a function, it should
            be of form (func, *args, **kwargs) -> bool which indicates
            if func outputs with *args and **kwargs should be stored or not.
            Additionally, a list[Op] is also supported for easier cases.
            The op should be in the format `torch.ops.***`, where the `***`
            names of operators can be obtained with `list_operators`.
        *args: Arguments to pass in to the given ``function``.
        **kwargs: Keyword arguments to pass into the given ``function``.
    F)use_reentrantr\   
context_fn)rV   utils
checkpoint	functoolspartialr[   )rC   r\   rI   r   r   r   r   r   ra      s   
ra   c                   @   sJ   e Zd ZddeddfddZdeeeeedf f fdd	ZdddZdS )!ProfileOperatorsTorchDispatchMode
   num_runsr]   Nc                 C   s   g | _ || _d S r4   )datarf   )r   rf   r   r   r   r       s   
z*ProfileOperatorsTorchDispatchMode.__init__.c                 C   s   t | j}dd }t||}t|s||dfS |}d}t| jD ]\}}	|	j}
t|
ttt	fs2|
gn|
}
||
v r<|} nq |dk rC|}||f}|||fS )Nc                 S   s   t | tjr|   S d S r4   )rS   rV   Tensoruntyped_storagedata_ptr)er   r   r   get_tensor_id   s   zNProfileOperatorsTorchDispatchMode._get_inplace_metadata.<locals>.get_tensor_idr   r   )
lenrg   r   r   	enumerater+   rS   rT   tupledict)r   r6   outr*   rl   r+   op_idop_parent_ididpast_output_idsr,   r   r   r   _get_inplace_metadata   s*   



z7ProfileOperatorsTorchDispatchMode._get_inplace_metadatar   c                 C   s  |d u ri }||i |}|  ||\}}}t|pt|}	tjj|jv }
|jjdkr3|	dddk}
tj
  t }t| jD ]	}||i | qAtj
  t | | j }tj
  tj
 d }||i | tj
 d }| jt|||| ||||	|
 |S )N#_scaled_dot_product_flash_attention	dropout_pr   i   )rx   r   r   rV   Tagnondeterministic_seededtagsoverloadpacketr!   getcudasynchronizetimerangerf   reset_peak_memory_statsmax_memory_allocatedrg   r>   r&   )r   r6   r@   r   r   rr   r*   r+   r,   r-   r.   tru   r(   mem1mem2r   r   r   rA      s>   


z4ProfileOperatorsTorchDispatchMode.__torch_dispatch__)re   rB   )r!   r"   r#   r2   r    r   rx   rA   r   r   r   r   rd      s     rd   c                 G   s<   t  }| | |  W d   n1 sw   Y  |j}|S )a  
    Use ProfileOperatorsTorchDispatchMode to get runtime and memory info.

    Args:
        function: The function to optimize which will be selectively checkpointed. Usually the forward pass
            of the model.
        *args: Arguments to pass in to the given ``function``.

    Returns:
        A list of tuples, where each tuples contains the name of the operator, the runtime of the operator,
            and the memory usage of the operator.

    N)rd   rg   )rC   r   profile_opsrg   r   r   r   _analyze_operators  s   
r   memory_budgetc             	      s\  t std|dk s|dkrtd| dt| g|R  }dd |D }tdd |D  \}}} }}}	}
tj|tjd	}tj|tjd	}d
d t|	D }dd t|
D } fdd|D }t	|d }t
|t
dd |D B }tt|}t|D ]
}||kr|d8 }q~d||< ||   }tdd |D }t|||||||d}t|dS )a  
    Given a function, its arguments, and the maximum amount of memory available,
    find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget.

    Args:
        function: The function to optimize which will be selectively checkpointed. Usually the forward pass
            of the model.
        *args: Arguments to pass in to the given ``function``.
        memory_budget (float): A float between 0 and 1 which describes what percentage of the total memory to use.

    Returns:
        A callable policy which can be passed to xformers.checkpoint()

    Raises:
        RuntimeError: If `scipy` is not available.
        ValueError: If `memory_budget` is not a float between 0 and 1.

    zlPlease install scipy 1.9.0+ to use `get_optimal_checkpoint_policy`. You can do so using `pip install scipy`.r      z5`memory_budget` must be a float between 0 and 1. Got .c                 S   s   g | ]	}|j tvr|qS r   )r'   OPS_TO_ALWAYS_SKIP.0xr   r   r   
<listcomp>U      z1get_optimal_checkpoint_policy.<locals>.<listcomp>c                 S   s   g | ]}t |qS r   )r   r   r   r   r   r   X      )dtypec                 S      g | ]\}}|r|qS r   r   r   ru   r   r   r   r   r   \      c                 S   r   r   r   r   r   r   r   r   ]  r   c                    s    g | ]}|rt t j|qS r   )rp   mapindexr   new_idsr   r   r   `  s     c                 S   s   g | ]}|d  qS )r   r   r   r   r   r   r   i  r   c                 S   s   g | ]	}t |tj qS r   )rS   rV   rh   r   r   r   r   r   t  r   )memoryruntimes
max_memoryview_like_opsinplace_ops
random_opsforce_store_random)optim_output)_scipy_is_availabler   
ValueErrorr   ziprV   tensorfloat64ro   rn   setsortedrT   reversedsumitemall#_optimize_runtime_with_given_memory_OptimalPolicy)rC   r   r   rg   ops	runtimes_memory__inplace_ops_view_like_ops_	rand_ops_r   r   r   rand_opsr   last_op	skip_ops_skip_opsopr   r   r   r   r   r   get_optimal_checkpoint_policy5  sL   

	r   r   r   r   r   r   .r   r   c                 C   s   | }t | |d}|g}	|D ]}
t|}d||
< |	t |ddd q|D ].\}}t|}||krFd||< d||< |	t |ddd q&d||< |	t |ddd q&|D ]}
t|}d||
< t|}|	t |||d qWt|}t||	|tddd}|jst	dt
|j}|S )aD  
    Given a list of operator names, their corresponding runtimes, and the maximum amount of memory available,
    find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget.
    Uses https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.milp.html

    Args:
        memory (torch.Tensor): Tensor containing the memory usage of each operator.
        runtimes (torch.Tensor): Tensor containing the runtime of each operator.
        max_memory (float): Maximum amount of memory to use.
        view_like_ops ([List[int]): Indices of the view-like ops.
        inplace_ops (List[Tuple[int, int]]): Tuple with the pair of inplace op -> parent of inplace op.
            This will be used to add the constraint that in-place ops need to either be
            stored in memory with the previous op, or recomputed with the previous op.
        random_ops ([List[int]): Indices of the random ops, which will always be recomputed.
        force_store_random (bool): force random ops to always be stored (instead of recomputed)
    )Aubr   r   )r   lbr   rm   )cconstraintsintegralityboundszThe problem is infeasible, and probably due to a change in xformers that makes random ops always be stored. Try passing a larger memory_budget. This will be fixed once https://github.com/pytorch/pytorch/issues/121212 is solved)r   rV   
zeros_liker>   r2   	ones_liker   r   successr   
from_numpyr   )r   r   r   r   r   r   r   r   memory_constraintr   ru   r   r   	op_parentvalr   resr   r   r   r   r     s<   



r   c                   @   s*   e Zd ZdejfddZdefddZdS )r   r   c                 C   s   d| _ | | _d S rN   )countertolistr   )r   r   r   r   r   r      s   z_OptimalPolicy.__init__r]   c                 O   s.   |t v rdS | j}|  jd7  _| j| dkS )NFr   )r   r   r   )r   r5   r6   r   r   countr   r   r   __call__  s
   z_OptimalPolicy.__call__N)r!   r"   r#   rV   rh   r    r3   r   r   r   r   r   r     s    r   c                       s>   e Zd Zd
 fdd	Zejjdd Zdd Zdd	 Z	  Z
S )SelectiveCheckpointWrapperNc                    sV   t  | |d u |d u A std|| _|| _zdtjj_W d S  t	y*   Y d S w )Nz1Need to specify either policy_fn or memory_budgetT)
rH   r    r   r   rI   rV   _dynamoconfig:_experimental_support_context_fn_in_torch_utils_checkpointAttributeError)r   modr   rI   rL   r   r   r      s   z#SelectiveCheckpointWrapper.__init__c                 O   s   t  sg S t j  t| jg|R i |d| ji}W d    n1 s'w   Y  t j rLt j	 rLt j
 dkrL|g}t jj|dd |d }|S )Nr   r   r   )src)rV   rW   randomfork_rngr   _checkpoint_wrapped_moduler   distributedis_availableis_initializedget_world_sizebroadcast_object_list)r   r   r   rI   objectsr   r   r   _get_policy_fn  s,   
z)SelectiveCheckpointWrapper._get_policy_fnc                 O   s"   | j d u r| j|i || _ | j S r4   )rI   r   r   r   r   r   get_policy_fn  s   
z(SelectiveCheckpointWrapper.get_policy_fnc                 O   s0   | j |i |}t| jg|R i |d|iS )NrI   )r   ra   r   )r   r   r   rI   r   r   r   forward  s   z"SelectiveCheckpointWrapper.forwardNN)r!   r"   r#   r    rV   compilerdisabler   r   r   rQ   r   r   rL   r   r     s    
r   modulerI   c                 C   s   t | ||S )a  
    Wrap a module with selective activation checkpointing.

    It behaves similarly to PyTorch's checkpoint_wrapper, but gives the possibility
    to the user to either specify a handcrafted policy_fn, or to let an optimization
    algorithm to select the policy given a user-specified memory_budget.

    The user should either specify the memory_budget argument or the policy_fn.

    The memory_budget is a float value between 0 (recompute everything in the backward) or 1
    (store everything for backward). Using a value of 0 should be similar to PyTorch's
    activation checkpoint, while 1 should be similar to the behavior of not using any
    activation checkpointing.
    )r   )r   r   rI   r   r   r   selective_checkpoint_wrapper  s   r   r4   r   )Hrb   r   collectionsr   copyr   dataclassesr   r   typingr   r   r   r	   r
   r   r   rV   ,torch.testing._internal.composite_compliancer   r   r   torch.utils._python_dispatchr   torch.utils._pytreer   r   scipy.optimizer   r   r   ImportError;torch.distributed.algorithms._checkpoint.checkpoint_wrapperr   torch.utils.checkpointr   r   nnModuler   r$   r%   rG   r   aten
lift_freshdefaultprofiler_record_function_exit_RecordFunctionclone_additional_ignored_opsr   r&   r;   r<   rE   rF   rR   r[   ra   rd   r   r1   r   rh   r2   r3   r   r   r   r   r   r   r   r   <module>   s   $




 
"NM
K6