from typing import Optional

from sqlmodel import Session, func, select

from app.models import TopupPlan, TopupPurchase


class TopupRepository:
    def __init__(self, session: Session):
        self.session = session

    def get_active_plans(self) -> list[TopupPlan]:
        return list(self.session.exec(
            select(TopupPlan).where(TopupPlan.is_active == 1)
        ).all())

    def get_plan_by_product(self, product_id: str) -> Optional[TopupPlan]:
        return self.session.exec(
            select(TopupPlan)
            .where(TopupPlan.product_id == product_id, TopupPlan.is_active == 1)
        ).first()

    def get_purchase_by_token(self, purchase_token: str) -> Optional[TopupPurchase]:
        return self.session.exec(
            select(TopupPurchase).where(TopupPurchase.purchase_token == purchase_token)
        ).first()

    def create_purchase(
        self,
        user_id: str,
        topup_plan_id: str,
        order_id: str,
        purchase_token: str,
        images: int,
        videos: int,
    ) -> TopupPurchase:
        purchase = TopupPurchase(
            user_id=user_id,
            topup_plan_id=topup_plan_id,
            order_id=order_id,
            purchase_token=purchase_token,
            images_remaining=images,
            videos_remaining=videos,
            status="active",
        )
        self.session.add(purchase)
        self.session.commit()
        return purchase

    def get_credits(self, user_id: str) -> tuple[int, int]:
        """Sum remaining topup credits. Returns (images, videos)."""
        row = self.session.exec(
            select(
                func.coalesce(func.sum(TopupPurchase.images_remaining), 0),
                func.coalesce(func.sum(TopupPurchase.videos_remaining), 0),
            )
            .where(TopupPurchase.user_id == user_id, TopupPurchase.status == "active")
        ).first()
        return (row[0], row[1]) if row else (0, 0)

    def deduct_credits(self, user_id: str, images_used: int, videos_used: int) -> None:
        """Deduct topup credits. Drains oldest purchases first."""
        if images_used > 0:
            remaining = images_used
            purchases = list(self.session.exec(
                select(TopupPurchase)
                .where(
                    TopupPurchase.user_id == user_id,
                    TopupPurchase.status == "active",
                    TopupPurchase.images_remaining > 0,
                )
                .order_by(TopupPurchase.purchased_at)  # type: ignore
            ).all())
            for p in purchases:
                if remaining <= 0:
                    break
                deduct = min(remaining, p.images_remaining)
                p.images_remaining -= deduct
                remaining -= deduct

        if videos_used > 0:
            remaining = videos_used
            purchases = list(self.session.exec(
                select(TopupPurchase)
                .where(
                    TopupPurchase.user_id == user_id,
                    TopupPurchase.status == "active",
                    TopupPurchase.videos_remaining > 0,
                )
                .order_by(TopupPurchase.purchased_at)  # type: ignore
            ).all())
            for p in purchases:
                if remaining <= 0:
                    break
                deduct = min(remaining, p.videos_remaining)
                p.videos_remaining -= deduct
                remaining -= deduct

        # Mark fully used purchases
        exhausted = list(self.session.exec(
            select(TopupPurchase)
            .where(
                TopupPurchase.user_id == user_id,
                TopupPurchase.status == "active",
                TopupPurchase.images_remaining <= 0,
                TopupPurchase.videos_remaining <= 0,
            )
        ).all())
        for p in exhausted:
            p.status = "used"

        self.session.commit()
