o
    X۷ib                     @  s*  d dl mZ d dlZd dlZd dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ z
d dlmZ dZW n ey?   d	ZY nw ejrlejejejejejejejejejejejejejd
ZejejejejdZ ni Zi Z dddZ!G dd de
Z"G dd dZ#dd Z$dd Z%G dd dZ&dS )    )annotationsN)nccl)_store)_Backend)sparse)MPITF)bBiIlLqQefdFD)sumprodmaxminc                 C  sT   | j j}|tvrtd| j  dt| }|d u r| j}|dv r&|d| fS ||fS )NUnknown dtype 	 for NCCLFD   )dtypechar_nccl_dtypes	TypeErrorsize)arraycountr   
nccl_dtype r%   R/home/ubuntu/vllm_env/lib/python3.10/site-packages/cupyx/distributed/_nccl_comm.py_get_nccl_dtype_and_count0   s   r'   c                      s   e Zd ZdZejejdf fdd	Zdd Zdd Z	d	d
 Z
dd Zdd Zdd Zd,ddZd-ddZd.ddZ	d,ddZd/ddZd/ddZd/d d!Zd/d"d#Zd.d$d%Zd.d&d'Zd/d(d)Zd*d+ Z  ZS )0NCCLBackenda  Interface that uses NVIDIA's NCCL to perform communications.

    Args:
        n_devices (int): Total number of devices that will be used in the
            distributed execution.
        rank (int): Unique id of the GPU that the communicator is associated to
            its value needs to be `0 <= rank < n_devices`.
        host (str, optional): host address for the process rendezvous on
            initialization. Defaults to `"127.0.0.1"`.
        port (int, optional): port used for the process rendezvous on
            initialization. Defaults to `13333`.
        use_mpi(bool, optional): switch between MPI and use the included TCP
            server for initialization & synchronization. Defaults to `False`.
    Fc                   sF   t  |||| to|| _| jr| || d S | |||| d S N)super__init___mpi_available_use_mpi_init_with_mpi_init_with_tcp_store)self	n_devicesrankhostportuse_mpi	__class__r%   r&   r+   L   s
   
zNCCLBackend.__init__c                 C  sX   t j| _| j | _| j  d }| jdkrt }| jj|dd}t	|||| _
d S )Nr   root)r   
COMM_WORLD	_mpi_commGet_rank	_mpi_rankBarrierr   get_unique_idbcastNcclCommunicator_comm)r0   r1   r2   nccl_idr%   r%   r&   r.   V   s   

