Upload 9 files
Browse files- examples/agent_from_any_llm.py +61 -0
- examples/gradio_ui.py +16 -0
- examples/inspect_multiagent_run.py +39 -0
- examples/multi_llm_agent.py +46 -0
- examples/multiple_tools.py +256 -0
- examples/rag.py +70 -0
- examples/rag_using_chromadb.py +130 -0
- examples/sandboxed_execution.py +12 -0
- examples/text_to_sql.py +79 -0
examples/agent_from_any_llm.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import (
|
2 |
+
CodeAgent,
|
3 |
+
InferenceClientModel,
|
4 |
+
LiteLLMModel,
|
5 |
+
OpenAIServerModel,
|
6 |
+
ToolCallingAgent,
|
7 |
+
TransformersModel,
|
8 |
+
tool,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
# Choose which inference type to use!
|
13 |
+
|
14 |
+
available_inferences = ["inference_client", "transformers", "ollama", "litellm", "openai"]
|
15 |
+
chosen_inference = "inference_client"
|
16 |
+
|
17 |
+
print(f"Chose model: '{chosen_inference}'")
|
18 |
+
|
19 |
+
if chosen_inference == "inference_client":
|
20 |
+
model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct", provider="nebius")
|
21 |
+
|
22 |
+
elif chosen_inference == "transformers":
|
23 |
+
model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)
|
24 |
+
|
25 |
+
elif chosen_inference == "ollama":
|
26 |
+
model = LiteLLMModel(
|
27 |
+
model_id="ollama_chat/llama3.2",
|
28 |
+
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
29 |
+
api_key="your-api-key", # replace with API key if necessary
|
30 |
+
num_ctx=8192, # ollama default is 2048 which will often fail horribly. 8192 works for easy tasks, more is better. Check https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator to calculate how much VRAM this will need for the selected model.
|
31 |
+
)
|
32 |
+
|
33 |
+
elif chosen_inference == "litellm":
|
34 |
+
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
|
35 |
+
model = LiteLLMModel(model_id="gpt-4o")
|
36 |
+
|
37 |
+
elif chosen_inference == "openai":
|
38 |
+
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
|
39 |
+
model = OpenAIServerModel(model_id="gpt-4o")
|
40 |
+
|
41 |
+
|
42 |
+
@tool
|
43 |
+
def get_weather(location: str, celsius: bool | None = False) -> str:
|
44 |
+
"""
|
45 |
+
Get weather in the next days at given location.
|
46 |
+
Secretly this tool does not care about the location, it hates the weather everywhere.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
location: the location
|
50 |
+
celsius: the temperature
|
51 |
+
"""
|
52 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
53 |
+
|
54 |
+
|
55 |
+
agent = ToolCallingAgent(tools=[get_weather], model=model, verbosity_level=2)
|
56 |
+
|
57 |
+
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
|
58 |
+
|
59 |
+
agent = CodeAgent(tools=[get_weather], model=model, verbosity_level=2, stream_outputs=True)
|
60 |
+
|
61 |
+
print("CodeAgent:", agent.run("What's the weather like in Paris?"))
|
examples/gradio_ui.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import CodeAgent, GradioUI, InferenceClientModel, WebSearchTool
|
2 |
+
|
3 |
+
|
4 |
+
agent = CodeAgent(
|
5 |
+
tools=[WebSearchTool()],
|
6 |
+
model=InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct", provider="fireworks-ai"),
|
7 |
+
verbosity_level=1,
|
8 |
+
planning_interval=3,
|
9 |
+
name="example_agent",
|
10 |
+
description="This is an example agent.",
|
11 |
+
step_callbacks=[],
|
12 |
+
stream_outputs=True,
|
13 |
+
# use_structured_outputs_internally=True,
|
14 |
+
)
|
15 |
+
|
16 |
+
GradioUI(agent, file_upload_folder="./data").launch()
|
examples/inspect_multiagent_run.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
2 |
+
from phoenix.otel import register
|
3 |
+
|
4 |
+
|
5 |
+
register()
|
6 |
+
SmolagentsInstrumentor().instrument(skip_dep_check=True)
|
7 |
+
|
8 |
+
|
9 |
+
from smolagents import (
|
10 |
+
CodeAgent,
|
11 |
+
InferenceClientModel,
|
12 |
+
ToolCallingAgent,
|
13 |
+
VisitWebpageTool,
|
14 |
+
WebSearchTool,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
# Then we run the agentic part!
|
19 |
+
model = InferenceClientModel(provider="nebius")
|
20 |
+
|
21 |
+
search_agent = ToolCallingAgent(
|
22 |
+
tools=[WebSearchTool(), VisitWebpageTool()],
|
23 |
+
model=model,
|
24 |
+
name="search_agent",
|
25 |
+
description="This is an agent that can do web search.",
|
26 |
+
return_full_result=True,
|
27 |
+
)
|
28 |
+
|
29 |
+
manager_agent = CodeAgent(
|
30 |
+
tools=[],
|
31 |
+
model=model,
|
32 |
+
managed_agents=[search_agent],
|
33 |
+
return_full_result=True,
|
34 |
+
)
|
35 |
+
run_result = manager_agent.run(
|
36 |
+
"If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
|
37 |
+
)
|
38 |
+
print("Here is the token usage for the manager agent", run_result.token_usage)
|
39 |
+
print("Here are the timing informations for the manager agent:", run_result.timing)
|
examples/multi_llm_agent.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from smolagents import CodeAgent, LiteLLMRouterModel, WebSearchTool
|
4 |
+
|
5 |
+
|
6 |
+
# Make sure to setup the necessary environment variables!
|
7 |
+
|
8 |
+
llm_loadbalancer_model_list = [
|
9 |
+
{
|
10 |
+
"model_name": "model-group-1",
|
11 |
+
"litellm_params": {
|
12 |
+
"model": "gpt-4o-mini",
|
13 |
+
"api_key": os.getenv("OPENAI_API_KEY"),
|
14 |
+
},
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"model_name": "model-group-1",
|
18 |
+
"litellm_params": {
|
19 |
+
"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
20 |
+
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
|
21 |
+
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
22 |
+
"aws_region_name": os.getenv("AWS_REGION"),
|
23 |
+
},
|
24 |
+
},
|
25 |
+
# {
|
26 |
+
# "model_name": "model-group-2",
|
27 |
+
# "litellm_params": {
|
28 |
+
# "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
29 |
+
# "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
|
30 |
+
# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
31 |
+
# "aws_region_name": os.getenv("AWS_REGION"),
|
32 |
+
# },
|
33 |
+
# },
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
model = LiteLLMRouterModel(
|
38 |
+
model_id="model-group-1",
|
39 |
+
model_list=llm_loadbalancer_model_list,
|
40 |
+
client_kwargs={"routing_strategy": "simple-shuffle"},
|
41 |
+
)
|
42 |
+
agent = CodeAgent(tools=[WebSearchTool()], model=model, stream_outputs=True, return_full_result=True)
|
43 |
+
|
44 |
+
full_result = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
|
45 |
+
|
46 |
+
print(full_result)
|
examples/multiple_tools.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
# from smolagents.agents import ToolCallingAgent
|
4 |
+
from smolagents import CodeAgent, InferenceClientModel, tool
|
5 |
+
|
6 |
+
|
7 |
+
# Choose which LLM engine to use!
|
8 |
+
model = InferenceClientModel()
|
9 |
+
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
10 |
+
|
11 |
+
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
|
12 |
+
# model = LiteLLMModel(model_id="gpt-4o")
|
13 |
+
|
14 |
+
|
15 |
+
@tool
|
16 |
+
def get_weather(location: str, celsius: bool | None = False) -> str:
|
17 |
+
"""
|
18 |
+
Get the current weather at the given location using the WeatherStack API.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
location: The location (city name).
|
22 |
+
celsius: Whether to return the temperature in Celsius (default is False, which returns Fahrenheit).
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
A string describing the current weather at the location.
|
26 |
+
"""
|
27 |
+
api_key = "your_api_key" # Replace with your API key from https://weatherstack.com/
|
28 |
+
units = "m" if celsius else "f" # 'm' for Celsius, 'f' for Fahrenheit
|
29 |
+
|
30 |
+
url = f"http://api.weatherstack.com/current?access_key={api_key}&query={location}&units={units}"
|
31 |
+
|
32 |
+
try:
|
33 |
+
response = requests.get(url)
|
34 |
+
response.raise_for_status() # Raise an exception for HTTP errors
|
35 |
+
|
36 |
+
data = response.json()
|
37 |
+
|
38 |
+
if data.get("error"): # Check if there's an error in the response
|
39 |
+
return f"Error: {data['error'].get('info', 'Unable to fetch weather data.')}"
|
40 |
+
|
41 |
+
weather = data["current"]["weather_descriptions"][0]
|
42 |
+
temp = data["current"]["temperature"]
|
43 |
+
temp_unit = "°C" if celsius else "°F"
|
44 |
+
|
45 |
+
return f"The current weather in {location} is {weather} with a temperature of {temp} {temp_unit}."
|
46 |
+
|
47 |
+
except requests.exceptions.RequestException as e:
|
48 |
+
return f"Error fetching weather data: {str(e)}"
|
49 |
+
|
50 |
+
|
51 |
+
@tool
|
52 |
+
def convert_currency(amount: float, from_currency: str, to_currency: str) -> str:
|
53 |
+
"""
|
54 |
+
Converts a specified amount from one currency to another using the ExchangeRate-API.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
amount: The amount of money to convert.
|
58 |
+
from_currency: The currency code of the currency to convert from (e.g., 'USD').
|
59 |
+
to_currency: The currency code of the currency to convert to (e.g., 'EUR').
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
str: A string describing the converted amount in the target currency, or an error message if the conversion fails.
|
63 |
+
|
64 |
+
Raises:
|
65 |
+
requests.exceptions.RequestException: If there is an issue with the HTTP request to the ExchangeRate-API.
|
66 |
+
"""
|
67 |
+
api_key = "your_api_key" # Replace with your actual API key from https://www.exchangerate-api.com/
|
68 |
+
url = f"https://v6.exchangerate-api.com/v6/{api_key}/latest/{from_currency}"
|
69 |
+
|
70 |
+
try:
|
71 |
+
response = requests.get(url)
|
72 |
+
response.raise_for_status()
|
73 |
+
|
74 |
+
data = response.json()
|
75 |
+
exchange_rate = data["conversion_rates"].get(to_currency)
|
76 |
+
|
77 |
+
if not exchange_rate:
|
78 |
+
return f"Error: Unable to find exchange rate for {from_currency} to {to_currency}."
|
79 |
+
|
80 |
+
converted_amount = amount * exchange_rate
|
81 |
+
return f"{amount} {from_currency} is equal to {converted_amount} {to_currency}."
|
82 |
+
|
83 |
+
except requests.exceptions.RequestException as e:
|
84 |
+
return f"Error fetching conversion data: {str(e)}"
|
85 |
+
|
86 |
+
|
87 |
+
@tool
|
88 |
+
def get_news_headlines() -> str:
|
89 |
+
"""
|
90 |
+
Fetches the top news headlines from the News API for the United States.
|
91 |
+
This function makes a GET request to the News API to retrieve the top news headlines
|
92 |
+
for the United States. It returns the titles and sources of the top 5 articles as a
|
93 |
+
formatted string. If no articles are available, it returns a message indicating that
|
94 |
+
no news is available. In case of a request error, it returns an error message.
|
95 |
+
Returns:
|
96 |
+
str: A string containing the top 5 news headlines and their sources, or an error message.
|
97 |
+
"""
|
98 |
+
api_key = "your_api_key" # Replace with your actual API key from https://newsapi.org/
|
99 |
+
url = f"https://newsapi.org/v2/top-headlines?country=us&apiKey={api_key}"
|
100 |
+
|
101 |
+
try:
|
102 |
+
response = requests.get(url)
|
103 |
+
response.raise_for_status()
|
104 |
+
|
105 |
+
data = response.json()
|
106 |
+
articles = data["articles"]
|
107 |
+
|
108 |
+
if not articles:
|
109 |
+
return "No news available at the moment."
|
110 |
+
|
111 |
+
headlines = [f"{article['title']} - {article['source']['name']}" for article in articles[:5]]
|
112 |
+
return "\n".join(headlines)
|
113 |
+
|
114 |
+
except requests.exceptions.RequestException as e:
|
115 |
+
return f"Error fetching news data: {str(e)}"
|
116 |
+
|
117 |
+
|
118 |
+
@tool
|
119 |
+
def get_joke() -> str:
|
120 |
+
"""
|
121 |
+
Fetches a random joke from the JokeAPI.
|
122 |
+
This function sends a GET request to the JokeAPI to retrieve a random joke.
|
123 |
+
It handles both single jokes and two-part jokes (setup and delivery).
|
124 |
+
If the request fails or the response does not contain a joke, an error message is returned.
|
125 |
+
Returns:
|
126 |
+
str: The joke as a string, or an error message if the joke could not be fetched.
|
127 |
+
"""
|
128 |
+
url = "https://v2.jokeapi.dev/joke/Any?type=single"
|
129 |
+
|
130 |
+
try:
|
131 |
+
response = requests.get(url)
|
132 |
+
response.raise_for_status()
|
133 |
+
|
134 |
+
data = response.json()
|
135 |
+
|
136 |
+
if "joke" in data:
|
137 |
+
return data["joke"]
|
138 |
+
elif "setup" in data and "delivery" in data:
|
139 |
+
return f"{data['setup']} - {data['delivery']}"
|
140 |
+
else:
|
141 |
+
return "Error: Unable to fetch joke."
|
142 |
+
|
143 |
+
except requests.exceptions.RequestException as e:
|
144 |
+
return f"Error fetching joke: {str(e)}"
|
145 |
+
|
146 |
+
|
147 |
+
@tool
|
148 |
+
def get_time_in_timezone(location: str) -> str:
|
149 |
+
"""
|
150 |
+
Fetches the current time for a given location using the World Time API.
|
151 |
+
Args:
|
152 |
+
location: The location for which to fetch the current time, formatted as 'Region/City'.
|
153 |
+
Returns:
|
154 |
+
str: A string indicating the current time in the specified location, or an error message if the request fails.
|
155 |
+
Raises:
|
156 |
+
requests.exceptions.RequestException: If there is an issue with the HTTP request.
|
157 |
+
"""
|
158 |
+
url = f"http://worldtimeapi.org/api/timezone/{location}.json"
|
159 |
+
|
160 |
+
try:
|
161 |
+
response = requests.get(url)
|
162 |
+
response.raise_for_status()
|
163 |
+
|
164 |
+
data = response.json()
|
165 |
+
current_time = data["datetime"]
|
166 |
+
|
167 |
+
return f"The current time in {location} is {current_time}."
|
168 |
+
|
169 |
+
except requests.exceptions.RequestException as e:
|
170 |
+
return f"Error fetching time data: {str(e)}"
|
171 |
+
|
172 |
+
|
173 |
+
@tool
|
174 |
+
def get_random_fact() -> str:
|
175 |
+
"""
|
176 |
+
Fetches a random fact from the "uselessfacts.jsph.pl" API.
|
177 |
+
Returns:
|
178 |
+
str: A string containing the random fact or an error message if the request fails.
|
179 |
+
"""
|
180 |
+
url = "https://uselessfacts.jsph.pl/random.json?language=en"
|
181 |
+
|
182 |
+
try:
|
183 |
+
response = requests.get(url)
|
184 |
+
response.raise_for_status()
|
185 |
+
|
186 |
+
data = response.json()
|
187 |
+
|
188 |
+
return f"Random Fact: {data['text']}"
|
189 |
+
|
190 |
+
except requests.exceptions.RequestException as e:
|
191 |
+
return f"Error fetching random fact: {str(e)}"
|
192 |
+
|
193 |
+
|
194 |
+
@tool
|
195 |
+
def search_wikipedia(query: str) -> str:
|
196 |
+
"""
|
197 |
+
Fetches a summary of a Wikipedia page for a given query.
|
198 |
+
Args:
|
199 |
+
query: The search term to look up on Wikipedia.
|
200 |
+
Returns:
|
201 |
+
str: A summary of the Wikipedia page if successful, or an error message if the request fails.
|
202 |
+
Raises:
|
203 |
+
requests.exceptions.RequestException: If there is an issue with the HTTP request.
|
204 |
+
"""
|
205 |
+
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{query}"
|
206 |
+
|
207 |
+
try:
|
208 |
+
response = requests.get(url)
|
209 |
+
response.raise_for_status()
|
210 |
+
|
211 |
+
data = response.json()
|
212 |
+
title = data["title"]
|
213 |
+
extract = data["extract"]
|
214 |
+
|
215 |
+
return f"Summary for {title}: {extract}"
|
216 |
+
|
217 |
+
except requests.exceptions.RequestException as e:
|
218 |
+
return f"Error fetching Wikipedia data: {str(e)}"
|
219 |
+
|
220 |
+
|
221 |
+
# If you want to use the ToolCallingAgent instead, uncomment the following lines as they both will work
|
222 |
+
|
223 |
+
# agent = ToolCallingAgent(
|
224 |
+
# tools=[
|
225 |
+
# convert_currency,
|
226 |
+
# get_weather,
|
227 |
+
# get_news_headlines,
|
228 |
+
# get_joke,
|
229 |
+
# get_random_fact,
|
230 |
+
# search_wikipedia,
|
231 |
+
# ],
|
232 |
+
# model=model,
|
233 |
+
# )
|
234 |
+
|
235 |
+
|
236 |
+
agent = CodeAgent(
|
237 |
+
tools=[
|
238 |
+
convert_currency,
|
239 |
+
get_weather,
|
240 |
+
get_news_headlines,
|
241 |
+
get_joke,
|
242 |
+
get_random_fact,
|
243 |
+
search_wikipedia,
|
244 |
+
],
|
245 |
+
model=model,
|
246 |
+
stream_outputs=True,
|
247 |
+
)
|
248 |
+
|
249 |
+
# Uncomment the line below to run the agent with a specific query
|
250 |
+
|
251 |
+
agent.run("Convert 5000 dollars to Euros")
|
252 |
+
# agent.run("What is the weather in New York?")
|
253 |
+
# agent.run("Give me the top news headlines")
|
254 |
+
# agent.run("Tell me a joke")
|
255 |
+
# agent.run("Tell me a Random Fact")
|
256 |
+
# agent.run("who is Elon Musk?")
|
examples/rag.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from huggingface_hub import login
|
2 |
+
|
3 |
+
# login()
|
4 |
+
import datasets
|
5 |
+
from langchain.docstore.document import Document
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from langchain_community.retrievers import BM25Retriever
|
8 |
+
|
9 |
+
|
10 |
+
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
11 |
+
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
12 |
+
|
13 |
+
source_docs = [
|
14 |
+
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
|
15 |
+
]
|
16 |
+
|
17 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
18 |
+
chunk_size=500,
|
19 |
+
chunk_overlap=50,
|
20 |
+
add_start_index=True,
|
21 |
+
strip_whitespace=True,
|
22 |
+
separators=["\n\n", "\n", ".", " ", ""],
|
23 |
+
)
|
24 |
+
docs_processed = text_splitter.split_documents(source_docs)
|
25 |
+
|
26 |
+
from smolagents import Tool
|
27 |
+
|
28 |
+
|
29 |
+
class RetrieverTool(Tool):
|
30 |
+
name = "retriever"
|
31 |
+
description = "Uses lexical search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
|
32 |
+
inputs = {
|
33 |
+
"query": {
|
34 |
+
"type": "string",
|
35 |
+
"description": "The query to perform. This should be lexically close to your target documents. Use the affirmative form rather than a question.",
|
36 |
+
}
|
37 |
+
}
|
38 |
+
output_type = "string"
|
39 |
+
|
40 |
+
def __init__(self, docs, **kwargs):
|
41 |
+
super().__init__(**kwargs)
|
42 |
+
self.retriever = BM25Retriever.from_documents(docs, k=10)
|
43 |
+
|
44 |
+
def forward(self, query: str) -> str:
|
45 |
+
assert isinstance(query, str), "Your search query must be a string"
|
46 |
+
|
47 |
+
docs = self.retriever.invoke(
|
48 |
+
query,
|
49 |
+
)
|
50 |
+
return "\nRetrieved documents:\n" + "".join(
|
51 |
+
[f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
from smolagents import CodeAgent, InferenceClientModel
|
56 |
+
|
57 |
+
|
58 |
+
retriever_tool = RetrieverTool(docs_processed)
|
59 |
+
agent = CodeAgent(
|
60 |
+
tools=[retriever_tool],
|
61 |
+
model=InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
|
62 |
+
max_steps=4,
|
63 |
+
verbosity_level=2,
|
64 |
+
stream_outputs=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
68 |
+
|
69 |
+
print("Final output:")
|
70 |
+
print(agent_output)
|
examples/rag_using_chromadb.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
from langchain.docstore.document import Document
|
5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from langchain_chroma import Chroma
|
7 |
+
|
8 |
+
# from langchain_community.document_loaders import PyPDFLoader
|
9 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
# from langchain_openai import OpenAIEmbeddings
|
14 |
+
from smolagents import LiteLLMModel, Tool
|
15 |
+
from smolagents.agents import CodeAgent
|
16 |
+
|
17 |
+
|
18 |
+
# from smolagents.agents import ToolCallingAgent
|
19 |
+
|
20 |
+
|
21 |
+
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
22 |
+
|
23 |
+
source_docs = [
|
24 |
+
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
|
25 |
+
]
|
26 |
+
|
27 |
+
## For your own PDFs, you can use the following code to load them into source_docs
|
28 |
+
# pdf_directory = "pdfs"
|
29 |
+
# pdf_files = [
|
30 |
+
# os.path.join(pdf_directory, f)
|
31 |
+
# for f in os.listdir(pdf_directory)
|
32 |
+
# if f.endswith(".pdf")
|
33 |
+
# ]
|
34 |
+
# source_docs = []
|
35 |
+
|
36 |
+
# for file_path in pdf_files:
|
37 |
+
# loader = PyPDFLoader(file_path)
|
38 |
+
# docs.extend(loader.load())
|
39 |
+
|
40 |
+
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
41 |
+
AutoTokenizer.from_pretrained("thenlper/gte-small"),
|
42 |
+
chunk_size=200,
|
43 |
+
chunk_overlap=20,
|
44 |
+
add_start_index=True,
|
45 |
+
strip_whitespace=True,
|
46 |
+
separators=["\n\n", "\n", ".", " ", ""],
|
47 |
+
)
|
48 |
+
|
49 |
+
# Split docs and keep only unique ones
|
50 |
+
print("Splitting documents...")
|
51 |
+
docs_processed = []
|
52 |
+
unique_texts = {}
|
53 |
+
for doc in tqdm(source_docs):
|
54 |
+
new_docs = text_splitter.split_documents([doc])
|
55 |
+
for new_doc in new_docs:
|
56 |
+
if new_doc.page_content not in unique_texts:
|
57 |
+
unique_texts[new_doc.page_content] = True
|
58 |
+
docs_processed.append(new_doc)
|
59 |
+
|
60 |
+
|
61 |
+
print("Embedding documents... This should take a few minutes (5 minutes on MacBook with M1 Pro)")
|
62 |
+
# Initialize embeddings and ChromaDB vector store
|
63 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
64 |
+
|
65 |
+
|
66 |
+
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
67 |
+
|
68 |
+
vector_store = Chroma.from_documents(docs_processed, embeddings, persist_directory="./chroma_db")
|
69 |
+
|
70 |
+
|
71 |
+
class RetrieverTool(Tool):
|
72 |
+
name = "retriever"
|
73 |
+
description = (
|
74 |
+
"Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
|
75 |
+
)
|
76 |
+
inputs = {
|
77 |
+
"query": {
|
78 |
+
"type": "string",
|
79 |
+
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
80 |
+
}
|
81 |
+
}
|
82 |
+
output_type = "string"
|
83 |
+
|
84 |
+
def __init__(self, vector_store, **kwargs):
|
85 |
+
super().__init__(**kwargs)
|
86 |
+
self.vector_store = vector_store
|
87 |
+
|
88 |
+
def forward(self, query: str) -> str:
|
89 |
+
assert isinstance(query, str), "Your search query must be a string"
|
90 |
+
docs = self.vector_store.similarity_search(query, k=3)
|
91 |
+
return "\nRetrieved documents:\n" + "".join(
|
92 |
+
[f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
retriever_tool = RetrieverTool(vector_store)
|
97 |
+
|
98 |
+
# Choose which LLM engine to use!
|
99 |
+
|
100 |
+
# from smolagents import InferenceClientModel
|
101 |
+
# model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
102 |
+
|
103 |
+
# from smolagents import TransformersModel
|
104 |
+
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
105 |
+
|
106 |
+
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620' and also change 'os.environ.get("ANTHROPIC_API_KEY")'
|
107 |
+
model = LiteLLMModel(
|
108 |
+
model_id="groq/llama-3.3-70b-versatile",
|
109 |
+
api_key=os.environ.get("GROQ_API_KEY"),
|
110 |
+
)
|
111 |
+
|
112 |
+
# # You can also use the ToolCallingAgent class
|
113 |
+
# agent = ToolCallingAgent(
|
114 |
+
# tools=[retriever_tool],
|
115 |
+
# model=model,
|
116 |
+
# verbose=True,
|
117 |
+
# )
|
118 |
+
|
119 |
+
agent = CodeAgent(
|
120 |
+
tools=[retriever_tool],
|
121 |
+
model=model,
|
122 |
+
max_steps=4,
|
123 |
+
verbosity_level=2,
|
124 |
+
)
|
125 |
+
|
126 |
+
agent_output = agent.run("How can I push a model to the Hub?")
|
127 |
+
|
128 |
+
|
129 |
+
print("Final output:")
|
130 |
+
print(agent_output)
|
examples/sandboxed_execution.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import CodeAgent, InferenceClientModel, WebSearchTool
|
2 |
+
|
3 |
+
|
4 |
+
model = InferenceClientModel()
|
5 |
+
|
6 |
+
agent = CodeAgent(tools=[WebSearchTool()], model=model, executor_type="docker")
|
7 |
+
output = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
|
8 |
+
print("Docker executor result:", output)
|
9 |
+
|
10 |
+
agent = CodeAgent(tools=[WebSearchTool()], model=model, executor_type="e2b")
|
11 |
+
output = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
|
12 |
+
print("E2B executor result:", output)
|
examples/text_to_sql.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import (
|
2 |
+
Column,
|
3 |
+
Float,
|
4 |
+
Integer,
|
5 |
+
MetaData,
|
6 |
+
String,
|
7 |
+
Table,
|
8 |
+
create_engine,
|
9 |
+
insert,
|
10 |
+
inspect,
|
11 |
+
text,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
engine = create_engine("sqlite:///:memory:")
|
16 |
+
metadata_obj = MetaData()
|
17 |
+
|
18 |
+
# create city SQL table
|
19 |
+
table_name = "receipts"
|
20 |
+
receipts = Table(
|
21 |
+
table_name,
|
22 |
+
metadata_obj,
|
23 |
+
Column("receipt_id", Integer, primary_key=True),
|
24 |
+
Column("customer_name", String(16), primary_key=True),
|
25 |
+
Column("price", Float),
|
26 |
+
Column("tip", Float),
|
27 |
+
)
|
28 |
+
metadata_obj.create_all(engine)
|
29 |
+
|
30 |
+
rows = [
|
31 |
+
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
|
32 |
+
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
|
33 |
+
{"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
|
34 |
+
{"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
|
35 |
+
]
|
36 |
+
for row in rows:
|
37 |
+
stmt = insert(receipts).values(**row)
|
38 |
+
with engine.begin() as connection:
|
39 |
+
cursor = connection.execute(stmt)
|
40 |
+
|
41 |
+
inspector = inspect(engine)
|
42 |
+
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
43 |
+
|
44 |
+
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
45 |
+
print(table_description)
|
46 |
+
|
47 |
+
from smolagents import tool
|
48 |
+
|
49 |
+
|
50 |
+
@tool
|
51 |
+
def sql_engine(query: str) -> str:
|
52 |
+
"""
|
53 |
+
Allows you to perform SQL queries on the table. Returns a string representation of the result.
|
54 |
+
The table is named 'receipts'. Its description is as follows:
|
55 |
+
Columns:
|
56 |
+
- receipt_id: INTEGER
|
57 |
+
- customer_name: VARCHAR(16)
|
58 |
+
- price: FLOAT
|
59 |
+
- tip: FLOAT
|
60 |
+
|
61 |
+
Args:
|
62 |
+
query: The query to perform. This should be correct SQL.
|
63 |
+
"""
|
64 |
+
output = ""
|
65 |
+
with engine.connect() as con:
|
66 |
+
rows = con.execute(text(query))
|
67 |
+
for row in rows:
|
68 |
+
output += "\n" + str(row)
|
69 |
+
return output
|
70 |
+
|
71 |
+
|
72 |
+
from smolagents import CodeAgent, InferenceClientModel
|
73 |
+
|
74 |
+
|
75 |
+
agent = CodeAgent(
|
76 |
+
tools=[sql_engine],
|
77 |
+
model=InferenceClientModel(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
78 |
+
)
|
79 |
+
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|