o
    oic]                     @  s   d dl mZ d dlZd dlmZ d dlZd dlmZmZm	Z	 ddl
mZ g dZd4ddZd4ddZd5ddZd6ddZd7ddZd8ddZd9d!d"Zd:d%d&Z	d;d<d-d.Zd=d2d3ZdS )>    )annotationsN)Optional)arangestackwhere   )transform_points)
bbox_generatorbbox_generator3dbbox_to_maskbbox_to_mask3dinfer_bbox_shapeinfer_bbox_shape3dnmstransform_bboxvalidate_bboxvalidate_bbox3dboxestorch.Tensorreturnboolc                 C  s
  t | jdv r| jdd tddgksdS t | jdkr%| ddd} | d | d	 }}| d
 | d }}| d | d }}| d | d }}|| d || d }	}
|| d || d }}t|	|
 }t|| }t|dkrzdS t|dkrdS dS )a  Validate if a 2D bounding box usable or not. This function checks if the boxes are rectangular or not.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
            bottom-left. The coordinates must be in the x, y order.

          Nr      F).r   r   ).r   r   ).r   r   ).r   r   ).r   r   ).r   r   ).r   r   ).r   r   r   g-C6?T)lenshapetorchSizeviewabsany)r   x_tly_tlx_try_trx_bry_brx_bly_blwidth_twidth_bheight_theight_b
width_diffheight_diff r2   H/home/ubuntu/.local/lib/python3.10/site-packages/kornia/geometry/bbox.pyr   +   s"   *	r   c              	   C  s  t | jdv r| jdd tddgkstd| j dt | jdkr,| d	dd} t| d
tjg d| jtj	ddddddf }t| d
tjg d| jtj	ddddddf }|| d
 }t
|d
d|dddf s~td| dt| d
tjg d| jtj	dddddd
f }t| d
tjg d| jtj	dddddd
f }|| d
 }t
|d
d|dddf std| d| dddddf | dddddf  d
 }t
|d
d|dddf std| ddS )a  Validate if a 3D bounding box usable or not. This function checks if the boxes are cube or not.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
            front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
            The coordinates must be in the x, y, z order.

    r   r   N   r   z1Box shape must be (B, 8, 3) or (B, N, 8, 3). Got .r   r   r   r   r         devicedtyper   r   r   r      z4Boxes must have be cube, while get different widths r   r   r8   r=   r   r   r   r7   z5Boxes must have be cube, while get different heights r   z4Boxes must have be cube, while get different depths T)r   r   r   r    AssertionErrorr!   index_selecttensorr:   longallclosepermuter   leftrightwidthsbotupperheightsdepthsr2   r2   r3   r   P   s$   *
44 44 0"r   !tuple[torch.Tensor, torch.Tensor]c                 C  s`   t |  | ddddf | ddddf  d }| ddddf | ddddf  d }||fS )az  Auto-infer the output sizes for the given 2D bounding boxes.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
            bottom-left. The coordinates must be in the x, y order.

    Returns:
        - Bounding box heights, shape of :math:`(B,)`.
        - Boundingbox widths, shape of :math:`(B,)`.

    Example:
        >>> boxes = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ], [
        ...     [1., 1.],
        ...     [3., 1.],
        ...     [3., 2.],
        ...     [1., 2.],
        ... ]])  # 2x4x2
        >>> infer_bbox_shape(boxes)
        (tensor([2., 2.]), tensor([2., 3.]))

    Nr   r   r   )r   )r   widthheightr2   r2   r3   r   s   s   ((r   /tuple[torch.Tensor, torch.Tensor, torch.Tensor]c              	   C  sN  t |  t| dtjg d| jtjddddddf }t| dtjg d| jtjddddddf }|| d dddf }t| dtjg d| jtjddddddf }t| dtjg d| jtjddddddf }|| d dddf }| ddd	dd
