zakerytclarke commited on
Commit
dcd8a1e
·
verified ·
1 Parent(s): 511cc6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -14
app.py CHANGED
@@ -12,25 +12,28 @@ st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:",
12
 
13
  DISCORD_TOKEN = os.environ.get("discord_key")
14
 
15
- # ========= CONFIG =========
16
- CONFIG = {
17
- # "OneTrainer": TeapotAI(
18
- # documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
19
- # settings=TeapotAISettings(rag_num_results=7)
20
- # ),
21
- "Teapot AI": TeapotAI(
22
- documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
23
- settings=TeapotAISettings(rag_num_results=3)
24
- ),
25
- }
26
 
27
- # ========= SEARCH API =========
28
 
29
- API_KEY = os.environ.get("brave_api_key")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def brave_search_context(query, count=3):
32
  url = "https://api.search.brave.com/res/v1/web/search"
33
- headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
34
  params = {"q": query, "count": count}
35
 
36
  response = requests.get(url, headers=headers, params=params)
@@ -43,6 +46,155 @@ def brave_search_context(query, count=3):
43
  print(f"Error: {response.status_code}, {response.text}")
44
  return ""
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ========= DISCORD CLIENT =========
47
  intents = discord.Intents.default()
48
  intents.messages = True
 
12
 
13
  DISCORD_TOKEN = os.environ.get("discord_key")
14
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # ======= API KEYS =======
17
 
18
+ BRAVE_API_KEY = os.environ.get("brave_api_key")
19
+ WEATHER_API_KEY = os.environ.get("weather_api_key")
20
+
21
+ # ======== TOOLS ===========
22
+ import requests
23
+ from typing import Optional
24
+ from teapotai import TeapotTool
25
+ import re
26
+ import math
27
+ import pandas as pd
28
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging
29
+
30
+ ### SEARCH TOOL
31
+ class BraveWebSearch(BaseModel):
32
+ search_query: str = Field(..., description="the search string to answer the question")
33
 
34
  def brave_search_context(query, count=3):
35
  url = "https://api.search.brave.com/res/v1/web/search"
36
+ headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY}
37
  params = {"q": query, "count": count}
38
 
39
  response = requests.get(url, headers=headers, params=params)
 
46
  print(f"Error: {response.status_code}, {response.text}")
47
  return ""
48
 
49
+ ### CALCULATOR TOOL
50
+ def evaluate_expression(expr) -> str:
51
+ """
52
+ Evaluate a simple algebraic expression string safely.
53
+ Supports +, -, *, /, **, and parentheses.
54
+ Retries evaluation after stripping non-numeric/non-operator characters if needed.
55
+ """
56
+
57
+ expr = expr.expression
58
+
59
+ allowed_names = {
60
+ k: v for k, v in vars(__builtins__).items()
61
+ if k in ("abs", "round")
62
+ }
63
+
64
+ allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")})
65
+
66
+ def safe_eval(expression):
67
+ return eval(expression, {"__builtins__": None}, allowed_names)
68
+
69
+ try:
70
+ result = safe_eval(expr)
71
+ return f"{expr} = {result}"
72
+ except Exception as e:
73
+ print(f"Initial evaluation failed: {e}")
74
+ # Strip out any characters that are not numbers, parentheses, or valid operators
75
+ cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr)
76
+ try:
77
+ result = safe_eval(cleaned_expr)
78
+ return f"{cleaned_expr} = {result}"
79
+ except Exception as e2:
80
+ print(f"Retry also failed: {e2}")
81
+ return "Sorry, I am unable to calculate that."
82
+
83
+
84
+ class Calculator(BaseModel):
85
+ expression: str = Field(..., description="mathematical expression")
86
+
87
+
88
+ ### Weather Tool
89
+ def get_weather(city_name):
90
+ # OpenWeatherMap API endpoint
91
+ url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}'
92
+
93
+
94
+ # Send GET request to the OpenWeatherMap API
95
+ response = requests.get(url)
96
+
97
+ # Check if the request was successful
98
+ if response.status_code == 200:
99
+ data = response.json()
100
+
101
+ # Extract relevant weather information
102
+ city = data['name']
103
+ temperature = round(data['main']['temp'])
104
+ weather_description = data['weather'][0]['description']
105
+
106
+ # Print or return the results
107
+ return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F."
108
+ else:
109
+ return "City not found or there was an error with the request."
110
+
111
+ class Weather(BaseModel):
112
+ city_name: str = Field(..., description="The name of the city to pull the weather for")
113
+
114
+ def get_weather(weather_schema):
115
+ # OpenWeatherMap API endpoint
116
+ url = f'https://api.openweathermap.org/data/2.5/weather?appid=de016a1d30e7bbe278971c2b17aabca0&units=imperial&q={weather_schema.city_name}'
117
+
118
+
119
+ # Send GET request to the OpenWeatherMap API
120
+ response = requests.get(url)
121
+
122
+ # Check if the request was successful
123
+ if response.status_code == 200:
124
+ data = response.json()
125
+
126
+ # Extract relevant weather information
127
+ city = data['name']
128
+ temperature = round(data['main']['temp'])
129
+ weather_description = data['weather'][0]['description']
130
+
131
+ # Print or return the results
132
+ return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F."
133
+ else:
134
+ return "City not found or there was an error with the request."
135
+
136
+
137
+ ### Stupid Question Tool
138
+ class CountNumberLetter(BaseModel):
139
+ word: str = Field(..., description="the word to count the number of letters in")
140
+ letter: str = Field(..., description="the letter to count the occurences of")
141
+
142
+ def count_number_letters(obj):
143
+ letter = obj.letter.lower()
144
+ expression = obj.word.lower()
145
+ count = len([l for l in expression if l == letter])
146
+ if count == 1:
147
+ return f"There is 1 '{letter}' in '{expression}'"
148
+ return f"There are {count} '{letter}'s in '{expression}'"
149
+
150
+
151
+
152
+ DEFAULT_TOOLS = [
153
+ TeapotTool(
154
+ name="websearch",
155
+ description="Execute web searches with pagination and filtering",
156
+ schema=BraveWebSearch,
157
+ fn=brave_search_context
158
+ ),
159
+ TeapotTool(
160
+ name="calculator",
161
+ description="Perform calculations",
162
+ schema=Calculator,
163
+ fn=lambda expression: evaluate_expression(expression),
164
+ ),
165
+ TeapotTool(
166
+ name="letter_counter",
167
+ description="Can count how many times a letter occurs in a word.",
168
+ schema=CountNumberLetter,
169
+ fn=count_number_letters
170
+ ),
171
+ TeapotTool(
172
+ name="weather",
173
+ description="Can pull today's weather information for any city.",
174
+ schema=Weather,
175
+ fn=get_weather
176
+ )
177
+ ]
178
+
179
+ # ========= CONFIG =========
180
+ CONFIG = {
181
+ # "OneTrainer": TeapotAI(
182
+ # documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
183
+ # settings=TeapotAISettings(rag_num_results=7)
184
+ # ),
185
+ "Teapot AI": TeapotAI(
186
+ model = AutoModelForSeq2SeqLM.from_pretrained(
187
+ "teapotai/teapotllm",
188
+ revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8",
189
+ ),
190
+ documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
191
+ settings=TeapotAISettings(rag_num_results=3, log_level="debug"),
192
+ tools=DEFAULT_TOOLS
193
+ ),
194
+ }
195
+
196
+
197
+
198
  # ========= DISCORD CLIENT =========
199
  intents = discord.Intents.default()
200
  intents.messages = True