o
    X۷ih                     @  s   d dl mZ d dlZd dlZd dlmZ d dlZd dlmZm	Z	 d dl
mZ d dlmZ d dlmZ dd	gZd
Zejedddgdd eD  dZdd Zdd ZG dd dZdddZdS )    )annotationsN)prod)
_get_dtype_get_module_func)_not_a_knot)
csr_matrix)spsolvedoublezthrust::complex<double>a  
#include <cupy/complex.cuh>
#include <cupy/math_constants.h>

__forceinline__ __device__ int getCurThreadIdx()
{
    const int threadsPerBlock   = blockDim.x;
    const int curThreadIdx    = ( blockIdx.x * threadsPerBlock ) + threadIdx.x;
    return curThreadIdx;
}
__forceinline__ __device__ int getThreadNum()
{
    const int blocksPerGrid     = gridDim.x;
    const int threadsPerBlock   = blockDim.x;
    const int threadNum         = blocksPerGrid * threadsPerBlock;
    return threadNum;
}

__device__ long long find_interval(
        const double* t, double xp, int k, int n, bool extrapolate) {

    double tb = *&t[k];
    double te = *&t[n];

    if(isnan(xp)) {
        return -1;
    }

    if((xp < tb || xp > te) && !extrapolate) {
        return -1;
    }

    int left = k;
    int right = n;
    int mid;
    bool found = false;

    while(left < right && !found) {
        mid = ((right + left) / 2);
        if(xp > *&t[mid]) {
            left = mid + 1;
        } else if (xp < *&t[mid]) {
            right = mid - 1;
        } else {
            found = true;
        }
    }

    int default_value = left - 1 < k ? k : left - 1;
    int result = found ? mid + 1 : default_value + 1;

    while(xp >= *&t[result] && result != n) {
        result++;
    }

    return result - 1;
}

__device__ void d_boor(
        const double* t, double xp, long long interval, const long long k,
        const int mu, double* temp) {

    double* h = temp;
    double* hh = h + k + 1;

    int ind, j, n;
    double xa, xb, w;

    /*
     * Perform k-m "standard" deBoor iterations
     * so that h contains the k+1 non-zero values of beta_{ell,k-m}(x)
     * needed to calculate the remaining derivatives.
     */
    h[0] = 1.0;
    for (j = 1; j <= k - mu; j++) {
        for(int p = 0; p < j; p++) {
            hh[p] = h[p];
        }
        h[0] = 0.0;
        for (n = 1; n <= j; n++) {
            ind = interval + n;
            xb = t[ind];
            xa = t[ind - j];
            if (xb == xa) {
                h[n] = 0.0;
                continue;
            }
            w = hh[n - 1]/(xb - xa);
            h[n - 1] += w*(xb - xp);
            h[n] = w*(xp - xa);
        }
    }

    /*
     * Now do m "derivative" recursions
     * to convert the values of beta into the mth derivative
     */
    for (j = k - mu + 1; j <= k; j++) {
        for(int p = 0; p < j; p++) {
            hh[p] = h[p];
        }
        h[0] = 0.0;
        for (n = 1; n <= j; n++) {
            ind = interval + n;
            xb = t[ind];
            xa = t[ind - j];
            if (xb == xa) {
                h[mu] = 0.0;
                continue;
            }
            w = ((double) j) * hh[n - 1]/(xb - xa);
            h[n - 1] -= w;
            h[n] = w;
        }
    }

}

__global__ void compute_nd_bsplines(
        const double* xi, int n_xi, const double* t, const long long* t_sz,
        int ndim, int max_t, const long long* k, const long long* max_k,
        const int* nu, bool extrapolate, bool check_all_validity,
        long long* intervals, double* splines, bool* invalid) {

    int total = n_xi * ndim;

    for(int midx = getCurThreadIdx(); midx < total; midx += getThreadNum()) {
        int idx = midx / ndim;
        int dim_idx = midx % ndim;

        double xd = xi[ndim * idx + dim_idx];
        const double* dim_t = t + max_t * dim_idx;
        const long long dim_k = k[dim_idx];
        const long long dim_t_sz = t_sz[dim_idx];

        long long interval = find_interval(
            dim_t, xd, dim_k, dim_t_sz - dim_k - 1, extrapolate);

        if(interval < 0) {
            invalid[check_all_validity ? idx : blockIdx.x] = true;
            continue;
        }

        intervals[ndim * idx + dim_idx] = interval;
        double* dim_splines = (
            splines + ndim * (2 * max_k[0] + 2) * idx +
            (2 * max_k[0] + 2) * dim_idx);

        d_boor(dim_t, xd, interval, dim_k, nu[dim_idx], dim_splines);
    }
}