zNCCLBackend._init_with_mpic                 C  s`   d }|dkr| j || t }|| jd< | j  n
| j  | jd }t|||| _d S )Nr   rC   )r   runr   r?   _store_proxybarrierrA   rB   )r0   r1   r2   r3   r4   rC   r%   r%   r&   r/   c   s   


z NCCLBackend._init_with_tcp_storec                 C  s    |j js|j jstdd S d S )Nz4NCCL requires arrays to be either c- or f-contiguous)flagsc_contiguousf_contiguousRuntimeError)r0   r"   r%   r%   r&   _check_contiguouso   s
   zNCCLBackend._check_contiguousc                 C  s   |d u r
t jj }|jS r)   )cupycudastreamget_current_streamptr)r0   rN   r%   r%   r&   _get_streamt   s   zNCCLBackend._get_streamc                 C  s8   |t vrtd| d|dv r|dkrtdt | S )NzUnknown op r   r   r   z-Only nccl.SUM is supported for complex arrays)	_nccl_opsrJ   
ValueError)r0   opr   r%   r%   r&   _get_opy   s   zNCCLBackend._get_opc                 C  sT   t }t|d ttfrt|d d st|d rt}t||| g|R   d S Nr   )_DenseNCCLCommunicator
isinstancelisttupler   issparse_SparseNCCLCommunicatorgetattr)r0   functionargs
comm_classr%   r%   r&   _dispatch_arg_type   s   zNCCLBackend._dispatch_arg_typer   Nc                 C     |  d||||f dS )a  Performs an all reduce operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_reduceNra   )r0   in_array	out_arrayrT   rN   r%   r%   r&   rc      s   zNCCLBackend.all_reducer   c                 C     |  d|||||f dS )a  Performs a reduce operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
                will only be modified by the `root` process.
            root (int, optional): rank of the process that will perform the
                reduction. Defaults to `0`.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        reduceNrd   )r0   re   rf   r9   rT   rN   r%   r%   r&   rh      s   zNCCLBackend.reducec                 C     |  d|||f dS )a  Performs a broadcast operation.

        Args:
            in_out_array (cupy.ndarray): array to be sent for `root` rank.
                Other ranks will receive the broadcast data here.
            root (int, optional): rank of the process that will send the
                broadcast. Defaults to `0`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        	broadcastNrd   )r0   in_out_arrayr9   rN   r%   r%   r&   rj      s   
zNCCLBackend.broadcastc                 C  rg   )a/  Performs a reduce scatter operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            count (int): Number of elements to send to each rank.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        reduce_scatterNrd   )r0   re   rf   r#   rT   rN   r%   r%   r&   rl      s   zNCCLBackend.reduce_scatterc                 C  rb   )as  Performs an all gather operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            count (int): Number of elements to send to each rank.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_gatherNrd   )r0   re   rf   r#   rN   r%   r%   r&   rm      s   
zNCCLBackend.all_gatherc                 C  ri   )a  Performs a send operation.

        Args:
            array (cupy.ndarray): array to be sent.
            peer (int): rank of the process `array` will be sent to.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        sendNrd   )r0   r"   peerrN   r%   r%   r&   rn         	zNCCLBackend.sendc                 C  ri   )a2  Performs a receive operation.

        Args:
            array (cupy.ndarray): array used to receive data.
            peer (int): rank of the process `array` will be received from.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        recvNrd   )r0   rf   ro   rN   r%   r%   r&   rq      rp   zNCCLBackend.recvc                 C  rb   )a  Performs a send and receive operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array used to receive data.
            peer (int): rank of the process to send `in_array` and receive
                `out_array`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        	send_recvNrd   )r0   re   rf   ro   rN   r%   r%   r&   rr         zNCCLBackend.send_recvc                 C  rb   )a  Performs a scatter operation.

        Args:
            in_array (cupy.ndarray): array to be sent. Its shape must be
                `(total_ranks, ...)`.
            out_array (cupy.ndarray): array where the result with be stored.
            root (int): rank that will send the `in_array` to other ranks.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        scatterNrd   r0   re   rf   r9   rN   r%   r%   r&   rt      rs   zNCCLBackend.scatterc                 C  rb   )a  Performs a gather operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
                Its shape must be `(total_ranks, ...)`.
            root (int): rank that will receive `in_array` from other ranks.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        gatherNrd   ru   r%   r%   r&   rv     rs   zNCCLBackend.gatherc                 C  ri   )a  Performs an all to all operation.

        Args:
            in_array (cupy.ndarray): array to be sent. Its shape must be
                `(total_ranks, ...)`.
            out_array (cupy.ndarray): array where the result with be stored.
                Its shape must be `(total_ranks, ...)`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_to_allNrd   )r0   re   rf   rN   r%   r%   r&   rw     s   
zNCCLBackend.all_to_allc                 C  s"   | j r
| j  dS | j  dS )zPerforms a barrier operation.

        The barrier is done in the cpu and is a explicit synchronization
        mechanism that halts the thread progression.
        N)r-   r;   r>   rE   rF   )r0   r%   r%   r&   rF   '  s   zNCCLBackend.barrierr   Nr   r   Nr   Nr)   )__name__
__module____qualname____doc__r   _DEFAULT_HOST_DEFAULT_PORTr+   r.   r/   rK   rQ   rU   ra   rc   rh   rj   rl   rm   rn   rq   rr   rt   rv   rw   rF   __classcell__r%   r%   r6   r&   r(   <   s0    












r(   c                   @  s   e Zd ZedddZedddZed dd	Ze	dd
dZed!ddZed!ddZ	ed!ddZ
ed!ddZed!ddZed!ddZed ddZed ddZed!ddZdS )"rW   r   Nc                 C  s\   | | | | ||}t|\}}|||jj}|j|jj	|jj	|||| d S r)   )
rK   rQ   r'   rU   r   r   rB   	allReducedatarP   )clscommre   rf   rT   rN   r   r#   r%   r%   r&   rc   7  s   


z!_DenseNCCLCommunicator.all_reducer   c           	   	   C  sh   | | |j|kr| | ||}t|\}}|||jj}|j|j	j
|j	j
||||| d S r)   )rK   r2   rQ   r'   rU   r   r   rB   rh   r   rP   )	r   r   re   rf   r9   rT   rN   r   r#   r%   r%   r&   rh   A  s   




z_DenseNCCLCommunicator.reducec                 C  sB   | | ||}t|\}}|j|jj|jj|||| d S r)   )rK   rQ   r'   rB   rj   r   rP   )r   r   rk   r9   rN   r   r#   r%   r%   r&   rj   M  s   

z _DenseNCCLCommunicator.broadcastc                 C  s^   | | | | ||}t||\}}|||jj}|j|jj	|jj	|||| d S r)   )
rK   rQ   r'   rU   r   r   rB   reduceScatterr   rP   )r   r   re   rf   r#   rT   rN   r   r%   r%   r&   rl   V  s   


z%_DenseNCCLCommunicator.reduce_scatterc                 C  sL   | | | | ||}t||\}}|j|jj|jj||| d S r)   )rK   rQ   r'   rB   	allGatherr   rP   )r   r   re   rf   r#   rN   r   r%   r%   r&   rm   a  s   


z!_DenseNCCLCommunicator.all_gatherc                 C  8   | | ||}t|\}}| |||||| d S r)   )rK   rQ   r'   _send)r   r   r"   ro   rN   r   r#   r%   r%   r&   rn   j     

z_DenseNCCLCommunicator.sendc                 C     |j |jj|||| d S r)   )rB   rn   r   rP   r   r   r"   ro   r   r#   rN   r%   r%   r&   r   q     z_DenseNCCLCommunicator._sendc                 C  r   r)   )rK   rQ   r'   _recv)r   r   rf   ro   rN   r   r#   r%   r%   r&   rq   u  r   z_DenseNCCLCommunicator.recvc                 C  r   r)   )rB   rq   r   rP   r   r   rf   ro   r   r#   rN   r%   r%   r&   r   |  r   z_DenseNCCLCommunicator._recvc           
      C  sr   | | | | ||}t|\}}t|\}}	t  | |||||| | |||||	| t  d S r)   )rK   rQ   r'   r   
groupStartr   r   groupEnd)
r   r   re   rf   ro   rN   idtypeicountodtypeocountr%   r%   r&   rr     s   


z _DenseNCCLCommunicator.send_recvc              	   C  s   |j d |jkrtd|j d|j  || || ||}t  ||jkrHt|jD ]}|| }t	|\}}	| 
|||||	| q1t	|\}
}| ||||
|| t  d S )Nr   z"scatter requires in_array to have 'elements in its first dimension, found )shape
_n_devicesrJ   rK   rQ   r   r   r2   ranger'   r   r   r   )r   r   re   rf   r9   rN   r
   r"   r   r   r   r#   r%   r%   r&   rt     s$   




z_DenseNCCLCommunicator.scatterc              	   C  s   |j d |jkrtd|j d|j  || || ||}t  ||jkrHt|jD ]}|| }t	|\}}	| 
|||||	| q1t	|\}
}| ||||
|| t  d S )Nr   z"gather requires out_array to have r   )r   r   rJ   rK   rQ   r   r   r2   r   r'   r   r   r   )r   r   re   rf   r9   rN   r
   r"   r   r   r   r#   r%   r%   r&   rv     s$   




z_DenseNCCLCommunicator.gatherc           
   	   C  s   |j d |jkrtd|j d|j  |j d |jkr(td|j d|j  || || ||}t|d \}}t|d \}}t  t|jD ]}	| 	|||	 |	||| | 
|||	 |	||| qPt  d S )Nr   %all_to_all requires in_array to have r   z&all_to_all requires out_array to have )r   r   rJ   rK   rQ   r'   r   r   r   r   r   r   )
r   r   re   rf   rN   r   r   r   r   r
   r%   r%   r&   rw     s,   




z!_DenseNCCLCommunicator.all_to_allrx   ry   rz   r)   )r{   r|   r}   classmethodrc   rh   rj   rl   rm   rn   r   rq   r   rr   rt   rv   rw   r%   r%   r%   r&   rW   5  s8    	
rW   c                 C  s   t d| }t dd}t dd}|dkr tj|||fddS |dkr.tj|||fddS |dkr=tj|||ffddS td)	N   r
   csr)r   r   )r   csccoo4NCCL is not supported for this type of sparse matrix)rL   emptyr   
csr_matrix
csc_matrix
coo_matrixr    )r   sparse_typer   ar   r%   r%   r&   _make_sparse_empty  s   r   c                 C  s2   t | rdS t | rdS t | rdS td)Nr   r   r   r   )r   isspmatrix_cooisspmatrix_csrisspmatrix_cscr    )matrixr%   r%   r&   _get_sparse_type  s   


r   c                   @  s   e Zd Zedd Zedd Zedd Zdd Zed&ddZed'ddZ	ed(ddZ
e	
d&ddZed)ddZed)ddZed)ddZed)ddZed)ddZed)ddZed(d d!Zed(d"d#Zed)d$d%Zd
S )*r\   c                 C  sN   t |r|  |j|j|jfS t |st |r#|j|j|j	fS t
d)Nr   )r   r   sum_duplicatesr   rowcolr   r   indptrindicesr    )r   r"   r%   r%   r&   _get_internal_arrays  s   
z,_SparseNCCLCommunicator._get_internal_arraysc                 C  s   |t dd |D  }|S )Nc                 s  s    | ]}|j V  qd S r)   )r!   ).0r   r%   r%   r&   	<genexpr>  s    z?_SparseNCCLCommunicator._get_shape_and_sizes.<locals>.<genexpr>)rZ   )r   arraysr   sizes_shaper%   r%   r&   _get_shape_and_sizes  s   z,_SparseNCCLCommunicator._get_shape_and_sizesc                 C  sz  |j r|dkrtj|dd}|jj||dd d S |dkr/tjddd}|jj||dd |S |d	krQ|j|kr@tj|dd}ntjddd}|jj||d
 |S |dkrptj|dd}tj|j	dgdd}|j
||| |S |dkrtj|dd}tj|j	dgdd}|j|| |S tdtd |dkrtj|dd}| ||||jd| d S |dkrtjddd}| ||||jd| t|S |d	kr|j|krtj|dd}ntjddd}tj||||d t|S |dkrtj|dd}tj|j	dfdd}tj|||||d t|S |dkr9tj|dd}tj|j	dfdd}tj||||d t|S td)Nrn   r   r   r   )desttagrq      )sourcer   r@   r8   rv   alltoallzUnsupported methodzUsing NCCL for transferring sparse arrays metadata. This will cause device synchronization and a huge performance degradation. Please install MPI and `mpi4py` in order to avoid this issue.)r9   rN   )rN   )r-   numpyr"   r;   Sendr   Recvr2   Bcastr   GatherAlltoallrJ   warningswarnrL   r   r   r   asnumpyrW   rj   rv   rw   )r   r   ro   r   methodrN   recv_bufr%   r%   r&   _exchange_shape_and_sizes  s   








z1_SparseNCCLCommunicator._exchange_shape_and_sizesc                 C  s~   t | r|d | _|d | _|d | _t|| _d S t | s%t | r;|d | _|d | _	|d | _
t|| _d S td)Nr   r   r   r   )r   r   r   r   r   rZ   _shaper   r   r   r   r    )r   r   r   r%   r%   r&   _assign_arraysE  s   






z&_SparseNCCLCommunicator._assign_arraysr   Nc                 C  s,   d}|  |||||| | |||| d S rV   )rh   rj   )r   r   re   rf   rT   rN   r9   r%   r%   r&   rc   T  s   z"_SparseNCCLCommunicator.all_reducer   c              
   C  sf  |  |}| ||j}| |||d|}|j|krt|t|kr&td|}	t|jt|}
t	|D ]V\}}t
|dd }|dd  }dd t||D }||krt  |D ]}| ||||j|j| qZt  | |
|| |dkr}|	|
 }	q4|dkr|	|
 }	q4td	q4| ||  |	|	j d S t  |D ]}| ||||j|j| qt  d S )
Nrv   z.in_array and out_array must be the same formatr   r   c                 S      g | ]\}}t j||jd qS r   rL   r   r   r   sr   r%   r%   r&   
<listcomp>n      z2_SparseNCCLCommunicator.reduce.<locals>.<listcomp>r   r   z.Sparse matrix only supports sum/prod reduction)r   r   r   r   r2   r   rS   r   r   	enumeraterZ   zipr   r   r   r!   r   r   r   )r   r   re   rf   r9   rT   rN   r   shape_and_sizesresultpartialro   ssr   sizesr   r%   r%   r&   rh   \  sV   





z_SparseNCCLCommunicator.reducec           
      C  s   |  |}|j|kr| ||j}nd}| |||d|}t|dd }|dd  }|j|kr:dd t||D }t  |D ]
}	t	
||	|| q@t  | ||| d S )Nr%   r@   r   r   c                 S  r   r   r   r   r%   r%   r&   r     r   z5_SparseNCCLCommunicator.broadcast.<locals>.<listcomp>)r   r2   r   r   r   rZ   r   r   r   rW   rj   r   r   )
r   r   rk   r9   rN   r   r   r   r   r   r%   r%   r&   rj     s(   



z!_SparseNCCLCommunicator.broadcastc              	   C  sl   d}g }t |ttfstd|D ]}	t|	jt|	}
| ||	|
||| ||
 q| 	||||| d S )Nr   z5in_array must be a list or a tuple of sparse matrices)
rX   rY   rZ   rS   r   r   r   rh   appendrt   )r   r   re   rf   r#   rT   rN   r9   reduce_out_arrayss_mpartial_out_arrayr%   r%   r&   rl     s   
z&_SparseNCCLCommunicator.reduce_scatterc           	        sd   d}g }|  | ||| |j|kr fddt|jD }|D ]}| |||| || q d S )Nr   c                   s   g | ]
}t  jt qS r%   )r   r   r   )r   _re   r%   r&   r     s    z6_SparseNCCLCommunicator.all_gather.<locals>.<listcomp>)rv   r2   r   r   rj   r   )	r   r   re   rf   r#   rN   r9   gather_out_arraysarrr%   r   r&   rm     s   

z"_SparseNCCLCommunicator.all_gatherc              	   C  s`   |  |}| ||j}| |||d| t  |D ]}| ||||j|j| qt	  d S )Nrn   )
r   r   r   r   r   r   r   r   r!   r   )r   r   r"   ro   rN   r   r   r   r%   r%   r&   rn     s   

