import keras
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.distillation.distillation_loss import _convert_loss_to_function
from keras.src.models.model import Model
from keras.src.saving import serialization_lib


@keras_export("keras.distillation.Distiller")
class Distiller(Model):
    """Distillation model for transferring knowledge from teacher to student.

    Knowledge distillation transfers knowledge from a large, complex model
    (teacher) to a smaller, simpler model (student). The student learns
    from both ground truth labels and the teacher's predictions, often
    achieving better performance than training on labels alone.

    Arguments:
        teacher: A trained `keras.Model` that serves as the knowledge source.
            The teacher model is frozen during distillation.
        student: A `keras.Model` to be trained through distillation.
        distillation_losses: List of distillation losses to apply. Can be a
            single distillation loss or a list of distillation losses like
            `keras.distillation.LogitsDistillation`,
            `keras.distillation.FeatureDistillation`, or custom distillation
            losses.
        distillation_loss_weights: List of weights for each distillation loss.
            Must have the same length as `distillation_losses`. If `None`,
            equal weights are used.
        student_loss_weight: Weight for the student's supervised loss component.
            Must be between 0 and 1. Defaults to 0.5.
        name: Name for the distiller model. Defaults to `"distiller"`.
        **kwargs: Additional keyword arguments passed to the parent `Model`
            class.

    Attributes:
        student: The student model being trained. Access this to get the trained
            student model for independent use after distillation training.
        teacher: The teacher model providing knowledge. This model is frozen
            during training.

    Examples:

    ```python
    # Basic distillation with KerasHub models
    import keras_hub as hub

    teacher = hub.models.CausalLM.from_preset("gemma_2b_en")
    student = hub.models.CausalLM.from_preset(
        "gemma_1.1_2b_en", load_weights=False
    )

    # Single distillation loss
    distiller = Distiller(
        teacher=teacher,
        student=student,
        distillation_losses=LogitsDistillation(temperature=3.0),
    )

    # Compile the distiller (like any Keras model)
    distiller.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # Train the distiller
    distiller.fit(x_train, y_train, epochs=10)

    # Access the trained student model
    trained_student = distiller.student

    # Multiple distillation losses
    distiller = Distiller(
        teacher=teacher,
        student=student,
        distillation_losses=[
            LogitsDistillation(temperature=3.0),
            FeatureDistillation(
                teacher_layer_name="dense_1",
                student_layer_name="dense_1"
            )
        ],
        distillation_loss_weights=[1.0, 0.5],
    )

    # Compile with custom settings
    distiller.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    ```
    """

    def __init__(
        self,
        teacher,
        student,
        distillation_losses,
        distillation_loss_weights=None,
        student_loss_weight=0.5,
        name="distiller",
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)

        # Validate inputs
        self._validate_models(teacher, student)

        # Store configuration
        self.teacher = teacher
        self.student = student

        # Validate student_loss_weight
        if not isinstance(student_loss_weight, (int, float)):
            raise ValueError(
                f"student_loss_weight must be a number, got "
                f"{type(student_loss_weight)}"
            )
        if student_loss_weight < 0.0 or student_loss_weight > 1.0:
            raise ValueError(
                f"student_loss_weight must be between 0.0 and 1.0, "
                f"got {student_loss_weight}"
            )
        self.student_loss_weight = student_loss_weight

        # Handle distillation losses configuration
        if distillation_losses is None:
            raise ValueError(
                "'distillation_losses' cannot be `None`. Provide a "
                "distillation loss (e.g., LogitsDistillation or "
                "FeatureDistillation) or a list of distillation losses."
            )

        # Convert single distillation loss to list for uniform handling
        if not isinstance(distillation_losses, (list, tuple)):
            self.distillation_losses = [distillation_losses]
            self.distillation_loss_weights = [1.0]
        else:
            self.distillation_losses = distillation_losses
            # Set default weights if not provided
            if distillation_loss_weights is None:
                self.distillation_loss_weights = [1.0] * len(
                    distillation_losses
                )
            else:
                if len(distillation_loss_weights) != len(distillation_losses):
                    raise ValueError(
                        f"Number of distillation_loss_weights "
                        f"({len(distillation_loss_weights)}) must match "
                        f"number of distillation_losses "
                        f"({len(distillation_losses)})"
                    )
                self.distillation_loss_weights = distillation_loss_weights

        # Validate distillation loss compatibility and create extractors
        for distillation_loss in self.distillation_losses:
            self._validate_distillation_loss_compatibility(
                teacher, student, distillation_loss
            )

        self._create_multi_feature_extractors()

        # Freeze teacher model
        self.teacher.trainable = False

        # Initialize loss tracking metrics
        self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
        self.distillation_loss_tracker = keras.metrics.Mean(
            name="distillation_loss"
        )
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")

    def _validate_models(self, teacher, student):
        """Validate that teacher and student models are compatible."""
        if not isinstance(teacher, keras.Model):
            raise ValueError(
                f"Teacher must be a keras.Model, got {type(teacher)}"
            )
        if not isinstance(student, keras.Model):
            raise ValueError(
                f"Student must be a keras.Model, got {type(student)}"
            )

        self._validate_input_compatibility(teacher, student)
        self._validate_output_compatibility(teacher, student)
        self._validate_dtype_compatibility(teacher, student)

    def _assert_shapes_are_compatible(self, shape1, shape2, context):
        """Assert that two shapes are compatible."""
        if len(shape1) != len(shape2):
            raise ValueError(
                f"Teacher and student {context} shapes have different "
                f"dimensions. Teacher: {shape1}, Student: {shape2}."
            )

        for dim1, dim2 in zip(shape1, shape2):
            if dim1 is not None and dim2 is not None and dim1 != dim2:
                raise ValueError(
                    f"Teacher and student {context} shapes are incompatible. "
                    f"Teacher: {shape1}, Student: {shape2}. "
                    f"All dimensions must match."
                )

    def _assert_same_dtype(self, teacher_dtype, student_dtype, context):
        """Assert that teacher and student dtypes are the same."""
        if teacher_dtype != student_dtype:
            raise ValueError(
                f"Teacher and student {context} dtypes must match. "
                f"Teacher: {teacher_dtype}, Student: {student_dtype}."
            )

    def _validate_input_compatibility(self, teacher, student):
        """Validate that teacher and student have compatible input shapes."""
        if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
            return
        teacher_inputs = getattr(teacher, "inputs")
        student_inputs = getattr(student, "inputs")
        if teacher_inputs is None or student_inputs is None:
            return

        tree.map_structure(
            lambda ti, si: self._assert_shapes_are_compatible(
                ti.shape, si.shape, "input"
            ),
            teacher_inputs,
            student_inputs,
        )

    def _validate_output_compatibility(self, teacher, student):
        """Validate that teacher and student have compatible output shapes."""
        if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
            return
        teacher_outputs = getattr(teacher, "outputs")
        student_outputs = getattr(student, "outputs")
        if teacher_outputs is None or student_outputs is None:
            return

        tree.map_structure(
            lambda to, so: self._assert_shapes_are_compatible(
                to.shape, so.shape, "output"
            ),
            teacher_outputs,
            student_outputs,
        )

    def _validate_dtype_compatibility(self, teacher, student):
        """Validate that teacher and student have compatible data types."""
        if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
            return
        if teacher.inputs is None or student.inputs is None:
            return

        tree.map_structure(
            lambda ti, si: self._assert_same_dtype(ti.dtype, si.dtype, "input"),
            teacher.inputs,
            student.inputs,
        )

        if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
            return
        if teacher.outputs is None or student.outputs is None:
            return

        tree.map_structure(
            lambda to, so: self._assert_same_dtype(
                to.dtype, so.dtype, "output"
            ),
            teacher.outputs,
            student.outputs,
        )

    def _validate_distillation_loss_compatibility(
        self, teacher, student, distillation_loss
    ):
        """Validate that the distillation loss is compatible with teacher
        and student models."""
        distillation_loss.validate_model_compatibility(teacher, student)

    def _create_multi_feature_extractors(self):
        """Create feature extractors for efficient multi-layer extraction."""
        teacher_layer_names = []
        student_layer_names = []

        for distillation_loss in self.distillation_losses:
            if (
                hasattr(distillation_loss, "teacher_layer_name")
                and distillation_loss.teacher_layer_name
            ):
                if (
                    distillation_loss.teacher_layer_name
                    not in teacher_layer_names
                ):
                    teacher_layer_names.append(
                        distillation_loss.teacher_layer_name
                    )
            if (
                hasattr(distillation_loss, "student_layer_name")
                and distillation_loss.student_layer_name
            ):
                if (
                    distillation_loss.student_layer_name
                    not in student_layer_names
                ):
                    student_layer_names.append(
                        distillation_loss.student_layer_name
                    )

        self._teacher_feature_extractor = self._create_feature_extractor(
            self.teacher, teacher_layer_names
        )
        self._student_feature_extractor = self._create_feature_extractor(
            self.student, student_layer_names
        )

    def _create_feature_extractor(self, model, layer_names):
        """Create a feature extractor for a model.

        Arguments:
            model: The model to create an extractor for.
            layer_names: List of layer names to extract features from.

        Returns:
            Feature extractor model or `None` if no layer names provided.

        Raises:
            ValueError: If model has no symbolic inputs/outputs.
        """
        if not layer_names:
            return None

        if not hasattr(model, "inputs") or model.inputs is None:
            raise ValueError(
                f"Cannot create feature extractor for {model.name}. "
                f"The model has no symbolic inputs attribute."
            )

        if isinstance(model, keras.Sequential):
            final_output = model.layers[-1].output
        else:
            final_output = model.output

        outputs = {"final_output": final_output}
        for layer_name in layer_names:
            layer = model.get_layer(name=layer_name)
            outputs[layer_name] = layer.output

        return keras.Model(
            inputs=model.inputs,
            outputs=outputs,
            name=f"{model.name}_multi_feature_extractor",
        )

    def _extract_all_teacher_features(self, x):
        """Extract all teacher features in a single forward pass."""
        if self._teacher_feature_extractor is not None:
            return self._teacher_feature_extractor(x, training=False)
        else:
            return {"final_output": self.teacher(x, training=False)}

    def _extract_all_student_features(self, x, y_pred):
        """Extract all student features in a single forward pass."""
        if self._student_feature_extractor is not None:
            return self._student_feature_extractor(x, training=True)
        else:
            return {"final_output": y_pred}

    def _get_distillation_loss_features(
        self, distillation_loss, all_features, is_teacher
    ):
        """Get the specific features needed by a distillation loss."""
        if is_teacher:
            layer_name = distillation_loss.teacher_layer_name or "final_output"
        else:
            layer_name = distillation_loss.student_layer_name or "final_output"

        if layer_name not in all_features:
            raise ValueError(
                f"Layer '{layer_name}' not found in extracted features. "
                f"Available: {list(all_features.keys())}"
            )

        return all_features[layer_name]

    def compile(self, optimizer="adam", loss=None, metrics=None, **kwargs):
        """Compile the distiller with proper integration.

        Arguments:
            optimizer: Optimizer for training the student model.
            loss: Student loss function for the student's supervised learning.
                Can be a string identifier or a loss function instance.
            metrics: Additional metrics to track during training.
            **kwargs: Additional arguments passed to parent compile.
        """
        if loss is None:
            raise ValueError("'loss' cannot be `None`.")

        self._student_loss = tree.map_structure(_convert_loss_to_function, loss)
        self._student_loss_for_serialization = loss

        if metrics is not None and not isinstance(metrics, (list, tuple)):
            raise ValueError(
                f"metrics must be a list or tuple, got {type(metrics)}"
            )

        super().compile(
            optimizer=optimizer,
            loss=None,
            metrics=metrics,
            **kwargs,
        )

    def call(self, inputs, training=None, **kwargs):
        """Forward pass returns student predictions."""
        return self.student(inputs, training=training, **kwargs)

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, training=True
    ):
        """Compute combined distillation loss.

        Arguments:
            x: Input data.
            y: Target data.
            y_pred: Model predictions.
            sample_weight: Sample weights (currently unused).
            training: Whether the model is in training mode.

        Returns:
            Combined loss tensor.
        """
        # Handle case where y_pred is not provided
        if y_pred is None:
            y_pred = self(x, training=training)
        # Compute student loss
        student_loss = 0.0
        if self.student_loss_weight > 0.0 and y is not None:
            loss_values = tree.map_structure(
                lambda l, o, o_pred: l(o, o_pred),
                self._student_loss,
                y,
                y_pred,
            )
            flat_losses = tree.flatten(loss_values)
            student_loss = (
                keras.ops.sum(keras.ops.stack(flat_losses))
                if len(flat_losses) > 1
                else flat_losses[0]
            )

            # Ensure student_loss is a scalar
            if hasattr(student_loss, "shape") and len(student_loss.shape) > 0:
                student_loss = keras.ops.mean(student_loss)

        # Compute distillation loss
        distillation_loss = 0.0
        if self.student_loss_weight < 1.0:
            teacher_features = self._extract_all_teacher_features(x)
            student_features = self._extract_all_student_features(x, y_pred)

            # Apply distillation losses using pre-extracted features
            for distillation_loss_fn, weight in zip(
                self.distillation_losses, self.distillation_loss_weights
            ):
                # Get appropriate outputs/features for this distillation loss
                if (
                    hasattr(distillation_loss_fn, "teacher_layer_name")
                    and distillation_loss_fn.teacher_layer_name is not None
                ):
                    # FeatureDistillation with specific layers
                    try:
                        distillation_loss_teacher_output = (
                            self._get_distillation_loss_features(
                                distillation_loss_fn,
                                teacher_features,
                                is_teacher=True,
                            )
                        )
                        distillation_loss_student_output = (
                            self._get_distillation_loss_features(
                                distillation_loss_fn,
                                student_features,
                                is_teacher=False,
                            )
                        )
                    except ValueError as e:
                        # Re-raise with context about which loss failed
                        raise RuntimeError(
                            f"Failed to extract features for "
                            f"{type(distillation_loss_fn).__name__} "
                            f"targeting teacher layer "
                            f"'{distillation_loss_fn.teacher_layer_name}' "
                            f"and student layer "
                            f"'{distillation_loss_fn.student_layer_name}'. "
                            f"Original error: {e}"
                        ) from e
                else:
                    # LogitsDistillation or FeatureDistillation (final outputs)
                    distillation_loss_teacher_output = teacher_features[
                        "final_output"
                    ]
                    distillation_loss_student_output = y_pred

                # Validate outputs are compatible for this distillation loss
                distillation_loss_fn.validate_outputs(
                    distillation_loss_teacher_output,
                    distillation_loss_student_output,
                )

                # Compute loss for this distillation loss
                current_distillation_loss = distillation_loss_fn.compute_loss(
                    distillation_loss_teacher_output,
                    distillation_loss_student_output,
                )

                # Validate that distillation loss returns a scalar
                if (
                    hasattr(current_distillation_loss, "shape")
                    and len(current_distillation_loss.shape) > 0
                ):
                    raise ValueError(
                        f"Distillation loss "
                        f"{distillation_loss_fn.__class__.__name__} "
                        f"returned a non-scalar loss with shape "
                        f"{current_distillation_loss.shape}. "
                        f"The compute_loss method must return a scalar "
                        f"tensor."
                    )

                # Apply weight and add to total
                distillation_loss = keras.ops.add(
                    distillation_loss,
                    keras.ops.multiply(weight, current_distillation_loss),
                )

        # Combine losses
        total_loss = keras.ops.add(
            keras.ops.multiply(self.student_loss_weight, student_loss),
            keras.ops.multiply(
                keras.ops.subtract(1.0, self.student_loss_weight),
                distillation_loss,
            ),
        )

        # Update metrics
        self.student_loss_tracker.update_state(student_loss)
        self.distillation_loss_tracker.update_state(distillation_loss)
        self.total_loss_tracker.update_state(total_loss)

        return total_loss

    def reset_metrics(self):
        """Reset all metrics."""
        super().reset_metrics()
        self.student_loss_tracker.reset_state()
        self.distillation_loss_tracker.reset_state()
        self.total_loss_tracker.reset_state()

    def get_config(self):
        """Get configuration for serialization."""
        config = super().get_config()
        config.update(
            {
                "teacher": serialization_lib.serialize_keras_object(
                    self.teacher
                ),
                "student": serialization_lib.serialize_keras_object(
                    self.student
                ),
                "distillation_losses": [
                    serialization_lib.serialize_keras_object(distillation_loss)
                    for distillation_loss in self.distillation_losses
                ],
                "distillation_loss_weights": self.distillation_loss_weights,
                "student_loss_weight": self.student_loss_weight,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        """Create instance from configuration."""
        config = config.copy()

        # Deserialize objects
        config["teacher"] = serialization_lib.deserialize_keras_object(
            config["teacher"]
        )
        config["student"] = serialization_lib.deserialize_keras_object(
            config["student"]
        )
        config["distillation_losses"] = [
            serialization_lib.deserialize_keras_object(distillation_loss)
            for distillation_loss in config["distillation_losses"]
        ]

        return cls(**config)