f | dddd	d
f  d dddf }|||fS )aL  Auto-infer the output sizes for the given 3D bounding boxes.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
            front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
            The coordinates must be in the x, y, z order.

    Returns:
        - Bounding box depths, shape of :math:`(B,)`.
        - Bounding box heights, shape of :math:`(B,)`.
        - Bounding box widths, shape of :math:`(B,)`.

    Example:
        >>> boxes = torch.tensor([[[ 0,  1,  2],
        ...         [10,  1,  2],
        ...         [10, 21,  2],
        ...         [ 0, 21,  2],
        ...         [ 0,  1, 32],
        ...         [10,  1, 32],
        ...         [10, 21, 32],
        ...         [ 0, 21, 32]],
        ...        [[ 3,  4,  5],
        ...         [43,  4,  5],
        ...         [43, 54,  5],
        ...         [ 3, 54,  5],
        ...         [ 3,  4, 65],
        ...         [43,  4, 65],
        ...         [43, 54, 65],
        ...         [ 3, 54, 65]]]) # 2x8x3
        >>> infer_bbox_shape3d(boxes)
        (tensor([31, 61]), tensor([21, 51]), tensor([11, 41]))

    r   r6   r9   Nr   r<   r>   r?   r   r   )r   r   rA   rB   r:   rC   rF   r2   r2   r3   r      s   #4444<
r   rO   intrP   c           
      C  s   t |  tj|| j| jd|d}tj|| j| jdd|}| ddddf ddd}| ddddf ddd}| ddddf ddd}| ddddf ddd}||k||k@ ||k@ ||k@ }	|	| jS )a   Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right
            and bottom-left. The coordinates must be in the x, y order.
        width: width of the masked image.
        height: height of the masked image.

    Returns:
        the output mask tensor.

    Note:
        It is currently non-differentiable.

    Examples:
        >>> boxes = torch.tensor([[
        ...        [1., 1.],
        ...        [3., 1.],
        ...        [3., 2.],
        ...        [1., 2.],
        ...   ]])  # 1x4x2
        >>> bbox_to_mask(boxes, 5, 5)
        tensor([[[0., 0., 0., 0., 0.],
                 [0., 1., 1., 1., 0.],
                 [0., 1., 1., 1., 0.],
                 [0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0.]]])

    r9   r   Nr   r   r   )r   r   r   r:   r;   r!   to)
r   rO   rP   yyxxx_miny_minx_maxy_maxmaskr2   r2   r3   r      s    r   sizetuple[int, int, int]c                 C  s.  t |  |\}}}| ddddf  }| ddddf  }| ddddf  }| ddddf  }| ddddf  }	| ddddf  }
t|| jtjd}t|| jtjd}t|| jtjd}|dddf |dddf k|dddf |dddf k@ dddddddf |dddf |dddf k|dddf |dddf k@ dddddddf B |dddf |	dddf k|dddf |
dddf k@ dddddddf B  }|jddd	jddd	}|jddd	jddd	}|jddd	jddd	}|| | }| S )
aI  Convert 3D bounding boxes to masks. Covered area is 1. and the remaining is 0.

    Args:
        boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
            of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
            front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
            The coordinates must be in the x, y, z order.
        size: depth, height and width of the masked image.

    Returns:
        the output mask tensor.

    Examples:
        >>> boxes = torch.tensor([[
        ...     [1., 1., 1.],
        ...     [2., 1., 1.],
        ...     [2., 2., 1.],
        ...     [1., 2., 1.],
        ...     [1., 1., 2.],
        ...     [2., 1., 2.],
        ...     [2., 2., 2.],
        ...     [1., 2., 2.],
        ... ]])  # 1x8x3
        >>> bbox_to_mask3d(boxes, (4, 5, 5))
        tensor([[[[[0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.]],
        <BLANKLINE>
                  [[0., 0., 0., 0., 0.],
                   [0., 1., 1., 0., 0.],
                   [0., 1., 1., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.]],
        <BLANKLINE>
                  [[0., 0., 0., 0., 0.],
                   [0., 1., 1., 0., 0.],
                   [0., 1., 1., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.]],
        <BLANKLINE>
                  [[0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0.]]]]])

    Nr   r   r   r   r9   r   T)dimkeepdim)r   rC   r   r:   r   floatall)r   r[   D0D1D2z_minz_maxrW   rY   rV   rX   zyxmcond1cond2cond3m_outr2   r2   r3   r      s.   2
TTTr   x_starty_startc              
   C  s  | j |j kr|  dv std|  d| d|j |j kr#| dv s.td| d| d| j|j  krA|j  krA|jksWn td| j d| j d|j d	|j d
	| j|j  krj|j  krj|jksn td| j d| j d|j d	|j d
	tjddgddgddgddggg| j| jd|  dkrdnt| dd}|dddddf  | 	dd7  < |dddddf  |	dd7  < |ddddf  |d 7  < |ddddf  |d 7  < |ddddf  |d 7  < |ddddf  |d 7  < |S )aO  Generate 2D bounding boxes according to the provided start coords, width and height.

    Args:
        x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
            tensor or :math:`(B,)`.
        y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
            tensor or :math:`(B,)`.
        width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
        height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.

    Returns:
        the bounding box tensor.

    Examples:
        >>> x_start = torch.tensor([0, 1])
        >>> y_start = torch.tensor([1, 0])
        >>> width = torch.tensor([5, 3])
        >>> height = torch.tensor([7, 4])
        >>> bbox_generator(x_start, y_start, width, height)
        tensor([[[0, 1],
                 [4, 1],
                 [4, 7],
                 [0, 7]],
        <BLANKLINE>
                [[1, 0],
                 [3, 0],
                 [3, 3],
                 [1, 3]]])

    r   r   z6`x_start` and `y_start` must be a scalar or (B,). Got , r5   z3`width` and `height` must be a scalar or (B,). Got 5All tensors must be in the same dtype. Got `x_start`(), `y_start`(), `width`(), `height`().6All tensors must be in the same device. Got `x_start`(r   r9   r   Nr   r   r   )
