# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dynamo roundtrip: graph break scenarios and resume function roundtripping.

When Dynamo encounters an unsupported operation (e.g. ``print()``), it
performs a *graph break*: it splits the function into compiled graph segments
connected by resume functions (``__resume_at_XX``).  The transformed bytecode
and the resume functions are all Dynamo-generated code objects that must
roundtrip correctly.

Gradient computation is disabled globally by the ``_no_grad`` autouse fixture
in ``conftest.py``, rather than via ``with torch.no_grad()`` inside the
compiled function.  ``with`` blocks produce ``BEFORE_WITH`` / ``POP_EXCEPT``
exception-handling bytecode that Dynamo copies verbatim into resume function
code objects; the decompiler cannot process those opcodes.
"""

import torch
import torch.nn as nn

from magi_compiler.magi_depyf import decompile
from magi_compiler.magi_depyf.decompile.recompiler import CodeRecompiler
from tests.magi_depyf.decompile.dynamo_roundtrip import helpers
from tests.magi_depyf.decompile.dynamo_roundtrip.helpers import _assert_close, _reset, get_cache_entries, roundtrip_and_verify

# ---------------------------------------------------------------------------
# Helper: recursively roundtrip all code objects (main + resume functions)
# ---------------------------------------------------------------------------


def _collect_resume_replacements(transformed_code, fn_globals, replacements, visited):
    """Walk *transformed_code*'s ``co_names`` to find ``__resume_at_*``
    functions, decompile→recompile their **original** code, and collect
    ``(fn, old_code, new_code)`` tuples for later replacement.

    Recursion: each resume function's *transformed* code (from its cache
    entry) may reference further resume functions, so we continue the walk.
    """
    for name in transformed_code.co_names:
        if not name.startswith("__resume"):
            continue
        resume_fn = fn_globals.get(name)
        if resume_fn is None or not hasattr(resume_fn, "__code__"):
            continue
        if id(resume_fn) in visited:
            continue
        visited.add(id(resume_fn))

        orig_code = resume_fn.__code__
        src = decompile(orig_code)
        assert src, f"Empty decompilation for resume fn {orig_code.co_name}"
        recompiled = CodeRecompiler.recompile(code_to_decompile=orig_code, reference_code=orig_code)
        assert recompiled is not None, f"Recompile failed for {orig_code.co_name}"
        replacements.append((resume_fn, orig_code, recompiled))

        resume_entries = get_cache_entries(resume_fn)
        if resume_entries:
            _collect_resume_replacements(resume_entries[0].code, fn_globals, replacements, visited)


def _roundtrip_all_entries(fn, inputs, backend="eager", atol=1e-5, **compile_kw):
    """Compile → decompile → recompile → replace code **recursively**
    (top-level fn + all resume functions) → re-compile and verify output.

    Top-level fn: replaced with decompile(TRANSFORMED code from cache entry).
    Resume functions: replaced with decompile(ORIGINAL code) — the
    Dynamo-generated bytecode that doesn't reference one-shot compiled fns.
    """
    _reset()
    torch.manual_seed(42)
    compiled = torch.compile(fn, backend=backend, **compile_kw)
    expected = compiled(*inputs)

    entries = get_cache_entries(fn)
    assert len(entries) >= 1, f"No cache entries for {fn.__code__.co_name}"

    tc = entries[0].code
    top_recompiled = CodeRecompiler.recompile(code_to_decompile=tc, reference_code=tc)

    resume_replacements = []
    _collect_resume_replacements(tc, fn.__globals__, resume_replacements, visited=set())

    old_code = fn.__code__
    fn.__code__ = top_recompiled
    for resume_fn, orig, new in resume_replacements:
        resume_fn.__code__ = new
    try:
        _reset()
        torch.manual_seed(42)
        compiled2 = torch.compile(fn, backend=backend, **compile_kw)
        actual = compiled2(*inputs)
        _assert_close(actual, expected, atol=atol)
    finally:
        fn.__code__ = old_code
        for resume_fn, orig, _ in resume_replacements:
            resume_fn.__code__ = orig


class TestGraphBreak:
    """Functions that cause Dynamo graph breaks, producing resume functions."""

    def test_print_graph_break(self):
        """print() causes a graph break, producing __resume_at functions."""
        layer = nn.Linear(32, 16)
        layer.eval()
        helpers.GLOBAL_MODULE = layer

        def fn(x):
            y = helpers.GLOBAL_MODULE(x)
            print("shape:", y.shape)
            return y * 2

        _roundtrip_all_entries(fn, (torch.randn(2, 32),))

    def test_multi_graph_break(self):
        """Multiple graph breaks in one function — same module called twice."""
        layer = nn.Linear(32, 32)
        layer.eval()
        helpers.GLOBAL_MODULE = layer

        def fn(x):
            y = helpers.GLOBAL_MODULE(x)
            print("after first:", y.shape)
            z = helpers.GLOBAL_MODULE(y)
            print("after second:", z.shape)
            return z

        _roundtrip_all_entries(fn, (torch.randn(2, 32),))

    def test_explicit_graph_break(self):
        """torch._dynamo.graph_break() explicit break."""
        layer = nn.Linear(16, 16)
        layer.eval()
        helpers.GLOBAL_MODULE = layer

        def fn(x):
            y = helpers.GLOBAL_MODULE(x)
            torch._dynamo.graph_break()
            return y + 1

        _roundtrip_all_entries(fn, (torch.randn(2, 16),))

    def test_conditional_specialization(self):
        """Data-independent branch — Dynamo specializes on the bool, no graph break."""
        layer = nn.Linear(16, 16)
        layer.eval()
        helpers.GLOBAL_MODULE = layer

        def fn(x, flag):
            y = helpers.GLOBAL_MODULE(x)
            if flag:
                return y + 1
            return y - 1

        roundtrip_and_verify(fn, (torch.randn(2, 16), True))

    def test_resume_function_recursive_roundtrip(self):
        """Verify ALL resume functions in a graph-break chain can be
        decompiled, recompiled, and the whole tree re-executed correctly."""
        layer = nn.Linear(16, 16)
        layer.eval()
        helpers.GLOBAL_MODULE = layer

        def fn(x):
            y = helpers.GLOBAL_MODULE(x)
            print("break1")
            z = y * 2
            print("break2")
            return z + 1

        _roundtrip_all_entries(fn, (torch.randn(2, 16),))
