o
    װiU                     @   s  d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
 dd ZejejejgZejejejejgZejejgZejejejejgZee e e Zd	d
 eeeD Zdd
 eD Zdd
 eD Z dZ!dZ"ej#e!ddd
 e D dd
 eD  dd
 eD  dd
 eD  dZ$ej#e"ddd
 e D dd
 eD  dd
 eD  dd
 eD  dd
 eD  dd
 eD  dZ%dd Z&dd Z'dd  Z(d!d" Z)d,d%d&Z*d'd( Z+	$	d-d*d+Z,dS ).    )productN)_normalize_axis_index)get_typename)runtime)
axis_slicec                 C   s&   t | }|dkrtjrd}|S d}|S )Nfloat16__halfhalf)r   r   is_hip)dtypetypename r   Q/home/ubuntu/.local/lib/python3.10/site-packages/cupyx/scipy/signal/_iir_utils.py_get_typename   s   r   c                 C   s.   g | ]\}}t ||t |u r||fqS r   )cupypromote_typesr   .0xyr   r   r   
<listcomp>   s    r   c                 C   s   g | ]}t |qS r   r   )r   tr   r   r   r       s    c                 C   s    g | ]\}}t |t |fqS r   r   r   r   r   r   r   !   s     a  
#include <cupy/math_constants.h>
#include <cupy/carray.cuh>
#include <cupy/complex.cuh>

template<typename U, typename T>
__global__ void compute_correction_factors(
        const int m, const int k, const T* b, U* out) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if(idx >= k) {
        return;
    }

    U* out_start = out + idx * (k + m);
    U* out_off = out_start + k;

    for(int i = 0; i < m; i++) {
        U acc = 0.0;
        for(int j = 0; j < k; j++) {
            acc += ((U) b[j]) * out_off[i - j - 1];

        }
        out_off[i] = acc;
    }
}

template<typename T>
__global__ void first_pass_iir(
        const int m, const int k, const int n, const int n_blocks,
        const int carries_stride, const T* factors, T* out,
        T* carries) {
    int orig_idx = blockDim.x * (blockIdx.x % n_blocks) + threadIdx.x;

    int num_row = blockIdx.x / n_blocks;
    int idx = 2 * orig_idx + 1;

    if(idx >= n) {
        return;
    }

    int group_num = idx / m;
    int group_pos = idx % m;

    T* out_off = out + num_row * n;
    T* carries_off = carries + num_row * carries_stride;

    T* group_start = out_off + m * group_num;
    T* group_carries = carries_off + k * group_num;

    int pos = group_pos;
    int up_bound = pos;
    int low_bound = pos;
    int rel_pos;

    for(int level = 1, iter = 1; level < m; level *=2, iter++) {
        int sz = min(pow(2.0f, ((float) iter)), ((float) m));

        if(level > 1) {
            int factor = ceil(pos / ((float) sz));
            up_bound = sz * factor - 1;
            low_bound = up_bound - level + 1;
        }

        if(level == 1) {
            pos = low_bound;
        }

        if(pos < low_bound) {
            pos += level / 2;
        }

        if(pos + m * group_num >= n) {
            break;
        }

        rel_pos = pos % level;
        T carry = 0.0;
        for(int i = 1; i <= min(k, level); i++) {
            T k_value = group_start[low_bound - i];
            const T* k_factors = factors + (m + k) * (i - 1) + k;
            T factor = k_factors[rel_pos];
            carry += k_value * factor;
        }

        group_start[pos] += carry;
        __syncthreads();
    }

    if(pos >= m - k) {
        if(carries != NULL) {
            group_carries[pos - (m - k)] = group_start[pos];
        }
    }

}