r   r]   r@   r;   r:   r   rB   repeatr   r!   )rn   ro   rO   rP   bboxr2   r2   r3   r	   A  sL   !((.&&r	   z_startdepthc                 C  s  | j |j   kr|j krn n|  dv s#td|  d| d| d|j |j   kr0|j kr8n n| dv sFtd| d| d| d| j|j  kre|j  kre|j  kre|j  kre|jksn td| j d| j d| j d	|j d
|j d|j d| j|j  kr|j  kr|j  kr|j  kr|jksn td| j d| j d| j d	|j d
|j d|j dtjg dg dg dg dgg| j| jdt| dd}|dddddf  | 	dd7  < |dddddf  |	dd7  < |dddddf  |	dd7  < |ddddf  |7  < |ddddf  |7  < |ddddf  |7  < |ddddf  |7  < |
 }|dddddf  |jdddd7  < tj||gdd}|S )a  Generate 3D bounding boxes according to the provided start coords, width, height and depth.

    Args:
        x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
            tensor or :math:`(B,)`.
        y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
            tensor or :math:`(B,)`.
        z_start: a tensor containing the z coordinates of the bounding boxes to be extracted. Shape must be a scalar
            tensor or :math:`(B,)`.
        width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
        height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
        depth: depths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.

    Returns:
        the 3d bounding box tensor :math:`(B, 8, 3)`.

    Examples:
        >>> x_start = torch.tensor([0, 3])
        >>> y_start = torch.tensor([1, 4])
        >>> z_start = torch.tensor([2, 5])
        >>> width = torch.tensor([10, 40])
        >>> height = torch.tensor([20, 50])
        >>> depth = torch.tensor([30, 60])
        >>> bbox_generator3d(x_start, y_start, z_start, width, height, depth)
        tensor([[[ 0,  1,  2],
                 [10,  1,  2],
                 [10, 21,  2],
                 [ 0, 21,  2],
                 [ 0,  1, 32],
                 [10,  1, 32],
                 [10, 21, 32],
                 [ 0, 21, 32]],
        <BLANKLINE>
                [[ 3,  4,  5],
                 [43,  4,  5],
                 [43, 54,  5],
                 [ 3, 54,  5],
                 [ 3,  4, 65],
                 [43,  4, 65],
                 [43, 54, 65],
                 [ 3, 54, 65]]])

    rp   zA`x_start`, `y_start` and `z_start` must be a scalar or (B,). Got rq   r5   z<`width`, `height` and `depth` must be a scalar or (B,). Got rr   rs   z), `z_start`(rt   ru   z) and `depth`(rv   rw   )r   r   r   r9   r   Nr   r   r   r   r]   r   )r   r]   r@   r;   r:   r   rB   rx   r   r!   clone	unsqueezecat)rn   ro   rz   rO   rP   r{   ry   	bbox_backr2   r2   r3   r
     sl   *3*@@$&&&.r
   xyxy	trans_matmodestrrestore_coordinatesOptional[bool]c                 C  s  t |tstdt| |dvrtd| |du r2|jdd tddgks2tj	dd	d
 |dkrJ|d |d  |d< |d |d  |d< t
| ||jd dd}||}|du sb|r|jdd tddgks| }tj|dddgf ddd |d< tj|dd	dgf ddd |d< tj|dddgf ddd |d< tj|dd	dgf ddd |d< |}|dkr|d |d  |d< |d |d  |d< |S )ar  Apply a transformation matrix to a box or batch of boxes.

    Args:
        trans_mat: The transformation matrix to be applied with a shape of :math:`(3, 3)`
            or batched as :math:`(B, 3, 3)`.
        boxes: The boxes to be transformed with a common shape of :math:`(N, 4)` or batched as :math:`(B, N, 4)`, the
            polygon shape of :math:`(B, N, 4, 2)` is also supported.
        mode: The format in which the boxes are provided. If set to 'xyxy' the boxes are assumed to be in the format
            ``xmin, ymin, xmax, ymax``. If set to 'xywh' the boxes are assumed to be in the format
            ``xmin, ymin, width, height``
        restore_coordinates: In case the boxes are flipped, adding a post processing step to restore the
            coordinates to a valid bounding box.

    Returns:
        The set of transformed points in the specified mode

    zMode must be a string. Got )r   xywhz(Mode must be one of 'xyxy', 'xywh'. Got Nr   r   r   aP  Previous behaviour produces incorrect box coordinates if a flip transformation performed on boxes.The previous wrong behaviour has been corrected and will be removed in the future versions.If you wish to keep the previous behaviour, please set `restore_coordinates=False`.Otherwise, set `restore_coordinates=True` as an acknowledgement.r   )
