o
    X۷izu                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
ZddgZdd	gZd
ZejedZdZejedd eD dZdZejedd eD dZdd Zdd Zd$ddZdd Zdd Zd%ddZd%d d!ZG d"d# d#ZdS )&    )annotationsN)internal)get_typename)
csr_matrixdoublezthrust::complex<double>intz	long longan  
#include <cupy/complex.cuh>
#include <cupy/float16.cuh>  // TODO(seberg): Add this via type_headers?

extern "C" {
__global__ void find_interval(
        const double* t, const double* x, long long* out,
        int k, int n, bool extrapolate, int total_x) {

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if(idx >= total_x) {
        return;
    }

    double xp = *&x[idx];
    double tb = *&t[k];
    double te = *&t[n];

    if(isnan(xp)) {
        out[idx] = -1;
        return;
    }

    if((xp < tb || xp > te) && !extrapolate) {
        out[idx] = -1;
        return;
    }

    int left = k;
    int right = n;
    int mid;
    bool found = false;

    while(left < right && !found) {
        mid = ((right + left) / 2);
        if(xp > *&t[mid]) {
            left = mid + 1;
        } else if (xp < *&t[mid]) {
            right = mid - 1;
        } else {
            found = true;
        }
    }

    int default_value = left - 1 < k ? k : left - 1;
    int result = found ? mid + 1 : default_value + 1;

    while(result != n && xp >= *&t[result]) {
        result++;
    }

    out[idx] = result - 1;
}
}
)codea	  
#include <cupy/complex.cuh>
#include <cupy/math_constants.h>
#include <cupy/float16.cuh>  // TODO(seberg): Add this via type_headers?
#define COMPUTE_LINEAR 0x1

template<typename T>
__global__ void d_boor(
        const double* t, const T* c, const int k, const int mu,
        const double* x, const long long* intervals, T* out,
        double* temp, int num_c, int mode, int num_x) {

    int idx = blockDim.x * blockIdx.x + threadIdx.x;

    if(idx >= num_x) {
        return;
    }

    double xp = *&x[idx];
    long long interval = *&intervals[idx];

    double* h = temp + idx * (2 * k + 1);
    double* hh = h + k + 1;

    int ind, j, n;
    double xa, xb, w;

    if(mode == COMPUTE_LINEAR && interval < 0) {
        for(j = 0; j < num_c; j++) {
            out[num_c * idx + j] = CUDART_NAN;
        }
        return;
    }

    /*
     * Perform k-m "standard" deBoor iterations
     * so that h contains the k+1 non-zero values of beta_{ell,k-m}(x)
     * needed to calculate the remaining derivatives.
     */
    h[0] = 1.0;
    for (j = 1; j <= k - mu; j++) {
        for(int p = 0; p < j; p++) {
            hh[p] = h[p];
        }
        h[0] = 0.0;
        for (n = 1; n <= j; n++) {
            ind = interval + n;
            xb = t[ind];
            xa = t[ind - j];
            if (xb == xa) {
                h[n] = 0.0;
                continue;
            }
            w = hh[n - 1]/(xb - xa);
            h[n - 1] += w*(xb - xp);
            h[n] = w*(xp - xa);
        }
    }

    /*
     * Now do m "derivative" recursions
     * to convert the values of beta into the mth derivative
     */
    for (j = k - mu + 1; j <= k; j++) {
        for(int p = 0; p < j; p++) {
            hh[p] = h[p];
        }
        h[0] = 0.0;
        for (n = 1; n <= j; n++) {
            ind = interval + n;
            xb = t[ind];
            xa = t[ind - j];
            if (xb == xa) {
                h[n] = 0.0;
                continue;
            }
            w = ((double) j) * hh[n - 1]/(xb - xa);
            h[n - 1] -= w;
            h[n] = w;
        }
    }

    if(mode != COMPUTE_LINEAR) {
        return;
    }

    // Compute linear combinations
    for(j = 0; j < num_c; j++) {
        out[num_c * idx + j] = 0;
        for(n = 0; n < k + 1; n++) {
            out[num_c * idx + j] = (
                out[num_c * idx + j] +
                c[(interval + n - k) * num_c + j] * ((T) h[n]));
        }
    }

}
c                 C     g | ]}d | dqS )zd_boor<> ).0	type_namer   r   V/home/ubuntu/vllm_env/lib/python3.10/site-packages/cupyx/scipy/interpolate/_bspline.py
<listcomp>       r   )r   name_expressionsa  
#include <cupy/complex.cuh>
#include <cupy/float16.cuh>  // TODO(seberg): Add this via type_headers?

