from typing import Optional
import numpy as np
import sounddevice as sd
import time
import threading
from collections import deque

from voice_synthesis.core.logger import Logger

class AudioPCMPlayer:
    def __init__(self, sample_rate: int = 22050, channels: int = 1):
        self.sample_rate = sample_rate
        self.channels = channels

        self.buffer = deque()
        self._lock = threading.Lock()

        Logger.info(f"Creating audio player: {sample_rate}Hz, {channels} channel(s)")

        self.stream = sd.OutputStream(
            samplerate=sample_rate,
            channels=channels,
            dtype='int16',
            callback=self._callback
        )
        self.stream.start()

    def _callback(self, outdata, frames, time_info, status):
        if status:
            Logger.debug(f"Audio status: {status}")

        bytes_needed = frames * self.channels * 2

        with self._lock:
            if not self.buffer:
                outdata.fill(0)
                return

            collected = bytearray()
            while self.buffer and len(collected) < bytes_needed:
                chunk = self.buffer.popleft() # popping audio
                needed = bytes_needed - len(collected)

                if len(chunk) <= needed:
                    collected.extend(chunk)
                else:
                    collected.extend(chunk[:needed])
                    self.buffer.appendleft(chunk[needed:]) # put the rest back
                    break

        # Make sure we have enough data (0 filled if not)
        if collected:
            audio_data = np.frombuffer(collected, dtype='int16')
            samples_collected = len(audio_data)
            samples_needed = frames * self.channels

            if samples_collected >= samples_needed:
                outdata[:] = audio_data[:samples_needed].reshape(-1, self.channels)
            else:
                outdata.fill(0)
                outdata[:samples_collected // self.channels] = audio_data.reshape(-1, self.channels)
        else:
            outdata.fill(0)

    def add_samples(self, pcm_data: bytes) -> None:
        with self._lock:
            self.buffer.append(pcm_data)

    def wait_for_completion(self) -> None:
        while self.buffer:
            time.sleep(0.05)

    def close(self) -> None:
        if self.stream:
            self.stream.stop()
            self.stream.close()
        with self._lock:
            self.buffer.clear()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
