import sounddevice as sd
import time
from typing import Callable, Optional, Awaitable
import asyncio

from voice_synthesis.audio.audio_utils import AudioFormat
from voice_synthesis.core.logger import Logger

class AudioCapture:
    """
    Capture microphone audio in 16-bit PCM format.
    """
    def __init__(self, format: AudioFormat):
        self.format = format
        self.period_ms = 30
        self.stream = None
        self.on_audio_data: Optional[Callable[[bytes], Awaitable[None]]] = None  # Async callback
        self._loop = None

        self._is_recording = False
        self._stop_event = asyncio.Event()

    @property
    def sample_rate(self):
        return self.format.sample_rate

    @property
    def channels(self):
        return self.format.channels

    def _callback(self, indata, frames, time_info, status):
        if self.on_audio_data is None:
            return

        if status:
            Logger.debug(f"Audio stream status: {status}")

        raw = indata.tobytes()

        if self._loop and self._loop.is_running():
            asyncio.run_coroutine_threadsafe(self.on_audio_data(raw), self._loop)

    def start(self, loop: Optional[asyncio.AbstractEventLoop] = None):
        """Start capturing microphone audio."""
        if self._is_recording:
            return

        self._is_recording = True
        self._loop = loop or asyncio.get_event_loop()

        try:
            self.stream = sd.InputStream(
                samplerate=self.sample_rate,
                channels=self.channels,
                dtype="int16",
                blocksize=int(self.sample_rate * self.period_ms / 1000),
                callback=self._callback
            )
            self.stream.start()
        except Exception as e:
            self._is_recording = False
            raise e

    def stop(self):
        """Stop audio capture."""
        if not self._is_recording:
            return

        self._is_recording = False

        if self.stream:
            try:
                self.stream.stop()
                self.stream.close()
            except Exception as e:
                Logger.info(f"Error stopping stream: {e}")
            finally:
                self.stream = None

        self._stop_event.set()

    async def wait_for_completion(self):
        """Wait until recording stops."""
        await self._stop_event.wait()

    @property
    def is_recording(self):
        """Check if currently recording."""
        return self._is_recording

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
