o
    Ơi                     @  s   d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	m
Z
 d dlZd dlZd dlZd dlmZmZ d dlmZ G dd deZG d	d
 d
eZG dd deZG dd deZG dd deZdS )    )annotationsN)Enumauto)AnyOptionalTypeUnion)ENABLED_FEATURESneeds_torch_tensorrt_runtime)is_tensorrt_version_supportedc                   @  s"  e Zd ZdZe Z	 e Z	 e Z	 e Z	 e Z		 e Z
	 e Z	 e Z	 e Z	 e Z	 e Z	 e Z	 eZeZeZe	Ze	ZeZeZeZeZe
Ze
Ze
ZeZeZeZeZ eZ!eZ"eZ#e$dddZ%e&	dd ddZ'e&	dd!ddZ(	dd"ddZ)d#ddZ*d$ddZ+d%ddZ,eZ-eZ.dS )&dtypezfEnum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypestr   returnboolc                 C  s.   t | tjrdS t | trt| tjrdS dS )NTF)
isinstancenpr   type
issubclassgeneric)r    r   I/home/ubuntu/.local/lib/python3.10/site-packages/torch_tensorrt/_enums.py
_is_np_objv   s   
zdtype._is_np_objF7Union[torch.dtype, trt.DataType, np.dtype, dtype, type]use_defaultc                 C  sp  t |tjrs|tjkrtjS |tjkrtjS |tjkrtjS |tj	kr&tj
S |tjkr.tjS |tjkr6tjS |tjkr>tjS |tjkrFtjS |tjkrNtjS |tjkrVtjS |tjkr^tjS |rltd| d tjS td| t |tjr|tjjkrtjS |tjjkrtjS |tjj krtjS |tjj!krtj
S |tjj"krtjS |tjj#krtjS |tjj$krtjS |tjj%krtjS |tjj&krtjS t'dr|tjj(krtj)S td| t*|r@|t+jkrtjS |t+jkrtjS |t+j	krtj
S |t+j,krtjS |t+j-krtjS |t+j.krtjS |t+jkr tjS |t+j/kr)tjS |r8td| d tjS tdt0| t |trH|S t1j2rddl3m4} t ||jr||jjkrctjS ||jj	krmtj
S ||jjkrwtjS ||jjkrtjS ||jjkrtjS ||jj5krtjS ||jjkrtjS ||jj6krtj6S td	| td
| d)a  Create a Torch-TensorRT dtype from another library's dtype system.

        Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
        If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.dtype.try_from()``

        Arguments:
            t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
            use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.

        Returns:
            dtype: Equivalent ``torch_tensorrt.dtype`` to ``t``

        Raises:
            TypeError: Unsupported data type or unknown source

        Examples:

            .. code:: py

                # Succeeds
                float_dtype = torch_tensorrt.dtype._from(torch.float) # Returns torch_tensorrt.dtype.f32

                # Throws exception
                float_dtype = torch_tensorrt.dtype._from(torch.complex128)

        zQGiven dtype that does not have direct mapping to Torch-TensorRT supported types (z+), defaulting to torch_tensorrt.dtype.floatzyProvided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: 10.8.0zsProvided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: zpProvided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: r   _ChProvided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: z<Provided unsupported source type for dtype conversion (got: ))7r   torchr   uint8u8int8i8longi64int32i32float8_e4m3fnf8float4_e2m1fn_x2f4halff16floatf32float64f64r   bbfloat16bf16loggingwarning	TypeErrortrtDataTypeUINT8INT8FP8INT32INT64HALFFLOATBOOLBF16r   FP4fp4r   r   int64float16float32bool_strr	   torchscript_frontendtorch_tensorrtr   doubleunknown)clsr   r   r   r   r   r   _from   s   %
















zdtype._from1Union[torch.dtype, trt.DataType, np.dtype, dtype]Optional[dtype]c              
   C  sV   z
