o
    װi                     @   s  d dl Z d dlZd dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
 z
d dlmZ dZW n ey9   dZY nw ejrfejejejejejejejejejejejejejd	Zejejejejd
Zni Zi ZdddZG dd deZ G dd dZ!dd Z"dd Z#G dd dZ$dS )    N)nccl)_store)_Backend)sparse)MPITF)bBiIlLqQefdFD)sumprodmaxminc                 C   sT   | j j}|tvrtd| j  dt| }|d u r| j}|dv r&|d| fS ||fS )NUnknown dtype 	 for NCCLFD   )dtypechar_nccl_dtypes	TypeErrorsize)arraycountr   
nccl_dtype r$   P/home/ubuntu/.local/lib/python3.10/site-packages/cupyx/distributed/_nccl_comm.py_get_nccl_dtype_and_count.   s   r&   c                       s   e Zd ZdZejejdf fdd	Zdd Zdd Z	d	d
 Z
dd Zdd Zdd Zd,ddZd-ddZd.ddZ	d,ddZd/ddZd/ddZd/d d!Zd/d"d#Zd.d$d%Zd.d&d'Zd/d(d)Zd*d+ Z  ZS )0NCCLBackenda  Interface that uses NVIDIA's NCCL to perform communications.

    Args:
        n_devices (int): Total number of devices that will be used in the
            distributed execution.
        rank (int): Unique id of the GPU that the communicator is associated to
            its value needs to be `0 <= rank < n_devices`.
        host (str, optional): host address for the process rendezvous on
            initialization. Defaults to `"127.0.0.1"`.
        port (int, optional): port used for the process rendezvous on
            initialization. Defaults to `13333`.
        use_mpi(bool, optional): switch between MPI and use the included TCP
            server for initialization & synchronization. Defaults to `False`.
    Fc                    sF   t  |||| to|| _| jr| || d S | |||| d S N)super__init___mpi_available_use_mpi_init_with_mpi_init_with_tcp_store)self	n_devicesrankhostportuse_mpi	__class__r$   r%   r*   J   s
   
zNCCLBackend.__init__c                 C   sX   t j| _| j | _| j  d }| jdkrt }| jj|dd}t	|||| _
d S )Nr   root)r   
COMM_WORLD	_mpi_commGet_rank	_mpi_rankBarrierr   get_unique_idbcastNcclCommunicator_comm)r/   r0   r1   nccl_idr$   r$   r%   r-   T   s   

zNCCLBackend._init_with_mpic                 C   s   d }|dkr%| j || t }tdd |D }|| jd< | j  n| j  | jd }tdd |D }t|||| _	d S )Nr   c                 S   s   g | ]}|d  qS    r$   .0r   r$   r$   r%   
<listcomp>i   s    z4NCCLBackend._init_with_tcp_store.<locals>.<listcomp>rB   c                 S   s   g | ]}t |d  qS rC   )intrE   r$   r$   r%   rG   o   s    )
r   runr   r>   bytes_store_proxybarriertupler@   rA   )r/   r0   r1   r2   r3   rB   shifted_nccl_idr$   r$   r%   r.   a   s   


z NCCLBackend._init_with_tcp_storec                 C   s    |j js|j jstdd S d S )Nz4NCCL requires arrays to be either c- or f-contiguous)flagsc_contiguousf_contiguousRuntimeError)r/   r!   r$   r$   r%   _check_contiguousr   s
   zNCCLBackend._check_contiguousc                 C   s   |d u r
