"""Pipecat Voice Agent — Maya.

A Hindi voice agent with function calling for
image generation, video generation, web search, and shopping.

Required AI services:
- Soniox (Speech-to-Text)
- Groq (LLM)
- ElevenLabs (Text-to-Speech — Hindi)
- Replicate (Image + Video Generation)

Run the bot locally::

    uv run bot.py
"""

import os
import sys

from dotenv import load_dotenv
from loguru import logger

print("Starting voice agent...")
print("Loading models and imports (may take ~20s on first run)\n")

logger.info("Loading Silero VAD model...")
from pipecat.audio.vad.silero import SileroVADAnalyzer

logger.info("Silero VAD model loaded")

from pipecat.frames.frames import LLMRunFrame

logger.info("Loading pipeline components...")
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
    LLMContextAggregatorPair,
    LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import (
    RTVIConfig,
    RTVIFunctionCallReportLevel,
    RTVIObserverParams,
    RTVIProcessor,
    RTVIServerMessageFrame,
)
from pipecat.runner.types import DailyRunnerArguments, RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.groq.llm import GroqLLMService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
# from pipecat.services.azure.tts import AzureTTSService  # Commented — keep for future
from pipecat.services.soniox.stt import SonioxSTTService, SonioxInputParams
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams, DailyTransport
from pipecat.transcriptions.language import Language

from config import (
    DEFAULT_LANGUAGE,
    GROQ_MODEL,
    SUPPORTED_LANGUAGES,
)
from functions.handlers import register_all_handlers
from functions.registry import all_tools
from prompts.system_prompt import SYSTEM_PROMPT

# ---------------------------------------------------------------------------
# Simple in-memory session state for uploaded images
# ---------------------------------------------------------------------------
uploaded_image: dict | None = None  # {"base64": str, "mime_type": str}

logger.info("All components loaded successfully!")

load_dotenv(override=True)


