VocRT / app.py
Anurag
pool changes
b51c9d7
from concurrent import futures
import asyncio
import torch
from models import build_model
from collections import deque
import grpc
import text_to_speech_pb2
import text_to_speech_pb2_grpc
from chat_database import save_chat_entry
import fastAPI
from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_stream
from providers.llm_provider import getResponseWithRagAsync, getResponseAsync
import numpy as np
import os
import re
import time
from silero_vad import load_silero_vad, VADIterator
import random
from providers.filler_words import filler_phrases
from scipy.io.wavfile import write
from faster_whisper import WhisperModel
sampling_rate = 16_000
vad_model = load_silero_vad()
frame_size = 512
DEFAULT_VAD_THRESHOLD = 0.50
device = 'cuda' if torch.cuda.is_available() else 'cpu'
whisper_model = WhisperModel(
"small",
device=device,
compute_type="int8",
cpu_threads=os.cpu_count(),
download_root="./models"
)
MODEL = build_model('kokoro-v0_19.pth', device)
VOICE_NAME = [
'af',
'af_bella', 'af_sarah', 'am_adam', 'am_michael',
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
'af_nicole', 'af_sky',
][0]
VOICEPACK = torch.load(
f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
AUDIO_FILES_DIR = 'audio_files'
os.makedirs(AUDIO_FILES_DIR, exist_ok=True)
PRE_CHUNK_LIMIT_BYTES = frame_size * 2 * 20
transcription_pool = futures.ThreadPoolExecutor(max_workers=10)
terminators = ['.', '?', '!', '...', '…', '?!', '!?', '‽', '。', '؟', '۔']
BLACKLIST = {
"Give me a minute.",
"Let me check the details.",
"Give me a minute. Let me check the details."
}
SHORT_UTTERANCE_BYTES = 9_600 # 300 ms
dummy_audio = np.frombuffer(
np.zeros(int(16_000 * 5.0), dtype=np.float32), dtype=np.int16).astype(np.float32) / 32768.0
def _fw_transcribe_block(audio_f32: np.ndarray) -> dict:
segments, info = whisper_model.transcribe(
audio_f32,
language="en",
beam_size=1,
vad_filter=False,
initial_prompt="Indian English accent; do not make up words.",
no_speech_threshold=0.25,
log_prob_threshold=-0.6,
temperature=0
)
text = "".join(seg.text for seg in segments)
return {"text": text, "language": info.language,
"language_probability": info.language_probability}
async def safe_transcribe(audio_float32: np.ndarray):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
transcription_pool,
lambda: _fw_transcribe_block(audio_float32)
)
class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer):
def __init__(self):
super().__init__()
self._transcribe_lock = asyncio.Lock()
async def ProcessText(self, request_iterator, context):
try:
global VOICEPACK
print("New connection")
tts_queue = asyncio.Queue()
response_queue = asyncio.Queue()
# tts_queue = asyncio.Queue(maxsize=1000)
# response_queue = asyncio.Queue(maxsize=1000)
parameters = {
"processing_active": False,
"queue": deque(),
"file_number": 0,
"session_id": "",
"interrupt_seq": 0,
"temperature": 1,
"activeVoice": "af",
"in_speech": False,
"maxTokens": 500,
"audio_buffer": bytearray(),
"pre_chunks": bytearray(),
"silence_counter": 0.0,
"silence_duration": 0.8, # default duration in seconds
"silence_threshold": DEFAULT_VAD_THRESHOLD, # default amplitude threshold
"VOICEPACK": VOICEPACK,
"audio_count": 0,
"user_sequence": 0,
"last_file_number": 0,
"vad_iter": VADIterator(vad_model, sampling_rate=sampling_rate, threshold=DEFAULT_VAD_THRESHOLD)
}
reader = asyncio.create_task(
self._read_requests(request_iterator, tts_queue, response_queue, parameters))
tts = asyncio.create_task(self._tts_queue_worker(
tts_queue, response_queue, parameters))
try:
while True:
resp = await response_queue.get()
if resp is None:
break
yield resp
finally:
for t in (reader, tts):
t.cancel()
try:
await t
except asyncio.CancelledError:
pass
except Exception as e:
print("Error in ProcessText:", e)
async def _read_requests(self, request_iterator, tts_queue: asyncio.Queue, response_queue: asyncio.Queue, parameters):
async for request in request_iterator:
field = request.WhichOneof('request_data')
if field == 'metadata':
meta = request.metadata
# print("\n\nMetadata : ", meta)
if meta.session_id:
parameters["session_id"] = meta.session_id
if meta.temperature:
parameters["temperature"] = meta.temperature
if meta.maxTokens:
parameters["maxTokens"] = meta.maxTokens
if meta.activeVoice:
parameters["activeVoice"] = meta.activeVoice
parameters["VOICEPACK"] = torch.load(
f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device)
print("\n\nVoice model loaded successfully")
if meta.silenceDuration:
silence_duration = meta.silenceDuration / 1000
parameters["silence_duration"] = silence_duration
if meta.threshold:
parameters["silence_threshold"] = meta.threshold / 100
parameters["vad_iter"] = VADIterator(
vad_model, sampling_rate=sampling_rate, threshold=parameters["silence_threshold"])
print("\n\nPatameter : ", parameters)
# output = await safe_transcribe("output2.wav")
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=dummy_bytes(),
session_id=parameters["session_id"],
sequence_id="-10",
transcript="",
)
await response_queue.put(resp)
continue
elif field == 'audio_data':
buffer = request.audio_data.buffer
audio_data = np.frombuffer(buffer, dtype=np.int16)
float_chunk = audio_data.astype(np.float32) / 32768.0
vad_result = parameters["vad_iter"](float_chunk)
parameters["pre_chunks"].extend(buffer)
if len(parameters["pre_chunks"]) > PRE_CHUNK_LIMIT_BYTES:
overflow = len(
parameters["pre_chunks"]) - PRE_CHUNK_LIMIT_BYTES
del parameters["pre_chunks"][:overflow]
if vad_result:
if "start" in vad_result:
parameters["in_speech"] = True
parameters["audio_buffer"].extend(
parameters["pre_chunks"])
if "end" in vad_result:
if (len(parameters["audio_buffer"]) < SHORT_UTTERANCE_BYTES):
parameters["audio_buffer"].extend(
parameters["pre_chunks"])
parameters["in_speech"] = False
if parameters["in_speech"]:
parameters["audio_buffer"].extend(buffer)
parameters["silence_counter"] = 0.0
parameters["audio_count"] += 1
else:
sample_rate = 16000
duration = len(audio_data) / sample_rate
parameters["silence_counter"] += duration
if parameters["silence_counter"] >= parameters["silence_duration"]:
parameters["silence_counter"] = 0.0
if parameters["audio_count"] < 2:
parameters["audio_count"] = 0
continue
parameters["audio_count"] = 0
print("Silence ")
sample_rate = 16000
audio_float = np.frombuffer(
parameters["audio_buffer"], dtype=np.int16).astype(np.float32) / 32768.0
parameters["audio_buffer"] = bytearray()
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=dummy_bytes(),
session_id=parameters["session_id"],
sequence_id="-3",
transcript="",
)
await response_queue.put(resp)
# resp = text_to_speech_pb2.ProcessTextResponse(
# buffer=dummy_bytes(),
# session_id=parameters["session_id"],
# sequence_id="0",
# transcript="",
# )
# await response_queue.put(resp)
whisper_start_time = time.time()
result = ""
try:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
transcription_pool, lambda: _fw_transcribe_block(
audio_float)
)
# result = await safe_transcribe(audio_float)
except Exception as e:
await tts_queue.put(("Sorry! I am not able to catch that can you repeat again please!", parameters["file_number"]))
print("Error in transcribing text : ", e)
continue
whisper_end_time = time.time()
time_taken_to_transcribe = whisper_end_time - whisper_start_time
print(
f"Transcribing time: {time_taken_to_transcribe:.4f} seconds")
transcribed_text = result["text"]
print(
f"Transcribed Text :", transcribed_text)
if not transcribed_text.strip():
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=dummy_bytes(),
session_id=parameters["session_id"],
sequence_id="-5",
transcript="",
)
await response_queue.put(resp)
continue
# Transcript Detected ------------------------------------------------------------------------------------
if transcribed_text:
parameters["queue"].clear()
parameters["user_sequence"] += 1
parameters["last_file_number"] = parameters["file_number"]
while not response_queue.empty():
try:
response_queue.get_nowait()
response_queue.task_done()
except asyncio.QueueEmpty:
break
while not tts_queue.empty():
try:
tts_queue.get_nowait()
tts_queue.task_done()
except asyncio.QueueEmpty:
break
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=dummy_bytes(),
session_id=parameters["session_id"],
sequence_id="-4",
transcript="",
)
await response_queue.put(resp)
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=dummy_bytes(),
session_id=parameters["session_id"],
sequence_id="-2",
transcript=transcribed_text,
)
save_chat_entry(
parameters["session_id"], "user", transcribed_text)
await response_queue.put(resp)
# try:
# filler = random.choice(filler_phrases)
# # await tts_queue.put((filler, parameters["file_number"]))
# loop = asyncio.get_event_loop()
# loop.call_later(
# 0,
# # 1.0,
# lambda: asyncio.create_task(
# tts_queue.put(
# (filler, parameters["file_number"]))
# )
# )
# except Exception as e:
# print("Error in sendign error : ", e)
final_response = ""
complete_response = ""
current_user_sequence = parameters["user_sequence"]
response = await getResponseAsync(
transcribed_text, parameters["session_id"])
if response is None:
continue
for chunk in response:
if (current_user_sequence != parameters["user_sequence"]):
break
msg = chunk.choices[0].delta.content
if msg:
complete_response += msg
m = re.search(r'[.?!]', msg)
if m:
idx = m.start()
segment = msg[:idx+1]
leftover = msg[idx+1:]
else:
segment, leftover = msg, ''
final_response += segment
if segment.endswith(('.', '!', '?')):
parameters["file_number"] += 1
parameters["queue"].append(
(final_response, parameters["file_number"]))
await tts_queue.put((final_response, parameters["file_number"]))
final_response = leftover
if final_response.strip():
parameters["file_number"] += 1
parameters["queue"].append(
(final_response, parameters["file_number"]))
await tts_queue.put((final_response, parameters["file_number"]))
if ("Let me check" in complete_response):
final_response = ""
complete_response = ""
current_user_sequence = parameters["user_sequence"]
response = await getResponseWithRagAsync(
transcribed_text, parameters["session_id"])
for chunk in response:
if (current_user_sequence != parameters["user_sequence"]):
break
msg = chunk.choices[0].delta.content
if msg:
m = re.search(r'[.?!]', msg)
if m:
idx = m.start()
segment = msg[:idx+1]
leftover = msg[idx+1:]
else:
segment, leftover = msg, ''
final_response += segment
complete_response += segment
if segment.endswith(('.', '!', '?')):
parameters["file_number"] += 1
parameters["queue"].append(
(final_response, parameters["file_number"]))
await tts_queue.put((final_response, parameters["file_number"]))
final_response = leftover
if final_response.strip():
parameters["file_number"] += 1
parameters["queue"].append(
(final_response, parameters["file_number"]))
await tts_queue.put((final_response, parameters["file_number"]))
continue
elif field == 'status':
transcript = request.status.transcript
played_seq = request.status.played_seq
interrupt_seq = request.status.interrupt_seq
parameters["interrupt_seq"] = interrupt_seq
text = transcript.strip() if transcript else ""
if text and text not in BLACKLIST:
save_chat_entry(
parameters["session_id"],
"assistant",
transcript
)
continue
else:
continue
async def _tts_queue_worker(self, tts_queue: asyncio.Queue,
response_queue: asyncio.Queue,
params: dict):
"""
Pull (text, seq) off tts_queue, run generate_audio_stream, wrap each chunk
in ProcessTextResponse, and push into response_queue.
"""
while True:
item = await tts_queue.get()
tts_queue.task_done()
if item is None:
break
sentence, seq = item
# drop anything the client has already played:
if seq <= int(params["interrupt_seq"]):
continue
loop = asyncio.get_running_loop()
# stream the audio chunks, pack into gRPC responses
async for audio_chunk in generate_audio_stream(
sentence, MODEL, params["VOICEPACK"], VOICE_NAME
):
audio_bytes = get_audio_bytes(audio_chunk)
if seq <= int(params["last_file_number"]):
break
resp = text_to_speech_pb2.ProcessTextResponse(
buffer=audio_bytes,
session_id=params["session_id"],
sequence_id=str(seq),
transcript=sentence,
)
await response_queue.put(resp)
async def serve():
print("Starting gRPC server...")
# Use grpc.aio.server for the gRPC async server
server = grpc.aio.server()
text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(
TextToSpeechServicer(), server)
server.add_insecure_port('[::]:8081')
await server.start()
print("gRPC server is running on port 8081")
# The serve method should wait for the server to terminate asynchronously
await server.wait_for_termination()
if __name__ == "__main__":
# Use asyncio.run to run the asynchronous serve function
asyncio.run(serve())