template<typename U>
__global__ void compute_design_matrix(
        const int k, const long long* intervals, double* bspline_basis,
        double* data, U* indices, int num_intervals) {

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if(idx >= num_intervals) {
        return;
    }

    long long interval = *&intervals[idx];

    double* work = bspline_basis + idx * (2 * k + 1);

    for(int j = 0; j <= k; j++) {
        int m = (k + 1) * idx + j;
        data[m] = work[j];
        indices[m] = (U) (interval - k + j);
    }
}
c                 C  r	   )zcompute_design_matrix<r
   r   )r   ityper   r   r   r      r   c                 G  s>   dd |D }d |}|r| d| dn|}| |}|S )Nc                 S  s   g | ]}t |jqS r   )r   dtype)r   argr   r   r   r      s    z$_get_module_func.<locals>.<listcomp>z, <r
   )joinget_function)module	func_nametemplate_argsargs_dtypestemplatekernel_namekernelr   r   r   _get_module_func   s
   

r   c                 C  s   t | t jr
t jS t jS )z>Return np.complex128 for complex dtypes, np.float64 otherwise.)cupy
issubdtypecomplexfloating
complex128float64r   r   r   r   
_get_dtype   s   r&   Fc                 C  sJ   t | } t | } t| j}| j|dd} |r#t |  s#td| S )z~Convert the input into a C contiguous float array.
    NB: Upcasts half- and single-precision floats to double precision.
    F)copyz$Array must not contain infs or nans.)	r    asarrayascontiguousarrayr&   r   astypeisfiniteall
ValueError)xcheck_finitedtypr   r   r   _as_float_array   s   


r1   c                 C  s   | j d | d }tj|tjd}ttd}	|	|j d d d d fd| ||||||j d f tt|j dd }
t	|j d d| d  }tt
d	|}||j d d d d fd| ||||||||
d|j d f dS )
a1  
    Evaluate a spline in the B-spline basis.

    Parameters
    ----------
    t : ndarray, shape (n+k+1)
        knots
    c : ndarray, shape (n, m)
        B-spline coefficients
    xp : ndarray, shape (s,)
        Points to evaluate the spline at.
    nu : int
        Order of derivative to evaluate.
    extrapolate : int, optional
        Whether to extrapolate to ouf-of-bounds points, or to return NaNs.
    out : ndarray, shape (s, m)
        Computed values of the spline at each of the input points.
        This argument is modified in-place.
    r      r%   find_interval   r4   N   d_boor)shaper    
empty_likeint64r   INTERVAL_MODULEr   npprodemptyD_BOOR_MODULE)tckxpnuextrapolateoutn	intervalsinterval_kernelnum_ctempd_boor_kernelr   r   r   _evaluate_spline   s   
rM   c                 C  s(  |j d | d }tj| tjd}ttd}|| j d d d d fd|| ||||| j d f t| j d d| d  }ttd| }	|	| j d d d d fd|d	|d| |d	|dd| j d f tj| j d |d  tj	d}
tt
d
|}|| j d d d d fd||||
|| j d f |
|fS )a  
    Returns a design matrix in CSR format.
    Note that only indices is passed, but not indptr because indptr is already
    precomputed in the calling Python function design_matrix.

    Parameters
    ----------
    x : array_like, shape (n,)
        Points to evaluate the spline at.
    t : array_like, shape (nt,)
        Sorted 1D array of knots.
    k : int
        B-spline degree.
    extrapolate : bool, optional
        Whether to extrapolate to ouf-of-bounds points.
    indices : ndarray, shape (n * (k + 1),)
        Preallocated indices of the final CSR array.
    Returns
    -------
    data
        The data array of a CSR array of the b-spline design matrix.
        In each row all the basis elements are evaluated at the certain point
        (first row - x[0], ..., last row - x[-1]).

    indices
        The indices array of a CSR array of the b-spline design matrix.
    r   r2   r%   r3   r4   r5   r6   r7   Ncompute_design_matrix)r8   r    r9   r:   r   r;   r>   r?   zerosr$   DESIGN_MAT_MODULE)r.   r@   rB   rE   indicesrG   rH   rI   bspline_basisrL   datadesign_mat_kernelr   r   r   _make_design_matrix  s.   

