o
    oi                     @  s   d dl mZ d dlmZmZmZ d dlZd dlmZ d dlm	Z	m
Z
mZ d dlmZ d dlmZ d dlmZ d	d
gZd1ddZd2d3ddZd4ddZd5d!d"Zd6d7d'd(Zd8d+d,ZG d-d	 d	ZG d.d/ d/eZG d0d
 d
ZdS )9    )annotations)OptionalTuplecastN)Size)Tensorstackzeros)validate_bbox)transform_points)eye_likeBoxesBoxes3Ddtypetorch.dtypereturnboolc                 C  s   | t jt jt jt jt jfv S N)torchfloat16float32float64bfloat16half)r    r   I/home/ubuntu/.local/lib/python3.10/site-packages/kornia/geometry/boxes.py_is_floating_point_dtype!   s   r   padboxeslist[torch.Tensor]methodstrtuple[torch.Tensor, list[int]]c                   s   t dd | D stddd | D  d|dkr:tdd | D   fd	d| D }tjjjj| d
d}||fS td| d)z&Merge a list of boxes into one tensor.c                 s  s8    | ]}|j d d tddgko| dkV  qdS )N         )shaper   r   dim.0boxr   r   r   	<genexpr>'   s   6 z"_merge_box_list.<locals>.<genexpr>z5Input boxes must be a list of (N, 4, 2) shaped. Got: c                 S  s   g | ]}|j qS r   r'   r)   r   r   r   
<listcomp>(   s    z#_merge_box_list.<locals>.<listcomp>.r   c                 s  s    | ]}|j d  V  qdS )r   Nr-   r)   r   r   r   r,   +   s    c                   s   g | ]	} |j d   qS )r   r-   r)   max_Nr   r   r.   ,   s    T)batch_first`z` is not implemented.)	all	TypeErrormaxr   nnutilsrnnpad_sequenceNotImplementedError)r   r    statsoutputr   r0   r   _merge_box_list%   s   r>   torch.TensorMc                 C  s   |  r|n| }| jdd \}}}|dkr| S | d|| |}|jdkr*|n|d}|jd |jd krJtd|jd  d|jd  dt||}|| }|S )	a  Transform 3D and 2D in kornia format by applying the transformation matrix M.

    Boxes and the transformation matrix could be batched or not.

    Args:
        boxes: 2D quadrilaterals or 3D hexahedrons in kornia format.
        M: the transformation matrix of shape :math:`(3, 3)` or :math:`(B, 3, 3)` for 2D and :math:`(4, 4)` or
            :math:`(B, 4, 4)` for 3D hexahedron.

    Nr   r&   zBatch size mismatch. Got z for boxes and z for the transformation matrix.)	is_floating_pointfloatr'   viewndim	unsqueeze
ValueErrorr   view_as)r   r@   boxes_per_batchn_points_per_boxcoordinates_dimensionpointstransformed_boxesr   r   r   _transform_boxes4   s   

rO   xminyminwidthheightc                 C  s   | j |j   kr|j   kr|j   krdks td tdt| jd | jd ddf| j| jd}| d|d< |d|d	< |d
  |d 7  < |d  |d 7  < |d  |d 7  < |d  |d 7  < |S )Nr%   zXWe expect to create a batch of 2D boxes (quadrilaterals) in vertices format (B, N, 4, 2)r      r$   devicer   rB   .r   .rT   .rT   r   .r%   r   .r%   rT   .r&   rT   )rF   rH   r	   r'   rV   r   rG   )rP   rQ   rR   rS   polygonsr   r   r   _boxes_to_polygonsR   s   .&r^   xyxyTmodevalidate_boxesc           	      C  s  |  }|dr6| jdk}d| j  krdkr)n n| jdd tddgks5td| d| j d	n2|d
ra| jdk}d| j  krKdkrTn n| jd dks`td| d| j d	ntd| |  rn| n|  } |rv| n| 	d} |dr|dkr| 
 }|ddddf d |ddddf< |ddddf d |ddddf< n|dkr| 
 }ntd| | pt| n}|d
r;|dkr| d | d  | d | d  }}n0|dkr| d | d  d | d | d  d }}n|dkr
| d | d }}ntd| |r*|dk rtd|dk r*td| d | d }}t||||}ntd| |rI|}|S |d}|S )z%Convert from boxes to quadrilaterals.verticesr$   r&   r#   Nr%   z3Boxes shape must be (N, 4, 2) or (B, N, 4, 2) when z mode. Got r/   xyrB   z-Boxes shape must be (N, 4) or (B, N, 4) when Unknown mode r   .rT   vertices_plusr_   .r&   rX   .r%   rW   	xyxy_plusxywh%Some boxes have negative widths or 0.&Some boxes have negative heights or 0.)lower
startswithrF   r'   r   r   rH   rC   rD   rG   cloner
   anyr^   squeeze)	r   r`   ra   batchedquadrilateralsrS   rR   rP   rQ   r   r   r   _boxes_to_quadrilateralsd   sX   

