zakerytclarke's picture
Update app.py
c800610 verified
raw
history blame
4.71 kB
import streamlit as st
import os
import aiohttp
import asyncio
import discord
import pandas as pd
import requests
from teapotai import TeapotAI, TeapotAISettings
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
DISCORD_TOKEN = os.environ.get("discord_key")
# ========= CONFIG =========
CONFIG = {
# "OneTrainer": TeapotAI(
# documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
# settings=TeapotAISettings(rag_num_results=7)
# ),
"Teapot AI": TeapotAI(
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
settings=TeapotAISettings(rag_num_results=7)
),
}
# ========= SEARCH API =========
API_KEY = os.environ.get("brave_api_key")
def brave_search_context(query, count=1):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
print(results)
return "\n\n".join([res["title"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results])
else:
print(f"Error: {response.status_code}, {response.text}")
return ""
# ========= DISCORD CLIENT =========
intents = discord.Intents.default()
intents.messages = True
client = discord.Client(intents=intents)
async def handle_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input))
return response
async def debug_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
search_result = brave_search_context(user_input)
rag_results = teapot_instance.rag(query=user_input)
return "\n\n".join(rag_results), search_result
@client.event
async def on_ready():
print(f'Logged in as {client.user}')
@client.event
async def on_message(message):
if message.author == client.user:
return
if f'<@{client.user.id}>' not in message.content:
return
server_name = message.guild.name if message.guild else "Teapot AI"
print(server_name, message.author, message.content)
async with message.channel.typing():
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
response = await handle_teapot_inference(server_name, cleaned_message)
sent_message = await message.reply(response)
@client.event
async def on_reaction_add(reaction, user):
if user == client.user:
return
if str(reaction.emoji) not in ["❓", "❔"]:
return
message = reaction.message
# Make sure it's a bot message that was a reply
if message.author != client.user or not message.reference:
return
# Fetch the original message that this bot message replied to
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
original_message = await message.channel.fetch_message(message.reference.message_id)
user_input = original_message.content.strip()
server_name = message.guild.name if message.guild else "Teapot AI"
# Create a thread or use existing one
thread = message.thread
if thread is None:
thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60)
rag_result, search_result = await debug_teapot_inference(server_name, user_input)
debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-1000:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)+"```"
await thread.send(debug_response)
# ========= STREAMLIT =========
@st.cache_resource
def discord_loop():
st.session_state["initialized"] = True
client.run(DISCORD_TOKEN)
st.write("418 I'm a teapot")
return
discord_loop()