rU   r2   c           	   
   C  s(  |dk r
t | | S | \}}}||krtd|| d f tdfdt|jdd   }zNt|D ]G}||d d |d| d   }|| }|dd|  |dd|   | | }tj|t	|f|jdd  f }|dd }|d8 }q2W n t
y } ztd	| |d}~ww |||fS )
a  
    Compute the spline representation of the derivative of a given spline

    Parameters
    ----------
    tck : tuple of (t, c, k)
        Spline whose derivative to compute
    n : int, optional
        Order of derivative to evaluate. Default: 1

    Returns
    -------
    tck_der : tuple of (t2, c2, k2)
        Spline of order k2=k-n representing the derivative
        of the input spline.

    Notes
    -----
    .. seealso:: :class:`scipy.interpolate.splder`

    See Also
    --------
    splantider, splev, spalde
    r   z@Order of derivative (n = %r) must be <= order of spline (k = %r)r6   NNr2   zIThe spline has internal repeated knots and is not differentiable %d times)
splantiderr-   slicelenr8   ranger    r_r<   rO   FloatingPointError)	tckrG   r@   rA   rB   shjdter   r   r   splderL  s4   

 "($

rd   c                 C  s   |dk r
t | | S | \}}}tdfdt|jdd   }t|D ]T}||d d |d| d   }|| }tj|d| d  | dd|d  }tjtd|jdd  ||d g|d  f }tj|d ||d f }|d7 }q#|||fS )	a  
    Compute the spline for the antiderivative (integral) of a given spline.

    Parameters
    ----------
    tck : tuple of (t, c, k)
        Spline whose antiderivative to compute
    n : int, optional
        Order of antiderivative to evaluate. Default: 1

    Returns
    -------
    tck_ader : tuple of (t2, c2, k2)
        Spline of order k2=k+n representing the antiderivative of the input
        spline.

    See Also
    --------
    splder, splev, spalde

    Notes
    -----
    The `splder` function is the inverse operation of this function.
    Namely, ``splder(splantider(tck))`` is identical to `tck`, modulo
    rounding error.

    .. seealso:: :class:`scipy.interpolate.splantider`
    r   NrV   r2   )axisr2   rW   r6   )	rd   rZ   r[   r8   r\   r    cumsumr]   rO   )r_   rG   r@   rA   rB   r`   ra   rb   r   r   r   rY     s   
 "(

rY   c                   @  s   e Zd ZdZdddZedddZedd	 Zedd
dZ	edddZ
d ddZdd Zdd Zd!ddZd!ddZd"ddZdS )#BSplinea  Univariate spline in the B-spline basis.

    .. math::
        S(x) = \sum_{j=0}^{n-1} c_j  B_{j, k; t}(x)

    where :math:`B_{j, k; t}` are B-spline basis functions of degree `k`
    and knots `t`.

    Parameters
    ----------
    t : ndarray, shape (n+k+1,)
        knots
    c : ndarray, shape (>=n, ...)
        spline coefficients
    k : int
        B-spline degree
    extrapolate : bool or 'periodic', optional
        whether to extrapolate beyond the base interval, ``t[k] .. t[n]``,
        or to return nans.
        If True, extrapolates the first and last polynomial pieces of b-spline
        functions active on the base interval.
        If 'periodic', periodic extrapolation is used.
        Default is True.
    axis : int, optional
        Interpolation axis. Default is zero.

    Attributes
    ----------
    t : ndarray
        knot vector
    c : ndarray
        spline coefficients
    k : int
        spline degree
    extrapolate : bool
        If True, extrapolates the first and last polynomial pieces of b-spline
        functions active on the base interval.
    axis : int
        Interpolation axis.
    tck : tuple
        A read-only equivalent of ``(self.t, self.c, self.k)``

    Notes
    -----
    B-spline basis elements are defined via

    .. math::
        B_{i, 0}(x) = 1, \textrm{if $t_i \le x < t_{i+1}$, otherwise $0$,}

        B_{i, k}(x) = \frac{x - t_i}{t_{i+k} - t_i} B_{i, k-1}(x)
                 + \frac{t_{i+k+1} - x}{t_{i+k+1} - t_{i+1}} B_{i+1, k-1}(x)

    **Implementation details**

    - At least ``k+1`` coefficients are required for a spline of degree `k`,
      so that ``n >= k+1``. Additional coefficients, ``c[j]`` with
      ``j > n``, are ignored.

    - B-spline basis elements of degree `k` form a partition of unity on the
      *base interval*, ``t[k] <= x <= t[n]``.

    - Based on [1]_ and [2]_

    .. seealso:: :class:`scipy.interpolate.BSpline`

    References
    ----------
    .. [1] Tom Lyche and Knut Morken, Spline methods,
        http://www.uio.no/studier/emner/matnat/ifi/INF-MAT5340/v05/undervisningsmateriale/
    .. [2] Carl de Boor, A practical guide to splines, Springer, 2001.
    Tr   c                 C  s~  t || _t|| _tj|tjd| _|dkr|| _	nt
|| _	| jjd | j d }t|| jj}|| _|dkrEt| j|d| _|dk rMtd| jjdkrWtd|| jd k rjtdd| d |f t| jdk  rxtd	tt| j||d  dk rtd
t| j std| jjdk rtd| jjd |k rtdt| jj}tj| j|d| _d S )Nr%   periodicr   r2    Spline order cannot be negative.z$Knot vector must be one-dimensional.z$Need at least %d knots for degree %dr6   z(Knots must be in a non-decreasing order.z!Need at least two internal knots.z#Knots should not have nans or infs.z,Coefficients must be at least 1-dimensional.z0Knots, coefficients and degree are inconsistent.)operatorindexrB   r    r(   rA   r)   r$   r@   rE   boolr8   r   _normalize_axis_indexndimre   moveaxisr-   diffanyr[   uniquer+   r,   r&   r   )selfr@   rA   rB   rE   re   rG   rb   r   r   r   __init__  sB   
 zBSpline.__init__c                 C  s0   t | }||||_|_|_||_||_|S )zConstruct a spline without making checks.
        Accepts same parameters as the regular constructor. Input arrays
        `t` and `c` must of correct shape and dtype.
        )object__new__r@   rA   rB   rE   re   )clsr@   rA   rB   rE   re   rt   r   r   r   construct_fast5  s
   
zBSpline.construct_fastc                 C  s   | j | j| jfS )z@Equivalent to ``(self.t, self.c, self.k)`` (read-only).
        )r@   rA   rB   rt   r   r   r   r_   A  s   zBSpline.tckc                 C  sb   t |d }t|}tj|d d f| ||d d f| f }t|}d||< | ||||S )ao  Return a B-spline basis element ``B(x | t[0], ..., t[k+1])``.

        Parameters
        ----------
        t : ndarray, shape (k+2,)
            internal knots
        extrapolate : bool or 'periodic', optional
            whether to extrapolate beyond the base interval,
            ``t[0] .. t[k+1]``, or to return nans.
            If 'periodic', periodic extrapolation is used.
            Default is True.

        Returns
        -------
        basis_element : callable
            A callable representing a B-spline basis element for the knot
            vector `t`.

        Notes
        -----
        The degree of the B-spline, `k`, is inferred from the length of `t` as
        ``len(t)-2``. The knot vector is constructed by appending and
        prepending ``k+1`` elements to internal knots `t`.

        .. seealso:: :class:`scipy.interpolate.BSpline`
        r6   r   r2   rW   g      ?)r[   r1   r    r]   
zeros_likery   )rx   r@   rE   rB   rA   r   r   r   basis_elementG  s   ,
zBSpline.basis_elementFc                 C  s  t |d}t |d}|dkrt|}|dk rtd|jdks.t|dd |dd k r6td| d	t|d
| d
 k rHtd| d	|dkrh|j| d }|| |||  || ||    }d}n!|st||| k st	|||j
d | d  krtd| d	|j
d }||d  }|ttjj	k rtj}ntj}tj||d  |d}tjd|d |d  |d |d}	t|||||\}
}t|
||	f|j
d |j
d | d fdS )a  
        Returns a design matrix as a CSR format sparse array.

        Parameters
        ----------
        x : array_like, shape (n,)
            Points to evaluate the spline at.
        t : array_like, shape (nt,)
            Sorted 1D array of knots.
        k : int
            B-spline degree.
        extrapolate : bool or 'periodic', optional
            Whether to extrapolate based on the first and last intervals
            or raise an error. If 'periodic', periodic extrapolation is used.
            Default is False.

        Returns
        -------
        design_matrix : `csr_matrix` object
            Sparse matrix in CSR format where each row contains all the basis
            elements of the input row (first row = basis elements of x[0],
            ..., last row = basis elements x[-1]).

        Notes
        -----
        In each row of the design matrix all the basis elements are evaluated
        at the certain point (first row - x[0], ..., last row - x[-1]).
        `nt` is a length of the vector of knots: as far as there are
        `nt - k - 1` basis elements, `nt` should be not less than `2 * k + 2`
        to have at least `k + 1` basis element.

        Out of bounds `x` raises a ValueError.

        .. note::
            This method returns a `csr_matrix` instance as CuPy still does not
            have `csr_array`.

        .. seealso:: :class:`scipy.interpolate.BSpline`
        Tri   r   rj   r2   NrW   z2Expect t to be a 1-D sorted array_like, but got t=.r6   zLength t is not enough for k=FzOut of bounds w/ x = r%   )r8   )r1   rm   r-   ro   r<   rr   r[   sizeminmaxr8   r    iinfoint32r:   r>   arangerU   r   )rx   r.   r@   rB   rE   rG   nnz	int_dtyperQ   indptrrS   r   r   r   design_matrixj  sB   
)
(
$.
"
zBSpline.design_matrixNc           	      C  s<  |du r| j }t|}|j|j}}tjt|tjd}|dkrF| jj	| j
 d }| j| j
 || j| j
  | j| | j| j
    }d}tjt|tt| jjdd f| jjd}| |||| ||| jjdd  }| jdkrtt|j}|||| j  |d|  ||| j d  }||}|S )a  
        Evaluate a spline function.

        Parameters
        ----------
        x : array_like
            points to evaluate the spline at.
        nu : int, optional
            derivative to evaluate (default is 0).
        extrapolate : bool or 'periodic', optional
            whether to extrapolate based on the first and last intervals
            or return nans. If 'periodic', periodic extrapolation is used.
            Default is `self.extrapolate`.

        Returns
        -------
        y : array_like
            Shape is determined by replacing the interpolation axis
            in the coefficient array with the shape of `x`.
        Nr%   ri   r2   Fr   )rE   r    r(   r8   ro   r)   ravelr$   r@   r~   rB   r>   r[   r   r<   r=   rA   r   	_evaluatereshapere   listr\   	transpose)	rt   r.   rD   rE   x_shapex_ndimrG   rF   	dim_orderr   r   r   __call__  s4   
 
&


zBSpline.__call__c                 C  s4   | j jjs| j  | _ | jjjs| j | _d S d S rV   )r@   flagsc_contiguousr'   rA   rz   r   r   r   _ensure_c_contiguous  s
   

zBSpline._ensure_c_contiguousc                 C  s.   t | j| j| jjd d| j|||| d S )Nr   rW   )rM   r@   rA   r   r8   rB   )rt   rC   rD   rE   rF   r   r   r   r      s   zBSpline._evaluater2   c                 C  sn   | j }t| jt| }|dkr"tj|t|f|jdd  f }t| j|| jf|}| j	|| j
| jdS )al  
        Return a B-spline representing the derivative.

        Parameters
        ----------
        nu : int, optional
            Derivative order.
            Default is 1.

        Returns
        -------
        b : BSpline object
            A new instance representing the derivative.

        See Also
        --------
        splder, splantider
        r   r2   NrE   re   )rA   r[   r@   r    r]   rO   r8   rd   rB   ry   rE   re   )rt   rD   rA   ctr_   r   r   r   
derivative  s   $
zBSpline.derivativec                 C  s   | j }t| jt| }|dkr"tj|t|f|jdd  f }t| j|| jf|}| j	dkr4d}n| j	}| j
||| jdS )a  
        Return a B-spline representing the antiderivative.

        Parameters
        ----------
        nu : int, optional
            Antiderivative order. Default is 1.

        Returns
        -------
        b : BSpline object
            A new instance representing the antiderivative.

        Notes
        -----
        If antiderivative is computed and ``self.extrapolate='periodic'``,
        it will be set to False for the returned instance. This is done because
        the antiderivative is no longer periodic and its correct evaluation
        outside of the initially given x interval is difficult.

        See Also
        --------
        splder, splantider
        r   r2   Nri   Fr   )rA   r[   r@   r    r]   rO   r8   rY   rB   rE   ry   re   )rt   rD   rA   r   r_   rE   r   r   r   antiderivative   s   $
zBSpline.antiderivativec                 C  s4  |du r| j }|   d}||k r||}}d}| jj| j d }|dkr<|s<t|| j| j  }t|| j|  }tj	dt
t| jjdd f| jjd}| j}t| jt| }|dkrttj|t|f|jdd  f }t| j|| jfd\}	}
}|dkri| j| j | j| }}|| }|| }t||\}}|dkrtj||gtjd}t|	|
|
jd d||dd| |d |d  }||9 }ntjdt
t| jjdd f| jjd}||| |  }|| }||krtj||gtjd}t|	|
|
jd d||dd| ||d |d  7 }nrtj||gtjd}t|	|
|
jd d||dd| ||d |d  7 }tj||| | gtjd}t|	|
|
jd d||dd| ||d |d  7 }n#tj||gtjd}t|	|
|
jd d||d|| |d |d  }||9 }||
jdd S )	a  
        Compute a definite integral of the spline.

        Parameters
        ----------
        a : float
            Lower limit of integration.
        b : float
            Upper limit of integration.
        extrapolate : bool or 'periodic', optional
            whether to extrapolate beyond the base interval,
            ``t[k] .. t[-k-1]``, or take the spline to be zero outside of the
            base interval. If 'periodic', periodic extrapolation is used.
            If None (default), use `self.extrapolate`.

        Returns
        -------
        I : array_like
            Definite integral of the spline over the interval ``[a, b]``.
        Nr2   rW   ri   r6   r%   r   F)rE   r   r@   r~   rB   r   itemr   r    r>   r   r<   r=   rA   r8   r   r[   r]   rO   rY   divmodr(   r$   rM   r   )rt   abrE   signrG   rF   rA   r   tacakatsteperiodinterval	n_periodsleftr.   integralr   r   r   	integrateH  sv   
"$


 




zBSpline.integrate)Tr   )TF)r   Nrf   rV   )__name__
__module____qualname____doc__ru   classmethodry   propertyr_   r|   r   r   r   r   r   r   r   r   r   r   r   rh     s"    
H.
"
[4

(rh   r   rf   )
__future__r   rk   r    
cupy._corer   cupy._core._scalarr   cupyx.scipy.sparser   numpyr<   TYPES	INT_TYPESINTERVAL_KERNEL	RawModuler;   D_BOOR_KERNELr?   DESIGN_MAT_KERNELrP   r   r&   r1   rM   rU   rd   rY   rh   r   r   r   r   <module>   s@    8c
%
5
<6