6

(
$&
$,


rs   zmindepthc           	      C  s0  | j |j   kr'|j   kr'|j   kr'|j   kr'|j   kr'dks,td tdt| jd | jd ddf| j| jd}| d|d	< |d|d
< |d|d< |d  |d 7  < |d  |d 7  < |d  |d 7  < |d  |d 7  < | }|d  |dd 7  < tj	||gdd}|S )Nr%   zUWe expect to create a batch of 3D boxes (hexahedrons) in vertices format (B, N, 8, 3)r   rT   r$   r&   rU   rB   rW   rX   rg   rY   rZ   r[   r\   r#   r(   )
rF   rH   r	   r'   rV   r   rG   rn   r   cat)	rP   rQ   rt   rR   rS   ru   front_verticesback_vertices
polygons3dr   r   r   _boxes3d_to_polygons3d   s    F&r{   c                   @  sT  e Zd ZdZ		dgdhddZdiddZdjddZedkddZdlddZ	dmdnddZ
	dmdod#d$Zdpd'd(Zdpd)d*Z	+	+	dqdrd/d0Zdsdtd2d3Z	dqdud7d8Zdvd:d;Ze	dwdxd>d?Z	dydzdBdCZd{dGdHZdmd|dJdKZd}dLdMZd~ddQdRZedvdSdTZeddUdVZeddXdYZedd[d\ZdddadbZddcddZddedfZd+S )r   a  2D boxes containing N or BxN boxes.

    Args:
        boxes: 2D boxes, shape of :math:`(N, 4, 2)`, :math:`(B, N, 4, 2)` or a list of :math:`(N, 4, 2)`.
            See below for more details.
        raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a
            floating point tensor. True to raise an error when `boxes` isn't a floating point tensor, False
            to cast to float.
        mode: the box format of the input boxes.

    Note:
        **2D boxes format** is defined as a floating data type tensor of shape ``Nx4x2`` or ``BxNx4x2``
        where each box is a `quadrilateral <https://en.wikipedia.org/wiki/Quadrilateral>`_ defined by it's
        4 vertices coordinates (A, B, C, D). Coordinates must be in ``x, y`` order. The height and width of
        a box is defined as ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. Examples of
        `quadrilaterals <https://en.wikipedia.org/wiki/Quadrilateral>`_ are rectangles, rhombus and trapezoids.

    Tre   r   !torch.Tensor | list[torch.Tensor]raise_if_not_floating_pointr   r`   r!   r   Nonec                 C  s   d | _ t|trt|\}| _ t|tjstdt| d| s1|r-t	d|j
 | }t|jdkr=|d}d|j  krHdkrSn n	|jdd  d	ks\t	d
|j d|jdkrcdnd| _|| _|| _d S )N"Input boxes is not a Tensor. Got: r/   +Coordinates must be in floating point. Got r   )rB   r$   r&   r$   r#   )r$   r%   z3Boxes shape must be (N, 4, 2) or (B, N, 4, 2). Got FT)_N
isinstancelistr>   r   r   r5   typerC   rH   r   rD   lenr'   reshaperF   _is_batched_data_modeselfr   r}   r`   r   r   r   __init__   s    

,
zBoxes.__init__keyslice | int | Tensorc                 C  s    t | | j| d}| j|_|S NF)r   r   r   r   r   new_boxr   r   r   __getitem__      zBoxes.__getitem__valuec                 C     |j | j |< | S r   r   r   r   r   r   r   r   __setitem__      zBoxes.__setitem__tuple[int, ...] | Sizec                 C     | j jS r   datar'   r   r   r   r   r'         zBoxes.shape!tuple[torch.Tensor, torch.Tensor]c                 C  s0   t tj| jddd}|d |d }}||fS )a  Compute boxes heights and widths.

        Returns:
            - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
            - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.

        Example:
            >>> boxes_xyxy = torch.tensor([[[1,1,2,2],[1,1,3,2]]])
            >>> boxes = Boxes.from_tensor(boxes_xyxy)
            >>> boxes.get_boxes_shape()
            (tensor([[1., 1.]]), tensor([[1., 2.]]))

        ri   Tas_padded_sequencerg   rf   )r   r   r   	to_tensor)r   