t jj }|jS r(   )cupycudastreamget_current_streamptr)r/   rV   r$   r$   r%   _get_streamw   s   zNCCLBackend._get_streamc                 C   s8   |t vrtd| d|dv r|dkrtdt | S )NzUnknown op r   r   r   z-Only nccl.SUM is supported for complex arrays)	_nccl_opsrR   
ValueError)r/   opr   r$   r$   r%   _get_op|   s   zNCCLBackend._get_opc                 C   sT   t }t|d ttfrt|d d st|d rt}t||| g|R   d S Nr   )_DenseNCCLCommunicator
isinstancelistrM   r   issparse_SparseNCCLCommunicatorgetattr)r/   functionargs
comm_classr$   r$   r%   _dispatch_arg_type   s   zNCCLBackend._dispatch_arg_typer   Nc                 C      |  d||||f dS )a  Performs an all reduce operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_reduceNrh   )r/   in_array	out_arrayr\   rV   r$   r$   r%   rj      s   zNCCLBackend.all_reducer   c                 C      |  d|||||f dS )a  Performs a reduce operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
                will only be modified by the `root` process.
            root (int, optional): rank of the process that will perform the
                reduction. Defaults to `0`.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        reduceNrk   )r/   rl   rm   r8   r\   rV   r$   r$   r%   ro      s   zNCCLBackend.reducec                 C      |  d|||f dS )a  Performs a broadcast operation.

        Args:
            in_out_array (cupy.ndarray): array to be sent for `root` rank.
                Other ranks will receive the broadcast data here.
            root (int, optional): rank of the process that will send the
                broadcast. Defaults to `0`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        	broadcastNrk   )r/   in_out_arrayr8   rV   r$   r$   r%   rq      s   
zNCCLBackend.broadcastc                 C   rn   )a/  Performs a reduce scatter operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            count (int): Number of elements to send to each rank.
            op (str): reduction operation, can be one of
                ('sum', 'prod', 'min' 'max'), arrays of complex type only
                support `'sum'`. Defaults to `'sum'`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        reduce_scatterNrk   )r/   rl   rm   r"   r\   rV   r$   r$   r%   rs      s   zNCCLBackend.reduce_scatterc                 C   ri   )as  Performs an all gather operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
            count (int): Number of elements to send to each rank.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_gatherNrk   )r/   rl   rm   r"   rV   r$   r$   r%   rt      s   
zNCCLBackend.all_gatherc                 C   rp   )a  Performs a send operation.

        Args:
            array (cupy.ndarray): array to be sent.
            peer (int): rank of the process `array` will be sent to.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        sendNrk   )r/   r!   peerrV   r$   r$   r%   ru         	zNCCLBackend.sendc                 C   rp   )a2  Performs a receive operation.

        Args:
            array (cupy.ndarray): array used to receive data.
            peer (int): rank of the process `array` will be received from.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        recvNrk   )r/   rm   rv   rV   r$   r$   r%   rx      rw   zNCCLBackend.recvc                 C   ri   )a  Performs a send and receive operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array used to receive data.
            peer (int): rank of the process to send `in_array` and receive
                `out_array`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        	send_recvNrk   )r/   rl   rm   rv   rV   r$   r$   r%   ry         zNCCLBackend.send_recvc                 C   ri   )a  Performs a scatter operation.

        Args:
            in_array (cupy.ndarray): array to be sent. Its shape must be
                `(total_ranks, ...)`.
            out_array (cupy.ndarray): array where the result with be stored.
            root (int): rank that will send the `in_array` to other ranks.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        scatterNrk   r/   rl   rm   r8   rV   r$   r$   r%   r{      rz   zNCCLBackend.scatterc                 C   ri   )a  Performs a gather operation.

        Args:
            in_array (cupy.ndarray): array to be sent.
            out_array (cupy.ndarray): array where the result with be stored.
                Its shape must be `(total_ranks, ...)`.
            root (int): rank that will receive `in_array` from other ranks.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        gatherNrk   r|   r$   r$   r%   r}     rz   zNCCLBackend.gatherc                 C   rp   )a  Performs an all to all operation.

        Args:
            in_array (cupy.ndarray): array to be sent. Its shape must be
                `(total_ranks, ...)`.
            out_array (cupy.ndarray): array where the result with be stored.
                Its shape must be `(total_ranks, ...)`.
            stream (cupy.cuda.Stream, optional): if supported, stream to
                perform the communication.
        
