from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc

from ray.data.aggregate import AbsMax, ApproximateQuantile, Max, Mean, Min, Std
from ray.data.block import BlockAccessor
from ray.data.preprocessor import Preprocessor, SerializablePreprocessorBase
from ray.data.preprocessors.version_support import SerializablePreprocessor
from ray.data.util.data_batch_conversion import BatchFormat
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
    from ray.data.dataset import Dataset

# Small epsilon value to handle near-zero values in division operations.
# This prevents numerical instability when scaling columns with very small
# variance or range. Similar to sklearn's approach.
_EPSILON = 1e-8


@PublicAPI(stability="alpha")
@SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.standard_scaler")
class StandardScaler(SerializablePreprocessorBase):
    r"""Translate and scale each column by its mean and standard deviation,
    respectively.

    The general formula is given by

    .. math::

        x' = \frac{x - \bar{x}}{s}

    where :math:`x` is the column, :math:`x'` is the transformed column,
    :math:`\bar{x}` is the column average, and :math:`s` is the column's sample
    standard deviation. If :math:`s = 0` (i.e., the column is constant-valued),
    then the transformed column will contain zeros.

    .. warning::
        :class:`StandardScaler` works best when your data is normal. If your data isn't
        approximately normal, then the transformed features won't be meaningful.

    Examples:
        >>> import pandas as pd
        >>> import ray
        >>> from ray.data.preprocessors import StandardScaler
        >>>
        >>> df = pd.DataFrame({"X1": [-2, 0, 2], "X2": [-3, -3, 3], "X3": [1, 1, 1]})
        >>> ds = ray.data.from_pandas(df)  # doctest: +SKIP
        >>> ds.to_pandas()  # doctest: +SKIP
           X1  X2  X3
        0  -2  -3   1
        1   0  -3   1
        2   2   3   1

        Columns are scaled separately.

        >>> preprocessor = StandardScaler(columns=["X1", "X2"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
                 X1        X2  X3
        0 -1.224745 -0.707107   1
        1  0.000000 -0.707107   1
        2  1.224745  1.414214   1

        Constant-valued columns get filled with zeros.

        >>> preprocessor = StandardScaler(columns=["X3"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2   X3
        0  -2  -3  0.0
        1   0  -3  0.0
        2   2   3  0.0

        >>> preprocessor = StandardScaler(
        ...     columns=["X1", "X2"],
        ...     output_columns=["X1_scaled", "X2_scaled"]
        ... )
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2  X3  X1_scaled  X2_scaled
        0  -2  -3   1  -1.224745  -0.707107
        1   0  -3   1   0.000000  -0.707107
        2   2   3   1   1.224745   1.414214

    Args:
        columns: The columns to separately scale.
        output_columns: The names of the transformed columns. If None, the transformed
            columns will be the same as the input columns. If not None, the length of
            ``output_columns`` must match the length of ``columns``, othwerwise an error
            will be raised.
    """

    def __init__(self, columns: List[str], output_columns: Optional[List[str]] = None):
        super().__init__()
        self.columns = columns
        self.output_columns = Preprocessor._derive_and_validate_output_columns(
            columns, output_columns
        )

    def _fit(self, dataset: "Dataset") -> Preprocessor:
        self.stat_computation_plan.add_aggregator(
            aggregator_fn=Mean,
            columns=self.columns,
        )
        self.stat_computation_plan.add_aggregator(
            aggregator_fn=lambda col: Std(col, ddof=0),
            columns=self.columns,
        )
        return self

    def _transform_pandas(self, df: pd.DataFrame):
        def column_standard_scaler(s: pd.Series):
            s_mean = self.stats_[f"mean({s.name})"]
            s_std = self.stats_[f"std({s.name})"]

            if s_std is None or s_mean is None:
                s[:] = np.nan
                return s

            # Handle division by zero and near-zero values for numerical stability.
            # If standard deviation is very small (constant or near-constant column),
            # treat it as 1 to avoid numerical instability.
            if s_std < _EPSILON:
                s_std = 1

            return (s - s_mean) / s_std

        df[self.output_columns] = df[self.columns].transform(column_standard_scaler)
        return df

    @staticmethod
    def _scale_column(column: pa.Array, mean: float, std: float) -> pa.Array:
        # Handle division by zero and near-zero values for numerical stability.
        if std < _EPSILON:
            std = 1

        return pc.divide(
            pc.subtract(column, pa.scalar(float(mean))), pa.scalar(float(std))
        )

    def _transform_arrow(self, table: pa.Table) -> pa.Table:
        """Transform using fast native PyArrow operations."""
        # Read all input columns first to avoid reading modified data when
        # output_columns[i] == columns[j] for i < j
        input_columns = [table.column(input_col) for input_col in self.columns]

        for input_col, output_col, column in zip(
            self.columns, self.output_columns, input_columns
        ):
            s_mean = self.stats_[f"mean({input_col})"]
            s_std = self.stats_[f"std({input_col})"]

            if s_std is None or s_mean is None:
                # Return column filled with nulls, preserving original column type
                null_array = pa.nulls(len(column), type=column.type)
                table = BlockAccessor.for_block(table).upsert_column(
                    output_col, null_array
                )
                continue

            scaled_column = self._scale_column(column, s_mean, s_std)

            table = BlockAccessor.for_block(table).upsert_column(
                output_col, scaled_column
            )

        return table

    @classmethod
    @DeveloperAPI
    def preferred_batch_format(cls) -> BatchFormat:
        return BatchFormat.ARROW

    def _get_serializable_fields(self) -> Dict[str, Any]:
        return {
            "columns": self.columns,
            "output_columns": self.output_columns,
            "_fitted": getattr(self, "_fitted", None),
        }

    def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
        # required fields
        self.columns = fields["columns"]
        self.output_columns = fields["output_columns"]
        # optional fields
        self._fitted = fields.get("_fitted")

    def __repr__(self):
        return f"{self.__class__.__name__}(columns={self.columns!r}, output_columns={self.output_columns!r})"