async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
    logger.info("Starting bot")

    # ----- STT: Soniox (Hindi + English) -----
    stt = SonioxSTTService(
        api_key=os.getenv("SONIOX_API_KEY"),
        params=SonioxInputParams(
            model="stt-rt-v4",
            enable_language_identification=True,
        ),
    )

    # ----- LLM: Groq -----
    logger.info(f"Using Groq LLM: {GROQ_MODEL}")
    llm = GroqLLMService(
        api_key=os.getenv("GROQ_API_KEY"),
        model=GROQ_MODEL,
        params=GroqLLMService.InputParams(
            temperature=0.7,
            max_tokens=2048,
        ),
    )

    # Register all function call handlers (image, video, search, shopping, edit)
    register_all_handlers(llm)

    # ----- RTVI Processor (handles client messages like image uploads) -----
    rtvi = RTVIProcessor(config=RTVIConfig(config=[]))

    # ----- TTS: ElevenLabs (Hindi) -----
    tts = ElevenLabsTTSService(
        api_key=os.getenv("ELEVENLABS_API_KEY"),
        model="eleven_turbo_v2_5",
        voice_id="7w5JDCUNbeKrn4ySFgfu",
        params=ElevenLabsTTSService.InputParams(
            language=Language.HI,
        ),
    )

    # ----- TTS: Azure (commented out — switch back by uncommenting) -----
    # tts = AzureTTSService(
    #     api_key=os.getenv("AZURE_SPEECH_KEY"),
    #     region=os.getenv("AZURE_SPEECH_REGION", "southeastasia"),
    #     voice="hi-IN-SwaraNeural",
    #     params=AzureTTSService.InputParams(
    #         language=Language.HI_IN,
    #         rate="1.0",
    #     ),
    # )

    # ----- Context + Aggregators -----
    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT,
        },
    ]

    context = LLMContext(messages, all_tools)
    user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
        context,
        user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
    )

    # ----- Pipeline (same order as 8ee2d06; RTVI is added by PipelineTask via rtvi_processor) -----
    pipeline = Pipeline(
        [
            transport.input(),          # Daily WebRTC input
            stt,                        # Soniox STT
            user_aggregator,            # Aggregate user messages
            llm,                        # Groq LLM + function calling
            tts,                        # ElevenLabs TTS (Hindi)
            transport.output(),         # Daily WebRTC output
            assistant_aggregator,       # Aggregate assistant messages
        ]
    )

    # Match 8ee2d06: task adds RTVI + RTVIObserver so images/shopping reach the client.
    # Pass our rtvi so task prepends it and creates the observer (no "no RTVIObserver" error).
    task = PipelineTask(
        pipeline,
        rtvi_processor=rtvi,
        rtvi_observer_params=RTVIObserverParams(
            function_call_report_level={
                "*": RTVIFunctionCallReportLevel.FULL,  # send name, args, results to client
            },
        ),
        params=PipelineParams(
            enable_metrics=True,
            enable_usage_metrics=True,
        ),
        idle_timeout_secs=getattr(runner_args, "pipeline_idle_timeout_secs", None),
    )

    # ----- RTVI client message handler (image uploads) -----
    @rtvi.event_handler("on_client_message")
    async def on_client_message(rtvi_proc, message):
        global uploaded_image
        try:
            msg_type = getattr(message, "type", None)
            msg_data = getattr(message, "data", None)
            logger.info(f"[RTVI] Client message: {msg_type}")

            if msg_type == "user_image_upload":
                data = msg_data if isinstance(msg_data, dict) else {}
                image_data = data.get("image", "")
                mime_type = data.get("mime_type", "image/jpeg")

                if image_data:
                    uploaded_image = {
                        "base64": image_data,
                        "mime_type": mime_type,
                    }
                    logger.info(
                        f"[RTVI] Image stored in memory "
                        f"({len(image_data)} chars, {mime_type})"
                    )

                    # Tell frontend upload succeeded
                    await task.queue_frames([
                        RTVIServerMessageFrame(data={
                            "type": "image_upload_status",
                            "data": {"status": "uploaded"},
                        })
                    ])

                    # Tell the LLM that user uploaded an image
                    messages.append({
                        "role": "system",
                        "content": (
                            "The user just uploaded a photo. Acknowledge it briefly "
                            "and ask them what they want to do — edit it or create "
                            "a video from it. Keep it short and friendly in Hindi. "
                            "Example: 'Photo mil gayi! Batao, edit karni hai ya "
                            "isse video banana hai?'"
                        ),
                    })
                    await task.queue_frames([LLMRunFrame()])
                else:
                    logger.warning("[RTVI] Empty image data received")
                    await task.queue_frames([
                        RTVIServerMessageFrame(data={
                            "type": "image_upload_status",
                            "data": {"status": "failed", "error": "Empty image data"},
                        })
                    ])

        except Exception as e:
            logger.error(f"[RTVI] Error handling client message: {e}", exc_info=True)

    # ----- Transport event handlers -----
    @transport.event_handler("on_client_connected")
    async def on_client_connected(transport, client):
        global uploaded_image
        uploaded_image = None  # Clear any stale image from previous session
        logger.info("Client connected")
        # Greet the user warmly in Hindi
        messages.append(
            {
                "role": "system",
                "content": (
                    "Greet the user warmly in Hindi. Introduce yourself as Maya — "
                    "their friendly AI dost who can images banaa sakti hai, "
                    "videos create kar sakti hai, web search kar sakti hai, "
                    "aur shopping mein help kar sakti hai. Keep it brief, "
                    "fun, and casual like a real Indian friend. Speak in Hindi."
                ),
            }
        )
        await task.queue_frames([LLMRunFrame()])

    @transport.event_handler("on_client_disconnected")
    async def on_client_disconnected(transport, client):
        global uploaded_image
        uploaded_image = None  # Clean up memory
        logger.info("Client disconnected")
        await task.cancel()

    # ----- Run -----
    runner = PipelineRunner(handle_sigint=getattr(runner_args, "handle_sigint", False))
    await runner.run(task)


async def bot(runner_args):
    """Main bot entry point compatible with both local dev and Pipecat Cloud.

    Works with:
    - DailyRunnerArguments / DailySessionArguments (local + cloud Daily)
    - SmallWebRTCRunnerArguments (local webrtc)
    - PipecatSessionArguments (cloud non-Daily — logged and skipped)
    """

    # If cloud sends a non-Daily session (no room), we can't run a voice bot.
    if not isinstance(runner_args, DailyRunnerArguments):
        # Try create_transport for webrtc and other local transport types
        try:
            transport_params = {
                "daily": lambda: DailyParams(
                    audio_in_enabled=True,
                    audio_out_enabled=True,
                ),
                "webrtc": lambda: TransportParams(
                    audio_in_enabled=True,
                    audio_out_enabled=True,
                ),
            }
            transport = await create_transport(runner_args, transport_params)
        except ValueError:
            logger.warning(
                f"Unsupported runner args type {type(runner_args).__name__}, "
                "skipping session (client must request createDailyRoom=true)"
            )
            return
    else:
        # Daily transport (local dev + Pipecat Cloud with Daily room)
        transport = DailyTransport(
            runner_args.room_url,
            runner_args.token,
            "Maya AI",
            DailyParams(
                audio_in_enabled=True,
                audio_out_enabled=True,
            ),
        )

    await run_bot(transport, runner_args)


if __name__ == "__main__":
    from pipecat.runner.run import main

    main()
