from datetime import datetime, timedelta, timezone

import jwt
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import ValidationError

from app.config import JWT_ALGORITHM, JWT_EXPIRY_DAYS, JWT_SECRET
from app.schemas.auth import CurrentUser

security = HTTPBearer(auto_error=False)


def create_jwt(
    user_id: str,
    email: str,
    is_premium: bool = False,
    plan: str = "",
    premium_expiry: str = "",
) -> str:
    payload = {
        "sub": user_id,
        "email": email,
        "premium": is_premium,
        "plan": plan,
        "premiumExpiry": premium_expiry,
        "exp": datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRY_DAYS),
        "iat": datetime.now(timezone.utc),
    }
    return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)


def verify_jwt(token: str) -> CurrentUser:
    try:
        payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
        return CurrentUser(**payload)
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(status_code=401, detail="Invalid token")
    except ValidationError:
        raise HTTPException(status_code=401, detail="Invalid token payload")


async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security),
) -> CurrentUser:
    if not credentials:
        raise HTTPException(status_code=401, detail="Missing authorization")
    return verify_jwt(credentials.credentials)