boxes_xywhwidthsheightsr   r   r   get_boxes_shape   s   zBoxes.get_boxes_shapeFinplacec                 C  s6   t j| j|jgdd}|r|| _| S |  }||_|S )a"  Merge boxes.

        Say, current instance holds :math:`(B, N, 4, 2)` and the incoming boxes holds :math:`(B, M, 4, 2)`,
        the merge results in :math:`(B, N + M, 4, 2)`.

        Args:
            boxes: 2D boxes.
            inplace: do transform in-place and return self.

        rT   rv   )r   rw   r   r   rn   )r   r   r   r   objr   r   r   merge  s   zBoxes.mergeindices!tuple[Tensor, ...] | list[Tensor]valuesTensor | Boxesc                 C  sV   |r| j }n| j  }t|tr|||j n||| |r"| S |  }||_ |S r   )r   rn   r   r   
index_put_r   )r   r   r   r   r   r   r   r   r   	index_put#  s   

zBoxes.index_putpadding_sizer   c                 C  s   t |jdkr|ddkstd|j d| jd  |ddddf j| jjd	7  < | jd
  |ddddf j| jjd	7  < | S zMPad a bounding box.

        Args:
            padding_size: (B, 4)

        r%   rT   r$   z%Expected padding_size as (B, 4). Got r/   rW   .NrV   rX   r&   r   r'   sizeRuntimeErrorr   torV   r   r   r   r   r   r   7  
   ,,z	Boxes.padc                 C  s   t |jdkr|ddkstd|j d| jd  |ddddf j| jjd	8  < | jd
  |ddddf j| jjd	8  < | S r   r   r   r   r   r   unpadD  r   zBoxes.unpadNtopleft"Optional[Tensor | tuple[int, int]]botrightc           
      C  s`  t |tr
t |tst|r| j}n| j }|d d d d df d|dd}||d |k  |d |d |k < |d d d dd f d|dd}||d |k  |d |d |k < |d d d d df d|dd}||d |k |d |d |k< |d d d dd f d|dd}||d |k |d |d |k< |r| S |  }	||	_|	S )NrT   r$   rW   rX   )r   r   r;   r   rn   repeatr   )
r   r   r   r   r   	topleft_x	topleft_y
botright_x
botright_yr   r   r   r   clampQ  s$   
& & & & zBoxes.clampcorrespondence_preservec                 C  s   t )a  Trim out zero padded boxes.

        Given box arrangements of shape :math:`(4, 4, Box)`:

            == === == === == === == === ==
            -- Box -- Box -- Box -- Box --
            --  0  --  0  -- Box -- Box --
            --  0  -- Box --  0  --  0  --
            --  0  --  0  --  0  --  0  --
            == === == === == === == === ==

        Nothing will change if correspondence_preserve is True. Only pure zero layers will be removed, resulting in
        shape :math:`(4, 3, Box)`:

            == === == === == === == === ==
            -- Box -- Box -- Box -- Box --
            --  0  --  0  -- Box -- Box --
            --  0  -- Box --  0  --  0  --
            == === == === == === == === ==

        Otherwise, you will get :math:`(4, 2, Box)`:

            == === == === == === == === ==
            -- Box -- Box -- Box -- Box --
            --  0  -- Box -- Box -- Box --
            == === == === == === == === ==
        )r;   )r   r   r   r   r   r   trimo  s   z
Boxes.trimmin_areaOptional[float]max_areac                 C  s`   |   }|r
| j}n| j }|d urd|||k < |d ur#d|||k< |r'| S |  }||_|S )Ng        )compute_arear   rn   )r   r   r   r   arear   r   r   r   r   filter_boxes_by_area  s   
zBoxes.filter_boxes_by_arear?   c           
      C  s   | j jdkr| j dn| j }|jddd}t|d |d  |d |d  }tj|ddd\}}t|d|d		d	d	d
}|d |d }}dt
tj|t|dd |t|dd  dd }	| j jdkrw|	| j jdd
 S |	S )zReturn :math:`(B, N)`.r$   )rB   r$   r%   rT   T)r(   keepdimrX   rW   )r(   
descendingrB   r%   g      ?rv   N)r   rF   rE   meanr   atan2sortgatherrG   expandabssumrollr'   )
r   coordscentroidangles_clockwise_indicesordered_cornersxyr   r   r   r   r     s   $8&zBoxes.compute_arear_   ra   c                   s<   t |tjrt| d}n
 fdd|D }| |d S )ab	  Create :class:`Boxes` from boxes stored in another format.

        Args:
            boxes: 2D boxes, shape of :math:`(N, 4)`, :math:`(B, N, 4)`, :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
            mode: The format in which the boxes are provided.

                * 'xyxy': boxes are assumed to be in the format ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin``
                  and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
                * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
                  ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
                  With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
                * 'xywh': boxes are assumed to be in the format ``xmin, ymin, width, height`` where
                  ``width = xmax - xmin`` and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
                * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
                  *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
                  box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
                  With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
                * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
                  ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
                  With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.

            validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width
                and height >= 1 (>= 2 when mode ends with '_plus' suffix).

        Returns:
            :class:`Boxes` class containing the original `boxes` in the format specified by ``mode``.

        Examples:
            >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
            >>> boxes = Boxes.from_tensor(boxes_xyxy, mode='xyxy')
            >>> boxes.data  # (2, 4, 2)
            tensor([[[0., 3.],
                     [0., 3.],
                     [0., 3.],
                     [0., 3.]],
            <BLANKLINE>
                    [[5., 1.],
                     [7., 1.],
                     [7., 3.],
                     [5., 3.]]])

        r`   ra   c                   s   g | ]}t | qS r   )rs   r)   r   r   r   r.     s    z%Boxes.from_tensor.<locals>.<listcomp>F)r   r   r   rs   )clsr   r`   ra   rr   r   r   r   from_tensor  s   /zBoxes.from_tensorOptional[str]r   c                 C  sR  | j r| jn| jd}tj|jdd|jddgdd|jd |jd d}|du r/| j	}|
 }|dv r8n)|dv rZ|d	 |d
  d |d |d  d }}||d< ||d	< ntd| |dv rutjg d|j|jd}|| }|drt|d |d
 |d |d	 }| jdur|sdd t|| jD }|S | j r|n|d}|S )au  Cast :class:`Boxes` to a tensor.

        ``mode`` controls which 2D boxes format should be use to represent boxes in the tensor.

        Args:
            mode: the output box format. It could be:

                * 'xyxy': boxes are defined as ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin`` and
                  ``height = ymax - ymin``.
                * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
                  ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
                * 'xywh': boxes are defined as ``xmin, ymin, width, height`` where ``width = xmax - xmin``
                  and ``height = ymax - ymin``.
                * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
                  *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
                  box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
                * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
                  ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
            as_padded_sequence: whether to keep the pads for a list of boxes. This parameter is only valid
                if the boxes are from a box list whilst `from_tensor`.

        Returns:
            Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:

                * 'vertices' or 'verticies_plus': :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
                * Any other value: :math:`(N, 4)` or :math:`(B, N, 4)`.

        Examples:
            >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
            >>> boxes = Boxes.from_tensor(boxes_xyxy)
            >>> assert (boxes_xyxy == boxes.to_tensor(mode='xyxy')).all()

        r   r#   rv   rT   r$   N)r_   rh   )ri   rb   re   rf   rX   rg   rW   rd   )r_   rb   )r   r   rT   rT   rU   rb   c                 S  s<   g | ]\}}t jj|t|jd  ddg d| g qS )rT   r   )r   r7   
