from concurrent import futures
import torch
from models import build_model
from collections import deque
import grpc
import text_to_speech_pb2
import text_to_speech_pb2_grpc
from chat_database import save_chat_entry
import fastAPI
from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_from_chunks
from providers.llm_provider import getResponseWithRAG, getResponse


device = 'cuda' if torch.cuda.is_available() else 'cpu'

MODEL = build_model('kokoro-v0_19.pth', device)

VOICE_NAME = [
    'af',
    'af_bella', 'af_sarah', 'am_adam', 'am_michael',
    'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
    'af_nicole', 'af_sky',
][0]


VOICEPACK = torch.load(
    f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)


class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer):
    def ProcessText(self, request_iterator , context):
        try:
            global VOICEPACK
            print("Received new request")
            parameters = {
                "processing_active": False,
                "queue": deque(),
                "file_number": 0,
                "session_id": "",
                "interrupt_seq": 0,
                "temperature": 1,
                "activeVoice": "af",
                "maxTokens": 500,
            }
            for request in request_iterator:
                field = request.WhichOneof('request_data')
                if field == 'metadata':
                    meta = request.metadata
                    print("Metadata received:")
                    print("  session_id:", meta.session_id)
                    print("  silenceDuration:", meta.silenceDuration)
                    print("  threshold:", meta.threshold)
                    print("  temperature:", meta.temperature)
                    print("  activeVoice:", meta.activeVoice)
                    print("  maxTokens:", meta.maxTokens)
                    print("Metadata : ", request.metadata)
                    if meta.session_id:
                        parameters["session_id"] = meta.session_id
                    if meta.temperature:
                        parameters["temperature"] = meta.temperature
                    if meta.maxTokens:
                        parameters["maxTokens"] = meta.maxTokens
                    if meta.activeVoice:
                        parameters["activeVoice"] = meta.activeVoice
                        VOICEPACK = torch.load(
                            f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device)
                    continue
                elif field == 'text':
                    text = request.text
                    if not text:
                        continue
                    # yield text_to_speech_pb2.ProcessTextResponse(
                    #     buffer=dummy_bytes(),
                    #     session_id=parameters["session_id"],
                    #     sequence_id="0",
                    #     transcript="",
                    # )
                    # intent = check_for_rag(
                    #     text, parameters["session_id"])
                    # print("Intent : ", intent.intent)
                    # print("Intent : ", intent.rag)
                    save_chat_entry(parameters["session_id"], "user", text)
                    parameters["queue"].clear()
                    yield text_to_speech_pb2.ProcessTextResponse(
                        buffer=dummy_bytes(),
                        session_id=parameters["session_id"],
                        sequence_id="-2",
                        transcript=text,
                    )
                    final_response = ""
                    complete_response = ""
                    response = getResponse(text, parameters["session_id"])
                    for chunk in response:
                        msg = chunk.choices[0].delta.content
                        if msg:
                            final_response += msg
                            complete_response += msg
                            if final_response.endswith(('.', '!', '?')):
                                parameters["file_number"] += 1
                                parameters["queue"].append(
                                    (final_response, parameters["file_number"]))
                                final_response = ""
                                if not parameters["processing_active"]:
                                    yield from self.process_queue(parameters)

                    if final_response:
                        parameters["file_number"] += 1
                        parameters["queue"].append(
                            (final_response, parameters["file_number"]))
                        if not parameters["processing_active"]:
                            yield from self.process_queue(parameters)

                    if ("Let me check" in complete_response):
                        final_response = ""
                        complete_response = ""
                        response = getResponseWithRAG(
                            text, parameters["session_id"])
                        for chunk in response:
                            msg = chunk.choices[0].delta.content
                            if msg:
                                final_response += msg
                                complete_response += msg
                                if final_response.endswith(('.', '!', '?')):
                                    parameters["file_number"] += 1
                                    parameters["queue"].append(
                                        (final_response, parameters["file_number"]))
                                    final_response = ""
                                    if not parameters["processing_active"]:
                                        yield from self.process_queue(parameters)

                        if final_response:
                            parameters["file_number"] += 1
                            parameters["queue"].append(
                                (final_response, parameters["file_number"]))
                            if not parameters["processing_active"]:
                                yield from self.process_queue(parameters)

                elif field == 'status':
                    transcript = request.status.transcript
                    played_seq = request.status.played_seq
                    interrupt_seq = request.status.interrupt_seq
                    parameters["interrupt_seq"] = interrupt_seq
                    save_chat_entry(
                        parameters["session_id"], "assistant", transcript)
                    continue
                else:
                    continue
        except Exception as e:
            print("Error in ProcessText:", e)

    def process_queue(self, parameters):
        global VOICEPACK
        try:
            while True:
                if not parameters["queue"]:
                    parameters["processing_active"] = False
                    break
                parameters["processing_active"] = True
                sentence, file_number = parameters["queue"].popleft()
                if file_number <= int(parameters["interrupt_seq"]):
                    continue

                combined_audio = generate_audio_from_chunks(
                    sentence, MODEL, VOICEPACK, VOICE_NAME)
                audio_bytes = get_audio_bytes(combined_audio)
                # filename = save_audio_to_file(combined_audio, file_number)
                yield text_to_speech_pb2.ProcessTextResponse(
                    buffer=audio_bytes,
                    session_id=parameters["session_id"],
                    sequence_id=str(file_number),
                    transcript=sentence,
                )
        except Exception as e:
            parameters["processing_active"] = False
            print("Error in process_queue:", e)


def serve():
    print("Starting gRPC server...")
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(
        TextToSpeechServicer(), server)
    server.add_insecure_port('[::]:8081')
    server.start()
    print("gRPC server is running on port 8081")
    server.wait_for_termination()


if __name__ == "__main__":
    serve()