template<typename T>
__global__ void eval_nd_bspline(
        const long long* indices_k1d, const long long* strides_c1,
        const double* b, const long long* intervals, const long long* k,
        bool* invalid, T* c1r, long long* volume, int ndim, int num_c,
        int n_xi, const long long* max_k, T* out) {

    for(int idx = getCurThreadIdx(); idx < n_xi; idx += getThreadNum()) {
        if(invalid[idx]) {
            for(int i = 0; i < num_c; i++) {
                out[num_c * idx + i] = CUDART_NAN;
            }
            continue;
        }

        for(int i = 0; i < num_c; i++) {
            out[num_c * idx + i] = 0;
        }

        const double* idx_splines = b + ndim * (2 * max_k[0] + 2) * idx;

        for(long long iflat = 0; iflat < *volume; iflat++) {
            const long long* idx_b = indices_k1d + ndim * iflat;
            long long idx_cflat_base = 0;
            double factor = 1.0;

            for(int d = 0; d < ndim; d++) {
                const double* dim_splines = (
                    idx_splines + (2 * max_k[0] + 2) * d);
                factor *= dim_splines[idx_b[d]];
                long long d_idx = idx_b[d] + intervals[ndim * idx + d] - k[d];
                idx_cflat_base += d_idx * strides_c1[d];
            }

            for(int i = 0; i < num_c; i++) {
                out[num_c * idx + i] += c1r[idx_cflat_base + i] * factor;
            }
        }
    }
}


