import json
from typing import Dict, List, Any

from sqlmodel import Session, col, func, select

from app.models import BotSession, SessionImage, SessionPpt, SessionVideo


class SessionRepository:
    def __init__(self, session: Session):
        self.session = session

    def save_complete(
        self,
        session_id: str,
        user_id: str,
        session_data: Dict[str, Any],
        images: List[Dict[str, Any]],
        videos: List[Dict[str, Any]],
    ) -> None:
        """Save session + images + videos. Replaces existing session if any."""
        existing = self.session.get(BotSession, session_id)
        if existing:
            self.session.delete(existing)
            self.session.commit()

        self.session.add(BotSession(
            session_id=session_id,
            user_id=user_id,
            character_id=session_data.get("character_id"),
            recording_r2_key=session_data.get("recording_r2_key"),
            transcription_r2_key=session_data.get("transcription_r2_key"),
            session_metadata_r2_key=session_data.get("session_metadata_r2_key"),
            image_generations_count=session_data.get("image_generations_count", 0),
            image_edits_count=session_data.get("image_edits_count", 0),
            video_count=session_data.get("video_count", 0),
            duration_seconds=session_data.get("duration_seconds", 0),
            started_at=session_data.get("started_at"),
            ended_at=session_data.get("ended_at"),
            status=session_data.get("status", "completed"),
            extra_data=json.dumps(session_data.get("metadata")) if session_data.get("metadata") else None,
        ))

        for img in images:
            self.session.add(SessionImage(
                session_id=session_id,
                user_id=user_id,
                image_type=img.get("image_type"),
                extra_data=json.dumps(img.get("metadata", {})),
                prompt=img.get("prompt"),
            ))

        for vid in videos:
            self.session.add(SessionVideo(
                session_id=session_id,
                user_id=user_id,
                video_type=vid.get("video_type"),
                extra_data=json.dumps(vid.get("metadata", {})),
                prompt=vid.get("prompt"),
            ))

        self.session.commit()

    def count_images_since(self, user_id: str, cutoff: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(SessionImage.user_id == user_id, col(SessionImage.created_at) >= cutoff)
        ).one()

    def count_videos_since(self, user_id: str, cutoff: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(SessionVideo.user_id == user_id, col(SessionVideo.created_at) >= cutoff)
        ).one()

    def count_ppts_since(self, user_id: str, cutoff: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(SessionPpt.user_id == user_id, col(SessionPpt.created_at) >= cutoff)
        ).one()

    def count_images_in_period(self, user_id: str, start: str, end: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(
                SessionImage.user_id == user_id,
                col(SessionImage.created_at) >= start,
                col(SessionImage.created_at) <= end,
            )
        ).one()

    def count_videos_in_period(self, user_id: str, start: str, end: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(
                SessionVideo.user_id == user_id,
                col(SessionVideo.created_at) >= start,
                col(SessionVideo.created_at) <= end,
            )
        ).one()

    def count_ppts_in_period(self, user_id: str, start: str, end: str) -> int:
        return self.session.exec(
            select(func.count())
            .where(
                SessionPpt.user_id == user_id,
                col(SessionPpt.created_at) >= start,
                col(SessionPpt.created_at) <= end,
            )
        ).one()