all_to_allNrk   )r/   rl   rm   rV   r$   r$   r%   r~     s   
zNCCLBackend.all_to_allc                 C   s"   | j r
| j  dS | j  dS )zPerforms a barrier operation.

        The barrier is done in the cpu and is a explicit synchronization
        mechanism that halts the thread progression.
        N)r,   r:   r=   rK   rL   )r/   r$   r$   r%   rL   *  s   zNCCLBackend.barrierr   Nr   r   Nr   Nr(   )__name__
__module____qualname____doc__r   _DEFAULT_HOST_DEFAULT_PORTr*   r-   r.   rS   rY   r]   rh   rj   ro   rq   rs   rt   ru   rx   ry   r{   r}   r~   rL   __classcell__r$   r$   r5   r%   r'   :   s0    












r'   c                   @   s   e Zd ZedddZedddZed dd	Ze	dd
dZed!ddZed!ddZ	ed!ddZ
ed!ddZed!ddZed!ddZed ddZed ddZed!ddZdS )"r_   r   Nc                 C   s\   | | | | ||}t|\}}|||jj}|j|jj	|jj	|||| d S r(   )
rS   rY   r&   r]   r   r   rA   	allReducedatarX   )clscommrl   rm   r\   rV   r   r"   r$   r$   r%   rj   :  s   


z!_DenseNCCLCommunicator.all_reducer   c           	   	   C   sh   | | |j|kr| | ||}t|\}}|||jj}|j|j	j
|j	j
||||| d S r(   )rS   r1   rY   r&   r]   r   r   rA   ro   r   rX   )	r   r   rl   rm   r8   r\   rV   r   r"   r$   r$   r%   ro   D  s   




z_DenseNCCLCommunicator.reducec                 C   sB   | | ||}t|\}}|j|jj|jj|||| d S r(   )rS   rY   r&   rA   rq   r   rX   )r   r   rr   r8   rV   r   r"   r$   r$   r%   rq   P  s   

z _DenseNCCLCommunicator.broadcastc                 C   s^   | | | | ||}t||\}}|||jj}|j|jj	|jj	|||| d S r(   )
rS   rY   r&   r]   r   r   rA   reduceScatterr   rX   )r   r   rl   rm   r"   r\   rV   r   r$   r$   r%   rs   Y  s   


z%_DenseNCCLCommunicator.reduce_scatterc                 C   sL   | | | | ||}t||\}}|j|jj|jj||| d S r(   )rS   rY   r&   rA   	allGatherr   rX   )r   r   rl   rm   r"   rV   r   r$   r$   r%   rt   d  s   


z!_DenseNCCLCommunicator.all_gatherc                 C   8   | | ||}t|\}}| |||||| d S r(   )rS   rY   r&   _send)r   r   r!   rv   rV   r   r"   r$   r$   r%   ru   m     

z_DenseNCCLCommunicator.sendc                 C      |j |jj|||| d S r(   )rA   ru   r   rX   r   r   r!   rv   r   r"   rV   r$   r$   r%   r   t     z_DenseNCCLCommunicator._sendc                 C   r   r(   )rS   rY   r&   _recv)r   r   rm   rv   rV   r   r"   r$   r$   r%   rx   x  r   z_DenseNCCLCommunicator.recvc                 C   r   r(   )rA   rx   r   rX   r   r   rm   rv   r   r"   rV   r$   r$   r%   r     r   z_DenseNCCLCommunicator._recvc           
      C   sr   | | | | ||}t|\}}t|\}}	t  | |||||| | |||||	| t  d S r(   )rS   rY   r&   r   
groupStartr   r   groupEnd)
r   r   rl   rm   rv   rV   idtypeicountodtypeocountr$   r$   r%   ry     s   