__global__ void store_nd_bsplines(
        const long long* indices_k1d, const long long* strides_c1,
        const double* b, const long long* intervals, const long long* k,
        long long* volume, int ndim, int n_xi, const long long* max_k,
        long long* out_idx, double* out) {

    int total = n_xi * volume[0];

    for(int midx = getCurThreadIdx(); midx < total; midx += getThreadNum()) {
        int idx = midx / volume[0];
        int iflat = midx % volume[0];

        const double* idx_splines = b + ndim * (2 * max_k[0] + 2) * idx;
        const long long* idx_b = indices_k1d + ndim * iflat;

        long long idx_cflat_base = 0;
        double factor = 1.0;

        for(int d = 0; d < ndim; d++) {
            const double* dim_splines = (
                idx_splines + (2 * max_k[0] + 2) * d);
            factor *= dim_splines[idx_b[d]];
            long long d_idx = idx_b[d] + intervals[ndim * idx + d] - k[d];
            idx_cflat_base += d_idx * strides_c1[d];
        }

        out_idx[volume[0] * idx + iflat] = idx_cflat_base;
        out[volume[0] * idx + iflat] = factor;
    }
}
)z
-std=c++17compute_nd_bsplinesstore_nd_bsplinesc                 C  s   g | ]}d | dqS )zeval_nd_bspline<> ).0tr   r   X/home/ubuntu/vllm_env/lib/python3.10/site-packages/cupyx/scipy/interpolate/_ndbspline.py
<listcomp>   s    r   )codeoptionsname_expressionsc                 C  s   |  }t|d }tj| jd |jd ftjd}tj| jd |jd d|  d ftjd}tj| jd tj	d}t
d}|dd| | jd ||| jd |jd ||||d|||f tt
d	|}|dd|	|||||||| jd || jd ||
f d
S )a  Evaluate an N-dim tensor product spline or its derivative.

    Parameters
    ----------
    xi : ndarray, shape(npoints, ndim)
        ``npoints`` values to evaluate the spline at, each value is
        a point in an ``ndim``-dimensional space.
    t : ndarray, shape(ndim, max_len_t)
        Array of knots for each dimension.
        This array packs the tuple of knot arrays per dimension into a single
        2D array. The array is ragged (knot lengths may differ), hence
        the real knots in dimension ``d`` are ``t[d, :len_t[d]]``.
    len_t : ndarray, 1D, shape (ndim,)
        Lengths of the knot arrays, per dimension.
    k : tuple of ints, len(ndim)
        Spline degrees in each dimension.
    nu : ndarray of ints, shape(ndim,)
        Orders of derivatives to compute, per dimension.
    extrapolate : int
        Whether to extrapolate out of bounds or return nans.
    c1r: ndarray, one-dimensional
        Flattened array of coefficients.
        The original N-dimensional coefficient array ``c`` has shape
        ``(n1, ..., nd, ...)`` where each ``ni == len(t[d]) - k[d] - 1``,
        and the second "..." represents trailing dimensions of ``c``.
        In code, given the C-ordered array ``c``, ``c1r`` is
        ``c1 = c.reshape(c.shape[:ndim] + (-1,)); c1r = c1.ravel()``
    num_c_tr : int
        The number of elements of ``c1r``, which correspond to the trailing
        dimensions of ``c``. In code, this is
        ``c1 = c.reshape(c.shape[:ndim] + (-1,)); num_c_tr = c1.shape[-1]``.
    strides_c1 : ndarray, one-dimensional
        Pre-computed strides of the ``c1`` array.
        Note: These are *data* strides, not numpy-style byte strides.
        This array is equivalent to
        ``[stride // s1.dtype.itemsize for stride in s1.strides]``.
    indices_k1d : ndarray, shape((k+1)**ndim, ndim)
        Pre-computed mapping between indices for iterating over a flattened
        array of shape ``[k[d] + 1) for d in range(ndim)`` and
        ndim-dimensional indices of the ``(k+1,)*ndim`` dimensional array.
        This is essentially a transposed version of
        ``cupy.unravel_index(cupy.arange((k+1)**ndim), (k+1,)*ndim)``.
    out : ndarray, shape (npoints, num_c_tr)
        Output values of the b-spline at given ``xi`` points.

    Notes
    -----

    This function is essentially equivalent to the following: given an
    N-dimensional vector ``x = (x1, x2, ..., xN)``, iterate over the
    dimensions, form linear combinations of products,
    B(x1) * B(x2) * ... B(xN) of (k+1)**N b-splines which are non-zero
    at ``x``.

    Since b-splines are localized, the sum has (k+1)**N non-zero elements.

    If ``i = (i1, i2, ..., iN)`` is a vector if intervals of the knot
    vectors, ``t[d, id] <= xd < t[d, id+1]``, for ``d=1, 2, ..., N``, then
    the core loop of this function is nothing but

    ```
    result = 0
    iters = [range(i[d] - self.k[d], i[d] + 1) for d in range(ndim)]
    for idx in itertools.product(*iters):
        term = self.c[idx] * cupy.prod([B(x[d], self.k[d], idx[d], self.t[d])
                                        for d in range(ndim)])
        result += term
    ```

    For efficiency reasons, we iterate over the flattened versions of the
    arrays.

       r   dtype   r
         Teval_nd_bsplineN)maxcupyr   emptyshapeint64itemfloat64zerosbool_
NDBSPL_MODget_functionr   )xir   len_tknuextrapolatec1rnum_c_tr
strides_c1indices_k1doutmax_kvolume	intervalssplinesinvalidr
   r   r   r   r   evaluate_ndbspline   s"   L $
