from concurrent import futures import torch from models import build_model from collections import deque import grpc import text_to_speech_pb2 import text_to_speech_pb2_grpc from chat_database import save_chat_entry import fastAPI from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_from_chunks from providers.llm_provider import getResponseWithRAG, getResponse device = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL = build_model('kokoro-v0_19.pth', device) VOICE_NAME = [ 'af', 'af_bella', 'af_sarah', 'am_adam', 'am_michael', 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', 'af_nicole', 'af_sky', ][0] VOICEPACK = torch.load( f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): def ProcessText(self, request_iterator , context): try: global VOICEPACK print("Received new request") parameters = { "processing_active": False, "queue": deque(), "file_number": 0, "session_id": "", "interrupt_seq": 0, "temperature": 1, "activeVoice": "af", "maxTokens": 500, } for request in request_iterator: field = request.WhichOneof('request_data') if field == 'metadata': meta = request.metadata print("Metadata received:") print(" session_id:", meta.session_id) print(" silenceDuration:", meta.silenceDuration) print(" threshold:", meta.threshold) print(" temperature:", meta.temperature) print(" activeVoice:", meta.activeVoice) print(" maxTokens:", meta.maxTokens) print("Metadata : ", request.metadata) if meta.session_id: parameters["session_id"] = meta.session_id if meta.temperature: parameters["temperature"] = meta.temperature if meta.maxTokens: parameters["maxTokens"] = meta.maxTokens if meta.activeVoice: parameters["activeVoice"] = meta.activeVoice VOICEPACK = torch.load( f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device) continue elif field == 'text': text = request.text if not text: continue # yield text_to_speech_pb2.ProcessTextResponse( # buffer=dummy_bytes(), # session_id=parameters["session_id"], # sequence_id="0", # transcript="", # ) # intent = check_for_rag( # text, parameters["session_id"]) # print("Intent : ", intent.intent) # print("Intent : ", intent.rag) save_chat_entry(parameters["session_id"], "user", text) parameters["queue"].clear() yield text_to_speech_pb2.ProcessTextResponse( buffer=dummy_bytes(), session_id=parameters["session_id"], sequence_id="-2", transcript=text, ) final_response = "" complete_response = "" response = getResponse(text, parameters["session_id"]) for chunk in response: msg = chunk.choices[0].delta.content if msg: final_response += msg complete_response += msg if final_response.endswith(('.', '!', '?')): parameters["file_number"] += 1 parameters["queue"].append( (final_response, parameters["file_number"])) final_response = "" if not parameters["processing_active"]: yield from self.process_queue(parameters) if final_response: parameters["file_number"] += 1 parameters["queue"].append( (final_response, parameters["file_number"])) if not parameters["processing_active"]: yield from self.process_queue(parameters) if ("Let me check" in complete_response): final_response = "" complete_response = "" response = getResponseWithRAG( text, parameters["session_id"]) for chunk in response: msg = chunk.choices[0].delta.content if msg: final_response += msg complete_response += msg if final_response.endswith(('.', '!', '?')): parameters["file_number"] += 1 parameters["queue"].append( (final_response, parameters["file_number"])) final_response = "" if not parameters["processing_active"]: yield from self.process_queue(parameters) if final_response: parameters["file_number"] += 1 parameters["queue"].append( (final_response, parameters["file_number"])) if not parameters["processing_active"]: yield from self.process_queue(parameters) elif field == 'status': transcript = request.status.transcript played_seq = request.status.played_seq interrupt_seq = request.status.interrupt_seq parameters["interrupt_seq"] = interrupt_seq save_chat_entry( parameters["session_id"], "assistant", transcript) continue else: continue except Exception as e: print("Error in ProcessText:", e) def process_queue(self, parameters): global VOICEPACK try: while True: if not parameters["queue"]: parameters["processing_active"] = False break parameters["processing_active"] = True sentence, file_number = parameters["queue"].popleft() if file_number <= int(parameters["interrupt_seq"]): continue combined_audio = generate_audio_from_chunks( sentence, MODEL, VOICEPACK, VOICE_NAME) audio_bytes = get_audio_bytes(combined_audio) # filename = save_audio_to_file(combined_audio, file_number) yield text_to_speech_pb2.ProcessTextResponse( buffer=audio_bytes, session_id=parameters["session_id"], sequence_id=str(file_number), transcript=sentence, ) except Exception as e: parameters["processing_active"] = False print("Error in process_queue:", e) def serve(): print("Starting gRPC server...") server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server( TextToSpeechServicer(), server) server.add_insecure_port('[::]:8081') server.start() print("gRPC server is running on port 8081") server.wait_for_termination() if __name__ == "__main__": serve()