o
    -i!S                     @   s  U d Z ddlZdZi Zeed< i Zeed< dZd$d	d
ZG dd dZ	G dd dZ
G dd de	ZG dd deZG dd de	ZG dd de	ZG dd dZG dd de	ZG dd de	ZG dd de	ZG dd de	ZG d d! d!e	ZG d"d# d#e	ZdS )%ac  
Backends in `einops` are organized to meet the following requirements
- backends are not imported unless those are actually needed, because
    - backends may not be installed
    - importing all available backends will drive to significant memory footprint
    - backends may be present but installed with errors (but never used),
      importing may drive to crashes
- backend should be either symbolic or imperative
    - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
- if backend can't provide symbols for shape dimensions, UnknownSize objects are used
    NzAlex Rogozhnikov_loaded_backends_type2backendFreturnAbstractBackendc                 C   s   t | }t|d}|dur|S tt D ]\}}|| r'|t|< |  S qg }t }|rA|	 }|| 7 }|
| |s0|D ]1}trLtd| |jtvrt|jtjv rttr_td|j | }|t|j< || rt|t|< |  S qCtdt | )z
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    NzTesting for subclass of zImported backend for z Tensor type unknown to einops {})typer   getlistr   itemsis_appropriate_typer   __subclasses__popappend_debug_importingprintframework_namesysmodulesRuntimeErrorformat)tensor_type_resultr   backendbackend_subclassesbackendsBackendSubclass r   I/home/ubuntu/LTX-2/.venv/lib/python3.10/site-packages/einops/_backends.pyget_backend   s<   





r   c                   @   s   e Zd ZU dZeed< dd Zdd Zdd Zd	d
 Z	dd Z
dd Zdd Zdd Zdd Zdd ZdefddZdd Zdd Zdd Zd efd!d"Zd#d$ Zd%d& Zd'd( Zd)d* Zd+S ),r   zJBase backend class, major part of methods are only for debugging purposes.r   c                 C      t  )z4helper method should recognize tensors it can handleNotImplementedErrorselfr   r   r   r   r
   C      z#AbstractBackend.is_appropriate_typec                 C      t dNz.framework doesn't support imperative executionr    r#   xr   r   r   
from_numpyG      zAbstractBackend.from_numpyc                 C   r%   r&   r    r'   r   r   r   to_numpyJ   r*   zAbstractBackend.to_numpyc                 C   r%   Nz/framework doesn't support symbolic computationsr    r#   shaper   r   r   create_symbolM   r*   zAbstractBackend.create_symbolc                 C   r%   r,   r    r#   symbolsymbol_value_pairsr   r   r   eval_symbolP      zAbstractBackend.eval_symbolc                 C   r%   )Nz"framework doesn't implement aranger    r#   startstopr   r   r   arangeT   r4   zAbstractBackend.arangec                 C      |j S )zashape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)r.   r'   r   r   r   r.   X   r$   zAbstractBackend.shapec                 C   
   | |S Nreshaper#   r(   r.   r   r   r   r>   \      
zAbstractBackend.reshapec                 C   r;   r<   	transposer#   r(   axesr   r   r   rB   _   r@   zAbstractBackend.transposec                 C   s   t |||dS Naxis)getattrr#   r(   	operationrD   r   r   r   reduceb      zAbstractBackend.reducetensorsc                 C   r   r<   r    r#   rM   r   r   r   stack_on_zeroth_dimensione      z)AbstractBackend.stack_on_zeroth_dimensionc                 C   r   r<   r    r#   r(   new_positionr   r   r   add_axish   rP   zAbstractBackend.add_axisc                 C   s@   dg| }|  D ]\}}| ||}|||< q	| |t|S )N   )r	   rS   tiletupler#   r(   n_axespos2lenrepeatsaxis_positionaxis_lengthr   r   r   add_axesk   s
   