"
r8   c                 C  s  | j d }| j d }| }t|d }| }tj| j d |j d ftjd}	tj| j d |j d d|  d ftjd}
tj	dtj
d}tj	|tjd}tdd | D }|tj||jd }tj|dd df }tj|ddd	 tjdddd	  }tt||}tj|tjdj }tj|| ftjd
}tj|| ftjd
}tjd|| d |tjd}td}|dd| | j d ||| j d |j d |||dd|	|
|f t| rtdtd}|dd|||
|	||t|t||||f |||fS )ae  Construct the N-D tensor product collocation matrix as a CSR array.

    In the dense representation, each row of the collocation matrix corresponds
    to a data point and contains non-zero b-spline basis functions which are
    non-zero at this data point.

    Parameters
    ----------
    xvals : ndarray, shape(size, ndim)
        Data points. ``xvals[j, :]`` gives the ``j``-th data point as an
        ``ndim``-dimensional array.
    t : tuple of 1D arrays, length-ndim
        Tuple of knot vectors
    k : ndarray, shape (ndim,)
        Spline degrees

    Returns
    -------
    csr_data, csr_indices, csr_indptr
        The collocation matrix in the CSR array format.

    Notes
    -----
    Algorithm: given `xvals` and the tuple of knots `t`, we construct a tensor
    product spline, i.e. a linear combination of

        B(x1; i1, t1) * B(x2; i2, t2) * ... * B(xN; iN, tN)


    Here ``B(x; i, t)`` is the ``i``-th b-spline defined by the knot vector
    ``t`` evaluated at ``x``.

    Since ``B`` functions are localized, for each point `(x1, ..., xN)` we
    loop over the dimensions, and
    - find the the location in the knot array, `t[i] <= x < t[i+1]`,
    - compute all non-zero `B` values
    - place these values into the relevant row

    In the dense representation, the collocation matrix would have had a row
    per data point, and each row has the values of the basis elements
    (i.e., tensor products of B-splines) evaluated at this data point.
    Since the matrix is very sparse (has size = len(x)**ndim, with only
    (k+1)**ndim non-zero elements per row), we construct it in the CSR format.
    r   r   r   r   r   c                 s      | ]}|d  V  qdS r   Nr   r   kdr   r   r   	<genexpr>      zcolloc_nd.<locals>.<genexpr>N)r!   r   r
   r   r   TFzOut of boundsr   )r!   r   r   r   getr    r"   r#   r$   r%   r&   tupleasarrayr   r_cumprodcopyunravel_indexarangeTr'   r(   any
ValueErrorint)xvalsr   r*   r+   sizendimr3   r4   
cpu_volumer5   r6   r7   r,   k1_shapec_shapecscstridesindices_indices_k1dcsr_indicescsr_data
csr_indptrr
   store_nd_splinesr   r   r   	colloc_ndX  sF   
-
 $(
"

rZ   c                   @  s<   e Zd ZdZddddZdddddZedd
dZdS )	NdBSplineaq  Tensor product spline object.

    The value at point ``xp = (x1, x2, ..., xN)`` is evaluated as a linear
    combination of products of one-dimensional b-splines in each of the ``N``
    dimensions::

       c[i1, i2, ..., iN] * B(x1; i1, t1) * B(x2; i2, t2) * ... * B(xN; iN, tN)


    Here ``B(x; i, t)`` is the ``i``-th b-spline defined by the knot vector
    ``t`` evaluated at ``x``.

    Parameters
    ----------
    t : tuple of 1D ndarrays
        knot vectors in directions 1, 2, ... N,
        ``len(t[i]) == n[i] + k + 1``
    c : ndarray, shape (n1, n2, ..., nN, ...)
        b-spline coefficients
    k : int or length-d tuple of integers
        spline degrees.
        A single integer is interpreted as having this degree for
        all dimensions.
    extrapolate : bool, optional
        Whether to extrapolate out-of-bounds inputs, or return `nan`.
        Default is to extrapolate.

    Attributes
    ----------
    t : tuple of ndarrays
        Knots vectors.
    c : ndarray
        Coefficients of the tensor-produce spline.
    k : tuple of integers
        Degrees for each dimension.
    extrapolate : bool, optional
        Whether to extrapolate or return nans for out-of-bounds inputs.
        Defaults to true.

    See Also
    --------
    BSpline : a one-dimensional B-spline object
    NdPPoly : an N-dimensional piecewise tensor product polynomial

    N)r-   c                C  s:  t |}zt | W n ty   |f| }Y nw t ||kr.tdt |dt |dtdd |D | _tdd |D | _t|| _|d u rNd}t	|| _