stacklevelr   ).r   ).r   ).r   ).r   r   r   .r|   r   )
isinstancer   	TypeErrortype
ValueErrorr   r   r    warningswarnr   r!   view_asr}   minmax)r   r   r   r   transformed_boxesrestored_boxesr2   r2   r3   r     s2   
$	
(""""r   scoresiou_thresholdr_   c                 C  s  t | jdkr| jd dkrtd| j dt |jdkr'td|j d| jd |jd kr=td	| j|jf d| d\}}}}|| ||  }|jd
d\}}	g }
|	jd dkr|	d }|
| t|| ||	dd  }t|| ||	dd  }t|| ||	dd  }t|| ||	dd  }tj	|| dd}tj	|| dd}|| }||| ||	dd   |  }t
||kd }|	|d  }	|	jd dks_t |
dkrt|
S t|
S )aU  Perform non-maxima suppression (NMS) on tensor of bounding boxes according to the intersection-over-union (IoU).

    Args:
        boxes: tensor containing the encoded bounding boxes with the shape :math:`(N, (x_1, y_1, x_2, y_2))`.
        scores: tensor containing the scores associated to each bounding box with shape :math:`(N,)`.
        iou_threshold: the throshold to discard the overlapping boxes.

    Return:
        A tensor mask with the indices to keep from the input set of boxes and scores.

    Example:
        >>> boxes = torch.tensor([
        ...     [10., 10., 20., 20.],
        ...     [15., 5., 15., 25.],
        ...     [100., 100., 200., 200.],
        ...     [100., 100., 200., 200.]])
        >>> scores = torch.tensor([0.9, 0.8, 0.7, 0.9])
        >>> nms(boxes, scores, iou_threshold=0.8)
        tensor([0, 3, 1])

    r   r   r   zboxes expected as Nx4. Got: r5   r   zscores expected as N. Got: r   z+boxes and scores mus have same shape. Got: T)
descendingNg        )r   )r   r   r   unbindsortappendr   r   r   clampr   r   rB   )r   r   r   x1y1x2y2areas_orderkeepixx1yy1xx2yy2whinterovrindsr2   r2   r3   r     s6   
 
r   )r   r   r   r   )r   r   r   rN   )r   r   r   rQ   )r   r   rO   rR   rP   rR   r   r   )r   r   r[   r\   r   r   )
rn   r   ro   r   rO   r   rP   r   r   r   )rn   r   ro   r   rz   r   rO   r   rP   r   r{   r   r   r   )r   N)
r   r   r   r   r   r   r   r   r   r   )r   r   r   r   r   r_   r   r   )
__future__r   r   typingr   r   kornia.corer   r   r   linalgr   __all__r   r   r   r   r   r   r	   r
   r   r   r2   r2   r2   r3   <module>   s$   

%
#
"
1
+
P
?\<