# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import math

import torch
import torch.nn.functional as F

from kornia.color import rgb_to_ycbcr, ycbcr_to_rgb
from kornia.constants import pi
from kornia.core import Device, Dtype, Parameter, Tensor
from kornia.core import ImageModule as Module
from kornia.core.check import (
    KORNIA_CHECK,
    KORNIA_CHECK_IS_TENSOR,
    KORNIA_CHECK_SHAPE,
)
from kornia.geometry.transform.affwarp import rescale
from kornia.utils.image import perform_keep_shape_image
from kornia.utils.misc import (
    differentiable_clipping,
    differentiable_polynomial_floor,
    differentiable_polynomial_rounding,
)

__all__ = ["JPEGCodecDifferentiable", "jpeg_codec_differentiable"]


def _get_default_qt_y(device: Device, dtype: Dtype) -> Tensor:
    """Generate default Quantization table of Y channel."""
    return torch.tensor(
        [
            [16, 11, 10, 16, 24, 40, 51, 61],
            [12, 12, 14, 19, 26, 58, 60, 55],
            [14, 13, 16, 24, 40, 57, 69, 56],
            [14, 17, 22, 29, 51, 87, 80, 62],
            [18, 22, 37, 56, 68, 109, 103, 77],
            [24, 35, 55, 64, 81, 104, 113, 92],
            [49, 64, 78, 87, 103, 121, 120, 101],
            [72, 92, 95, 98, 112, 100, 103, 99],
        ],
        device=device,
        dtype=dtype,
    )


