o
    װiP                     @   s   d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	 dd Z
dd	 Zd
d Zdd ZdddZdd ZdddZdS )    N)runtime)internal)device)_utilc                 C   s,  ddl m} ddl m} | j}|d }| j|ddd||} | jd }tj||ftj	d}tj|ftj	d}|d	 || k}	|	rt
 }
|}|| | j }| jj}|||  }tj|||tjd}|tjkrk|j}n|tjkrt|j}n|tjkr}|j}n|tjkr|j}nJ ||
||jj||jj|jj| ngt
 }
|tjkr|j}|j}n&|tjkr|j}|j}n|tjkr|j}|j}n|tjkr|j }|j!}nJ t"|D ],}| | jj}||
||||}tj||d}||
|||||jj|| jj|| jj q| |||dd ||dd fS )a  Compute pivoted LU decomposition.

    Decompose a given batch of square matrices. Inputs and outputs are
    transposed.

    Args:
        a_t (cupy.ndarray): The input matrix with dimension ``(..., N, N)``.
            The dimension condition is not checked.
        dtype (numpy.dtype): float32, float64, complex64, or complex128.

    Returns:
        tuple:
        lu_t (cupy.ndarray):
            ``L`` without its unit diagonal and ``U`` with
            dimension ``(..., N, N)``.
        piv (cupy.ndarray):
            1-origin pivot indices with dimension
            ``(..., N)``.
        dev_info (cupy.ndarray):
            ``getrf`` info with dimension ``(...)``.

    .. seealso:: :func:`scipy.linalg.lu_factor`

    r   cublascusolverC)orderdtypei   FN)#cupy_backends.cuda.libsr   r	   shapeastypereshapecupyemptynumpyint32r   get_cublas_handleitemsizedataptrarangeuintpfloat32sgetrfBatchedfloat64dgetrfBatched	complex64cgetrfBatched
complex128zgetrfBatchedget_cusolver_handlesgetrf_bufferSizesgetrfdgetrf_bufferSizedgetrfcgetrf_bufferSizecgetrfzgetrf_bufferSizezgetrfrange)a_tr   r   r	   
orig_shapen
batch_sizeipivdev_infouse_batchedhandleldastepstartstopa_arraygetrfBatchedgetrf_bufferSizegetrfia_ptr
buffersize	workspace rD   N/home/ubuntu/.local/lib/python3.10/site-packages/cupy/linalg/_decomposition.py
_lu_factor
   sn   








rF   c                 C   s6  ddl m} ddl m} ddlm} |dstdt| \}}| jdkr-t	
| j|S |dkr5|j}n|dkr=|j}n|d	krE|j}n|j}| j|d
dd}t	j|}|jd }	|jd |jj }
t }t|jdd }t	j
|tjd}|||j|	|jj|
|jj| t	jj || t	!|j|ddS )a  Batched Cholesky decomposition.

    Decompose a given array of two-dimensional square matrices into
    ``L * L.T``, where ``L`` is a lower-triangular matrix and ``.T``
    is a conjugate transpose operator.

    Args:
        a (cupy.ndarray): The input array of matrices
            with dimension ``(..., N, N)``

    Returns:
        cupy.ndarray: The lower-triangular matrix.
    r   r   r   )check_availabilitypotrfBatchedzpotrfBatched is not availablefdFr   Tr   copyr   r
   Nr   FrM   )"r   r   r	   cupyx.cusolverrG   RuntimeErrorr   linalg_common_typesizer   r   r   spotrfBatcheddpotrfBatchedcpotrfBatchedzpotrfBatchedr   _core	_mat_ptrsstridesr   r   r   r&   r   prodr   r   CUBLAS_FILL_MODE_UPPERr   r   linalg3_check_cusolver_dev_info_if_synchronization_allowedtril)ar   r	   rG   r   	out_dtyperH   xxpr2   ldxr7   r3   r5   rD   rD   rE   _potrf_batchedj   s<   

