BryanBradfo's picture
change of everything
1ea0a09
import json
import os
import time
from pathlib import Path
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from elevenlabs import ElevenLabs
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_tts_model,
get_twilio_turn_credentials,
)
from fastrtc.utils import audio_to_bytes
from gradio.utils import get_space
from groq import Groq
from pydantic import BaseModel
from transformers import pipeline
load_dotenv()
groq_client = Groq()
# Using a better but still small free model - distilGPT2 is more coherent than OPT-125M
text_generation = pipeline(
"text-generation",
model="distilgpt2", # Better free model that's still small
device_map="auto",
truncation=True
)
tts_client = ElevenLabs(api_key=os.environ["ELEVENLABS_API_KEY"])
curr_dir = Path(__file__).parent
tts_model = get_tts_model()
# Keep track of last response to prevent repetition
last_response = ""
def clean_response(text):
"""Clean the generated text to avoid repetition and improve quality"""
# Remove repeated phrases (simple approach)
sentences = text.split('. ')
cleaned_sentences = []
for s in sentences:
if s and s not in cleaned_sentences:
cleaned_sentences.append(s)
cleaned_text = '. '.join(cleaned_sentences)
if not cleaned_text.endswith('.'):
cleaned_text += '.'
# Limit length to avoid very long responses
if len(cleaned_text) > 200:
cleaned_text = cleaned_text[:197] + "..."
return cleaned_text
def response(
audio: tuple[int, np.ndarray],
chatbot: list[dict] | None = None,
):
global last_response
chatbot = chatbot or []
try:
prompt = groq_client.audio.transcriptions.create(
file=("audio-file.mp3", audio_to_bytes(audio)),
model="whisper-large-v3-turbo",
response_format="verbose_json",
).text
chatbot.append({"role": "user", "content": prompt})
yield AdditionalOutputs(chatbot)
# Create a better prompt for the model
context = "You are a helpful assistant. Keep your responses short and to the point."
if chatbot and len(chatbot) > 1:
# Add the last exchange for context
context += f"\nPrevious: {chatbot[-2]['content']}\nYou: {prompt}"
else:
context += f"\nUser: {prompt}"
# Generate response using the Hugging Face model
generated_text = text_generation(
context,
max_length=150, # Longer context but still reasonable
num_return_sequences=1,
do_sample=True,
top_p=0.92,
temperature=0.7,
repetition_penalty=1.2, # Penalize repetition
)
# Extract only the new content
full_text = generated_text[0]['generated_text']
response_text = full_text.replace(context, "").strip()
# Clean up the response
response_text = clean_response(response_text)
# Prevent exact repetition from previous response
if response_text == last_response:
response_text = "I understand. Can you elaborate on that?"
last_response = response_text
# Add a fallback if the response is empty or too short
if len(response_text) < 10:
response_text = "I see. Could you tell me more about that?"
chatbot.append({"role": "assistant", "content": response_text})
yield AdditionalOutputs(chatbot) # Send chatbot update first
# Split the audio generation into smaller chunks to avoid repeating the entire message
sentences = response_text.split('. ')
start = time.time()
print("starting tts", start)
# Process each sentence separately for TTS
for i, sentence in enumerate(sentences):
if not sentence.strip():
continue
sentence = sentence.strip() + "."
for chunk in tts_model.stream_tts_sync(sentence):
print(f"chunk {i}.{chunk}", time.time() - start)
yield chunk
print("finished tts", time.time() - start)
except Exception as e:
print(f"Error in response generation: {e}")
error_message = "Sorry, I encountered an error processing your request."
chatbot.append({"role": "assistant", "content": error_message})
yield AdditionalOutputs(chatbot)
# Still try to produce audio for the error message
try:
for chunk in tts_model.stream_tts_sync(error_message):
yield chunk
except Exception:
# If even TTS fails, just return with the error in chatbot
pass
chatbot = gr.Chatbot(type="messages")
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[chatbot],
additional_outputs=[chatbot],
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
)
class Message(BaseModel):
role: str
content: str
class InputData(BaseModel):
webrtc_id: str
chatbot: list[Message]
app = FastAPI()
stream.mount(app)
@app.get("/")
async def _():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content, status_code=200)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
return {"status": "ok"}
@app.get("/outputs")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
chatbot = output.args[0]
if chatbot and len(chatbot) > 0:
yield f"event: output\ndata: {json.dumps(chatbot[-1])}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860, server_name="0.0.0.0")
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)