@PublicAPI(stability="alpha")
@SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.min_max_scaler")
class MinMaxScaler(SerializablePreprocessorBase):
    r"""Scale each column by its range.

    The general formula is given by

    .. math::

        x' = \frac{x - \min(x)}{\max{x} - \min{x}}

    where :math:`x` is the column and :math:`x'` is the transformed column. If
    :math:`\max{x} - \min{x} = 0` (i.e., the column is constant-valued), then the
    transformed column will get filled with zeros.

    Transformed values are always in the range :math:`[0, 1]`.

    .. tip::
        This can be used as an alternative to :py:class:`StandardScaler`.

    Examples:
        >>> import pandas as pd
        >>> import ray
        >>> from ray.data.preprocessors import MinMaxScaler
        >>>
        >>> df = pd.DataFrame({"X1": [-2, 0, 2], "X2": [-3, -3, 3], "X3": [1, 1, 1]})   # noqa: E501
        >>> ds = ray.data.from_pandas(df)  # doctest: +SKIP
        >>> ds.to_pandas()  # doctest: +SKIP
           X1  X2  X3
        0  -2  -3   1
        1   0  -3   1
        2   2   3   1

        Columns are scaled separately.

        >>> preprocessor = MinMaxScaler(columns=["X1", "X2"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
            X1   X2  X3
        0  0.0  0.0   1
        1  0.5  0.0   1
        2  1.0  1.0   1

        Constant-valued columns get filled with zeros.

        >>> preprocessor = MinMaxScaler(columns=["X3"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2   X3
        0  -2  -3  0.0
        1   0  -3  0.0
        2   2   3  0.0

        >>> preprocessor = MinMaxScaler(columns=["X1", "X2"], output_columns=["X1_scaled", "X2_scaled"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2  X3  X1_scaled  X2_scaled
        0  -2  -3   1        0.0        0.0
        1   0  -3   1        0.5        0.0
        2   2   3   1        1.0        1.0

    Args:
        columns: The columns to separately scale.
        output_columns: The names of the transformed columns. If None, the transformed
            columns will be the same as the input columns. If not None, the length of
            ``output_columns`` must match the length of ``columns``, othwerwise an error
            will be raised.
    """

    def __init__(self, columns: List[str], output_columns: Optional[List[str]] = None):
        super().__init__()
        self.columns = columns
        self.output_columns = Preprocessor._derive_and_validate_output_columns(
            columns, output_columns
        )

    def _fit(self, dataset: "Dataset") -> Preprocessor:
        aggregates = [Agg(col) for Agg in [Min, Max] for col in self.columns]
        self.stats_ = dataset.aggregate(*aggregates)
        return self

    def _transform_pandas(self, df: pd.DataFrame):
        def column_min_max_scaler(s: pd.Series):
            s_min = self.stats_[f"min({s.name})"]
            s_max = self.stats_[f"max({s.name})"]
            diff = s_max - s_min

            # Handle division by zero and near-zero values for numerical stability.
            # If range is very small (constant or near-constant column),
            # treat it as 1 to avoid numerical instability.
            if diff < _EPSILON:
                diff = 1

            return (s - s_min) / diff

        df[self.output_columns] = df[self.columns].transform(column_min_max_scaler)
        return df

    def _get_serializable_fields(self) -> Dict[str, Any]:
        return {
            "columns": self.columns,
            "output_columns": self.output_columns,
            "_fitted": getattr(self, "_fitted", None),
        }

    def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
        # required fields
        self.columns = fields["columns"]
        self.output_columns = fields["output_columns"]
        # optional fields
        self._fitted = fields.get("_fitted")

    def __repr__(self):
        return f"{self.__class__.__name__}(columns={self.columns!r}, output_columns={self.output_columns!r})"