template<typename T>
__global__ void correct_carries(
    const int m, const int k, const int n_blocks, const int carries_stride,
    const int offset, const T* factors, T* carries) {

    int idx = threadIdx.x;
    int pos = idx + (m - k);
    T* row_carries = carries + carries_stride * blockIdx.x;

    for(int i = offset; i < n_blocks; i++) {
        T* this_carries = row_carries + k * (i + (1 - offset));
        T* prev_carries = row_carries + k * (i - offset);

        T carry = 0.0;
        for(int j = 1; j <= k; j++) {
            const T* k_factors = factors + (m + k) * (j - 1) + k;
            T factor = k_factors[pos];
            T k_value = prev_carries[k - j];
            carry += factor * k_value;
        }

        this_carries[idx] += carry;
        __syncthreads();
    }
}

template<typename T>
__global__ void second_pass_iir(
        const int m, const int k, const int n, const int carries_stride,
        const int n_blocks, const int offset, const T* factors,
        T* carries, T* out) {

    int idx = blockDim.x * (blockIdx.x % n_blocks) + threadIdx.x;
    idx += offset * m;

    int row_num = blockIdx.x / n_blocks;
    int n_group = idx / m;
    int pos = idx % m;

    if(idx >= n) {
        return;
    }

    T* out_off = out + row_num * n;
    T* carries_off = carries + row_num * carries_stride;
    const T* prev_carries = carries_off + (n_group - offset) * k;

    T carry = 0.0;
    for(int i = 1; i <= k; i++) {
        const T* k_factors = factors + (m + k) * (i - 1) + k;
        T factor = k_factors[pos];
        T k_value = prev_carries[k - i];
        carry += factor * k_value;
    }

    out_off[idx] += carry;
}
a  
#include <cupy/math_constants.h>
#include <cupy/carray.cuh>
#include <cupy/complex.cuh>

template<typename T>
__global__ void pick_carries(
        const int m, const int n, const int carries_stride, const int n_blocks,
        const int offset, T* x, T* carries) {

    int idx = m * (blockIdx.x % n_blocks) + threadIdx.x + m - 2;
    int pos = threadIdx.x;
    int row_num = blockIdx.x / n_blocks;
    int n_group = idx / m;

    T* x_off = x + row_num * n;
    T* carries_off = carries + row_num * carries_stride;
    T* group_carries = carries_off + (n_group + (1 - offset)) * 2;

    if(idx >= n) {
        return;
    }

    group_carries[pos] = x_off[idx];
}

template<typename U, typename T>
__global__ void compute_correction_factors_sos(
        const int m, const T* f_const, U* all_out) {

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> bc_d[2];
    T* b_c = reinterpret_cast<T*>(bc_d);

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> off_d[4];
    U* off_cache = reinterpret_cast<U*>(off_d);

    int idx = threadIdx.x;
    int num_section = blockIdx.x;

    const int n_const = 6;
    const int a_off = 3;
    const int k = 2;
    const int off_idx = 1;

    U* out = all_out + num_section * k * m;
    U* out_start = out + idx * m;
    const T* b = f_const + num_section * n_const + a_off + 1;

    b_c[idx] = b[idx];
    __syncthreads();

    U* this_cache = off_cache + k * idx;
    this_cache[off_idx - idx] = 1;
    this_cache[idx] = 0;

    for(int i = 0; i < m; i++) {
        U acc = 0.0;
        for(int j = 0; j < k; j++) {
            acc += -((U) b_c[j]) * this_cache[off_idx - j];

        }
        this_cache[0] = this_cache[1];
        this_cache[1] = acc;
        out_start[i] = acc;
    }
}