z_SparseNCCLCommunicator.sendc                 C  sT   |j j}|tvrtd|j  dt|\}}||}|j|jj	|||| d S Nr   r   )
r   r   r   r    r'   rQ   rB   rn   r   rP   r   r%   r%   r&   r     s   
z_SparseNCCLCommunicator._sendc              	   C  s   |  ||dd|}| |}t|dd }|dd  }dd t||D }	t  |	D ]}
| ||
||
j|
j| q,t	  | 
||	| d S )Nr%   rq   r   r   c                 S  r   r   r   r   r%   r%   r&   r     s     z0_SparseNCCLCommunicator.recv.<locals>.<listcomp>)r   r   rZ   r   r   r   r   r   r!   r   r   )r   r   rf   ro   rN   r   r   r   r   arrsr   r%   r%   r&   rq     s   

z_SparseNCCLCommunicator.recvc                 C  sR   |j }|tvrtd|j dt|\}}||}|j|jj	|||| d S r   )
r   r   r    r   r'   rQ   rB   rq   r   rP   r   r%   r%   r&   r     s   
z_SparseNCCLCommunicator._recvc                 C  s4   t   | |||| | |||| t   d S r)   )r   r   rn   rq   r   )r   r   re   rf   ro   rN   r%   r%   r&   rr     s   z!_SparseNCCLCommunicator.send_recvc                 C  sz   |j |kr3t  t|D ]\}}||kr| |||| qt  | || || || j d S | 	|||| d S r)   )
r2   r   r   r   rn   r   r   r   r   rq   )r   r   re   rf   r9   rN   ro   s_ar%   r%   r&   rt     s   
z_SparseNCCLCommunicator.scatterc                 C  s|   |j |kr4t|jD ]'}t|jt|}||kr!| |||| n| || ||j	 |
| q
d S | |||| d S r)   )r2   r   r   r   r   r   rq   r   r   r   r   rn   )r   r   re   rf   r9   rN   ro   resr%   r%   r&   rv     s   

