from fastapi import APIRouter, Depends
from loguru import logger
from sqlmodel import Session

from app.auth import get_current_user
from app.database import get_session
from app.models import PlanLimit, PricingTier
from app.repositories.pricing import PricingRepository
from app.schemas.auth import CurrentUser
from app.schemas.pricing import PricingResponse

router = APIRouter()


def _hash_user_to_tier(user_id: str, tiers: list[PricingTier]) -> PricingTier:
    import hashlib
    h = int(hashlib.sha256(user_id.encode()).hexdigest(), 16)
    return tiers[h % len(tiers)]


def _format_tier(tier: PricingTier, limits: PlanLimit | None = None) -> PricingResponse:
    pid = tier.product_id
    sub_id, base_plan = pid.split(":") if ":" in pid else (pid, "")
    return PricingResponse(
        tierId=tier.tier_id,
        name=tier.name,
        price=tier.price,
        period=tier.period,
        trialDays=tier.trial_days,
        productId=sub_id,
        basePlanId=base_plan,
        maxImages=limits.max_images if limits else None,
        maxVideos=limits.max_videos if limits else None,
    )


@router.get("/pricing", response_model=PricingResponse)
async def get_pricing(
    current_user: CurrentUser = Depends(get_current_user),
    session: Session = Depends(get_session),
):
    """Get pricing tier for authenticated user. Auto-assigns via hash if not already assigned."""
    user_id = current_user.sub
    repo = PricingRepository(session)

    user_pricing = repo.get_user_tier(user_id)
    if user_pricing:
        tier = repo.get_tier(user_pricing.tier_id)
        if tier:
            limits = repo.get_limits(tier.tier_id)
            return _format_tier(tier, limits)

    all_tiers = repo.get_all_tiers()
    if not all_tiers:
        return PricingResponse(
            tierId="default", name="Monthly", price="9",
            period="month", trialDays=0,
            productId="maya_pro", basePlanId="mayapro2",
        )

    chosen = _hash_user_to_tier(user_id, all_tiers)
    repo.assign_tier(user_id, chosen.tier_id)
    limits = repo.get_limits(chosen.tier_id)

    logger.info(f"PRICING_HASH_ASSIGNED: user={user_id} tier={chosen.tier_id} price=${chosen.price}")
    return _format_tier(chosen, limits)
