"""Bot session state, transcript manager, chunk state, and cleanup."""

import asyncio
import json
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional

from loguru import logger
from sqlmodel import Session, select

from app.database import engine
from app.models import Persona


# --- Constants ---

SAMPLE_RATE = 16000
NO_USER_TIMEOUT = 30  # seconds to wait for user before closing

# Fallback prompt if persona not found in DB
DEFAULT_SYSTEM_PROMPT = """You are Maya, a friendly and helpful AI assistant. Keep your responses concise but warm.

Remember to:
- Be conversational and friendly
- Handle errors gracefully
- Never output asterisks or markdown formatting
- You are talking to the user over a voice call

Start by warmly introducing yourself and asking how you can help today.
"""


def load_persona(persona_id: str) -> dict:
    """Load persona config from DB. Returns dict with system_prompt, tools, tts_voice, name."""
    try:
        with Session(engine) as session:
            persona = session.exec(
                select(Persona)
                .where(Persona.persona_id == persona_id, Persona.active == 1)
            ).first()
            if persona:
                return {
                    "persona_id": persona.persona_id,
                    "name": persona.name,
                    "system_prompt": persona.system_prompt,
                    "tools": json.loads(persona.tools or "[]"),
                    "tts_voice": persona.tts_voice or "aura-2-helena-en",
                }
    except Exception as e:
        logger.error(f"Failed to load persona {persona_id}: {e}")
    return {
        "persona_id": persona_id,
        "name": "Maya",
        "system_prompt": DEFAULT_SYSTEM_PROMPT,
        "tools": ["cf_image", "generate_video", "generate_video_from_image"],
        "tts_voice": "aura-2-helena-en",
    }


class ChunkState:
    """Tracks audio recording chunks for a session."""

    def __init__(self):
        self.counter: int = 0
        self.chunks: List[str] = []
        self.session_id: Optional[str] = None

    def reset(self, session_id: str):
        self.session_id = session_id
        self.counter = 0
        self.chunks = []

    def clear(self):
        self.counter = 0
        self.chunks.clear()
        self.session_id = None

    def get_chunk_key(self, session_id: str, chunk_num: int) -> str:
        return f"sessions/{session_id}/chunks/chunk_{chunk_num:03d}.wav"

    def add_chunk(self, chunk_key: str) -> int:
        chunk_num = self.counter
        self.counter += 1
        self.chunks.append(chunk_key)
        return chunk_num

    @property
    def total_chunks(self) -> int:
        return self.counter

    @property
    def chunk_list(self) -> List[str]:
        return self.chunks.copy()


class TranscriptManager:
    """Collects all messages during a session for saving."""

    def __init__(self):
        self._messages: List[Dict] = []

    def clear(self):
        self._messages.clear()

    def add_message(self, role: str, content: str, msg_type: str = None, **kwargs):
        entry = {
            "timestamp": datetime.now().isoformat(),
            "role": role,
            "content": content,
        }
        if msg_type:
            entry["type"] = msg_type
        entry.update(kwargs)
        self._messages.append(entry)

    def add_generation(self, generation_type: str, content: Dict):
        self._messages.append({
            "timestamp": datetime.now().isoformat(),
            "role": "system",
            "type": "generation",
            "generation_type": generation_type,
            "content": content,
        })

    def add_image_upload(self, mime_type: str, size_bytes: int, r2_url: str = None):
        self._messages.append({
            "timestamp": datetime.now().isoformat(),
            "role": "user",
            "type": "image_upload",
            "content": {"mime_type": mime_type, "size_bytes": size_bytes, "r2_url": r2_url},
        })

    @property
    def messages(self) -> List[Dict]:
        return self._messages

    @property
    def count(self) -> int:
        return len(self._messages)

    def get_user_messages(self) -> List[Dict]:
        return [m for m in self._messages if m.get("role") == "user"]

    def get_assistant_messages(self) -> List[Dict]:
        return [m for m in self._messages if m.get("role") == "assistant"]

    def get_generations(self) -> List[Dict]:
        return [m for m in self._messages if m.get("type") == "generation"]

    def get_image_generations_count(self) -> int:
        return len([m for m in self._messages if m.get("generation_type") == "image"])

    def get_image_edits_count(self) -> int:
        return len([m for m in self._messages if m.get("generation_type") == "image_edit"])