z_SparseNCCLCommunicator.gatherc              
   C  sP  t ||jkrtd|j dt | g }g }t|D ]\}}| |}	|| |	|j q| |||d|}t	|jD ]g}t
|| dd }
|| dd  }| || }dd t||D }t  |D ]}| ||||j|j| qi|D ]}| ||||j|j| qzt  |t|| jt||  | || ||
 q>d S )Nr   zelements, found r   r   r   c                 S  r   r   r   r   r%   r%   r&   r   ;  r   z6_SparseNCCLCommunicator.all_to_all.<locals>.<listcomp>)lenr   rJ   r   r   r   r   r   r   r   rZ   r   r   r   r   r   r!   r   r   r   r   r   )r   r   re   rf   rN   r   recv_shape_and_sizesr
   r   r   r   r   s_arraysr_arraysr%   r%   r&   rw   #  sB   



z"_SparseNCCLCommunicator.all_to_allrx   ry   rz   r)   )r{   r|   r}   r   r   r   r   r   rc   rh   rj   rl   rm   rn   r   rq   r   rr   rt   rv   rw   r%   r%   r%   r&   r\     sF    


K-r\   r)   )'
__future__r   r   r   rL   	cupy.cudar   cupyx.distributedr   cupyx.distributed._commr   cupyx.scipyr   mpi4pyr   r,   ImportError	available	NCCL_INT8
NCCL_UINT8
NCCL_INT32NCCL_UINT32
NCCL_INT64NCCL_UINT64NCCL_FLOAT16NCCL_FLOAT32NCCL_FLOAT64r   NCCL_SUM	NCCL_PRODNCCL_MAXNCCL_MINrR   r'   r(   rW   r   r   r\   r%   r%   r%   r&   <module>   sX    
 z 