import json import os import time from pathlib import Path import gradio as gr import numpy as np from dotenv import load_dotenv from elevenlabs import ElevenLabs from fastapi import FastAPI from fastapi.responses import HTMLResponse, StreamingResponse from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, get_tts_model, get_twilio_turn_credentials, ) from fastrtc.utils import audio_to_bytes from gradio.utils import get_space from groq import Groq from pydantic import BaseModel from transformers import pipeline load_dotenv() groq_client = Groq() # Using a better but still small free model - distilGPT2 is more coherent than OPT-125M text_generation = pipeline( "text-generation", model="distilgpt2", # Better free model that's still small device_map="auto", truncation=True ) tts_client = ElevenLabs(api_key=os.environ["ELEVENLABS_API_KEY"]) curr_dir = Path(__file__).parent tts_model = get_tts_model() # Keep track of last response to prevent repetition last_response = "" def clean_response(text): """Clean the generated text to avoid repetition and improve quality""" # Remove repeated phrases (simple approach) sentences = text.split('. ') cleaned_sentences = [] for s in sentences: if s and s not in cleaned_sentences: cleaned_sentences.append(s) cleaned_text = '. '.join(cleaned_sentences) if not cleaned_text.endswith('.'): cleaned_text += '.' # Limit length to avoid very long responses if len(cleaned_text) > 200: cleaned_text = cleaned_text[:197] + "..." return cleaned_text def response( audio: tuple[int, np.ndarray], chatbot: list[dict] | None = None, ): global last_response chatbot = chatbot or [] try: prompt = groq_client.audio.transcriptions.create( file=("audio-file.mp3", audio_to_bytes(audio)), model="whisper-large-v3-turbo", response_format="verbose_json", ).text chatbot.append({"role": "user", "content": prompt}) yield AdditionalOutputs(chatbot) # Create a better prompt for the model context = "You are a helpful assistant. Keep your responses short and to the point." if chatbot and len(chatbot) > 1: # Add the last exchange for context context += f"\nPrevious: {chatbot[-2]['content']}\nYou: {prompt}" else: context += f"\nUser: {prompt}" # Generate response using the Hugging Face model generated_text = text_generation( context, max_length=150, # Longer context but still reasonable num_return_sequences=1, do_sample=True, top_p=0.92, temperature=0.7, repetition_penalty=1.2, # Penalize repetition ) # Extract only the new content full_text = generated_text[0]['generated_text'] response_text = full_text.replace(context, "").strip() # Clean up the response response_text = clean_response(response_text) # Prevent exact repetition from previous response if response_text == last_response: response_text = "I understand. Can you elaborate on that?" last_response = response_text # Add a fallback if the response is empty or too short if len(response_text) < 10: response_text = "I see. Could you tell me more about that?" chatbot.append({"role": "assistant", "content": response_text}) yield AdditionalOutputs(chatbot) # Send chatbot update first # Split the audio generation into smaller chunks to avoid repeating the entire message sentences = response_text.split('. ') start = time.time() print("starting tts", start) # Process each sentence separately for TTS for i, sentence in enumerate(sentences): if not sentence.strip(): continue sentence = sentence.strip() + "." for chunk in tts_model.stream_tts_sync(sentence): print(f"chunk {i}.{chunk}", time.time() - start) yield chunk print("finished tts", time.time() - start) except Exception as e: print(f"Error in response generation: {e}") error_message = "Sorry, I encountered an error processing your request." chatbot.append({"role": "assistant", "content": error_message}) yield AdditionalOutputs(chatbot) # Still try to produce audio for the error message try: for chunk in tts_model.stream_tts_sync(error_message): yield chunk except Exception: # If even TTS fails, just return with the error in chatbot pass chatbot = gr.Chatbot(type="messages") stream = Stream( modality="audio", mode="send-receive", handler=ReplyOnPause(response), additional_outputs_handler=lambda a, b: b, additional_inputs=[chatbot], additional_outputs=[chatbot], rtc_configuration=get_twilio_turn_credentials() if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=90 if get_space() else None, ) class Message(BaseModel): role: str content: str class InputData(BaseModel): webrtc_id: str chatbot: list[Message] app = FastAPI() stream.mount(app) @app.get("/") async def _(): rtc_config = get_twilio_turn_credentials() if get_space() else None html_content = (curr_dir / "index.html").read_text() html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content, status_code=200) @app.post("/input_hook") async def _(body: InputData): stream.set_input(body.webrtc_id, body.model_dump()["chatbot"]) return {"status": "ok"} @app.get("/outputs") def _(webrtc_id: str): async def output_stream(): async for output in stream.output_stream(webrtc_id): chatbot = output.args[0] if chatbot and len(chatbot) > 0: yield f"event: output\ndata: {json.dumps(chatbot[-1])}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") if __name__ == "__main__": import os if (mode := os.getenv("MODE")) == "UI": stream.ui.launch(server_port=7860, server_name="0.0.0.0") elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)