def _get_default_qt_c(device: Device, dtype: Dtype) -> Tensor:
    """Generate default Quantization table of C channels."""
    return torch.tensor(
        [
            [17, 18, 24, 47, 99, 99, 99, 99],
            [18, 21, 26, 66, 99, 99, 99, 99],
            [24, 26, 56, 99, 99, 99, 99, 99],
            [47, 66, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
        ],
        device=device,
        dtype=dtype,
    )


def _patchify_8x8(input: Tensor) -> Tensor:
    """Extract non-overlapping 8 x 8 patches from the given input image.

    Args:
        input (Tensor): Input image of the shape :math:`(B, H, W)`.

    Returns:
        output (Tensor): Image patchify of the shape :math:`(B, N, 8, 8)`.

    """
    # Get input shape
    B, H, W = input.shape
    # Patchify to shape [B, N, H // 8, W // 8]
    output: Tensor = input.view(B, H // 8, 8, W // 8, 8).permute(0, 1, 3, 2, 4).reshape(B, -1, 8, 8)
    return output


def _unpatchify_8x8(input: Tensor, H: int, W: int) -> Tensor:
    """Reverse non-overlapping 8 x 8 patching.

    Args:
        input (Tensor): Input image of the shape :math:`(B, N, 8, 8)`.
        H: height of resulting tensor.
        W: width of resulting tensor.

    Returns:
        output (Tensor): Image patchify of the shape :math:`(B, H, W)`.

    """
    # Get input shape
    B, _N = input.shape[:2]
    # Unpatch to [B, H, W]
    output: Tensor = input.view(B, H // 8, W // 8, 8, 8).permute(0, 1, 3, 2, 4).reshape(B, H, W)
    return output


def _dct_8x8(input: Tensor) -> Tensor:
    """Perform an 8 x 8 discrete cosine transform.

    Args:
        input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.

    Returns:
        output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.

    """
    # Get dtype and device
    dtype: Dtype = input.dtype
    device: Device = input.device
    # Make DCT tensor and scaling
    index: Tensor = torch.arange(8, dtype=dtype, device=device)
    x, y, u, v = torch.meshgrid(index, index, index, index)
    dct_tensor: Tensor = ((2.0 * x + 1.0) * u * pi / 16.0).cos() * ((2.0 * y + 1.0) * v * pi / 16.0).cos()
    alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
    alpha[0] = 1.0 / (2**0.5)
    dct_scale: Tensor = torch.einsum("i, j -> ij", alpha, alpha) * 0.25
    # Apply DCT
    output: Tensor = dct_scale[None, None] * torch.tensordot(input - 128.0, dct_tensor)
    return output


def _idct_8x8(input: Tensor) -> Tensor:
    """Perform an 8 x 8 discrete cosine transform.

    Args:
        input (Tensor): Patched input tensor of the shape :math:`(B, N, 8, 8)`.

    Returns:
        output (Tensor): DCT output tensor of the shape :math:`(B, N, 8, 8)`.

    """
    # Get dtype and device
    dtype: Dtype = input.dtype
    device: Device = input.device
    # Make and apply scaling
    alpha: Tensor = torch.ones(8, dtype=dtype, device=device)
    alpha[0] = 1.0 / (2**0.5)
    dct_scale: Tensor = torch.outer(alpha, alpha)
    input = input * dct_scale[None, None]
    # Make DCT tensor and scaling
    index: Tensor = torch.arange(8, dtype=dtype, device=device)
    x, y, u, v = torch.meshgrid(index, index, index, index)
    idct_tensor: Tensor = ((2.0 * u + 1.0) * x * pi / 16.0).cos() * ((2.0 * v + 1.0) * y * pi / 16.0).cos()
    # Apply DCT
    output: Tensor = 0.25 * torch.tensordot(input, idct_tensor, dims=2) + 128.0
    return output


def _jpeg_quality_to_scale(
    compression_strength: Tensor,
) -> Tensor:
    """Convert a given JPEG quality to the scaling factor.

    Args:
        compression_strength (Tensor): Compression strength ranging from 0 to 100. Any shape is supported.

    Returns:
        scale (Tensor): Scaling factor to be applied to quantization matrix. Same shape as input.

    """
    # Get scale
    scale: Tensor = differentiable_polynomial_floor(
        torch.where(
            compression_strength < 50,
            5000.0 / compression_strength,
            200.0 - 2.0 * compression_strength,
        )
    )
    return scale


def _quantize(
    input: Tensor,
    jpeg_quality: Tensor,
    quantization_table: Tensor,
) -> Tensor:
    """Perform quantization.

    Args:
        input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
        jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
        quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.

    Returns:
        output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.

    """
    # Scale quantization table
    quantization_table_scaled: Tensor = (
        quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
    )
    # Perform scaling
    quantization_table = differentiable_polynomial_floor(
        differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
    )
    output: Tensor = input / quantization_table
    # Perform rounding
    output = differentiable_polynomial_rounding(output)
    return output


def _dequantize(
    input: Tensor,
    jpeg_quality: Tensor,
    quantization_table: Tensor,
) -> Tensor:
    """Perform dequantization.

    Args:
        input (Tensor): Input tensor of the shape :math:`(B, N, 8, 8)`.
        jpeg_quality (Tensor): Compression strength to be applied, shape is :math:`(B)`.
        quantization_table (Tensor): Quantization table of the shape :math:`(1, 8, 8)` or :math:`(B, 8, 8)`.

    Returns:
        output (Tensor): Quantized output tensor of the shape :math:`(B, N, 8, 8)`.

    """
    # Scale quantization table
    quantization_table_scaled: Tensor = (
        quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
    )
    # Perform scaling
    output: Tensor = input * differentiable_polynomial_floor(
        differentiable_clipping((quantization_table_scaled + 50.0) / 100.0, 1, 255)
    )
    return output


def _chroma_subsampling(input_ycbcr: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    """Implement chroma subsampling.

    Args:
        input_ycbcr (Tensor): YCbCr input tensor of the shape :math:`(B, 3, H, W)`.

    Returns:
        output_y (Tensor): Y component (not-subsampled), shape is :math:`(B, H, W)`.
        output_cb (Tensor): Cb component (subsampled), shape is :math:`(B, H // 2, W // 2)`.
        output_cr (Tensor): Cr component (subsampled), shape is :math:`(B, H // 2, W // 2)`.

    """
    # Get components
    output_y: Tensor = input_ycbcr[:, 0]
    output_cb: Tensor = input_ycbcr[:, 1]
    output_cr: Tensor = input_ycbcr[:, 2]
    # Perform average pooling of Cb and Cr channels
    output_cb = rescale(
        output_cb[:, None],
        factor=0.5,
        interpolation="bilinear",
        align_corners=False,
        antialias=True,
    )
    output_cr = rescale(
        output_cr[:, None],
        factor=0.5,
        interpolation="bilinear",
        align_corners=False,
        antialias=True,
    )
    return output_y, output_cb[:, 0], output_cr[:, 0]


def _chroma_upsampling(input_c: Tensor) -> Tensor:
    """Perform chroma upsampling.

    Args:
        input_c (Tensor): Cb or Cr component to be upsampled of the shape :math:`(B, H, W)`.

    Returns:
        output_c (Tensor): Upsampled C(b or r) component of the shape :math:`(B, H * 2, W * 2)`.

    """
    # Upsample component
    output_c: Tensor = rescale(
        input_c[:, None],
        factor=2.0,
        interpolation="bilinear",
        align_corners=False,
        antialias=False,
    )
    return output_c[:, 0]


def _jpeg_encode(
    image_rgb: Tensor,
    jpeg_quality: Tensor,
    quantization_table_y: Tensor,
    quantization_table_c: Tensor,
) -> tuple[Tensor, Tensor, Tensor]:
    """Perform JPEG encoding.

    Args:
        image_rgb (Tensor): RGB input images of the shape :math:`(B, 3, H, W)`.
        jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
        quantization_table_y (Tensor): Quantization table for Y channel.
        quantization_table_c (Tensor): Quantization table for C channels.

    Returns:
        y_encoded (Tensor): Encoded Y component of the shape :math:`(B, N, 8, 8)`.
        cb_encoded (Tensor): Encoded Cb component of the shape :math:`(B, N, 8, 8)`.
        cr_encoded (Tensor): Encoded Cr component of the shape :math:`(B, N, 8, 8)`.

    """
    # Convert RGB image to YCbCr.
    image_ycbcr: Tensor = rgb_to_ycbcr(image_rgb)
    # Scale pixel-range to [0, 255]
    image_ycbcr = 255.0 * image_ycbcr
    # Perform chroma subsampling
    input_y, input_cb, input_cr = _chroma_subsampling(image_ycbcr)
    # Patchify, DCT, and rounding
    input_y, input_cb, input_cr = (
        _patchify_8x8(input_y),
        _patchify_8x8(input_cb),
        _patchify_8x8(input_cr),
    )
    dct_y = _dct_8x8(input_y)
    dct_cb_cr = _dct_8x8(torch.cat((input_cb, input_cr), dim=1))
    y_encoded: Tensor = _quantize(
        dct_y,
        jpeg_quality,
        quantization_table_y,
    )
    cb_encoded, cr_encoded = _quantize(
        dct_cb_cr,
        jpeg_quality,
        quantization_table_c,
    ).chunk(2, dim=1)
    return y_encoded, cb_encoded, cr_encoded


def _jpeg_decode(
    input_y: Tensor,
    input_cb: Tensor,
    input_cr: Tensor,
    jpeg_quality: Tensor,
    H: int,
    W: int,
    quantization_table_y: Tensor,
    quantization_table_c: Tensor,
) -> Tensor:
    """Perform JPEG decoding.

    Args:
        input_y (Tensor): Compressed Y component of the shape :math:`(B, N, 8, 8)`.
        input_cb (Tensor): Compressed Cb component of the shape :math:`(B, N, 8, 8)`.
        input_cr (Tensor): Compressed Cr component of the shape :math:`(B, N, 8, 8)`.
        jpeg_quality (Tensor): Compression strength of the shape :math:`(B)`.
        H (int): Original image height.
        W (int): Original image width.
        quantization_table_y (Tensor): Quantization table for Y channel.
        quantization_table_c (Tensor): Quantization table for C channels.

    Returns:
        rgb_decoded (Tensor): Decompressed RGB image of the shape :math:`(B, 3, H, W)`.

    """
    # Dequantize inputs
    input_y = _dequantize(
        input_y,
        jpeg_quality,
        quantization_table_y,
    )
    input_cb_cr = _dequantize(
        torch.cat((input_cb, input_cr), dim=1),
        jpeg_quality,
        quantization_table_c,
    )
    # Perform inverse DCT
    idct_y: Tensor = _idct_8x8(input_y)
    idct_cb, idct_cr = _idct_8x8(input_cb_cr).chunk(2, dim=1)
    # Reverse patching
    image_y: Tensor = _unpatchify_8x8(idct_y, H, W)
    image_cb: Tensor = _unpatchify_8x8(idct_cb, H // 2, W // 2)
    image_cr: Tensor = _unpatchify_8x8(idct_cr, H // 2, W // 2)
    # Perform chroma upsampling
    image_cb = _chroma_upsampling(image_cb)
    image_cr = _chroma_upsampling(image_cr)
    # Back to [0, 1] pixel-range
    image_ycbcr: Tensor = torch.stack((image_y, image_cb, image_cr), dim=1) / 255.0
    # Convert back to RGB space.
    rgb_decoded: Tensor = ycbcr_to_rgb(image_ycbcr)
    return rgb_decoded


def _perform_padding(image: Tensor) -> tuple[Tensor, int, int]:
    """Pad a given image to be dividable by 16.

    Args:
        image: Image of the shape :math:`(*, 3, H, W)`.

    Returns:
        image_padded: Padded image of the shape :math:`(*, 3, H_{new}, W_{new})`.
        h_pad: Padded pixels along the horizontal axis.
        w_pad: Padded pixels along the vertical axis.

    """
    # Get spatial dimensions of the image
    H, W = image.shape[-2:]
    # Compute horizontal and vertical padding
    h_pad: int = math.ceil(H / 16) * 16 - H
    w_pad: int = math.ceil(W / 16) * 16 - W
    # Perform padding (we follow JPEG and pad only the bottom and right side of the image)
    image_padded: Tensor = F.pad(image, (0, w_pad, 0, h_pad), "replicate")
    return image_padded, h_pad, w_pad


@perform_keep_shape_image
def jpeg_codec_differentiable(
    image_rgb: Tensor,
    jpeg_quality: Tensor,
    quantization_table_y: Tensor | None = None,
    quantization_table_c: Tensor | None = None,
) -> Tensor:
    r"""Differentiable JPEG encoding-decoding module.

    Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:

    .. image:: _static/img/jpeg_codec_differentiable.png

    .. math::

        \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}

    Where:
       - :math:`I` is the original image to be coded.
       - :math:`q` is the JPEG quality controlling the compression strength.
       - :math:`QT_{y}` is the luma quantization table.
       - :math:`QT_{c}` is the chroma quantization table.
       - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.

    .. note:::
        The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
        required to first perform denormalization followed by normalizing the output images again.

        Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
        however, does not allow the excess of the JPEG-coded byte file itself.
        For more details please refer to :cite:`reich2024`.

        This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
        There we provide an optimized Rust implementation for fast JPEG loading.

    Args:
        image_rgb: the RGB image to be coded.
        jpeg_quality: JPEG quality in the range :math:`[0, 100]` controlling the compression strength.
        quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
          quantization table.
        quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
          quantization table.

    Shape:
        - image_rgb: :math:`(*, 3, H, W)`.
        - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).
        - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
        - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).

    Return:
        JPEG coded image of the shape :math:`(B, 3, H, W)`

    Example:
        To perform JPEG coding with the standard quantization tables just provide a JPEG quality

        >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
        >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality)
        >>> img_jpeg.sum().backward()

        You also have the option to provide custom quantization tables

        >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
        >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
        >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
        >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
        >>> img_jpeg.sum().backward()

        In case you want to control the quantization purly base on the quantization tables use a JPEG quality of 99.5.
        Setting the JPEG quality to 99.5 leads to a QT scaling of 1, see Eq. 2 of :cite:`reich2024` for details.

        >>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.ones(3) * 99.5
        >>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
        >>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
        >>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
        >>> img_jpeg.sum().backward()

    """
    # Check that inputs are tensors
    KORNIA_CHECK_IS_TENSOR(image_rgb)
    KORNIA_CHECK_IS_TENSOR(jpeg_quality)
    # Get device and dtype
    dtype: Dtype = image_rgb.dtype
    device: Device = image_rgb.device
    # Use default QT if QT is not given
    quantization_table_y = _get_default_qt_y(device, dtype) if quantization_table_y is None else quantization_table_y
    quantization_table_c = _get_default_qt_c(device, dtype) if quantization_table_c is None else quantization_table_c
    KORNIA_CHECK_IS_TENSOR(quantization_table_y)
    KORNIA_CHECK_IS_TENSOR(quantization_table_c)
    # Check shape of inputs
    KORNIA_CHECK_SHAPE(image_rgb, ["*", "3", "H", "W"])
    KORNIA_CHECK_SHAPE(jpeg_quality, ["B"])
    # Add batch dimension to quantization tables if needed
    if quantization_table_y.ndim == 2:
        quantization_table_y = quantization_table_y.unsqueeze(dim=0)
    if quantization_table_c.ndim == 2:
        quantization_table_c = quantization_table_c.unsqueeze(dim=0)
    # Check resulting shape of quantization tables
    KORNIA_CHECK_SHAPE(quantization_table_y, ["B", "8", "8"])
    KORNIA_CHECK_SHAPE(quantization_table_c, ["B", "8", "8"])
    # Check value range of JPEG quality
    KORNIA_CHECK(
        (jpeg_quality.amin().item() >= 0.0) and (jpeg_quality.amax().item() <= 100.0),
        f"JPEG quality is out of range. Expected range is [0, 100], "
        f"got [{jpeg_quality.amin().item()}, {jpeg_quality.amax().item()}]. Consider clipping jpeg_quality.",
    )
    # Pad the image to a shape dividable by 16
    image_rgb, h_pad, w_pad = _perform_padding(image_rgb)
    # Get height and shape
    H, W = image_rgb.shape[-2:]
    # Check matching batch dimensions
    if quantization_table_y.shape[0] != 1:
        KORNIA_CHECK(
            quantization_table_y.shape[0] == image_rgb.shape[0],
            f"Batch dimensions do not match. "
            f"Got {image_rgb.shape[0]} images and {quantization_table_y.shape[0]} quantization tables (Y).",
        )
    if quantization_table_c.shape[0] != 1:
        KORNIA_CHECK(
            quantization_table_c.shape[0] == image_rgb.shape[0],
            f"Batch dimensions do not match. "
            f"Got {image_rgb.shape[0]} images and {quantization_table_c.shape[0]} quantization tables (C).",
        )
    if jpeg_quality.shape[0] != 1:
        KORNIA_CHECK(
            jpeg_quality.shape[0] == image_rgb.shape[0],
            f"Batch dimensions do not match. "
            f"Got {image_rgb.shape[0]} images and {jpeg_quality.shape[0]} JPEG qualities.",
        )
    # keep jpeg_quality same device as input tensor
    jpeg_quality = jpeg_quality.to(device, dtype)
    # Quantization tables to same device and dtype as input image
    quantization_table_y = quantization_table_y.to(device, dtype)
    quantization_table_c = quantization_table_c.to(device, dtype)
    # Perform encoding
    y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
        image_rgb=image_rgb,
        jpeg_quality=jpeg_quality,
        quantization_table_c=quantization_table_c,
        quantization_table_y=quantization_table_y,
    )
    image_rgb_jpeg: Tensor = _jpeg_decode(
        input_y=y_encoded,
        input_cb=cb_encoded,
        input_cr=cr_encoded,
        jpeg_quality=jpeg_quality,
        H=H,
        W=W,
        quantization_table_c=quantization_table_c,
        quantization_table_y=quantization_table_y,
    )
    # Clip coded image
    image_rgb_jpeg = differentiable_clipping(input=image_rgb_jpeg, min_val=0.0, max_val=255.0)
    # Crop the image again to the original shape
    image_rgb_jpeg = image_rgb_jpeg[..., : H - h_pad, : W - w_pad]
    return image_rgb_jpeg


class JPEGCodecDifferentiable(Module):
    r"""Differentiable JPEG encoding-decoding module.

    Based on :cite:`reich2024` :cite:`shin2017`, we perform differentiable JPEG encoding-decoding as follows:

    .. math::

        \text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}

    Where:
       - :math:`I` is the original image to be coded.
       - :math:`q` is the JPEG quality controlling the compression strength.
       - :math:`QT_{y}` is the luma quantization table.
       - :math:`QT_{c}` is the chroma quantization table.
       - :math:`\hat{I}` is the resulting JPEG encoded-decoded image.

    .. image:: _static/img/jpeg_codec_differentiable.png

    .. note::
        The input (and output) pixel range is :math:`[0, 1]`. In case you want to handle normalized images you are
        required to first perform denormalization followed by normalizing the output images again.

        Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting,
        however, does not allow the excess of the JPEG-coded byte file itself.
        For more details please refer to :cite:`reich2024`.

        This implementation is not meant for data loading. For loading JPEG images please refer to `kornia.io`.
        There we provide an optimized Rust implementation for fast JPEG loading.

    Args:
        quantization_table_y: quantization table for Y channel. Default: `None`, which will load the standard
          quantization table.
        quantization_table_c: quantization table for C channels. Default: `None`, which will load the standard
          quantization table.

    Shape:
        - quantization_table_y: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
        - quantization_table_c: :math:`(8, 8)` or :math:`(B, 8, 8)` (if used batch dim. needs to match w/ image_rgb).
        - image_rgb: :math:`(*, 3, H, W)`.
        - jpeg_quality: :math:`(1)` or :math:`(B)` (if used batch dim. needs to match w/ image_rgb).

    Example:
        You can use the differentiable JPEG module with standard quantization tables by

        >>> diff_jpeg_module = JPEGCodecDifferentiable()
        >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
        >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
        >>> img_jpeg.sum().backward()

        You can also specify custom quantization tables to be used by

        >>> quantization_table_y = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
        >>> quantization_table_c = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
        >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
        >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
        >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
        >>> img_jpeg.sum().backward()

        In case you want to learn the quantization tables just pass parameters `nn.Parameter`

        >>> quantization_table_y = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
        >>> quantization_table_c = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
        >>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
        >>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
        >>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
        >>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
        >>> img_jpeg.sum().backward()

    """

    def __init__(
        self,
        quantization_table_y: Tensor | Parameter | None = None,
        quantization_table_c: Tensor | Parameter | None = None,
    ) -> None:
        super().__init__()
        # Get default quantization tables if needed
        quantization_table_y = _get_default_qt_y(None, None) if quantization_table_y is None else quantization_table_y
        quantization_table_c = _get_default_qt_c(None, None) if quantization_table_c is None else quantization_table_c
        if isinstance(quantization_table_y, Parameter):
            self.register_parameter("quantization_table_y", quantization_table_y)
        else:
            self.register_buffer("quantization_table_y", quantization_table_y)
        if isinstance(quantization_table_c, Parameter):
            self.register_parameter("quantization_table_c", quantization_table_c)
        else:
            self.register_buffer("quantization_table_c", quantization_table_c)

    def forward(
        self,
        image_rgb: Tensor,
        jpeg_quality: Tensor,
    ) -> Tensor:
        device = image_rgb.device
        dtype = image_rgb.dtype
        # Move quantization tables to the same device and dtype as input
        # and store it in the local variables created in init
        quantization_table_y = self.quantization_table_y.to(device, dtype)
        quantization_table_c = self.quantization_table_c.to(device, dtype)
        # Perform encoding-decoding
        image_rgb_jpeg: Tensor = jpeg_codec_differentiable(
            image_rgb,
            jpeg_quality=jpeg_quality,
            quantization_table_c=quantization_table_c,
            quantization_table_y=quantization_table_y,
        )
        return image_rgb_jpeg