functionalr   r   r'   )r*   onr   r   r   r.   )  s   < z#Boxes.to_tensor.<locals>.<listcomp>)r   r   rG   r   r   aminamaxrE   r'   r`   rl   rH   	as_tensorrV   r   rm   r^   r   ziprp   )r   r`   r   batched_boxesr   rS   rR   offsetr   r   r   r     s0   $"*

zBoxes.to_tensorrS   intrR   c                 C  s  | j jrtd| j}| j}| j}|jdkr| jr0tj| j j	d | j j	d ||f| j| jd}ntj| j j	d ||f| j| jd}t
tj| jddd}|d	d
d
df d| |d	dd
df d| t|d|||dd  D ]\}}	d||	d |	d |	d |	d f< qx|S |r| j	d | j	d ||f}
n| j	d ||f}
t
t| jddd}|d	d
d
df d| |d	dd
df d| |dd  }|d
d
df |d
d
df |d
d
df |d
d
df f\}}}}|d|}|d|}|d|}|d|}tj||d}tj||d}|d
d
d
f |d
d
d
f k|d
d
d
f |d
d
d
f d k @ }|d
d
d
f |d
d
d
f k|d
d
d
f |d
d
d
f d k @ }|d|d@ |}|j|
 S )a  Convert 2D boxes to masks. Covered area is 1 and the remaining is 0.

        Args:
            height: height of the masked image/images.
            width: width of the masked image/images.

        Returns:
            the output mask tensor, shape of :math:`(N, width, height)` or :math:`(B,N, width, height)` and dtype of
            :func:`Boxes.dtype` (it can be any floating point dtype).

        Note:
            It is currently non-differentiable.

        Examples:
            >>> boxes = Boxes(torch.tensor([[  # Equivalent to boxes = Boxes.from_tensor([[1,1,4,3]])
            ...        [1., 1.],
            ...        [4., 1.],
            ...        [4., 3.],
            ...        [1., 3.],
            ...   ]]))  # 1x4x2
            >>> boxes.to_mask(5, 5)
            tensor([[[0., 0., 0., 0., 0.],
                     [0., 1., 1., 1., 1.],
                     [0., 1., 1., 1., 1.],
                     [0., 1., 1., 1., 1.],
                     [0., 0., 0., 0., 0.]]])

        cBoxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`.cudar   rT   r   rV   r_   Tr   .Nr%   rB   r$   r&   r   )r   requires_gradr   r   r   rV   r   r   r	   r'   r   r   r   clamp_r   rE   roundr   longr   arangerG   r   )r   rS   rR   
is_batchedr   rV   maskclipped_boxes_xyxymask_channelbox_xyxy	out_shaper_   x1y1x2y2ysxsy_maskx_maskmasksr   r   r   to_mask.  sN   
""&DDD
zBoxes.to_maskr@   c                 C  sj   d|j   krdkrn n	|jdd dkrtd|j dt| j|}|r,|| _| S |  }||_|S )a  Apply a transformation matrix to the 2D boxes.

        Args:
            M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
            inplace: do transform in-place and return self.

        Returns:
            The transformed boxes.

        r%   r&   r#   N)r&   r&   zAThe transformation matrix shape must be (3, 3) or (B, 3, 3). Got r/   )rF   r'   rH   rO   r   rn   )r   r@   r   rN   r   r   r   r   transform_boxes  s   ,zBoxes.transform_boxesc                 C     | j |ddS )z1Inplace version of :func:`Boxes.transform_boxes`.Tr   r  r   r@   r   r   r   transform_boxes_     zBoxes.transform_boxes_warpr   r    c                 C  sH   |dkrt |dkrnt td|}||dddddf< | j||dS )a#  Translate boxes by the provided size.

        Args:
            size: translate size for x, y direction, shape of :math:`(B, 2)`.
            method: "warp" or "fast".
            inplace: do transform in-place and return self.

        Returns:
            The transformed boxes.

        fastr  r&   Nr%   r  )r;   r   r  )r   r   r    r   r@   r   r   r   	translate  s   
zBoxes.translatec                 C     | j S r   r   r   r   r   r   r        z
Boxes.datac                 C  r  r   r   r   r   r   r   r`     r  z
Boxes.modetorch.devicec                 C  r   zReturns boxes device.r   rV   r   r   r   r   rV        zBoxes.devicer   c                 C  r   zReturns boxes dtype.r   r   r   r   r   r   r     r  zBoxes.dtyperV   Optional[torch.device]r   Optional[torch.dtype]c                 C  .   |durt |std| jj||d| _| S z)Like :func:`torch.nn.Module.to()` method.NzBoxes must be in floating pointrU   r   rH   r   r   r   rV   r   r   r   r   r        zBoxes.toc                 C  s0   t | | j d}| j|_| j|_| j|_|S r   )r   r   rn   r   r   r   r   r   r   r   r   rn     s
   zBoxes.clonec                 C  s   | j || _ | S r   )r   r   )r   r   r   r   r   r     s   z