rd   c              	   C   s`  ddl m} ddl m} t|  t|  t|  | jdkr$t| S t	| \}}| j
dkr7t| j|S | j|ddd}t| }t }tjdtjd	}|d
krZ|j}	|j}
n|dkre|j}	|j}
n|dkrp|j}	|j}
n|j}	|j}
|
||j||jj|}tj||d	}|	||j||jj||jj||jj tjj |	| tj!|dd |j|ddS )a6  Cholesky decomposition.

    Decompose a given two-dimensional square matrix into ``L * L.H``,
    where ``L`` is a lower-triangular matrix and ``.H`` is a conjugate
    transpose operator.

    Args:
        a (cupy.ndarray): Hermitian (symmetric if all elements are real),
            positive-definite input matrix with dimension ``(..., M, M)``.

    Returns:
        cupy.ndarray: The lower-triangular matrix of shape ``(..., M, M)``.

    .. warning::
        This function calls one or more cuSOLVER routine(s) which may yield
        invalid results if input conditions are not met.
        To detect these invalid results, you can set the `linalg`
        configuration to a value that is not `ignore` in
        :func:`cupyx.errstate` or :func:`cupyx.seterr`.

    .. seealso:: :func:`numpy.linalg.cholesky`
    r   r   r      r   TrL      r   rI   rJ   rK   )kFrN   )"r   r   r	   r   _assert_cupy_array_assert_stacked_2d_assert_stacked_squarendimrd   rQ   rR   r   r   r   r   lenr   r&   r   r   spotrfspotrf_bufferSizedpotrfdpotrf_bufferSizecpotrfcpotrf_bufferSizezpotrfzpotrf_bufferSizer[   r   r   r\   r]   _tril)r_   r   r	   r   r`   ra   r2   r7   r5   potrfpotrf_bufferSizerB   rC   rD   rD   rE   cholesky   sL   




rx   c                 C   s  ddl m} | jd d }t|}| jdd  \}}t||}|dks(|dkrt| \}}	|dkrGt	|||f |	t	|||f |	fS |dkr^t
|||	}
|
t	|||f |	fS |dkrlt	|||f |	S |dkrt	|||f |	t	||f |	fS | jdg| jdd  R  } || |}|dkr|||jdd   S |\}
}|
||
jdd   }
|dkrdnd}|||j|d   }|
|fS )	Nr   )_geqrf_orgqr_batchedr
   reducedcompleterrawr   )rO   ry   r   r   rZ   minr   rQ   r   r   stacked_identityr   )r_   modery   batch_shaper3   mr2   rg   r   r`   qoutr|   idxrD   rD   rE   _qr_batched   s:   


r   rz   c                 C   s  ddl m} t|  |dvr%|dv rd|}t|d|}t|| jdkr/t| |S t| \}}| j	\}}t
||}|dkr|dkrXt|df|td|f|fS |d	krjt||t||f|fS |d
krvtd|f|S t||f|td|fS |  j|ddd}	t }
tjdtjd}|dkr|j}|j}n+|dkr|j}|j}n |dkr|j}|j}n|dkr|j}|j}n
d| j}t|||
|||	jj|}tj||d}tj||d}||
|||	jj||jj|jj||jj	 tjj || |d
kr|	ddd|f  }t!|j|ddS |dkr2|	j|dd|j|ddfS |d	krG||krG|}t||f|}n
|}t||f|}|	|d|< |dkrc|j"}|j#}n#|dkro|j$}|j%}n|dkr{|j&}|j'}n|dkr|j(}|j)}||
||||jj||jj}tj||d}||
||||jj||jj|jj||jj
 tjj || |d|  }|	ddd|f  }|j|ddt!|j|ddfS )a  QR decomposition.

    Decompose a given two-dimensional matrix into ``Q * R``, where ``Q``
    is an orthonormal and ``R`` is an upper-triangular matrix.

    Args:
        a (cupy.ndarray): The input matrix.
        mode (str): The mode of decomposition. Currently 'reduced',
            'complete', 'r', and 'raw' modes are supported. The default mode
            is 'reduced', in which matrix ``A = (..., M, N)`` is decomposed
            into ``Q``, ``R`` with dimensions ``(..., M, K)``, ``(..., K, N)``,
            where ``K = min(M, N)``.

    Returns:
        cupy.ndarray, or tuple of ndarray:
            Although the type of returned object depends on the mode,
            it returns a tuple of ``(Q, R)`` by default.
            For details, please see the document of :func:`numpy.linalg.qr`.

    .. warning::
        This function calls one or more cuSOLVER routine(s) which may yield
        invalid results if input conditions are not met.
        To detect these invalid results, you can set the `linalg`
        configuration to a value that is not `ignore` in
        :func:`cupyx.errstate` or :func:`cupyx.seterr`.

    .. seealso:: :func:`numpy.linalg.qr`
    r   r   )rz   r{   r|   r}   )rI   fulleeconomicz)The deprecated mode '{}' is not supportedzUnrecognized mode '{}'re   rz   r{   r|   r   r   TrL   rf   r   rI   rJ   rK   DzDdtype must be float32, float64, complex64 or complex128 (actual: {})NFrN   r}   )*r   r	   r   rh   format
ValueErrorrk   r   rQ   r   r~   r   r   identity	transposer   r   r&   r   r   sgeqrf_bufferSizesgeqrfdgeqrf_bufferSizedgeqrfcgeqrf_bufferSizecgeqrfzgeqrf_bufferSizezgeqrfr   r   r   r\   r]   _triusorgqr_bufferSizesorgqrdorgqr_bufferSizedorgqrcungqr_bufferSizecungqrzungqr_bufferSizezungqr)r_   r   r	   msgr   r`   r   r2   rg   ra   r7   r5   geqrf_bufferSizegeqrfrB   rC   taur|   mcr   orgqr_bufferSizeorgqrrD   rD   rE   qr  s   






 