template<typename T>
__global__ void first_pass_iir_sos(
        const int m, const int n, const int n_blocks,
        const T* factors, T* out, T* carries) {

    extern __shared__ unsigned int thread_status[2];
    extern __shared__ __align__(sizeof(T)) thrust::complex<double> fc_d[2 * 1024];
    T* factor_cache = reinterpret_cast<T*>(fc_d);

    int orig_idx = blockDim.x * (blockIdx.x % n_blocks) + threadIdx.x;

    int num_row = blockIdx.x / n_blocks;
    int idx = 2 * orig_idx + 1;
    const int k = 2;

    if(idx >= n) {
        return;
    }

    int group_num = idx / m;
    int group_pos = idx % m;
    T* out_off = out + num_row * n;
    T* carries_off = carries + num_row * n_blocks * k;

    T* group_start = out_off + m * group_num;
    T* group_carries = carries_off + group_num * k;

    const T* section_factors = factors;
    T* section_carries = group_carries;

    factor_cache[group_pos] = section_factors[group_pos];
    factor_cache[group_pos - 1] = section_factors[group_pos - 1];
    factor_cache[m + group_pos] = section_factors[m + group_pos];
    factor_cache[m + group_pos - 1] = section_factors[m + group_pos - 1];
    __syncthreads();

    int pos = group_pos;
    int up_bound = pos;
    int low_bound = pos;
    int rel_pos;

    for(int level = 1, iter = 1; level < m; level *= 2, iter++) {
        int sz = min(pow(2.0f, ((float) iter)), ((float) m));

        if(level > 1) {
            int factor = ceil(pos / ((float) sz));
            up_bound = sz * factor - 1;
            low_bound = up_bound - level + 1;
        }

        if(level == 1) {
            pos = low_bound;
        }

        if(pos < low_bound) {
            pos += level / 2;
        }

        if(pos + m * group_num >= n) {
            break;
        }

        rel_pos = pos % level;
        T carry = 0.0;
        for(int i = 1; i <= min(k, level); i++) {
            T k_value = group_start[low_bound - i];
            const T* k_factors = factor_cache + m  * (i - 1);
            T factor = k_factors[rel_pos];
            carry += k_value * factor;
        }

        group_start[pos] += carry;
        __syncthreads();
    }

    if(pos >= m - k) {
        if(carries != NULL) {
            section_carries[pos - (m - k)] = group_start[pos];
        }
    }
}

template<typename T>
__global__ void correct_carries_sos(
    const int m, const int n_blocks, const int carries_stride,
    const int offset, const T* factors, T* carries) {

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> fcd3[4];
    T* factor_cache = reinterpret_cast<T*>(fcd3);

    int idx = threadIdx.x;
    const int k = 2;
    int pos = idx + (m - k);
    T* row_carries = carries + carries_stride * blockIdx.x;

    factor_cache[2 * idx] = factors[pos];
    factor_cache[2 * idx + 1] = factors[m + pos];
    __syncthreads();

    for(int i = offset; i < n_blocks; i++) {
        T* this_carries = row_carries + k * (i + (1 - offset));
        T* prev_carries = row_carries + k * (i - offset);

        T carry = 0.0;
        for(int j = 1; j <= k; j++) {
            // const T* k_factors = factors + m * (j - 1);
            // T factor = k_factors[pos];
            T factor = factor_cache[2 * idx + (j - 1)];
            T k_value = prev_carries[k - j];
            carry += factor * k_value;
        }

        this_carries[idx] += carry;
        __syncthreads();
    }
}

template<typename T>
__global__ void second_pass_iir_sos(
        const int m, const int n, const int carries_stride,
        const int n_blocks, const int offset, const T* factors,
        T* carries, T* out) {

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> fcd2[2 * 1024];
    T* factor_cache = reinterpret_cast<T*>(fcd2);

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> c_d[2];
    T* carries_cache = reinterpret_cast<T*>(c_d);

    int idx = blockDim.x * (blockIdx.x % n_blocks) + threadIdx.x;
    idx += offset * m;

    int row_num = blockIdx.x / n_blocks;
    int n_group = idx / m;
    int pos = idx % m;
    const int k = 2;

    T* out_off = out + row_num * n;
    T* carries_off = carries + row_num * carries_stride;
    const T* prev_carries = carries_off + (n_group - offset) * k;

    if(pos < k) {
        carries_cache[pos] = prev_carries[pos];
    }

    if(idx >= n) {
        return;
    }

    factor_cache[pos] = factors[pos];
    factor_cache[pos + m] = factors[pos + m];
    __syncthreads();

    T carry = 0.0;
    for(int i = 1; i <= k; i++) {
        const T* k_factors = factor_cache + m * (i - 1);
        T factor = k_factors[pos];
        T k_value = carries_cache[k - i];
        carry += factor * k_value;
    }

    out_off[idx] += carry;
}