zAbstractBackend.add_axesc                 C   r   )z!repeats - same lengths as x.shaper    r#   r(   rZ   r   r   r   rU   r   r$   zAbstractBackend.tilerG   c                 C   r   )zzconcatenates tensors along axis.
        Assume identical across tensors: devices, dtypes and shapes except selected axis.r    r#   rM   rG   r   r   r   concatv      zAbstractBackend.concatc                 C   r   r<   r    r'   r   r   r   is_float_type{   ra   zAbstractBackend.is_float_typec                 C   r%   )Nzbackend does not provide layersr    r#   r   r   r   layers   r*   zAbstractBackend.layersc                 C   s   d | jS )Nz<einops backend for {}>)r   r   rc   r   r   r   __repr__      zAbstractBackend.__repr__c                 G   r%   )Nzbackend does not support einsumr    r#   patternr(   r   r   r   einsum   r*   zAbstractBackend.einsumN)__name__
__module____qualname____doc__str__annotations__r
   r)   r+   r/   r3   r8   r.   r>   rB   rK   r   rO   rS   r]   rU   intr`   rb   rd   re   ri   r   r   r   r   r   >   s,   
 c                   @   s8   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d ZdS )UnknownSizezUpseudo-symbol for symbolic frameworks which do not provide symbols for shape elementsc                 C      | S r<   r   r#   otherr   r   r   __floordiv__      zUnknownSize.__floordiv__c                 C   s   dS NTr   rs   r   r   r   __eq__   rv   zUnknownSize.__eq__c                 C   rr   r<   r   rs   r   r   r   __mul__   rv   zUnknownSize.__mul__c                 C   rr   r<   r   rs   r   r   r   __rmul__   rv   zUnknownSize.__rmul__c                 C   s   t d S r<   )hashrc   r   r   r   __hash__   r*   zUnknownSize.__hash__N)	rj   rk   rl   rm   ru   rx   ry   rz   r|   r   r   r   r   rq      s    rq   c                   @   t   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zde	fddZ
dd ZdefddZdd Zdd Zdd ZdS )NumpyBackendnumpyc                 C      dd l }|| _d S Nr   )r   np)r#   r   r   r   r   __init__      
zNumpyBackend.__init__c                 C      t || jjS r<   )
isinstancer   ndarrayr"   r   r   r   r
         z NumpyBackend.is_appropriate_typec                 C      |S r<   r   r'   r   r   r   r)      rv   zNumpyBackend.from_numpyc                 C   r   r<   r   r'   r   r   r   r+      rv   zNumpyBackend.to_numpyc                 C      | j ||S r<   )r   r8   r5   r   r   r   r8      r   zNumpyBackend.arangerM   c                 C      | j |S r<   )r   stackrN   r   r   r   rO      rf   z&NumpyBackend.stack_on_zeroth_dimensionc                 C   r   r<   )r   rU   r^   r   r   r   rU      r   zNumpyBackend.tilerG   c                 C      | j j||dS rE   )r   concatenater_   r   r   r   r`      rL   zNumpyBackend.concatc                 C   
   |j dv S N)float16float32float64float128bfloat16dtyper'   r   r   r   rb      r@   zNumpyBackend.is_float_typec                 C   r   r<   )r   expand_dimsrQ   r   r   r   rS      r   zNumpyBackend.add_axisc                 G      | j j|g|R  S r<   )r   ri   rg   r   r   r   ri         zNumpyBackend.einsumN)rj   rk   rl   r   r   r
   r)   r+   r8   r   rO   rU   rp   r`   rb   rS   ri   r   r   r   r   r~          r~   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )
JaxBackendjaxc                    s*   t t|   | j| _dd l}|j| _d S r   )superr   r   r   onp	jax.numpyr   )r#   r   	__class__r   r   r      s   zJaxBackend.__init__c                 C   r   r<   )r   asarrayr'   r   r   r   r)      rf   zJaxBackend.from_numpyc                 C   r   r<   )r   r   r'   r   r   r   r+      rf   zJaxBackend.to_numpy)rj   rk   rl   r   r   r)   r+   __classcell__r   r   r   r   r      s
    r   c                   @      e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
defddZdd Zdd ZdefddZdd Zdd Zdd Zd d! Zd"S )#TorchBackendtorchc                 C   s   dd l }|| _ ddlm} d S )Nr   rT   )_torch_specific)r    r   )r#   r   r   r   r   r   r      s   zTorchBackend.__init__c                 C   r   r<   )r   r   Tensorr"   r   r   r   r
      r   z TorchBackend.is_appropriate_typec                 C       | j |}| |rd|_|S rw   )r   r)   rb   requires_gradr#   r(   variabler   r   r   r)         
zTorchBackend.from_numpyc                 C      |    S r<   detachcpur   r'   r   r   r   r+      rL   zTorchBackend.to_numpyc                 C      | j j||| j jdS Nr   )r   r8   int64r5   r   r   r   r8         zTorchBackend.arangec                 C   s   |dkr
|j |dS |dkr|j|dS |dkr|j|dS |dkr(|j|dS |dv rDtt|d d d D ]
}t|||d}q7|S td|)	Nmindimmaxsummean)anyallprodUnknown reduction )aminamaxr   r   r   sortedrH   r!   )r#   r(   rJ   reduced_axesir   r   r   rK      s   
zTorchBackend.reducec                 C   r;   r<   permuterC   r   r   r   rB     r@   zTorchBackend.transposerM   c                 C   r   r<   )r   r   rN   r   r   r   rO     rf   z&TorchBackend.stack_on_zeroth_dimensionc                 C   :   dg| }|  D ]\}}| ||}|||< q	||S Nr   r	   rS   expandrW   r   r   r   r]     
   


zTorchBackend.add_axesc                 C   r;   r<   repeatr^   r   r   r   rU     r@   zTorchBackend.tilerG   c                 C   r   Nr   )r   catr_   r   r   r   r`     rL   zTorchBackend.concatc                 C   r   r<   )r   	unsqueezerQ   r   r   r   rS     r   zTorchBackend.add_axisc                 C   s"   |j | jj| jj| jj| jjfv S r<   )r   r   r   r   r   r   r'   r   r   r   rb     s   "zTorchBackend.is_float_typec                 C      ddl m} |S )NrT   )r   )rd   r   )r#   r   r   r   r   rd        zTorchBackend.layersc                 G   r   r<   )r   ri   rg   r   r   r   ri     r   zTorchBackend.einsumNrj   rk   rl   r   r   r
   r)   r+   r8   rK   rB   r   rO   r]   rU   rp   r`   rS   rb   rd   ri   r   r   r   r   r      s"    r   c                   @   r}   )CupyBackendcupyc                 C      dd l }|| _ d S r   )r   )r#   r   r   r   r   r   &  r   zCupyBackend.__init__c                 C   r   r<   )r   r   r   r"   r   r   r   r
   +  r   zCupyBackend.is_appropriate_typec                 C   r   r<   )r   r   r'   r   r   r   r)   .  rf   zCupyBackend.from_numpyc                 C   r   r<   )r   asnumpyr'   r   r   r   r+   1  rf   zCupyBackend.to_numpyc                 C   r   r<   )r   r8   r5   r   r   r   r8   4  r   zCupyBackend.arangerM   c                 C   r   r<   )r   r   rN   r   r   r   rO   7  rf   z%CupyBackend.stack_on_zeroth_dimensionc                 C   r   r<   )r   rU   r^   r   r   r   rU   :  r   zCupyBackend.tilerG   c                 C   r   rE   )r   r   r_   r   r   r   r`   =  rL   zCupyBackend.concatc                 C   r   r<   )r   r   rQ   r   r   r   rS   @  r   zCupyBackend.add_axisc                 C   r   r   r   r'   r   r   r   rb   C  r@   zCupyBackend.is_float_typec                 G   r   r<   )r   ri   rg   r   r   r   ri   F  r   zCupyBackend.einsumN)rj   rk   rl   r   r   r
   r)   r+   r8   r   rO   rU   rp   r`   rS   rb   ri   r   r   r   r   r   #  r   r   c                   @   s6   e Zd ZdZdefddZdd Zdd Zd	d
 ZdS )HashableTuplez.Overcomes non-hashability of symbolic elementselementsc                 C   s
   || _ d S r<   r   )r#   r   r   r   r   r   M  r@   zHashableTuple.__init__c                 c   s    | j D ]}|V  qd S r<   r   r'   r   r   r   __iter__P  s   
zHashableTuple.__iter__c                 C   s
   t | jS r<   )lenr   rc   r   r   r   __len__T  r@   zHashableTuple.__len__c                 C   s
   | j | S r<   r   )r#   itemr   r   r   __getitem__W  r@   zHashableTuple.__getitem__N)	rj   rk   rl   rm   rV   r   r   r   r   r   r   r   r   r   J  s    r   c                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZdefddZdd ZdefddZdd Zdd Zd d! Zd"d# Zd$S )%TensorflowBackend
tensorflowc                 C   r   r   )r   tfr#   r   r   r   r   r   `  r   zTensorflowBackend.__init__c                 C   s   t || jj| jjfS r<   )r   r   r   Variabler"   r   r   r   r
   e  r   z%TensorflowBackend.is_appropriate_typec                 C   s   | j  sJ | j |S r<   )r   executing_eagerlyconvert_to_tensorr'   r   r   r   r)   h  s   zTensorflowBackend.from_numpyc                 C   s   | j  sJ | S r<   )r   r   r   r'   r   r   r   r+   l  s   zTensorflowBackend.to_numpyc                 C   r   r<   )r   ranger5   r   r   r   r8   p  r   zTensorflowBackend.arangec                    sx   | j  rtdd |jD S |j }| j | t fddt|D }zt| |W S  ty;   t| Y S w )Nc                 s   s&    | ]}|d u rt  nt|V  qd S r<   )rq   rp   ).0dr   r   r   	<genexpr>u  s   $ z*TensorflowBackend.shape.<locals>.<genexpr>c                    s   g | ]
