import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRALinear(nn.Module):
    """
    LoRA 线性层：直接持有 weight/bias，保持与 nn.Linear 相同的 state_dict key 结构。

    state_dict 结构：
        - weight: 原始权重（与 nn.Linear 一致）
        - bias: 原始偏置（与 nn.Linear 一致）
        - lora_A: LoRA 低秩矩阵 A
        - lora_B: LoRA 低秩矩阵 B

    这样设计的好处：加载预训练权重时无需做 key 转换。
    """

    def __init__(
        self,
        base: nn.Linear,
        r: int,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."

        self.in_features = base.in_features
        self.out_features = base.out_features
        self.r = r
        self.alpha = alpha
        self._base_scaling = alpha / r if r > 0 else 0.0
        
        # 使用 buffer 存储 scaling，这样修改值不会触发 torch.compile 重编译
        # persistent=False 表示不保存到 state_dict，避免加载时 missing key
        self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)

        # 直接持有 weight 和 bias（从原始 Linear 转移过来）
        self.weight = base.weight
        self.bias = base.bias  # 可能是 None

        # LoRA 参数
        if r > 0:
            self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
            self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
        else:
            self.register_parameter("lora_A", None)
            self.register_parameter("lora_B", None)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 基础 Linear 计算
        result = F.linear(x, self.weight, self.bias)
        if self.r <= 0 or self.lora_A is None:
            return result
        # LoRA: result + dropout(x @ A^T @ B^T) * scaling
        lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
        return result + self.dropout(lora_out) * self.scaling

    def reset_lora_parameters(self):
        """重置 LoRA 参数到初始状态"""
        if self.r > 0 and self.lora_A is not None:
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def set_enabled(self, enabled: bool):
        """启用/禁用 LoRA（通过 scaling 控制，兼容 torch.compile）"""
        # 使用 fill_ 原地修改 buffer 值，不会触发重编译
        self.scaling.fill_(self._base_scaling if enabled else 0.0)

    @property
    def enabled(self) -> bool:
        return self.scaling.item() != 0.0


def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
    """
    根据类似 'layers.0.self_attn.q_proj' 的全名，返回 parent module（即 q_proj 的上一级）。
    """
    parts = name.split(".")
    if len(parts) == 1:
        return root
    parent = root
    for p in parts[:-1]:
        if not hasattr(parent, p):
            return None
        parent = getattr(parent, p)
    return parent


def apply_lora_to_named_linear_modules(
    root: nn.Module,
    *,
    target_submodule_names: list[str],
    r: int,
    alpha: float,
    dropout: float,
) -> None:
    """
    在给定模块及其子模块中，对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。

    例如 target_submodule_names=["q_proj", "v_proj"] 时，
    会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
    """
    for full_name, module in list(root.named_modules()):
        if not isinstance(module, nn.Linear):
            continue
        short_name = full_name.split(".")[-1]
        if short_name not in target_submodule_names:
            continue

        parent = _get_parent_module(root, full_name)
        if parent is None:
            continue

        # 用 LoRALinear 替换原始 Linear
        lora_layer = LoRALinear(
            base=module,
            r=r,
            alpha=alpha,
            dropout=dropout,
        )
        setattr(parent, short_name, lora_layer)