template<typename T>
__global__ void fir_sos(
        const int m, const int n, const int carries_stride, const int n_blocks,
        const int offset, const T* sos, T* carries, T* out) {

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> fir_cc[1024 + 2];
    T* fir_cache = reinterpret_cast<T*>(fir_cc);

    extern __shared__ __align__(sizeof(T)) thrust::complex<double> fir_b[3];
    T* b = reinterpret_cast<T*>(fir_b);

    int idx = blockDim.x * (blockIdx.x % n_blocks) + threadIdx.x;
    int row_num = blockIdx.x / n_blocks;
    int n_group = idx / m;
    int pos = idx % m;
    const int k = 2;

    T* out_row = out + row_num * n;
    T* out_off = out_row + n_group * m;
    T* carries_off = carries + row_num * carries_stride;
    T* this_carries = carries_off + k * (n_group + (1 - offset));
    T* group_carries = carries_off + (n_group - offset) * k;

    if(pos <= k) {
        b[pos] = sos[pos];
    }

    if(pos < k) {
        if(offset && n_group == 0) {
            fir_cache[pos] = 0;
        } else {
            fir_cache[pos] = group_carries[pos];
        }
    }

    if(idx >= n) {
        return;
    }

    fir_cache[pos + k] = out_off[pos];
    __syncthreads();

    T acc = 0.0;
    for(int i = k; i >= 0; i--) {
        acc += fir_cache[pos + i] * b[k - i];
    }

    out_off[pos] = acc;
}
)z
-std=c++11c                 C   "   g | ]\}}d | d| dqS )zcompute_correction_factors<, >r   r   r   r   r   r         c                 C      g | ]}d | dqS )zcorrect_carries<r   r   r   r   r   r   r   r         c                 C   r   )zfirst_pass_iir<r   r   r   r   r   r   r     r   c                 C   r   )zsecond_pass_iir<r   r   r   r   r   r   r     r   )codeoptionsname_expressionsc                 C   r   )zcompute_correction_factors_sos<r   r   r   r   r   r   r   r     r   c                 C   r   )zpick_carries<r   r   r   r   r   r   r     r   c                 C   r   )zcorrect_carries_sos<r   r   r   r   r   r   r     r   c                 C   r   )zfirst_pass_iir_sos<r   r   r   r   r   r   r     r   c                 C   r   )zsecond_pass_iir_sos<r   r   r   r   r   r   r     r   c                 C   r   )zfir_sos<r   r   r   r   r   r   r     r   c                 G   s>   dd |D }d |}|r| d| dn|}| |}|S )Nc                 S   s   g | ]}t |jqS r   )r   r   )r   argr   r   r   r     s    z$_get_module_func.<locals>.<listcomp>r   <r   )joinget_function)module	func_nametemplate_argsargs_dtypestemplatekernel_namekernelr   r   r   _get_module_func  s
   

