|
from concurrent import futures |
|
import torch |
|
from models import build_model |
|
from collections import deque |
|
import grpc |
|
import text_to_speech_pb2 |
|
import text_to_speech_pb2_grpc |
|
from chat_database import save_chat_entry |
|
import fastAPI |
|
from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_from_chunks |
|
from providers.llm_provider import getResponseWithRAG, getResponse |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
MODEL = build_model('kokoro-v0_19.pth', device) |
|
|
|
VOICE_NAME = [ |
|
'af', |
|
'af_bella', 'af_sarah', 'am_adam', 'am_michael', |
|
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', |
|
'af_nicole', 'af_sky', |
|
][0] |
|
|
|
|
|
VOICEPACK = torch.load( |
|
f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) |
|
|
|
|
|
class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): |
|
def ProcessText(self, request_iterator , context): |
|
try: |
|
global VOICEPACK |
|
print("Received new request") |
|
parameters = { |
|
"processing_active": False, |
|
"queue": deque(), |
|
"file_number": 0, |
|
"session_id": "", |
|
"interrupt_seq": 0, |
|
"temperature": 1, |
|
"activeVoice": "af", |
|
"maxTokens": 500, |
|
} |
|
for request in request_iterator: |
|
field = request.WhichOneof('request_data') |
|
if field == 'metadata': |
|
meta = request.metadata |
|
print("Metadata received:") |
|
print(" session_id:", meta.session_id) |
|
print(" silenceDuration:", meta.silenceDuration) |
|
print(" threshold:", meta.threshold) |
|
print(" temperature:", meta.temperature) |
|
print(" activeVoice:", meta.activeVoice) |
|
print(" maxTokens:", meta.maxTokens) |
|
print("Metadata : ", request.metadata) |
|
if meta.session_id: |
|
parameters["session_id"] = meta.session_id |
|
if meta.temperature: |
|
parameters["temperature"] = meta.temperature |
|
if meta.maxTokens: |
|
parameters["maxTokens"] = meta.maxTokens |
|
if meta.activeVoice: |
|
parameters["activeVoice"] = meta.activeVoice |
|
VOICEPACK = torch.load( |
|
f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device) |
|
continue |
|
elif field == 'text': |
|
text = request.text |
|
if not text: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_chat_entry(parameters["session_id"], "user", text) |
|
parameters["queue"].clear() |
|
yield text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-2", |
|
transcript=text, |
|
) |
|
final_response = "" |
|
complete_response = "" |
|
response = getResponse(text, parameters["session_id"]) |
|
for chunk in response: |
|
msg = chunk.choices[0].delta.content |
|
if msg: |
|
final_response += msg |
|
complete_response += msg |
|
if final_response.endswith(('.', '!', '?')): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
final_response = "" |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
if final_response: |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
if ("Let me check" in complete_response): |
|
final_response = "" |
|
complete_response = "" |
|
response = getResponseWithRAG( |
|
text, parameters["session_id"]) |
|
for chunk in response: |
|
msg = chunk.choices[0].delta.content |
|
if msg: |
|
final_response += msg |
|
complete_response += msg |
|
if final_response.endswith(('.', '!', '?')): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
final_response = "" |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
if final_response: |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
elif field == 'status': |
|
transcript = request.status.transcript |
|
played_seq = request.status.played_seq |
|
interrupt_seq = request.status.interrupt_seq |
|
parameters["interrupt_seq"] = interrupt_seq |
|
save_chat_entry( |
|
parameters["session_id"], "assistant", transcript) |
|
continue |
|
else: |
|
continue |
|
except Exception as e: |
|
print("Error in ProcessText:", e) |
|
|
|
def process_queue(self, parameters): |
|
global VOICEPACK |
|
try: |
|
while True: |
|
if not parameters["queue"]: |
|
parameters["processing_active"] = False |
|
break |
|
parameters["processing_active"] = True |
|
sentence, file_number = parameters["queue"].popleft() |
|
if file_number <= int(parameters["interrupt_seq"]): |
|
continue |
|
|
|
combined_audio = generate_audio_from_chunks( |
|
sentence, MODEL, VOICEPACK, VOICE_NAME) |
|
audio_bytes = get_audio_bytes(combined_audio) |
|
|
|
yield text_to_speech_pb2.ProcessTextResponse( |
|
buffer=audio_bytes, |
|
session_id=parameters["session_id"], |
|
sequence_id=str(file_number), |
|
transcript=sentence, |
|
) |
|
except Exception as e: |
|
parameters["processing_active"] = False |
|
print("Error in process_queue:", e) |
|
|
|
|
|
def serve(): |
|
print("Starting gRPC server...") |
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) |
|
text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server( |
|
TextToSpeechServicer(), server) |
|
server.add_insecure_port('[::]:8081') |
|
server.start() |
|
print("gRPC server is running on port 8081") |
|
server.wait_for_termination() |
|
|
|
|
|
if __name__ == "__main__": |
|
serve() |
|
|