t j||d}|W S  ttfy* } ztjd| ddd W Y d}~dS d}~ww )a  Create a Torch-TensorRT dtype from another library's dtype system.

        Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
        If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
        then returns ``None``.


        Arguments:
            t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
            use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.

        Returns:
            Optional(dtype): Equivalent ``torch_tensorrt.dtype`` to ``t`` or ``None``

        Examples:

            .. code:: py

                # Succeeds
                float_dtype = torch_tensorrt.dtype.try_from(torch.float) # Returns torch_tensorrt.dtype.f32

                # Unsupported type
                float_dtype = torch_tensorrt.dtype.try_from(torch.complex128) # Returns None

        r   Conversion from z to torch_tensorrt.dtype failedTexc_infoN)r   rO   
ValueErrorr7   r5   debug)rN   r   r   casted_formater   r   r   try_from  s    zdtype.try_fromIUnion[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]]c                 C  sf  |t jkrs| tjkrt jS | tjkrt jS | tjkrt jS | tjkr%t j	S | tj
kr-t jS | tjkr5t jS | tjkr=t jS | tjkrEt jS | tjkrMt jS | tjkrUt jS | tjkr]t jS |rktd|  d t jS td|  d|tjkr| tjkrtjjS | tjkrtjjS | tjkrtjjS | tj
krtjj S | tjkrtjj!S | tjkrtjj"S | tjkrtjj#S | tjkrtjj$S | tjkrtjj%S |rtjj#S t&dr| tjkrtjj'S td|t(jkr>| tjkrt(jS | tjkrt(jS | tjkrt(j)S | tjkrt(j*S | tjkrt(j+S | tjkrt(jS | tjkr"t(j,S | tjkr+t(j-S | tjkr4t(j.S |r:t(j,S td|tkrE| S t/j0rdd	l1m2} ||jkr| tjkr_|jj	S | tjkri|jjS | tjkrs|jj)S | tjkr}|jjS | tjkr|jjS | tjkr|jjS | tjkr|jjS | tj3kr|jj3S td
|  td| )a  Convert dtype into the equivalent type in [torch, numpy, tensorrt]

        Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
        If  ``self`` is not supported in the target library, then an exception will be raised.
        As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.dtype.try_to()``

        Arguments:
            t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
            use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.

        Returns:
            Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``

        Raises:
            TypeError: Unsupported data type or unknown target

        Examples:

            .. code:: py

                # Succeeds
                float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float

                # Failure
                float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Throws exception

        z8Given dtype that does not have direct mapping to torch (z), defaulting to torch.floatzUnsupported torch dtype (had: r   r   zUnsupported tensorrt dtypezUnsupported numpy dtyper   r   r   z;Provided unsupported destination type for dtype conversion )4r   r   r!   r    r#   r"   r'   intr%   r$   r)   r(   r+   r*   r-   r,   r/   r.   r1   rL   r2   r   r4   r3   r5   r6   r7   r8   r9   r:   r;   r=   r<   r>   r?   r@   rA   rB   r   rC   r   r&   rE   rF   rG   r0   rH   r	   rJ   rK   r   rM   selfr   r   r   r   r   r   toG  s   
$

























zdtype.to;Optional[Union[torch.dtype, trt.DataType, np.dtype, dtype]]c              
   C  sT   z	|  ||}|W S  ttfy) } ztjd| ddd W Y d}~dS d}~ww )a  Convert dtype into the equivalent type in [torch, numpy, tensorrt]

        Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
        If  ``self`` is not supported in the target library, then returns ``None``.

        Arguments:
            t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
            use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.

        Returns:
            Optional(Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``

        Examples:

            .. code:: py

                # Succeeds
                float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float

                # Failure
                float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Returns None

        z/torch_tensorrt.dtype conversion to target type  failedTrT   Nr_   rV   r7   r5   rW   r^   r   r   rX   rY   r   r   r   try_to  s   
zdtype.try_tootherc                 C     t |}t| j|jkS N)r   rO   r   valuer^   re   other_r   r   r   __eq__     
zdtype.__eq__r\   c                 C  
   t | jS rg   hashrh   r^   r   r   r   __hash__     