r.   c                 C   s>   t | |d} | j}| d| jd } | jjs|  } | |fS )Nr   moveaxisshapereshapeflagsc_contiguouscopyr   axisx_shaper   r   r   collapse_2d  s   r:   c                 C   sJ   t | |d d} | j}| | jd d| jd } | jjs!|  } | |fS )N   r/   r   r0   r7   r   r   r   collapse_2d_rest  s   r<   c                 C   sb   | j }tj||d}tj|d d d tj||f|df }ttd|| }||fd||| |f |S )Nr   r/   compute_correction_factorsr;   )sizer   eyec_emptyr.   
IIR_MODULE)ablock_szr   k
correctioncorr_kernelr   r   r   r>     s   r>   r/      c                 C   s  |d u rt | j|j}||}|d ur||}| j}| j}t||}|j}|| }	|dkrBt| |\} }|d urBt||\}}
t j	| |dd}| jdkrQdn| jd }|	| d | }|| }t j
||d}t j|d d d t j||f|df }t j|||f|d}ttd||}ttd|}ttd	|}ttd
|}||fd||||f ||f|d f|||	||| |||f |d ur|jdkrt ||d|jf}n|jdkr||d|jd }|jdkr|}n	t j||fdd}|jjs| }|dks|d ur8t|d u }|| }|d|  | }||f|f|||||||f ||| f|f|||	||||||f	 |dkrR||}t |d|}|jjsR| }|S )Nr;   Tr   r6   r   r=   r/   r>   first_pass_iirsecond_pass_iircorrect_carriesr?      )r8   )r   result_typer   astyper2   ndimr   r@   r:   arrayrA   rB   rC   r.   rD   broadcast_tor3   concatenater4   r5   r6   intr1   )r   rE   r8   zir   rF   r9   x_ndimrG   n_outnum_rowsn_blockstotal_blocksrH   carriesrI   first_pass_kernelsecond_pass_kernelcarry_correction_kernelstarting_groupblocks_to_mergecarries_strider   r   r   	apply_iir  s   











rf   c                 C   sD   | j d }tj|d|f|d}ttd|| }||fd|| |f |S )Nr   rO   r=   compute_correction_factors_sos)rO   )r2   r   rC   r.   IIR_SOS_MODULE)sosrF   r   
n_sectionsrH   rI   r   r   r   rg   e  s   
rg   Tc                  C   s  |d u rt | j|j}||}|d ur||}| j}| j}	|jd }
t||	}d}|| }d }|	dkr=t| |\} }|d urHt||\}}|d u rTt j	| |dd}| jdkr[dn| jd }|| d | }|| }t
|||}t j|||f|d}|}d }|d urt |}t j||d |f|d}ttd|}ttd|}ttd	|}ttd
|}ttd|}t|d u }|| }|d|  | }||| f|f|||||||f t|
D ]}|| }|d ur||d d d df }||d d dd d f< t||d |||d d d df< |r!||| f|f||||||||f ||f|d f||||| ||f |dks=|d ur|d ure||d d dd f }||d d dd d f< ||d d dd d d f< ||f|f|||||| |f ||| f|f||||||| ||f |r||| f|f|||||||f |d urt||d |||d d dd f< q|	dkr||}t |d|}|jjs| }|d ur||}t|dkrt |d|}|jjs| }|d ur||fS |S )Nr   rO   r;   TrK   r=   first_pass_iir_sossecond_pass_iir_soscorrect_carries_sosfir_sospick_carriesr/   )r   rP   r   rQ   r2   rR   r   r:   r<   rS   rg   rC   
empty_liker.   rh   rV   ranger   r3   r1   r4   r5   r6   len) r   ri   r8   rW   r   rF   	apply_firr[   r9   rX   rj   rG   rY   zi_shaper\   r]   r^   rH   r_   all_carrieszi_outr`   ra   rb   
fir_kernelcarries_kernelrc   rd   re   sb
section_zir   r   r   apply_iir_sosn  s   







"



"






r|   )r/   NNrJ   )r/   NNrJ   TN)-	itertoolsr   r   cupy._core.internalr   cupy._core._scalarr   cupy_backends.cuda.apir   cupyx.scipy.signal._arraytoolsr   r   r   float32float64FLOAT_TYPESint8int16int32int64	INT_TYPES	complex64
complex128COMPLEX_TYPESuint8uint16uint32uint64UNSIGNED_TYPESTYPES
TYPE_PAIRS
TYPE_NAMESTYPE_PAIR_NAMES
IIR_KERNELIIR_SOS_KERNEL	RawModulerD   rh   r.   r:   r<   r>   rf   rg   r|   r   r   r   r   <module>   sr      		
S	