# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Journaling system for ONNX IR operations."""

from __future__ import annotations

import weakref
from typing import Any

__all__ = ["Journal", "JournalEntry", "get_current_journal"]

import dataclasses
import datetime
import time
import traceback
from collections.abc import Callable, Sequence

from typing_extensions import Self

from onnx_ir.journaling import _wrappers

_current_journal: Journal | None = None


@dataclasses.dataclass(frozen=True)
class JournalEntry:
    """A single journal entry recording an operation on the IR.

    Attributes:
        timestamp: The time at which the operation was performed.
        operation: The name of the operation performed.
        class_: The class of the object on which the operation was performed.
        class_name: The name of the class of the object.
        ref: A weak reference to the object on which the operation was performed.
            To access the object, call ``ref()``. Note that ``ref`` may be ``None``,
            and ``ref()`` may return ``None`` if the object has been garbage-collected.
        obj: The referenced object, or None if it has been garbage-collected or not recorded.
            This is the same as calling ``entry.ref() if entry.ref is not None else None``.
        object_id: The unique identifier of the object (id()).
        stack_trace: The stack trace at the time of the operation.
        details: Additional details about the operation.
    """

    timestamp: float
    operation: str
    class_: type
    class_name: str
    ref: weakref.ref | None
    object_id: int
    stack_trace: list[traceback.FrameSummary]
    details: str | None

    @property
    def obj(self) -> Any | None:
        """Get the referenced object, or None if it has been garbage-collected or not recorded."""
        if self.ref is None:
            return None
        return self.ref()

    def display(self) -> None:
        """Display the journal entry in a detailed multi-line format."""
        # Header with timestamp
        timestamp = datetime.datetime.fromtimestamp(self.timestamp).strftime(
            "%Y-%m-%d %H:%M:%S.%f"
        )
        print(f"\033[1m{'=' * 80}\033[0m")
        print(f"\033[1mTimestamp:\033[0m {timestamp}")
        print(f"\033[1mOperation:\033[0m {self.operation}")
        print(f"\033[1mClass:\033[0m {self.class_name} (id={self.object_id})")

        # Object representation
        if self.ref is None:
            object_repr = "<no ref>"
        elif (obj := self.ref()) is not None:
            object_repr = repr(obj)
        else:
            object_repr = "<deleted>"
        print("\033[1mObject:\033[0m")
        for line in object_repr.split("\n"):
            print(f"  {line}")

        # Details
        if self.details:
            print("\033[1mDetails:\033[0m")
            for line in self.details.split("\n"):
                print(f"  {line}")

        # Stack trace - find user code frame
        if self.stack_trace:
            user_frame = None
            for f in reversed(self.stack_trace):
                filename = f.filename.replace("\\", "/")
                if "onnx_ir" not in filename or "onnx_ir/passes" in filename:
                    user_frame = f
                    break

            print("\033[1mUser Code Location:\033[0m")
            if user_frame is not None:
                print(
                    f"  \033[90m{user_frame.filename}:{user_frame.lineno} in {user_frame.name}\033[0m"
                )
                if user_frame.line:
                    print(f"  \033[90m>>> {user_frame.line}\033[0m")
            else:
                print("  \033[90m<unknown>\033[0m")

            print("\033[1mFull Stack Trace:\033[0m")
            for f in self.stack_trace:
                print(f"  \033[90m{f.filename}:{f.lineno} in {f.name}\033[0m")
                if f.line:
                    print(f"    \033[90m{f.line}\033[0m")
        print(f"\033[1m{'=' * 80}\033[0m")


def get_current_journal() -> Journal | None:
    """Get the current journal, if any."""
    return _current_journal


def _get_stack_trace() -> list[traceback.FrameSummary]:
    return traceback.extract_stack()[:-3]


class Journal:
    """A journaling system to record operations on the ONNX IR.

    It can be used as a context manager to automatically record operations within a block.

    Example::

        from onnx_ir.journaling import Journal

        journal = Journal()

        with Journal() as journal:
            # Perform operations on the ONNX IR
            pass

        for entry in journal.entries:
            print(f"Operation: {entry.operation} on {entry.class_name}")


    You can also filter the entries by operation or class name using the `filter` method::

        filtered_entries = [
            entry for entry in journal.entries
            if entry.operation == "set_name" and entry.class_name == "Node"
        ]
    """

    def __init__(self) -> None:
        self._entries: list[JournalEntry] = []
        self._previous_journal: Journal | None = None
        self._hooks: list[Callable[[JournalEntry], None]] = []
        self._original_methods: dict[str, Callable] = {}

    def __enter__(self) -> Self:
        global _current_journal
        self._previous_journal = _current_journal
        _current_journal = self
        self._original_methods = _wrappers.wrap_ir_classes(self)
        return self

    def __exit__(self, exc_type, exc_value, exc_tb) -> None:
        _wrappers.restore_ir_classes(self._original_methods)
        global _current_journal
        _current_journal = self._previous_journal

    @property
    def entries(self) -> Sequence[JournalEntry]:
        """Get all recorded journal entries."""
        return self._entries

    def record(self, obj: Any, operation: str, details: str | None = None) -> None:
        """Record a new journal entry."""
        entry = JournalEntry(
            timestamp=time.time(),
            operation=operation,
            class_=obj.__class__,
            class_name=obj.__class__.__name__,
            ref=weakref.ref(obj) if obj is not None else None,
            object_id=id(obj),
            stack_trace=_get_stack_trace(),
            details=details,
        )
        self._entries.append(entry)
        for hook in self._hooks:
            hook(entry)

    def add_hook(self, hook: Callable[[JournalEntry], None]) -> None:
        """Add a hook that will be called whenever a new journal entry is recorded."""
        self._hooks.append(hook)

    def clear_hooks(self) -> None:
        """Clear all hooks."""
        self._hooks.clear()

    def display(self) -> None:
        """Display all journal entries."""
        for entry in self._entries:
            details = f" [{entry.details}]" if entry.details else ""
            timestamp = datetime.datetime.fromtimestamp(entry.timestamp).strftime(
                "%Y-%m-%d %H:%M:%S.%f"
            )
            if entry.stack_trace:
                # Find the first frame that is not from internal onnx_ir modules
                frame = None
                for f in reversed(entry.stack_trace):
                    # Normalize path separators for cross-platform compatibility
                    filename = f.filename.replace("\\", "/")
                    if "onnx_ir" not in filename or "onnx_ir/passes" in filename:
                        frame = f
                        break
                if frame is not None:
                    location = f"{frame.filename}:{frame.lineno} in {frame.name}"
                else:
                    location = "<unknown>"
            else:
                location = "<unknown>"
            print()
            print(f"[{timestamp}] \033[90m{location}\033[0m")
            if entry.ref is None:
                object_repr = "<no ref>"
            elif (obj := entry.ref()) is not None:
                object_repr = repr(obj).replace("\n", "\\n")
                if len(object_repr) > 100:
                    object_repr = object_repr[:95] + "[...]"
            else:
                object_repr = "<deleted>"
            details_text = details.replace("\n", "\\n")
            if len(details_text) > 100:
                details_text = details_text[:95] + "[...]"
            print(
                f"Class: {entry.class_name}(id={entry.object_id}). Operation: {entry.operation}. Object: {object_repr}."
            )
            if details:
                print(f"\033[90mDetails: {details_text}\033[0m")
