from pathlib import Path
import wave
from typing import Optional

from app.audio.audio_utils import AudioFormat
import app.core.constants as constants

class AudioFileWriter:
    def __init__(self,
                 file_path: str,
                 format: Optional[AudioFormat] = constants.Defaults.DEFAULT_AUDIO_FORMAT):
        self.path = Path(file_path)
        self.path.parent.mkdir(parents=True, exist_ok=True)
        ext = self.path.suffix.lower()

        self.is_wav = False
        self.format = format

        if ext == ".wav":
            if format is None:
                raise ValueError("Writing WAV requires an AudioFormat")

            self.is_wav = True
            self.file = wave.open(str(self.path), "wb")

            self.file.setnchannels(format.channels)
            self.file.setsampwidth(format.bits_per_sample // 8)
            self.file.setframerate(format.sample_rate)

        elif ext == ".pcm":
            if format is not None:
                raise ValueError("PCM files must not have a format")
            self.file = open(self.path, "wb")

        else:
            raise ValueError(f"Unsupported extension: {ext}")

    def write_chunk(self, data: bytes) -> None:
        if self.is_wav:
            self.file.writeframesraw(data)
        else:
            self.file.write(data)
            self.file.flush()

    def close(self) -> None:
        if hasattr(self, "file") and self.file:
            self.file.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