Boxes.type)Tre   )r   r|   r}   r   r`   r!   r   r~   )r   r   r   r   )r   r   r   r   r   r   r   r   )r   r   F)r   r   r   r   r   r   )r   r   r   r   r   r   r   r   )r   r   r   r   )NNF)r   r   r   r   r   r   r   r   )FF)r   r   r   r   r   r   )r   r   r   r   r   r   r   r   r   r?   r_   T)r   r|   r`   r!   ra   r   r   r   r   )r`   r   r   r   r   r|   )rS   r   rR   r   r   r?   )r@   r?   r   r   r   r   )r@   r?   r   r   )r  F)r   r   r    r!   r   r   r   r   r   r!   r   r  r   r   NN)rV   r  r   r  r   r   )r   r   )r   r   r   r   ) __name__
__module____qualname____doc__r   r   r   propertyr'   r   r   r   r   r   r   r   r   r   classmethodr   r   r  r  r	  r  r   r`   rV   r   r   rn   r   r   r   r   r   r      sV    





6
H^

c                      sF   e Zd ZU ded< e	ddd	d
Zdd fddZdddZ  ZS )
VideoBoxesr   temporal_channel_sizeTr   r|   ra   r   r   c              	   C  s   t |tfs| dks|jdd  tddgkrtd|d}t|	|d|d d|d	|dd
|d}| |dd
}||_
|S )N   r#   r$   r%   zQInput box type is not yet supported. Please input an `BxTxNx4x2` tensor directly.rT   r   rB   r&   re   r   F)r   r   r(   r'   r   r   rH   r   rs   rE   r.  )r   r   ra   r.  rr   outr   r   r   r     s   4
*zVideoBoxes.from_tensorNr`   r   c                   sL   t  j|dd}t|tr|jd jg|jdd  R  S  fdd|D S )NFr   rB   rT   c                   s,   g | ]}|j d  jg|jdd R  qS )rB   rT   N)rE   r.  r'   )r*   _outr   r   r   r.     s   , z(VideoBoxes.to_tensor.<locals>.<listcomp>)superr   r   r   rE   r.  r'   )r   r`   r0  	__class__r   r   r     s   
 zVideoBoxes.to_tensorc                 C  s8   t | | j d}| j|_| j|_| j|_| j|_|S r   )r   r   rn   r   r   r   r.  r  r   r   r   rn      s   zVideoBoxes.clone)T)r   r|   ra   r   r   r-  r   )r`   r   r   r|   )r   r-  )	r'  r(  r)  __annotations__r,  r   r   rn   __classcell__r   r   r3  r   r-    s   
 r-  c                   @  s   e Zd ZdZ	d?d@ddZdAddZdBddZedCddZdDddZ	e
dEdFddZdGdHdd ZdId%d&ZdJdKd*d+ZdLd,d-ZedMd.d/ZedNd0d1ZedOd3d4ZedPd6d7ZdQdRd=d>Zd8S )Sr   a  3D boxes containing N or BxN boxes.

    Args:
        boxes: 3D boxes, shape of :math:`(N,8,3)` or :math:`(B,N,8,3)`. See below for more details.
        raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a floating
            point tensor. True to raise an error when `boxes` isn't a floating point tensor, False to cast to float.

    Note:
        **3D boxes format** is defined as a floating data type tensor of shape ``Nx8x3`` or ``BxNx8x3`` where each box
        is a `hexahedron <https://en.wikipedia.org/wiki/Hexahedron>`_ defined by it's 8 vertices coordinates.
        Coordinates must be in ``x, y, z`` order. The height, width and depth of a box is defined as
        ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``. Examples of
        `hexahedrons <https://en.wikipedia.org/wiki/Hexahedron>`_ are cubes and rhombohedrons.

    Txyzxyz_plusr   r?   r}   r   r`   r!   r   r~   c                 C  s   t |tjstdt| d| s#|rtd|j d| }t	|j
dkr/|d}d|j  kr:dkrEn n	|j
dd  d	ksNtd
|j
 d|jdkrUdnd| _|| _|| _d S )Nr   r/   r   r   )rB      r&   r$   r#   )   r&   z53D bbox shape must be (N, 8, 3) or (B, N, 8, 3). Got FT)r   r   r   r5   r   rC   rH   r   rD   r   r'   r   rF   r   r   r   r   r   r   r   r     s   
,
zBoxes3D.__init__r   r   c                 C  s    t | j| ddd}| j|_|S )NFr7  r`   )r   r   r   r   r   r   r   r   1  r   zBoxes3D.__getitem__r   c                 C  r   r   r   r   r   r   r   r   6  r   zBoxes3D.__setitem__r   c                 C  r   r   r   r   r   r   r   r'   :  r   zBoxes3D.shape/tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s2   | j dd}|d |d |d }}}|||fS )a*  Compute boxes heights and widths.

        Returns:
            - Boxes depths, shape of :math:`(N,)` or :math:`(B,N)`.
            - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
            - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.

        Example:
            >>> boxes_xyzxyz = torch.tensor([[ 0,  1,  2, 10, 21, 32], [3, 4, 5, 43, 54, 65]])
            >>> boxes3d = Boxes3D.from_tensor(boxes_xyzxyz)
            >>> boxes3d.get_boxes_shape()
            (tensor([30., 60.]), tensor([20., 50.]), tensor([10., 40.]))

        xyzwhdr:  rf   .r$   .r/  )r   )r   boxes_xyzwhdr   r   depthsr   r   r   r   >  s   
