# This file was auto-generated by Fern from our API Definition.

import json
import typing

import websockets
import websockets.sync.connection as websockets_sync_connection
from ..core.events import EventEmitterMixin, EventType
from ..core.pydantic_utilities import parse_obj_as
from ..types.audio_data import AudioData
from ..types.audio_message import AudioMessage
from ..types.speech_to_text_streaming_response import (
    SpeechToTextStreamingResponse,
)
from ..types.stt_flush_signal import SttFlushSignal

SpeechToTextStreamingSocketClientResponse = typing.Union[SpeechToTextStreamingResponse]


class AsyncSpeechToTextStreamingSocketClient(EventEmitterMixin):
    def __init__(self, *, websocket: websockets.WebSocketClientProtocol):
        super().__init__()
        self._websocket = websocket

    async def __aiter__(self):
        async for message in self._websocket:
            message = json.loads(message) if isinstance(message, str) else message
            yield parse_obj_as(
                SpeechToTextStreamingSocketClientResponse, message
            )  # type: ignore

    async def start_listening(self):
        """
        Start listening for messages on the websocket connection.

        Emits events in the following order:
        - EventType.OPEN when connection is established
        - EventType.MESSAGE for each message received
        - EventType.ERROR if an error occurs
        - EventType.CLOSE when connection is closed
        """
        self._emit(EventType.OPEN, None)
        try:
            async for raw_message in self._websocket:
                raw_message = (
                    json.loads(raw_message)
                    if isinstance(raw_message, str)
                    else raw_message
                )
                parsed = parse_obj_as(
                    SpeechToTextStreamingSocketClientResponse, raw_message
                )  # type: ignore
                self._emit(EventType.MESSAGE, parsed)
        except websockets.WebSocketException as exc:
            self._emit(EventType.ERROR, exc)
        finally:
            self._emit(EventType.CLOSE, None)

    async def transcribe(self, audio: str, encoding="audio/wav", sample_rate=16000):
        """
        Sends transcription request to the server.
        :param audio: Base64 encoded audio data
        :param encoding: Audio encoding format (default is "audio/wav")
        :param sample_rate: Audio sample rate in Hz (default is 16000)
        """

        return await self._send_speech_to_text_streaming_audio_message(
            message=AudioMessage(
                audio=AudioData(data=audio, sample_rate=sample_rate, encoding=encoding)
            )
        )

    async def flush(self) -> None:
        """
        Signal to flush the audio buffer and force finalize partial transcriptions.
        Use this to force processing of any remaining audio that hasn't been
        transcribed yet.
        """
        message = SttFlushSignal()
        await self._send_model(message)

    async def _send_speech_to_text_streaming_audio_message(
        self, message: AudioMessage
    ) -> None:
        """
        Send a message to the websocket connection.
        The message will be sent as a AudioMessage.
        """
        await self._send_model(message)

    async def recv(self) -> SpeechToTextStreamingSocketClientResponse:
        """
        Receive a message from the websocket connection.
        """
        data = await self._websocket.recv()
        data = json.loads(data) if isinstance(data, str) else data
        return parse_obj_as(
            SpeechToTextStreamingSocketClientResponse, data
        )  # type: ignore

    async def _send(self, data: typing.Any) -> None:
        """
        Send a message to the websocket connection.
        """
        if isinstance(data, dict):
            data = json.dumps(data)
        await self._websocket.send(data)

    async def _send_model(self, data: typing.Any) -> None:
        """
        Send a Pydantic model to the websocket connection.
        """
        await self._send(data.dict())


class SpeechToTextStreamingSocketClient(EventEmitterMixin):
    def __init__(self, *, websocket: websockets_sync_connection.Connection):
        super().__init__()
        self._websocket = websocket

    def __iter__(self):
        for message in self._websocket:
            message = json.loads(message) if isinstance(message, str) else message
            yield parse_obj_as(
                SpeechToTextStreamingSocketClientResponse, message
            )  # type: ignore

    def start_listening(self):
        """
        Start listening for messages on the websocket connection.

        Emits events in the following order:
        - EventType.OPEN when connection is established
        - EventType.MESSAGE for each message received
        - EventType.ERROR if an error occurs
        - EventType.CLOSE when connection is closed
        """
        self._emit(EventType.OPEN, None)
        try:
            for raw_message in self._websocket:
                raw_message = (
                    json.loads(raw_message)
                    if isinstance(raw_message, str)
                    else raw_message
                )
                parsed = parse_obj_as(
                    SpeechToTextStreamingSocketClientResponse, raw_message
                )  # type: ignore
                self._emit(EventType.MESSAGE, parsed)
        except websockets.WebSocketException as exc:
            self._emit(EventType.ERROR, exc)
        finally:
            self._emit(EventType.CLOSE, None)

    def transcribe(self, audio: str, encoding="audio/wav", sample_rate=16000) -> None:
        """
        Sends transcription request to the server.
        :param audio: Base64 encoded audio data
        :param encoding (Optional): Audio encoding format (default is "audio/wav")
        :param sample_rate (Optional): Audio sample rate in Hz (default is 16000)
        """
        return self._send_speech_to_text_streaming_audio_message(
            message=AudioMessage(
                audio=AudioData(data=audio, sample_rate=sample_rate, encoding=encoding)
            )
        )

    def flush(self) -> None:
        """
        Signal to flush the audio buffer and force finalize partial transcriptions.
        Use this to force processing of any remaining audio that hasn't been
        transcribed yet.
        """
        message = SttFlushSignal()
        self._send_model(message)

    def recv(self) -> SpeechToTextStreamingSocketClientResponse:
        """
        Receive a message from the websocket connection.
        """
        data = self._websocket.recv()
        data = json.loads(data) if isinstance(data, str) else data
        return parse_obj_as(
            SpeechToTextStreamingSocketClientResponse, data
        )  # type: ignore

    def _send_speech_to_text_streaming_audio_message(
        self, message: AudioMessage
    ) -> None:
        """
        Send a message to the websocket connection.
        The message will be sent as a AudioMessage.
        """
        self._send_model(message)

    def _send(self, data: typing.Any) -> None:
        """
        Send a message to the websocket connection.
        """
        if isinstance(data, dict):
            data = json.dumps(data)
        self._websocket.send(data)

    def _send_model(self, data: typing.Any) -> None:
        """
        Send a Pydantic model to the websocket connection.
        """
        self._send(data.dict())