r   c                 C   s  ddl m}m} | jd d }t|}| jdd  \}}t| \}	}
|
j	 }|dkrst
||}t||f |}|rq|rVtj|||f |
d}tj|||f |
d}ntj|||f |
d}tj|||f |
d}|||fS |S |dks{|dkrt|d |}|r|rt|||
}t|||
}ntj||df |
d}tj|d|f |
d}|||fS |S | jdg| jdd  R  } tjs|dkr|dkr| j|	dd	d
} || ||d	}n	|| |	j||d	}|r:|\}}}|j|
d	d}|jg ||jdd  R  }|j|d	d}|jg ||jdd  R  }|j|
d	d}|jg ||jdd  R  }|||dd fS |}|j|d	d}|jg ||jdd  R  }|S )Nr   )_gesvdj_batched_gesvd_batchedr
   r   r   r       r   FrL   rN   )rO   r   r   r   r   rZ   r   rQ   charlowerr~   r   r   r   r   r   is_hipr   swapaxesconj)r_   full_matrices
compute_uvr   r   r   r3   r2   r   r   uv_dtypes_dtyperg   suvtr   vrD   rD   rE   _svd_batched  sZ   





r   Tc                 C   s  ddl m} t|  | jdkrt| ||S t| \}}|j }|j }| j	\}}	|	dks4|dkrft
d|}
|rd|rMt
j||d}t
j|	|d}nt
j|df|d}t
jd|	f|d}||
|fS |
S |	|kru| j|ddd}d	}n| j	\}	}|  j|ddd}d}|}|r|rt
j|	|	f|d}|d
d
d
|f }td}td}n|}t
j||f|d}td}td}|jj|jj}}nd\}}td}td}t
j||d}
t }t
jdtjd}|dkr|j}|j}n|dkr|j}|j}n|dkr|j}|j}n|j}|j}|||	|}t
j||d}tjs d}nt
jt |	|d |d}|jj}|||||	||jj|	|
jj||	|||jj|||jj t
j!j"|| |
j|d	d}
|r||j|d	d}|j|d	d}|rw| |
| fS ||
|fS |
S )a  Singular Value Decomposition.

    Factorizes the matrix ``a`` as ``u * np.diag(s) * v``, where ``u`` and
    ``v`` are unitary and ``s`` is an one-dimensional array of ``a``'s
    singular values.

    Args:
        a (cupy.ndarray): The input matrix with dimension ``(..., M, N)``.
        full_matrices (bool): If True, it returns u and v with dimensions
            ``(..., M, M)`` and ``(..., N, N)``. Otherwise, the dimensions
            of u and v are ``(..., M, K)`` and ``(..., K, N)``, respectively,
            where ``K = min(M, N)``.
        compute_uv (bool): If ``False``, it only returns singular values.

    Returns:
        tuple of :class:`cupy.ndarray`:
            A tuple of ``(u, s, v)`` such that ``a = u * np.diag(s) * v``.

    .. warning::
        This function calls one or more cuSOLVER routine(s) which may yield
        invalid results if input conditions are not met.
        To detect these invalid results, you can set the `linalg`
        configuration to a value that is not `ignore` in
        :func:`cupyx.errstate` or :func:`cupyx.seterr`.

    .. note::
        On CUDA, when ``a.ndim > 2`` and the matrix dimensions <= 32, a fast
        code path based on Jacobian method (``gesvdj``) is taken. Otherwise,
        a QR method (``gesvd``) is used.

        On ROCm, there is no such a fast code path that switches the underlying
        algorithm.

    .. seealso:: :func:`numpy.linalg.svd`
    r   r   re   r   r   r   TrL   FNAOS)r   r   Nrf   rI   rJ   rK   rN   )#r   r	   r   rh   rk   r   rQ   r   r   r   r   r   eyer   r   ordr   r   r   r&   r   r   sgesvdsgesvd_bufferSizedgesvddgesvd_bufferSizecgesvdcgesvd_bufferSizezgesvdzgesvd_bufferSizer   r   r~   r\   r]   )r_   r   r   r	   r   r   
real_dtyper   r2   r   r   r   r   ra   
trans_flagrg   job_ujob_vtu_ptrvt_ptrr7   r5   gesvdgesvd_bufferSizerB   rC   	rwork_ptrrworkrD   rD   rE   svd  s   $









r   )rz   )TT)r   r   cupy_backends.cuda.apir   
cupy._corer   	cupy.cudar   cupy.linalgr   rF   rd   rx   r   r   r   r   rD   rD   rD   rE   <module>   s    `3D
& B