zdtype.__hash__N)r   r   r   r   F)r   r   r   r   r   r   )r   rP   r   r   r   rQ   )r   r[   r   r   r   rP   )r   r[   r   r   r   r`   )re   rP   r   r   r   r\   )/__name__
__module____qualname____doc__r   rM   r!   r#   r'   r%   r-   r/   r1   r2   r4   r)   r+   r    r"   r&   r$   rE   float8fp8float4rD   r,   fp16rF   r.   fp32rG   rL   fp64r0   r3   staticmethodr   classmethodrO   rZ   r_   rd   rk   rq   r   r\   r   r   r   r   r      s|     + 

'
r   c                   @  s   e Zd ZdZe Z	 e Z	 e Z	 e Z	 e Z		 e Z
	 e Z	 e Z	 e Z	 e Z	 e Z	 e Z	 e Z	 eZeZeZedddZeddd	ZdddZdddZdddZdddZdS )memory_format f;Union[torch.memory_format, trt.TensorFormat, memory_format]r   c                 C  s  t |tjr%|tjkrtjS |tjkrtjS |tjkrtjS tdt t |t	j
r|t	jjkr4tjS |t	jjkr=tjS |t	jjkrFtjS |t	jjkrOtjS |t	jjkrXtjS |t	jjkratjS |t	jjkrjtjS |t	jjkrstjS |t	jjkr|tjS |t	jjkrtjS |t	jj krtj!S |t	jj"krtj#S |t	jj$krtj%S tdt t |tr|S t&j'rddl(m)} t ||jr||jjkrtjS ||jjkrtjS t*dtd)a  Create a Torch-TensorRT memory format enum from another library memory format enum.

        Takes a memory format enum from one of torch, and tensorrt and create a ``torch_tensorrt.memory_format``.
        If the source is not supported or the memory format is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.memory_format.try_from()``

        Arguments:
            f (Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format enum from another library

        Returns:
            memory_format: Equivalent ``torch_tensorrt.memory_format`` to ``f``

        Raises:
            TypeError: Unsupported memory format or unknown source

        Examples:

            .. code:: py

                torchtrt_linear = torch_tensorrt.memory_format._from(torch.contiguous)

        z7Provided an unsupported memory format for tensor, got: z7Provided an unsupported tensor format for tensor, got: r   r   ZProvided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)z=Provided unsupported source type for memory_format conversion)+r   r   r   contiguous_format
contiguouschannels_lastchannels_last_3dr7   r   r8   r9   TensorFormatLINEARlinearCHW2chw2HWC8hwc8CHW4chw4CHW16chw16CHW32chw32DHWC8dhwc8CDHW32cdhw32HWChwc
DLA_LINEAR
dla_linearDLA_HWC4dla_hwc4HWC16hwc16DHWCdhwcr	   rJ   rK   r   rV   )rN   r   r   r   r   r   rO     sj   



zmemory_format._fromOptional[memory_format]c              
   C  R   zt |}|W S  ttfy( } ztjd| ddd W Y d}~dS d}~ww )a  Create a Torch-TensorRT memory format enum from another library memory format enum.

        Takes a memory format enum from one of torch, and tensorrt and create a ``torch_tensorrt.memory_format``.
        If the source is not supported or the memory format is not supported in Torch-TensorRT,
        then ``None`` will be returned.


        Arguments:
            f (Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format enum from another library

        Returns:
            Optional(memory_format): Equivalent ``torch_tensorrt.memory_format`` to ``f``

        Examples:

            .. code:: py

                torchtrt_linear = torch_tensorrt.memory_format.try_from(torch.contiguous)

        rS   z' to torch_tensorrt.memory_format failedTrT   N)r   rO   rV   r7   r5   rW   )rN   r   rX   rY   r   r   r   rZ     s   

zmemory_format.try_fromr   MUnion[Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format]]c                 C  s  |t jkr!| tjkrt jS | tjkrt jS | tjkrt jS td|tjkr| tj	kr/tjj
S | tjkr8tjjS | tjkrAtjjS | tjkrJtjjS | tjkrStjjS | tjkr\tjjS | tjkretjjS | tjkrntjjS | tjkrwtjjS | tjkrtjjS | tjkrtjjS | tjkrtjj S | tj!krtjj"S td|tkr| S t#j$rddl%m&} ||jkr| tjkr|jjS | tjkr|jjS t'dtd)a  Convert ``memory_format`` into the equivalent type in torch or tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent memory format.
        If  ``self`` is not supported in the target library, then an exception will be raised.
        As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.memory_format.try_to()``

        Arguments:
            t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))): Memory format type enum from another library to convert to

        Returns:
            Union(torch.memory_format, tensorrt.TensorFormat, memory_format): Memory format equivalent ``torch_tensorrt.memory_format`` in enum ``t``

        Raises:
            TypeError: Unknown target type or unsupported memory format

        Examples:

            .. code:: py

                # Succeeds
                tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
        zUnsupported torch dtypez"Unsupported tensorrt memory formatr   r   r   zBProvided unsupported destination type for memory format conversion)(r   r   r   r   r   r   r7   r8   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r	   rJ   rK   r   rV   r^   r   r   r   r   r   r_     sf   




















zmemory_format.toEOptional[Union[torch.memory_format, trt.TensorFormat, memory_format]]c              
   C  R   z|  |}|W S  ttfy( } ztjd| ddd W Y d}~dS d}~ww )a  Convert ``memory_format`` into the equivalent type in torch or tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent memory format.
        If  ``self`` is not supported in the target library, then ``None`` will be returned

        Arguments:
            t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))): Memory format type enum from another library to convert to

        Returns:
            Optional(Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format equivalent ``torch_tensorrt.memory_format`` in enum ``t``

        Examples:

            .. code:: py

                # Succeeds
                tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
        z7torch_tensorrt.memory_format conversion to target type ra   TrT   Nrb   r^   r   rX   rY   r   r   r   rd   n     

zmemory_format.try_tore   r   c                 C  s   t |}| j|jkS rg   )r   rO   rh   ri   r   r   r   rk     s   
zmemory_format.__eq__r\   c                 C  rm   rg   rn   rp   r   r   r   rq     rr   zmemory_format.__hash__N)r   r   r   r   )r   r   r   r   )r   r   r   r   )r   r   r   r   )re   r   r   r   rt   )ru   rv   rw   rx   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rO   rZ   r_   rd   rk   rq   r   r   r   r   r     sN    									Z
!
[
#r   c                   @  st   e Zd ZdZe Z	 e Z	 e Z	 edddZ	eddd	Z
	
ddddZ	
ddddZd ddZd!ddZdS )"
DeviceTypez#Type of device TensorRT will targetd!Union[trt.DeviceType, DeviceType]r   c                 C  s   t |tjr|tjjkrtjS |tjjkrtjS tdt |tr#|S tjrHddlm	} t ||jrH||jjkr;tjS ||jjkrDtjS tdt
d)ab  Create a Torch-TensorRT device type enum from a TensorRT device type enum.

        Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
        If the source is not supported or the device type is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.DeviceType.try_from()``

        Arguments:
            d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library

        Returns:
            DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``

        Raises:
            TypeError: Unknown source type or unsupported device type

        Examples:

            .. code:: py

                torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)

        6Provided an unsupported device type (support: GPU/DLA)r   r   z:Provided unsupported source type for DeviceType conversion)r   r8   r   GPUDLArV   r	   rJ   rK   r   r7   )rN   r   r   r   r   r   rO     s*   
zDeviceType._fromOptional[DeviceType]c              
   C  r   )a  Create a Torch-TensorRT device type enum from a TensorRT device type enum.

        Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
        If the source is not supported or the device type is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.DeviceType.try_from()``

        Arguments:
            d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library

        Returns:
            DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``

        Examples:

            .. code:: py

                torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)

        rS   z$ to torch_tensorrt.DeviceType failedTrT   N)r   rO   rV   r7   r5   rW   )rN   r   rX   rY   r   r   r   rZ     s   

zDeviceType.try_fromFr   -Union[Type[trt.DeviceType], Type[DeviceType]]r   r   c                 C  s   |t jkr!| tjkrt jjS | tjkrt jjS |rt jjS td|tkr'| S tjrKddlm} ||jkrK| tjkr>|jjS | tjkrG|jjS tdt	d)a  Convert ``DeviceType`` into the equivalent type in tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent device type.
        If  ``self`` is not supported in the target library, then an exception will be raised.
        As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.DeviceType.try_to()``

        Arguments:
            t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to

        Returns:
            Union(tensorrt.DeviceType, DeviceType): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``

        Raises:
            TypeError: Unknown target type or unsupported device type

        Examples:

            .. code:: py

                # Succeeds
                trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
        r   r   r   z@Provided unsupported destination type for device type conversion)
r8   r   r   r   rV   r	   rJ   rK   r   r7   r]   r   r   r   r_   
  s2   





zDeviceType.to+Optional[Union[trt.DeviceType, DeviceType]]c              
   C  sV   z
| j ||d}|W S  ttfy* } ztjd| ddd W Y d}~dS d}~ww )a  Convert ``DeviceType`` into the equivalent type in tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent memory format.
        If  ``self`` is not supported in the target library, then ``None`` will be returned.

        Arguments:
            t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to

        Returns:
            Optional(Union(tensorrt.DeviceType, DeviceType)): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``

        Examples:

            .. code:: py

                # Succeeds
                trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
        rR   z4torch_tensorrt.DeviceType conversion to target type ra   TrT   Nrb   rc   r   r   r   rd   H  s   
zDeviceType.try_tore   c                 C  rf   rg   )r   rO   r   rh   ri   r   r   r   rk   i  rl   zDeviceType.__eq__r\   c                 C  rm   rg   rn   rp   r   r   r   rq   m  rr   zDeviceType.__hash__N)r   r   r   r   )r   r   r   r   rs   )r   r   r   r   r   r   )r   r   r   r   r   r   )re   r   r   r   rt   )ru   rv   rw   rx   r   UNKNOWNr   r   r   rO   rZ   r_   rd   rk   rq   r   r   r   r   r     s$    5#A
!r   c                   @  sl   e Zd ZdZe Z	 e Z	 e Z	 edddZ	eddd	Z
dddZdddZdddZdddZdS )EngineCapabilityzr
    EngineCapability determines the restrictions of a network during build time and what runtime it targets.
    c-Union[trt.EngineCapability, EngineCapability]r   c                 C  s   t |tjr%|tjjkrtjS |tjjkrtjS |tjjkr!tjS tdt |tr,|S tjrZddl	m
} t ||jrZ||jjkrDtjS ||jjkrMtjS ||jjkrVtjS tdtd)a  Create a Torch-TensorRT Engine capability enum from a TensorRT Engine capability enum.

        Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
        If the source is not supported or the engine capability is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.EngineCapability.try_from()``

        Arguments:
            c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library

        Returns:
            EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``

        Raises:
            TypeError: Unknown source type or unsupported engine capability

        Examples:

            .. code:: py

                torchtrt_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAFETY)

        )Provided an unsupported engine capabilityr   r   z@Provided unsupported source type for EngineCapability conversion)r   r8   r   STANDARDSAFETYDLA_STANDALONErV   r	   rJ   rK   r   r7   )rN   r   r   r   r   r   rO     s.   
zEngineCapability._fromOptional[EngineCapability]c              
   C  sR   zt | }|W S  ttfy( } ztjd|  ddd W Y d}~dS d}~ww )ab  Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.

        Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
        If the source is not supported or the engine capability level is not supported in Torch-TensorRT,
        then an exception will be raised. As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.EngineCapability.try_from()``

        Arguments:
            c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library

        Returns:
            EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``

        Examples:

            .. code:: py

                torchtrt_safety_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAEFTY)

        rS   z) to torch_tensorrt.EngineCapablity failedTrT   N)r   rO   rV   r7   r5   rW   )r   rX   rY   r   r   r   rZ     r   zEngineCapability.try_fromr   9Union[Type[trt.EngineCapability], Type[EngineCapability]]c                 C  s   |t jkr$| tjkrt jjS | tjkrt jjS | tjkr t jjS td|tkr*| S tjrWddlm	} ||jkrW| tjkrA|jjS | tjkrJ|jjS | tjkrS|jjS tdt
d)a  Convert ``EngineCapability`` into the equivalent type in tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent engine capability.
        If  ``self`` is not supported in the target library, then an exception will be raised.
        As such it is not recommended to use this method directly.

        Alternatively use ``torch_tensorrt.EngineCapability.try_to()``

        Arguments:
            t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to

        Returns:
            Union(tensorrt.EngineCapability, EngineCapability): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``

        Raises:
            TypeError: Unknown target type or unsupported engine capability

        Examples:

            .. code:: py

                # Succeeds
                torchtrt_dla_ec = torch_tensorrt.EngineCapability.DLA_STANDALONE.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA
        r   r   r   zKProvided unsupported destination type for engine capability type conversion)r8   r   r   r   r   rV   r	   rJ   rK   r   r7   r   r   r   r   r_     s.   







zEngineCapability.to7Optional[Union[trt.EngineCapability, EngineCapability]]c              
   C  r   )a"  Convert ``EngineCapability`` into the equivalent type in tensorrt

        Converts ``self`` into one of torch or tensorrt equivalent engine capability.
        If  ``self`` is not supported in the target library, then ``None`` will be returned.

        Arguments:
            t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to

        Returns:
            Optional(Union(tensorrt.EngineCapability, EngineCapability)): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``

        Examples:

            .. code:: py

                # Succeeds
                trt_dla_ec = torch_tensorrt.EngineCapability.DLA.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA_STANDALONE
        z9torch_tensorrt.EngineCapablity conversion to target type ra   TrT   Nrb   r   r   r   r   rd   !  s   

zEngineCapability.try_tore   r   c                 C  rf   rg   )r   rO   r   rh   ri   r   r   r   rk   @  rl   zEngineCapability.__eq__r\   c                 C  rm   rg   rn   rp   r   r   r   rq   D  rr   zEngineCapability.__hash__N)r   r   r   r   )r   r   r   r   )r   r   r   r   )r   r   r   r   )re   r   r   r   rt   )ru   rv   rw   rx   r   r   r   r   r   rO   rZ   r_   rd   rk   rq   r   r   r   r   r   q  s     9
"
9
r   c                   @  sT   e Zd ZdZe Z	 e Z	 e Z	 e Ze	dddZ
dddZeddd	Zd
S )PlatformzZ
    Specifies a target OS and CPU architecture that a Torch-TensorRT program targets
    r   c                 C  s   ddl }|  dr(|  drtjS |  dr%tjS tj	S |  dr=|  dr=tjS tj	S )z
        Returns an enum for the current platform Torch-TensorRT is running on

        Returns:
            Platform: Current platform
        r   Nlinuxaarch64x86_64windowsamd64)
platformsystemlower
startswithmachiner   LINUX_AARCH64LINUX_X86_64
WIN_X86_64r   )rN   r   r   r   r   current_platformd  s   zPlatform.current_platformrI   c                 C  rm   rg   )rI   namerp   r   r   r   __str__|  rr   zPlatform.__str__c                 C  sZ   t jj }| tjkrt jj }|S | tjkr t jj }|S | tj	kr+t jj
 }|S rg   )r   opstensorrt_platform_unknownr   r   _platform_linux_x86_64r   _platform_linux_aarch64r   _platform_win_x86_64)r^   valr   r   r   _to_serialized_rt_platform  s   


z#Platform._to_serialized_rt_platformN)r   r   )r   rI   )ru   rv   rw   rx   r   r   r   r   r   r   r   r   r
   r   r   r   r   r   r   H  s    
r   )
__future__r   r5   enumr   r   typingr   r   r   r   numpyr   r   r8   r   torch_tensorrt._featuresr	   r
   torch_tensorrt._utilsr   r   r   r   r   r   r   r   r   r   <module>   s.        	    W X