zBoxes3D.get_boxes_shapexyzxyzra   c                 C  s  d|j   krdkrn n|jd dkstd|j d|j dk}|r&|n|d}| r1|n| }|d |d	 |d
 }}}| }|dkrd|d |d  }|d |d	  }	|d |d
  }
n=|dkr|d |d  d }|d |d	  d }	|d |d
  d }
n|dkr|d |d |d }
}	}ntd| |r|dk rtd|	dk rtd|
dk rtdt|||||	|
}|r|n|	d}| |d|dS )a  Create :class:`Boxes3D` from 3D boxes stored in another format.

        Args:
            boxes: 3D boxes, shape of :math:`(N,6)` or :math:`(B,N,6)`.
            mode: The format in which the 3D boxes are provided.

                * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
                  ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
                * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
                  ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
                * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
                  ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.

            validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width, height
                and depth >= 1 (>= 2 when mode ends with '_plus' suffix).

        Returns:
            :class:`Boxes3D` class containing the original `boxes` in the format specified by ``mode``.

        Examples:
            >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
            >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
            >>> boxes.data  # (2, 8, 3)
            tensor([[[0., 3., 6.],
                     [0., 3., 6.],
                     [0., 3., 6.],
                     [0., 3., 6.],
                     [0., 3., 7.],
                     [0., 3., 7.],
                     [0., 3., 7.],
                     [0., 3., 7.]],
            <BLANKLINE>
                    [[5., 1., 3.],
                     [7., 1., 3.],
                     [7., 3., 3.],
                     [5., 3., 3.],
                     [5., 1., 8.],
                     [7., 1., 8.],
                     [7., 3., 8.],
                     [5., 3., 8.]]])

        r%   r&   rB   r8  z,BBox shape must be (N, 6) or (B, N, 6). Got r/   r   rW   rX   rg   rA  rf   r=  r>  r7  rT   r<  rd   rj   rk   z%Some boxes have negative depths or 0.F)r}   r`   )
rF   r'   rH   rG   rC   rD   rl   ro   r{   rp   )r   r   r`   ra   rq   rP   rQ   rt   rR   rS   ru   hexahedronsr   r   r   r   Q  s8   (,
zBoxes3D.from_tensorc                 C  st  | j jrtd| jr| j n| j d}t|jdd|jddgdd|j	d |j	d d}|
 }|dv r8n6|dv rg|d	 |d
  d }|d |d  d }|d |d  d }||d	< ||d< ||d< ntd| |dv rtjg d|j|jd}|| }|dr|d
 |d |d }}	}
