"""Usage tracking service: load usage state and deduct topup credits."""

from datetime import datetime, timedelta, timezone
from typing import Dict

from loguru import logger
from sqlmodel import Session

from app.repositories.pricing import PricingRepository
from app.repositories.session import SessionRepository
from app.repositories.subscription import SubscriptionRepository
from app.repositories.topup import TopupRepository


class UsageService:
    def __init__(self, session: Session):
        self.pricing = PricingRepository(session)
        self.sessions = SessionRepository(session)
        self.subscriptions = SubscriptionRepository(session)
        self.topups = TopupRepository(session)

    def get_usage_state(self, user_id: str) -> Dict:
        """Load full usage state for a user. Called at bot session start and /usage endpoint."""
        try:
            tier_id = self.pricing.get_user_tier_id_if_active(user_id)
            is_premium = tier_id is not None
            effective_tier = tier_id or "free"

            limits = self.pricing.get_limits(effective_tier)
            if not limits:
                limits = self.pricing.get_limits("free")

            max_images = limits.max_images if limits else 5
            max_videos = limits.max_videos if limits else 0
            max_ppts = limits.max_ppts if limits else 0
            window_hours = limits.window_hours if limits else 24

            if is_premium and window_hours is None:
                sub = self.subscriptions.get_active(user_id)
                if sub and sub.created_at and sub.expiry_date:
                    images_used = self.sessions.count_images_in_period(user_id, sub.created_at, sub.expiry_date)
                    videos_used = self.sessions.count_videos_in_period(user_id, sub.created_at, sub.expiry_date)
                    ppts_used = self.sessions.count_ppts_in_period(user_id, sub.created_at, sub.expiry_date)
                else:
                    images_used = videos_used = ppts_used = 0
            else:
                cutoff = (datetime.now(timezone.utc) - timedelta(hours=window_hours)).strftime("%Y-%m-%d %H:%M:%S")
                images_used = self.sessions.count_images_since(user_id, cutoff)
                videos_used = self.sessions.count_videos_since(user_id, cutoff)
                ppts_used = self.sessions.count_ppts_since(user_id, cutoff)

            if is_premium:
                topup_images, topup_videos = self.topups.get_credits(user_id)
            else:
                topup_images = topup_videos = 0

            plan_images_remaining = max(0, max_images - images_used)
            plan_videos_remaining = max(0, max_videos - videos_used)
            plan_ppts_remaining = max(0, max_ppts - ppts_used)

            result = {
                "is_premium": is_premium,
                "tier_id": effective_tier,
                "plan_images_remaining": plan_images_remaining,
                "plan_videos_remaining": plan_videos_remaining,
                "plan_ppts_remaining": plan_ppts_remaining,
                "topup_images_remaining": topup_images,
                "topup_videos_remaining": topup_videos,
            }
            logger.info(
                f"USAGE_LOADED: user_id={user_id} tier={effective_tier} "
                f"plan_img={plan_images_remaining}/{max_images} plan_vid={plan_videos_remaining}/{max_videos} "
                f"plan_ppt={plan_ppts_remaining}/{max_ppts} "
                f"topup_img={topup_images} topup_vid={topup_videos}"
            )
            return result

        except Exception as e:
            logger.error(f"USAGE_LOAD_FAILED: user_id={user_id} error={e}")
            return {
                "is_premium": False,
                "tier_id": "free",
                "plan_images_remaining": 5,
                "plan_videos_remaining": 0,
                "plan_ppts_remaining": 0,
                "topup_images_remaining": 0,
                "topup_videos_remaining": 0,
            }

    def deduct_topup_credits(self, user_id: str, images_used: int, videos_used: int) -> None:
        """Deduct topup credits at session end."""
        if images_used <= 0 and videos_used <= 0:
            return
        try:
            self.topups.deduct_credits(user_id, images_used, videos_used)
            logger.info(f"TOPUP_DEDUCTED: user_id={user_id} images={images_used} videos={videos_used}")
        except Exception as e:
            logger.error(f"TOPUP_DEDUCT_FAILED: user_id={user_id} error={e}")