\}}|p | qS r   r   )r   r   stf_shaper   r   
<listcomp>z  s    z+TensorflowBackend.shape.<locals>.<listcomp>)	r   r   rV   r.   as_list	enumerater{   BaseExceptionr   )r#   r(   static_shaper.   r   r   r   r.   s  s   

zTensorflowBackend.shapec                 C   s   t | jd| ||dS )Nreduce_rF   )rH   r   rI   r   r   r   rK        zTensorflowBackend.reducec                 C   r   r<   )r   r>   r?   r   r   r   r>     r   zTensorflowBackend.reshapec                 C   r   r<   )r   rB   rC   r   r   r   rB     r   zTensorflowBackend.transposerM   c                 C   r   r<   )r   r   rN   r   r   r   rO     rf   z+TensorflowBackend.stack_on_zeroth_dimensionc                 C   r   r<   )r   rU   r^   r   r   r   rU     r   zTensorflowBackend.tilerG   c                 C   r   rE   )r   r`   r_   r   r   r   r`     rL   zTensorflowBackend.concatc                 C   r   r<   )r   r   rQ   r   r   r   rS     r   zTensorflowBackend.add_axisc                 C   r   r   r   r'   r   r   r   rb     r@   zTensorflowBackend.is_float_typec                 C   r   )NrT   )r   )rd   r   r   r   r   r   rd     r   zTensorflowBackend.layersc                 G   r   r<   )r   ri   rg   r   r   r   ri     r   zTensorflowBackend.einsumN)rj   rk   rl   r   r   r
   r)   r+   r8   r.   rK   r>   rB   r   rO   rU   rp   r`   rS   rb   rd   ri   r   r   r   r   r   ]  s$    r   c                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZdefddZdd ZdefddZdd Zdd Zd d! Zd"S )#TFKerasBackendztensorflow.kerasc                 C   s$   dd l }|| _|j| _|jj| _d S r   )r   r   kerasr   K)r#   r   r   r   r   r     s   zTFKerasBackend.__init__c                 C   s   | j |o| j|S r<   )r   	is_tensorr   is_keras_tensorr"   r   r   r   r
     r   z"TFKerasBackend.is_appropriate_typec                 C   s   | j j|dS )N)batch_shape)r   Inputr-   r   r   r   r/     r   zTFKerasBackend.create_symbolc                 C   s.   | j jdd |D |}|dd |D S )Nc                 S   s   g | ]\}}|qS r   r   )r   var_r   r   r   r         z.TFKerasBackend.eval_symbol.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r   r   )r   r   valr   r   r   r     r  )r   modelsModelpredict_on_batch)r#   r1   r2   modelr   r   r   r3     s   zTFKerasBackend.eval_symbolc                 C   r   r<   )r   r8   r5   r   r   r   r8     r   zTFKerasBackend.arangec                 C   s   | j |}tt|S r<   )r   r.   r   rV   r?   r   r   r   r.     s   zTFKerasBackend.shapec                 C   s   t | j|||dS rE   )rH   r   rI   r   r   r   rK     r   zTFKerasBackend.reducec                 C   r   r<   )r   r>   r?   r   r   r   r>     r   zTFKerasBackend.reshapec                 C   r   r<   )r   permute_dimensionsrC   r   r   r   rB     r   zTFKerasBackend.transposerM   c                 C   r   r<   )r   r   rN   r   r   r   rO     rf   z(TFKerasBackend.stack_on_zeroth_dimensionc                 C   r   r<   )r   rU   r^   r   r   r   rU     r   zTFKerasBackend.tilerG   c                 C   r   rE   )r   r   r_   r   r   r   r`     rL   zTFKerasBackend.concatc                 C   r   r<   )r   r   rQ   r   r   r   rS     r   zTFKerasBackend.add_axisc                 C   s   d| j |v S )Nfloat)r   r   r'   r   r   r   rb     rL   zTFKerasBackend.is_float_typec                 C   r   )NrT   )r   )rd   r   )r#   r   r   r   r   rd     r   zTFKerasBackend.layersN)rj   rk   rl   r   r   r
   r/   r3   r8   r.   rK   r>   rB   r   rO   rU   rp   r`   rS   rb   rd   r   r   r   r   r     s"    r   c                   @   r   )#OneFlowBackendoneflowc                 C   r   r   )r
  flow)r#   r  r   r   r   r     r   zOneFlowBackend.__init__c                 C   r   r<   )r   r  r   r"   r   r   r   r
     r   z"OneFlowBackend.is_appropriate_typec                 C   r   rw   )r  r)   rb   r   r   r   r   r   r)     r   zOneFlowBackend.from_numpyc                 C   r   r<   r   r'   r   r   r   r+     rL   zOneFlowBackend.to_numpyc                 C   r   r   )r  r8   r   r5   r   r   r   r8     r   zOneFlowBackend.arangec                 C   sl   t |ddD ]-}|dkr|j|d\}}q|dkr"|j|d\}}q|dv r/t|||d}qtd||S )NTreverser   r   r   )r   r   r   r   r   r   )r   r   r   rH   r!   )r#   r(   rJ   r   rG   r   r   r   r   rK     s   
zOneFlowBackend.reducec                 C   r;   r<   r   rC   r   r   r   rB      r@   zOneFlowBackend.transposerM   c                 C   r   r<   )r  r   rN   r   r   r   rO     rf   z(OneFlowBackend.stack_on_zeroth_dimensionc                 C   s:   dg| }|  D ]\}}| ||}|||< q	|j| S r   r   rW   r   r   r   r]     r   zOneFlowBackend.add_axesc                 C   r;   r<   r   r^   r   r   r   rU     r@   zOneFlowBackend.tilerG   c                 C   r   r   )r  r`   r_   r   r   r   r`     rL   zOneFlowBackend.concatc                 C   r   r<   )r  r   rQ   r   r   r   rS     r   zOneFlowBackend.add_axisc                 C      |j | jj| jj| jjfv S r<   )r   r  r   r   r   r'   r   r   r   rb        zOneFlowBackend.is_float_typec                 C   r   )NrT   )r
  )rd   r
  )r#   r
  r   r   r   rd     r   zOneFlowBackend.layersc                 G   r   r<   )r  ri   rg   r   r   r   ri     r   zOneFlowBackend.einsumNr   r   r   r   r   r	    s"    r	  c                       s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Z fddZ	dd Z
dd ZdefddZdd Zdd ZdefddZdd Zdd Zd d! Zd"d# Zd$d% Z  ZS )&PaddleBackendpaddlec                 C   r   r   r  r#   r  r   r   r   r   %  r   zPaddleBackend.__init__c                 C   r   r<   )r  r   r"   r   r   r   r
   *  rf   z!PaddleBackend.is_appropriate_typec                 C   s   | j |}d|_|S )NF)r  	to_tensorstop_gradient)r#   r(   r   r   r   r   r)   -  s   zPaddleBackend.from_numpyc                 C   s   |   S r<   )r   r   r'   r   r   r   r+   2  rf   zPaddleBackend.to_numpyc                 C   r   r   )r  r8   r   r5   r   r   r   r8   5  r   zPaddleBackend.arangec                    s4   t ||jkrt |||dS t |||S r   )r   ndimr   rK   squeezerI   r   r   r   rK   8  s   zPaddleBackend.reducec                 C   r;   r<   rA   rC   r   r   r   rB   ?  r@   zPaddleBackend.transposec                 C   r   r   r   rW   r   r   r   r]   B  r   zPaddleBackend.add_axesrM   c                 C   r   r<   )r  r   rN   r   r   r   rO   I  rf   z'PaddleBackend.stack_on_zeroth_dimensionc                 C   r;   r<   r=   r?   r   r   r   r>   L  r@   zPaddleBackend.reshapec                 C   r;   r<   )rU   r^   r   r   r   rU   O  r@   zPaddleBackend.tilerG   c                 C   r   rE   )r  r`   r_   r   r   r   r`   R  rL   zPaddleBackend.concatc                 C   r;   r<   r   rQ   r   r   r   rS   U  r@   zPaddleBackend.add_axisc                 C   r  r<   )r   r  r   r   r   r'   r   r   r   rb   X  r  zPaddleBackend.is_float_typec                 C   r   )NrT   r  )rd   r  r  r   r   r   rd   [  r   zPaddleBackend.layersc                 G   r   r<   )r  ri   rg   r   r   r   ri   `  r   zPaddleBackend.einsumc                 C   s
   t |jS r<   )rV   r.   r'   r   r   r   r.   c  r@   zPaddleBackend.shape)rj   rk   rl   r   r   r
   r)   r+   r8   rK   rB   r]   r   rO   r>   rU   rp   r`   rS   rb   rd   ri   r.   r   r   r   r   r   r  "  s&    r  c                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZdefddZdd Zdd ZdefddZdd Zd d! Zd"S )#TinygradBackendtinygradc                 C   r   r   )r  )r#   r  r   r   r   r   j  r   zTinygradBackend.__init__c                 C   r   r<   )r   r  r   r"   r   r   r   r
   o  r   z#TinygradBackend.is_appropriate_typec                 C   r   r<   )r  r   r'   r   r   r   r)   r  rf   zTinygradBackend.from_numpyc                 C      |  S r<   )r   r'   r   r   r   r+   u  r*   zTinygradBackend.to_numpyc                 C   s   | j j||S r<   )r  r   r8   r5   r   r   r   r8   x  rL   zTinygradBackend.arangec                 C   r9   r<   r:   r'   r   r   r   r.   {  rP   zTinygradBackend.shapec                 C   r;   r<   r=   r?   r   r   r   r>   ~  r@   zTinygradBackend.reshapec                 C   r;   r<   r   rC   r   r   r   rB     r@   zTinygradBackend.transposec                 C   s&   t |ddD ]
}t|||d}q|S )NTr  rF   )r   rH   )r#   r(   rJ   rD   rG   r   r   r   rK     s   zTinygradBackend.reducerM   c                 C   s   | j j|S r<   )r  r   r   rN   r   r   r   rO     r   z)TinygradBackend.stack_on_zeroth_dimensionc                 C   r;   r<   r  rQ   r   r   r   rS     r@   zTinygradBackend.add_axisc                 C   r;   r<   r   r^   r   r   r   rU     r@   zTinygradBackend.tilerG   c                 C   s0   t |dkr|d j|dd  d|iS |d S )NrT   r   r   )r   r   r_   r   r   r   r`     s   0zTinygradBackend.concatc                 C   s   | j j|jS r<   )r  dtypesis_floatr   r'   r   r   r   rb     rL   zTinygradBackend.is_float_typec                 G   s   | j jj|g|R  S r<   )r  r   ri   rg   r   r   r   ri     r   zTinygradBackend.einsumN)rj   rk   rl   r   r   r
   r)   r+   r8   r.   r>   rB   rK   r   rO   rS   rU   rp   r`   rb   ri   r   r   r   r   r  g  s"    r  c                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZdefddZdd ZdefddZdd Zdd Zd S )!PyTensorBackendpytensorc                 C   s   ddl m} || _d S )Nr   )r   )r  r   ptr"   r   r   r   r     s   
zPyTensorBackend.__init__c                 C   r   r<   )r   r   TensorVariabler"   r   r   r   r
     r   z#PyTensorBackend.is_appropriate_typec                 C   s   |j | jjjv S r<   )r   r   r   float_dtypesr'   r   r   r   rb     rL   zPyTensorBackend.is_float_typec                 C   r   r<   )r   	as_tensorr'   r   r   r   r)     rf   zPyTensorBackend.from_numpyc                 C   r  r<   )evalr'   r   r   r   r+     r*   zPyTensorBackend.to_numpyc                 C   s"   t |ttB s
|f}| jj|dS )Nr:   )r   rV   r   r   r   r-   r   r   r   r/     s   zPyTensorBackend.create_symbolc                 C   s   | t|S r<   )r$  dictr0   r   r   r   r3     r   zPyTensorBackend.eval_symbolc                 C   r   r<   )r   r8   r5   r   r   r   r8     r   zPyTensorBackend.arangec                 C   s   t dd t|jj|jD S )Nc                 s   s$    | ]\}}|d ur|n|V  qd S r<   r   )r   
static_dimsymbolic_dimr   r   r   r     s
    
z(PyTensorBackend.shape.<locals>.<genexpr>)rV   zipr   r.   r'   r   r   r   r.     s   zPyTensorBackend.shaperM   c                 C   r   r<   )r   r   rN   r   r   r   rO     rf   z)PyTensorBackend.stack_on_zeroth_dimensionc                 C   r   r<   )r   rU   r^   r   r   r   rU     r   zPyTensorBackend.tilerG   c                 C   r   rE   )r   r   r_   r   r   r   r`     rL   zPyTensorBackend.concatc                 C   r   r<   )r   r   rQ   r   r   r   rS     r   zPyTensorBackend.add_axisc                 G   r   r<   )r   ri   rg   r   r   r   ri     r   zPyTensorBackend.einsumN)rj   rk   rl   r   r   r
   rb   r)   r+   r/   r3   r8   r.   r   rO   rU   rp   r`   rS   ri   r   r   r   r   r    s     r  )r   r   )rm   r   
__author__r   r%  ro   r   r   r   r   rq   r~   r   r   r   r   r   r   r	  r  r  r  r   r   r   r   <module>   s(    
(L'M'F9FE5