@PublicAPI(stability="alpha")
@SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.max_abs_scaler")
class MaxAbsScaler(SerializablePreprocessorBase):
    r"""Scale each column by its absolute max value.

    The general formula is given by

    .. math::

        x' = \frac{x}{\max{\vert x \vert}}

    where :math:`x` is the column and :math:`x'` is the transformed column. If
    :math:`\max{\vert x \vert} = 0` (i.e., the column contains all zeros), then the
    column is unmodified.

    .. tip::
        This is the recommended way to scale sparse data. If you data isn't sparse,
        you can use :class:`MinMaxScaler` or :class:`StandardScaler` instead.

    Examples:
        >>> import pandas as pd
        >>> import ray
        >>> from ray.data.preprocessors import MaxAbsScaler
        >>>
        >>> df = pd.DataFrame({"X1": [-6, 3], "X2": [2, -4], "X3": [0, 0]})   # noqa: E501
        >>> ds = ray.data.from_pandas(df)  # doctest: +SKIP
        >>> ds.to_pandas()  # doctest: +SKIP
           X1  X2  X3
        0  -6   2   0
        1   3  -4   0

        Columns are scaled separately.

        >>> preprocessor = MaxAbsScaler(columns=["X1", "X2"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
            X1   X2  X3
        0 -1.0  0.5   0
        1  0.5 -1.0   0

        Zero-valued columns aren't scaled.

        >>> preprocessor = MaxAbsScaler(columns=["X3"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2   X3
        0  -6   2  0.0
        1   3  -4  0.0

        >>> preprocessor = MaxAbsScaler(columns=["X1", "X2"], output_columns=["X1_scaled", "X2_scaled"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2  X3  X1_scaled  X2_scaled
        0  -2  -3   1       -1.0       -1.0
        1   0  -3   1        0.0       -1.0
        2   2   3   1        1.0        1.0

    Args:
        columns: The columns to separately scale.
        output_columns: The names of the transformed columns. If None, the transformed
            columns will be the same as the input columns. If not None, the length of
            ``output_columns`` must match the length of ``columns``, othwerwise an error
            will be raised.
    """

    def __init__(self, columns: List[str], output_columns: Optional[List[str]] = None):
        super().__init__()
        self.columns = columns
        self.output_columns = Preprocessor._derive_and_validate_output_columns(
            columns, output_columns
        )

    def _fit(self, dataset: "Dataset") -> Preprocessor:
        aggregates = [AbsMax(col) for col in self.columns]
        self.stats_ = dataset.aggregate(*aggregates)
        return self

    def _transform_pandas(self, df: pd.DataFrame):
        def column_abs_max_scaler(s: pd.Series):
            s_abs_max = self.stats_[f"abs_max({s.name})"]

            # Handle division by zero.
            # All values are 0.
            if s_abs_max == 0:
                s_abs_max = 1

            return s / s_abs_max

        df[self.output_columns] = df[self.columns].transform(column_abs_max_scaler)
        return df

    def _get_serializable_fields(self) -> Dict[str, Any]:
        return {
            "columns": self.columns,
            "output_columns": self.output_columns,
            "_fitted": getattr(self, "_fitted", None),
        }

    def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
        # required fields
        self.columns = fields["columns"]
        self.output_columns = fields["output_columns"]
        # optional fields
        self._fitted = fields.get("_fitted")

    def __repr__(self):
        return f"{self.__class__.__name__}(columns={self.columns!r}, output_columns={self.output_columns!r})"


