from typing import Optional

from sqlmodel import Session, select

from app.models import PlanLimit, PricingTier, UserPricing


class PricingRepository:
    def __init__(self, session: Session):
        self.session = session

    def get_user_tier(self, user_id: str) -> Optional[UserPricing]:
        return self.session.exec(
            select(UserPricing).where(UserPricing.user_id == user_id)
        ).first()

    def get_tier(self, tier_id: str) -> Optional[PricingTier]:
        return self.session.exec(
            select(PricingTier).where(PricingTier.tier_id == tier_id)
        ).first()

    def get_all_tiers(self) -> list[PricingTier]:
        return list(self.session.exec(
            select(PricingTier).order_by(PricingTier.tier_id)  # type: ignore
        ).all())

    def get_limits(self, tier_id: str) -> Optional[PlanLimit]:
        return self.session.get(PlanLimit, tier_id)

    def assign_tier(self, user_id: str, tier_id: str) -> UserPricing:
        up = UserPricing(user_id=user_id, tier_id=tier_id)
        self.session.add(up)
        self.session.commit()
        return up

    def get_user_tier_id_if_active(self, user_id: str) -> Optional[str]:
        """Get tier_id only if user has an active subscription. Used by usage tracking."""
        from app.models import Subscription
        row = self.session.exec(
            select(UserPricing.tier_id)
            .join(Subscription, Subscription.user_id == UserPricing.user_id)
            .where(
                UserPricing.user_id == user_id,
                Subscription.status.in_(["active", "grace_period", "pending_verification"]),  # type: ignore
            )
            .order_by(Subscription.expiry_date.desc())  # type: ignore
        ).first()
        return row if row else None