class BotSessionState:
    """Tracks all state for a single bot session."""

    def __init__(self, session_id: str, user_id: str):
        self.session_id = session_id
        self.user_id = user_id
        self.bot_start_time = time.time()
        self.user_joined = False
        self.current_user_id: Optional[str] = None
        self.no_user_timeout_task: Optional[asyncio.Task] = None
        # Audio
        self.audio_buffer = None

        # Images
        self.uploaded_image: Optional[Dict] = None
        self.last_image: Optional[Dict] = None
        self.last_generated_image: Optional[Dict] = None
        self.is_processing = False
        self._last_upload_time: float = 0

        # Videos
        self.generated_videos: List[Dict] = []

        # Pipeline references for cleanup
        self.transport = None
        self.pipeline = None
        self._background_tasks: List[asyncio.Task] = []

        # Character
        self.character_id: Optional[str] = None
        self.character_name: Optional[str] = None

        # Subscription
        self.is_premium: bool = False
        self.plan: str = ""
        # Usage counters (loaded once at session start)
        self.plan_images_remaining: int = 0
        self.plan_videos_remaining: int = 0
        self.plan_ppts_remaining: int = 0
        self.topup_images_remaining: int = 0
        self.topup_videos_remaining: int = 0
        self._topup_images_used_this_session: int = 0
        self._topup_videos_used_this_session: int = 0

    def get_session_info(self) -> str:
        info = f"session_id={self.session_id} user_id={self.user_id}"
        if self.character_id:
            info += f" character={self.character_id}"
        return info

    def can_generate_image(self) -> str:
        """Returns '' if allowed, or a message explaining why not."""
        if not self.is_premium:
            if self.plan_images_remaining <= 0:
                return "subscribe"
            return ""
        # Premium user
        if self.plan_images_remaining > 0:
            return ""
        if self.topup_images_remaining > 0:
            return ""
        return "topup"

    def can_generate_video(self) -> str:
        """Returns '' if allowed, or a message explaining why not."""
        if not self.is_premium:
            if self.plan_videos_remaining <= 0:
                return "subscribe"
            return ""
        if self.plan_videos_remaining > 0:
            return ""
        if self.topup_videos_remaining > 0:
            return ""
        return "topup"

    def can_generate_ppt(self) -> str:
        """Returns '' if allowed, or a message explaining why not."""
        if not self.is_premium:
            if self.plan_ppts_remaining <= 0:
                return "subscribe"
            return ""
        if self.plan_ppts_remaining > 0:
            return ""
        return "topup"

    def use_ppt(self):
        """Decrement PPT counter."""
        if self.plan_ppts_remaining > 0:
            self.plan_ppts_remaining -= 1

    def use_image(self):
        """Decrement image counter. Plan first, then topup."""
        if self.plan_images_remaining > 0:
            self.plan_images_remaining -= 1
        elif self.topup_images_remaining > 0:
            self.topup_images_remaining -= 1
            self._topup_images_used_this_session += 1

    def use_video(self):
        """Decrement video counter. Plan first, then topup."""
        if self.plan_videos_remaining > 0:
            self.plan_videos_remaining -= 1
        elif self.topup_videos_remaining > 0:
            self.topup_videos_remaining -= 1
            self._topup_videos_used_this_session += 1

    def can_upload_image(self, cooldown: float = 2.0) -> bool:
        now = time.time()
        if now - self._last_upload_time < cooldown:
            return False
        self._last_upload_time = now
        return True

    def set_uploaded_image(self, image_info: Dict):
        self.last_generated_image = None
        self.uploaded_image = image_info
        self.last_image = image_info

    def set_generated_image(self, image_info: Dict):
        self.last_generated_image = image_info
        self.last_image = image_info

    def set_edited_image(self, image_info: Dict):
        self.last_generated_image = image_info
        self.last_image = image_info
        self.uploaded_image = None

    def get_image_to_edit(self) -> Optional[Dict]:
        return self.last_image or self.uploaded_image or self.last_generated_image

    def get_uploaded_image(self) -> Optional[Dict]:
        return self.uploaded_image or self.last_generated_image

    def clear_all_images(self):
        self.uploaded_image = None
        self.last_image = None
        self.last_generated_image = None

    def add_generated_video(self, video_type: str, metadata: Dict, prompt: str = None):
        self.generated_videos.append({
            "video_type": video_type, "metadata": metadata, "prompt": prompt,
        })

    def cancel_timeout_task(self):
        if self.no_user_timeout_task and not self.no_user_timeout_task.done():
            self.no_user_timeout_task.cancel()

    async def cancel_background_tasks(self):
        tasks = list(self._background_tasks)
        for t in tasks:
            if not t.done():
                t.cancel()
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)
            logger.debug(f"CLEANUP: cancelled {len(tasks)} background tasks session_id={self.session_id}")
        self._background_tasks.clear()


async def cleanup_session(
    bot_session: BotSessionState,
    session_id: str,
    transcript_manager: Optional[TranscriptManager] = None,
    chunk_state: Optional[ChunkState] = None,
):
    """Full memory cleanup after session ends."""
    threads_start = threading.active_count()
    logger.info(f"CLEANUP_START: session_id={session_id} threads={threads_start}")

    try:
        await bot_session.cancel_background_tasks()
    except Exception as e:
        logger.error(f"CLEANUP_ERROR: cancel_background_tasks: {e}")

    # Pipeline cleanup releases transport + native threads
    try:
        if bot_session.pipeline:
            if hasattr(bot_session.pipeline, "cleanup"):
                await bot_session.pipeline.cleanup()
            bot_session.pipeline = None
    except Exception as e:
        logger.error(f"CLEANUP_ERROR: pipeline: {e}")

    try:
        if bot_session.transport:
            if hasattr(bot_session.transport, "cleanup"):
                await bot_session.transport.cleanup()
            bot_session.transport = None
    except Exception as e:
        logger.error(f"CLEANUP_ERROR: transport: {e}")

    bot_session.clear_all_images()
    bot_session.audio_buffer = None
    bot_session.generated_videos.clear()

    if transcript_manager:
        transcript_manager.clear()
    if chunk_state:
        chunk_state.clear()

    threads_end = threading.active_count()
    logger.info(f"CLEANUP_END: session_id={session_id} threads={threads_start}->{threads_end}")