@PublicAPI(stability="alpha")
@SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.robust_scaler")
class RobustScaler(SerializablePreprocessorBase):
    r"""Scale and translate each column using approximate quantiles.

    The general formula is given by

    .. math::
        x' = \frac{x - \mu_{1/2}}{\mu_h - \mu_l}

    where :math:`x` is the column, :math:`x'` is the transformed column,
    :math:`\mu_{1/2}` is the column median. :math:`\mu_{h}` and :math:`\mu_{l}` are the
    high and low quantiles, respectively. By default, :math:`\mu_{h}` is the third
    quartile and :math:`\mu_{l}` is the first quartile.

    Internally, the `ApproximateQuantile` aggregator is used to calculate the
    approximate quantiles.

    .. tip::
        This scaler works well when your data contains many outliers.

    Examples:
        >>> import pandas as pd
        >>> import ray
        >>> from ray.data.preprocessors import RobustScaler
        >>>
        >>> df = pd.DataFrame({
        ...     "X1": [1, 2, 3, 4, 5],
        ...     "X2": [13, 5, 14, 2, 8],
        ...     "X3": [1, 2, 2, 2, 3],
        ... })
        >>> ds = ray.data.from_pandas(df)  # doctest: +SKIP
        >>> ds.to_pandas()  # doctest: +SKIP
           X1  X2  X3
        0   1  13   1
        1   2   5   2
        2   3  14   2
        3   4   2   2
        4   5   8   3

        :class:`RobustScaler` separately scales each column.

        >>> preprocessor = RobustScaler(columns=["X1", "X2"])
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
            X1     X2  X3
        0 -1.0  0.625   1
        1 -0.5 -0.375   2
        2  0.0  0.750   2
        3  0.5 -0.750   2
        4  1.0  0.000   3

        >>> preprocessor = RobustScaler(
        ...    columns=["X1", "X2"],
        ...    output_columns=["X1_scaled", "X2_scaled"]
        ... )
        >>> preprocessor.fit_transform(ds).to_pandas()  # doctest: +SKIP
           X1  X2  X3  X1_scaled  X2_scaled
        0   1  13   1       -1.0      0.625
        1   2   5   2       -0.5     -0.375
        2   3  14   2        0.0      0.750
        3   4   2   2        0.5     -0.750
        4   5   8   3        1.0      0.000

    Args:
        columns: The columns to separately scale.
        quantile_range: A tuple that defines the lower and upper quantiles. Values
            must be between 0 and 1. Defaults to the 1st and 3rd quartiles:
            ``(0.25, 0.75)``.
        output_columns: The names of the transformed columns. If None, the transformed
            columns will be the same as the input columns. If not None, the length of
            ``output_columns`` must match the length of ``columns``, othwerwise an error
            will be raised.
        quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL);
            higher values yield lower error but use more memory. Defaults to 800. See
            https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html
            for details on accuracy and size.
    """

    DEFAULT_QUANTILE_PRECISION = 800

    def __init__(
        self,
        columns: List[str],
        quantile_range: Tuple[float, float] = (0.25, 0.75),
        output_columns: Optional[List[str]] = None,
        quantile_precision: int = DEFAULT_QUANTILE_PRECISION,
    ):
        super().__init__()
        self.columns = columns
        self.quantile_range = quantile_range
        self.quantile_precision = quantile_precision

        self.output_columns = Preprocessor._derive_and_validate_output_columns(
            columns, output_columns
        )

    def _fit(self, dataset: "Dataset") -> Preprocessor:
        quantiles = [
            self.quantile_range[0],
            0.50,
            self.quantile_range[1],
        ]
        aggregates = [
            ApproximateQuantile(
                on=col,
                quantiles=quantiles,
                quantile_precision=self.quantile_precision,
            )
            for col in self.columns
        ]
        aggregated = dataset.aggregate(*aggregates)

        self.stats_ = {}
        for col in self.columns:
            low_q, med_q, high_q = aggregated[f"approx_quantile({col})"]
            self.stats_[f"low_quantile({col})"] = low_q
            self.stats_[f"median({col})"] = med_q
            self.stats_[f"high_quantile({col})"] = high_q

        return self

    def _transform_pandas(self, df: pd.DataFrame):
        def column_robust_scaler(s: pd.Series):
            s_low_q = self.stats_[f"low_quantile({s.name})"]
            s_median = self.stats_[f"median({s.name})"]
            s_high_q = self.stats_[f"high_quantile({s.name})"]
            diff = s_high_q - s_low_q

            # Handle division by zero.
            # Return all zeros.
            if diff == 0:
                return np.zeros_like(s)

            return (s - s_median) / diff

        df[self.output_columns] = df[self.columns].transform(column_robust_scaler)
        return df

    def _get_serializable_fields(self) -> Dict[str, Any]:
        return {
            "columns": self.columns,
            "output_columns": self.output_columns,
            "quantile_range": self.quantile_range,
            "quantile_precision": self.quantile_precision,
            "_fitted": getattr(self, "_fitted", None),
        }

    def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
        # required fields
        self.columns = fields["columns"]
        self.output_columns = fields["output_columns"]
        self.quantile_range = fields["quantile_range"]
        self.quantile_precision = fields["quantile_precision"]
        # optional fields
        self._fitted = fields.get("_fitted")

    def __repr__(self):
        return (
            f"{self.__class__.__name__}(columns={self.columns!r}, "
            f"quantile_range={self.quantile_range!r}), "
            f"output_columns={self.output_columns!r})"
        )