|d	 |d |d }}}t||	|
|||}| jr|}|S |d}|S )a  Cast :class:`Boxes3D` to a tensor.

        ``mode`` controls which 3D boxes format should be use to represent boxes in the tensor.

        Args:
            mode: The format in which the boxes are provided.

                * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
                  ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
                * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
                   ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
                * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
                  ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
                * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
                  *front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
                  back-top-right, back-bottom-right,  back-bottom-left*. Vertices coordinates are in (x,y, z) order.
                  Finally, box width, height and depth are defined as ``width = xmax - xmin``, ``height = ymax - ymin``
                  and ``depth = zmax - zmin``.
                * 'vertices_plus': similar to 'vertices' mode but where box width, length and depth are defined as
                  ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.

        Returns:
            3D Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:

                * 'vertices' or 'verticies_plus': :math:`(N, 8, 3)` or :math:`(B, N, 8, 3)`.
                * Any other value: :math:`(N, 6)` or :math:`(B, N, 6)`.

        Note:
            It is currently non-differentiable due to a bug. See github issue
            `#1304 <https://github.com/kornia/kornia/issues/1396>`_.

        Examples:
            >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
            >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
            >>> assert (boxes.to_tensor(mode='xyzxyz') == boxes_xyzxyz).all()

        a  Boxes3D.to_tensor doesn't support computing gradients since they aren't accurate. Please, create boxes from tensors with `requires_grad=False`. This is a known bug. Help is needed to fix it. For more information, see https://github.com/kornia/kornia/issues/1396.r   r#   rv   rT   r8  )rA  r7  )r<  rb   re   rf   rW   r=  rX   r>  rg   rd   )rA  rb   )r   r   r   rT   rT   rT   rU   rb   )r   r   r   r   rG   r   r   r   rE   r'   rl   rH   r   r   rV   r   rm   r{   rp   )r   r`   r   r   rR   rS   ru   r   rP   rQ   rt   r   r   r   r     s<   & 