z _DenseNCCLCommunicator.send_recvc              	   C   s   |j d |jkrtd|j d|j  || || ||}t  ||jkrHt|jD ]}|| }t	|\}}	| 
|||||	| q1t	|\}
}| ||||
|| t  d S )Nr   z"scatter requires in_array to have 'elements in its first dimension, found )shape
_n_devicesrR   rS   rY   r   r   r1   ranger&   r   r   r   )r   r   rl   rm   r8   rV   r	   r!   r   r   r   r"   r$   r$   r%   r{     s$   




z_DenseNCCLCommunicator.scatterc              	   C   s   |j d |jkrtd|j d|j  || || ||}t  ||jkrHt|jD ]}|| }t	|\}}	| 
|||||	| q1t	|\}
}| ||||
|| t  d S )Nr   z"gather requires out_array to have r   )r   r   rR   rS   rY   r   r   r1   r   r&   r   r   r   )r   r   rl   rm   r8   rV   r	   r!   r   r   r   r"   r$   r$   r%   r}     s$   




z_DenseNCCLCommunicator.gatherc           
   	   C   s   |j d |jkrtd|j d|j  |j d |jkr(td|j d|j  || || ||}t|d \}}t|d \}}t  t|jD ]}	| 	|||	 |	||| | 
|||	 |	||| qPt  d S )Nr   %all_to_all requires in_array to have r   z&all_to_all requires out_array to have )r   r   rR   rS   rY   r&   r   r   r   r   r   r   )
r   r   rl   rm   rV   r   r   r   r   r	   r$   r$   r%   r~     s,   




z!_DenseNCCLCommunicator.all_to_allr   r   r   r(   )r   r   r   classmethodrj   ro   rq   rs   rt   ru   r   rx   r   ry   r{   r}   r~   r$   r$   r$   r%   r_   8  s8    	
r_   c                 C   s   t d| }t dd}t dd}|dkr tj|||fddS |dkr.tj|||fddS |dkr=tj|||ffddS td)	N   r	   csr)r   r   )r   csccoo4NCCL is not supported for this type of sparse matrix)rT   emptyr   
csr_matrix
csc_matrix
coo_matrixr   )r   sparse_typer   ar   r$   r$   r%   _make_sparse_empty  s   r   c                 C   s2   t | rdS t | rdS t | rdS td)Nr   r   r   r   )r   isspmatrix_cooisspmatrix_csrisspmatrix_cscr   )matrixr$   r$   r%   _get_sparse_type  s   


r   c                   @   s   e Zd Zedd Zedd Zedd Zdd Zed&ddZed'ddZ	ed(ddZ
e	
d&ddZed)ddZed)ddZed)ddZed)ddZed)ddZed)ddZed(d d!Zed(d"d#Zed)d$d%Zd
S )*rc   c                 C   sN   t |r|  |j|j|jfS t |st |r#|j|j|j	fS t
d)Nr   )r   r   sum_duplicatesr   rowcolr   r   indptrindicesr   )r   r!   r$   r$   r%   _get_internal_arrays  s   
z,_SparseNCCLCommunicator._get_internal_arraysc                 C   s   |t dd |D  }|S )Nc                 s   s    | ]}|j V  qd S r(   )r    )rF   r   r$   r$   r%   	<genexpr>  s    z?_SparseNCCLCommunicator._get_shape_and_sizes.<locals>.<genexpr>)rM   )r   arraysr   sizes_shaper$   r$   r%   _get_shape_and_sizes  s   z,_SparseNCCLCommunicator._get_shape_and_sizesc                 C   sz  |j r|dkrtj|dd}|jj||dd d S |dkr/tjddd}|jj||dd |S |d	krQ|j|kr@tj|dd}ntjddd}|jj||d
 |S |dkrptj|dd}tj|j	dgdd}|j