t|| _t|D ]}| j| }| j| }|jd | d	 }	|dk r~td
| d|jd	krtd| d|	|d	 k rtdd| d  d| d| dt|dk  rtd| dt t|||	d	  dk rtd| dt| std| d| jj|k rtd| d| jj| |	krtd| d| jj|  dt | d|	 d| dq]t| jj}
tj| j|
d| _d S )Nzlen(t)=z != len(k)=.c                 s  s    | ]}t |V  qd S N)operatorindex)r   kir   r   r   r=     s    z%NdBSpline.__init__.<locals>.<genexpr>c                 s  s    | ]
}t j|td V  qdS r   N)r   ascontiguousarrayfloatr   tir   r   r   r=     s    Tr   r   zSpline degree in dimension z cannot be negative.zKnot vector in dimension z must be one-dimensional.zNeed at least r   z knots for degree z in dimension zKnots in dimension z# must be in a non-decreasing order.z.Need at least two internal knots in dimension z should not have nans or infs.zCoefficients must be at least z-dimensional.z,Knots, coefficients and degree in dimension z are inconsistent: got z coefficients for z knots, need at least z for k=r   )len	TypeErrorrJ   rA   r+   r   r   rB   cboolr-   ranger!   rN   diffrI   uniqueisfiniteallr   r   rb   )selfr   rh   r+   r-   rN   dtdr<   ndtr   r   r   __init__  sp   








