from typing import Callable, Any, Dict
import asyncio
import base64
import json
import websockets

from app.core.logger import Logger

JsonDict = Dict[str, Any]
OnMessageCallback = Callable[[JsonDict], None]
OnErrorCallback = Callable[[Exception], None]
OnCloseCallback = Callable[[], None]

class WebSocketClient:
    def __init__(self):
        self.connection_ended = asyncio.Event()
        self.websocket = None
        self.receive_task = None
        self.on_message = None
        self.on_error = None
        self.on_close = None

    @property
    def is_connected(self):
        return self.websocket is not None and self.websocket.open

    async def wait_for_completion(self):
        """Wait until EndOfConnection is received."""
        await self.connection_ended.wait()

    async def connect(self, url) -> bool:
        try:
            self.websocket = await websockets.connect(url)
            Logger.info(f"Connected to {url}")
            self.receive_task = asyncio.create_task(self._receive_loop())
            return True

        except Exception as ex:
            if self.on_error:
                self.on_error(ex)
        return False

    async def send_json(self, obj):
        try:
            if not self.websocket:
                raise RuntimeError("WebSocket not connected")

            await self.websocket.send(json.dumps(obj))

        except Exception as ex:
            if self.on_error:
                self.on_error(ex)

    async def send_audio(self, pcm_bytes, is_last=False, json: dict = None):
        try:
            base64_audio = base64.b64encode(pcm_bytes).decode("utf-8")

            if json is None: json = {}
            msg = {
                "data": "data:audio/pcm;base64,{}".format(base64_audio),
                "last": is_last,
                **json
            }

            # Uncomment this and use verbose to make sure you're sending audio.
            Logger.debug(f"Sending bytes: {len(pcm_bytes)} bytes - {is_last}")

            await self.send_json(msg)

        except Exception as ex:
            Logger.error(f"Error sending audio: {ex}")
            if self.on_error:
                self.on_error(ex)

    async def _receive_loop(self):
        try:
            async for message in self.websocket:
                try:
                    data = json.loads(message)
                    if self.on_message:
                        self.on_message(data)

                    if data.get('type') == 'EndOfConnection':
                        self.connection_ended.set()

                except Exception as ex:
                    if self.on_error:
                        self.on_error(ex)

        except websockets.ConnectionClosedOK:
            if self.on_close:
                self.on_close()

        except Exception as ex:
            if self.on_error:
                self.on_error(ex)

        finally:
            if self.on_close:
                self.on_close()

    async def close(self):
        if self.websocket:
            await self.websocket.close()

        if self.receive_task:
            self.receive_task.cancel()