||| |S |dkrtj|dd}tj|j	dgdd}|j|| |S tdtd |dkrtj|dd}| ||||jd| d S |dkrtjddd}| ||||jd| t|S |d	kr|j|krtj|dd}ntjddd}tj||||d t|S |dkrtj|dd}tj|j	dfdd}tj|||||d t|S |dkr9tj|dd}tj|j	dfdd}tj||||d t|S td)Nru   r   r   r   )desttagrx      )sourcer   r?   r7   r}   alltoallzUnsupported methodzUsing NCCL for transferring sparse arrays metadata. This will cause device synchronization and a huge performance degradation. Please install MPI and `mpi4py` in order to avoid this issue.)r8   rV   )rV   )r,   numpyr!   r:   Sendr   Recvr1   Bcastr   GatherAlltoallrR   warningswarnrT   r   r   r   asnumpyr_   rq   r}   r~   )r   r   rv   r   methodrV   recv_bufr$   r$   r%   _exchange_shape_and_sizes  s   








z1_SparseNCCLCommunicator._exchange_shape_and_sizesc                 C   s~   t | r|d | _|d | _|d | _t|| _d S t | s%t | r;|d | _|d | _	|d | _
t|| _d S td)Nr   r   r   r   )r   r   r   r   r   rM   _shaper   r   r   r   r   )r   r   r   r$   r$   r%   _assign_arraysH  s   






z&_SparseNCCLCommunicator._assign_arraysr   Nc                 C   s,   d}|  |||||| | |||| d S r^   )ro   rq   )r   r   rl   rm   r\   rV   r8   r$   r$   r%   rj   W  s   z"_SparseNCCLCommunicator.all_reducer   c              
   C   sf  |  |}| ||j}| |||d|}|j|krt|t|kr&td|}	t|jt|}
t	|D ]V\}}t
|dd }|dd  }dd t||D }||krt  |D ]}| ||||j|j| qZt  | |
|| |dkr}|	|
 }	q4|dkr|	|
 }	q4td	q4| ||  |	|	j d S t  |D ]}| ||||j|j| qt  d S )
Nr}   z.in_array and out_array must be the same formatr   r   c                 S       g | ]\}}t j||jd qS r   rT   r   r   rF   sr   r$   r$   r%   rG   q      z2_SparseNCCLCommunicator.reduce.<locals>.<listcomp>r   r   z.Sparse matrix only supports sum/prod reduction)r   r   r   r   r1   r   r[   r   r   	enumeraterM   zipr   r   r   r    r   r   r   )r   r   rl   rm   r8   r\   rV   r   shape_and_sizesresultpartialrv   ssr   sizesr   r$   r$   r%   ro   _  sV   





z_SparseNCCLCommunicator.reducec           
      C   s   |  |}|j|kr| ||j}nd}| |||d|}t|dd }|dd  }|j|kr:dd t||D }t  |D ]
}	t	
||	|| q@t  | ||| d S )Nr$   r?   r   r   c                 S   r   r   r   r   r$   r$   r%   rG     r   z5_SparseNCCLCommunicator.broadcast.<locals>.<listcomp>)r   r1   r   r   r   rM   r   r   r   r_   rq   r   r   )
r   r   rr   r8   rV   r   r   r   r   r   r$   r$   r%   rq     s(   



z!_SparseNCCLCommunicator.broadcastc              	   C   sl   d}g }t |ttfstd|D ]}	t|	jt|	}
| ||	|
||| ||
 q| 	||||| d S )Nr   z5in_array must be a list or a tuple of sparse matrices)
r`   ra   rM   r[   r   r   r   ro   appendr{   )r   r   rl   rm   r"   r\   rV   r8   reduce_out_arrayss_mpartial_out_arrayr$   r$   r%   rs     s   
z&_SparseNCCLCommunicator.reduce_scatterc           	         sd   d}g }|  | ||| |j|kr fddt|jD }|D ]}| |||| || q d S )Nr   c                    s   g | ]
}t  jt qS r$   )r   r   r   )rF   _rl   r$   r%   rG     s    z6_SparseNCCLCommunicator.all_gather.<locals>.<listcomp>)r}   r1   r   r   rq   r   )	r   r   rl   rm   r"   rV   r8   gather_out_arraysarrr$   r   r%   rt     s   

z"_SparseNCCLCommunicator.all_gatherc              	   C   s`   |  |}| ||j}| |||d| t  |D ]}| ||||j|j| qt	  d S )Nru   )
r   r   r   r   r   r   r   r   r    r   )r   r   r!   rv   rV   r   r   r   r$   r$   r%   ru     s   