zNdBSpline.__init__)r,   r-   c                  s\  t | j}|du r| j}t|}|du rtj|ftjd}n2tj|tjd}|jdks2|j	d |kr@t
d|dt | j dt|dk  rPt
d|tj|td}|j	}|d	|d	 }t|}|d	 |krwt
d
| d| tj| jtjd}dd | jD }tj|t|ftd}|tj t|D ]}	| j|	 ||	dt | j|	 f< qtj|tjd}tdd | jD }
ttt|
|
}tj|tjdj }| j| jj	d| d    }tj fdd jD tjd} j	d	 }tj|j	dd	 |f  j d}t!||||||||||| ||dd	 | jj	|d  S )a@  Evaluate the tensor product b-spline at ``xi``.

        Parameters
        ----------
        xi : array_like, shape(..., ndim)
            The coordinates to evaluate the interpolator at.
            This can be a list or tuple of ndim-dimensional points
            or an array with the shape (num_points, ndim).
        nu : array_like, optional, shape (ndim,)
            Orders of derivatives to evaluate. Each must be non-negative.
            Defaults to the zeroth derivivative.
        extrapolate : bool, optional
            Whether to exrapolate based on first and last intervals in each
            dimension, or return `nan`. Default is to ``self.extrapolate``.

        Returns
        -------
        values : ndarray, shape ``xi.shape[:-1] + self.c.shape[ndim:]``
            Interpolated values at ``xi``
        Nr   r   r   z'invalid number of derivative orders nu=z for ndim = r\   z%derivatives must be positive, got nu=r?   zShapes: xi.shape=z
 and ndim=c                 S     g | ]}t |qS r   rf   rd   r   r   r   r   X      z&NdBSpline.__call__.<locals>.<listcomp>c                 s  r9   r:   r   r;   r   r   r   r=   `  r>   z%NdBSpline.__call__.<locals>.<genexpr>)r?   c                   s   g | ]}| j j qS r   )r   itemsize)r   sc1r   r   r   i  s    )"rf   r   r-   ri   r   r%   int32rB   rN   r!   rJ   rI   r#   rc   reshaperb   r+   r"   r    r   fillnanrj   rA   rF   rG   r   rH   rE   rh   ravelstridesr   r8   )ro   r)   r,   r-   rN   xi_shape_kr*   _trp   r!   rT   rU   r.   _strides_c1r/   r2   r   rz   r   __call__%  sj   

"
 "zNdBSpline.__call__Tc                 C  s  t j|t jd}|jd }t||kr tdt| d|dzt| W n ty4   |f| }Y nw dd |D }t j|t|ft	d}|
t j t|D ]}|| ||dt|| f< qQt j|t jd}t j|t jd}	t||||	\}
}}t|
||fS )	a  Construct the design matrix as a CSR format sparse array.

        Parameters
        ----------
        xvals :  ndarray, shape(npts, ndim)
            Data points. ``xvals[j, :]`` gives the ``j``-th data point as an
            ``ndim``-dimensional array.
        t : tuple of 1D ndarrays, length-ndim
            Knot vectors in directions 1, 2, ... ndim,
        k : int
            B-spline degree.
        extrapolate : bool, optional
            Whether to extrapolate out-of-bounds values of raise a `ValueError`

        Returns
        -------
        design_matrix : a CSR matrix
            Each row of the design matrix corresponds to a value in `xvals` and
            contains values of b-spline basis elements which are non-zero
            at this value.

        r   r?   z*Data and knots are inconsistent: len(t) = z for  ndim=r\   c                 S  ru   r   rv   rd   r   r   r   r     rw   z+NdBSpline.design_matrix.<locals>.<listcomp>N)r   rB   r$   r!   rf   rJ   rg   r    r   rc   r~   r   rj   r"   rZ   r   )clsrL   r   r+   r-   rN   r*   r   rp   kkdatarT   indptrr   r   r   design_matrix}  s,   
zNdBSpline.design_matrix)T)__name__
__module____qualname____doc__rt   r   classmethodr   r   r   r   r   r[     s    .9Xr[      c                   sp  t }tdd D }zt   W n ty!    f|  Y nw tD ](\}}t t|}| | krNtd| d| d |  d | d  d	q&t fd	dt|D }tjd
d t	j
 D td}	t|	| }
|j}t|d| t||d f}||}t|jtjrt|
|jt|
|jd  }nt|
|}||||d  }t|| S )a  Construct an interpolating NdBspline.

    Parameters
    ----------
    points : tuple of ndarrays of float, with shapes (m1,), ... (mN,)
        The points defining the regular grid in N dimensions. The points in
        each dimension (i.e. every element of the `points` tuple) must be
        strictly ascending or descending.
    values : ndarray of float, shape (m1, ..., mN, ...)
        The data on the regular grid in n dimensions.
    k : int, optional
        The spline degree. Must be odd. Default is cubic, k=3
    solver : a `scipy.sparse.linalg` solver (iterative or direct), optional.
        An iterative solver from `scipy.sparse.linalg` or a direct one,
        `sparse.sparse.linalg.spsolve`.
        Used to solve the sparse linear system
        ``design_matrix @ coefficients = rhs`` for the coefficients.
        Default is `scipy.sparse.linalg.gcrotmk`
    solver_args : dict, optional
        Additional arguments for the solver. The call signature is
        ``solver(csr_array, rhs_vector, **solver_args)``

    Returns
    -------
    spl : NdBSpline object

    Notes
    -----
    Boundary conditions are not-a-knot in all dimensions.
    c                 s  s    | ]}t |V  qd S r]   rv   )r   xr   r   r   r=     r>   zmake_ndbspl.<locals>.<genexpr>z
There are z points in dimension z, but order z requires at least  r   z points per dimension.c                 3  s,    | ]}t tj| td  | V  qdS ra   )r   r   rB   rc   )r   rp   r+   pointsr   r   r=     s    c                 S  s   g | ]}|qS r   r   )r   xvr   r   r   r     s    zmake_ndbspl.<locals>.<listcomp>r   Ny              ?)rf   rA   rg   	enumerater   
atleast_1drJ   rj   rB   	itertoolsproductrc   r[   r   r!   r   r}   
issubdtyper   complexfloatingr   realimag)r   valuesr+   rN   r   rp   pointnumptsr   rL   matrv_shape
vals_shapevalscoefr   r   r   make_ndbspl  sB   

 


r   )r   )
__future__r   r   r^   mathr   r    cupyx.scipy.interpolate._bspliner   r   !cupyx.scipy.interpolate._bspline2r   cupyx.scipy.sparser   cupyx.scipy.sparse.linalgr   TYPES
NDBSPL_DEF	RawModuler'   r8   rZ   r[   r   r   r   r   r   <module>   s.     c_e r