zBoxes3D.to_tensorru   r   rS   rR   c                 C  s,  | j jrtd| jr$t| j jd | j jd |||f| j j| j jd}nt| j jd |||f| j j| j jd}| d}|ddddf 	d| |ddddf 	d| |dd	ddf 	d| t
|d
||||d
d  D ]\}}d||d	 |d |d |d |d |d f< qv|S )u  Convert ·D boxes to masks. Covered area is 1 and the remaining is 0.

        Args:
            depth: depth of the masked image/images.
            height: height of the masked image/images.
            width: width of the masked image/images.

        Returns:
            the output mask tensor, shape of :math:`(N, depth, width, height)` or :math:`(B,N, depth, width, height)`
             and dtype of :func:`Boxes3D.dtype` (it can be any floating point dtype).

        Note:
            It is currently non-differentiable.

        Examples:
            >>> boxes = Boxes3D(torch.tensor([[  # Equivalent to boxes = Boxes.3Dfrom_tensor([[1,1,1,3,3,2]])
            ...     [1., 1., 1.],
            ...     [3., 1., 1.],
            ...     [3., 3., 1.],
            ...     [1., 3., 1.],
            ...     [1., 1., 2.],
            ...     [3., 1., 2.],
            ...     [3., 3., 2.],
            ...     [1., 3., 2.],
            ... ]]))  # 1x8x3
            >>> boxes.to_mask(4, 5, 5)
            tensor([[[[0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.]],
            <BLANKLINE>
                     [[0., 0., 0., 0., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 0., 0., 0., 0.]],
            <BLANKLINE>
                     [[0., 0., 0., 0., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 1., 1., 1., 0.],
                      [0., 0., 0., 0., 0.]],
            <BLANKLINE>
                     [[0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0.]]]])

        r   r   rT   r   rA  .Nr&   r%   rB   r8  r/  r$   )r   r   r   r   r	   r'   r   rV   r   r   r   rE   r   r   )r   ru   rS   rR   r   clipped_boxes_xyzxyzr   
box_xyzxyzr   r   r   r    s.   4&
 ,zBoxes3D.to_maskFr@   r   c                 C  sd   d|j   krdkrn n	|jdd dkrtd|j dt| j|}|r,|| _| S t|dd	S )
a  Apply a transformation matrix to the 3D boxes.

        Args:
            M: The transformation matrix to be applied, shape of :math:`(4, 4)` or :math:`(B, 4, 4)`.
            inplace: do transform in-place and return self.

        Returns:
            The transformed boxes.

        r%   r&   r#   N)r$   r$   zAThe transformation matrix shape must be (4, 4) or (B, 4, 4). Got r/   Fr7  )rF   r'   rH   rO   r   r   )r   r@   r   rN   r   r   r   r  B  s   ,zBoxes3D.transform_boxesc                 C  r  )z3Inplace version of :func:`Boxes3D.transform_boxes`.Tr  r  r  r   r   r   r	  W  r
  zBoxes3D.transform_boxes_c                 C  r  r   r   r   r   r   r   r   [  r  zBoxes3D.datac                 C  r  r   r  r   r   r   r   r`   _  r  zBoxes3D.moder  c                 C  r   r  r  r   r   r   r   rV   c  r  zBoxes3D.devicer   c                 C  r   r  r  r   r   r   r   r   h  r  zBoxes3D.dtypeNrV   r  r   r  c                 C  r  r  r  r  r   r   r   r   m  r  z
Boxes3D.to)Tr7  )r   r?   r}   r   r`   r!   r   r~   )r   r   r   r   )r   r   r   r   r   r   r  )r   r;  )rA  T)r   r?   r`   r!   ra   r   r   r   )rA  )r`   r!   r   r?   )ru   r   rS   r   rR   r   r   r?   r   )r@   r?   r   r   r   r   )r@   r?   r   r   r!  r#  r$  r%  r&  )rV   r  r   r  r   r   )r'  r(  r)  r*  r   r   r   r+  r'   r   r,  r   r   r  r  r	  r   r`   rV   r   r   r   r   r   r   r   	  s0    


M
OT
)r   r   r   r   )r   )r   r   r    r!   r   r"   )r   r?   r@   r?   r   r?   )
rP   r?   rQ   r?   rR   r?   rS   r?   r   r?   r"  )r   r?   r`   r!   ra   r   r   r?   )rP   r?   rQ   r?   rt   r?   rR   r?   rS   r?   ru   r?   r   r?   )
__future__r   typingr   r   r   r   r   kornia.corer   r   r	   kornia.geometry.bboxr
   kornia.geometry.linalgr   kornia.utilsr   __all__r   r>   rO   r^   rs   r{   r   r-  r   r   r   r   r   <module>   s,   



:    +%