|
from concurrent import futures |
|
import asyncio |
|
import torch |
|
from models import build_model |
|
from collections import deque |
|
import grpc |
|
import text_to_speech_pb2 |
|
import text_to_speech_pb2_grpc |
|
from chat_database import save_chat_entry |
|
import fastAPI |
|
from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_stream |
|
from providers.llm_provider import getResponseWithRagAsync, getResponseAsync |
|
import numpy as np |
|
import os |
|
import re |
|
import time |
|
from silero_vad import load_silero_vad, VADIterator |
|
import random |
|
from providers.filler_words import filler_phrases |
|
from scipy.io.wavfile import write |
|
from faster_whisper import WhisperModel |
|
|
|
sampling_rate = 16_000 |
|
vad_model = load_silero_vad() |
|
frame_size = 512 |
|
DEFAULT_VAD_THRESHOLD = 0.50 |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
whisper_model = WhisperModel( |
|
"small", |
|
device=device, |
|
compute_type="int8", |
|
cpu_threads=os.cpu_count(), |
|
download_root="./models" |
|
) |
|
|
|
MODEL = build_model('kokoro-v0_19.pth', device) |
|
VOICE_NAME = [ |
|
'af', |
|
'af_bella', 'af_sarah', 'am_adam', 'am_michael', |
|
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', |
|
'af_nicole', 'af_sky', |
|
][0] |
|
|
|
|
|
VOICEPACK = torch.load( |
|
f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) |
|
|
|
AUDIO_FILES_DIR = 'audio_files' |
|
os.makedirs(AUDIO_FILES_DIR, exist_ok=True) |
|
|
|
PRE_CHUNK_LIMIT_BYTES = frame_size * 2 * 20 |
|
|
|
transcription_pool = futures.ThreadPoolExecutor(max_workers=10) |
|
|
|
terminators = ['.', '?', '!', '...', '…', '?!', '!?', '‽', '。', '؟', '۔'] |
|
|
|
BLACKLIST = { |
|
"Give me a minute.", |
|
"Let me check the details.", |
|
"Give me a minute. Let me check the details." |
|
} |
|
|
|
SHORT_UTTERANCE_BYTES = 9_600 |
|
|
|
|
|
dummy_audio = np.frombuffer( |
|
np.zeros(int(16_000 * 5.0), dtype=np.float32), dtype=np.int16).astype(np.float32) / 32768.0 |
|
|
|
|
|
def _fw_transcribe_block(audio_f32: np.ndarray) -> dict: |
|
segments, info = whisper_model.transcribe( |
|
audio_f32, |
|
language="en", |
|
beam_size=1, |
|
vad_filter=False, |
|
initial_prompt="Indian English accent; do not make up words.", |
|
no_speech_threshold=0.25, |
|
log_prob_threshold=-0.6, |
|
temperature=0 |
|
) |
|
text = "".join(seg.text for seg in segments) |
|
return {"text": text, "language": info.language, |
|
"language_probability": info.language_probability} |
|
|
|
|
|
async def safe_transcribe(audio_float32: np.ndarray): |
|
loop = asyncio.get_running_loop() |
|
return await loop.run_in_executor( |
|
transcription_pool, |
|
lambda: _fw_transcribe_block(audio_float32) |
|
) |
|
|
|
|
|
class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._transcribe_lock = asyncio.Lock() |
|
|
|
async def ProcessText(self, request_iterator, context): |
|
try: |
|
global VOICEPACK |
|
|
|
print("New connection") |
|
|
|
tts_queue = asyncio.Queue() |
|
response_queue = asyncio.Queue() |
|
|
|
|
|
|
|
parameters = { |
|
"processing_active": False, |
|
"queue": deque(), |
|
"file_number": 0, |
|
"session_id": "", |
|
"interrupt_seq": 0, |
|
"temperature": 1, |
|
"activeVoice": "af", |
|
"in_speech": False, |
|
"maxTokens": 500, |
|
"audio_buffer": bytearray(), |
|
"pre_chunks": bytearray(), |
|
"silence_counter": 0.0, |
|
"silence_duration": 0.8, |
|
"silence_threshold": DEFAULT_VAD_THRESHOLD, |
|
"VOICEPACK": VOICEPACK, |
|
"audio_count": 0, |
|
"user_sequence": 0, |
|
"last_file_number": 0, |
|
"vad_iter": VADIterator(vad_model, sampling_rate=sampling_rate, threshold=DEFAULT_VAD_THRESHOLD) |
|
} |
|
|
|
reader = asyncio.create_task( |
|
self._read_requests(request_iterator, tts_queue, response_queue, parameters)) |
|
|
|
tts = asyncio.create_task(self._tts_queue_worker( |
|
tts_queue, response_queue, parameters)) |
|
|
|
try: |
|
while True: |
|
resp = await response_queue.get() |
|
if resp is None: |
|
break |
|
yield resp |
|
finally: |
|
for t in (reader, tts): |
|
t.cancel() |
|
try: |
|
await t |
|
except asyncio.CancelledError: |
|
pass |
|
|
|
except Exception as e: |
|
print("Error in ProcessText:", e) |
|
|
|
async def _read_requests(self, request_iterator, tts_queue: asyncio.Queue, response_queue: asyncio.Queue, parameters): |
|
async for request in request_iterator: |
|
field = request.WhichOneof('request_data') |
|
if field == 'metadata': |
|
meta = request.metadata |
|
|
|
if meta.session_id: |
|
parameters["session_id"] = meta.session_id |
|
if meta.temperature: |
|
parameters["temperature"] = meta.temperature |
|
if meta.maxTokens: |
|
parameters["maxTokens"] = meta.maxTokens |
|
if meta.activeVoice: |
|
parameters["activeVoice"] = meta.activeVoice |
|
parameters["VOICEPACK"] = torch.load( |
|
f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device) |
|
print("\n\nVoice model loaded successfully") |
|
if meta.silenceDuration: |
|
silence_duration = meta.silenceDuration / 1000 |
|
parameters["silence_duration"] = silence_duration |
|
if meta.threshold: |
|
parameters["silence_threshold"] = meta.threshold / 100 |
|
parameters["vad_iter"] = VADIterator( |
|
vad_model, sampling_rate=sampling_rate, threshold=parameters["silence_threshold"]) |
|
|
|
print("\n\nPatameter : ", parameters) |
|
|
|
|
|
|
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-10", |
|
transcript="", |
|
) |
|
await response_queue.put(resp) |
|
|
|
continue |
|
elif field == 'audio_data': |
|
|
|
buffer = request.audio_data.buffer |
|
|
|
audio_data = np.frombuffer(buffer, dtype=np.int16) |
|
|
|
float_chunk = audio_data.astype(np.float32) / 32768.0 |
|
|
|
vad_result = parameters["vad_iter"](float_chunk) |
|
|
|
parameters["pre_chunks"].extend(buffer) |
|
if len(parameters["pre_chunks"]) > PRE_CHUNK_LIMIT_BYTES: |
|
overflow = len( |
|
parameters["pre_chunks"]) - PRE_CHUNK_LIMIT_BYTES |
|
del parameters["pre_chunks"][:overflow] |
|
|
|
if vad_result: |
|
if "start" in vad_result: |
|
parameters["in_speech"] = True |
|
parameters["audio_buffer"].extend( |
|
parameters["pre_chunks"]) |
|
if "end" in vad_result: |
|
if (len(parameters["audio_buffer"]) < SHORT_UTTERANCE_BYTES): |
|
parameters["audio_buffer"].extend( |
|
parameters["pre_chunks"]) |
|
parameters["in_speech"] = False |
|
|
|
if parameters["in_speech"]: |
|
parameters["audio_buffer"].extend(buffer) |
|
parameters["silence_counter"] = 0.0 |
|
parameters["audio_count"] += 1 |
|
else: |
|
sample_rate = 16000 |
|
duration = len(audio_data) / sample_rate |
|
parameters["silence_counter"] += duration |
|
|
|
if parameters["silence_counter"] >= parameters["silence_duration"]: |
|
parameters["silence_counter"] = 0.0 |
|
if parameters["audio_count"] < 2: |
|
parameters["audio_count"] = 0 |
|
continue |
|
parameters["audio_count"] = 0 |
|
print("Silence ") |
|
|
|
sample_rate = 16000 |
|
|
|
audio_float = np.frombuffer( |
|
parameters["audio_buffer"], dtype=np.int16).astype(np.float32) / 32768.0 |
|
|
|
parameters["audio_buffer"] = bytearray() |
|
|
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-3", |
|
transcript="", |
|
) |
|
await response_queue.put(resp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisper_start_time = time.time() |
|
result = "" |
|
try: |
|
loop = asyncio.get_running_loop() |
|
result = await loop.run_in_executor( |
|
transcription_pool, lambda: _fw_transcribe_block( |
|
audio_float) |
|
) |
|
|
|
except Exception as e: |
|
await tts_queue.put(("Sorry! I am not able to catch that can you repeat again please!", parameters["file_number"])) |
|
print("Error in transcribing text : ", e) |
|
continue |
|
|
|
whisper_end_time = time.time() |
|
time_taken_to_transcribe = whisper_end_time - whisper_start_time |
|
print( |
|
f"Transcribing time: {time_taken_to_transcribe:.4f} seconds") |
|
transcribed_text = result["text"] |
|
print( |
|
f"Transcribed Text :", transcribed_text) |
|
|
|
if not transcribed_text.strip(): |
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-5", |
|
transcript="", |
|
) |
|
await response_queue.put(resp) |
|
continue |
|
|
|
|
|
|
|
|
|
if transcribed_text: |
|
parameters["queue"].clear() |
|
parameters["user_sequence"] += 1 |
|
parameters["last_file_number"] = parameters["file_number"] |
|
while not response_queue.empty(): |
|
try: |
|
response_queue.get_nowait() |
|
response_queue.task_done() |
|
except asyncio.QueueEmpty: |
|
break |
|
while not tts_queue.empty(): |
|
try: |
|
tts_queue.get_nowait() |
|
tts_queue.task_done() |
|
except asyncio.QueueEmpty: |
|
break |
|
|
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-4", |
|
transcript="", |
|
) |
|
await response_queue.put(resp) |
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id="-2", |
|
transcript=transcribed_text, |
|
) |
|
save_chat_entry( |
|
parameters["session_id"], "user", transcribed_text) |
|
await response_queue.put(resp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_response = "" |
|
complete_response = "" |
|
current_user_sequence = parameters["user_sequence"] |
|
response = await getResponseAsync( |
|
transcribed_text, parameters["session_id"]) |
|
if response is None: |
|
continue |
|
for chunk in response: |
|
if (current_user_sequence != parameters["user_sequence"]): |
|
break |
|
msg = chunk.choices[0].delta.content |
|
if msg: |
|
complete_response += msg |
|
m = re.search(r'[.?!]', msg) |
|
if m: |
|
idx = m.start() |
|
segment = msg[:idx+1] |
|
leftover = msg[idx+1:] |
|
else: |
|
segment, leftover = msg, '' |
|
|
|
final_response += segment |
|
|
|
if segment.endswith(('.', '!', '?')): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
await tts_queue.put((final_response, parameters["file_number"])) |
|
final_response = leftover |
|
|
|
if final_response.strip(): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
await tts_queue.put((final_response, parameters["file_number"])) |
|
|
|
if ("Let me check" in complete_response): |
|
final_response = "" |
|
complete_response = "" |
|
current_user_sequence = parameters["user_sequence"] |
|
response = await getResponseWithRagAsync( |
|
transcribed_text, parameters["session_id"]) |
|
for chunk in response: |
|
if (current_user_sequence != parameters["user_sequence"]): |
|
break |
|
msg = chunk.choices[0].delta.content |
|
if msg: |
|
m = re.search(r'[.?!]', msg) |
|
if m: |
|
idx = m.start() |
|
segment = msg[:idx+1] |
|
leftover = msg[idx+1:] |
|
else: |
|
segment, leftover = msg, '' |
|
|
|
final_response += segment |
|
complete_response += segment |
|
|
|
if segment.endswith(('.', '!', '?')): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
await tts_queue.put((final_response, parameters["file_number"])) |
|
final_response = leftover |
|
|
|
if final_response.strip(): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append( |
|
(final_response, parameters["file_number"])) |
|
await tts_queue.put((final_response, parameters["file_number"])) |
|
|
|
continue |
|
|
|
elif field == 'status': |
|
transcript = request.status.transcript |
|
played_seq = request.status.played_seq |
|
interrupt_seq = request.status.interrupt_seq |
|
parameters["interrupt_seq"] = interrupt_seq |
|
text = transcript.strip() if transcript else "" |
|
if text and text not in BLACKLIST: |
|
save_chat_entry( |
|
parameters["session_id"], |
|
"assistant", |
|
transcript |
|
) |
|
continue |
|
else: |
|
continue |
|
|
|
async def _tts_queue_worker(self, tts_queue: asyncio.Queue, |
|
response_queue: asyncio.Queue, |
|
params: dict): |
|
""" |
|
Pull (text, seq) off tts_queue, run generate_audio_stream, wrap each chunk |
|
in ProcessTextResponse, and push into response_queue. |
|
""" |
|
while True: |
|
item = await tts_queue.get() |
|
tts_queue.task_done() |
|
if item is None: |
|
break |
|
|
|
sentence, seq = item |
|
|
|
if seq <= int(params["interrupt_seq"]): |
|
continue |
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
async for audio_chunk in generate_audio_stream( |
|
sentence, MODEL, params["VOICEPACK"], VOICE_NAME |
|
): |
|
audio_bytes = get_audio_bytes(audio_chunk) |
|
if seq <= int(params["last_file_number"]): |
|
break |
|
resp = text_to_speech_pb2.ProcessTextResponse( |
|
buffer=audio_bytes, |
|
session_id=params["session_id"], |
|
sequence_id=str(seq), |
|
transcript=sentence, |
|
) |
|
await response_queue.put(resp) |
|
|
|
|
|
async def serve(): |
|
print("Starting gRPC server...") |
|
|
|
|
|
server = grpc.aio.server() |
|
text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server( |
|
TextToSpeechServicer(), server) |
|
server.add_insecure_port('[::]:8081') |
|
|
|
await server.start() |
|
print("gRPC server is running on port 8081") |
|
|
|
|
|
await server.wait_for_termination() |
|
|
|
if __name__ == "__main__": |
|
|
|
asyncio.run(serve()) |
|
|