import asyncio
import traceback
from typing import Optional

from pathlib import Path

import typer

from speech_enhancement.audio.audio_file_writer import AudioFileWriter
from speech_enhancement.audio.audio_utils       import AudioFormat, base64_to_bytes, stream_file, stream_microphone
from speech_enhancement.core                    import constants
from speech_enhancement.core.websocket_client   import WebSocketClient
from speech_enhancement.core.logger             import Logger
from speech_enhancement.api.vdkclient           import VdkClient

app = typer.Typer()

@app.command()
def start(
    scheme:          str  = typer.Option(constants.Defaults.DEFAULT_VDK_SCHEME, "--scheme", "-P", help="Protocol to use for the VDK service (http or https)."),
    host:            str  = typer.Option(constants.Defaults.DEFAULT_VDK_HOST,   "--host",   "-h", help="VDK service host."),
    port:            int  = typer.Option(constants.Defaults.DEFAULT_VDK_PORT,   "--port",   "-p", help="VDK service port."),
    input_file:      str  = typer.Option(None,  "--input",     "-i", help="Input audio file."),
    input_reference: str  = typer.Option(None,  "--reference", "-r", help="Input reference audio file. (AEC for Barge-In)"),
    output_file:     str  = typer.Option(None,  "--output",    "-o", help="Output audio file. (if not ext .wav : output PCM 16Khz 16 mono) "),
    enhancer:        str  = typer.Option(None,  "--enhancer",  "-e", help="Enhancer to use."),
    listEnhancers:   bool = typer.Option(False, "--list",      "-l", help="List available enhancers."),
    verbose:         bool = typer.Option(False, "--verbose",   "-v"),
):
    code = asyncio.run(run(scheme=scheme, host=host, port=port, input_file=input_file, input_reference=input_reference,
                           output_file=output_file, enhancer=enhancer, listEnhancers=listEnhancers, verbose=verbose))
    typer.Exit(code)

async def run(
    scheme: str,
    host: str,
    port: int,
    input_file: Optional[str] = None,
    input_reference: Optional[str] = None,
    output_file: Optional[str] = None,
    enhancer: Optional[str] = None,
    listEnhancers: bool = False,
    verbose: bool = False
):
    # -- validating arguments --
    if listEnhancers:
        pass
    elif not enhancer:
        Logger.error("You must specify a enhancer."); return 1
    elif input_file and (not Path(input_file).exists() or not Path(input_file).is_file()):
        Logger.error(f"File {input_file} does not exist."); return 1
    elif input_reference and (not Path(input_reference).exists() or not Path(input_reference).is_file()):
        Logger.error(f"File {input_reference} does not exist."); return 1

    # -- Logging configuration --
    Logger.verbose = verbose
    Logger.debug(" -- configuration -- ")
    Logger.debug(f" scheme:                 {scheme}")
    Logger.debug(f" host:                   {host}")
    Logger.debug(f" port:                   {port}")
    Logger.debug(f" input_file:             {input_file}")
    Logger.debug(f" input_reference:        {input_reference}")
    Logger.debug(f" output_file:            {output_file}")
    Logger.debug(f" enhancer:               {enhancer}")
    Logger.debug(f" list:                   {list}")
    Logger.debug(f" verbose:                {verbose}")
    Logger.debug(" ------------------------ ")

    # -----------------------------------------------------------------
    # ---------------------- running application ----------------------
    # -----------------------------------------------------------------

    client = VdkClient(scheme, host, port)
    wsClient = None
    try:
        # ------- Making sure service is reachable and ready. -------
        if not await client.health.check_healthz():
            Logger.error("❌ The VDK service is not ready.")
            return 1

        Logger.info(f"✅ VDK service is reachable.")

        enhancers = await client.speech_enhancement.get_available_enhancers()

        # ------- Did the user requested a list of enhancers ? -------
        if listEnhancers:
            Logger.info(f"Available enhancers: {enhancers}")
            return 0

        if enhancer and enhancer not in enhancers:
            Logger.error(f"❌ Enhancer {enhancer} is not available.")
            return 1

        if len(enhancers) == 0:
            Logger.error("❌ No enhancers available.")
            return 1

        selectedEnhancer = enhancer if enhancer else enhancers[0]

        # ------- Asking for an enhancement (retrieving a token) -------
        token = await client.speech_enhancement.enhance(selectedEnhancer)
        if token is None:
            Logger.error("❌ Could not retrieve a token.")
            return 1

        Logger.info(f"✅ Token: {token}")

        # ------- Starting the websocket connection -------
        output_filename = None
        if input_file:
            output_filename = output_file if output_file else f"{Path(input_file).stem}_enhanced.wav"
        else:
            output_filename = output_file if output_file else f"mic_enhanced.wav"

        format = AudioFormat(16000, 1, 16)
        output_file = AudioFileWriter(output_filename, format) # Output file.
        wsClient = WebSocketClient()
        wsClient.on_error = lambda ex: Logger.error("❌ WebSocket error: " + "\n".join(traceback.format_exception(type(ex), ex, ex.__traceback__)))
        wsClient.on_close = lambda: Logger.debug("WebSocket connection closed.")
        wsClient.on_message = lambda data: handle_message(data, output_file)
        await wsClient.connect(client.ws_uri(token))

        Logger.info("✅ WebSocket connection established.")

        # Streaming reference if there is.
        if input_reference:
            await stream_file(wsClient, input_reference, realtime=True, is_reference=True)

        # Streaming input file
        if input_file:
            await stream_file(wsClient, input_file, realtime=True)
        else:
            await stream_microphone(wsClient)

        await wsClient.wait_for_completion()
        output_file.close()

        Logger.info("✅ Done.")

    except Exception as e:
        Logger.error("❌ Application failed: " + "\n".join(traceback.format_exception(type(e), e, e.__traceback__)))
        return 1
    finally:
        if wsClient is not None:
            await wsClient.close()
        await client.close()

    return 0

def handle_message(data, output_file: AudioFileWriter):
    logMessage(data)
    if   "event" in data: Logger.info(f"Event: {data['event']['code_string']}")
    elif "error" in data: Logger.error(f"Error: {data['error']['code']}")
    elif "data"  in data and "last" in data:
        decodedAudio = base64_to_bytes(data["data"], len(constants.B64_PCM_PREFIX))
        output_file.write_chunk(decodedAudio)

def logMessage(data):
    display = dict(data)
    if "data" in display:
        display["data"] = "<base64data>"

    Logger.debug(f"Socket message: {display}")

def main():
    app()
