Duibonduil commited on
Commit
dfd6145
·
verified ·
1 Parent(s): b58d21f

Upload 9 files

Browse files
examples/agent_from_any_llm.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import (
2
+ CodeAgent,
3
+ InferenceClientModel,
4
+ LiteLLMModel,
5
+ OpenAIServerModel,
6
+ ToolCallingAgent,
7
+ TransformersModel,
8
+ tool,
9
+ )
10
+
11
+
12
+ # Choose which inference type to use!
13
+
14
+ available_inferences = ["inference_client", "transformers", "ollama", "litellm", "openai"]
15
+ chosen_inference = "inference_client"
16
+
17
+ print(f"Chose model: '{chosen_inference}'")
18
+
19
+ if chosen_inference == "inference_client":
20
+ model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct", provider="nebius")
21
+
22
+ elif chosen_inference == "transformers":
23
+ model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)
24
+
25
+ elif chosen_inference == "ollama":
26
+ model = LiteLLMModel(
27
+ model_id="ollama_chat/llama3.2",
28
+ api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
29
+ api_key="your-api-key", # replace with API key if necessary
30
+ num_ctx=8192, # ollama default is 2048 which will often fail horribly. 8192 works for easy tasks, more is better. Check https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator to calculate how much VRAM this will need for the selected model.
31
+ )
32
+
33
+ elif chosen_inference == "litellm":
34
+ # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
35
+ model = LiteLLMModel(model_id="gpt-4o")
36
+
37
+ elif chosen_inference == "openai":
38
+ # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
39
+ model = OpenAIServerModel(model_id="gpt-4o")
40
+
41
+
42
+ @tool
43
+ def get_weather(location: str, celsius: bool | None = False) -> str:
44
+ """
45
+ Get weather in the next days at given location.
46
+ Secretly this tool does not care about the location, it hates the weather everywhere.
47
+
48
+ Args:
49
+ location: the location
50
+ celsius: the temperature
51
+ """
52
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
53
+
54
+
55
+ agent = ToolCallingAgent(tools=[get_weather], model=model, verbosity_level=2)
56
+
57
+ print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
58
+
59
+ agent = CodeAgent(tools=[get_weather], model=model, verbosity_level=2, stream_outputs=True)
60
+
61
+ print("CodeAgent:", agent.run("What's the weather like in Paris?"))
examples/gradio_ui.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, GradioUI, InferenceClientModel, WebSearchTool
2
+
3
+
4
+ agent = CodeAgent(
5
+ tools=[WebSearchTool()],
6
+ model=InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct", provider="fireworks-ai"),
7
+ verbosity_level=1,
8
+ planning_interval=3,
9
+ name="example_agent",
10
+ description="This is an example agent.",
11
+ step_callbacks=[],
12
+ stream_outputs=True,
13
+ # use_structured_outputs_internally=True,
14
+ )
15
+
16
+ GradioUI(agent, file_upload_folder="./data").launch()
examples/inspect_multiagent_run.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
2
+ from phoenix.otel import register
3
+
4
+
5
+ register()
6
+ SmolagentsInstrumentor().instrument(skip_dep_check=True)
7
+
8
+
9
+ from smolagents import (
10
+ CodeAgent,
11
+ InferenceClientModel,
12
+ ToolCallingAgent,
13
+ VisitWebpageTool,
14
+ WebSearchTool,
15
+ )
16
+
17
+
18
+ # Then we run the agentic part!
19
+ model = InferenceClientModel(provider="nebius")
20
+
21
+ search_agent = ToolCallingAgent(
22
+ tools=[WebSearchTool(), VisitWebpageTool()],
23
+ model=model,
24
+ name="search_agent",
25
+ description="This is an agent that can do web search.",
26
+ return_full_result=True,
27
+ )
28
+
29
+ manager_agent = CodeAgent(
30
+ tools=[],
31
+ model=model,
32
+ managed_agents=[search_agent],
33
+ return_full_result=True,
34
+ )
35
+ run_result = manager_agent.run(
36
+ "If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
37
+ )
38
+ print("Here is the token usage for the manager agent", run_result.token_usage)
39
+ print("Here are the timing informations for the manager agent:", run_result.timing)
examples/multi_llm_agent.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from smolagents import CodeAgent, LiteLLMRouterModel, WebSearchTool
4
+
5
+
6
+ # Make sure to setup the necessary environment variables!
7
+
8
+ llm_loadbalancer_model_list = [
9
+ {
10
+ "model_name": "model-group-1",
11
+ "litellm_params": {
12
+ "model": "gpt-4o-mini",
13
+ "api_key": os.getenv("OPENAI_API_KEY"),
14
+ },
15
+ },
16
+ {
17
+ "model_name": "model-group-1",
18
+ "litellm_params": {
19
+ "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
20
+ "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
21
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
22
+ "aws_region_name": os.getenv("AWS_REGION"),
23
+ },
24
+ },
25
+ # {
26
+ # "model_name": "model-group-2",
27
+ # "litellm_params": {
28
+ # "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
29
+ # "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
30
+ # "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
31
+ # "aws_region_name": os.getenv("AWS_REGION"),
32
+ # },
33
+ # },
34
+ ]
35
+
36
+
37
+ model = LiteLLMRouterModel(
38
+ model_id="model-group-1",
39
+ model_list=llm_loadbalancer_model_list,
40
+ client_kwargs={"routing_strategy": "simple-shuffle"},
41
+ )
42
+ agent = CodeAgent(tools=[WebSearchTool()], model=model, stream_outputs=True, return_full_result=True)
43
+
44
+ full_result = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
45
+
46
+ print(full_result)
examples/multiple_tools.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # from smolagents.agents import ToolCallingAgent
4
+ from smolagents import CodeAgent, InferenceClientModel, tool
5
+
6
+
7
+ # Choose which LLM engine to use!
8
+ model = InferenceClientModel()
9
+ # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
10
+
11
+ # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
12
+ # model = LiteLLMModel(model_id="gpt-4o")
13
+
14
+
15
+ @tool
16
+ def get_weather(location: str, celsius: bool | None = False) -> str:
17
+ """
18
+ Get the current weather at the given location using the WeatherStack API.
19
+
20
+ Args:
21
+ location: The location (city name).
22
+ celsius: Whether to return the temperature in Celsius (default is False, which returns Fahrenheit).
23
+
24
+ Returns:
25
+ A string describing the current weather at the location.
26
+ """
27
+ api_key = "your_api_key" # Replace with your API key from https://weatherstack.com/
28
+ units = "m" if celsius else "f" # 'm' for Celsius, 'f' for Fahrenheit
29
+
30
+ url = f"http://api.weatherstack.com/current?access_key={api_key}&query={location}&units={units}"
31
+
32
+ try:
33
+ response = requests.get(url)
34
+ response.raise_for_status() # Raise an exception for HTTP errors
35
+
36
+ data = response.json()
37
+
38
+ if data.get("error"): # Check if there's an error in the response
39
+ return f"Error: {data['error'].get('info', 'Unable to fetch weather data.')}"
40
+
41
+ weather = data["current"]["weather_descriptions"][0]
42
+ temp = data["current"]["temperature"]
43
+ temp_unit = "°C" if celsius else "°F"
44
+
45
+ return f"The current weather in {location} is {weather} with a temperature of {temp} {temp_unit}."
46
+
47
+ except requests.exceptions.RequestException as e:
48
+ return f"Error fetching weather data: {str(e)}"
49
+
50
+
51
+ @tool
52
+ def convert_currency(amount: float, from_currency: str, to_currency: str) -> str:
53
+ """
54
+ Converts a specified amount from one currency to another using the ExchangeRate-API.
55
+
56
+ Args:
57
+ amount: The amount of money to convert.
58
+ from_currency: The currency code of the currency to convert from (e.g., 'USD').
59
+ to_currency: The currency code of the currency to convert to (e.g., 'EUR').
60
+
61
+ Returns:
62
+ str: A string describing the converted amount in the target currency, or an error message if the conversion fails.
63
+
64
+ Raises:
65
+ requests.exceptions.RequestException: If there is an issue with the HTTP request to the ExchangeRate-API.
66
+ """
67
+ api_key = "your_api_key" # Replace with your actual API key from https://www.exchangerate-api.com/
68
+ url = f"https://v6.exchangerate-api.com/v6/{api_key}/latest/{from_currency}"
69
+
70
+ try:
71
+ response = requests.get(url)
72
+ response.raise_for_status()
73
+
74
+ data = response.json()
75
+ exchange_rate = data["conversion_rates"].get(to_currency)
76
+
77
+ if not exchange_rate:
78
+ return f"Error: Unable to find exchange rate for {from_currency} to {to_currency}."
79
+
80
+ converted_amount = amount * exchange_rate
81
+ return f"{amount} {from_currency} is equal to {converted_amount} {to_currency}."
82
+
83
+ except requests.exceptions.RequestException as e:
84
+ return f"Error fetching conversion data: {str(e)}"
85
+
86
+
87
+ @tool
88
+ def get_news_headlines() -> str:
89
+ """
90
+ Fetches the top news headlines from the News API for the United States.
91
+ This function makes a GET request to the News API to retrieve the top news headlines
92
+ for the United States. It returns the titles and sources of the top 5 articles as a
93
+ formatted string. If no articles are available, it returns a message indicating that
94
+ no news is available. In case of a request error, it returns an error message.
95
+ Returns:
96
+ str: A string containing the top 5 news headlines and their sources, or an error message.
97
+ """
98
+ api_key = "your_api_key" # Replace with your actual API key from https://newsapi.org/
99
+ url = f"https://newsapi.org/v2/top-headlines?country=us&apiKey={api_key}"
100
+
101
+ try:
102
+ response = requests.get(url)
103
+ response.raise_for_status()
104
+
105
+ data = response.json()
106
+ articles = data["articles"]
107
+
108
+ if not articles:
109
+ return "No news available at the moment."
110
+
111
+ headlines = [f"{article['title']} - {article['source']['name']}" for article in articles[:5]]
112
+ return "\n".join(headlines)
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ return f"Error fetching news data: {str(e)}"
116
+
117
+
118
+ @tool
119
+ def get_joke() -> str:
120
+ """
121
+ Fetches a random joke from the JokeAPI.
122
+ This function sends a GET request to the JokeAPI to retrieve a random joke.
123
+ It handles both single jokes and two-part jokes (setup and delivery).
124
+ If the request fails or the response does not contain a joke, an error message is returned.
125
+ Returns:
126
+ str: The joke as a string, or an error message if the joke could not be fetched.
127
+ """
128
+ url = "https://v2.jokeapi.dev/joke/Any?type=single"
129
+
130
+ try:
131
+ response = requests.get(url)
132
+ response.raise_for_status()
133
+
134
+ data = response.json()
135
+
136
+ if "joke" in data:
137
+ return data["joke"]
138
+ elif "setup" in data and "delivery" in data:
139
+ return f"{data['setup']} - {data['delivery']}"
140
+ else:
141
+ return "Error: Unable to fetch joke."
142
+
143
+ except requests.exceptions.RequestException as e:
144
+ return f"Error fetching joke: {str(e)}"
145
+
146
+
147
+ @tool
148
+ def get_time_in_timezone(location: str) -> str:
149
+ """
150
+ Fetches the current time for a given location using the World Time API.
151
+ Args:
152
+ location: The location for which to fetch the current time, formatted as 'Region/City'.
153
+ Returns:
154
+ str: A string indicating the current time in the specified location, or an error message if the request fails.
155
+ Raises:
156
+ requests.exceptions.RequestException: If there is an issue with the HTTP request.
157
+ """
158
+ url = f"http://worldtimeapi.org/api/timezone/{location}.json"
159
+
160
+ try:
161
+ response = requests.get(url)
162
+ response.raise_for_status()
163
+
164
+ data = response.json()
165
+ current_time = data["datetime"]
166
+
167
+ return f"The current time in {location} is {current_time}."
168
+
169
+ except requests.exceptions.RequestException as e:
170
+ return f"Error fetching time data: {str(e)}"
171
+
172
+
173
+ @tool
174
+ def get_random_fact() -> str:
175
+ """
176
+ Fetches a random fact from the "uselessfacts.jsph.pl" API.
177
+ Returns:
178
+ str: A string containing the random fact or an error message if the request fails.
179
+ """
180
+ url = "https://uselessfacts.jsph.pl/random.json?language=en"
181
+
182
+ try:
183
+ response = requests.get(url)
184
+ response.raise_for_status()
185
+
186
+ data = response.json()
187
+
188
+ return f"Random Fact: {data['text']}"
189
+
190
+ except requests.exceptions.RequestException as e:
191
+ return f"Error fetching random fact: {str(e)}"
192
+
193
+
194
+ @tool
195
+ def search_wikipedia(query: str) -> str:
196
+ """
197
+ Fetches a summary of a Wikipedia page for a given query.
198
+ Args:
199
+ query: The search term to look up on Wikipedia.
200
+ Returns:
201
+ str: A summary of the Wikipedia page if successful, or an error message if the request fails.
202
+ Raises:
203
+ requests.exceptions.RequestException: If there is an issue with the HTTP request.
204
+ """
205
+ url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{query}"
206
+
207
+ try:
208
+ response = requests.get(url)
209
+ response.raise_for_status()
210
+
211
+ data = response.json()
212
+ title = data["title"]
213
+ extract = data["extract"]
214
+
215
+ return f"Summary for {title}: {extract}"
216
+
217
+ except requests.exceptions.RequestException as e:
218
+ return f"Error fetching Wikipedia data: {str(e)}"
219
+
220
+
221
+ # If you want to use the ToolCallingAgent instead, uncomment the following lines as they both will work
222
+
223
+ # agent = ToolCallingAgent(
224
+ # tools=[
225
+ # convert_currency,
226
+ # get_weather,
227
+ # get_news_headlines,
228
+ # get_joke,
229
+ # get_random_fact,
230
+ # search_wikipedia,
231
+ # ],
232
+ # model=model,
233
+ # )
234
+
235
+
236
+ agent = CodeAgent(
237
+ tools=[
238
+ convert_currency,
239
+ get_weather,
240
+ get_news_headlines,
241
+ get_joke,
242
+ get_random_fact,
243
+ search_wikipedia,
244
+ ],
245
+ model=model,
246
+ stream_outputs=True,
247
+ )
248
+
249
+ # Uncomment the line below to run the agent with a specific query
250
+
251
+ agent.run("Convert 5000 dollars to Euros")
252
+ # agent.run("What is the weather in New York?")
253
+ # agent.run("Give me the top news headlines")
254
+ # agent.run("Tell me a joke")
255
+ # agent.run("Tell me a Random Fact")
256
+ # agent.run("who is Elon Musk?")
examples/rag.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from huggingface_hub import login
2
+
3
+ # login()
4
+ import datasets
5
+ from langchain.docstore.document import Document
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_community.retrievers import BM25Retriever
8
+
9
+
10
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
11
+ knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
12
+
13
+ source_docs = [
14
+ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
15
+ ]
16
+
17
+ text_splitter = RecursiveCharacterTextSplitter(
18
+ chunk_size=500,
19
+ chunk_overlap=50,
20
+ add_start_index=True,
21
+ strip_whitespace=True,
22
+ separators=["\n\n", "\n", ".", " ", ""],
23
+ )
24
+ docs_processed = text_splitter.split_documents(source_docs)
25
+
26
+ from smolagents import Tool
27
+
28
+
29
+ class RetrieverTool(Tool):
30
+ name = "retriever"
31
+ description = "Uses lexical search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
32
+ inputs = {
33
+ "query": {
34
+ "type": "string",
35
+ "description": "The query to perform. This should be lexically close to your target documents. Use the affirmative form rather than a question.",
36
+ }
37
+ }
38
+ output_type = "string"
39
+
40
+ def __init__(self, docs, **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.retriever = BM25Retriever.from_documents(docs, k=10)
43
+
44
+ def forward(self, query: str) -> str:
45
+ assert isinstance(query, str), "Your search query must be a string"
46
+
47
+ docs = self.retriever.invoke(
48
+ query,
49
+ )
50
+ return "\nRetrieved documents:\n" + "".join(
51
+ [f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
52
+ )
53
+
54
+
55
+ from smolagents import CodeAgent, InferenceClientModel
56
+
57
+
58
+ retriever_tool = RetrieverTool(docs_processed)
59
+ agent = CodeAgent(
60
+ tools=[retriever_tool],
61
+ model=InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
62
+ max_steps=4,
63
+ verbosity_level=2,
64
+ stream_outputs=True,
65
+ )
66
+
67
+ agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
68
+
69
+ print("Final output:")
70
+ print(agent_output)
examples/rag_using_chromadb.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import datasets
4
+ from langchain.docstore.document import Document
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_chroma import Chroma
7
+
8
+ # from langchain_community.document_loaders import PyPDFLoader
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer
12
+
13
+ # from langchain_openai import OpenAIEmbeddings
14
+ from smolagents import LiteLLMModel, Tool
15
+ from smolagents.agents import CodeAgent
16
+
17
+
18
+ # from smolagents.agents import ToolCallingAgent
19
+
20
+
21
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
22
+
23
+ source_docs = [
24
+ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
25
+ ]
26
+
27
+ ## For your own PDFs, you can use the following code to load them into source_docs
28
+ # pdf_directory = "pdfs"
29
+ # pdf_files = [
30
+ # os.path.join(pdf_directory, f)
31
+ # for f in os.listdir(pdf_directory)
32
+ # if f.endswith(".pdf")
33
+ # ]
34
+ # source_docs = []
35
+
36
+ # for file_path in pdf_files:
37
+ # loader = PyPDFLoader(file_path)
38
+ # docs.extend(loader.load())
39
+
40
+ text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
41
+ AutoTokenizer.from_pretrained("thenlper/gte-small"),
42
+ chunk_size=200,
43
+ chunk_overlap=20,
44
+ add_start_index=True,
45
+ strip_whitespace=True,
46
+ separators=["\n\n", "\n", ".", " ", ""],
47
+ )
48
+
49
+ # Split docs and keep only unique ones
50
+ print("Splitting documents...")
51
+ docs_processed = []
52
+ unique_texts = {}
53
+ for doc in tqdm(source_docs):
54
+ new_docs = text_splitter.split_documents([doc])
55
+ for new_doc in new_docs:
56
+ if new_doc.page_content not in unique_texts:
57
+ unique_texts[new_doc.page_content] = True
58
+ docs_processed.append(new_doc)
59
+
60
+
61
+ print("Embedding documents... This should take a few minutes (5 minutes on MacBook with M1 Pro)")
62
+ # Initialize embeddings and ChromaDB vector store
63
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
64
+
65
+
66
+ # embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
67
+
68
+ vector_store = Chroma.from_documents(docs_processed, embeddings, persist_directory="./chroma_db")
69
+
70
+
71
+ class RetrieverTool(Tool):
72
+ name = "retriever"
73
+ description = (
74
+ "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
75
+ )
76
+ inputs = {
77
+ "query": {
78
+ "type": "string",
79
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
80
+ }
81
+ }
82
+ output_type = "string"
83
+
84
+ def __init__(self, vector_store, **kwargs):
85
+ super().__init__(**kwargs)
86
+ self.vector_store = vector_store
87
+
88
+ def forward(self, query: str) -> str:
89
+ assert isinstance(query, str), "Your search query must be a string"
90
+ docs = self.vector_store.similarity_search(query, k=3)
91
+ return "\nRetrieved documents:\n" + "".join(
92
+ [f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
93
+ )
94
+
95
+
96
+ retriever_tool = RetrieverTool(vector_store)
97
+
98
+ # Choose which LLM engine to use!
99
+
100
+ # from smolagents import InferenceClientModel
101
+ # model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
102
+
103
+ # from smolagents import TransformersModel
104
+ # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
105
+
106
+ # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620' and also change 'os.environ.get("ANTHROPIC_API_KEY")'
107
+ model = LiteLLMModel(
108
+ model_id="groq/llama-3.3-70b-versatile",
109
+ api_key=os.environ.get("GROQ_API_KEY"),
110
+ )
111
+
112
+ # # You can also use the ToolCallingAgent class
113
+ # agent = ToolCallingAgent(
114
+ # tools=[retriever_tool],
115
+ # model=model,
116
+ # verbose=True,
117
+ # )
118
+
119
+ agent = CodeAgent(
120
+ tools=[retriever_tool],
121
+ model=model,
122
+ max_steps=4,
123
+ verbosity_level=2,
124
+ )
125
+
126
+ agent_output = agent.run("How can I push a model to the Hub?")
127
+
128
+
129
+ print("Final output:")
130
+ print(agent_output)
examples/sandboxed_execution.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, InferenceClientModel, WebSearchTool
2
+
3
+
4
+ model = InferenceClientModel()
5
+
6
+ agent = CodeAgent(tools=[WebSearchTool()], model=model, executor_type="docker")
7
+ output = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
8
+ print("Docker executor result:", output)
9
+
10
+ agent = CodeAgent(tools=[WebSearchTool()], model=model, executor_type="e2b")
11
+ output = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
12
+ print("E2B executor result:", output)
examples/text_to_sql.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import (
2
+ Column,
3
+ Float,
4
+ Integer,
5
+ MetaData,
6
+ String,
7
+ Table,
8
+ create_engine,
9
+ insert,
10
+ inspect,
11
+ text,
12
+ )
13
+
14
+
15
+ engine = create_engine("sqlite:///:memory:")
16
+ metadata_obj = MetaData()
17
+
18
+ # create city SQL table
19
+ table_name = "receipts"
20
+ receipts = Table(
21
+ table_name,
22
+ metadata_obj,
23
+ Column("receipt_id", Integer, primary_key=True),
24
+ Column("customer_name", String(16), primary_key=True),
25
+ Column("price", Float),
26
+ Column("tip", Float),
27
+ )
28
+ metadata_obj.create_all(engine)
29
+
30
+ rows = [
31
+ {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
32
+ {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
33
+ {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
34
+ {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
35
+ ]
36
+ for row in rows:
37
+ stmt = insert(receipts).values(**row)
38
+ with engine.begin() as connection:
39
+ cursor = connection.execute(stmt)
40
+
41
+ inspector = inspect(engine)
42
+ columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
43
+
44
+ table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
45
+ print(table_description)
46
+
47
+ from smolagents import tool
48
+
49
+
50
+ @tool
51
+ def sql_engine(query: str) -> str:
52
+ """
53
+ Allows you to perform SQL queries on the table. Returns a string representation of the result.
54
+ The table is named 'receipts'. Its description is as follows:
55
+ Columns:
56
+ - receipt_id: INTEGER
57
+ - customer_name: VARCHAR(16)
58
+ - price: FLOAT
59
+ - tip: FLOAT
60
+
61
+ Args:
62
+ query: The query to perform. This should be correct SQL.
63
+ """
64
+ output = ""
65
+ with engine.connect() as con:
66
+ rows = con.execute(text(query))
67
+ for row in rows:
68
+ output += "\n" + str(row)
69
+ return output
70
+
71
+
72
+ from smolagents import CodeAgent, InferenceClientModel
73
+
74
+
75
+ agent = CodeAgent(
76
+ tools=[sql_engine],
77
+ model=InferenceClientModel(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct"),
78
+ )
79
+ agent.run("Can you give me the name of the client who got the most expensive receipt?")