z_SparseNCCLCommunicator.sendc                 C   sT   |j j}|tvrtd|j  dt|\}}||}|j|jj	|||| d S Nr   r   )
r   r   r   r   r&   rY   rA   ru   r   rX   r   r$   r$   r%   r     s   
z_SparseNCCLCommunicator._sendc              	   C   s   |  ||dd|}| |}t|dd }|dd  }dd t||D }	t  |	D ]}
| ||
||
j|
j| q,t	  | 
||	| d S )Nr$   rx   r   r   c                 S   r   r   r   r   r$   r$   r%   rG     s     z0_SparseNCCLCommunicator.recv.<locals>.<listcomp>)r   r   rM   r   r   r   r   r   r    r   r   )r   r   rm   rv   rV   r   r   r   r   arrsr   r$   r$   r%   rx     s   

z_SparseNCCLCommunicator.recvc                 C   sR   |j }|tvrtd|j dt|\}}||}|j|jj	|||| d S r   )
r   r   r   r   r&   rY   rA   rx   r   rX   r   r$   r$   r%   r     s   
z_SparseNCCLCommunicator._recvc                 C   s4   t   | |||| | |||| t   d S r(   )r   r   ru   rx   r   )r   r   rl   rm   rv   rV   r$   r$   r%   ry     s   z!_SparseNCCLCommunicator.send_recvc                 C   sz   |j |kr3t  t|D ]\}}||kr| |||| qt  | || || || j d S | 	|||| d S r(   )
r1   r   r   r   ru   r   r   r   r   rx   )r   r   rl   rm   r8   rV   rv   s_ar$   r$   r%   r{     s   
z_SparseNCCLCommunicator.scatterc                 C   s|   |j |kr4t|jD ]'}t|jt|}||kr!| |||| n| || ||j	 |
| q
d S | |||| d S r(   )r1   r   r   r   r   r   rx   r   r   r   r   ru   )r   r   rl   rm   r8   rV   rv   resr$   r$   r%   r}     s   

z_SparseNCCLCommunicator.gatherc              
   C   sP  t ||jkrtd|j dt | g }g }t|D ]\}}| |}	|| |	|j q| |||d|}t	|jD ]g}t
|| dd }
|| dd  }| || }dd t||D }t  |D ]}| ||||j|j| qi|D ]}| ||||j|j| qzt  |t|| jt||  | || ||
 q>d S )Nr   zelements, found r   r   r   c                 S   r   r   r   r   r$   r$   r%   rG   >  r   z6_SparseNCCLCommunicator.all_to_all.<locals>.<listcomp>)lenr   rR   r   r   r   r   r   r   r   rM   r   r   r   r   r   r    r   r   r   r   r   )r   r   rl   rm   rV   r   recv_shape_and_sizesr	   r   r   r   r   s_arraysr_arraysr$   r$   r%   r~   &  sB   



z"_SparseNCCLCommunicator.all_to_allr   r   r   r(   )r   r   r   r   r   r   r   r   rj   ro   rq   rs   rt   ru   r   rx   r   ry   r{   r}   r~   r$   r$   r$   r%   rc     sF    


K-rc   r(   )%r   r   rT   	cupy.cudar   cupyx.distributedr   cupyx.distributed._commr   cupyx.scipyr   mpi4pyr   r+   ImportError	available	NCCL_INT8
NCCL_UINT8
NCCL_INT32NCCL_UINT32
NCCL_INT64NCCL_UINT64NCCL_FLOAT16NCCL_FLOAT32NCCL_FLOAT64r   NCCL_SUM	NCCL_PRODNCCL_MAXNCCL_MINrZ   r&   r'   r_   r   r   rc   r$   r$   r$   r%   <module>   sV    
  