Duibonduil commited on
Commit
9c31777
·
verified ·
1 Parent(s): 1e0a254

Upload 21 files

Browse files
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from smolagents.agents import MultiStepAgent
6
+ from smolagents.monitoring import LogLevel
7
+
8
+
9
+ # Import fixture modules as plugins
10
+ pytest_plugins = ["tests.fixtures.agents", "tests.fixtures.tools"]
11
+
12
+ original_multi_step_agent_init = MultiStepAgent.__init__
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def patch_multi_step_agent_with_suppressed_logging():
17
+ with patch.object(MultiStepAgent, "__init__", autospec=True) as mock_init:
18
+
19
+ def init_with_suppressed_logging(self, *args, verbosity_level=LogLevel.OFF, **kwargs):
20
+ original_multi_step_agent_init(self, *args, verbosity_level=verbosity_level, **kwargs)
21
+
22
+ mock_init.side_effect = init_with_suppressed_logging
23
+ yield
tests/test_agents.py ADDED
@@ -0,0 +1,2089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import io
16
+ import os
17
+ import re
18
+ import tempfile
19
+ import uuid
20
+ import warnings
21
+ from collections.abc import Generator
22
+ from contextlib import nullcontext as does_not_raise
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from textwrap import dedent
26
+ from typing import Optional
27
+ from unittest.mock import MagicMock, patch
28
+
29
+ import pytest
30
+ from huggingface_hub import (
31
+ ChatCompletionOutputFunctionDefinition,
32
+ ChatCompletionOutputMessage,
33
+ ChatCompletionOutputToolCall,
34
+ )
35
+ from rich.console import Console
36
+
37
+ from smolagents import EMPTY_PROMPT_TEMPLATES
38
+ from smolagents.agent_types import AgentImage, AgentText
39
+ from smolagents.agents import (
40
+ AgentError,
41
+ AgentMaxStepsError,
42
+ AgentToolCallError,
43
+ CodeAgent,
44
+ MultiStepAgent,
45
+ ToolCall,
46
+ ToolCallingAgent,
47
+ ToolOutput,
48
+ populate_template,
49
+ )
50
+ from smolagents.default_tools import DuckDuckGoSearchTool, FinalAnswerTool, PythonInterpreterTool, VisitWebpageTool
51
+ from smolagents.memory import (
52
+ ActionStep,
53
+ PlanningStep,
54
+ TaskStep,
55
+ )
56
+ from smolagents.models import (
57
+ ChatMessage,
58
+ ChatMessageToolCall,
59
+ ChatMessageToolCallFunction,
60
+ InferenceClientModel,
61
+ MessageRole,
62
+ Model,
63
+ TransformersModel,
64
+ )
65
+ from smolagents.monitoring import AgentLogger, LogLevel, TokenUsage
66
+ from smolagents.tools import Tool, tool
67
+ from smolagents.utils import (
68
+ BASE_BUILTIN_MODULES,
69
+ AgentExecutionError,
70
+ AgentGenerationError,
71
+ AgentToolExecutionError,
72
+ )
73
+
74
+
75
+ @dataclass
76
+ class ChoiceDeltaToolCallFunction:
77
+ arguments: Optional[str] = None
78
+ name: Optional[str] = None
79
+
80
+
81
+ @dataclass
82
+ class ChoiceDeltaToolCall:
83
+ index: Optional[int] = None
84
+ id: Optional[str] = None
85
+ function: Optional[ChoiceDeltaToolCallFunction] = None
86
+ type: Optional[str] = None
87
+
88
+
89
+ @dataclass
90
+ class ChoiceDelta:
91
+ content: Optional[str] = None
92
+ function_call: Optional[str] = None
93
+ refusal: Optional[str] = None
94
+ role: Optional[str] = None
95
+ tool_calls: Optional[list] = None
96
+
97
+
98
+ def get_new_path(suffix="") -> str:
99
+ directory = tempfile.mkdtemp()
100
+ return os.path.join(directory, str(uuid.uuid4()) + suffix)
101
+
102
+
103
+ @pytest.fixture
104
+ def agent_logger():
105
+ return AgentLogger(
106
+ LogLevel.DEBUG, console=Console(record=True, no_color=True, force_terminal=False, file=io.StringIO())
107
+ )
108
+
109
+
110
+ class FakeToolCallModel(Model):
111
+ def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
112
+ if len(messages) < 3:
113
+ return ChatMessage(
114
+ role=MessageRole.ASSISTANT,
115
+ content="",
116
+ tool_calls=[
117
+ ChatMessageToolCall(
118
+ id="call_0",
119
+ type="function",
120
+ function=ChatMessageToolCallFunction(
121
+ name="python_interpreter", arguments={"code": "2*3.6452"}
122
+ ),
123
+ )
124
+ ],
125
+ )
126
+ else:
127
+ return ChatMessage(
128
+ role=MessageRole.ASSISTANT,
129
+ content="",
130
+ tool_calls=[
131
+ ChatMessageToolCall(
132
+ id="call_1",
133
+ type="function",
134
+ function=ChatMessageToolCallFunction(name="final_answer", arguments={"answer": "7.2904"}),
135
+ )
136
+ ],
137
+ )
138
+
139
+
140
+ class FakeToolCallModelImage(Model):
141
+ def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
142
+ if len(messages) < 3:
143
+ return ChatMessage(
144
+ role=MessageRole.ASSISTANT,
145
+ content="",
146
+ tool_calls=[
147
+ ChatMessageToolCall(
148
+ id="call_0",
149
+ type="function",
150
+ function=ChatMessageToolCallFunction(
151
+ name="fake_image_generation_tool",
152
+ arguments={"prompt": "An image of a cat"},
153
+ ),
154
+ )
155
+ ],
156
+ )
157
+ else:
158
+ return ChatMessage(
159
+ role=MessageRole.ASSISTANT,
160
+ content="",
161
+ tool_calls=[
162
+ ChatMessageToolCall(
163
+ id="call_1",
164
+ type="function",
165
+ function=ChatMessageToolCallFunction(name="final_answer", arguments="image.png"),
166
+ )
167
+ ],
168
+ )
169
+
170
+
171
+ class FakeToolCallModelVL(Model):
172
+ def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
173
+ if len(messages) < 3:
174
+ return ChatMessage(
175
+ role=MessageRole.ASSISTANT,
176
+ content="",
177
+ tool_calls=[
178
+ ChatMessageToolCall(
179
+ id="call_0",
180
+ type="function",
181
+ function=ChatMessageToolCallFunction(
182
+ name="fake_image_understanding_tool",
183
+ arguments={
184
+ "prompt": "What is in this image?",
185
+ "image": "image.png",
186
+ },
187
+ ),
188
+ )
189
+ ],
190
+ )
191
+ else:
192
+ return ChatMessage(
193
+ role=MessageRole.ASSISTANT,
194
+ content="",
195
+ tool_calls=[
196
+ ChatMessageToolCall(
197
+ id="call_1",
198
+ type="function",
199
+ function=ChatMessageToolCallFunction(name="final_answer", arguments="The image is a cat."),
200
+ )
201
+ ],
202
+ )
203
+
204
+
205
+ class FakeCodeModel(Model):
206
+ def generate(self, messages, stop_sequences=None):
207
+ prompt = str(messages)
208
+ if "special_marker" not in prompt:
209
+ return ChatMessage(
210
+ role=MessageRole.ASSISTANT,
211
+ content="""
212
+ Thought: I should multiply 2 by 3.6452. special_marker
213
+ <code>
214
+ result = 2**3.6452
215
+ </code>
216
+ """,
217
+ )
218
+ else: # We're at step 2
219
+ return ChatMessage(
220
+ role=MessageRole.ASSISTANT,
221
+ content="""
222
+ Thought: I can now answer the initial question
223
+ <code>
224
+ final_answer(7.2904)
225
+ </code>
226
+ """,
227
+ )
228
+
229
+
230
+ class FakeCodeModelPlanning(Model):
231
+ def generate(self, messages, stop_sequences=None):
232
+ prompt = str(messages)
233
+ if "planning_marker" not in prompt:
234
+ return ChatMessage(
235
+ role=MessageRole.ASSISTANT,
236
+ content="llm plan update planning_marker",
237
+ token_usage=TokenUsage(input_tokens=10, output_tokens=10),
238
+ )
239
+ elif "action_marker" not in prompt:
240
+ return ChatMessage(
241
+ role=MessageRole.ASSISTANT,
242
+ content="""
243
+ Thought: I should multiply 2 by 3.6452. action_marker
244
+ <code>
245
+ result = 2**3.6452
246
+ </code>
247
+ """,
248
+ token_usage=TokenUsage(input_tokens=10, output_tokens=10),
249
+ )
250
+ else:
251
+ return ChatMessage(
252
+ role=MessageRole.ASSISTANT,
253
+ content="llm plan again",
254
+ token_usage=TokenUsage(input_tokens=10, output_tokens=10),
255
+ )
256
+
257
+
258
+ class FakeCodeModelError(Model):
259
+ def generate(self, messages, stop_sequences=None):
260
+ prompt = str(messages)
261
+ if "special_marker" not in prompt:
262
+ return ChatMessage(
263
+ role=MessageRole.ASSISTANT,
264
+ content="""
265
+ Thought: I should multiply 2 by 3.6452. special_marker
266
+ <code>
267
+ print("Flag!")
268
+ def error_function():
269
+ raise ValueError("error")
270
+
271
+ error_function()
272
+ </code>
273
+ """,
274
+ )
275
+ else: # We're at step 2
276
+ return ChatMessage(
277
+ role=MessageRole.ASSISTANT,
278
+ content="""
279
+ Thought: I faced an error in the previous step.
280
+ <code>
281
+ final_answer("got an error")
282
+ </code>
283
+ """,
284
+ )
285
+
286
+
287
+ class FakeCodeModelSyntaxError(Model):
288
+ def generate(self, messages, stop_sequences=None):
289
+ prompt = str(messages)
290
+ if "special_marker" not in prompt:
291
+ return ChatMessage(
292
+ role=MessageRole.ASSISTANT,
293
+ content="""
294
+ Thought: I should multiply 2 by 3.6452. special_marker
295
+ <code>
296
+ a = 2
297
+ b = a * 2
298
+ print("Failing due to unexpected indent")
299
+ print("Ok, calculation done!")
300
+ </code>
301
+ """,
302
+ )
303
+ else: # We're at step 2
304
+ return ChatMessage(
305
+ role=MessageRole.ASSISTANT,
306
+ content="""
307
+ Thought: I can now answer the initial question
308
+ <code>
309
+ final_answer("got an error")
310
+ </code>
311
+ """,
312
+ )
313
+
314
+
315
+ class FakeCodeModelImport(Model):
316
+ def generate(self, messages, stop_sequences=None):
317
+ return ChatMessage(
318
+ role=MessageRole.ASSISTANT,
319
+ content="""
320
+ Thought: I can answer the question
321
+ <code>
322
+ import numpy as np
323
+ final_answer("got an error")
324
+ </code>
325
+ """,
326
+ )
327
+
328
+
329
+ class FakeCodeModelFunctionDef(Model):
330
+ def generate(self, messages, stop_sequences=None):
331
+ prompt = str(messages)
332
+ if "special_marker" not in prompt:
333
+ return ChatMessage(
334
+ role=MessageRole.ASSISTANT,
335
+ content="""
336
+ Thought: Let's define the function. special_marker
337
+ <code>
338
+ import numpy as np
339
+
340
+ def moving_average(x, w):
341
+ return np.convolve(x, np.ones(w), 'valid') / w
342
+ </code>
343
+ """,
344
+ )
345
+ else: # We're at step 2
346
+ return ChatMessage(
347
+ role=MessageRole.ASSISTANT,
348
+ content="""
349
+ Thought: I can now answer the initial question
350
+ <code>
351
+ x, w = [0, 1, 2, 3, 4, 5], 2
352
+ res = moving_average(x, w)
353
+ final_answer(res)
354
+ </code>
355
+ """,
356
+ )
357
+
358
+
359
+ class FakeCodeModelSingleStep(Model):
360
+ def generate(self, messages, stop_sequences=None):
361
+ return ChatMessage(
362
+ role=MessageRole.ASSISTANT,
363
+ content="""
364
+ Thought: I should multiply 2 by 3.6452. special_marker
365
+ <code>
366
+ result = python_interpreter(code="2*3.6452")
367
+ final_answer(result)
368
+ ```
369
+ """,
370
+ )
371
+
372
+
373
+ class FakeCodeModelNoReturn(Model):
374
+ def generate(self, messages, stop_sequences=None):
375
+ return ChatMessage(
376
+ role=MessageRole.ASSISTANT,
377
+ content="""
378
+ Thought: I should multiply 2 by 3.6452. special_marker
379
+ <code>
380
+ result = python_interpreter(code="2*3.6452")
381
+ print(result)
382
+ ```
383
+ """,
384
+ )
385
+
386
+
387
+ class TestAgent:
388
+ def test_fake_toolcalling_agent(self):
389
+ agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
390
+ output = agent.run("What is 2 multiplied by 3.6452?")
391
+ assert isinstance(output, str)
392
+ assert "7.2904" in output
393
+ assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
394
+ assert "7.2904" in agent.memory.steps[1].observations
395
+ assert (
396
+ agent.memory.steps[2].model_output
397
+ == "Tool call call_1: calling 'final_answer' with arguments: {'answer': '7.2904'}"
398
+ )
399
+
400
+ def test_toolcalling_agent_handles_image_tool_outputs(self, shared_datadir):
401
+ import PIL.Image
402
+
403
+ @tool
404
+ def fake_image_generation_tool(prompt: str) -> PIL.Image.Image:
405
+ """Tool that generates an image.
406
+
407
+ Args:
408
+ prompt: The prompt
409
+ """
410
+
411
+ import PIL.Image
412
+
413
+ return PIL.Image.open(shared_datadir / "000000039769.png")
414
+
415
+ agent = ToolCallingAgent(
416
+ tools=[fake_image_generation_tool], model=FakeToolCallModelImage(), verbosity_level=10
417
+ )
418
+ output = agent.run("Make me an image.")
419
+ assert isinstance(output, AgentImage)
420
+ assert isinstance(agent.state["image.png"], PIL.Image.Image)
421
+
422
+ def test_toolcalling_agent_handles_image_inputs(self, shared_datadir):
423
+ import PIL.Image
424
+
425
+ image = PIL.Image.open(shared_datadir / "000000039769.png") # dummy input
426
+
427
+ @tool
428
+ def fake_image_understanding_tool(prompt: str, image: PIL.Image.Image) -> str:
429
+ """Tool that creates a caption for an image.
430
+
431
+ Args:
432
+ prompt: The prompt
433
+ image: The image
434
+ """
435
+ return "The image is a cat."
436
+
437
+ agent = ToolCallingAgent(tools=[fake_image_understanding_tool], model=FakeToolCallModelVL())
438
+ output = agent.run("Caption this image.", images=[image])
439
+ assert output == "The image is a cat."
440
+
441
+ def test_fake_code_agent(self):
442
+ agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel(), verbosity_level=10)
443
+ output = agent.run("What is 2 multiplied by 3.6452?")
444
+ assert isinstance(output, float)
445
+ assert output == 7.2904
446
+ assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
447
+ assert agent.memory.steps[2].tool_calls == [
448
+ ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_2")
449
+ ]
450
+
451
+ def test_additional_args_added_to_task(self):
452
+ agent = CodeAgent(tools=[], model=FakeCodeModel())
453
+ agent.run(
454
+ "What is 2 multiplied by 3.6452?",
455
+ additional_args={"instruction": "Remember this."},
456
+ )
457
+ assert "Remember this" in agent.task
458
+
459
+ def test_reset_conversations(self):
460
+ agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel())
461
+ output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
462
+ assert output == 7.2904
463
+ assert len(agent.memory.steps) == 3
464
+
465
+ output = agent.run("What is 2 multiplied by 3.6452?", reset=False)
466
+ assert output == 7.2904
467
+ assert len(agent.memory.steps) == 5
468
+
469
+ output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
470
+ assert output == 7.2904
471
+ assert len(agent.memory.steps) == 3
472
+
473
+ def test_setup_agent_with_empty_toolbox(self):
474
+ ToolCallingAgent(model=FakeToolCallModel(), tools=[])
475
+
476
+ def test_fails_max_steps(self):
477
+ agent = CodeAgent(
478
+ tools=[PythonInterpreterTool()],
479
+ model=FakeCodeModelNoReturn(), # use this callable because it never ends
480
+ max_steps=5,
481
+ )
482
+ answer = agent.run("What is 2 multiplied by 3.6452?")
483
+ assert len(agent.memory.steps) == 7 # Task step + 5 action steps + Final answer
484
+ assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
485
+ assert isinstance(answer, str)
486
+
487
+ agent = CodeAgent(
488
+ tools=[PythonInterpreterTool()],
489
+ model=FakeCodeModelNoReturn(), # use this callable because it never ends
490
+ max_steps=5,
491
+ )
492
+ answer = agent.run("What is 2 multiplied by 3.6452?", max_steps=3)
493
+ assert len(agent.memory.steps) == 5 # Task step + 3 action steps + Final answer
494
+ assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
495
+ assert isinstance(answer, str)
496
+
497
+ def test_tool_descriptions_get_baked_in_system_prompt(self):
498
+ tool = PythonInterpreterTool()
499
+ tool.name = "fake_tool_name"
500
+ tool.description = "fake_tool_description"
501
+ agent = CodeAgent(tools=[tool], model=FakeCodeModel())
502
+ agent.run("Empty task")
503
+ assert agent.system_prompt is not None
504
+ assert f"def {tool.name}(" in agent.system_prompt
505
+ assert f'"""{tool.description}' in agent.system_prompt
506
+
507
+ def test_module_imports_get_baked_in_system_prompt(self):
508
+ agent = CodeAgent(tools=[], model=FakeCodeModel())
509
+ agent.run("Empty task")
510
+ for module in BASE_BUILTIN_MODULES:
511
+ assert module in agent.system_prompt
512
+
513
+ def test_init_agent_with_different_toolsets(self):
514
+ toolset_1 = []
515
+ agent = CodeAgent(tools=toolset_1, model=FakeCodeModel())
516
+ assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
517
+
518
+ toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
519
+ with pytest.raises(ValueError) as e:
520
+ agent = CodeAgent(tools=toolset_2, model=FakeCodeModel())
521
+ assert "Each tool or managed_agent should have a unique name!" in str(e)
522
+
523
+ with pytest.raises(ValueError) as e:
524
+ agent.name = "python_interpreter"
525
+ agent.description = "empty"
526
+ CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel(), managed_agents=[agent])
527
+ assert "Each tool or managed_agent should have a unique name!" in str(e)
528
+
529
+ # check that python_interpreter base tool does not get added to CodeAgent
530
+ agent = CodeAgent(tools=[], model=FakeCodeModel(), add_base_tools=True)
531
+ assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage
532
+
533
+ # check that python_interpreter base tool gets added to ToolCallingAgent
534
+ agent = ToolCallingAgent(tools=[], model=FakeCodeModel(), add_base_tools=True)
535
+ assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage
536
+
537
+ def test_function_persistence_across_steps(self):
538
+ agent = CodeAgent(
539
+ tools=[],
540
+ model=FakeCodeModelFunctionDef(),
541
+ max_steps=2,
542
+ additional_authorized_imports=["numpy"],
543
+ verbosity_level=100,
544
+ )
545
+ res = agent.run("ok")
546
+ assert res[0] == 0.5
547
+
548
+ def test_init_managed_agent(self):
549
+ agent = CodeAgent(tools=[], model=FakeCodeModelFunctionDef(), name="managed_agent", description="Empty")
550
+ assert agent.name == "managed_agent"
551
+ assert agent.description == "Empty"
552
+
553
+ def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
554
+ managed_agent = CodeAgent(
555
+ tools=[], model=FakeCodeModelFunctionDef(), name="managed_agent", description="Empty"
556
+ )
557
+ manager_agent = CodeAgent(
558
+ tools=[],
559
+ model=FakeCodeModelFunctionDef(),
560
+ managed_agents=[managed_agent],
561
+ )
562
+ assert "You can also give tasks to team members." not in managed_agent.system_prompt
563
+ assert "{{managed_agents_descriptions}}" not in managed_agent.system_prompt
564
+ assert "You can also give tasks to team members." in manager_agent.system_prompt
565
+
566
+ def test_replay_shows_logs(self, agent_logger):
567
+ agent = CodeAgent(
568
+ tools=[],
569
+ model=FakeCodeModelImport(),
570
+ verbosity_level=0,
571
+ additional_authorized_imports=["numpy"],
572
+ logger=agent_logger,
573
+ )
574
+ agent.run("Count to 3")
575
+
576
+ str_output = agent_logger.console.export_text()
577
+
578
+ assert "New run" in str_output
579
+ assert 'final_answer("got' in str_output
580
+ assert "</code>" in str_output
581
+
582
+ agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel(), verbosity_level=0)
583
+ agent.logger = agent_logger
584
+
585
+ agent.run("What is 2 multiplied by 3.6452?")
586
+ agent.replay()
587
+
588
+ str_output = agent_logger.console.export_text()
589
+ assert "Tool call" in str_output
590
+ assert "arguments" in str_output
591
+
592
+ def test_code_nontrivial_final_answer_works(self):
593
+ class FakeCodeModelFinalAnswer(Model):
594
+ def generate(self, messages, stop_sequences=None):
595
+ return ChatMessage(
596
+ role=MessageRole.ASSISTANT,
597
+ content="""<code>
598
+ def nested_answer():
599
+ final_answer("Correct!")
600
+
601
+ nested_answer()
602
+ </code>""",
603
+ )
604
+
605
+ agent = CodeAgent(tools=[], model=FakeCodeModelFinalAnswer())
606
+
607
+ output = agent.run("Count to 3")
608
+ assert output == "Correct!"
609
+
610
+ def test_transformers_toolcalling_agent(self):
611
+ @tool
612
+ def weather_api(location: str, celsius: str = "") -> str:
613
+ """
614
+ Gets the weather in the next days at given location.
615
+ Secretly this tool does not care about the location, it hates the weather everywhere.
616
+
617
+ Args:
618
+ location: the location
619
+ celsius: the temperature type
620
+ """
621
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
622
+
623
+ model = TransformersModel(
624
+ model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
625
+ max_new_tokens=100,
626
+ device_map="auto",
627
+ do_sample=False,
628
+ )
629
+ agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1, verbosity_level=10)
630
+ task = "What is the weather in Paris? "
631
+ agent.run(task)
632
+ assert agent.memory.steps[0].task == task
633
+ assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
634
+ step_memory_dict = agent.memory.get_succinct_steps()[1]
635
+ assert step_memory_dict["model_output_message"]["tool_calls"][0]["function"]["name"] == "weather_api"
636
+ assert step_memory_dict["model_output_message"]["raw"]["completion_kwargs"]["max_new_tokens"] == 100
637
+ assert "model_input_messages" in agent.memory.get_full_steps()[1]
638
+ assert step_memory_dict["token_usage"]["total_tokens"] > 100
639
+ assert step_memory_dict["timing"]["duration"] > 0.1
640
+
641
+ def test_final_answer_checks(self):
642
+ error_string = "failed with error"
643
+
644
+ def check_always_fails(final_answer, agent_memory):
645
+ assert False, "Error raised in check"
646
+
647
+ agent = CodeAgent(model=FakeCodeModel(), tools=[], final_answer_checks=[check_always_fails])
648
+ agent.run("Dummy task.")
649
+ assert error_string in str(agent.write_memory_to_messages())
650
+ assert "Error raised in check" in str(agent.write_memory_to_messages())
651
+
652
+ agent = CodeAgent(
653
+ model=FakeCodeModel(),
654
+ tools=[],
655
+ final_answer_checks=[lambda x, y: x == 7.2904],
656
+ verbosity_level=1000,
657
+ )
658
+ output = agent.run("Dummy task.")
659
+ assert output == 7.2904 # Check that output is correct
660
+ assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2
661
+ assert error_string not in str(agent.write_memory_to_messages())
662
+
663
+ def test_generation_errors_are_raised(self):
664
+ class FakeCodeModel(Model):
665
+ def generate(self, messages, stop_sequences=None):
666
+ assert False, "Generation failed"
667
+
668
+ agent = CodeAgent(model=FakeCodeModel(), tools=[])
669
+ with pytest.raises(AgentGenerationError) as e:
670
+ agent.run("Dummy task.")
671
+ assert len(agent.memory.steps) == 2
672
+ assert "Generation failed" in str(e)
673
+
674
+ def test_planning_step_with_injected_memory(self):
675
+ """Test that agent properly uses update plan prompts when memory is injected before a run.
676
+
677
+ This test verifies:
678
+ 1. Planning steps are created with the correct frequency
679
+ 2. Injected memory is included in planning context
680
+ 3. Messages are properly formatted with expected roles and content
681
+ """
682
+ planning_interval = 1
683
+ max_steps = 4
684
+ task = "Continuous task"
685
+ previous_task = "Previous user request"
686
+
687
+ # Create agent with planning capability
688
+ agent = CodeAgent(
689
+ tools=[],
690
+ planning_interval=planning_interval,
691
+ model=FakeCodeModelPlanning(),
692
+ max_steps=max_steps,
693
+ )
694
+
695
+ # Inject memory before run to simulate existing conversation history
696
+ previous_step = TaskStep(task=previous_task)
697
+ agent.memory.steps.append(previous_step)
698
+
699
+ # Run the agent
700
+ agent.run(task, reset=False)
701
+
702
+ # Extract and validate planning steps
703
+ planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)]
704
+ assert len(planning_steps) > 2, "Expected multiple planning steps to be generated"
705
+
706
+ # Verify first planning step incorporates injected memory
707
+ first_planning_step = planning_steps[0]
708
+ input_messages = first_planning_step.model_input_messages
709
+
710
+ # Check message structure and content
711
+ assert len(input_messages) == 4, (
712
+ "First planning step should have 4 messages: system-plan-pre-update + memory + task + user-plan-post-update"
713
+ )
714
+
715
+ # Verify system message contains current task
716
+ system_message = input_messages[0]
717
+ assert system_message.role == "system", "First message should have system role"
718
+ assert task in system_message.content[0]["text"], f"System message should contain the current task: '{task}'"
719
+
720
+ # Verify memory message contains previous task
721
+ memory_message = input_messages[1]
722
+ assert previous_task in memory_message.content[0]["text"], (
723
+ f"Memory message should contain previous task: '{previous_task}'"
724
+ )
725
+
726
+ # Verify task message contains current task
727
+ task_message = input_messages[2]
728
+ assert task in task_message.content[0]["text"], f"Task message should contain current task: '{task}'"
729
+
730
+ # Verify user message for planning
731
+ user_message = input_messages[3]
732
+ assert user_message.role == "user", "Fourth message should have user role"
733
+
734
+ # Verify second planning step has more context from first agent actions
735
+ second_planning_step = planning_steps[1]
736
+ second_messages = second_planning_step.model_input_messages
737
+
738
+ # Check that conversation history is growing appropriately
739
+ assert len(second_messages) == 6, "Second planning step should have 6 messages including tool interactions"
740
+
741
+ # Verify all conversation elements are present
742
+ conversation_text = "".join([msg.content[0]["text"] for msg in second_messages if hasattr(msg, "content")])
743
+ assert previous_task in conversation_text, "Previous task should be included in the conversation history"
744
+ assert task in conversation_text, "Current task should be included in the conversation history"
745
+ assert "tools" in conversation_text, "Tool interactions should be included in the conversation history"
746
+
747
+
748
+ class CustomFinalAnswerTool(FinalAnswerTool):
749
+ def forward(self, answer) -> str:
750
+ return answer + "CUSTOM"
751
+
752
+
753
+ class MockTool(Tool):
754
+ def __init__(self, name):
755
+ self.name = name
756
+ self.description = "Mock tool description"
757
+ self.inputs = {}
758
+ self.output_type = "string"
759
+
760
+ def forward(self):
761
+ return "Mock tool output"
762
+
763
+
764
+ class MockAgent:
765
+ def __init__(self, name, tools, description="Mock agent description"):
766
+ self.name = name
767
+ self.tools = {t.name: t for t in tools}
768
+ self.description = description
769
+
770
+
771
+ class DummyMultiStepAgent(MultiStepAgent):
772
+ def step(self, memory_step: ActionStep) -> Generator[None]:
773
+ yield None
774
+
775
+ def initialize_system_prompt(self):
776
+ pass
777
+
778
+
779
+ class TestMultiStepAgent:
780
+ def test_instantiation_disables_logging_to_terminal(self):
781
+ fake_model = MagicMock()
782
+ agent = DummyMultiStepAgent(tools=[], model=fake_model)
783
+ assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture"
784
+
785
+ def test_instantiation_with_prompt_templates(self, prompt_templates):
786
+ agent = DummyMultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates)
787
+ assert agent.prompt_templates == prompt_templates
788
+ assert agent.prompt_templates["system_prompt"] == "This is a test system prompt."
789
+ assert "managed_agent" in agent.prompt_templates
790
+ assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}"
791
+ assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}"
792
+
793
+ @pytest.mark.parametrize(
794
+ "tools, expected_final_answer_tool",
795
+ [([], FinalAnswerTool), ([CustomFinalAnswerTool()], CustomFinalAnswerTool)],
796
+ )
797
+ def test_instantiation_with_final_answer_tool(self, tools, expected_final_answer_tool):
798
+ agent = DummyMultiStepAgent(tools=tools, model=MagicMock())
799
+ assert "final_answer" in agent.tools
800
+ assert isinstance(agent.tools["final_answer"], expected_final_answer_tool)
801
+
802
+ def test_instantiation_with_deprecated_grammar(self):
803
+ class SimpleAgent(MultiStepAgent):
804
+ def initialize_system_prompt(self) -> str:
805
+ return "Test system prompt"
806
+
807
+ # Test with a non-None grammar parameter
808
+ with pytest.warns(
809
+ FutureWarning, match="Parameter 'grammar' is deprecated and will be removed in version 1.20."
810
+ ):
811
+ SimpleAgent(tools=[], model=MagicMock(), grammar={"format": "json"}, verbosity_level=LogLevel.DEBUG)
812
+
813
+ # Verify no warning when grammar is None
814
+ with warnings.catch_warnings():
815
+ warnings.simplefilter("error") # Turn warnings into errors
816
+ SimpleAgent(tools=[], model=MagicMock(), grammar=None, verbosity_level=LogLevel.DEBUG)
817
+
818
+ def test_system_prompt_property(self):
819
+ """Test that system_prompt property is read-only and calls initialize_system_prompt."""
820
+
821
+ class SimpleAgent(MultiStepAgent):
822
+ def initialize_system_prompt(self) -> str:
823
+ return "Test system prompt"
824
+
825
+ def step(self, memory_step: ActionStep) -> Generator[None]:
826
+ yield None
827
+
828
+ # Create a simple agent with mocked model
829
+ model = MagicMock()
830
+ agent = SimpleAgent(tools=[], model=model)
831
+
832
+ # Test reading the property works and calls initialize_system_prompt
833
+ assert agent.system_prompt == "Test system prompt"
834
+
835
+ # Test setting the property raises AttributeError with correct message
836
+ with pytest.raises(
837
+ AttributeError,
838
+ match=re.escape(
839
+ """The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
840
+ ),
841
+ ):
842
+ agent.system_prompt = "New system prompt"
843
+
844
+ # assert "read-only" in str(exc_info.value)
845
+ # assert "Use 'self.prompt_templates[\"system_prompt\"]' instead" in str(exc_info.value)
846
+
847
+ def test_logs_display_thoughts_even_if_error(self):
848
+ class FakeJsonModelNoCall(Model):
849
+ def generate(self, messages, stop_sequences=None, tools_to_call_from=None):
850
+ return ChatMessage(
851
+ role=MessageRole.ASSISTANT,
852
+ content="""I don't want to call tools today""",
853
+ tool_calls=None,
854
+ raw="""I don't want to call tools today""",
855
+ )
856
+
857
+ agent_toolcalling = ToolCallingAgent(model=FakeJsonModelNoCall(), tools=[], max_steps=1, verbosity_level=10)
858
+ with agent_toolcalling.logger.console.capture() as capture:
859
+ agent_toolcalling.run("Dummy task")
860
+ assert "don't" in capture.get() and "want" in capture.get()
861
+
862
+ class FakeCodeModelNoCall(Model):
863
+ def generate(self, messages, stop_sequences=None):
864
+ return ChatMessage(
865
+ role=MessageRole.ASSISTANT,
866
+ content="""I don't want to write an action today""",
867
+ )
868
+
869
+ agent_code = CodeAgent(model=FakeCodeModelNoCall(), tools=[], max_steps=1, verbosity_level=10)
870
+ with agent_code.logger.console.capture() as capture:
871
+ agent_code.run("Dummy task")
872
+ assert "don't" in capture.get() and "want" in capture.get()
873
+
874
+ def test_step_number(self):
875
+ fake_model = MagicMock()
876
+ fake_model.generate.return_value = ChatMessage(
877
+ role=MessageRole.ASSISTANT,
878
+ content="Model output.",
879
+ tool_calls=None,
880
+ raw="Model output.",
881
+ token_usage=None,
882
+ )
883
+ max_steps = 2
884
+ agent = CodeAgent(tools=[], model=fake_model, max_steps=max_steps)
885
+ assert hasattr(agent, "step_number"), "step_number attribute should be defined"
886
+ assert agent.step_number == 0, "step_number should be initialized to 0"
887
+ agent.run("Test task")
888
+ assert hasattr(agent, "step_number"), "step_number attribute should be defined"
889
+ assert agent.step_number == max_steps + 1, "step_number should be max_steps + 1 after run method is called"
890
+
891
+ @pytest.mark.parametrize(
892
+ "step, expected_messages_list",
893
+ [
894
+ (
895
+ 1,
896
+ [
897
+ [
898
+ ChatMessage(
899
+ role=MessageRole.USER, content=[{"type": "text", "text": "INITIAL_PLAN_USER_PROMPT"}]
900
+ ),
901
+ ],
902
+ ],
903
+ ),
904
+ (
905
+ 2,
906
+ [
907
+ [
908
+ ChatMessage(
909
+ role=MessageRole.SYSTEM,
910
+ content=[{"type": "text", "text": "UPDATE_PLAN_SYSTEM_PROMPT"}],
911
+ ),
912
+ ChatMessage(
913
+ role=MessageRole.USER,
914
+ content=[{"type": "text", "text": "UPDATE_PLAN_USER_PROMPT"}],
915
+ ),
916
+ ],
917
+ ],
918
+ ),
919
+ ],
920
+ )
921
+ def test_planning_step(self, step, expected_messages_list):
922
+ fake_model = MagicMock()
923
+ agent = CodeAgent(
924
+ tools=[],
925
+ model=fake_model,
926
+ )
927
+ task = "Test task"
928
+
929
+ planning_step = list(agent._generate_planning_step(task, is_first_step=(step == 1), step=step))[-1]
930
+ expected_message_texts = {
931
+ "INITIAL_PLAN_USER_PROMPT": populate_template(
932
+ agent.prompt_templates["planning"]["initial_plan"],
933
+ variables=dict(
934
+ task=task,
935
+ tools=agent.tools,
936
+ managed_agents=agent.managed_agents,
937
+ answer_facts=planning_step.model_output_message.content,
938
+ ),
939
+ ),
940
+ "UPDATE_PLAN_SYSTEM_PROMPT": populate_template(
941
+ agent.prompt_templates["planning"]["update_plan_pre_messages"], variables=dict(task=task)
942
+ ),
943
+ "UPDATE_PLAN_USER_PROMPT": populate_template(
944
+ agent.prompt_templates["planning"]["update_plan_post_messages"],
945
+ variables=dict(
946
+ task=task,
947
+ tools=agent.tools,
948
+ managed_agents=agent.managed_agents,
949
+ facts_update=planning_step.model_output_message.content,
950
+ remaining_steps=agent.max_steps - step,
951
+ ),
952
+ ),
953
+ }
954
+ for expected_messages in expected_messages_list:
955
+ for expected_message in expected_messages:
956
+ expected_message.content[0]["text"] = expected_message_texts[expected_message.content[0]["text"]]
957
+ assert isinstance(planning_step, PlanningStep)
958
+ expected_model_input_messages = expected_messages_list[0]
959
+ model_input_messages = planning_step.model_input_messages
960
+ assert isinstance(model_input_messages, list)
961
+ assert len(model_input_messages) == len(expected_model_input_messages) # 2
962
+ for message, expected_message in zip(model_input_messages, expected_model_input_messages):
963
+ assert isinstance(message, ChatMessage)
964
+ assert message.role in MessageRole.__members__.values()
965
+ assert message.role == expected_message.role
966
+ assert isinstance(message.content, list)
967
+ for content, expected_content in zip(message.content, expected_message.content):
968
+ assert content == expected_content
969
+ # Test calls to model
970
+ assert len(fake_model.generate.call_args_list) == 1
971
+ for call_args, expected_messages in zip(fake_model.generate.call_args_list, expected_messages_list):
972
+ assert len(call_args.args) == 1
973
+ messages = call_args.args[0]
974
+ assert isinstance(messages, list)
975
+ assert len(messages) == len(expected_messages)
976
+ for message, expected_message in zip(messages, expected_messages):
977
+ assert isinstance(message, ChatMessage)
978
+ assert message.role in MessageRole.__members__.values()
979
+ assert message.role == expected_message.role
980
+ assert isinstance(message.content, list)
981
+ for content, expected_content in zip(message.content, expected_message.content):
982
+ assert content == expected_content
983
+
984
+ @pytest.mark.parametrize(
985
+ "images, expected_messages_list",
986
+ [
987
+ (
988
+ None,
989
+ [
990
+ [
991
+ ChatMessage(
992
+ role=MessageRole.SYSTEM,
993
+ content=[{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}],
994
+ ),
995
+ ChatMessage(
996
+ role=MessageRole.USER,
997
+ content=[{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}],
998
+ ),
999
+ ]
1000
+ ],
1001
+ ),
1002
+ (
1003
+ ["image1.png"],
1004
+ [
1005
+ [
1006
+ ChatMessage(
1007
+ role=MessageRole.SYSTEM,
1008
+ content=[
1009
+ {"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"},
1010
+ {"type": "image", "image": "image1.png"},
1011
+ ],
1012
+ ),
1013
+ ChatMessage(
1014
+ role=MessageRole.USER,
1015
+ content=[{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}],
1016
+ ),
1017
+ ]
1018
+ ],
1019
+ ),
1020
+ ],
1021
+ )
1022
+ def test_provide_final_answer(self, images, expected_messages_list):
1023
+ fake_model = MagicMock()
1024
+ fake_model.generate.return_value = ChatMessage(
1025
+ role=MessageRole.ASSISTANT,
1026
+ content="Final answer.",
1027
+ tool_calls=None,
1028
+ raw="Final answer.",
1029
+ token_usage=None,
1030
+ )
1031
+ agent = CodeAgent(
1032
+ tools=[],
1033
+ model=fake_model,
1034
+ )
1035
+ task = "Test task"
1036
+ final_answer = agent.provide_final_answer(task, images=images).content
1037
+ expected_message_texts = {
1038
+ "FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"],
1039
+ "FINAL_ANSWER_USER_PROMPT": populate_template(
1040
+ agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task)
1041
+ ),
1042
+ }
1043
+ for expected_messages in expected_messages_list:
1044
+ for expected_message in expected_messages:
1045
+ for expected_content in expected_message.content:
1046
+ if "text" in expected_content:
1047
+ expected_content["text"] = expected_message_texts[expected_content["text"]]
1048
+ assert final_answer == "Final answer."
1049
+ # Test calls to model
1050
+ assert len(fake_model.generate.call_args_list) == 1
1051
+ for call_args, expected_messages in zip(fake_model.generate.call_args_list, expected_messages_list):
1052
+ assert len(call_args.args) == 1
1053
+ messages = call_args.args[0]
1054
+ assert isinstance(messages, list)
1055
+ assert len(messages) == len(expected_messages)
1056
+ for message, expected_message in zip(messages, expected_messages):
1057
+ assert isinstance(message, ChatMessage)
1058
+ assert message.role in MessageRole.__members__.values()
1059
+ assert message.role == expected_message.role
1060
+ assert isinstance(message.content, list)
1061
+ for content, expected_content in zip(message.content, expected_message.content):
1062
+ assert content == expected_content
1063
+
1064
+ def test_interrupt(self):
1065
+ fake_model = MagicMock()
1066
+ fake_model.generate.return_value = ChatMessage(
1067
+ role=MessageRole.ASSISTANT,
1068
+ content="Model output.",
1069
+ tool_calls=None,
1070
+ raw="Model output.",
1071
+ token_usage=None,
1072
+ )
1073
+
1074
+ def interrupt_callback(memory_step, agent):
1075
+ agent.interrupt()
1076
+
1077
+ agent = CodeAgent(
1078
+ tools=[],
1079
+ model=fake_model,
1080
+ step_callbacks=[interrupt_callback],
1081
+ )
1082
+ with pytest.raises(AgentError) as e:
1083
+ agent.run("Test task")
1084
+ assert "Agent interrupted" in str(e)
1085
+
1086
+ @pytest.mark.parametrize(
1087
+ "tools, managed_agents, name, expectation",
1088
+ [
1089
+ # Valid case: no duplicates
1090
+ (
1091
+ [MockTool("tool1"), MockTool("tool2")],
1092
+ [MockAgent("agent1", [MockTool("tool3")])],
1093
+ "test_agent",
1094
+ does_not_raise(),
1095
+ ),
1096
+ # Invalid case: duplicate tool names
1097
+ ([MockTool("tool1"), MockTool("tool1")], [], "test_agent", pytest.raises(ValueError)),
1098
+ # Invalid case: tool name same as managed agent name
1099
+ (
1100
+ [MockTool("tool1")],
1101
+ [MockAgent("tool1", [MockTool("final_answer")])],
1102
+ "test_agent",
1103
+ pytest.raises(ValueError),
1104
+ ),
1105
+ # Valid case: tool name same as managed agent's tool name
1106
+ ([MockTool("tool1")], [MockAgent("agent1", [MockTool("tool1")])], "test_agent", does_not_raise()),
1107
+ # Invalid case: duplicate managed agent name and managed agent tool name
1108
+ ([MockTool("tool1")], [], "tool1", pytest.raises(ValueError)),
1109
+ # Valid case: duplicate tool names across managed agents
1110
+ (
1111
+ [MockTool("tool1")],
1112
+ [
1113
+ MockAgent("agent1", [MockTool("tool2"), MockTool("final_answer")]),
1114
+ MockAgent("agent2", [MockTool("tool2"), MockTool("final_answer")]),
1115
+ ],
1116
+ "test_agent",
1117
+ does_not_raise(),
1118
+ ),
1119
+ ],
1120
+ )
1121
+ def test_validate_tools_and_managed_agents(self, tools, managed_agents, name, expectation):
1122
+ fake_model = MagicMock()
1123
+ with expectation:
1124
+ DummyMultiStepAgent(
1125
+ tools=tools,
1126
+ model=fake_model,
1127
+ name=name,
1128
+ managed_agents=managed_agents,
1129
+ )
1130
+
1131
+ def test_from_dict(self):
1132
+ # Create a test agent dictionary
1133
+ agent_dict = {
1134
+ "model": {"class": "TransformersModel", "data": {"model_id": "test/model"}},
1135
+ "tools": [
1136
+ {
1137
+ "name": "valid_tool_function",
1138
+ "code": 'from smolagents import Tool\nfrom typing import Any, Optional\n\nclass SimpleTool(Tool):\n name = "valid_tool_function"\n description = "A valid tool function."\n inputs = {"input":{"type":"string","description":"Input string."}}\n output_type = "string"\n\n def forward(self, input: str) -> str:\n """A valid tool function.\n\n Args:\n input (str): Input string.\n """\n return input.upper()',
1139
+ "requirements": {"smolagents"},
1140
+ }
1141
+ ],
1142
+ "managed_agents": {},
1143
+ "prompt_templates": EMPTY_PROMPT_TEMPLATES,
1144
+ "max_steps": 15,
1145
+ "verbosity_level": 2,
1146
+ "planning_interval": 3,
1147
+ "name": "test_agent",
1148
+ "description": "Test agent description",
1149
+ }
1150
+
1151
+ # Call from_dict
1152
+ with patch("smolagents.models.TransformersModel") as mock_model_class:
1153
+ mock_model_instance = mock_model_class.from_dict.return_value
1154
+ agent = DummyMultiStepAgent.from_dict(agent_dict)
1155
+
1156
+ # Verify the agent was created correctly
1157
+ assert agent.model == mock_model_instance
1158
+ assert mock_model_class.from_dict.call_args.args[0] == {"model_id": "test/model"}
1159
+ assert agent.max_steps == 15
1160
+ assert agent.logger.level == 2
1161
+ assert agent.planning_interval == 3
1162
+ assert agent.name == "test_agent"
1163
+ assert agent.description == "Test agent description"
1164
+ # Verify the tool was created correctly
1165
+ assert sorted(agent.tools.keys()) == ["final_answer", "valid_tool_function"]
1166
+ assert agent.tools["valid_tool_function"].name == "valid_tool_function"
1167
+ assert agent.tools["valid_tool_function"].description == "A valid tool function."
1168
+ assert agent.tools["valid_tool_function"].inputs == {
1169
+ "input": {"type": "string", "description": "Input string."}
1170
+ }
1171
+ assert agent.tools["valid_tool_function"]("test") == "TEST"
1172
+
1173
+ # Test overriding with kwargs
1174
+ with patch("smolagents.models.TransformersModel") as mock_model_class:
1175
+ agent = DummyMultiStepAgent.from_dict(agent_dict, max_steps=30)
1176
+ assert agent.max_steps == 30
1177
+
1178
+
1179
+ class TestToolCallingAgent:
1180
+ def test_toolcalling_agent_instructions(self):
1181
+ agent = ToolCallingAgent(tools=[], model=MagicMock(), instructions="Test instructions")
1182
+ assert agent.instructions == "Test instructions"
1183
+ assert "Test instructions" in agent.system_prompt
1184
+
1185
+ def test_toolcalling_agent_passes_both_tools_and_managed_agents(self, test_tool):
1186
+ """Test that both tools and managed agents are passed to the model."""
1187
+ managed_agent = MagicMock()
1188
+ managed_agent.name = "managed_agent"
1189
+ model = MagicMock()
1190
+ model.generate.return_value = ChatMessage(
1191
+ role=MessageRole.ASSISTANT,
1192
+ content="",
1193
+ tool_calls=[
1194
+ ChatMessageToolCall(
1195
+ id="call_0",
1196
+ type="function",
1197
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "test_value"}),
1198
+ )
1199
+ ],
1200
+ )
1201
+ agent = ToolCallingAgent(tools=[test_tool], managed_agents=[managed_agent], model=model)
1202
+ # Run the agent one step to trigger the model call
1203
+ next(agent.run("Test task", stream=True))
1204
+ # Check that the model was called with both tools and managed agents:
1205
+ # - Get all tool_to_call_from names passed to the model
1206
+ tools_to_call_from_names = [tool.name for tool in model.generate.call_args.kwargs["tools_to_call_from"]]
1207
+ # - Verify both regular tools and managed agents are included
1208
+ assert "test_tool" in tools_to_call_from_names # The regular tool
1209
+ assert "managed_agent" in tools_to_call_from_names # The managed agent
1210
+ assert "final_answer" in tools_to_call_from_names # The final_answer tool (added by default)
1211
+
1212
+ @patch("huggingface_hub.InferenceClient")
1213
+ def test_toolcalling_agent_api(self, mock_inference_client):
1214
+ mock_client = mock_inference_client.return_value
1215
+ mock_response = mock_client.chat_completion.return_value
1216
+ mock_response.choices[0].message = ChatCompletionOutputMessage(
1217
+ role=MessageRole.ASSISTANT,
1218
+ content='{"name": "weather_api", "arguments": {"location": "Paris", "date": "today"}}',
1219
+ )
1220
+ mock_response.usage.prompt_tokens = 10
1221
+ mock_response.usage.completion_tokens = 20
1222
+
1223
+ model = InferenceClientModel(model_id="test-model")
1224
+
1225
+ from smolagents import tool
1226
+
1227
+ @tool
1228
+ def weather_api(location: str, date: str) -> str:
1229
+ """
1230
+ Gets the weather in the next days at given location.
1231
+ Args:
1232
+ location: the location
1233
+ date: the date
1234
+ """
1235
+ return f"The weather in {location} on date:{date} is sunny."
1236
+
1237
+ agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1)
1238
+ agent.run("What's the weather in Paris?")
1239
+ assert agent.memory.steps[0].task == "What's the weather in Paris?"
1240
+ assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
1241
+ assert agent.memory.steps[1].tool_calls[0].arguments == {"location": "Paris", "date": "today"}
1242
+ assert agent.memory.steps[1].observations == "The weather in Paris on date:today is sunny."
1243
+
1244
+ mock_response.choices[0].message = ChatCompletionOutputMessage(
1245
+ role=MessageRole.ASSISTANT,
1246
+ content=None,
1247
+ tool_calls=[
1248
+ ChatCompletionOutputToolCall(
1249
+ function=ChatCompletionOutputFunctionDefinition(
1250
+ name="weather_api", arguments='{"location": "Paris", "date": "today"}'
1251
+ ),
1252
+ id="call_0",
1253
+ type="function",
1254
+ )
1255
+ ],
1256
+ )
1257
+
1258
+ agent.run("What's the weather in Paris?")
1259
+ assert agent.memory.steps[0].task == "What's the weather in Paris?"
1260
+ assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
1261
+ assert agent.memory.steps[1].tool_calls[0].arguments == {"location": "Paris", "date": "today"}
1262
+ assert agent.memory.steps[1].observations == "The weather in Paris on date:today is sunny."
1263
+
1264
+ @patch("openai.OpenAI")
1265
+ def test_toolcalling_agent_stream_outputs_multiple_tool_calls(self, mock_openai_client, test_tool):
1266
+ """Test that ToolCallingAgent with stream_outputs=True returns the first final_answer when multiple are called."""
1267
+ mock_client = mock_openai_client.return_value
1268
+ from smolagents import OpenAIServerModel
1269
+
1270
+ # Mock streaming response with multiple final_answer calls
1271
+ mock_deltas = [
1272
+ ChoiceDelta(role=MessageRole.ASSISTANT),
1273
+ ChoiceDelta(
1274
+ tool_calls=[
1275
+ ChoiceDeltaToolCall(
1276
+ index=0,
1277
+ id="call_1",
1278
+ function=ChoiceDeltaToolCallFunction(name="final_answer"),
1279
+ type="function",
1280
+ )
1281
+ ]
1282
+ ),
1283
+ ChoiceDelta(
1284
+ tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"an'))]
1285
+ ),
1286
+ ChoiceDelta(
1287
+ tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='swer"'))]
1288
+ ),
1289
+ ChoiceDelta(
1290
+ tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments=': "out'))]
1291
+ ),
1292
+ ChoiceDelta(
1293
+ tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments="put1"))]
1294
+ ),
1295
+ ChoiceDelta(
1296
+ tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))]
1297
+ ),
1298
+ ChoiceDelta(
1299
+ tool_calls=[
1300
+ ChoiceDeltaToolCall(
1301
+ index=1,
1302
+ id="call_2",
1303
+ function=ChoiceDeltaToolCallFunction(name="test_tool"),
1304
+ type="function",
1305
+ )
1306
+ ]
1307
+ ),
1308
+ ChoiceDelta(
1309
+ tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"in'))]
1310
+ ),
1311
+ ChoiceDelta(
1312
+ tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='put"'))]
1313
+ ),
1314
+ ChoiceDelta(
1315
+ tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments=': "out'))]
1316
+ ),
1317
+ ChoiceDelta(
1318
+ tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments="put2"))]
1319
+ ),
1320
+ ChoiceDelta(
1321
+ tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"}'))]
1322
+ ),
1323
+ ]
1324
+
1325
+ class MockChoice:
1326
+ def __init__(self, delta):
1327
+ self.delta = delta
1328
+
1329
+ class MockChunk:
1330
+ def __init__(self, delta):
1331
+ self.choices = [MockChoice(delta)]
1332
+ self.usage = None
1333
+
1334
+ mock_client.chat.completions.create.return_value = (MockChunk(delta) for delta in mock_deltas)
1335
+
1336
+ # Mock usage for non-streaming fallback
1337
+ mock_usage = MagicMock()
1338
+ mock_usage.prompt_tokens = 10
1339
+ mock_usage.completion_tokens = 20
1340
+
1341
+ model = OpenAIServerModel(model_id="fakemodel")
1342
+
1343
+ agent = ToolCallingAgent(model=model, tools=[test_tool], max_steps=1, stream_outputs=True)
1344
+ result = agent.run("Make 2 calls to final answer: return both 'output1' and 'output2'")
1345
+ assert len(agent.memory.steps[-1].model_output_message.tool_calls) == 2
1346
+ assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
1347
+ assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
1348
+
1349
+ # The agent should return the final answer call
1350
+ assert result == "output1"
1351
+
1352
+ @patch("huggingface_hub.InferenceClient")
1353
+ def test_toolcalling_agent_api_misformatted_output(self, mock_inference_client):
1354
+ """Test that even misformatted json blobs don't interrupt the run for a ToolCallingAgent."""
1355
+ mock_client = mock_inference_client.return_value
1356
+ mock_response = mock_client.chat_completion.return_value
1357
+ mock_response.choices[0].message = ChatCompletionOutputMessage(
1358
+ role=MessageRole.ASSISTANT,
1359
+ content='{"name": weather_api", "arguments": {"location": "Paris", "date": "today"}}',
1360
+ )
1361
+
1362
+ mock_response.usage.prompt_tokens = 10
1363
+ mock_response.usage.completion_tokens = 20
1364
+
1365
+ model = InferenceClientModel(model_id="test-model")
1366
+
1367
+ logger = AgentLogger(console=Console(markup=False, no_color=True))
1368
+
1369
+ agent = ToolCallingAgent(model=model, tools=[], max_steps=2, verbosity_level=1, logger=logger)
1370
+ with agent.logger.console.capture() as capture:
1371
+ agent.run("What's the weather in Paris?")
1372
+ assert agent.memory.steps[0].task == "What's the weather in Paris?"
1373
+ assert agent.memory.steps[1].tool_calls is None
1374
+ assert "The JSON blob you used is invalid" in agent.memory.steps[1].error.message
1375
+ assert "Error while parsing" in capture.get()
1376
+ assert len(agent.memory.steps) == 4
1377
+
1378
+ def test_change_tools_after_init(self):
1379
+ from smolagents import tool
1380
+
1381
+ @tool
1382
+ def fake_tool_1() -> str:
1383
+ """Fake tool"""
1384
+ return "1"
1385
+
1386
+ @tool
1387
+ def fake_tool_2() -> str:
1388
+ """Fake tool"""
1389
+ return "2"
1390
+
1391
+ class FakeCodeModel(Model):
1392
+ def generate(self, messages, stop_sequences=None):
1393
+ return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(fake_tool_1())\n</code>")
1394
+
1395
+ agent = CodeAgent(tools=[fake_tool_1], model=FakeCodeModel())
1396
+
1397
+ agent.tools["final_answer"] = CustomFinalAnswerTool()
1398
+ agent.tools["fake_tool_1"] = fake_tool_2
1399
+
1400
+ answer = agent.run("Fake task.")
1401
+ assert answer == "2CUSTOM"
1402
+
1403
+ def test_custom_final_answer_with_custom_inputs(self, test_tool):
1404
+ class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
1405
+ inputs = {
1406
+ "answer1": {"type": "string", "description": "First part of the answer."},
1407
+ "answer2": {"type": "string", "description": "Second part of the answer."},
1408
+ }
1409
+
1410
+ def forward(self, answer1: str, answer2: str) -> str:
1411
+ return answer1 + " and " + answer2
1412
+
1413
+ model = MagicMock()
1414
+ model.generate.return_value = ChatMessage(
1415
+ role=MessageRole.ASSISTANT,
1416
+ content=None,
1417
+ tool_calls=[
1418
+ ChatMessageToolCall(
1419
+ id="call_0",
1420
+ type="function",
1421
+ function=ChatMessageToolCallFunction(
1422
+ name="final_answer", arguments={"answer1": "1", "answer2": "2"}
1423
+ ),
1424
+ ),
1425
+ ChatMessageToolCall(
1426
+ id="call_1",
1427
+ type="function",
1428
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "3"}),
1429
+ ),
1430
+ ],
1431
+ )
1432
+ agent = ToolCallingAgent(tools=[test_tool, CustomFinalAnswerToolWithCustomInputs()], model=model)
1433
+ answer = agent.run("Fake task.")
1434
+ assert answer == "1 and 2"
1435
+ assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
1436
+ assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
1437
+
1438
+ @pytest.mark.parametrize(
1439
+ "test_case",
1440
+ [
1441
+ # Case 0: Single valid tool call
1442
+ {
1443
+ "tool_calls": [
1444
+ ChatMessageToolCall(
1445
+ id="call_1",
1446
+ type="function",
1447
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "test_value"}),
1448
+ )
1449
+ ],
1450
+ "expected_model_output": "Tool call call_1: calling 'test_tool' with arguments: {'input': 'test_value'}",
1451
+ "expected_observations": "Processed: test_value",
1452
+ "expected_final_outputs": ["Processed: test_value"],
1453
+ "expected_error": None,
1454
+ },
1455
+ # Case 1: Multiple tool calls
1456
+ {
1457
+ "tool_calls": [
1458
+ ChatMessageToolCall(
1459
+ id="call_1",
1460
+ type="function",
1461
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "value1"}),
1462
+ ),
1463
+ ChatMessageToolCall(
1464
+ id="call_2",
1465
+ type="function",
1466
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "value2"}),
1467
+ ),
1468
+ ],
1469
+ "expected_model_output": "Tool call call_1: calling 'test_tool' with arguments: {'input': 'value1'}\nTool call call_2: calling 'test_tool' with arguments: {'input': 'value2'}",
1470
+ "expected_observations": "Processed: value1\nProcessed: value2",
1471
+ "expected_final_outputs": ["Processed: value1", "Processed: value2"],
1472
+ "expected_error": None,
1473
+ },
1474
+ # Case 2: Invalid tool name
1475
+ {
1476
+ "tool_calls": [
1477
+ ChatMessageToolCall(
1478
+ id="call_1",
1479
+ type="function",
1480
+ function=ChatMessageToolCallFunction(name="nonexistent_tool", arguments={"input": "test"}),
1481
+ )
1482
+ ],
1483
+ "expected_error": AgentToolExecutionError,
1484
+ },
1485
+ # Case 3: Tool execution error
1486
+ {
1487
+ "tool_calls": [
1488
+ ChatMessageToolCall(
1489
+ id="call_1",
1490
+ type="function",
1491
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "error"}),
1492
+ )
1493
+ ],
1494
+ "expected_error": AgentToolExecutionError,
1495
+ },
1496
+ # Case 4: Empty tool calls list
1497
+ {
1498
+ "tool_calls": [],
1499
+ "expected_model_output": "",
1500
+ "expected_observations": "",
1501
+ "expected_final_outputs": [],
1502
+ "expected_error": None,
1503
+ },
1504
+ # Case 5: Final answer call
1505
+ {
1506
+ "tool_calls": [
1507
+ ChatMessageToolCall(
1508
+ id="call_1",
1509
+ type="function",
1510
+ function=ChatMessageToolCallFunction(
1511
+ name="final_answer", arguments={"answer": "This is the final answer"}
1512
+ ),
1513
+ )
1514
+ ],
1515
+ "expected_model_output": "Tool call call_1: calling 'final_answer' with arguments: {'answer': 'This is the final answer'}",
1516
+ "expected_observations": "This is the final answer",
1517
+ "expected_final_outputs": ["This is the final answer"],
1518
+ "expected_error": None,
1519
+ },
1520
+ # Case 6: Invalid arguments
1521
+ {
1522
+ "tool_calls": [
1523
+ ChatMessageToolCall(
1524
+ id="call_1",
1525
+ type="function",
1526
+ function=ChatMessageToolCallFunction(name="test_tool", arguments={"wrong_param": "value"}),
1527
+ )
1528
+ ],
1529
+ "expected_error": AgentToolCallError,
1530
+ },
1531
+ ],
1532
+ )
1533
+ def test_process_tool_calls(self, test_case, test_tool):
1534
+ # Create a ToolCallingAgent instance with the test tool
1535
+ agent = ToolCallingAgent(tools=[test_tool], model=MagicMock())
1536
+ # Create chat message with the specified tool calls for process_tool_calls
1537
+ chat_message = ChatMessage(role=MessageRole.ASSISTANT, content="", tool_calls=test_case["tool_calls"])
1538
+ # Create a memory step for process_tool_calls
1539
+ memory_step = ActionStep(step_number=10, timing="mock_timing")
1540
+
1541
+ # Process tool calls
1542
+ if test_case["expected_error"]:
1543
+ with pytest.raises(test_case["expected_error"]):
1544
+ list(agent.process_tool_calls(chat_message, memory_step))
1545
+ else:
1546
+ final_outputs = list(agent.process_tool_calls(chat_message, memory_step))
1547
+ assert memory_step.model_output == test_case["expected_model_output"]
1548
+ assert memory_step.observations == test_case["expected_observations"]
1549
+ assert [
1550
+ final_output.output for final_output in final_outputs if isinstance(final_output, ToolOutput)
1551
+ ] == test_case["expected_final_outputs"]
1552
+ # Verify memory step tool calls were updated correctly
1553
+ if test_case["tool_calls"]:
1554
+ assert memory_step.tool_calls == [
1555
+ ToolCall(name=tool_call.function.name, arguments=tool_call.function.arguments, id=tool_call.id)
1556
+ for tool_call in test_case["tool_calls"]
1557
+ ]
1558
+
1559
+
1560
+ class TestCodeAgent:
1561
+ def test_code_agent_instructions(self):
1562
+ agent = CodeAgent(tools=[], model=MagicMock(), instructions="Test instructions")
1563
+ assert agent.instructions == "Test instructions"
1564
+ assert "Test instructions" in agent.system_prompt
1565
+
1566
+ agent = CodeAgent(
1567
+ tools=[], model=MagicMock(), instructions="Test instructions", use_structured_outputs_internally=True
1568
+ )
1569
+ assert agent.instructions == "Test instructions"
1570
+ assert "Test instructions" in agent.system_prompt
1571
+
1572
+ @pytest.mark.filterwarnings("ignore") # Ignore FutureWarning for deprecated grammar parameter
1573
+ def test_init_with_incompatible_grammar_and_use_structured_outputs_internally(self):
1574
+ # Test that using both parameters raises ValueError with correct message
1575
+ with pytest.raises(
1576
+ ValueError, match="You cannot use 'grammar' and 'use_structured_outputs_internally' at the same time."
1577
+ ):
1578
+ CodeAgent(
1579
+ tools=[],
1580
+ model=MagicMock(),
1581
+ grammar={"format": "json"},
1582
+ use_structured_outputs_internally=True,
1583
+ verbosity_level=LogLevel.DEBUG,
1584
+ )
1585
+
1586
+ # Verify no error when only one option is used
1587
+ # Only grammar
1588
+ agent_with_grammar = CodeAgent(
1589
+ tools=[],
1590
+ model=MagicMock(),
1591
+ grammar={"format": "json"},
1592
+ use_structured_outputs_internally=False,
1593
+ verbosity_level=LogLevel.DEBUG,
1594
+ )
1595
+ assert agent_with_grammar.grammar is not None
1596
+ assert agent_with_grammar._use_structured_outputs_internally is False
1597
+
1598
+ # Only structured output
1599
+ agent_with_structured = CodeAgent(
1600
+ tools=[],
1601
+ model=MagicMock(),
1602
+ grammar=None,
1603
+ use_structured_outputs_internally=True,
1604
+ verbosity_level=LogLevel.DEBUG,
1605
+ )
1606
+ assert agent_with_structured.grammar is None
1607
+ assert agent_with_structured._use_structured_outputs_internally is True
1608
+
1609
+ @pytest.mark.parametrize("provide_run_summary", [False, True])
1610
+ def test_call_with_provide_run_summary(self, provide_run_summary):
1611
+ agent = CodeAgent(tools=[], model=MagicMock(), provide_run_summary=provide_run_summary)
1612
+ assert agent.provide_run_summary is provide_run_summary
1613
+ agent.name = "test_agent"
1614
+ agent.run = MagicMock(return_value="Test output")
1615
+ agent.write_memory_to_messages = MagicMock(return_value=[{"content": "Test summary"}])
1616
+
1617
+ result = agent("Test request")
1618
+ expected_summary = "Here is the final answer from your managed agent 'test_agent':\nTest output"
1619
+ if provide_run_summary:
1620
+ expected_summary += (
1621
+ "\n\nFor more detail, find below a summary of this agent's work:\n"
1622
+ "<summary_of_work>\n\nTest summary\n---\n</summary_of_work>"
1623
+ )
1624
+ assert result == expected_summary
1625
+
1626
+ def test_errors_logging(self):
1627
+ class FakeCodeModel(Model):
1628
+ def generate(self, messages, stop_sequences=None):
1629
+ return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nsecret=3;['1', '2'][secret]\n</code>")
1630
+
1631
+ agent = CodeAgent(tools=[], model=FakeCodeModel(), verbosity_level=1)
1632
+
1633
+ with agent.logger.console.capture() as capture:
1634
+ agent.run("Test request")
1635
+ assert "secret\\\\" in repr(capture.get())
1636
+
1637
+ def test_missing_import_triggers_advice_in_error_log(self):
1638
+ # Set explicit verbosity level to 1 to override the default verbosity level of -1 set in CI fixture
1639
+ agent = CodeAgent(tools=[], model=FakeCodeModelImport(), verbosity_level=1)
1640
+
1641
+ with agent.logger.console.capture() as capture:
1642
+ agent.run("Count to 3")
1643
+ str_output = capture.get()
1644
+ assert "`additional_authorized_imports`" in str_output.replace("\n", "")
1645
+
1646
+ def test_errors_show_offending_line_and_error(self):
1647
+ agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError())
1648
+ output = agent.run("What is 2 multiplied by 3.6452?")
1649
+ assert isinstance(output, AgentText)
1650
+ assert output == "got an error"
1651
+ assert "Code execution failed at line 'error_function()'" in str(agent.memory.steps[1].error)
1652
+ assert "ValueError" in str(agent.memory.steps)
1653
+
1654
+ def test_error_saves_previous_print_outputs(self):
1655
+ agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError(), verbosity_level=10)
1656
+ agent.run("What is 2 multiplied by 3.6452?")
1657
+ assert "Flag!" in str(agent.memory.steps[1].observations)
1658
+
1659
+ def test_syntax_error_show_offending_lines(self):
1660
+ agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelSyntaxError())
1661
+ output = agent.run("What is 2 multiplied by 3.6452?")
1662
+ assert isinstance(output, AgentText)
1663
+ assert output == "got an error"
1664
+ assert ' print("Failing due to unexpected indent")' in str(agent.memory.steps)
1665
+ assert isinstance(agent.memory.steps[-2], ActionStep)
1666
+ assert agent.memory.steps[-2].code_action == dedent("""a = 2
1667
+ b = a * 2
1668
+ print("Failing due to unexpected indent")
1669
+ print("Ok, calculation done!")""")
1670
+
1671
+ def test_end_code_appending(self):
1672
+ # Checking original output message
1673
+ orig_output = FakeCodeModelNoReturn().generate([])
1674
+ assert not orig_output.content.endswith("<end_code>")
1675
+
1676
+ # Checking the step output
1677
+ agent = CodeAgent(
1678
+ tools=[PythonInterpreterTool()],
1679
+ model=FakeCodeModelNoReturn(),
1680
+ max_steps=1,
1681
+ )
1682
+ answer = agent.run("What is 2 multiplied by 3.6452?")
1683
+ assert answer
1684
+
1685
+ memory_steps = agent.memory.steps
1686
+ actions_steps = [s for s in memory_steps if isinstance(s, ActionStep)]
1687
+
1688
+ outputs = [s.model_output for s in actions_steps if s.model_output]
1689
+ assert outputs
1690
+ assert all(o.endswith("<end_code>") for o in outputs)
1691
+
1692
+ messages = [s.model_output_message for s in actions_steps if s.model_output_message]
1693
+ assert messages
1694
+ assert all(m.content.endswith("<end_code>") for m in messages)
1695
+
1696
+ def test_change_tools_after_init(self):
1697
+ from smolagents import tool
1698
+
1699
+ @tool
1700
+ def fake_tool_1() -> str:
1701
+ """Fake tool"""
1702
+ return "1"
1703
+
1704
+ @tool
1705
+ def fake_tool_2() -> str:
1706
+ """Fake tool"""
1707
+ return "2"
1708
+
1709
+ class FakeCodeModel(Model):
1710
+ def generate(self, messages, stop_sequences=None):
1711
+ return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(fake_tool_1())\n</code>")
1712
+
1713
+ agent = CodeAgent(tools=[fake_tool_1], model=FakeCodeModel())
1714
+
1715
+ agent.tools["final_answer"] = CustomFinalAnswerTool()
1716
+ agent.tools["fake_tool_1"] = fake_tool_2
1717
+
1718
+ answer = agent.run("Fake task.")
1719
+ assert answer == "2CUSTOM"
1720
+
1721
+ def test_local_python_executor_with_custom_functions(self):
1722
+ model = MagicMock()
1723
+ model.generate.return_value = ChatMessage(
1724
+ role=MessageRole.ASSISTANT,
1725
+ content="",
1726
+ tool_calls=None,
1727
+ raw="",
1728
+ token_usage=None,
1729
+ )
1730
+ agent = CodeAgent(tools=[], model=model, executor_kwargs={"additional_functions": {"open": open}})
1731
+ agent.run("Test run")
1732
+ assert "open" in agent.python_executor.static_tools
1733
+
1734
+ @pytest.mark.parametrize("agent_dict_version", ["v1.9", "v1.10"])
1735
+ def test_from_folder(self, agent_dict_version, get_agent_dict):
1736
+ agent_dict = get_agent_dict(agent_dict_version)
1737
+ with (
1738
+ patch("smolagents.agents.Path") as mock_path,
1739
+ patch("smolagents.models.InferenceClientModel") as mock_model,
1740
+ ):
1741
+ import json
1742
+
1743
+ mock_path.return_value.__truediv__.return_value.read_text.return_value = json.dumps(agent_dict)
1744
+ mock_model.from_dict.return_value.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
1745
+ agent = CodeAgent.from_folder("ignored_dummy_folder")
1746
+ assert isinstance(agent, CodeAgent)
1747
+ assert agent.name == "test_agent"
1748
+ assert agent.description == "dummy description"
1749
+ assert agent.max_steps == 10
1750
+ assert agent.planning_interval == 2
1751
+ assert agent.additional_authorized_imports == ["pandas"]
1752
+ assert "pandas" in agent.authorized_imports
1753
+ assert agent.executor_type == "local"
1754
+ assert agent.executor_kwargs == {}
1755
+ assert agent.max_print_outputs_length is None
1756
+ assert agent.managed_agents == {}
1757
+ assert set(agent.tools.keys()) == {"final_answer"}
1758
+ assert agent.model == mock_model.from_dict.return_value
1759
+ assert mock_model.from_dict.call_args.args[0]["model_id"] == "Qwen/Qwen2.5-Coder-32B-Instruct"
1760
+ assert agent.model.model_id == "Qwen/Qwen2.5-Coder-32B-Instruct"
1761
+ assert agent.logger.level == 2
1762
+ assert agent.prompt_templates["system_prompt"] == "dummy system prompt"
1763
+
1764
+ def test_from_dict(self):
1765
+ # Create a test agent dictionary
1766
+ agent_dict = {
1767
+ "model": {"class": "InferenceClientModel", "data": {"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct"}},
1768
+ "tools": [
1769
+ {
1770
+ "name": "valid_tool_function",
1771
+ "code": 'from smolagents import Tool\nfrom typing import Any, Optional\n\nclass SimpleTool(Tool):\n name = "valid_tool_function"\n description = "A valid tool function."\n inputs = {"input":{"type":"string","description":"Input string."}}\n output_type = "string"\n\n def forward(self, input: str) -> str:\n """A valid tool function.\n\n Args:\n input (str): Input string.\n """\n return input.upper()',
1772
+ "requirements": {"smolagents"},
1773
+ }
1774
+ ],
1775
+ "managed_agents": {},
1776
+ "prompt_templates": EMPTY_PROMPT_TEMPLATES,
1777
+ "max_steps": 15,
1778
+ "verbosity_level": 2,
1779
+ "use_structured_output": False,
1780
+ "planning_interval": 3,
1781
+ "name": "test_code_agent",
1782
+ "description": "Test code agent description",
1783
+ "authorized_imports": ["pandas", "numpy"],
1784
+ "executor_type": "local",
1785
+ "executor_kwargs": {"max_print_outputs_length": 10_000},
1786
+ "max_print_outputs_length": 1000,
1787
+ }
1788
+
1789
+ # Call from_dict
1790
+ with patch("smolagents.models.InferenceClientModel") as mock_model_class:
1791
+ mock_model_instance = mock_model_class.from_dict.return_value
1792
+ agent = CodeAgent.from_dict(agent_dict)
1793
+
1794
+ # Verify the agent was created correctly with CodeAgent-specific parameters
1795
+ assert agent.model == mock_model_instance
1796
+ assert agent.additional_authorized_imports == ["pandas", "numpy"]
1797
+ assert agent.executor_type == "local"
1798
+ assert agent.executor_kwargs == {"max_print_outputs_length": 10_000}
1799
+ assert agent.max_print_outputs_length == 1000
1800
+
1801
+ # Test with missing optional parameters
1802
+ minimal_agent_dict = {
1803
+ "model": {"class": "InferenceClientModel", "data": {"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct"}},
1804
+ "tools": [],
1805
+ "managed_agents": {},
1806
+ }
1807
+
1808
+ with patch("smolagents.models.InferenceClientModel"):
1809
+ agent = CodeAgent.from_dict(minimal_agent_dict)
1810
+ # Verify defaults are used
1811
+ assert agent.max_steps == 20 # default from MultiStepAgent.__init__
1812
+
1813
+ # Test overriding with kwargs
1814
+ with patch("smolagents.models.InferenceClientModel"):
1815
+ agent = CodeAgent.from_dict(
1816
+ agent_dict,
1817
+ additional_authorized_imports=["matplotlib"],
1818
+ executor_kwargs={"max_print_outputs_length": 5_000},
1819
+ )
1820
+ assert agent.additional_authorized_imports == ["matplotlib"]
1821
+ assert agent.executor_kwargs == {"max_print_outputs_length": 5_000}
1822
+
1823
+ def test_custom_final_answer_with_custom_inputs(self):
1824
+ class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
1825
+ inputs = {
1826
+ "answer1": {"type": "string", "description": "First part of the answer."},
1827
+ "answer2": {"type": "string", "description": "Second part of the answer."},
1828
+ }
1829
+
1830
+ def forward(self, answer1: str, answer2: str) -> str:
1831
+ return answer1 + "CUSTOM" + answer2
1832
+
1833
+ model = MagicMock()
1834
+ model.generate.return_value = ChatMessage(
1835
+ role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(answer1='1', answer2='2')\n</code>"
1836
+ )
1837
+ agent = CodeAgent(tools=[CustomFinalAnswerToolWithCustomInputs()], model=model)
1838
+ answer = agent.run("Fake task.")
1839
+ assert answer == "1CUSTOM2"
1840
+
1841
+
1842
+ class TestMultiAgents:
1843
+ def test_multiagents_save(self, tmp_path):
1844
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)
1845
+
1846
+ web_agent = ToolCallingAgent(
1847
+ model=model,
1848
+ tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()],
1849
+ name="web_agent",
1850
+ description="does web searches",
1851
+ )
1852
+ code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular")
1853
+
1854
+ agent = CodeAgent(
1855
+ model=model,
1856
+ tools=[],
1857
+ additional_authorized_imports=["pandas", "datetime"],
1858
+ managed_agents=[web_agent, code_agent],
1859
+ max_print_outputs_length=1000,
1860
+ executor_type="local",
1861
+ executor_kwargs={"max_print_outputs_length": 10_000},
1862
+ )
1863
+ agent.save(tmp_path)
1864
+
1865
+ expected_structure = {
1866
+ "managed_agents": {
1867
+ "useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]},
1868
+ "web_agent": {
1869
+ "tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]},
1870
+ "files": ["agent.json", "prompts.yaml"],
1871
+ },
1872
+ },
1873
+ "tools": {"files": ["final_answer.py"]},
1874
+ "files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"],
1875
+ }
1876
+
1877
+ def verify_structure(current_path: Path, structure: dict):
1878
+ for dir_name, contents in structure.items():
1879
+ if dir_name != "files":
1880
+ # For directories, verify they exist and recurse into them
1881
+ dir_path = current_path / dir_name
1882
+ assert dir_path.exists(), f"Directory {dir_path} does not exist"
1883
+ assert dir_path.is_dir(), f"{dir_path} is not a directory"
1884
+ verify_structure(dir_path, contents)
1885
+ else:
1886
+ # For files, verify each exists in the current path
1887
+ for file_name in contents:
1888
+ file_path = current_path / file_name
1889
+ assert file_path.exists(), f"File {file_path} does not exist"
1890
+ assert file_path.is_file(), f"{file_path} is not a file"
1891
+
1892
+ verify_structure(tmp_path, expected_structure)
1893
+
1894
+ # Test that re-loaded agents work as expected.
1895
+ agent2 = CodeAgent.from_folder(tmp_path, planning_interval=5)
1896
+ assert agent2.planning_interval == 5 # Check that kwargs are used
1897
+ assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
1898
+ assert agent2.max_print_outputs_length == 1000
1899
+ assert agent2.executor_type == "local"
1900
+ assert agent2.executor_kwargs == {"max_print_outputs_length": 10_000}
1901
+ assert (
1902
+ agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10
1903
+ ) # For now tool init parameters are forgotten
1904
+ assert agent2.model.kwargs["temperature"] == pytest.approx(0.5)
1905
+
1906
+ def test_multiagents(self):
1907
+ class FakeModelMultiagentsManagerAgent(Model):
1908
+ model_id = "fake_model"
1909
+
1910
+ def generate(
1911
+ self,
1912
+ messages,
1913
+ stop_sequences=None,
1914
+ tools_to_call_from=None,
1915
+ ):
1916
+ if tools_to_call_from is not None:
1917
+ if len(messages) < 3:
1918
+ return ChatMessage(
1919
+ role=MessageRole.ASSISTANT,
1920
+ content="",
1921
+ tool_calls=[
1922
+ ChatMessageToolCall(
1923
+ id="call_0",
1924
+ type="function",
1925
+ function=ChatMessageToolCallFunction(
1926
+ name="search_agent",
1927
+ arguments="Who is the current US president?",
1928
+ ),
1929
+ )
1930
+ ],
1931
+ )
1932
+ else:
1933
+ assert "Report on the current US president" in str(messages)
1934
+ return ChatMessage(
1935
+ role=MessageRole.ASSISTANT,
1936
+ content="",
1937
+ tool_calls=[
1938
+ ChatMessageToolCall(
1939
+ id="call_0",
1940
+ type="function",
1941
+ function=ChatMessageToolCallFunction(
1942
+ name="final_answer", arguments="Final report."
1943
+ ),
1944
+ )
1945
+ ],
1946
+ )
1947
+ else:
1948
+ if len(messages) < 3:
1949
+ return ChatMessage(
1950
+ role=MessageRole.ASSISTANT,
1951
+ content="""
1952
+ Thought: Let's call our search agent.
1953
+ <code>
1954
+ result = search_agent("Who is the current US president?")
1955
+ </code>
1956
+ """,
1957
+ )
1958
+ else:
1959
+ assert "Report on the current US president" in str(messages)
1960
+ return ChatMessage(
1961
+ role=MessageRole.ASSISTANT,
1962
+ content="""
1963
+ Thought: Let's return the report.
1964
+ <code>
1965
+ final_answer("Final report.")
1966
+ </code>
1967
+ """,
1968
+ )
1969
+
1970
+ manager_model = FakeModelMultiagentsManagerAgent()
1971
+
1972
+ class FakeModelMultiagentsManagedAgent(Model):
1973
+ model_id = "fake_model"
1974
+
1975
+ def generate(
1976
+ self,
1977
+ messages,
1978
+ tools_to_call_from=None,
1979
+ stop_sequences=None,
1980
+ ):
1981
+ return ChatMessage(
1982
+ role=MessageRole.ASSISTANT,
1983
+ content="Here is the secret content: FLAG1",
1984
+ tool_calls=[
1985
+ ChatMessageToolCall(
1986
+ id="call_0",
1987
+ type="function",
1988
+ function=ChatMessageToolCallFunction(
1989
+ name="final_answer",
1990
+ arguments="Report on the current US president",
1991
+ ),
1992
+ )
1993
+ ],
1994
+ )
1995
+
1996
+ managed_model = FakeModelMultiagentsManagedAgent()
1997
+
1998
+ web_agent = ToolCallingAgent(
1999
+ tools=[],
2000
+ model=managed_model,
2001
+ max_steps=10,
2002
+ name="search_agent",
2003
+ description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
2004
+ verbosity_level=2,
2005
+ )
2006
+
2007
+ manager_code_agent = CodeAgent(
2008
+ tools=[],
2009
+ model=manager_model,
2010
+ managed_agents=[web_agent],
2011
+ additional_authorized_imports=["time", "numpy", "pandas"],
2012
+ )
2013
+
2014
+ report = manager_code_agent.run("Fake question.")
2015
+ assert report == "Final report."
2016
+
2017
+ manager_toolcalling_agent = ToolCallingAgent(
2018
+ tools=[],
2019
+ model=manager_model,
2020
+ managed_agents=[web_agent],
2021
+ )
2022
+
2023
+ with web_agent.logger.console.capture() as capture:
2024
+ report = manager_toolcalling_agent.run("Fake question.")
2025
+ assert report == "Final report."
2026
+ assert "FLAG1" in capture.get() # Check that managed agent's output is properly logged
2027
+
2028
+ # Test that visualization works
2029
+ with manager_toolcalling_agent.logger.console.capture() as capture:
2030
+ manager_toolcalling_agent.visualize()
2031
+ assert "├──" in capture.get()
2032
+
2033
+
2034
+ @pytest.fixture
2035
+ def prompt_templates():
2036
+ return {
2037
+ "system_prompt": "This is a test system prompt.",
2038
+ "managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"},
2039
+ "planning": {
2040
+ "initial_plan": "The plan.",
2041
+ "update_plan_pre_messages": "custom",
2042
+ "update_plan_post_messages": "custom",
2043
+ },
2044
+ "final_answer": {"pre_messages": "custom", "post_messages": "custom"},
2045
+ }
2046
+
2047
+
2048
+ @pytest.mark.parametrize(
2049
+ "arguments",
2050
+ [
2051
+ {},
2052
+ {"arg": "bar"},
2053
+ {None: None},
2054
+ [1, 2, 3],
2055
+ ],
2056
+ )
2057
+ def test_tool_calling_agents_raises_tool_call_error_being_invoked_with_wrong_arguments(arguments):
2058
+ @tool
2059
+ def _sample_tool(prompt: str) -> str:
2060
+ """Tool that returns same string
2061
+ Args:
2062
+ prompt: The string to return
2063
+ Returns:
2064
+ The same string
2065
+ """
2066
+
2067
+ return prompt
2068
+
2069
+ agent = ToolCallingAgent(model=FakeToolCallModel(), tools=[_sample_tool])
2070
+ with pytest.raises(AgentToolCallError):
2071
+ agent.execute_tool_call(_sample_tool.name, arguments)
2072
+
2073
+
2074
+ def test_tool_calling_agents_raises_agent_execution_error_when_tool_raises():
2075
+ @tool
2076
+ def _sample_tool(_: str) -> float:
2077
+ """Tool that fails
2078
+
2079
+ Args:
2080
+ _: The pointless string
2081
+ Returns:
2082
+ Some number
2083
+ """
2084
+
2085
+ return 1 / 0
2086
+
2087
+ agent = ToolCallingAgent(model=FakeToolCallModel(), tools=[_sample_tool])
2088
+ with pytest.raises(AgentExecutionError):
2089
+ agent.execute_tool_call(_sample_tool.name, "sample")
tests/test_all_docs.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import os
18
+ import re
19
+ import shutil
20
+ import subprocess
21
+ import tempfile
22
+ import traceback
23
+ from pathlib import Path
24
+
25
+ import pytest
26
+ from dotenv import load_dotenv
27
+
28
+ from .utils.markers import require_run_all
29
+
30
+
31
+ class SubprocessCallException(Exception):
32
+ pass
33
+
34
+
35
+ def run_command(command: list[str], return_stdout=False, env=None):
36
+ """
37
+ Runs command with subprocess.check_output and returns stdout if requested.
38
+ Properly captures and handles errors during command execution.
39
+ """
40
+ for i, c in enumerate(command):
41
+ if isinstance(c, Path):
42
+ command[i] = str(c)
43
+
44
+ if env is None:
45
+ env = os.environ.copy()
46
+
47
+ try:
48
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
49
+ if return_stdout:
50
+ if hasattr(output, "decode"):
51
+ output = output.decode("utf-8")
52
+ return output
53
+ except subprocess.CalledProcessError as e:
54
+ raise SubprocessCallException(
55
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
56
+ ) from e
57
+
58
+
59
+ class DocCodeExtractor:
60
+ """Handles extraction and validation of Python code from markdown files."""
61
+
62
+ @staticmethod
63
+ def extract_python_code(content: str) -> list[str]:
64
+ """Extract Python code blocks from markdown content."""
65
+ pattern = r"```(?:python|py)\n(.*?)\n```"
66
+ matches = re.finditer(pattern, content, re.DOTALL)
67
+ return [match.group(1).strip() for match in matches]
68
+
69
+ @staticmethod
70
+ def create_test_script(code_blocks: list[str], tmp_dir: str) -> Path:
71
+ """Create a temporary Python script from code blocks."""
72
+ combined_code = "\n\n".join(code_blocks)
73
+ assert len(combined_code) > 0, "Code is empty!"
74
+ tmp_file = Path(tmp_dir) / "test_script.py"
75
+
76
+ with open(tmp_file, "w", encoding="utf-8") as f:
77
+ f.write(combined_code)
78
+
79
+ return tmp_file
80
+
81
+
82
+ # Skip: slow tests + require API keys
83
+ @require_run_all
84
+ class TestDocs:
85
+ """Test case for documentation code testing."""
86
+
87
+ @classmethod
88
+ def setup_class(cls):
89
+ cls._tmpdir = tempfile.mkdtemp()
90
+ cls.launch_args = ["python3"]
91
+ cls.docs_dir = Path(__file__).parent.parent / "docs" / "source" / "en"
92
+ cls.extractor = DocCodeExtractor()
93
+
94
+ if not cls.docs_dir.exists():
95
+ raise ValueError(f"Docs directory not found at {cls.docs_dir}")
96
+
97
+ load_dotenv()
98
+
99
+ cls.md_files = list(cls.docs_dir.rglob("*.md")) + list(cls.docs_dir.rglob("*.mdx"))
100
+ if not cls.md_files:
101
+ raise ValueError(f"No markdown files found in {cls.docs_dir}")
102
+
103
+ @classmethod
104
+ def teardown_class(cls):
105
+ shutil.rmtree(cls._tmpdir)
106
+
107
+ @pytest.mark.timeout(100)
108
+ def test_single_doc(self, doc_path: Path):
109
+ """Test a single documentation file."""
110
+ with open(doc_path, "r", encoding="utf-8") as f:
111
+ content = f.read()
112
+
113
+ code_blocks = self.extractor.extract_python_code(content)
114
+ excluded_snippets = [
115
+ "ToolCollection",
116
+ "image_generation_tool", # We don't want to run this expensive operation
117
+ "from_langchain", # Langchain is not a dependency
118
+ "while llm_should_continue(memory):", # This is pseudo code
119
+ "ollama_chat/llama3.2", # Exclude ollama building in guided tour
120
+ "model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
121
+ "SmolagentsInstrumentor", # Exclude telemetry since it needs additional installs
122
+ ]
123
+ code_blocks = [
124
+ block
125
+ for block in code_blocks
126
+ if not any(
127
+ [snippet in block for snippet in excluded_snippets]
128
+ ) # Exclude these tools that take longer to run and add dependencies
129
+ ]
130
+ if len(code_blocks) == 0:
131
+ pytest.skip(f"No Python code blocks found in {doc_path.name}")
132
+
133
+ # Validate syntax of each block individually by parsing it
134
+ for i, block in enumerate(code_blocks, 1):
135
+ ast.parse(block)
136
+
137
+ # Create and execute test script
138
+ print("\n\nCollected code block:==========\n".join(code_blocks))
139
+ try:
140
+ code_blocks = [
141
+ (
142
+ block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
143
+ .replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
144
+ .replace("{your_username}", "m-ric")
145
+ )
146
+ for block in code_blocks
147
+ ]
148
+ test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
149
+ run_command(self.launch_args + [str(test_script)])
150
+
151
+ except SubprocessCallException as e:
152
+ pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
153
+ except Exception:
154
+ pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
155
+
156
+ @pytest.fixture(autouse=True)
157
+ def _setup(self):
158
+ """Fixture to ensure temporary directory exists for each test."""
159
+ os.makedirs(self._tmpdir, exist_ok=True)
160
+ yield
161
+ # Clean up test files after each test
162
+ for file in Path(self._tmpdir).glob("*"):
163
+ file.unlink()
164
+
165
+
166
+ def pytest_generate_tests(metafunc):
167
+ """Generate test cases for each markdown file."""
168
+ if "doc_path" in metafunc.fixturenames:
169
+ test_class = metafunc.cls
170
+
171
+ # Initialize the class if needed
172
+ if not hasattr(test_class, "md_files"):
173
+ test_class.setup_class()
174
+
175
+ # Parameterize with the markdown files
176
+ metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])
tests/test_cli.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from smolagents.cli import load_model
6
+ from smolagents.local_python_executor import LocalPythonExecutor
7
+ from smolagents.models import InferenceClientModel, LiteLLMModel, OpenAIServerModel, TransformersModel
8
+
9
+
10
+ @pytest.fixture
11
+ def set_env_vars(monkeypatch):
12
+ monkeypatch.setenv("FIREWORKS_API_KEY", "test_fireworks_api_key")
13
+ monkeypatch.setenv("HF_TOKEN", "test_hf_api_key")
14
+
15
+
16
+ def test_load_model_openai_server_model(set_env_vars):
17
+ with patch("openai.OpenAI") as MockOpenAI:
18
+ model = load_model("OpenAIServerModel", "test_model_id")
19
+ assert isinstance(model, OpenAIServerModel)
20
+ assert model.model_id == "test_model_id"
21
+ assert MockOpenAI.call_count == 1
22
+ assert MockOpenAI.call_args.kwargs["base_url"] == "https://api.fireworks.ai/inference/v1"
23
+ assert MockOpenAI.call_args.kwargs["api_key"] == "test_fireworks_api_key"
24
+
25
+
26
+ def test_load_model_litellm_model():
27
+ model = load_model("LiteLLMModel", "test_model_id", api_key="test_api_key", api_base="https://api.test.com")
28
+ assert isinstance(model, LiteLLMModel)
29
+ assert model.api_key == "test_api_key"
30
+ assert model.api_base == "https://api.test.com"
31
+ assert model.model_id == "test_model_id"
32
+
33
+
34
+ def test_load_model_transformers_model():
35
+ with (
36
+ patch(
37
+ "transformers.AutoModelForImageTextToText.from_pretrained",
38
+ side_effect=ValueError("Unrecognized configuration class"),
39
+ ),
40
+ patch("transformers.AutoModelForCausalLM.from_pretrained"),
41
+ patch("transformers.AutoTokenizer.from_pretrained"),
42
+ ):
43
+ model = load_model("TransformersModel", "test_model_id")
44
+ assert isinstance(model, TransformersModel)
45
+ assert model.model_id == "test_model_id"
46
+
47
+
48
+ def test_load_model_hf_api_model(set_env_vars):
49
+ with patch("huggingface_hub.InferenceClient") as huggingface_hub_InferenceClient:
50
+ model = load_model("InferenceClientModel", "test_model_id")
51
+ assert isinstance(model, InferenceClientModel)
52
+ assert model.model_id == "test_model_id"
53
+ assert huggingface_hub_InferenceClient.call_count == 1
54
+ assert huggingface_hub_InferenceClient.call_args.kwargs["token"] == "test_hf_api_key"
55
+
56
+
57
+ def test_load_model_invalid_model_type():
58
+ with pytest.raises(ValueError, match="Unsupported model type: InvalidModel"):
59
+ load_model("InvalidModel", "test_model_id")
60
+
61
+
62
+ def test_cli_main(capsys):
63
+ with patch("smolagents.cli.load_model") as mock_load_model:
64
+ mock_load_model.return_value = "mock_model"
65
+ with patch("smolagents.cli.CodeAgent") as mock_code_agent:
66
+ from smolagents.cli import run_smolagent
67
+
68
+ run_smolagent("test_prompt", [], "InferenceClientModel", "test_model_id", provider="hf-inference")
69
+ # load_model
70
+ assert len(mock_load_model.call_args_list) == 1
71
+ assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id")
72
+ assert mock_load_model.call_args.kwargs == {"api_base": None, "api_key": None, "provider": "hf-inference"}
73
+ # CodeAgent
74
+ assert len(mock_code_agent.call_args_list) == 1
75
+ assert mock_code_agent.call_args.args == ()
76
+ assert mock_code_agent.call_args.kwargs == {
77
+ "tools": [],
78
+ "model": "mock_model",
79
+ "additional_authorized_imports": None,
80
+ }
81
+ # agent.run
82
+ assert len(mock_code_agent.return_value.run.call_args_list) == 1
83
+ assert mock_code_agent.return_value.run.call_args.args == ("test_prompt",)
84
+ # print
85
+ captured = capsys.readouterr()
86
+ assert "Running agent with these tools: []" in captured.out
87
+
88
+
89
+ def test_vision_web_browser_main():
90
+ with patch("smolagents.vision_web_browser.helium"):
91
+ with patch("smolagents.vision_web_browser.load_model") as mock_load_model:
92
+ mock_load_model.return_value = "mock_model"
93
+ with patch("smolagents.vision_web_browser.CodeAgent") as mock_code_agent:
94
+ from smolagents.vision_web_browser import helium_instructions, run_webagent
95
+
96
+ run_webagent("test_prompt", "InferenceClientModel", "test_model_id", provider="hf-inference")
97
+ # load_model
98
+ assert len(mock_load_model.call_args_list) == 1
99
+ assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id")
100
+ # CodeAgent
101
+ assert len(mock_code_agent.call_args_list) == 1
102
+ assert mock_code_agent.call_args.args == ()
103
+ assert len(mock_code_agent.call_args.kwargs["tools"]) == 4
104
+ assert mock_code_agent.call_args.kwargs["model"] == "mock_model"
105
+ assert mock_code_agent.call_args.kwargs["additional_authorized_imports"] == ["helium"]
106
+ # agent.python_executor
107
+ assert len(mock_code_agent.return_value.python_executor.call_args_list) == 1
108
+ assert mock_code_agent.return_value.python_executor.call_args.args == ("from helium import *",)
109
+ assert LocalPythonExecutor(["helium"])("from helium import *") == (None, "", False)
110
+ # agent.run
111
+ assert len(mock_code_agent.return_value.run.call_args_list) == 1
112
+ assert mock_code_agent.return_value.run.call_args.args == ("test_prompt" + helium_instructions,)
tests/test_default_tools.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+
17
+ import pytest
18
+
19
+ from smolagents.agent_types import _AGENT_TYPE_MAPPING
20
+ from smolagents.default_tools import (
21
+ DuckDuckGoSearchTool,
22
+ PythonInterpreterTool,
23
+ SpeechToTextTool,
24
+ VisitWebpageTool,
25
+ WikipediaSearchTool,
26
+ )
27
+
28
+ from .test_tools import ToolTesterMixin
29
+ from .utils.markers import require_run_all
30
+
31
+
32
+ class DefaultToolTests(unittest.TestCase):
33
+ def test_visit_webpage(self):
34
+ arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
35
+ result = VisitWebpageTool()(arguments)
36
+ assert isinstance(result, str)
37
+ assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
38
+
39
+ @require_run_all
40
+ def test_ddgs_with_kwargs(self):
41
+ result = DuckDuckGoSearchTool(timeout=20)("DeepSeek parent company")
42
+ assert isinstance(result, str)
43
+
44
+
45
+ class TestPythonInterpreterTool(ToolTesterMixin):
46
+ def setup_method(self):
47
+ self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
48
+ self.tool.setup()
49
+
50
+ def test_exact_match_arg(self):
51
+ result = self.tool("(2 / 2) * 4")
52
+ assert result == "Stdout:\n\nOutput: 4.0"
53
+
54
+ def test_exact_match_kwarg(self):
55
+ result = self.tool(code="(2 / 2) * 4")
56
+ assert result == "Stdout:\n\nOutput: 4.0"
57
+
58
+ def test_agent_type_output(self):
59
+ inputs = ["2 * 2"]
60
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
61
+ output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
62
+ assert isinstance(output, output_type)
63
+
64
+ def test_agent_types_inputs(self):
65
+ inputs = ["2 * 2"]
66
+ _inputs = []
67
+
68
+ for _input, expected_input in zip(inputs, self.tool.inputs.values()):
69
+ input_type = expected_input["type"]
70
+ if isinstance(input_type, list):
71
+ _inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
72
+ else:
73
+ _inputs.append(_AGENT_TYPE_MAPPING[input_type](_input))
74
+
75
+ # Should not raise an error
76
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
77
+ output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
78
+ assert isinstance(output, output_type)
79
+
80
+ def test_imports_work(self):
81
+ result = self.tool("import numpy as np")
82
+ assert "import from numpy is not allowed" not in result.lower()
83
+
84
+ def test_unauthorized_imports_fail(self):
85
+ with pytest.raises(Exception) as e:
86
+ self.tool("import sympy as sp")
87
+ assert "sympy" in str(e).lower()
88
+
89
+
90
+ class TestSpeechToTextTool:
91
+ def test_new_instance(self):
92
+ from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
93
+
94
+ tool = SpeechToTextTool()
95
+ assert tool is not None
96
+ assert tool.pre_processor_class == WhisperProcessor
97
+ assert tool.model_class == WhisperForConditionalGeneration
98
+
99
+ def test_initialization(self):
100
+ from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
101
+
102
+ tool = SpeechToTextTool(model="dummy_model_id")
103
+ assert tool is not None
104
+ assert tool.pre_processor_class == WhisperProcessor
105
+ assert tool.model_class == WhisperForConditionalGeneration
106
+
107
+
108
+ @pytest.mark.parametrize(
109
+ "language, content_type, extract_format, query",
110
+ [
111
+ ("en", "summary", "HTML", "Python_(programming_language)"), # English, Summary Mode, HTML format
112
+ ("en", "text", "WIKI", "Python_(programming_language)"), # English, Full Text Mode, WIKI format
113
+ ("es", "summary", "HTML", "Python_(lenguaje_de_programación)"), # Spanish, Summary Mode, HTML format
114
+ ("es", "text", "WIKI", "Python_(lenguaje_de_programación)"), # Spanish, Full Text Mode, WIKI format
115
+ ],
116
+ )
117
+ def test_wikipedia_search(language, content_type, extract_format, query):
118
+ tool = WikipediaSearchTool(
119
+ user_agent="TestAgent ([email protected])",
120
+ language=language,
121
+ content_type=content_type,
122
+ extract_format=extract_format,
123
+ )
124
+
125
+ result = tool.forward(query)
126
+
127
+ assert isinstance(result, str), "Output should be a string"
128
+ assert "✅ **Wikipedia Page:**" in result, "Response should contain Wikipedia page title"
129
+ assert "🔗 **Read more:**" in result, "Response should contain Wikipedia page URL"
130
+
131
+ if content_type == "summary":
132
+ assert len(result.split()) < 1000, "Summary mode should return a shorter text"
133
+ if content_type == "text":
134
+ assert len(result.split()) > 1000, "Full text mode should return a longer text"
tests/test_final_answer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import numpy as np
18
+ import PIL.Image
19
+ import pytest
20
+
21
+ from smolagents.agent_types import _AGENT_TYPE_MAPPING
22
+ from smolagents.default_tools import FinalAnswerTool
23
+
24
+ from .test_tools import ToolTesterMixin
25
+ from .utils.markers import require_torch
26
+
27
+
28
+ class TestFinalAnswerTool(ToolTesterMixin):
29
+ def setup_method(self):
30
+ self.inputs = {"answer": "Final answer"}
31
+ self.tool = FinalAnswerTool()
32
+
33
+ def test_exact_match_arg(self):
34
+ result = self.tool("Final answer")
35
+ assert result == "Final answer"
36
+
37
+ def test_exact_match_kwarg(self):
38
+ result = self.tool(answer=self.inputs["answer"])
39
+ assert result == "Final answer"
40
+
41
+ @require_torch
42
+ def test_agent_type_output(self, inputs):
43
+ for input_type, input in inputs.items():
44
+ output = self.tool(**input, sanitize_inputs_outputs=True)
45
+ agent_type = _AGENT_TYPE_MAPPING[input_type]
46
+ assert isinstance(output, agent_type)
47
+
48
+ @pytest.fixture
49
+ def inputs(self, shared_datadir):
50
+ import torch
51
+
52
+ return {
53
+ "string": {"answer": "Text input"},
54
+ "image": {"answer": PIL.Image.open(shared_datadir / "000000039769.png").resize((512, 512))},
55
+ "audio": {"answer": torch.Tensor(np.ones(3000))},
56
+ }
tests/test_function_type_hints_utils.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Any
16
+
17
+ import pytest
18
+
19
+ from smolagents._function_type_hints_utils import DocstringParsingException, get_imports, get_json_schema
20
+
21
+
22
+ @pytest.fixture
23
+ def valid_func():
24
+ """A well-formed function with docstring, type hints, and return block."""
25
+
26
+ def multiply(x: int, y: float) -> float:
27
+ """
28
+ Multiplies two numbers.
29
+
30
+ Args:
31
+ x: The first number.
32
+ y: The second number.
33
+ Returns:
34
+ Product of x and y.
35
+ """
36
+ return x * y
37
+
38
+ return multiply
39
+
40
+
41
+ @pytest.fixture
42
+ def no_docstring_func():
43
+ """Function with no docstring."""
44
+
45
+ def sample(x: int):
46
+ return x
47
+
48
+ return sample
49
+
50
+
51
+ @pytest.fixture
52
+ def missing_arg_doc_func():
53
+ """Function with docstring but missing an argument description."""
54
+
55
+ def add(x: int, y: int):
56
+ """
57
+ Adds two numbers.
58
+
59
+ Args:
60
+ x: The first number.
61
+ """
62
+ return x + y
63
+
64
+ return add
65
+
66
+
67
+ @pytest.fixture
68
+ def bad_return_func():
69
+ """Function docstring with missing return description (allowed)."""
70
+
71
+ def do_nothing(x: str | None = None):
72
+ """
73
+ Does nothing.
74
+
75
+ Args:
76
+ x: Some optional string.
77
+ """
78
+ pass
79
+
80
+ return do_nothing
81
+
82
+
83
+ @pytest.fixture
84
+ def complex_types_func():
85
+ def process_data(items: list[str], config: dict[str, float], point: tuple[int, int]) -> dict:
86
+ """
87
+ Process some data.
88
+
89
+ Args:
90
+ items: List of items to process.
91
+ config: Configuration parameters.
92
+ point: A position as (x,y).
93
+
94
+ Returns:
95
+ Processed data result.
96
+ """
97
+ return {"result": True}
98
+
99
+ return process_data
100
+
101
+
102
+ @pytest.fixture
103
+ def optional_types_func():
104
+ def process_with_optional(required_arg: str, optional_arg: int | None = None) -> str:
105
+ """
106
+ Process with optional argument.
107
+
108
+ Args:
109
+ required_arg: A required string argument.
110
+ optional_arg: An optional integer argument.
111
+
112
+ Returns:
113
+ Processing result.
114
+ """
115
+ return "processed"
116
+
117
+ return process_with_optional
118
+
119
+
120
+ @pytest.fixture
121
+ def enum_choices_func():
122
+ def select_color(color: str) -> str:
123
+ """
124
+ Select a color.
125
+
126
+ Args:
127
+ color: The color to select (choices: ["red", "green", "blue"])
128
+
129
+ Returns:
130
+ Selected color.
131
+ """
132
+ return color
133
+
134
+ return select_color
135
+
136
+
137
+ @pytest.fixture
138
+ def union_types_func():
139
+ def process_union(value: int | str) -> bool | str:
140
+ """
141
+ Process a value that can be either int or string.
142
+
143
+ Args:
144
+ value: An integer or string value.
145
+
146
+ Returns:
147
+ Processing result.
148
+ """
149
+ return True if isinstance(value, int) else "string result"
150
+
151
+ return process_union
152
+
153
+
154
+ @pytest.fixture
155
+ def nested_types_func():
156
+ def process_nested_data(data: list[dict[str, Any]]) -> list[str]:
157
+ """
158
+ Process nested data structure.
159
+
160
+ Args:
161
+ data: List of dictionaries to process.
162
+
163
+ Returns:
164
+ List of processed results.
165
+ """
166
+ return ["result"]
167
+
168
+ return process_nested_data
169
+
170
+
171
+ @pytest.fixture
172
+ def typed_docstring_func():
173
+ def calculate(x: int, y: float) -> float:
174
+ """
175
+ Calculate something.
176
+
177
+ Args:
178
+ x (int): An integer parameter with type in docstring.
179
+ y (float): A float parameter with type in docstring.
180
+
181
+ Returns:
182
+ float: The calculated result.
183
+ """
184
+ return x * y
185
+
186
+ return calculate
187
+
188
+
189
+ @pytest.fixture
190
+ def mismatched_types_func():
191
+ def convert(value: int) -> str:
192
+ """
193
+ Convert a value.
194
+
195
+ Args:
196
+ value (str): A string value (type mismatch with hint).
197
+
198
+ Returns:
199
+ int: Converted value (type mismatch with hint).
200
+ """
201
+ return str(value)
202
+
203
+ return convert
204
+
205
+
206
+ @pytest.fixture
207
+ def complex_docstring_types_func():
208
+ def process(data: dict[str, list[int]]) -> list[dict[str, Any]]:
209
+ """
210
+ Process complex data.
211
+
212
+ Args:
213
+ data (Dict[str, List[int]]): Nested structure with types.
214
+
215
+ Returns:
216
+ List[Dict[str, Any]]: Processed results with types.
217
+ """
218
+ return [{"result": sum(v) for k, v in data.items()}]
219
+
220
+ return process
221
+
222
+
223
+ @pytest.fixture
224
+ def keywords_in_description_func():
225
+ def process(value: str) -> str:
226
+ """
227
+ Function with Args: or Returns: keywords in its description.
228
+
229
+ Args:
230
+ value: A string value.
231
+
232
+ Returns:
233
+ str: Processed value.
234
+ """
235
+ return value.upper()
236
+
237
+ return process
238
+
239
+
240
+ class TestGetJsonSchema:
241
+ def test_get_json_schema_example(self):
242
+ def fn(x: int, y: tuple[str, str, float] | None = None) -> None:
243
+ """
244
+ Test function
245
+ Args:
246
+ x: The first input
247
+ y: The second input
248
+ """
249
+ pass
250
+
251
+ schema = get_json_schema(fn)
252
+ expected_schema = {
253
+ "name": "fn",
254
+ "description": "Test function",
255
+ "parameters": {
256
+ "type": "object",
257
+ "properties": {
258
+ "x": {"type": "integer", "description": "The first input"},
259
+ "y": {
260
+ "type": "array",
261
+ "description": "The second input",
262
+ "nullable": True,
263
+ "prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}],
264
+ },
265
+ },
266
+ "required": ["x"],
267
+ },
268
+ "return": {"type": "null"},
269
+ }
270
+ assert schema["function"]["parameters"]["properties"]["y"] == expected_schema["parameters"]["properties"]["y"]
271
+ assert schema["function"] == expected_schema
272
+
273
+ @pytest.mark.parametrize(
274
+ "fixture_name,should_fail",
275
+ [
276
+ ("valid_func", False),
277
+ # ('no_docstring_func', True),
278
+ # ('missing_arg_doc_func', True),
279
+ ("bad_return_func", False),
280
+ ],
281
+ )
282
+ def test_get_json_schema(self, request, fixture_name, should_fail):
283
+ func = request.getfixturevalue(fixture_name)
284
+ schema = get_json_schema(func)
285
+ assert schema["type"] == "function"
286
+ assert "function" in schema
287
+ assert "parameters" in schema["function"]
288
+
289
+ @pytest.mark.parametrize(
290
+ "fixture_name,should_fail",
291
+ [
292
+ # ('valid_func', False),
293
+ ("no_docstring_func", True),
294
+ ("missing_arg_doc_func", True),
295
+ # ('bad_return_func', False),
296
+ ],
297
+ )
298
+ def test_get_json_schema_raises(self, request, fixture_name, should_fail):
299
+ func = request.getfixturevalue(fixture_name)
300
+ with pytest.raises(DocstringParsingException):
301
+ get_json_schema(func)
302
+
303
+ @pytest.mark.parametrize(
304
+ "fixture_name,expected_properties",
305
+ [
306
+ ("valid_func", {"x": "integer", "y": "number"}),
307
+ ("bad_return_func", {"x": "string"}),
308
+ ],
309
+ )
310
+ def test_property_types(self, request, fixture_name, expected_properties):
311
+ """Test that property types are correctly mapped."""
312
+ func = request.getfixturevalue(fixture_name)
313
+ schema = get_json_schema(func)
314
+
315
+ properties = schema["function"]["parameters"]["properties"]
316
+ for prop_name, expected_type in expected_properties.items():
317
+ assert properties[prop_name]["type"] == expected_type
318
+
319
+ def test_schema_basic_structure(self, valid_func):
320
+ """Test that basic schema structure is correct."""
321
+ schema = get_json_schema(valid_func)
322
+ # Check schema type
323
+ assert schema["type"] == "function"
324
+ assert "function" in schema
325
+ # Check function schema
326
+ function_schema = schema["function"]
327
+ assert function_schema["name"] == "multiply"
328
+ assert "description" in function_schema
329
+ assert function_schema["description"] == "Multiplies two numbers."
330
+ # Check parameters schema
331
+ assert "parameters" in function_schema
332
+ params = function_schema["parameters"]
333
+ assert params["type"] == "object"
334
+ assert "properties" in params
335
+ assert "required" in params
336
+ assert set(params["required"]) == {"x", "y"}
337
+ properties = params["properties"]
338
+ assert properties["x"]["type"] == "integer"
339
+ assert properties["y"]["type"] == "number"
340
+ # Check return schema
341
+ assert "return" in function_schema
342
+ return_schema = function_schema["return"]
343
+ assert return_schema["type"] == "number"
344
+ assert return_schema["description"] == "Product of x and y."
345
+
346
+ def test_complex_types(self, complex_types_func):
347
+ """Test schema generation for complex types."""
348
+ schema = get_json_schema(complex_types_func)
349
+ properties = schema["function"]["parameters"]["properties"]
350
+ # Check list type
351
+ assert properties["items"]["type"] == "array"
352
+ # Check dict type
353
+ assert properties["config"]["type"] == "object"
354
+ # Check tuple type
355
+ assert properties["point"]["type"] == "array"
356
+ assert len(properties["point"]["prefixItems"]) == 2
357
+ assert properties["point"]["prefixItems"][0]["type"] == "integer"
358
+ assert properties["point"]["prefixItems"][1]["type"] == "integer"
359
+
360
+ def test_optional_types(self, optional_types_func):
361
+ """Test schema generation for optional arguments."""
362
+ schema = get_json_schema(optional_types_func)
363
+ params = schema["function"]["parameters"]
364
+ # Required argument should be in required list
365
+ assert "required_arg" in params["required"]
366
+ # Optional argument should not be in required list
367
+ assert "optional_arg" not in params["required"]
368
+ # Optional argument should be nullable
369
+ assert params["properties"]["optional_arg"]["nullable"] is True
370
+ assert params["properties"]["optional_arg"]["type"] == "integer"
371
+
372
+ def test_enum_choices(self, enum_choices_func):
373
+ """Test schema generation for enum choices in docstring."""
374
+ schema = get_json_schema(enum_choices_func)
375
+ color_prop = schema["function"]["parameters"]["properties"]["color"]
376
+ assert "enum" in color_prop
377
+ assert color_prop["enum"] == ["red", "green", "blue"]
378
+
379
+ def test_union_types(self, union_types_func):
380
+ """Test schema generation for union types."""
381
+ schema = get_json_schema(union_types_func)
382
+ value_prop = schema["function"]["parameters"]["properties"]["value"]
383
+ return_prop = schema["function"]["return"]
384
+ # Check union in parameter
385
+ assert len(value_prop["type"]) == 2
386
+ # Check union in return type: should be converted to "any"
387
+ assert return_prop["type"] == "any"
388
+
389
+ def test_nested_types(self, nested_types_func):
390
+ """Test schema generation for nested complex types."""
391
+ schema = get_json_schema(nested_types_func)
392
+ data_prop = schema["function"]["parameters"]["properties"]["data"]
393
+ assert data_prop["type"] == "array"
394
+
395
+ def test_typed_docstring_parsing(self, typed_docstring_func):
396
+ """Test parsing of docstrings with type annotations."""
397
+ schema = get_json_schema(typed_docstring_func)
398
+ # Type hints should take precedence over docstring types
399
+ assert schema["function"]["parameters"]["properties"]["x"]["type"] == "integer"
400
+ assert schema["function"]["parameters"]["properties"]["y"]["type"] == "number"
401
+ # Description should be extracted correctly
402
+ assert (
403
+ schema["function"]["parameters"]["properties"]["x"]["description"]
404
+ == "An integer parameter with type in docstring."
405
+ )
406
+ assert (
407
+ schema["function"]["parameters"]["properties"]["y"]["description"]
408
+ == "A float parameter with type in docstring."
409
+ )
410
+ # Return type and description should be correct
411
+ assert schema["function"]["return"]["type"] == "number"
412
+ assert schema["function"]["return"]["description"] == "The calculated result."
413
+
414
+ def test_mismatched_docstring_types(self, mismatched_types_func):
415
+ """Test that type hints take precedence over docstring types when they conflict."""
416
+ schema = get_json_schema(mismatched_types_func)
417
+ # Type hints should take precedence over docstring types
418
+ assert schema["function"]["parameters"]["properties"]["value"]["type"] == "integer"
419
+ # Return type from type hint should be used, not docstring
420
+ assert schema["function"]["return"]["type"] == "string"
421
+
422
+ def test_complex_docstring_types(self, complex_docstring_types_func):
423
+ """Test parsing of complex type annotations in docstrings."""
424
+ schema = get_json_schema(complex_docstring_types_func)
425
+ # Check that complex nested type is parsed correctly from type hints
426
+ data_prop = schema["function"]["parameters"]["properties"]["data"]
427
+ assert data_prop["type"] == "object"
428
+ # Check return type
429
+ return_prop = schema["function"]["return"]
430
+ assert return_prop["type"] == "array"
431
+ # Description should include the type information from docstring
432
+ assert data_prop["description"] == "Nested structure with types."
433
+ assert return_prop["description"] == "Processed results with types."
434
+
435
+ @pytest.mark.parametrize(
436
+ "fixture_name,expected_description",
437
+ [
438
+ ("typed_docstring_func", "An integer parameter with type in docstring."),
439
+ ("complex_docstring_types_func", "Nested structure with types."),
440
+ ],
441
+ )
442
+ def test_type_in_description_handling(self, request, fixture_name, expected_description):
443
+ """Test that type information in docstrings is preserved in description."""
444
+ func = request.getfixturevalue(fixture_name)
445
+ schema = get_json_schema(func)
446
+ # First parameter description should contain the expected text
447
+ first_param_name = list(schema["function"]["parameters"]["properties"].keys())[0]
448
+ assert schema["function"]["parameters"]["properties"][first_param_name]["description"] == expected_description
449
+
450
+ def test_with_special_words_in_description_func(self, keywords_in_description_func):
451
+ schema = get_json_schema(keywords_in_description_func)
452
+ assert schema["function"]["description"] == "Function with Args: or Returns: keywords in its description."
453
+
454
+
455
+ class TestGetCode:
456
+ @pytest.mark.parametrize(
457
+ "code, expected",
458
+ [
459
+ (
460
+ """
461
+ import numpy
462
+ import pandas
463
+ """,
464
+ ["numpy", "pandas"],
465
+ ),
466
+ # From imports
467
+ (
468
+ """
469
+ from torch import nn
470
+ from transformers import AutoModel
471
+ """,
472
+ ["torch", "transformers"],
473
+ ),
474
+ # Mixed case with nested imports
475
+ (
476
+ """
477
+ import numpy as np
478
+ from torch.nn import Linear
479
+ import os.path
480
+ """,
481
+ ["numpy", "torch", "os"],
482
+ ),
483
+ # Try/except block (should be filtered)
484
+ (
485
+ """
486
+ try:
487
+ import torch
488
+ except ImportError:
489
+ pass
490
+ import numpy
491
+ """,
492
+ ["numpy"],
493
+ ),
494
+ # Flash attention block (should be filtered)
495
+ (
496
+ """
497
+ if is_flash_attn_2_available():
498
+ from flash_attn import flash_attn_func
499
+ import transformers
500
+ """,
501
+ ["transformers"],
502
+ ),
503
+ # Relative imports (should be excluded)
504
+ (
505
+ """
506
+ from .utils import helper
507
+ from ..models import transformer
508
+ """,
509
+ [],
510
+ ),
511
+ ],
512
+ )
513
+ def test_get_imports(self, code: str, expected: list[str]):
514
+ assert sorted(get_imports(code)) == sorted(expected)
tests/test_gradio_ui.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import shutil
18
+ import tempfile
19
+ import unittest
20
+ from unittest.mock import Mock, patch
21
+
22
+ import pytest
23
+
24
+ from smolagents.agent_types import AgentAudio, AgentImage, AgentText
25
+ from smolagents.gradio_ui import GradioUI, pull_messages_from_step, stream_to_gradio
26
+ from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, ToolCall
27
+ from smolagents.models import ChatMessageStreamDelta
28
+ from smolagents.monitoring import Timing, TokenUsage
29
+
30
+
31
+ class GradioUITester(unittest.TestCase):
32
+ def setUp(self):
33
+ """Initialize test environment"""
34
+ self.temp_dir = tempfile.mkdtemp()
35
+ self.mock_agent = Mock()
36
+ self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
37
+ self.allowed_types = [".pdf", ".docx", ".txt"]
38
+
39
+ def tearDown(self):
40
+ """Clean up test environment"""
41
+ shutil.rmtree(self.temp_dir)
42
+
43
+ def test_upload_file_default_types(self):
44
+ """Test default allowed file types"""
45
+ default_types = [".pdf", ".docx", ".txt"]
46
+ for file_type in default_types:
47
+ with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
48
+ mock_file = Mock()
49
+ mock_file.name = temp_file.name
50
+
51
+ textbox, uploads_log = self.ui.upload_file(mock_file, [])
52
+
53
+ self.assertIn("File uploaded:", textbox.value)
54
+ self.assertEqual(len(uploads_log), 1)
55
+ self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
56
+
57
+ def test_upload_file_default_types_disallowed(self):
58
+ """Test default disallowed file types"""
59
+ disallowed_types = [".exe", ".sh", ".py", ".jpg"]
60
+ for file_type in disallowed_types:
61
+ with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
62
+ mock_file = Mock()
63
+ mock_file.name = temp_file.name
64
+
65
+ textbox, uploads_log = self.ui.upload_file(mock_file, [])
66
+
67
+ self.assertEqual(textbox.value, "File type disallowed")
68
+ self.assertEqual(len(uploads_log), 0)
69
+
70
+ def test_upload_file_success(self):
71
+ """Test successful file upload scenario"""
72
+ with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
73
+ mock_file = Mock()
74
+ mock_file.name = temp_file.name
75
+
76
+ textbox, uploads_log = self.ui.upload_file(mock_file, [])
77
+
78
+ self.assertIn("File uploaded:", textbox.value)
79
+ self.assertEqual(len(uploads_log), 1)
80
+ self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
81
+ self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
82
+
83
+ def test_upload_file_none(self):
84
+ """Test scenario when no file is selected"""
85
+ textbox, uploads_log = self.ui.upload_file(None, [])
86
+
87
+ self.assertEqual(textbox.value, "No file uploaded")
88
+ self.assertEqual(len(uploads_log), 0)
89
+
90
+ def test_upload_file_invalid_type(self):
91
+ """Test disallowed file type"""
92
+ with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
93
+ mock_file = Mock()
94
+ mock_file.name = temp_file.name
95
+
96
+ textbox, uploads_log = self.ui.upload_file(mock_file, [])
97
+
98
+ self.assertEqual(textbox.value, "File type disallowed")
99
+ self.assertEqual(len(uploads_log), 0)
100
+
101
+ def test_upload_file_special_chars(self):
102
+ """Test scenario with special characters in filename"""
103
+ with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
104
+ # Create a new temporary file with special characters
105
+ special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
106
+ shutil.copy(temp_file.name, special_char_name)
107
+ try:
108
+ mock_file = Mock()
109
+ mock_file.name = special_char_name
110
+
111
+ with patch("shutil.copy"):
112
+ textbox, uploads_log = self.ui.upload_file(mock_file, [])
113
+
114
+ self.assertIn("File uploaded:", textbox.value)
115
+ self.assertEqual(len(uploads_log), 1)
116
+ self.assertIn("test_____", uploads_log[0])
117
+ finally:
118
+ # Clean up the special character file
119
+ if os.path.exists(special_char_name):
120
+ os.remove(special_char_name)
121
+
122
+ def test_upload_file_custom_types(self):
123
+ """Test custom allowed file types"""
124
+ with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
125
+ mock_file = Mock()
126
+ mock_file.name = temp_file.name
127
+
128
+ textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])
129
+
130
+ self.assertIn("File uploaded:", textbox.value)
131
+ self.assertEqual(len(uploads_log), 1)
132
+
133
+
134
+ class TestStreamToGradio:
135
+ """Tests for the stream_to_gradio function."""
136
+
137
+ @patch("smolagents.gradio_ui.pull_messages_from_step")
138
+ def test_stream_to_gradio_memory_step(self, mock_pull_messages):
139
+ """Test streaming a memory step"""
140
+ # Create mock agent and memory step
141
+ mock_agent = Mock()
142
+ mock_agent.run = Mock(return_value=[Mock(spec=ActionStep)])
143
+ mock_agent.model = Mock()
144
+ mock_agent.model.last_input_token_count = 100
145
+ mock_agent.model.last_output_token_count = 200
146
+ # Mock the pull_messages_from_step function to return some messages
147
+ mock_message = Mock()
148
+ mock_pull_messages.return_value = [mock_message]
149
+ # Call stream_to_gradio
150
+ result = list(stream_to_gradio(mock_agent, "test task"))
151
+ # Verify that pull_messages_from_step was called and the message was yielded
152
+ mock_pull_messages.assert_called_once()
153
+ assert result == [mock_message]
154
+
155
+ def test_stream_to_gradio_stream_delta(self):
156
+ """Test streaming a ChatMessageStreamDelta"""
157
+ # Create mock agent and stream delta
158
+ mock_agent = Mock()
159
+ mock_delta = ChatMessageStreamDelta(content="Hello")
160
+ mock_agent.run = Mock(return_value=[mock_delta])
161
+ mock_agent.model = Mock()
162
+ mock_agent.model.last_input_token_count = 100
163
+ mock_agent.model.last_output_token_count = 200
164
+ # Call stream_to_gradio
165
+ result = list(stream_to_gradio(mock_agent, "test task"))
166
+ # Verify that the content was yielded
167
+ assert result == ["Hello"]
168
+
169
+ def test_stream_to_gradio_multiple_deltas(self):
170
+ """Test streaming multiple ChatMessageStreamDeltas"""
171
+ # Create mock agent and stream deltas
172
+ mock_agent = Mock()
173
+ mock_delta1 = ChatMessageStreamDelta(content="Hello")
174
+ mock_delta2 = ChatMessageStreamDelta(content=" world")
175
+ mock_agent.run = Mock(return_value=[mock_delta1, mock_delta2])
176
+ mock_agent.model = Mock()
177
+ mock_agent.model.last_input_token_count = 100
178
+ mock_agent.model.last_output_token_count = 200
179
+ # Call stream_to_gradio
180
+ result = list(stream_to_gradio(mock_agent, "test task"))
181
+ # Verify that the content was accumulated and yielded
182
+ assert result == ["Hello", "Hello world"]
183
+
184
+ @pytest.mark.parametrize(
185
+ "task,task_images,reset_memory,additional_args",
186
+ [
187
+ ("simple task", None, False, None),
188
+ ("task with images", ["image1.png", "image2.png"], False, None),
189
+ ("task with reset", None, True, None),
190
+ ("task with args", None, False, {"arg1": "value1"}),
191
+ ("complex task", ["image.png"], True, {"arg1": "value1", "arg2": "value2"}),
192
+ ],
193
+ )
194
+ def test_stream_to_gradio_parameters(self, task, task_images, reset_memory, additional_args):
195
+ """Test that stream_to_gradio passes parameters correctly to agent.run"""
196
+ # Create mock agent
197
+ mock_agent = Mock()
198
+ mock_agent.run = Mock(return_value=[])
199
+ # Call stream_to_gradio
200
+ list(
201
+ stream_to_gradio(
202
+ mock_agent,
203
+ task=task,
204
+ task_images=task_images,
205
+ reset_agent_memory=reset_memory,
206
+ additional_args=additional_args,
207
+ )
208
+ )
209
+ # Verify that agent.run was called with the right parameters
210
+ mock_agent.run.assert_called_once_with(
211
+ task, images=task_images, stream=True, reset=reset_memory, additional_args=additional_args
212
+ )
213
+
214
+
215
+ class TestPullMessagesFromStep:
216
+ def test_action_step_basic(
217
+ self,
218
+ ):
219
+ """Test basic ActionStep processing."""
220
+ step = ActionStep(
221
+ step_number=1,
222
+ model_output="This is the model output",
223
+ observations="Some execution logs",
224
+ error=None,
225
+ timing=Timing(start_time=1.0, end_time=3.5),
226
+ token_usage=TokenUsage(input_tokens=100, output_tokens=50),
227
+ )
228
+ messages = list(pull_messages_from_step(step))
229
+ assert len(messages) == 5 # step number, model_output, logs, footnote, divider
230
+ for message, expected_content in zip(
231
+ messages,
232
+ [
233
+ "**Step 1**",
234
+ "This is the model output",
235
+ "execution logs",
236
+ "Input tokens: 100 | Output tokens: 50 | Duration: 2.5",
237
+ "-----",
238
+ ],
239
+ ):
240
+ assert expected_content in message.content
241
+
242
+ def test_action_step_with_tool_calls(self):
243
+ """Test ActionStep with tool calls."""
244
+ step = ActionStep(
245
+ step_number=2,
246
+ tool_calls=[ToolCall(name="test_tool", arguments={"answer": "Test answer"}, id="tool_call_1")],
247
+ observations="Tool execution logs",
248
+ timing=Timing(start_time=1.0, end_time=2.5),
249
+ token_usage=TokenUsage(input_tokens=100, output_tokens=50),
250
+ )
251
+ messages = list(pull_messages_from_step(step))
252
+ assert len(messages) == 5 # step, tool call, logs, footnote, divider
253
+ assert messages[1].content == "Test answer"
254
+ assert "Used tool test_tool" in messages[1].metadata["title"]
255
+
256
+ @pytest.mark.parametrize(
257
+ "tool_name, args, expected",
258
+ [
259
+ ("python_interpreter", "print('Hello')", "```python\nprint('Hello')\n```"),
260
+ ("regular_tool", {"key": "value"}, "{'key': 'value'}"),
261
+ ("string_args_tool", "simple string", "simple string"),
262
+ ],
263
+ )
264
+ def test_action_step_tool_call_formats(self, tool_name, args, expected):
265
+ """Test different formats of tool calls."""
266
+ tool_call = Mock()
267
+ tool_call.name = tool_name
268
+ tool_call.arguments = args
269
+ step = ActionStep(
270
+ step_number=1,
271
+ tool_calls=[tool_call],
272
+ timing=Timing(start_time=1.0, end_time=2.5),
273
+ token_usage=TokenUsage(input_tokens=100, output_tokens=50),
274
+ )
275
+ messages = list(pull_messages_from_step(step))
276
+ tool_message = next(
277
+ msg
278
+ for msg in messages
279
+ if msg.role == "assistant" and msg.metadata and msg.metadata.get("title", "").startswith("🛠️")
280
+ )
281
+ assert expected in tool_message.content
282
+
283
+ def test_action_step_with_error(self):
284
+ """Test ActionStep with error."""
285
+ step = ActionStep(
286
+ step_number=3,
287
+ error="This is an error message",
288
+ timing=Timing(start_time=1.0, end_time=2.0),
289
+ token_usage=TokenUsage(input_tokens=100, output_tokens=200),
290
+ )
291
+ messages = list(pull_messages_from_step(step))
292
+ error_message = next((m for m in messages if "error" in str(m.content).lower()), None)
293
+ assert error_message is not None
294
+ assert "This is an error message" in error_message.content
295
+
296
+ def test_action_step_with_images(self):
297
+ """Test ActionStep with observation images."""
298
+ step = ActionStep(
299
+ step_number=4,
300
+ observations_images=["image1.png", "image2.jpg"],
301
+ token_usage=TokenUsage(input_tokens=100, output_tokens=200),
302
+ timing=Timing(start_time=1.0, end_time=2.0),
303
+ )
304
+ with patch("smolagents.gradio_ui.AgentImage") as mock_agent_image:
305
+ mock_agent_image.return_value.to_string.side_effect = lambda: "path/to/image.png"
306
+ messages = list(pull_messages_from_step(step))
307
+ image_messages = [m for m in messages if "image" in str(m).lower()]
308
+ assert len(image_messages) == 2
309
+ assert "path/to/image.png" in str(image_messages[0])
310
+
311
+ @pytest.mark.parametrize(
312
+ "skip_model_outputs, expected_messages_length, token_usage",
313
+ [(False, 4, TokenUsage(input_tokens=80, output_tokens=30)), (True, 2, None)],
314
+ )
315
+ def test_planning_step(self, skip_model_outputs, expected_messages_length, token_usage):
316
+ """Test PlanningStep processing."""
317
+ step = PlanningStep(
318
+ plan="1. First step\n2. Second step",
319
+ model_input_messages=Mock(),
320
+ model_output_message=Mock(),
321
+ token_usage=token_usage,
322
+ timing=Timing(start_time=1.0, end_time=2.0),
323
+ )
324
+ messages = list(pull_messages_from_step(step, skip_model_outputs=skip_model_outputs))
325
+ assert len(messages) == expected_messages_length # [header, plan,] footnote, divider
326
+ expected_contents = [
327
+ "**Planning step**",
328
+ "1. First step\n2. Second step",
329
+ "Input tokens: 80 | Output tokens: 30" if token_usage else "",
330
+ "-----",
331
+ ]
332
+ for message, expected_content in zip(messages, expected_contents[-expected_messages_length:]):
333
+ assert expected_content in message.content
334
+
335
+ if not token_usage:
336
+ assert "Input tokens: 80 | Output tokens: 30" not in message.content
337
+
338
+ @pytest.mark.parametrize(
339
+ "answer_type, answer_value, expected_content",
340
+ [
341
+ (AgentText, "This is a text answer", "**Final answer:**\nThis is a text answer\n"),
342
+ (lambda: "Plain string", "Plain string", "**Final answer:** Plain string"),
343
+ ],
344
+ )
345
+ def test_final_answer_step(self, answer_type, answer_value, expected_content):
346
+ """Test FinalAnswerStep with different answer types."""
347
+ try:
348
+ final_answer = answer_type()
349
+ except TypeError:
350
+ with patch.object(answer_type, "to_string", return_value=answer_value):
351
+ final_answer = answer_type(answer_value)
352
+ step = FinalAnswerStep(
353
+ output=final_answer,
354
+ )
355
+ messages = list(pull_messages_from_step(step))
356
+ assert len(messages) == 1
357
+ assert messages[0].content == expected_content
358
+
359
+ def test_final_answer_step_image(self):
360
+ """Test FinalAnswerStep with image answer."""
361
+ with patch.object(AgentImage, "to_string", return_value="path/to/image.png"):
362
+ step = FinalAnswerStep(output=AgentImage("path/to/image.png"))
363
+ messages = list(pull_messages_from_step(step))
364
+ assert len(messages) == 1
365
+ assert messages[0].content["path"] == "path/to/image.png"
366
+ assert messages[0].content["mime_type"] == "image/png"
367
+
368
+ def test_final_answer_step_audio(self):
369
+ """Test FinalAnswerStep with audio answer."""
370
+ with patch.object(AgentAudio, "to_string", return_value="path/to/audio.wav"):
371
+ step = FinalAnswerStep(output=AgentAudio("path/to/audio.wav"))
372
+ messages = list(pull_messages_from_step(step))
373
+ assert len(messages) == 1
374
+ assert messages[0].content["path"] == "path/to/audio.wav"
375
+ assert messages[0].content["mime_type"] == "audio/wav"
376
+
377
+ def test_unsupported_step_type(self):
378
+ """Test handling of unsupported step types."""
379
+
380
+ class UnsupportedStep(Mock):
381
+ pass
382
+
383
+ step = UnsupportedStep()
384
+ with pytest.raises(ValueError, match="Unsupported step type"):
385
+ list(pull_messages_from_step(step))
tests/test_import.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+
5
+
6
+ def test_import_smolagents_without_extras(monkeypatch):
7
+ monkeypatch.delenv("VIRTUAL_ENV", raising=False)
8
+ with tempfile.TemporaryDirectory() as temp_dir:
9
+ # Create a virtual environment
10
+ venv_dir = os.path.join(temp_dir, "venv")
11
+ subprocess.run(["uv", "venv", venv_dir], check=True)
12
+
13
+ # Install smolagents in the virtual environment
14
+ subprocess.run(
15
+ ["uv", "pip", "install", "--python", os.path.join(venv_dir, "bin", "python"), "smolagents @ ."], check=True
16
+ )
17
+
18
+ # Run the import test in the virtual environment
19
+ result = subprocess.run(
20
+ [os.path.join(venv_dir, "bin", "python"), "-c", "import smolagents"],
21
+ capture_output=True,
22
+ text=True,
23
+ )
24
+
25
+ # Check if the import was successful
26
+ assert result.returncode == 0, (
27
+ "Import failed with error: "
28
+ + (result.stderr.splitlines()[-1] if result.stderr else "No error message")
29
+ + "\n"
30
+ + result.stderr
31
+ )
tests/test_local_python_executor.py ADDED
@@ -0,0 +1,2353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import types
18
+ from contextlib import nullcontext as does_not_raise
19
+ from textwrap import dedent
20
+ from unittest.mock import patch
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ import pytest
25
+
26
+ from smolagents.default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
27
+ from smolagents.local_python_executor import (
28
+ DANGEROUS_FUNCTIONS,
29
+ DANGEROUS_MODULES,
30
+ InterpreterError,
31
+ LocalPythonExecutor,
32
+ PrintContainer,
33
+ check_import_authorized,
34
+ evaluate_boolop,
35
+ evaluate_condition,
36
+ evaluate_delete,
37
+ evaluate_python_code,
38
+ evaluate_subscript,
39
+ fix_final_answer_code,
40
+ get_safe_module,
41
+ )
42
+
43
+
44
+ # Fake function we will use as tool
45
+ def add_two(x):
46
+ return x + 2
47
+
48
+
49
+ class TestEvaluatePythonCode:
50
+ def assertDictEqualNoPrint(self, dict1, dict2):
51
+ assert {k: v for k, v in dict1.items() if k != "_print_outputs"} == {
52
+ k: v for k, v in dict2.items() if k != "_print_outputs"
53
+ }
54
+
55
+ def test_evaluate_assign(self):
56
+ code = "x = 3"
57
+ state = {}
58
+ result, _ = evaluate_python_code(code, {}, state=state)
59
+ assert result == 3
60
+ self.assertDictEqualNoPrint(state, {"x": 3, "_operations_count": {"counter": 2}})
61
+
62
+ code = "x = y"
63
+ state = {"y": 5}
64
+ result, _ = evaluate_python_code(code, {}, state=state)
65
+ # evaluate returns the value of the last assignment.
66
+ assert result == 5
67
+ self.assertDictEqualNoPrint(state, {"x": 5, "y": 5, "_operations_count": {"counter": 2}})
68
+
69
+ code = "a=1;b=None"
70
+ result, _ = evaluate_python_code(code, {}, state={})
71
+ # evaluate returns the value of the last assignment.
72
+ assert result is None
73
+
74
+ def test_assignment_cannot_overwrite_tool(self):
75
+ code = "print = '3'"
76
+ with pytest.raises(InterpreterError) as e:
77
+ evaluate_python_code(code, {"print": print}, state={})
78
+ assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
79
+
80
+ def test_subscript_call(self):
81
+ code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
82
+ state = {}
83
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
84
+ assert result == 64
85
+ assert state["result_foo"] == 8
86
+ assert state["result_boo"] == 64
87
+
88
+ def test_evaluate_call(self):
89
+ code = "y = add_two(x)"
90
+ state = {"x": 3}
91
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
92
+ assert result == 5
93
+ self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": {"counter": 3}})
94
+
95
+ # Should not work without the tool
96
+ with pytest.raises(InterpreterError, match="Forbidden function evaluation: 'add_two'"):
97
+ evaluate_python_code(code, {}, state=state)
98
+
99
+ def test_evaluate_class_def(self):
100
+ code = dedent('''\
101
+ class MyClass:
102
+ """A class with a value."""
103
+
104
+ def __init__(self, value):
105
+ self.value = value
106
+
107
+ def get_value(self):
108
+ return self.value
109
+
110
+ instance = MyClass(42)
111
+ result = instance.get_value()
112
+ ''')
113
+ state = {}
114
+ result, _ = evaluate_python_code(code, {}, state=state)
115
+ assert result == 42
116
+ assert state["instance"].__doc__ == "A class with a value."
117
+
118
+ def test_evaluate_class_def_with_assign_attribute_target(self):
119
+ """
120
+ Test evaluate_class_def function when stmt is an instance of ast.Assign with ast.Attribute target.
121
+ """
122
+ code = dedent("""
123
+ class TestSubClass:
124
+ attr1 = 1
125
+ class TestClass:
126
+ data = TestSubClass()
127
+ data.attr1 = "value1"
128
+ data.attr2 = "value2"
129
+ result = (TestClass.data.attr1, TestClass.data.attr2)
130
+ """)
131
+
132
+ state = {}
133
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
134
+
135
+ assert result == ("value1", "value2")
136
+ assert isinstance(state["TestClass"], type)
137
+ assert state["TestClass"].data.attr1 == "value1"
138
+ assert state["TestClass"].data.attr2 == "value2"
139
+
140
+ def test_evaluate_constant(self):
141
+ code = "x = 3"
142
+ state = {}
143
+ result, _ = evaluate_python_code(code, {}, state=state)
144
+ assert result == 3
145
+ self.assertDictEqualNoPrint(state, {"x": 3, "_operations_count": {"counter": 2}})
146
+
147
+ def test_evaluate_dict(self):
148
+ code = "test_dict = {'x': x, 'y': add_two(x)}"
149
+ state = {"x": 3}
150
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
151
+ assert result == {"x": 3, "y": 5}
152
+ self.assertDictEqualNoPrint(
153
+ state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": {"counter": 7}}
154
+ )
155
+
156
+ def test_evaluate_expression(self):
157
+ code = "x = 3\ny = 5"
158
+ state = {}
159
+ result, _ = evaluate_python_code(code, {}, state=state)
160
+ # evaluate returns the value of the last assignment.
161
+ assert result == 5
162
+ self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": {"counter": 4}})
163
+
164
+ def test_evaluate_f_string(self):
165
+ code = "text = f'This is x: {x}.'"
166
+ state = {"x": 3}
167
+ result, _ = evaluate_python_code(code, {}, state=state)
168
+ # evaluate returns the value of the last assignment.
169
+ assert result == "This is x: 3."
170
+ self.assertDictEqualNoPrint(state, {"x": 3, "text": "This is x: 3.", "_operations_count": {"counter": 6}})
171
+
172
+ def test_evaluate_f_string_with_format(self):
173
+ code = "text = f'This is x: {x:.2f}.'"
174
+ state = {"x": 3.336}
175
+ result, _ = evaluate_python_code(code, {}, state=state)
176
+ assert result == "This is x: 3.34."
177
+ self.assertDictEqualNoPrint(
178
+ state, {"x": 3.336, "text": "This is x: 3.34.", "_operations_count": {"counter": 8}}
179
+ )
180
+
181
+ def test_evaluate_f_string_with_complex_format(self):
182
+ code = "text = f'This is x: {x:>{width}.{precision}f}.'"
183
+ state = {"x": 3.336, "width": 10, "precision": 2}
184
+ result, _ = evaluate_python_code(code, {}, state=state)
185
+ assert result == "This is x: 3.34."
186
+ self.assertDictEqualNoPrint(
187
+ state,
188
+ {
189
+ "x": 3.336,
190
+ "width": 10,
191
+ "precision": 2,
192
+ "text": "This is x: 3.34.",
193
+ "_operations_count": {"counter": 14},
194
+ },
195
+ )
196
+
197
+ def test_evaluate_if(self):
198
+ code = "if x <= 3:\n y = 2\nelse:\n y = 5"
199
+ state = {"x": 3}
200
+ result, _ = evaluate_python_code(code, {}, state=state)
201
+ # evaluate returns the value of the last assignment.
202
+ assert result == 2
203
+ self.assertDictEqualNoPrint(state, {"x": 3, "y": 2, "_operations_count": {"counter": 6}})
204
+
205
+ state = {"x": 8}
206
+ result, _ = evaluate_python_code(code, {}, state=state)
207
+ # evaluate returns the value of the last assignment.
208
+ assert result == 5
209
+ self.assertDictEqualNoPrint(state, {"x": 8, "y": 5, "_operations_count": {"counter": 6}})
210
+
211
+ def test_evaluate_list(self):
212
+ code = "test_list = [x, add_two(x)]"
213
+ state = {"x": 3}
214
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
215
+ assert result == [3, 5]
216
+ self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": {"counter": 5}})
217
+
218
+ def test_evaluate_name(self):
219
+ code = "y = x"
220
+ state = {"x": 3}
221
+ result, _ = evaluate_python_code(code, {}, state=state)
222
+ assert result == 3
223
+ self.assertDictEqualNoPrint(state, {"x": 3, "y": 3, "_operations_count": {"counter": 2}})
224
+
225
+ def test_evaluate_subscript(self):
226
+ code = "test_list = [x, add_two(x)]\ntest_list[1]"
227
+ state = {"x": 3}
228
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
229
+ assert result == 5
230
+ self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": {"counter": 9}})
231
+
232
+ code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
233
+ state = {"x": 3}
234
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
235
+ assert result == 5
236
+ self.assertDictEqualNoPrint(
237
+ state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": {"counter": 11}}
238
+ )
239
+
240
+ code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
241
+ state = {}
242
+ evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
243
+ assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
244
+
245
+ def test_subscript_string_with_string_index_raises_appropriate_error(self):
246
+ code = """
247
+ search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
248
+ for result in search_results:
249
+ if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
250
+ current_weather_url = result['href']
251
+ print(current_weather_url)
252
+ break"""
253
+ with pytest.raises(InterpreterError) as e:
254
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
255
+ assert "You're trying to subscript a string with a string index" in e
256
+
257
+ def test_evaluate_for(self):
258
+ code = "x = 0\nfor i in range(3):\n x = i"
259
+ state = {}
260
+ result, _ = evaluate_python_code(code, {"range": range}, state=state)
261
+ assert result == 2
262
+ self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": {"counter": 11}})
263
+
264
+ def test_evaluate_binop(self):
265
+ code = "y + x"
266
+ state = {"x": 3, "y": 6}
267
+ result, _ = evaluate_python_code(code, {}, state=state)
268
+ assert result == 9
269
+ self.assertDictEqualNoPrint(state, {"x": 3, "y": 6, "_operations_count": {"counter": 4}})
270
+
271
+ def test_recursive_function(self):
272
+ code = """
273
+ def recur_fibo(n):
274
+ if n <= 1:
275
+ return n
276
+ else:
277
+ return(recur_fibo(n-1) + recur_fibo(n-2))
278
+ recur_fibo(6)"""
279
+ result, _ = evaluate_python_code(code, {}, state={})
280
+ assert result == 8
281
+
282
+ def test_max_operations(self):
283
+ # Check that operation counter is not reset in functions
284
+ code = dedent(
285
+ """
286
+ def func(a):
287
+ for j in range(10):
288
+ a += j
289
+ return a
290
+
291
+ for i in range(5):
292
+ func(i)
293
+ """
294
+ )
295
+ with patch("smolagents.local_python_executor.MAX_OPERATIONS", 100):
296
+ with pytest.raises(InterpreterError) as exception_info:
297
+ evaluate_python_code(code, {"range": range}, state={})
298
+ assert "Reached the max number of operations" in str(exception_info.value)
299
+
300
+ def test_operations_count(self):
301
+ # Check that operation counter is not reset in functions
302
+ code = dedent(
303
+ """
304
+ def func():
305
+ return 0
306
+
307
+ func()
308
+ """
309
+ )
310
+ state = {}
311
+ evaluate_python_code(code, {"range": range}, state=state)
312
+ assert state["_operations_count"]["counter"] == 5
313
+
314
+ def test_evaluate_string_methods(self):
315
+ code = "'hello'.replace('h', 'o').split('e')"
316
+ result, _ = evaluate_python_code(code, {}, state={})
317
+ assert result == ["o", "llo"]
318
+
319
+ def test_evaluate_slicing(self):
320
+ code = "'hello'[1:3][::-1]"
321
+ result, _ = evaluate_python_code(code, {}, state={})
322
+ assert result == "le"
323
+
324
+ def test_access_attributes(self):
325
+ class A:
326
+ attr = 2
327
+
328
+ code = "A.attr"
329
+ result, _ = evaluate_python_code(code, {}, state={"A": A})
330
+ assert result == 2
331
+
332
+ def test_list_comprehension(self):
333
+ code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
334
+ result, _ = evaluate_python_code(code, {}, state={})
335
+ assert result == "t-h-e-s-e-a-g-u-l-l"
336
+
337
+ def test_string_indexing(self):
338
+ code = """text_block = [
339
+ "THESE",
340
+ "AGULL"
341
+ ]
342
+ sentence = ""
343
+ for block in text_block:
344
+ for col in range(len(text_block[0])):
345
+ sentence += block[col]
346
+ """
347
+ result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
348
+ assert result == "THESEAGULL"
349
+
350
+ def test_tuples(self):
351
+ code = "x = (1, 2, 3)\nx[1]"
352
+ result, _ = evaluate_python_code(code, {}, state={})
353
+ assert result == 2
354
+
355
+ code = """
356
+ digits, i = [1, 2, 3], 1
357
+ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
358
+ evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
359
+
360
+ code = """
361
+ def calculate_isbn_10_check_digit(number):
362
+ total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
363
+ remainder = total % 11
364
+ check_digit = 11 - remainder
365
+ if check_digit == 10:
366
+ return 'X'
367
+ elif check_digit == 11:
368
+ return '0'
369
+ else:
370
+ return str(check_digit)
371
+
372
+ # Given 9-digit numbers
373
+ numbers = [
374
+ "478225952",
375
+ "643485613",
376
+ "739394228",
377
+ "291726859",
378
+ "875262394",
379
+ "542617795",
380
+ "031810713",
381
+ "957007669",
382
+ "871467426"
383
+ ]
384
+
385
+ # Calculate check digits for each number
386
+ check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
387
+ print(check_digits)
388
+ """
389
+ state = {}
390
+ evaluate_python_code(
391
+ code,
392
+ {
393
+ "range": range,
394
+ "print": print,
395
+ "sum": sum,
396
+ "enumerate": enumerate,
397
+ "int": int,
398
+ "str": str,
399
+ },
400
+ state,
401
+ )
402
+
403
+ def test_listcomp(self):
404
+ code = "x = [i for i in range(3)]"
405
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
406
+ assert result == [0, 1, 2]
407
+
408
+ def test_setcomp(self):
409
+ code = "batman_times = {entry['time'] for entry in [{'time': 10}, {'time': 19}, {'time': 20}]}"
410
+ result, _ = evaluate_python_code(code, {}, state={})
411
+ assert result == {10, 19, 20}
412
+
413
+ def test_break_continue(self):
414
+ code = "for i in range(10):\n if i == 5:\n break\ni"
415
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
416
+ assert result == 5
417
+
418
+ code = "for i in range(10):\n if i == 5:\n continue\ni"
419
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
420
+ assert result == 9
421
+
422
+ def test_call_int(self):
423
+ code = "import math\nstr(math.ceil(149))"
424
+ result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
425
+ assert result == "149"
426
+
427
+ def test_lambda(self):
428
+ code = "f = lambda x: x + 2\nf(3)"
429
+ result, _ = evaluate_python_code(code, {}, state={})
430
+ assert result == 5
431
+
432
+ def test_dictcomp(self):
433
+ code = "x = {i: i**2 for i in range(3)}"
434
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
435
+ assert result == {0: 0, 1: 1, 2: 4}
436
+
437
+ code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
438
+ result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
439
+ assert result == {102: "b"}
440
+
441
+ code = """
442
+ shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
443
+ shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
444
+ """
445
+ result, _ = evaluate_python_code(code, {}, state={})
446
+ assert result == {"A": ("a", "b"), "B": ("a", "b")}
447
+
448
+ def test_tuple_assignment(self):
449
+ code = "a, b = 0, 1\nb"
450
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
451
+ assert result == 1
452
+
453
+ def test_while(self):
454
+ code = "i = 0\nwhile i < 3:\n i += 1\ni"
455
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
456
+ assert result == 3
457
+
458
+ # test infinite loop
459
+ code = "i = 0\nwhile i < 3:\n i -= 1\ni"
460
+ with patch("smolagents.local_python_executor.MAX_WHILE_ITERATIONS", 100):
461
+ with pytest.raises(InterpreterError, match=".*Maximum number of 100 iterations in While loop exceeded"):
462
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
463
+
464
+ # test lazy evaluation
465
+ code = dedent(
466
+ """
467
+ house_positions = [0, 7, 10, 15, 18, 22, 22]
468
+ i, n, loc = 0, 7, 30
469
+ while i < n and house_positions[i] <= loc:
470
+ i += 1
471
+ """
472
+ )
473
+ state = {}
474
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
475
+
476
+ def test_generator(self):
477
+ code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
478
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
479
+ assert result == [1, 4, 9, 16, 25]
480
+
481
+ def test_boolops(self):
482
+ code = """if (not (a > b and a > c)) or d > e:
483
+ best_city = "Brooklyn"
484
+ else:
485
+ best_city = "Manhattan"
486
+ best_city
487
+ """
488
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
489
+ assert result == "Brooklyn"
490
+
491
+ code = """if d > e and a < b:
492
+ best_city = "Brooklyn"
493
+ elif d < e and a < b:
494
+ best_city = "Sacramento"
495
+ else:
496
+ best_city = "Manhattan"
497
+ best_city
498
+ """
499
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
500
+ assert result == "Sacramento"
501
+
502
+ # Short-circuit evaluation:
503
+ # (T and 0) or (T and T) => 0 or True => True
504
+ code = "result = (x > 3 and y) or (z == 10 and not y)\nresult"
505
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"x": 5, "y": 0, "z": 10})
506
+ assert result
507
+
508
+ # (None or "") or "Found" => "" or "Found" => "Found"
509
+ code = "result = (a or c) or b\nresult"
510
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": None, "b": "Found", "c": ""})
511
+ assert result == "Found"
512
+
513
+ # ("First" and "") or "Third" => "" or "Third" -> "Third"
514
+ code = "result = (a and b) or c\nresult"
515
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": "First", "b": "", "c": "Third"})
516
+ assert result == "Third"
517
+
518
+ def test_if_conditions(self):
519
+ code = """char='a'
520
+ if char.isalpha():
521
+ print('2')"""
522
+ state = {}
523
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
524
+ assert state["_print_outputs"].value == "2\n"
525
+
526
+ def test_imports(self):
527
+ code = "import math\nmath.sqrt(4)"
528
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
529
+ assert result == 2.0
530
+
531
+ code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
532
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
533
+ assert result == "lose"
534
+
535
+ code = "import time, re\ntime.sleep(0.1)"
536
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
537
+ assert result is None
538
+
539
+ code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
540
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
541
+ assert result == 1
542
+
543
+ code = "import itertools\nlist(itertools.islice(range(10), 3))"
544
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
545
+ assert result == [0, 1, 2]
546
+
547
+ code = "import re\nre.search('a', 'abc').group()"
548
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
549
+ assert result == "a"
550
+
551
+ code = "import stat\nstat.S_ISREG(0o100644)"
552
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
553
+ assert result
554
+
555
+ code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
556
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
557
+ assert result == 2.8
558
+
559
+ code = "import unicodedata\nunicodedata.name('A')"
560
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
561
+ assert result == "LATIN CAPITAL LETTER A"
562
+
563
+ # Test submodules are handled properly, thus not raising error
564
+ code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
565
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
566
+
567
+ code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
568
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
569
+
570
+ def test_additional_imports(self):
571
+ code = "import numpy as np"
572
+ evaluate_python_code(code, authorized_imports=["numpy"], state={})
573
+
574
+ # Test that allowing 'numpy.*' allows numpy root package and its submodules
575
+ code = "import numpy as np\nnp.random.default_rng(123)\nnp.array([1, 2])"
576
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.*"])
577
+
578
+ # Test that allowing 'numpy.*' allows importing a submodule
579
+ code = "import numpy.random as rd\nrd.default_rng(12345)"
580
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.*"])
581
+
582
+ code = "import numpy.random as rd"
583
+ evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
584
+ evaluate_python_code(code, authorized_imports=["numpy.*"], state={})
585
+ evaluate_python_code(code, authorized_imports=["*"], state={})
586
+ with pytest.raises(InterpreterError):
587
+ evaluate_python_code(code, authorized_imports=["random"], state={})
588
+
589
+ with pytest.raises(InterpreterError):
590
+ evaluate_python_code(code, authorized_imports=["numpy.a"], state={})
591
+ with pytest.raises(InterpreterError):
592
+ evaluate_python_code(code, authorized_imports=["numpy.a.*"], state={})
593
+
594
+ def test_multiple_comparators(self):
595
+ code = "0 <= -1 < 4 and 0 <= -5 < 4"
596
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
597
+ assert not result
598
+
599
+ code = "0 <= 1 < 4 and 0 <= -5 < 4"
600
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
601
+ assert not result
602
+
603
+ code = "0 <= 4 < 4 and 0 <= 3 < 4"
604
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
605
+ assert not result
606
+
607
+ code = "0 <= 3 < 4 and 0 <= 3 < 4"
608
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
609
+ assert result
610
+
611
+ def test_print_output(self):
612
+ code = "print('Hello world!')\nprint('Ok no one cares')"
613
+ state = {}
614
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
615
+ assert result is None
616
+ assert state["_print_outputs"].value == "Hello world!\nOk no one cares\n"
617
+
618
+ # Test print in function (state copy)
619
+ code = """
620
+ print("1")
621
+ def function():
622
+ print("2")
623
+ function()"""
624
+ state = {}
625
+ evaluate_python_code(code, {"print": print}, state=state)
626
+ assert state["_print_outputs"].value == "1\n2\n"
627
+
628
+ # Test print in list comprehension (state copy)
629
+ code = """
630
+ print("1")
631
+ def function():
632
+ print("2")
633
+ [function() for i in range(10)]"""
634
+ state = {}
635
+ evaluate_python_code(code, {"print": print, "range": range}, state=state)
636
+ assert state["_print_outputs"].value == "1\n2\n2\n2\n2\n2\n2\n2\n2\n2\n2\n"
637
+
638
+ def test_tuple_target_in_iterator(self):
639
+ code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
640
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
641
+ assert result == "Samuel"
642
+
643
+ def test_classes(self):
644
+ code = """
645
+ class Animal:
646
+ species = "Generic Animal"
647
+
648
+ def __init__(self, name, age):
649
+ self.name = name
650
+ self.age = age
651
+
652
+ def sound(self):
653
+ return "The animal makes a sound."
654
+
655
+ def __str__(self):
656
+ return f"{self.name}, {self.age} years old"
657
+
658
+ class Dog(Animal):
659
+ species = "Canine"
660
+
661
+ def __init__(self, name, age, breed):
662
+ super().__init__(name, age)
663
+ self.breed = breed
664
+
665
+ def sound(self):
666
+ return "The dog barks."
667
+
668
+ def __str__(self):
669
+ return f"{self.name}, {self.age} years old, {self.breed}"
670
+
671
+ class Cat(Animal):
672
+ def sound(self):
673
+ return "The cat meows."
674
+
675
+ def __str__(self):
676
+ return f"{self.name}, {self.age} years old, {self.species}"
677
+
678
+
679
+ # Testing multiple instances
680
+ dog1 = Dog("Fido", 3, "Labrador")
681
+ dog2 = Dog("Buddy", 5, "Golden Retriever")
682
+
683
+ # Testing method with built-in function
684
+ animals = [dog1, dog2, Cat("Whiskers", 2)]
685
+ num_animals = len(animals)
686
+
687
+ # Testing exceptions in methods
688
+ class ExceptionTest:
689
+ def method_that_raises(self):
690
+ raise ValueError("An error occurred")
691
+
692
+ try:
693
+ exc_test = ExceptionTest()
694
+ exc_test.method_that_raises()
695
+ except ValueError as e:
696
+ exception_message = str(e)
697
+
698
+
699
+ # Collecting results
700
+ dog1_sound = dog1.sound()
701
+ dog1_str = str(dog1)
702
+ dog2_sound = dog2.sound()
703
+ dog2_str = str(dog2)
704
+ cat = Cat("Whiskers", 2)
705
+ cat_sound = cat.sound()
706
+ cat_str = str(cat)
707
+ """
708
+ state = {}
709
+ evaluate_python_code(
710
+ code,
711
+ {"print": print, "len": len, "super": super, "str": str, "sum": sum},
712
+ state=state,
713
+ )
714
+
715
+ # Assert results
716
+ assert state["dog1_sound"] == "The dog barks."
717
+ assert state["dog1_str"] == "Fido, 3 years old, Labrador"
718
+ assert state["dog2_sound"] == "The dog barks."
719
+ assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
720
+ assert state["cat_sound"] == "The cat meows."
721
+ assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
722
+ assert state["num_animals"] == 3
723
+ assert state["exception_message"] == "An error occurred"
724
+
725
+ def test_variable_args(self):
726
+ code = """
727
+ def var_args_method(self, *args, **kwargs):
728
+ return sum(args) + sum(kwargs.values())
729
+
730
+ var_args_method(1, 2, 3, x=4, y=5)
731
+ """
732
+ state = {}
733
+ result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
734
+ assert result == 15
735
+
736
+ def test_exceptions(self):
737
+ code = """
738
+ def method_that_raises(self):
739
+ raise ValueError("An error occurred")
740
+
741
+ try:
742
+ method_that_raises()
743
+ except ValueError as e:
744
+ exception_message = str(e)
745
+ """
746
+ state = {}
747
+ evaluate_python_code(
748
+ code,
749
+ {"print": print, "len": len, "super": super, "str": str, "sum": sum},
750
+ state=state,
751
+ )
752
+ assert state["exception_message"] == "An error occurred"
753
+
754
+ def test_print(self):
755
+ code = "print(min([1, 2, 3]))"
756
+ state = {}
757
+ evaluate_python_code(code, {"min": min, "print": print}, state=state)
758
+ assert state["_print_outputs"].value == "1\n"
759
+
760
+ def test_types_as_objects(self):
761
+ code = "type_a = float(2); type_b = str; type_c = int"
762
+ state = {}
763
+ result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
764
+ # Type objects are not wrapped by safer_func
765
+ assert not hasattr(result, "__wrapped__")
766
+ assert result is int
767
+
768
+ def test_tuple_id(self):
769
+ code = """
770
+ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
771
+ unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
772
+ """
773
+ state = {}
774
+ result, is_final_answer = evaluate_python_code(code, {}, state=state)
775
+ assert result == ["orange", "pear"]
776
+
777
+ def test_nonsimple_augassign(self):
778
+ code = """
779
+ counts_dict = {'a': 0}
780
+ counts_dict['a'] += 1
781
+ counts_list = [1, 2, 3]
782
+ counts_list += [4, 5, 6]
783
+
784
+ class Counter:
785
+ def __init__(self):
786
+ self.count = 0
787
+
788
+ a = Counter()
789
+ a.count += 1
790
+ """
791
+ state = {}
792
+ evaluate_python_code(code, {}, state=state)
793
+ assert state["counts_dict"] == {"a": 1}
794
+ assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
795
+ assert state["a"].count == 1
796
+
797
+ def test_adding_int_to_list_raises_error(self):
798
+ code = """
799
+ counts = [1, 2, 3]
800
+ counts += 1"""
801
+ with pytest.raises(InterpreterError) as e:
802
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
803
+ assert "Cannot add non-list value 1 to a list." in str(e)
804
+
805
+ def test_error_highlights_correct_line_of_code(self):
806
+ code = """a = 1
807
+ b = 2
808
+
809
+ counts = [1, 2, 3]
810
+ counts += 1
811
+ b += 1"""
812
+ with pytest.raises(InterpreterError) as e:
813
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
814
+ assert "Code execution failed at line 'counts += 1" in str(e)
815
+
816
+ def test_error_type_returned_in_function_call(self):
817
+ code = """def error_function():
818
+ raise ValueError("error")
819
+
820
+ error_function()"""
821
+ with pytest.raises(InterpreterError) as e:
822
+ evaluate_python_code(code)
823
+ assert "error" in str(e)
824
+ assert "ValueError" in str(e)
825
+
826
+ def test_assert(self):
827
+ code = """
828
+ assert 1 == 1
829
+ assert 1 == 2
830
+ """
831
+ with pytest.raises(InterpreterError) as e:
832
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
833
+ assert "1 == 2" in str(e) and "1 == 1" not in str(e)
834
+
835
+ def test_with_context_manager(self):
836
+ code = """
837
+ class SimpleLock:
838
+ def __init__(self):
839
+ self.locked = False
840
+
841
+ def __enter__(self):
842
+ self.locked = True
843
+ return self
844
+
845
+ def __exit__(self, exc_type, exc_value, traceback):
846
+ self.locked = False
847
+
848
+ lock = SimpleLock()
849
+
850
+ with lock as l:
851
+ assert l.locked == True
852
+
853
+ assert lock.locked == False
854
+ """
855
+ state = {}
856
+ tools = {}
857
+ evaluate_python_code(code, tools, state=state)
858
+
859
+ def test_default_arg_in_function(self):
860
+ code = """
861
+ def f(a, b=333, n=1000):
862
+ return b + n
863
+ n = f(1, n=667)
864
+ """
865
+ res, is_final_answer = evaluate_python_code(code, {}, {})
866
+ assert res == 1000
867
+ assert not is_final_answer
868
+
869
+ def test_set(self):
870
+ code = """
871
+ S1 = {'a', 'b', 'c'}
872
+ S2 = {'b', 'c', 'd'}
873
+ S3 = S1.difference(S2)
874
+ S4 = S1.intersection(S2)
875
+ """
876
+ state = {}
877
+ evaluate_python_code(code, {}, state=state)
878
+ assert state["S3"] == {"a"}
879
+ assert state["S4"] == {"b", "c"}
880
+
881
+ def test_break(self):
882
+ code = """
883
+ i = 0
884
+
885
+ while True:
886
+ i+= 1
887
+ if i==3:
888
+ break
889
+
890
+ i"""
891
+ result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
892
+ assert result == 3
893
+ assert not is_final_answer
894
+
895
+ def test_return(self):
896
+ # test early returns
897
+ code = """
898
+ def add_one(n, shift):
899
+ if True:
900
+ return n + shift
901
+ return n
902
+
903
+ add_one(1, 1)
904
+ """
905
+ state = {}
906
+ result, is_final_answer = evaluate_python_code(
907
+ code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
908
+ )
909
+ assert result == 2
910
+
911
+ # test returning None
912
+ code = """
913
+ def returns_none(a):
914
+ return
915
+
916
+ returns_none(1)
917
+ """
918
+ state = {}
919
+ result, is_final_answer = evaluate_python_code(
920
+ code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
921
+ )
922
+ assert result is None
923
+
924
+ def test_nested_for_loop(self):
925
+ code = """
926
+ all_res = []
927
+ for i in range(10):
928
+ subres = []
929
+ for j in range(i):
930
+ subres.append(j)
931
+ all_res.append(subres)
932
+
933
+ out = [i for sublist in all_res for i in sublist]
934
+ out[:10]
935
+ """
936
+ state = {}
937
+ result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
938
+ assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
939
+
940
+ def test_pandas(self):
941
+ code = """
942
+ import pandas as pd
943
+
944
+ df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
945
+
946
+ df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
947
+
948
+ parts_with_5_set_count = df[df['SetCount'] == 5.0]
949
+ parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
950
+ """
951
+ state = {}
952
+ result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
953
+ assert np.array_equal(result, [-1, 5])
954
+
955
+ code = """
956
+ import pandas as pd
957
+
958
+ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
959
+
960
+ # Filter the DataFrame to get only the rows with outdated atomic numbers
961
+ filtered_df = df.loc[df['AtomicNumber'].isin([104])]
962
+ """
963
+ result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
964
+ assert np.array_equal(result.values[0], [104, 1])
965
+
966
+ # Test groupby
967
+ code = """import pandas as pd
968
+ data = pd.DataFrame.from_dict([
969
+ {"Pclass": 1, "Survived": 1},
970
+ {"Pclass": 2, "Survived": 0},
971
+ {"Pclass": 2, "Survived": 1}
972
+ ])
973
+ survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
974
+ """
975
+ result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
976
+ assert result.values[1] == 0.5
977
+
978
+ # Test loc and iloc
979
+ code = """import pandas as pd
980
+ data = pd.DataFrame.from_dict([
981
+ {"Pclass": 1, "Survived": 1},
982
+ {"Pclass": 2, "Survived": 0},
983
+ {"Pclass": 2, "Survived": 1}
984
+ ])
985
+ survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
986
+ survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
987
+ survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
988
+ """
989
+ result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
990
+
991
+ def test_starred(self):
992
+ code = """
993
+ from math import radians, sin, cos, sqrt, atan2
994
+
995
+ def haversine(lat1, lon1, lat2, lon2):
996
+ R = 6371000 # Radius of the Earth in meters
997
+ lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
998
+ dlat = lat2 - lat1
999
+ dlon = lon2 - lon1
1000
+ a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
1001
+ c = 2 * atan2(sqrt(a), sqrt(1 - a))
1002
+ distance = R * c
1003
+ return distance
1004
+
1005
+ coords_geneva = (46.1978, 6.1342)
1006
+ coords_barcelona = (41.3869, 2.1660)
1007
+
1008
+ distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
1009
+ """
1010
+ result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
1011
+ assert round(result, 1) == 622395.4
1012
+
1013
+ def test_for(self):
1014
+ code = """
1015
+ shifts = {
1016
+ "Worker A": ("6:45 pm", "8:00 pm"),
1017
+ "Worker B": ("10:00 am", "11:45 am")
1018
+ }
1019
+
1020
+ shift_intervals = {}
1021
+ for worker, (start, end) in shifts.items():
1022
+ shift_intervals[worker] = end
1023
+ shift_intervals
1024
+ """
1025
+ result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
1026
+ assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
1027
+
1028
+ def test_syntax_error_points_error(self):
1029
+ code = "a = ;"
1030
+ with pytest.raises(InterpreterError) as e:
1031
+ evaluate_python_code(code)
1032
+ assert "SyntaxError" in str(e)
1033
+ assert " ^" in str(e)
1034
+
1035
+ def test_close_matches_subscript(self):
1036
+ code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
1037
+ with pytest.raises(Exception) as e:
1038
+ evaluate_python_code(code)
1039
+ assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "")
1040
+
1041
+ def test_dangerous_builtins_calls_are_blocked(self):
1042
+ unsafe_code = "import os"
1043
+ dangerous_code = f"""
1044
+ exec = callable.__self__.exec
1045
+ compile = callable.__self__.compile
1046
+ exec(compile('{unsafe_code}', 'no filename', 'exec'))
1047
+ """
1048
+
1049
+ with pytest.raises(InterpreterError):
1050
+ evaluate_python_code(unsafe_code, static_tools=BASE_PYTHON_TOOLS)
1051
+
1052
+ with pytest.raises(InterpreterError):
1053
+ evaluate_python_code(dangerous_code, static_tools=BASE_PYTHON_TOOLS)
1054
+
1055
+ def test_final_answer_accepts_kwarg_answer(self):
1056
+ code = "final_answer(answer=2)"
1057
+ result, _ = evaluate_python_code(code, {"final_answer": (lambda answer: 2 * answer)}, state={})
1058
+ assert result == 4
1059
+
1060
+ def test_dangerous_builtins_are_callable_if_explicitly_added(self):
1061
+ dangerous_code = dedent("""
1062
+ eval("1 + 1")
1063
+ exec(compile("1 + 1", "no filename", "exec"))
1064
+ """)
1065
+ evaluate_python_code(
1066
+ dangerous_code, static_tools={"compile": compile, "eval": eval, "exec": exec} | BASE_PYTHON_TOOLS
1067
+ )
1068
+
1069
+ def test_can_import_os_if_explicitly_authorized(self):
1070
+ dangerous_code = "import os; os.listdir('./')"
1071
+ evaluate_python_code(dangerous_code, authorized_imports=["os"])
1072
+
1073
+ def test_can_import_os_if_all_imports_authorized(self):
1074
+ dangerous_code = "import os; os.listdir('./')"
1075
+ evaluate_python_code(dangerous_code, authorized_imports=["*"])
1076
+
1077
+ @pytest.mark.filterwarnings("ignore::DeprecationWarning")
1078
+ def test_can_import_scipy_if_explicitly_authorized(self):
1079
+ code = "import scipy"
1080
+ evaluate_python_code(code, authorized_imports=["scipy"])
1081
+
1082
+ @pytest.mark.filterwarnings("ignore::DeprecationWarning")
1083
+ def test_can_import_sklearn_if_explicitly_authorized(self):
1084
+ code = "import sklearn"
1085
+ evaluate_python_code(code, authorized_imports=["sklearn"])
1086
+
1087
+ def test_function_def_recovers_source_code(self):
1088
+ executor = LocalPythonExecutor([])
1089
+
1090
+ executor.send_tools({"final_answer": FinalAnswerTool()})
1091
+
1092
+ res, _, _ = executor(
1093
+ dedent(
1094
+ """
1095
+ def target_function():
1096
+ return "Hello world"
1097
+
1098
+ final_answer(target_function)
1099
+ """
1100
+ )
1101
+ )
1102
+ assert res.__name__ == "target_function"
1103
+ assert res.__source__ == "def target_function():\n return 'Hello world'"
1104
+
1105
+ def test_evaluate_class_def_with_pass(self):
1106
+ code = dedent("""
1107
+ class TestClass:
1108
+ pass
1109
+
1110
+ instance = TestClass()
1111
+ instance.attr = "value"
1112
+ result = instance.attr
1113
+ """)
1114
+ state = {}
1115
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
1116
+ assert result == "value"
1117
+
1118
+ def test_evaluate_class_def_with_ann_assign_name(self):
1119
+ """
1120
+ Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Name target.
1121
+
1122
+ This test verifies that annotated assignments within a class definition are correctly evaluated.
1123
+ """
1124
+ code = dedent("""
1125
+ class TestClass:
1126
+ x: int = 5
1127
+ y: str = "test"
1128
+
1129
+ instance = TestClass()
1130
+ result = (instance.x, instance.y)
1131
+ """)
1132
+
1133
+ state = {}
1134
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
1135
+
1136
+ assert result == (5, "test")
1137
+ assert isinstance(state["TestClass"], type)
1138
+ # Type objects are not wrapped by safer_func
1139
+ for value in state["TestClass"].__annotations__.values():
1140
+ assert not hasattr(value, "__wrapped__")
1141
+ assert state["TestClass"].__annotations__ == {"x": int, "y": str}
1142
+ assert state["TestClass"].x == 5
1143
+ assert state["TestClass"].y == "test"
1144
+ assert isinstance(state["instance"], state["TestClass"])
1145
+ assert state["instance"].x == 5
1146
+ assert state["instance"].y == "test"
1147
+
1148
+ def test_evaluate_class_def_with_ann_assign_attribute(self):
1149
+ """
1150
+ Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Attribute target.
1151
+
1152
+ This test ensures that class attributes using attribute notation are correctly handled.
1153
+ """
1154
+ code = dedent("""
1155
+ class TestSubClass:
1156
+ attr = 1
1157
+ class TestClass:
1158
+ data: TestSubClass = TestSubClass()
1159
+ data.attr: str = "value"
1160
+
1161
+ result = TestClass.data.attr
1162
+ """)
1163
+
1164
+ state = {}
1165
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
1166
+
1167
+ assert result == "value"
1168
+ assert isinstance(state["TestClass"], type)
1169
+ assert state["TestClass"].__annotations__.keys() == {"data"}
1170
+ assert isinstance(state["TestClass"].__annotations__["data"], type)
1171
+ assert state["TestClass"].__annotations__["data"].__name__ == "TestSubClass"
1172
+ assert state["TestClass"].data.attr == "value"
1173
+
1174
+ def test_evaluate_class_def_with_ann_assign_subscript(self):
1175
+ """
1176
+ Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Subscript target.
1177
+
1178
+ This test ensures that class attributes using subscript notation are correctly handled.
1179
+ """
1180
+ code = dedent("""
1181
+ class TestClass:
1182
+ key_data: dict = {}
1183
+ key_data["key"]: str = "value"
1184
+ index_data: list = [10, 20, 30]
1185
+ index_data[0:2]: list[str] = ["a", "b"]
1186
+
1187
+ result = (TestClass.key_data['key'], TestClass.index_data[1:])
1188
+ """)
1189
+
1190
+ state = {}
1191
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
1192
+
1193
+ assert result == ("value", ["b", 30])
1194
+ assert isinstance(state["TestClass"], type)
1195
+ # Type objects are not wrapped by safer_func
1196
+ for value in state["TestClass"].__annotations__.values():
1197
+ assert not hasattr(value, "__wrapped__")
1198
+ assert state["TestClass"].__annotations__ == {"key_data": dict, "index_data": list}
1199
+ assert state["TestClass"].key_data == {"key": "value"}
1200
+ assert state["TestClass"].index_data == ["a", "b", 30]
1201
+
1202
+ def test_evaluate_annassign(self):
1203
+ code = dedent("""\
1204
+ # Basic annotated assignment
1205
+ x: int = 42
1206
+
1207
+ # Type annotations with expressions
1208
+ y: float = x / 2
1209
+
1210
+ # Type annotation without assignment
1211
+ z: list
1212
+
1213
+ # Type annotation with complex value
1214
+ names: list = ["Alice", "Bob", "Charlie"]
1215
+
1216
+ # Type hint shouldn't restrict values at runtime
1217
+ s: str = 123 # Would be a type error in static checking, but valid at runtime
1218
+
1219
+ # Access the values
1220
+ result = (x, y, names, s)
1221
+ """)
1222
+ state = {}
1223
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
1224
+ assert state["x"] == 42
1225
+ assert state["y"] == 21.0
1226
+ assert "z" not in state # z should be not be defined
1227
+ assert state["names"] == ["Alice", "Bob", "Charlie"]
1228
+ assert state["s"] == 123 # Type hints don't restrict at runtime
1229
+ assert state["result"] == (42, 21.0, ["Alice", "Bob", "Charlie"], 123)
1230
+
1231
+ @pytest.mark.parametrize(
1232
+ "code, expected_result",
1233
+ [
1234
+ (
1235
+ dedent("""\
1236
+ x = 1
1237
+ x += 2
1238
+ """),
1239
+ 3,
1240
+ ),
1241
+ (
1242
+ dedent("""\
1243
+ x = "a"
1244
+ x += "b"
1245
+ """),
1246
+ "ab",
1247
+ ),
1248
+ (
1249
+ dedent("""\
1250
+ class Custom:
1251
+ def __init__(self, value):
1252
+ self.value = value
1253
+ def __iadd__(self, other):
1254
+ self.value += other * 10
1255
+ return self
1256
+
1257
+ x = Custom(1)
1258
+ x += 2
1259
+ x.value
1260
+ """),
1261
+ 21,
1262
+ ),
1263
+ ],
1264
+ )
1265
+ def test_evaluate_augassign(self, code, expected_result):
1266
+ state = {}
1267
+ result, _ = evaluate_python_code(code, {}, state=state)
1268
+ assert result == expected_result
1269
+
1270
+ @pytest.mark.parametrize(
1271
+ "operator, expected_result",
1272
+ [
1273
+ ("+=", 7),
1274
+ ("-=", 3),
1275
+ ("*=", 10),
1276
+ ("/=", 2.5),
1277
+ ("//=", 2),
1278
+ ("%=", 1),
1279
+ ("**=", 25),
1280
+ ("&=", 0),
1281
+ ("|=", 7),
1282
+ ("^=", 7),
1283
+ (">>=", 1),
1284
+ ("<<=", 20),
1285
+ ],
1286
+ )
1287
+ def test_evaluate_augassign_number(self, operator, expected_result):
1288
+ code = dedent("""\
1289
+ x = 5
1290
+ x {operator} 2
1291
+ """).format(operator=operator)
1292
+ state = {}
1293
+ result, _ = evaluate_python_code(code, {}, state=state)
1294
+ assert result == expected_result
1295
+
1296
+ @pytest.mark.parametrize(
1297
+ "operator, expected_result",
1298
+ [
1299
+ ("+=", 7),
1300
+ ("-=", 3),
1301
+ ("*=", 10),
1302
+ ("/=", 2.5),
1303
+ ("//=", 2),
1304
+ ("%=", 1),
1305
+ ("**=", 25),
1306
+ ("&=", 0),
1307
+ ("|=", 7),
1308
+ ("^=", 7),
1309
+ (">>=", 1),
1310
+ ("<<=", 20),
1311
+ ],
1312
+ )
1313
+ def test_evaluate_augassign_custom(self, operator, expected_result):
1314
+ operator_names = {
1315
+ "+=": "iadd",
1316
+ "-=": "isub",
1317
+ "*=": "imul",
1318
+ "/=": "itruediv",
1319
+ "//=": "ifloordiv",
1320
+ "%=": "imod",
1321
+ "**=": "ipow",
1322
+ "&=": "iand",
1323
+ "|=": "ior",
1324
+ "^=": "ixor",
1325
+ ">>=": "irshift",
1326
+ "<<=": "ilshift",
1327
+ }
1328
+ code = dedent("""\
1329
+ class Custom:
1330
+ def __init__(self, value):
1331
+ self.value = value
1332
+ def __{operator_name}__(self, other):
1333
+ self.value {operator} other
1334
+ return self
1335
+
1336
+ x = Custom(5)
1337
+ x {operator} 2
1338
+ x.value
1339
+ """).format(operator=operator, operator_name=operator_names[operator])
1340
+ state = {}
1341
+ result, _ = evaluate_python_code(code, {}, state=state)
1342
+ assert result == expected_result
1343
+
1344
+ @pytest.mark.parametrize(
1345
+ "code, expected_error_message",
1346
+ [
1347
+ (
1348
+ dedent("""\
1349
+ x = 5
1350
+ del x
1351
+ x
1352
+ """),
1353
+ "The variable `x` is not defined",
1354
+ ),
1355
+ (
1356
+ dedent("""\
1357
+ x = [1, 2, 3]
1358
+ del x[2]
1359
+ x[2]
1360
+ """),
1361
+ "IndexError: list index out of range",
1362
+ ),
1363
+ (
1364
+ dedent("""\
1365
+ x = {"key": "value"}
1366
+ del x["key"]
1367
+ x["key"]
1368
+ """),
1369
+ "Could not index {} with 'key'",
1370
+ ),
1371
+ (
1372
+ dedent("""\
1373
+ del x
1374
+ """),
1375
+ "Cannot delete name 'x': name is not defined",
1376
+ ),
1377
+ ],
1378
+ )
1379
+ def test_evaluate_delete(self, code, expected_error_message):
1380
+ state = {}
1381
+ with pytest.raises(InterpreterError) as exception_info:
1382
+ evaluate_python_code(code, {}, state=state)
1383
+ assert expected_error_message in str(exception_info.value)
1384
+
1385
+ def test_non_standard_comparisons(self):
1386
+ code = dedent("""\
1387
+ class NonStdEqualsResult:
1388
+ def __init__(self, left:object, right:object):
1389
+ self._left = left
1390
+ self._right = right
1391
+ def __str__(self) -> str:
1392
+ return f'{self._left} == {self._right}'
1393
+
1394
+ class NonStdComparisonClass:
1395
+ def __init__(self, value: str ):
1396
+ self._value = value
1397
+ def __str__(self):
1398
+ return self._value
1399
+ def __eq__(self, other):
1400
+ return NonStdEqualsResult(self, other)
1401
+ a = NonStdComparisonClass("a")
1402
+ b = NonStdComparisonClass("b")
1403
+ result = a == b
1404
+ """)
1405
+ result, _ = evaluate_python_code(code, state={})
1406
+ assert not isinstance(result, bool)
1407
+ assert str(result) == "a == b"
1408
+
1409
+
1410
+ class TestEvaluateBoolop:
1411
+ @pytest.mark.parametrize("a", [1, 0])
1412
+ @pytest.mark.parametrize("b", [2, 0])
1413
+ @pytest.mark.parametrize("c", [3, 0])
1414
+ def test_evaluate_boolop_and(self, a, b, c):
1415
+ boolop_ast = ast.parse("a and b and c").body[0].value
1416
+ state = {"a": a, "b": b, "c": c}
1417
+ result = evaluate_boolop(boolop_ast, state, {}, {}, [])
1418
+ assert result == (a and b and c)
1419
+
1420
+ @pytest.mark.parametrize("a", [1, 0])
1421
+ @pytest.mark.parametrize("b", [2, 0])
1422
+ @pytest.mark.parametrize("c", [3, 0])
1423
+ def test_evaluate_boolop_or(self, a, b, c):
1424
+ boolop_ast = ast.parse("a or b or c").body[0].value
1425
+ state = {"a": a, "b": b, "c": c}
1426
+ result = evaluate_boolop(boolop_ast, state, {}, {}, [])
1427
+ assert result == (a or b or c)
1428
+
1429
+
1430
+ class TestEvaluateDelete:
1431
+ @pytest.mark.parametrize(
1432
+ "code, state, expectation",
1433
+ [
1434
+ ("del x", {"x": 1}, {}),
1435
+ ("del x[1]", {"x": [1, 2, 3]}, {"x": [1, 3]}),
1436
+ ("del x['key']", {"x": {"key": "value"}}, {"x": {}}),
1437
+ ("del x", {}, InterpreterError("Cannot delete name 'x': name is not defined")),
1438
+ ],
1439
+ )
1440
+ def test_evaluate_delete(self, code, state, expectation):
1441
+ delete_node = ast.parse(code).body[0]
1442
+ if isinstance(expectation, Exception):
1443
+ with pytest.raises(type(expectation)) as exception_info:
1444
+ evaluate_delete(delete_node, state, {}, {}, [])
1445
+ assert str(expectation) in str(exception_info.value)
1446
+ else:
1447
+ evaluate_delete(delete_node, state, {}, {}, [])
1448
+ _ = state.pop("_operations_count", None)
1449
+ assert state == expectation
1450
+
1451
+
1452
+ class TestEvaluateCondition:
1453
+ @pytest.mark.parametrize(
1454
+ "condition, state, expected_result",
1455
+ [
1456
+ ("a == b", {"a": 1, "b": 1}, True),
1457
+ ("a == b", {"a": 1, "b": 2}, False),
1458
+ ("a != b", {"a": 1, "b": 1}, False),
1459
+ ("a != b", {"a": 1, "b": 2}, True),
1460
+ ("a < b", {"a": 1, "b": 1}, False),
1461
+ ("a < b", {"a": 1, "b": 2}, True),
1462
+ ("a < b", {"a": 2, "b": 1}, False),
1463
+ ("a <= b", {"a": 1, "b": 1}, True),
1464
+ ("a <= b", {"a": 1, "b": 2}, True),
1465
+ ("a <= b", {"a": 2, "b": 1}, False),
1466
+ ("a > b", {"a": 1, "b": 1}, False),
1467
+ ("a > b", {"a": 1, "b": 2}, False),
1468
+ ("a > b", {"a": 2, "b": 1}, True),
1469
+ ("a >= b", {"a": 1, "b": 1}, True),
1470
+ ("a >= b", {"a": 1, "b": 2}, False),
1471
+ ("a >= b", {"a": 2, "b": 1}, True),
1472
+ ("a is b", {"a": 1, "b": 1}, True),
1473
+ ("a is b", {"a": 1, "b": 2}, False),
1474
+ ("a is not b", {"a": 1, "b": 1}, False),
1475
+ ("a is not b", {"a": 1, "b": 2}, True),
1476
+ ("a in b", {"a": 1, "b": [1, 2, 3]}, True),
1477
+ ("a in b", {"a": 4, "b": [1, 2, 3]}, False),
1478
+ ("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
1479
+ ("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
1480
+ # Chained conditions:
1481
+ ("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
1482
+ ("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
1483
+ ("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
1484
+ ("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
1485
+ ],
1486
+ )
1487
+ def test_evaluate_condition(self, condition, state, expected_result):
1488
+ condition_ast = ast.parse(condition, mode="eval").body
1489
+ result = evaluate_condition(condition_ast, state, {}, {}, [])
1490
+ assert result == expected_result
1491
+
1492
+ @pytest.mark.parametrize(
1493
+ "condition, state, expected_result",
1494
+ [
1495
+ ("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
1496
+ ("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
1497
+ ("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
1498
+ ("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
1499
+ ("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
1500
+ ("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
1501
+ (
1502
+ "a == b",
1503
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
1504
+ pd.DataFrame({"x": [True, True], "y": [True, False]}),
1505
+ ),
1506
+ (
1507
+ "a != b",
1508
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
1509
+ pd.DataFrame({"x": [False, False], "y": [False, True]}),
1510
+ ),
1511
+ (
1512
+ "a < b",
1513
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
1514
+ pd.DataFrame({"x": [True, False], "y": [False, False]}),
1515
+ ),
1516
+ (
1517
+ "a <= b",
1518
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
1519
+ pd.DataFrame({"x": [True, True], "y": [False, False]}),
1520
+ ),
1521
+ (
1522
+ "a > b",
1523
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
1524
+ pd.DataFrame({"x": [False, False], "y": [True, True]}),
1525
+ ),
1526
+ (
1527
+ "a >= b",
1528
+ {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
1529
+ pd.DataFrame({"x": [False, True], "y": [True, True]}),
1530
+ ),
1531
+ ],
1532
+ )
1533
+ def test_evaluate_condition_with_pandas(self, condition, state, expected_result):
1534
+ condition_ast = ast.parse(condition, mode="eval").body
1535
+ result = evaluate_condition(condition_ast, state, {}, {}, [])
1536
+ if isinstance(result, pd.Series):
1537
+ pd.testing.assert_series_equal(result, expected_result)
1538
+ else:
1539
+ pd.testing.assert_frame_equal(result, expected_result)
1540
+
1541
+ @pytest.mark.parametrize(
1542
+ "condition, state, expected_exception",
1543
+ [
1544
+ # Chained conditions:
1545
+ (
1546
+ "a == b == c",
1547
+ {
1548
+ "a": pd.Series([1, 2, 3]),
1549
+ "b": pd.Series([2, 2, 2]),
1550
+ "c": pd.Series([3, 3, 3]),
1551
+ },
1552
+ ValueError(
1553
+ "The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
1554
+ ),
1555
+ ),
1556
+ (
1557
+ "a == b == c",
1558
+ {
1559
+ "a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
1560
+ "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
1561
+ "c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
1562
+ },
1563
+ ValueError(
1564
+ "The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
1565
+ ),
1566
+ ),
1567
+ ],
1568
+ )
1569
+ def test_evaluate_condition_with_pandas_exceptions(self, condition, state, expected_exception):
1570
+ condition_ast = ast.parse(condition, mode="eval").body
1571
+ with pytest.raises(type(expected_exception)) as exception_info:
1572
+ _ = evaluate_condition(condition_ast, state, {}, {}, [])
1573
+ assert str(expected_exception) in str(exception_info.value)
1574
+
1575
+
1576
+ class TestEvaluateSubscript:
1577
+ @pytest.mark.parametrize(
1578
+ "subscript, state, expected_result",
1579
+ [
1580
+ ("dct[1]", {"dct": {1: 11, 2: 22}}, 11),
1581
+ ("dct[2]", {"dct": {1: "a", 2: "b"}}, "b"),
1582
+ ("dct['b']", {"dct": {"a": 1, "b": 2}}, 2),
1583
+ ("dct['a']", {"dct": {"a": "aa", "b": "bb"}}, "aa"),
1584
+ ("dct[1, 2]", {"dct": {(1, 2): 3}}, 3), # tuple-index
1585
+ ("dct['a']['b']", {"dct": {"a": {"b": 1}}}, 1), # nested
1586
+ ("lst[0]", {"lst": [1, 2, 3]}, 1),
1587
+ ("lst[-1]", {"lst": [1, 2, 3]}, 3),
1588
+ ("lst[1:3]", {"lst": [1, 2, 3, 4]}, [2, 3]),
1589
+ ("lst[:]", {"lst": [1, 2, 3]}, [1, 2, 3]),
1590
+ ("lst[::2]", {"lst": [1, 2, 3, 4]}, [1, 3]),
1591
+ ("lst[::-1]", {"lst": [1, 2, 3]}, [3, 2, 1]),
1592
+ ("tup[1]", {"tup": (1, 2, 3)}, 2),
1593
+ ("tup[-1]", {"tup": (1, 2, 3)}, 3),
1594
+ ("tup[1:3]", {"tup": (1, 2, 3, 4)}, (2, 3)),
1595
+ ("tup[:]", {"tup": (1, 2, 3)}, (1, 2, 3)),
1596
+ ("tup[::2]", {"tup": (1, 2, 3, 4)}, (1, 3)),
1597
+ ("tup[::-1]", {"tup": (1, 2, 3)}, (3, 2, 1)),
1598
+ ("st[1]", {"str": "abc"}, "b"),
1599
+ ("st[-1]", {"str": "abc"}, "c"),
1600
+ ("st[1:3]", {"str": "abcd"}, "bc"),
1601
+ ("st[:]", {"str": "abc"}, "abc"),
1602
+ ("st[::2]", {"str": "abcd"}, "ac"),
1603
+ ("st[::-1]", {"str": "abc"}, "cba"),
1604
+ ("arr[1]", {"arr": np.array([1, 2, 3])}, 2),
1605
+ ("arr[1:3]", {"arr": np.array([1, 2, 3, 4])}, np.array([2, 3])),
1606
+ ("arr[:]", {"arr": np.array([1, 2, 3])}, np.array([1, 2, 3])),
1607
+ ("arr[::2]", {"arr": np.array([1, 2, 3, 4])}, np.array([1, 3])),
1608
+ ("arr[::-1]", {"arr": np.array([1, 2, 3])}, np.array([3, 2, 1])),
1609
+ ("arr[1, 2]", {"arr": np.array([[1, 2, 3], [4, 5, 6]])}, 6),
1610
+ ("ser[1]", {"ser": pd.Series([1, 2, 3])}, 2),
1611
+ ("ser.loc[1]", {"ser": pd.Series([1, 2, 3])}, 2),
1612
+ ("ser.loc[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 3),
1613
+ ("ser.iloc[1]", {"ser": pd.Series([1, 2, 3])}, 2),
1614
+ ("ser.iloc[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 2),
1615
+ ("ser.at[1]", {"ser": pd.Series([1, 2, 3])}, 2),
1616
+ ("ser.at[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 3),
1617
+ ("ser.iat[1]", {"ser": pd.Series([1, 2, 3])}, 2),
1618
+ ("ser.iat[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 2),
1619
+ ("ser[1:3]", {"ser": pd.Series([1, 2, 3, 4])}, pd.Series([2, 3], index=[1, 2])),
1620
+ ("ser[:]", {"ser": pd.Series([1, 2, 3])}, pd.Series([1, 2, 3])),
1621
+ ("ser[::2]", {"ser": pd.Series([1, 2, 3, 4])}, pd.Series([1, 3], index=[0, 2])),
1622
+ ("ser[::-1]", {"ser": pd.Series([1, 2, 3])}, pd.Series([3, 2, 1], index=[2, 1, 0])),
1623
+ ("df['y'][1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
1624
+ ("df['y'][5]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
1625
+ ("df.loc[1, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
1626
+ ("df.loc[5, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
1627
+ ("df.iloc[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
1628
+ ("df.iloc[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 4),
1629
+ ("df.at[1, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
1630
+ ("df.at[5, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
1631
+ ("df.iat[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
1632
+ ("df.iat[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 4),
1633
+ ],
1634
+ )
1635
+ def test_evaluate_subscript(self, subscript, state, expected_result):
1636
+ subscript_ast = ast.parse(subscript).body[0].value
1637
+ result = evaluate_subscript(subscript_ast, state, {}, {}, [])
1638
+ try:
1639
+ assert result == expected_result
1640
+ except ValueError:
1641
+ assert (result == expected_result).all()
1642
+
1643
+ @pytest.mark.parametrize(
1644
+ "subscript, state, expected_error_message",
1645
+ [
1646
+ ("dct['a']", {"dct": {}}, "KeyError: 'a'"),
1647
+ ("dct[0]", {"dct": {}}, "KeyError: 0"),
1648
+ ("dct['c']", {"dct": {"a": 1, "b": 2}}, "KeyError: 'c'"),
1649
+ ("dct[1, 2, 3]", {"dct": {(1, 2): 3}}, "KeyError: (1, 2, 3)"),
1650
+ ("lst[0]", {"lst": []}, "IndexError: list index out of range"),
1651
+ ("lst[3]", {"lst": [1, 2, 3]}, "IndexError: list index out of range"),
1652
+ ("lst[-4]", {"lst": [1, 2, 3]}, "IndexError: list index out of range"),
1653
+ ("value[0]", {"value": 1}, "TypeError: 'int' object is not subscriptable"),
1654
+ ],
1655
+ )
1656
+ def test_evaluate_subscript_error(self, subscript, state, expected_error_message):
1657
+ subscript_ast = ast.parse(subscript).body[0].value
1658
+ with pytest.raises(InterpreterError, match="Could not index") as exception_info:
1659
+ _ = evaluate_subscript(subscript_ast, state, {}, {}, [])
1660
+ assert expected_error_message in str(exception_info.value)
1661
+
1662
+ @pytest.mark.parametrize(
1663
+ "subscriptable_class, expectation",
1664
+ [
1665
+ (True, 20),
1666
+ (False, InterpreterError("TypeError: 'Custom' object is not subscriptable")),
1667
+ ],
1668
+ )
1669
+ def test_evaluate_subscript_with_custom_class(self, subscriptable_class, expectation):
1670
+ if subscriptable_class:
1671
+
1672
+ class Custom:
1673
+ def __getitem__(self, key):
1674
+ return key * 10
1675
+ else:
1676
+
1677
+ class Custom:
1678
+ pass
1679
+
1680
+ state = {"obj": Custom()}
1681
+ subscript = "obj[2]"
1682
+ subscript_ast = ast.parse(subscript).body[0].value
1683
+ if isinstance(expectation, Exception):
1684
+ with pytest.raises(type(expectation), match="Could not index") as exception_info:
1685
+ evaluate_subscript(subscript_ast, state, {}, {}, [])
1686
+ assert "TypeError: 'Custom' object is not subscriptable" in str(exception_info.value)
1687
+ else:
1688
+ result = evaluate_subscript(subscript_ast, state, {}, {}, [])
1689
+ assert result == expectation
1690
+
1691
+
1692
+ def test_get_safe_module_handle_lazy_imports():
1693
+ class FakeModule(types.ModuleType):
1694
+ def __init__(self, name):
1695
+ super().__init__(name)
1696
+ self.non_lazy_attribute = "ok"
1697
+
1698
+ def __getattr__(self, name):
1699
+ if name == "lazy_attribute":
1700
+ raise ImportError("lazy import failure")
1701
+ return super().__getattr__(name)
1702
+
1703
+ def __dir__(self):
1704
+ return super().__dir__() + ["lazy_attribute"]
1705
+
1706
+ fake_module = FakeModule("fake_module")
1707
+ safe_module = get_safe_module(fake_module, authorized_imports=set())
1708
+ assert not hasattr(safe_module, "lazy_attribute")
1709
+ assert getattr(safe_module, "non_lazy_attribute") == "ok"
1710
+
1711
+
1712
+ class TestPrintContainer:
1713
+ def test_initial_value(self):
1714
+ pc = PrintContainer()
1715
+ assert pc.value == ""
1716
+
1717
+ def test_append(self):
1718
+ pc = PrintContainer()
1719
+ pc.append("Hello")
1720
+ assert pc.value == "Hello"
1721
+
1722
+ def test_iadd(self):
1723
+ pc = PrintContainer()
1724
+ pc += "World"
1725
+ assert pc.value == "World"
1726
+
1727
+ def test_str(self):
1728
+ pc = PrintContainer()
1729
+ pc.append("Hello")
1730
+ assert str(pc) == "Hello"
1731
+
1732
+ def test_repr(self):
1733
+ pc = PrintContainer()
1734
+ pc.append("Hello")
1735
+ assert repr(pc) == "PrintContainer(Hello)"
1736
+
1737
+ def test_len(self):
1738
+ pc = PrintContainer()
1739
+ pc.append("Hello")
1740
+ assert len(pc) == 5
1741
+
1742
+
1743
+ def test_fix_final_answer_code():
1744
+ test_cases = [
1745
+ (
1746
+ "final_answer = 3.21\nfinal_answer(final_answer)",
1747
+ "final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
1748
+ ),
1749
+ (
1750
+ "x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
1751
+ "x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
1752
+ ),
1753
+ (
1754
+ "def func():\n final_answer = 42\n return final_answer(final_answer)",
1755
+ "def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
1756
+ ),
1757
+ (
1758
+ "final_answer(5) # Should not change function calls",
1759
+ "final_answer(5) # Should not change function calls",
1760
+ ),
1761
+ (
1762
+ "obj.final_answer = 5 # Should not change object attributes",
1763
+ "obj.final_answer = 5 # Should not change object attributes",
1764
+ ),
1765
+ (
1766
+ "final_answer=3.21;final_answer(final_answer)",
1767
+ "final_answer_variable=3.21;final_answer(final_answer_variable)",
1768
+ ),
1769
+ ]
1770
+
1771
+ for i, (input_code, expected) in enumerate(test_cases, 1):
1772
+ result = fix_final_answer_code(input_code)
1773
+ assert result == expected, f"""
1774
+ Test case {i} failed:
1775
+ Input: {input_code}
1776
+ Expected: {expected}
1777
+ Got: {result}
1778
+ """
1779
+
1780
+
1781
+ @pytest.mark.parametrize(
1782
+ "module,authorized_imports,expected",
1783
+ [
1784
+ ("os", ["other", "*"], True),
1785
+ ("AnyModule", ["*"], True),
1786
+ ("os", ["os"], True),
1787
+ ("AnyModule", ["AnyModule"], True),
1788
+ ("Module.os", ["Module"], False),
1789
+ ("Module.os", ["Module", "Module.os"], True),
1790
+ ("os.path", ["os.*"], True),
1791
+ ("os", ["os.path"], True),
1792
+ ],
1793
+ )
1794
+ def test_check_import_authorized(module: str, authorized_imports: list[str], expected: bool):
1795
+ assert check_import_authorized(module, authorized_imports) == expected
1796
+
1797
+
1798
+ class TestLocalPythonExecutor:
1799
+ def test_state_name(self):
1800
+ executor = LocalPythonExecutor(additional_authorized_imports=[])
1801
+ assert executor.state.get("__name__") == "__main__"
1802
+
1803
+ @pytest.mark.parametrize(
1804
+ "code",
1805
+ [
1806
+ "d = {'func': lambda x: x + 10}; func = d['func']; func(1)",
1807
+ "d = {'func': lambda x: x + 10}; d['func'](1)",
1808
+ ],
1809
+ )
1810
+ def test_call_from_dict(self, code):
1811
+ executor = LocalPythonExecutor([])
1812
+ result, _, _ = executor(code)
1813
+ assert result == 11
1814
+
1815
+ @pytest.mark.parametrize(
1816
+ "code",
1817
+ [
1818
+ "a = b = 1; a",
1819
+ "a = b = 1; b",
1820
+ "a, b = c, d = 1, 1; a",
1821
+ "a, b = c, d = 1, 1; b",
1822
+ "a, b = c, d = 1, 1; c",
1823
+ "a, b = c, d = {1, 2}; a",
1824
+ "a, b = c, d = {1, 2}; c",
1825
+ "a, b = c, d = {1: 10, 2: 20}; a",
1826
+ "a, b = c, d = {1: 10, 2: 20}; c",
1827
+ "a = b = (lambda: 1)(); b",
1828
+ "a = b = (lambda: 1)(); lambda x: 10; b",
1829
+ "a = b = (lambda x: lambda y: x + y)(0)(1); b",
1830
+ dedent("""
1831
+ def foo():
1832
+ return 1;
1833
+ a = b = foo(); b"""),
1834
+ dedent("""
1835
+ def foo(*args, **kwargs):
1836
+ return sum(args)
1837
+ a = b = foo(1,-1,1); b"""),
1838
+ "a, b = 1, 2; a, b = b, a; b",
1839
+ ],
1840
+ )
1841
+ def test_chained_assignments(self, code):
1842
+ executor = LocalPythonExecutor([])
1843
+ executor.send_tools({})
1844
+ result, _, _ = executor(code)
1845
+ assert result == 1
1846
+
1847
+ def test_evaluate_assign_error(self):
1848
+ code = "a, b = 1, 2, 3; a"
1849
+ executor = LocalPythonExecutor([])
1850
+ with pytest.raises(InterpreterError, match=".*Cannot unpack tuple of wrong size"):
1851
+ executor(code)
1852
+
1853
+ def test_function_def_recovers_source_code(self):
1854
+ executor = LocalPythonExecutor([])
1855
+ executor.send_tools({"final_answer": FinalAnswerTool()})
1856
+ res, _, _ = executor(
1857
+ dedent(
1858
+ """
1859
+ def target_function():
1860
+ return "Hello world"
1861
+
1862
+ final_answer(target_function)
1863
+ """
1864
+ )
1865
+ )
1866
+ assert res.__name__ == "target_function"
1867
+ assert res.__source__ == "def target_function():\n return 'Hello world'"
1868
+
1869
+ @pytest.mark.parametrize(
1870
+ "code, expected_result",
1871
+ [("isinstance(5, int)", True), ("isinstance('foo', str)", True), ("isinstance(5, str)", False)],
1872
+ )
1873
+ def test_isinstance_builtin_type(self, code, expected_result):
1874
+ executor = LocalPythonExecutor([])
1875
+ executor.send_tools({})
1876
+ result, _, _ = executor(code)
1877
+ assert result is expected_result
1878
+
1879
+
1880
+ class TestLocalPythonExecutorSecurity:
1881
+ @pytest.mark.parametrize(
1882
+ "additional_authorized_imports, expected_error",
1883
+ [([], InterpreterError("Import of os is not allowed")), (["os"], None)],
1884
+ )
1885
+ def test_vulnerability_import(self, additional_authorized_imports, expected_error):
1886
+ executor = LocalPythonExecutor(additional_authorized_imports)
1887
+ with (
1888
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1889
+ if isinstance(expected_error, Exception)
1890
+ else does_not_raise()
1891
+ ):
1892
+ executor("import os")
1893
+
1894
+ @pytest.mark.parametrize(
1895
+ "additional_authorized_imports, expected_error",
1896
+ [([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)],
1897
+ )
1898
+ def test_vulnerability_builtins(self, additional_authorized_imports, expected_error):
1899
+ executor = LocalPythonExecutor(additional_authorized_imports)
1900
+ with (
1901
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1902
+ if isinstance(expected_error, Exception)
1903
+ else does_not_raise()
1904
+ ):
1905
+ executor("import builtins")
1906
+
1907
+ @pytest.mark.parametrize(
1908
+ "additional_authorized_imports, expected_error",
1909
+ [([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)],
1910
+ )
1911
+ def test_vulnerability_builtins_safe_functions(self, additional_authorized_imports, expected_error):
1912
+ executor = LocalPythonExecutor(additional_authorized_imports)
1913
+ with (
1914
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1915
+ if isinstance(expected_error, Exception)
1916
+ else does_not_raise()
1917
+ ):
1918
+ executor("import builtins; builtins.print(1)")
1919
+
1920
+ @pytest.mark.parametrize(
1921
+ "additional_authorized_imports, additional_tools, expected_error",
1922
+ [
1923
+ ([], [], InterpreterError("Import of builtins is not allowed")),
1924
+ (["builtins"], [], InterpreterError("Forbidden access to function: exec")),
1925
+ (["builtins"], ["exec"], None),
1926
+ ],
1927
+ )
1928
+ def test_vulnerability_builtins_dangerous_functions(
1929
+ self, additional_authorized_imports, additional_tools, expected_error
1930
+ ):
1931
+ executor = LocalPythonExecutor(additional_authorized_imports)
1932
+ if additional_tools:
1933
+ from builtins import exec
1934
+
1935
+ executor.send_tools({"exec": exec})
1936
+ with (
1937
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1938
+ if isinstance(expected_error, Exception)
1939
+ else does_not_raise()
1940
+ ):
1941
+ executor("import builtins; builtins.exec")
1942
+
1943
+ @pytest.mark.parametrize(
1944
+ "additional_authorized_imports, additional_tools, expected_error",
1945
+ [
1946
+ ([], [], InterpreterError("Import of os is not allowed")),
1947
+ (["os"], [], InterpreterError("Forbidden access to function: popen")),
1948
+ (["os"], ["popen"], None),
1949
+ ],
1950
+ )
1951
+ def test_vulnerability_dangerous_functions(self, additional_authorized_imports, additional_tools, expected_error):
1952
+ executor = LocalPythonExecutor(additional_authorized_imports)
1953
+ if additional_tools:
1954
+ from os import popen
1955
+
1956
+ executor.send_tools({"popen": popen})
1957
+ with (
1958
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1959
+ if isinstance(expected_error, Exception)
1960
+ else does_not_raise()
1961
+ ):
1962
+ executor("import os; os.popen")
1963
+
1964
+ @pytest.mark.parametrize("dangerous_function", DANGEROUS_FUNCTIONS)
1965
+ def test_vulnerability_for_all_dangerous_functions(self, dangerous_function):
1966
+ dangerous_module_name, dangerous_function_name = dangerous_function.rsplit(".", 1)
1967
+ # Skip test if module is not installed: posix module is not installed on Windows
1968
+ pytest.importorskip(dangerous_module_name)
1969
+ executor = LocalPythonExecutor([dangerous_module_name])
1970
+ if "__" in dangerous_function_name:
1971
+ error_match = f".*Forbidden access to dunder attribute: {dangerous_function_name}"
1972
+ else:
1973
+ error_match = f".*Forbidden access to function: {dangerous_function_name}.*"
1974
+ with pytest.raises(InterpreterError, match=error_match):
1975
+ executor(f"import {dangerous_module_name}; {dangerous_function}")
1976
+
1977
+ @pytest.mark.parametrize(
1978
+ "additional_authorized_imports, expected_error",
1979
+ [
1980
+ ([], InterpreterError("Import of sys is not allowed")),
1981
+ (["sys"], InterpreterError("Forbidden access to module: os")),
1982
+ (["sys", "os"], None),
1983
+ ],
1984
+ )
1985
+ def test_vulnerability_via_sys(self, additional_authorized_imports, expected_error):
1986
+ executor = LocalPythonExecutor(additional_authorized_imports)
1987
+ with (
1988
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
1989
+ if isinstance(expected_error, Exception)
1990
+ else does_not_raise()
1991
+ ):
1992
+ executor(
1993
+ dedent(
1994
+ """
1995
+ import sys
1996
+ sys.modules["os"].system(":")
1997
+ """
1998
+ )
1999
+ )
2000
+
2001
+ @pytest.mark.parametrize("dangerous_module", DANGEROUS_MODULES)
2002
+ def test_vulnerability_via_sys_for_all_dangerous_modules(self, dangerous_module):
2003
+ import sys
2004
+
2005
+ if dangerous_module not in sys.modules or dangerous_module == "sys":
2006
+ pytest.skip("module not present in sys.modules")
2007
+ executor = LocalPythonExecutor(["sys"])
2008
+ with pytest.raises(InterpreterError) as exception_info:
2009
+ executor(
2010
+ dedent(
2011
+ f"""
2012
+ import sys
2013
+ sys.modules["{dangerous_module}"]
2014
+ """
2015
+ )
2016
+ )
2017
+ assert f"Forbidden access to module: {dangerous_module}" in str(exception_info.value)
2018
+
2019
+ @pytest.mark.parametrize(
2020
+ "additional_authorized_imports, expected_error",
2021
+ [(["importlib"], InterpreterError("Forbidden access to module: os")), (["importlib", "os"], None)],
2022
+ )
2023
+ def test_vulnerability_via_importlib(self, additional_authorized_imports, expected_error):
2024
+ executor = LocalPythonExecutor(additional_authorized_imports)
2025
+ with (
2026
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
2027
+ if isinstance(expected_error, Exception)
2028
+ else does_not_raise()
2029
+ ):
2030
+ executor(
2031
+ dedent(
2032
+ """
2033
+ import importlib
2034
+ importlib.import_module("os").system(":")
2035
+ """
2036
+ )
2037
+ )
2038
+
2039
+ @pytest.mark.parametrize(
2040
+ "code, additional_authorized_imports, expected_error",
2041
+ [
2042
+ # os submodule
2043
+ (
2044
+ "import queue; queue.threading._os.system(':')",
2045
+ [],
2046
+ InterpreterError("Forbidden access to module: threading"),
2047
+ ),
2048
+ (
2049
+ "import queue; queue.threading._os.system(':')",
2050
+ ["threading"],
2051
+ InterpreterError("Forbidden access to module: os"),
2052
+ ),
2053
+ ("import random; random._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
2054
+ (
2055
+ "import random; random.__dict__['_os'].system(':')",
2056
+ [],
2057
+ InterpreterError("Forbidden access to dunder attribute: __dict__"),
2058
+ ),
2059
+ (
2060
+ "import doctest; doctest.inspect.os.system(':')",
2061
+ ["doctest"],
2062
+ InterpreterError("Forbidden access to module: inspect"),
2063
+ ),
2064
+ (
2065
+ "import doctest; doctest.inspect.os.system(':')",
2066
+ ["doctest", "inspect"],
2067
+ InterpreterError("Forbidden access to module: os"),
2068
+ ),
2069
+ # subprocess submodule
2070
+ (
2071
+ "import asyncio; asyncio.base_events.events.subprocess",
2072
+ ["asyncio"],
2073
+ InterpreterError("Forbidden access to module: asyncio.base_events"),
2074
+ ),
2075
+ (
2076
+ "import asyncio; asyncio.base_events.events.subprocess",
2077
+ ["asyncio", "asyncio.base_events"],
2078
+ InterpreterError("Forbidden access to module: asyncio.events"),
2079
+ ),
2080
+ (
2081
+ "import asyncio; asyncio.base_events.events.subprocess",
2082
+ ["asyncio", "asyncio.base_events", "asyncio.base_events.events"],
2083
+ InterpreterError("Forbidden access to module: asyncio.events"),
2084
+ ),
2085
+ # sys submodule
2086
+ (
2087
+ "import queue; queue.threading._sys.modules['os'].system(':')",
2088
+ [],
2089
+ InterpreterError("Forbidden access to module: threading"),
2090
+ ),
2091
+ (
2092
+ "import queue; queue.threading._sys.modules['os'].system(':')",
2093
+ ["threading"],
2094
+ InterpreterError("Forbidden access to module: sys"),
2095
+ ),
2096
+ ("import warnings; warnings.sys", ["warnings"], InterpreterError("Forbidden access to module: sys")),
2097
+ # Allowed
2098
+ ("import pandas; pandas.io", ["pandas", "pandas.io"], None),
2099
+ ],
2100
+ )
2101
+ def test_vulnerability_via_submodules(self, code, additional_authorized_imports, expected_error):
2102
+ executor = LocalPythonExecutor(additional_authorized_imports)
2103
+ with (
2104
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
2105
+ if isinstance(expected_error, Exception)
2106
+ else does_not_raise()
2107
+ ):
2108
+ executor(code)
2109
+
2110
+ @pytest.mark.parametrize(
2111
+ "code, additional_authorized_imports, expected_error",
2112
+ [
2113
+ # Using filter with functools.partial
2114
+ (
2115
+ dedent(
2116
+ """
2117
+ import functools
2118
+ import warnings
2119
+ list(filter(functools.partial(getattr, warnings), ["sys"]))
2120
+ """
2121
+ ),
2122
+ ["warnings", "functools"],
2123
+ InterpreterError("Forbidden access to module: sys"),
2124
+ ),
2125
+ # Using map
2126
+ (
2127
+ dedent(
2128
+ """
2129
+ import warnings
2130
+ list(map(getattr, [warnings], ["sys"]))
2131
+ """
2132
+ ),
2133
+ ["warnings"],
2134
+ InterpreterError("Forbidden access to module: sys"),
2135
+ ),
2136
+ # Using map with functools.partial
2137
+ (
2138
+ dedent(
2139
+ """
2140
+ import functools
2141
+ import warnings
2142
+ list(map(functools.partial(getattr, warnings), ["sys"]))
2143
+ """
2144
+ ),
2145
+ ["warnings", "functools"],
2146
+ InterpreterError("Forbidden access to module: sys"),
2147
+ ),
2148
+ ],
2149
+ )
2150
+ def test_vulnerability_via_submodules_through_indirect_attribute_access(
2151
+ self, code, additional_authorized_imports, expected_error
2152
+ ):
2153
+ # warnings.sys
2154
+ executor = LocalPythonExecutor(additional_authorized_imports)
2155
+ executor.send_tools({})
2156
+ with pytest.raises(type(expected_error), match=f".*{expected_error}"):
2157
+ executor(code)
2158
+
2159
+ @pytest.mark.parametrize(
2160
+ "additional_authorized_imports, additional_tools, expected_error",
2161
+ [
2162
+ ([], [], InterpreterError("Import of sys is not allowed")),
2163
+ (["sys"], [], InterpreterError("Forbidden access to module: builtins")),
2164
+ (
2165
+ ["sys", "builtins"],
2166
+ [],
2167
+ InterpreterError("Forbidden access to function: __import__"),
2168
+ ),
2169
+ (["sys", "builtins"], ["__import__"], InterpreterError("Forbidden access to module: os")),
2170
+ (["sys", "builtins", "os"], ["__import__"], None),
2171
+ ],
2172
+ )
2173
+ def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, additional_tools, expected_error):
2174
+ executor = LocalPythonExecutor(additional_authorized_imports)
2175
+ if additional_tools:
2176
+ from builtins import __import__
2177
+
2178
+ executor.send_tools({"__import__": __import__})
2179
+ with (
2180
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
2181
+ if isinstance(expected_error, Exception)
2182
+ else does_not_raise()
2183
+ ):
2184
+ executor(
2185
+ dedent(
2186
+ """
2187
+ import sys
2188
+ builtins = sys._getframe().f_builtins
2189
+ builtins_import = builtins["__import__"]
2190
+ os_module = builtins_import("os")
2191
+ os_module.system(":")
2192
+ """
2193
+ )
2194
+ )
2195
+
2196
+ @pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
2197
+ @pytest.mark.parametrize(
2198
+ "additional_authorized_imports, additional_tools, expected_error",
2199
+ [
2200
+ ([], [], InterpreterError("Forbidden access to dunder attribute: __traceback__")),
2201
+ (
2202
+ ["builtins", "os"],
2203
+ ["__import__"],
2204
+ InterpreterError("Forbidden access to dunder attribute: __traceback__"),
2205
+ ),
2206
+ ],
2207
+ )
2208
+ def test_vulnerability_builtins_via_traceback(
2209
+ self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
2210
+ ):
2211
+ if patch_builtin_import_module:
2212
+ monkeypatch.setattr("builtins.__import__.__module__", None) # inspect.getmodule(func) = None
2213
+ executor = LocalPythonExecutor(additional_authorized_imports)
2214
+ if additional_tools:
2215
+ from builtins import __import__
2216
+
2217
+ executor.send_tools({"__import__": __import__})
2218
+ with (
2219
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
2220
+ if isinstance(expected_error, Exception)
2221
+ else does_not_raise()
2222
+ ):
2223
+ executor(
2224
+ dedent(
2225
+ """
2226
+ try:
2227
+ 1 / 0
2228
+ except Exception as e:
2229
+ builtins = e.__traceback__.tb_frame.f_back.f_globals["__builtins__"]
2230
+ builtins_import = builtins["__import__"]
2231
+ os_module = builtins_import("os")
2232
+ os_module.system(":")
2233
+ """
2234
+ )
2235
+ )
2236
+
2237
+ @pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
2238
+ @pytest.mark.parametrize(
2239
+ "additional_authorized_imports, additional_tools, expected_error",
2240
+ [
2241
+ ([], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
2242
+ (["warnings"], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
2243
+ (
2244
+ ["warnings", "builtins"],
2245
+ [],
2246
+ InterpreterError("Forbidden access to dunder attribute: __base__"),
2247
+ ),
2248
+ (["warnings", "builtins", "os"], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
2249
+ (
2250
+ ["warnings", "builtins", "os"],
2251
+ ["__import__"],
2252
+ InterpreterError("Forbidden access to dunder attribute: __base__"),
2253
+ ),
2254
+ ],
2255
+ )
2256
+ def test_vulnerability_builtins_via_class_catch_warnings(
2257
+ self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
2258
+ ):
2259
+ if patch_builtin_import_module:
2260
+ monkeypatch.setattr("builtins.__import__.__module__", None) # inspect.getmodule(func) = None
2261
+ executor = LocalPythonExecutor(additional_authorized_imports)
2262
+ if additional_tools:
2263
+ from builtins import __import__
2264
+
2265
+ executor.send_tools({"__import__": __import__})
2266
+ if isinstance(expected_error, tuple): # different error depending on patch status
2267
+ expected_error = expected_error[patch_builtin_import_module]
2268
+ if isinstance(expected_error, Exception):
2269
+ expectation = pytest.raises(type(expected_error), match=f".*{expected_error}")
2270
+ elif expected_error is None:
2271
+ expectation = does_not_raise()
2272
+ with expectation:
2273
+ executor(
2274
+ dedent(
2275
+ """
2276
+ classes = {}.__class__.__base__.__subclasses__()
2277
+ for cls in classes:
2278
+ if cls.__name__ == "catch_warnings":
2279
+ break
2280
+ builtins = cls()._module.__builtins__
2281
+ builtins_import = builtins["__import__"]
2282
+ os_module = builtins_import('os')
2283
+ os_module.system(":")
2284
+ """
2285
+ )
2286
+ )
2287
+
2288
+ @pytest.mark.filterwarnings("ignore::DeprecationWarning")
2289
+ @pytest.mark.parametrize(
2290
+ "additional_authorized_imports, expected_error",
2291
+ [
2292
+ ([], InterpreterError("Forbidden access to dunder attribute: __base__")),
2293
+ (["os"], InterpreterError("Forbidden access to dunder attribute: __base__")),
2294
+ ],
2295
+ )
2296
+ def test_vulnerability_load_module_via_builtin_importer(self, additional_authorized_imports, expected_error):
2297
+ executor = LocalPythonExecutor(additional_authorized_imports)
2298
+ with (
2299
+ pytest.raises(type(expected_error), match=f".*{expected_error}")
2300
+ if isinstance(expected_error, Exception)
2301
+ else does_not_raise()
2302
+ ):
2303
+ executor(
2304
+ dedent(
2305
+ """
2306
+ classes = {}.__class__.__base__.__subclasses__()
2307
+ for cls in classes:
2308
+ if cls.__name__ == "BuiltinImporter":
2309
+ break
2310
+ os_module = cls().load_module("os")
2311
+ os_module.system(":")
2312
+ """
2313
+ )
2314
+ )
2315
+
2316
+ def test_vulnerability_class_via_subclasses(self):
2317
+ # Subclass: subprocess.Popen
2318
+ executor = LocalPythonExecutor([])
2319
+ code = dedent(
2320
+ """
2321
+ for cls in ().__class__.__base__.__subclasses__():
2322
+ if 'Popen' in cls.__class__.__repr__(cls):
2323
+ break
2324
+ cls(["sh", "-c", ":"]).wait()
2325
+ """
2326
+ )
2327
+ with pytest.raises(InterpreterError, match="Forbidden access to dunder attribute: __base__"):
2328
+ executor(code)
2329
+
2330
+ code = dedent(
2331
+ """
2332
+ [c for c in ().__class__.__base__.__subclasses__() if "Popen" in c.__class__.__repr__(c)][0](
2333
+ ["sh", "-c", ":"]
2334
+ ).wait()
2335
+ """
2336
+ )
2337
+ with pytest.raises(InterpreterError, match="Forbidden access to dunder attribute: __base__"):
2338
+ executor(code)
2339
+
2340
+ @pytest.mark.parametrize(
2341
+ "code, dunder_attribute",
2342
+ [("a = (); b = a.__class__", "__class__"), ("class A:\n attr=1\nx = A()\nx_dict = x.__dict__", "__dict__")],
2343
+ )
2344
+ def test_vulnerability_via_dunder_access(self, code, dunder_attribute):
2345
+ executor = LocalPythonExecutor([])
2346
+ with pytest.raises(InterpreterError, match=f"Forbidden access to dunder attribute: {dunder_attribute}"):
2347
+ executor(code)
2348
+
2349
+ def test_vulnerability_via_dunder_indirect_access(self):
2350
+ executor = LocalPythonExecutor([])
2351
+ code = "a = (); b = getattr(a, '__class__')"
2352
+ with pytest.raises(InterpreterError, match="Forbidden function evaluation: 'getattr'"):
2353
+ executor(code)
tests/test_mcp_client.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from textwrap import dedent
2
+
3
+ import pytest
4
+ from mcp import StdioServerParameters
5
+
6
+ from smolagents.mcp_client import MCPClient
7
+
8
+
9
+ @pytest.fixture
10
+ def echo_server_script():
11
+ return dedent(
12
+ '''
13
+ from mcp.server.fastmcp import FastMCP
14
+
15
+ mcp = FastMCP("Echo Server")
16
+
17
+ @mcp.tool()
18
+ def echo_tool(text: str) -> str:
19
+ """Echo the input text"""
20
+ return f"Echo: {text}"
21
+
22
+ mcp.run()
23
+ '''
24
+ )
25
+
26
+
27
+ def test_mcp_client_with_syntax(echo_server_script: str):
28
+ """Test the MCPClient with the context manager syntax."""
29
+ server_parameters = StdioServerParameters(command="python", args=["-c", echo_server_script])
30
+ with MCPClient(server_parameters) as tools:
31
+ assert len(tools) == 1
32
+ assert tools[0].name == "echo_tool"
33
+ assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
34
+
35
+
36
+ def test_mcp_client_try_finally_syntax(echo_server_script: str):
37
+ """Test the MCPClient with the try ... finally syntax."""
38
+ server_parameters = StdioServerParameters(command="python", args=["-c", echo_server_script])
39
+ mcp_client = MCPClient(server_parameters)
40
+ try:
41
+ tools = mcp_client.get_tools()
42
+ assert len(tools) == 1
43
+ assert tools[0].name == "echo_tool"
44
+ assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
45
+ finally:
46
+ mcp_client.disconnect()
47
+
48
+
49
+ def test_multiple_servers(echo_server_script: str):
50
+ """Test the MCPClient with multiple servers."""
51
+ server_parameters = [
52
+ StdioServerParameters(command="python", args=["-c", echo_server_script]),
53
+ StdioServerParameters(command="python", args=["-c", echo_server_script]),
54
+ ]
55
+ with MCPClient(server_parameters) as tools:
56
+ assert len(tools) == 2
57
+ assert tools[0].name == "echo_tool"
58
+ assert tools[1].name == "echo_tool"
59
+ assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
60
+ assert tools[1].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
tests/test_memory.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from PIL import Image
3
+
4
+ from smolagents.agents import ToolCall
5
+ from smolagents.memory import (
6
+ ActionStep,
7
+ AgentMemory,
8
+ ChatMessage,
9
+ MemoryStep,
10
+ MessageRole,
11
+ PlanningStep,
12
+ SystemPromptStep,
13
+ TaskStep,
14
+ )
15
+ from smolagents.monitoring import Timing, TokenUsage
16
+
17
+
18
+ class TestAgentMemory:
19
+ def test_initialization(self):
20
+ system_prompt = "This is a system prompt."
21
+ memory = AgentMemory(system_prompt=system_prompt)
22
+ assert memory.system_prompt.system_prompt == system_prompt
23
+ assert memory.steps == []
24
+
25
+ def test_return_all_code_actions(self):
26
+ memory = AgentMemory(system_prompt="This is a system prompt.")
27
+ memory.steps = [
28
+ ActionStep(step_number=1, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('Hello')"),
29
+ ActionStep(step_number=2, timing=Timing(start_time=0.0, end_time=1.0), code_action=None),
30
+ ActionStep(step_number=3, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('World')"),
31
+ ] # type: ignore
32
+ assert memory.return_full_code() == "print('Hello')\n\nprint('World')"
33
+
34
+
35
+ class TestMemoryStep:
36
+ def test_initialization(self):
37
+ step = MemoryStep()
38
+ assert isinstance(step, MemoryStep)
39
+
40
+ def test_dict(self):
41
+ step = MemoryStep()
42
+ assert step.dict() == {}
43
+
44
+ def test_to_messages(self):
45
+ step = MemoryStep()
46
+ with pytest.raises(NotImplementedError):
47
+ step.to_messages()
48
+
49
+
50
+ def test_action_step_dict():
51
+ action_step = ActionStep(
52
+ model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
53
+ tool_calls=[
54
+ ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}),
55
+ ],
56
+ timing=Timing(start_time=0.0, end_time=1.0),
57
+ step_number=1,
58
+ error=None,
59
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
60
+ model_output="Hi",
61
+ observations="This is a nice observation",
62
+ observations_images=[Image.new("RGB", (100, 100))],
63
+ action_output="Output",
64
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20),
65
+ )
66
+ action_step_dict = action_step.dict()
67
+ # Check each key individually for better test failure messages
68
+ assert "model_input_messages" in action_step_dict
69
+ assert action_step_dict["model_input_messages"] == [ChatMessage(role=MessageRole.USER, content="Hello")]
70
+
71
+ assert "tool_calls" in action_step_dict
72
+ assert len(action_step_dict["tool_calls"]) == 1
73
+ assert action_step_dict["tool_calls"][0] == {
74
+ "id": "id",
75
+ "type": "function",
76
+ "function": {
77
+ "name": "get_weather",
78
+ "arguments": {"location": "Paris"},
79
+ },
80
+ }
81
+
82
+ assert "timing" in action_step_dict
83
+ assert action_step_dict["timing"] == {"start_time": 0.0, "end_time": 1.0, "duration": 1.0}
84
+
85
+ assert "token_usage" in action_step_dict
86
+ assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
87
+
88
+ assert "step_number" in action_step_dict
89
+ assert action_step_dict["step_number"] == 1
90
+
91
+ assert "error" in action_step_dict
92
+ assert action_step_dict["error"] is None
93
+
94
+ assert "model_output_message" in action_step_dict
95
+ assert action_step_dict["model_output_message"] == {
96
+ "role": "assistant",
97
+ "content": "Hi",
98
+ "tool_calls": None,
99
+ "raw": None,
100
+ "token_usage": None,
101
+ }
102
+
103
+ assert "model_output" in action_step_dict
104
+ assert action_step_dict["model_output"] == "Hi"
105
+
106
+ assert "observations" in action_step_dict
107
+ assert action_step_dict["observations"] == "This is a nice observation"
108
+
109
+ assert "observations_images" in action_step_dict
110
+
111
+ assert "action_output" in action_step_dict
112
+ assert action_step_dict["action_output"] == "Output"
113
+
114
+
115
+ def test_action_step_to_messages():
116
+ action_step = ActionStep(
117
+ model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
118
+ tool_calls=[
119
+ ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}),
120
+ ],
121
+ timing=Timing(start_time=0.0, end_time=1.0),
122
+ step_number=1,
123
+ error=None,
124
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
125
+ model_output="Hi",
126
+ observations="This is a nice observation",
127
+ observations_images=[Image.new("RGB", (100, 100))],
128
+ action_output="Output",
129
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20),
130
+ )
131
+ messages = action_step.to_messages()
132
+ assert len(messages) == 4
133
+ for message in messages:
134
+ assert isinstance(message, ChatMessage)
135
+ assistant_message = messages[0]
136
+ assert assistant_message.role == MessageRole.ASSISTANT
137
+ assert len(assistant_message.content) == 1
138
+ assert assistant_message.content[0]["type"] == "text"
139
+ assert assistant_message.content[0]["text"] == "Hi"
140
+ message = messages[1]
141
+ assert message.role == MessageRole.TOOL_CALL
142
+
143
+ assert len(message.content) == 1
144
+ assert message.content[0]["type"] == "text"
145
+ assert "Calling tools:" in message.content[0]["text"]
146
+
147
+ image_message = messages[2]
148
+ assert image_message.content[0]["type"] == "image" # type: ignore
149
+
150
+ observation_message = messages[3]
151
+ assert observation_message.role == MessageRole.TOOL_RESPONSE
152
+ assert "Observation:\nThis is a nice observation" in observation_message.content[0]["text"]
153
+
154
+
155
+ def test_action_step_to_messages_no_tool_calls_with_observations():
156
+ action_step = ActionStep(
157
+ model_input_messages=None,
158
+ tool_calls=None,
159
+ timing=Timing(start_time=0.0, end_time=1.0),
160
+ step_number=1,
161
+ error=None,
162
+ model_output_message=None,
163
+ model_output=None,
164
+ observations="This is an observation.",
165
+ observations_images=None,
166
+ action_output=None,
167
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20),
168
+ )
169
+ messages = action_step.to_messages()
170
+ assert len(messages) == 1
171
+ observation_message = messages[0]
172
+ assert observation_message.role == MessageRole.TOOL_RESPONSE
173
+ assert "Observation:\nThis is an observation." in observation_message.content[0]["text"]
174
+
175
+
176
+ def test_planning_step_to_messages():
177
+ planning_step = PlanningStep(
178
+ model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
179
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Plan"),
180
+ plan="This is a plan.",
181
+ timing=Timing(start_time=0.0, end_time=1.0),
182
+ )
183
+ messages = planning_step.to_messages(summary_mode=False)
184
+ assert len(messages) == 2
185
+ for message in messages:
186
+ assert isinstance(message, ChatMessage)
187
+ assert isinstance(message.content, list)
188
+ assert len(message.content) == 1
189
+ for content in message.content:
190
+ assert isinstance(content, dict)
191
+ assert "type" in content
192
+ assert "text" in content
193
+ assert messages[0].role == MessageRole.ASSISTANT
194
+ assert messages[1].role == MessageRole.USER
195
+
196
+
197
+ def test_task_step_to_messages():
198
+ task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))])
199
+ messages = task_step.to_messages(summary_mode=False)
200
+ assert len(messages) == 1
201
+ for message in messages:
202
+ assert isinstance(message, ChatMessage)
203
+ assert message.role == MessageRole.USER
204
+ assert isinstance(message.content, list)
205
+ assert len(message.content) == 2
206
+ text_content = message.content[0]
207
+ assert isinstance(text_content, dict)
208
+ assert "type" in text_content
209
+ assert "text" in text_content
210
+ for image_content in message.content[1:]:
211
+ assert isinstance(image_content, dict)
212
+ assert "type" in image_content
213
+ assert "image" in image_content
214
+
215
+
216
+ def test_system_prompt_step_to_messages():
217
+ system_prompt_step = SystemPromptStep(system_prompt="This is a system prompt.")
218
+ messages = system_prompt_step.to_messages(summary_mode=False)
219
+ assert len(messages) == 1
220
+ for message in messages:
221
+ assert isinstance(message, ChatMessage)
222
+ assert message.role == MessageRole.SYSTEM
223
+ assert isinstance(message.content, list)
224
+ assert len(message.content) == 1
225
+ for content in message.content:
226
+ assert isinstance(content, dict)
227
+ assert "type" in content
228
+ assert "text" in content
tests/test_models.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import sys
17
+ import unittest
18
+ from contextlib import ExitStack
19
+ from unittest.mock import MagicMock, patch
20
+
21
+ import pytest
22
+ from huggingface_hub import ChatCompletionOutputMessage
23
+
24
+ from smolagents.default_tools import FinalAnswerTool
25
+ from smolagents.models import (
26
+ AmazonBedrockServerModel,
27
+ AzureOpenAIServerModel,
28
+ ChatMessage,
29
+ ChatMessageToolCall,
30
+ InferenceClientModel,
31
+ LiteLLMModel,
32
+ LiteLLMRouterModel,
33
+ MessageRole,
34
+ MLXModel,
35
+ Model,
36
+ OpenAIServerModel,
37
+ TransformersModel,
38
+ get_clean_message_list,
39
+ get_tool_call_from_text,
40
+ get_tool_json_schema,
41
+ parse_json_if_needed,
42
+ supports_stop_parameter,
43
+ )
44
+ from smolagents.tools import tool
45
+
46
+ from .utils.markers import require_run_all
47
+
48
+
49
+ class TestModel:
50
+ def test_agglomerate_stream_deltas(self):
51
+ from smolagents.models import (
52
+ ChatMessageStreamDelta,
53
+ ChatMessageToolCallFunction,
54
+ ChatMessageToolCallStreamDelta,
55
+ TokenUsage,
56
+ agglomerate_stream_deltas,
57
+ )
58
+
59
+ stream_deltas = [
60
+ ChatMessageStreamDelta(
61
+ content="Hi",
62
+ tool_calls=[
63
+ ChatMessageToolCallStreamDelta(
64
+ index=0,
65
+ type="function",
66
+ function=ChatMessageToolCallFunction(arguments="", name="web_search", description=None),
67
+ )
68
+ ],
69
+ token_usage=None,
70
+ ),
71
+ ChatMessageStreamDelta(
72
+ content=" everyone",
73
+ tool_calls=[
74
+ ChatMessageToolCallStreamDelta(
75
+ index=0,
76
+ type="function",
77
+ function=ChatMessageToolCallFunction(arguments=' {"', name="web_search", description=None),
78
+ )
79
+ ],
80
+ token_usage=None,
81
+ ),
82
+ ChatMessageStreamDelta(
83
+ content=", it's",
84
+ tool_calls=[
85
+ ChatMessageToolCallStreamDelta(
86
+ index=0,
87
+ type="function",
88
+ function=ChatMessageToolCallFunction(
89
+ arguments='query": "current pope name and date of birth"}',
90
+ name="web_search",
91
+ description=None,
92
+ ),
93
+ )
94
+ ],
95
+ token_usage=None,
96
+ ),
97
+ ChatMessageStreamDelta(
98
+ content="",
99
+ tool_calls=None,
100
+ token_usage=TokenUsage(input_tokens=1348, output_tokens=24),
101
+ ),
102
+ ]
103
+ agglomerated_stream_delta = agglomerate_stream_deltas(stream_deltas)
104
+ assert agglomerated_stream_delta.content == "Hi everyone, it's"
105
+ assert (
106
+ agglomerated_stream_delta.tool_calls[0].function.arguments
107
+ == ' {"query": "current pope name and date of birth"}'
108
+ )
109
+ assert agglomerated_stream_delta.token_usage.total_tokens == 1372
110
+
111
+ @pytest.mark.parametrize(
112
+ "model_id, stop_sequences, should_contain_stop",
113
+ [
114
+ ("regular-model", ["stop1", "stop2"], True), # Regular model should include stop
115
+ ("openai/o3", ["stop1", "stop2"], False), # o3 model should not include stop
116
+ ("openai/o4-mini", ["stop1", "stop2"], False), # o4-mini model should not include stop
117
+ ("something/else/o3", ["stop1", "stop2"], False), # Path ending with o3 should not include stop
118
+ ("something/else/o4-mini", ["stop1", "stop2"], False), # Path ending with o4-mini should not include stop
119
+ ("o3", ["stop1", "stop2"], False), # Exact o3 model should not include stop
120
+ ("o4-mini", ["stop1", "stop2"], False), # Exact o4-mini model should not include stop
121
+ ("regular-model", None, False), # None stop_sequences should not add stop parameter
122
+ ],
123
+ )
124
+ def test_prepare_completion_kwargs_stop_sequences(self, model_id, stop_sequences, should_contain_stop):
125
+ model = Model()
126
+ model.model_id = model_id
127
+ completion_kwargs = model._prepare_completion_kwargs(
128
+ messages=[
129
+ ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello"}]),
130
+ ],
131
+ stop_sequences=stop_sequences,
132
+ )
133
+ # Verify that the stop parameter is only included when appropriate
134
+ if should_contain_stop:
135
+ assert "stop" in completion_kwargs
136
+ assert completion_kwargs["stop"] == stop_sequences
137
+ else:
138
+ assert "stop" not in completion_kwargs
139
+
140
+ @pytest.mark.parametrize(
141
+ "with_tools, tool_choice, expected_result",
142
+ [
143
+ # Default behavior: With tools but no explicit tool_choice, should default to "required"
144
+ (True, ..., {"has_tool_choice": True, "value": "required"}),
145
+ # Custom value: With tools and explicit tool_choice="auto"
146
+ (True, "auto", {"has_tool_choice": True, "value": "auto"}),
147
+ # Tool name as string
148
+ (True, "valid_tool_function", {"has_tool_choice": True, "value": "valid_tool_function"}),
149
+ # Tool choice as dictionary
150
+ (
151
+ True,
152
+ {"type": "function", "function": {"name": "valid_tool_function"}},
153
+ {"has_tool_choice": True, "value": {"type": "function", "function": {"name": "valid_tool_function"}}},
154
+ ),
155
+ # With tools but explicit None tool_choice: should exclude tool_choice
156
+ (True, None, {"has_tool_choice": False, "value": None}),
157
+ # Without tools: tool_choice should never be included
158
+ (False, "required", {"has_tool_choice": False, "value": None}),
159
+ (False, "auto", {"has_tool_choice": False, "value": None}),
160
+ (False, None, {"has_tool_choice": False, "value": None}),
161
+ (False, ..., {"has_tool_choice": False, "value": None}),
162
+ ],
163
+ )
164
+ def test_prepare_completion_kwargs_tool_choice(self, with_tools, tool_choice, expected_result, example_tool):
165
+ model = Model()
166
+ kwargs = {"messages": [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello"}])]}
167
+ if with_tools:
168
+ kwargs["tools_to_call_from"] = [example_tool]
169
+ if tool_choice is not ...:
170
+ kwargs["tool_choice"] = tool_choice
171
+
172
+ completion_kwargs = model._prepare_completion_kwargs(**kwargs)
173
+
174
+ if expected_result["has_tool_choice"]:
175
+ assert "tool_choice" in completion_kwargs
176
+ assert completion_kwargs["tool_choice"] == expected_result["value"]
177
+ else:
178
+ assert "tool_choice" not in completion_kwargs
179
+
180
+ def test_get_json_schema_has_nullable_args(self):
181
+ @tool
182
+ def get_weather(location: str, celsius: bool | None = False) -> str:
183
+ """
184
+ Get weather in the next days at given location.
185
+ Secretly this tool does not care about the location, it hates the weather everywhere.
186
+
187
+ Args:
188
+ location: the location
189
+ celsius: the temperature type
190
+ """
191
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
192
+
193
+ assert "nullable" in get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
194
+
195
+ def test_chatmessage_has_model_dumps_json(self):
196
+ message = ChatMessage("user", [{"type": "text", "text": "Hello!"}])
197
+ data = json.loads(message.model_dump_json())
198
+ assert data["content"] == [{"type": "text", "text": "Hello!"}]
199
+
200
+ @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
201
+ def test_get_mlx_message_no_tool(self):
202
+ model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
203
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
204
+ output = model(messages, stop_sequences=["great"]).content
205
+ assert output.startswith("Hello")
206
+
207
+ @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
208
+ def test_get_mlx_message_tricky_stop_sequence(self):
209
+ # In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'"
210
+ # which is required to test capturing stop_sequences that have extra chars at the end.
211
+ model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100)
212
+ stop_sequence = " print '>"
213
+ messages = [
214
+ ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": f"Please{stop_sequence}'"}]),
215
+ ]
216
+ # check our assumption that that ">" is followed by "'"
217
+ assert model.tokenizer.vocab[">'"]
218
+ assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'"
219
+ # check stop_sequence capture when output has trailing chars
220
+ assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you"
221
+
222
+ def test_transformers_message_no_tool(self, monkeypatch):
223
+ monkeypatch.setattr("huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT", 30) # instead of 10
224
+ model = TransformersModel(
225
+ model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
226
+ max_new_tokens=5,
227
+ device_map="cpu",
228
+ do_sample=False,
229
+ )
230
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
231
+ output = model.generate(messages).content
232
+ assert output == "Hello! I'm here"
233
+
234
+ output = model.generate_stream(messages, stop_sequences=["great"])
235
+ output_str = ""
236
+ for el in output:
237
+ output_str += el.content
238
+ assert output_str == "Hello! I'm here"
239
+
240
+ def test_transformers_message_vl_no_tool(self, shared_datadir, monkeypatch):
241
+ monkeypatch.setattr("huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT", 30) # instead of 10
242
+ import PIL.Image
243
+
244
+ img = PIL.Image.open(shared_datadir / "000000039769.png")
245
+ model = TransformersModel(
246
+ model_id="llava-hf/llava-interleave-qwen-0.5b-hf",
247
+ max_new_tokens=4,
248
+ device_map="cpu",
249
+ do_sample=False,
250
+ )
251
+ messages = [
252
+ ChatMessage(
253
+ role=MessageRole.USER,
254
+ content=[{"type": "text", "text": "What is this?"}, {"type": "image", "image": img}],
255
+ )
256
+ ]
257
+ output = model.generate(messages).content
258
+ assert output == "This is a very"
259
+
260
+ output = model.generate_stream(messages, stop_sequences=["great"])
261
+ output_str = ""
262
+ for el in output:
263
+ output_str += el.content
264
+ assert output_str == "This is a very"
265
+
266
+ def test_parse_json_if_needed(self):
267
+ args = "abc"
268
+ parsed_args = parse_json_if_needed(args)
269
+ assert parsed_args == "abc"
270
+
271
+ args = '{"a": 3}'
272
+ parsed_args = parse_json_if_needed(args)
273
+ assert parsed_args == {"a": 3}
274
+
275
+ args = "3"
276
+ parsed_args = parse_json_if_needed(args)
277
+ assert parsed_args == 3
278
+
279
+ args = 3
280
+ parsed_args = parse_json_if_needed(args)
281
+ assert parsed_args == 3
282
+
283
+
284
+ class TestInferenceClientModel:
285
+ def test_call_with_custom_role_conversions(self):
286
+ custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
287
+ model = InferenceClientModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
288
+ model.client = MagicMock()
289
+ mock_response = model.client.chat_completion.return_value
290
+ mock_response.choices[0].message = ChatCompletionOutputMessage(role=MessageRole.ASSISTANT)
291
+ messages = [ChatMessage(role=MessageRole.USER, content="Test message")]
292
+ _ = model(messages)
293
+ # Verify that the role conversion was applied
294
+ assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
295
+ "role conversion should be applied"
296
+ )
297
+
298
+ def test_init_model_with_tokens(self):
299
+ model = InferenceClientModel(model_id="test-model", token="abc")
300
+ assert model.client.token == "abc"
301
+
302
+ model = InferenceClientModel(model_id="test-model", api_key="abc")
303
+ assert model.client.token == "abc"
304
+
305
+ with pytest.raises(ValueError, match="Received both `token` and `api_key` arguments."):
306
+ InferenceClientModel(model_id="test-model", token="abc", api_key="def")
307
+
308
+ def test_structured_outputs_with_unsupported_provider(self):
309
+ with pytest.raises(
310
+ ValueError, match="InferenceClientModel only supports structured outputs with these providers:"
311
+ ):
312
+ model = InferenceClientModel(model_id="test-model", token="abc", provider="some_provider")
313
+ model.generate(
314
+ messages=[ChatMessage(role=MessageRole.USER, content="Hello!")],
315
+ response_format={"type": "json_object"},
316
+ )
317
+
318
+ @require_run_all
319
+ def test_get_hfapi_message_no_tool(self):
320
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=10)
321
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
322
+ model(messages, stop_sequences=["great"])
323
+
324
+ @require_run_all
325
+ def test_get_hfapi_message_no_tool_external_provider(self):
326
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
327
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
328
+ model(messages, stop_sequences=["great"])
329
+
330
+ @require_run_all
331
+ def test_get_hfapi_message_stream_no_tool(self):
332
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=10)
333
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
334
+ for el in model.generate_stream(messages, stop_sequences=["great"]):
335
+ assert el.content is not None
336
+
337
+ @require_run_all
338
+ def test_get_hfapi_message_stream_no_tool_external_provider(self):
339
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
340
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
341
+ for el in model.generate_stream(messages, stop_sequences=["great"]):
342
+ assert el.content is not None
343
+
344
+
345
+ class TestLiteLLMModel:
346
+ @pytest.mark.parametrize(
347
+ "model_id, error_flag",
348
+ [
349
+ ("groq/llama-3.3-70b", "Invalid API Key"),
350
+ ("cerebras/llama-3.3-70b", "The api_key client option must be set"),
351
+ ("mistral/mistral-tiny", "The api_key client option must be set"),
352
+ ],
353
+ )
354
+ def test_call_different_providers_without_key(self, model_id, error_flag):
355
+ model = LiteLLMModel(model_id=model_id)
356
+ messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Test message"}])]
357
+ with pytest.raises(Exception) as e:
358
+ # This should raise 401 error because of missing API key, not fail for any "bad format" reason
359
+ model.generate(messages)
360
+ assert error_flag in str(e)
361
+ with pytest.raises(Exception) as e:
362
+ # This should raise 401 error because of missing API key, not fail for any "bad format" reason
363
+ for el in model.generate_stream(messages):
364
+ assert el.content is not None
365
+ assert error_flag in str(e)
366
+
367
+ def test_passing_flatten_messages(self):
368
+ model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
369
+ assert not model.flatten_messages_as_text
370
+
371
+ model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
372
+ assert model.flatten_messages_as_text
373
+
374
+
375
+ class TestLiteLLMRouterModel:
376
+ @pytest.mark.parametrize(
377
+ "model_id, expected",
378
+ [
379
+ ("llama-3.3-70b", False),
380
+ ("llama-3.3-70b", True),
381
+ ("mistral-tiny", True),
382
+ ],
383
+ )
384
+ def test_flatten_messages_as_text(self, model_id, expected):
385
+ model_list = [
386
+ {"model_name": "llama-3.3-70b", "litellm_params": {"model": "groq/llama-3.3-70b"}},
387
+ {"model_name": "llama-3.3-70b", "litellm_params": {"model": "cerebras/llama-3.3-70b"}},
388
+ {"model_name": "mistral-tiny", "litellm_params": {"model": "mistral/mistral-tiny"}},
389
+ ]
390
+ model = LiteLLMRouterModel(model_id=model_id, model_list=model_list, flatten_messages_as_text=expected)
391
+ assert model.flatten_messages_as_text is expected
392
+
393
+ def test_create_client(self):
394
+ model_list = [
395
+ {"model_name": "llama-3.3-70b", "litellm_params": {"model": "groq/llama-3.3-70b"}},
396
+ {"model_name": "llama-3.3-70b", "litellm_params": {"model": "cerebras/llama-3.3-70b"}},
397
+ ]
398
+ with patch("litellm.router.Router") as mock_router:
399
+ router_model = LiteLLMRouterModel(
400
+ model_id="model-group-1", model_list=model_list, client_kwargs={"routing_strategy": "simple-shuffle"}
401
+ )
402
+ # Ensure that the Router constructor was called with the expected keyword arguments
403
+ mock_router.assert_called_once()
404
+ assert mock_router.call_count == 1
405
+ assert mock_router.call_args.kwargs["model_list"] == model_list
406
+ assert mock_router.call_args.kwargs["routing_strategy"] == "simple-shuffle"
407
+ assert router_model.client == mock_router.return_value
408
+
409
+
410
+ class TestOpenAIServerModel:
411
+ def test_client_kwargs_passed_correctly(self):
412
+ model_id = "gpt-3.5-turbo"
413
+ api_base = "https://api.openai.com/v1"
414
+ api_key = "test_api_key"
415
+ organization = "test_org"
416
+ project = "test_project"
417
+ client_kwargs = {"max_retries": 5}
418
+
419
+ with patch("openai.OpenAI") as MockOpenAI:
420
+ model = OpenAIServerModel(
421
+ model_id=model_id,
422
+ api_base=api_base,
423
+ api_key=api_key,
424
+ organization=organization,
425
+ project=project,
426
+ client_kwargs=client_kwargs,
427
+ )
428
+ MockOpenAI.assert_called_once_with(
429
+ base_url=api_base, api_key=api_key, organization=organization, project=project, max_retries=5
430
+ )
431
+ assert model.client == MockOpenAI.return_value
432
+
433
+ @require_run_all
434
+ def test_streaming_tool_calls(self):
435
+ model = OpenAIServerModel(model_id="gpt-4o-mini")
436
+ messages = [
437
+ ChatMessage(
438
+ role=MessageRole.USER,
439
+ content=[
440
+ {
441
+ "type": "text",
442
+ "text": "Hello! Please return the final answer 'blob' and the final answer 'blob2' in two parallel tool calls",
443
+ }
444
+ ],
445
+ ),
446
+ ]
447
+ for el in model.generate_stream(messages, tools_to_call_from=[FinalAnswerTool()]):
448
+ if el.tool_calls:
449
+ assert el.tool_calls[0].function.name == "final_answer"
450
+ args = el.tool_calls[0].function.arguments
451
+ if len(el.tool_calls) > 1:
452
+ assert el.tool_calls[1].function.name == "final_answer"
453
+ args2 = el.tool_calls[1].function.arguments
454
+ assert args == '{"answer": "blob"}'
455
+ assert args2 == '{"answer": "blob2"}'
456
+
457
+
458
+ class TestAmazonBedrockServerModel:
459
+ def test_client_for_bedrock(self):
460
+ model_id = "us.amazon.nova-pro-v1:0"
461
+
462
+ with patch("boto3.client") as MockBoto3:
463
+ model = AmazonBedrockServerModel(
464
+ model_id=model_id,
465
+ )
466
+
467
+ assert model.client == MockBoto3.return_value
468
+
469
+
470
+ class TestAzureOpenAIServerModel:
471
+ def test_client_kwargs_passed_correctly(self):
472
+ model_id = "gpt-3.5-turbo"
473
+ api_key = "test_api_key"
474
+ api_version = "2023-12-01-preview"
475
+ azure_endpoint = "https://example-resource.azure.openai.com/"
476
+ organization = "test_org"
477
+ project = "test_project"
478
+ client_kwargs = {"max_retries": 5}
479
+
480
+ with patch("openai.OpenAI") as MockOpenAI, patch("openai.AzureOpenAI") as MockAzureOpenAI:
481
+ model = AzureOpenAIServerModel(
482
+ model_id=model_id,
483
+ api_key=api_key,
484
+ api_version=api_version,
485
+ azure_endpoint=azure_endpoint,
486
+ organization=organization,
487
+ project=project,
488
+ client_kwargs=client_kwargs,
489
+ )
490
+ assert MockOpenAI.call_count == 0
491
+ MockAzureOpenAI.assert_called_once_with(
492
+ base_url=None,
493
+ api_key=api_key,
494
+ api_version=api_version,
495
+ azure_endpoint=azure_endpoint,
496
+ organization=organization,
497
+ project=project,
498
+ max_retries=5,
499
+ )
500
+ assert model.client == MockAzureOpenAI.return_value
501
+
502
+
503
+ class TestTransformersModel:
504
+ @pytest.mark.parametrize(
505
+ "patching",
506
+ [
507
+ [
508
+ (
509
+ "transformers.AutoModelForImageTextToText.from_pretrained",
510
+ {"side_effect": ValueError("Unrecognized configuration class")},
511
+ ),
512
+ ("transformers.AutoModelForCausalLM.from_pretrained", {}),
513
+ ("transformers.AutoTokenizer.from_pretrained", {}),
514
+ ],
515
+ [
516
+ ("transformers.AutoModelForImageTextToText.from_pretrained", {}),
517
+ ("transformers.AutoProcessor.from_pretrained", {}),
518
+ ],
519
+ ],
520
+ )
521
+ def test_init(self, patching):
522
+ with ExitStack() as stack:
523
+ mocks = {target: stack.enter_context(patch(target, **kwargs)) for target, kwargs in patching}
524
+ model = TransformersModel(
525
+ model_id="test-model", device_map="cpu", torch_dtype="float16", trust_remote_code=True
526
+ )
527
+ assert model.model_id == "test-model"
528
+ if "transformers.AutoTokenizer.from_pretrained" in mocks:
529
+ assert model.model == mocks["transformers.AutoModelForCausalLM.from_pretrained"].return_value
530
+ assert mocks["transformers.AutoModelForCausalLM.from_pretrained"].call_args.kwargs == {
531
+ "device_map": "cpu",
532
+ "torch_dtype": "float16",
533
+ "trust_remote_code": True,
534
+ }
535
+ assert model.tokenizer == mocks["transformers.AutoTokenizer.from_pretrained"].return_value
536
+ assert mocks["transformers.AutoTokenizer.from_pretrained"].call_args.args == ("test-model",)
537
+ assert mocks["transformers.AutoTokenizer.from_pretrained"].call_args.kwargs == {"trust_remote_code": True}
538
+ elif "transformers.AutoProcessor.from_pretrained" in mocks:
539
+ assert model.model == mocks["transformers.AutoModelForImageTextToText.from_pretrained"].return_value
540
+ assert mocks["transformers.AutoModelForImageTextToText.from_pretrained"].call_args.kwargs == {
541
+ "device_map": "cpu",
542
+ "torch_dtype": "float16",
543
+ "trust_remote_code": True,
544
+ }
545
+ assert model.processor == mocks["transformers.AutoProcessor.from_pretrained"].return_value
546
+ assert mocks["transformers.AutoProcessor.from_pretrained"].call_args.args == ("test-model",)
547
+ assert mocks["transformers.AutoProcessor.from_pretrained"].call_args.kwargs == {"trust_remote_code": True}
548
+
549
+
550
+ def test_get_clean_message_list_basic():
551
+ messages = [
552
+ ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}]),
553
+ ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": "Hi there!"}]),
554
+ ]
555
+ result = get_clean_message_list(messages)
556
+ assert len(result) == 2
557
+ assert result[0]["role"] == "user"
558
+ assert result[0]["content"][0]["text"] == "Hello!"
559
+ assert result[1]["role"] == "assistant"
560
+ assert result[1]["content"][0]["text"] == "Hi there!"
561
+
562
+
563
+ def test_get_clean_message_list_role_conversions():
564
+ messages = [
565
+ ChatMessage(role=MessageRole.TOOL_CALL, content=[{"type": "text", "text": "Calling tool..."}]),
566
+ ChatMessage(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": "Tool response"}]),
567
+ ]
568
+ result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"})
569
+ assert len(result) == 2
570
+ assert result[0]["role"] == "assistant"
571
+ assert result[0]["content"][0]["text"] == "Calling tool..."
572
+ assert result[1]["role"] == "user"
573
+ assert result[1]["content"][0]["text"] == "Tool response"
574
+
575
+
576
+ @pytest.mark.parametrize(
577
+ "convert_images_to_image_urls, expected_clean_message",
578
+ [
579
+ (
580
+ False,
581
+ dict(
582
+ role=MessageRole.USER,
583
+ content=[
584
+ {"type": "image", "image": "encoded_image"},
585
+ {"type": "image", "image": "second_encoded_image"},
586
+ ],
587
+ ),
588
+ ),
589
+ (
590
+ True,
591
+ dict(
592
+ role=MessageRole.USER,
593
+ content=[
594
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,encoded_image"}},
595
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,second_encoded_image"}},
596
+ ],
597
+ ),
598
+ ),
599
+ ],
600
+ )
601
+ def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message):
602
+ message = ChatMessage(
603
+ role=MessageRole.USER,
604
+ content=[{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
605
+ )
606
+ with patch("smolagents.models.encode_image_base64") as mock_encode:
607
+ mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
608
+ result = get_clean_message_list([message], convert_images_to_image_urls=convert_images_to_image_urls)
609
+ mock_encode.assert_any_call(b"image_data")
610
+ mock_encode.assert_any_call(b"second_image_data")
611
+ assert len(result) == 1
612
+ assert result[0] == expected_clean_message
613
+
614
+
615
+ def test_get_clean_message_list_flatten_messages_as_text():
616
+ messages = [
617
+ ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}]),
618
+ ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "How are you?"}]),
619
+ ]
620
+ result = get_clean_message_list(messages, flatten_messages_as_text=True)
621
+ assert len(result) == 1
622
+ assert result[0]["role"] == "user"
623
+ assert result[0]["content"] == "Hello!\nHow are you?"
624
+
625
+
626
+ @pytest.mark.parametrize(
627
+ "model_class, model_kwargs, patching, expected_flatten_messages_as_text",
628
+ [
629
+ (AzureOpenAIServerModel, {}, ("openai.AzureOpenAI", {}), False),
630
+ (InferenceClientModel, {}, ("huggingface_hub.InferenceClient", {}), False),
631
+ (LiteLLMModel, {}, None, False),
632
+ (LiteLLMModel, {"model_id": "ollama"}, None, True),
633
+ (LiteLLMModel, {"model_id": "groq"}, None, True),
634
+ (LiteLLMModel, {"model_id": "cerebras"}, None, True),
635
+ (MLXModel, {}, ("mlx_lm.load", {"return_value": (MagicMock(), MagicMock())}), True),
636
+ (OpenAIServerModel, {}, ("openai.OpenAI", {}), False),
637
+ (OpenAIServerModel, {"flatten_messages_as_text": True}, ("openai.OpenAI", {}), True),
638
+ (
639
+ TransformersModel,
640
+ {},
641
+ [
642
+ (
643
+ "transformers.AutoModelForImageTextToText.from_pretrained",
644
+ {"side_effect": ValueError("Unrecognized configuration class")},
645
+ ),
646
+ ("transformers.AutoModelForCausalLM.from_pretrained", {}),
647
+ ("transformers.AutoTokenizer.from_pretrained", {}),
648
+ ],
649
+ True,
650
+ ),
651
+ (
652
+ TransformersModel,
653
+ {},
654
+ [
655
+ ("transformers.AutoModelForImageTextToText.from_pretrained", {}),
656
+ ("transformers.AutoProcessor.from_pretrained", {}),
657
+ ],
658
+ False,
659
+ ),
660
+ ],
661
+ )
662
+ def test_flatten_messages_as_text_for_all_models(
663
+ model_class, model_kwargs, patching, expected_flatten_messages_as_text
664
+ ):
665
+ with ExitStack() as stack:
666
+ if isinstance(patching, list):
667
+ for target, kwargs in patching:
668
+ stack.enter_context(patch(target, **kwargs))
669
+ elif patching:
670
+ target, kwargs = patching
671
+ stack.enter_context(patch(target, **kwargs))
672
+
673
+ model = model_class(**{"model_id": "test-model", **model_kwargs})
674
+ assert model.flatten_messages_as_text is expected_flatten_messages_as_text, f"{model_class.__name__} failed"
675
+
676
+
677
+ @pytest.mark.parametrize(
678
+ "model_id,expected",
679
+ [
680
+ # Unsupported base models
681
+ ("o3", False),
682
+ ("o4-mini", False),
683
+ # Unsupported versioned models
684
+ ("o3-2025-04-16", False),
685
+ ("o4-mini-2025-04-16", False),
686
+ # Unsupported models with path prefixes
687
+ ("openai/o3", False),
688
+ ("openai/o4-mini", False),
689
+ ("openai/o3-2025-04-16", False),
690
+ ("openai/o4-mini-2025-04-16", False),
691
+ # Supported models
692
+ ("o3-mini", True), # Different from o3
693
+ ("o3-mini-2025-01-31", True), # Different from o3
694
+ ("o4", True), # Different from o4-mini
695
+ ("o4-turbo", True), # Different from o4-mini
696
+ ("gpt-4", True),
697
+ ("claude-3-5-sonnet", True),
698
+ ("mistral-large", True),
699
+ # Supported models with path prefixes
700
+ ("openai/gpt-4", True),
701
+ ("anthropic/claude-3-5-sonnet", True),
702
+ ("mistralai/mistral-large", True),
703
+ # Edge cases
704
+ ("", True), # Empty string doesn't match pattern
705
+ ("o3x", True), # Not exactly o3
706
+ ("o3_mini", True), # Not o3-mini format
707
+ ("prefix-o3", True), # o3 not at start
708
+ ],
709
+ )
710
+ def test_supports_stop_parameter(model_id, expected):
711
+ """Test the supports_stop_parameter function with various model IDs"""
712
+ assert supports_stop_parameter(model_id) == expected, f"Failed for model_id: {model_id}"
713
+
714
+
715
+ class TestGetToolCallFromText:
716
+ @pytest.fixture(autouse=True)
717
+ def mock_uuid4(self):
718
+ with patch("uuid.uuid4", return_value="test-uuid"):
719
+ yield
720
+
721
+ def test_get_tool_call_from_text_basic(self):
722
+ text = '{"name": "weather_tool", "arguments": "New York"}'
723
+ result = get_tool_call_from_text(text, "name", "arguments")
724
+ assert isinstance(result, ChatMessageToolCall)
725
+ assert result.id == "test-uuid"
726
+ assert result.type == "function"
727
+ assert result.function.name == "weather_tool"
728
+ assert result.function.arguments == "New York"
729
+
730
+ def test_get_tool_call_from_text_name_key_missing(self):
731
+ text = '{"action": "weather_tool", "arguments": "New York"}'
732
+ with pytest.raises(ValueError) as exc_info:
733
+ get_tool_call_from_text(text, "name", "arguments")
734
+ error_msg = str(exc_info.value)
735
+ assert "Key tool_name_key='name' not found" in error_msg
736
+ assert "'action', 'arguments'" in error_msg
737
+
738
+ def test_get_tool_call_from_text_json_object_args(self):
739
+ text = '{"name": "weather_tool", "arguments": {"city": "New York"}}'
740
+ result = get_tool_call_from_text(text, "name", "arguments")
741
+ assert result.function.arguments == {"city": "New York"}
742
+
743
+ def test_get_tool_call_from_text_json_string_args(self):
744
+ text = '{"name": "weather_tool", "arguments": "{\\"city\\": \\"New York\\"}"}'
745
+ result = get_tool_call_from_text(text, "name", "arguments")
746
+ assert result.function.arguments == {"city": "New York"}
747
+
748
+ def test_get_tool_call_from_text_missing_args(self):
749
+ text = '{"name": "weather_tool"}'
750
+ result = get_tool_call_from_text(text, "name", "arguments")
751
+ assert result.function.arguments is None
752
+
753
+ def test_get_tool_call_from_text_custom_keys(self):
754
+ text = '{"tool": "weather_tool", "params": "New York"}'
755
+ result = get_tool_call_from_text(text, "tool", "params")
756
+ assert result.function.name == "weather_tool"
757
+ assert result.function.arguments == "New York"
758
+
759
+ def test_get_tool_call_from_text_numeric_args(self):
760
+ text = '{"name": "calculator", "arguments": 42}'
761
+ result = get_tool_call_from_text(text, "name", "arguments")
762
+ assert result.function.name == "calculator"
763
+ assert result.function.arguments == 42
tests/test_monitoring.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import PIL.Image
19
+ import pytest
20
+
21
+ from smolagents import (
22
+ CodeAgent,
23
+ RunResult,
24
+ ToolCallingAgent,
25
+ stream_to_gradio,
26
+ )
27
+ from smolagents.models import (
28
+ ChatMessage,
29
+ ChatMessageToolCall,
30
+ ChatMessageToolCallFunction,
31
+ MessageRole,
32
+ Model,
33
+ TokenUsage,
34
+ )
35
+
36
+
37
+ class FakeLLMModel(Model):
38
+ def __init__(self, give_token_usage: bool = True):
39
+ self.give_token_usage = give_token_usage
40
+
41
+ def generate(self, prompt, tools_to_call_from=None, **kwargs):
42
+ if tools_to_call_from is not None:
43
+ return ChatMessage(
44
+ role=MessageRole.ASSISTANT,
45
+ content="",
46
+ tool_calls=[
47
+ ChatMessageToolCall(
48
+ id="fake_id",
49
+ type="function",
50
+ function=ChatMessageToolCallFunction(name="final_answer", arguments={"answer": "image"}),
51
+ )
52
+ ],
53
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None,
54
+ )
55
+ else:
56
+ return ChatMessage(
57
+ role=MessageRole.ASSISTANT,
58
+ content="""<code>
59
+ final_answer('This is the final answer.')
60
+ </code>""",
61
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None,
62
+ )
63
+
64
+
65
+ class MonitoringTester(unittest.TestCase):
66
+ def test_code_agent_metrics(self):
67
+ agent = CodeAgent(
68
+ tools=[],
69
+ model=FakeLLMModel(),
70
+ max_steps=1,
71
+ )
72
+ agent.run("Fake task")
73
+
74
+ self.assertEqual(agent.monitor.total_input_token_count, 10)
75
+ self.assertEqual(agent.monitor.total_output_token_count, 20)
76
+
77
+ def test_toolcalling_agent_metrics(self):
78
+ agent = ToolCallingAgent(
79
+ tools=[],
80
+ model=FakeLLMModel(),
81
+ max_steps=1,
82
+ )
83
+
84
+ agent.run("Fake task")
85
+
86
+ self.assertEqual(agent.monitor.total_input_token_count, 10)
87
+ self.assertEqual(agent.monitor.total_output_token_count, 20)
88
+
89
+ def test_code_agent_metrics_max_steps(self):
90
+ class FakeLLMModelMalformedAnswer(Model):
91
+ def generate(self, prompt, **kwargs):
92
+ return ChatMessage(
93
+ role=MessageRole.ASSISTANT,
94
+ content="Malformed answer",
95
+ token_usage=TokenUsage(input_tokens=10, output_tokens=20),
96
+ )
97
+
98
+ agent = CodeAgent(
99
+ tools=[],
100
+ model=FakeLLMModelMalformedAnswer(),
101
+ max_steps=1,
102
+ )
103
+
104
+ agent.run("Fake task")
105
+
106
+ self.assertEqual(agent.monitor.total_input_token_count, 20)
107
+ self.assertEqual(agent.monitor.total_output_token_count, 40)
108
+
109
+ def test_code_agent_metrics_generation_error(self):
110
+ class FakeLLMModelGenerationException(Model):
111
+ def generate(self, prompt, **kwargs):
112
+ raise Exception("Cannot generate")
113
+
114
+ agent = CodeAgent(
115
+ tools=[],
116
+ model=FakeLLMModelGenerationException(),
117
+ max_steps=1,
118
+ )
119
+ with pytest.raises(Exception) as e:
120
+ agent.run("Fake task")
121
+ assert "Cannot generate" in str(e.value)
122
+
123
+ def test_streaming_agent_text_output(self):
124
+ agent = CodeAgent(
125
+ tools=[],
126
+ model=FakeLLMModel(),
127
+ max_steps=1,
128
+ planning_interval=2,
129
+ )
130
+
131
+ # Use stream_to_gradio to capture the output
132
+ outputs = list(stream_to_gradio(agent, task="Test task"))
133
+
134
+ self.assertEqual(len(outputs), 11)
135
+ plan_message = outputs[1]
136
+ self.assertEqual(plan_message.role, "assistant")
137
+ self.assertIn("<code>", plan_message.content)
138
+ final_message = outputs[-1]
139
+ self.assertEqual(final_message.role, "assistant")
140
+ self.assertIn("This is the final answer.", final_message.content)
141
+
142
+ def test_streaming_agent_image_output(self):
143
+ agent = ToolCallingAgent(
144
+ tools=[],
145
+ model=FakeLLMModel(),
146
+ max_steps=1,
147
+ verbosity_level=100,
148
+ )
149
+
150
+ # Use stream_to_gradio to capture the output
151
+ outputs = list(
152
+ stream_to_gradio(
153
+ agent,
154
+ task="Test task",
155
+ additional_args=dict(image=PIL.Image.new("RGB", (100, 100))),
156
+ )
157
+ )
158
+
159
+ self.assertEqual(len(outputs), 7)
160
+ final_message = outputs[-1]
161
+ self.assertEqual(final_message.role, "assistant")
162
+ self.assertIsInstance(final_message.content, dict)
163
+ self.assertEqual(final_message.content["mime_type"], "image/png")
164
+
165
+ def test_streaming_with_agent_error(self):
166
+ class DummyModel(Model):
167
+ def generate(self, prompt, **kwargs):
168
+ return ChatMessage(role=MessageRole.ASSISTANT, content="Malformed call")
169
+
170
+ agent = CodeAgent(
171
+ tools=[],
172
+ model=DummyModel(),
173
+ max_steps=1,
174
+ )
175
+
176
+ # Use stream_to_gradio to capture the output
177
+ outputs = list(stream_to_gradio(agent, task="Test task"))
178
+
179
+ self.assertEqual(len(outputs), 11)
180
+ final_message = outputs[-1]
181
+ self.assertEqual(final_message.role, "assistant")
182
+ self.assertIn("Malformed call", final_message.content)
183
+
184
+ def test_run_return_full_result(self):
185
+ agent = CodeAgent(
186
+ tools=[],
187
+ model=FakeLLMModel(),
188
+ max_steps=1,
189
+ return_full_result=True,
190
+ )
191
+
192
+ result = agent.run("Fake task")
193
+
194
+ self.assertIsInstance(result, RunResult)
195
+ self.assertEqual(result.output, "This is the final answer.")
196
+ self.assertEqual(result.state, "success")
197
+ self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20))
198
+ self.assertIsInstance(result.messages, list)
199
+ self.assertGreater(result.timing.duration, 0)
200
+
201
+ agent = ToolCallingAgent(
202
+ tools=[],
203
+ model=FakeLLMModel(),
204
+ max_steps=1,
205
+ return_full_result=True,
206
+ )
207
+
208
+ result = agent.run("Fake task")
209
+
210
+ self.assertIsInstance(result, RunResult)
211
+ self.assertEqual(result.output, "image")
212
+ self.assertEqual(result.state, "success")
213
+ self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20))
214
+ self.assertIsInstance(result.messages, list)
215
+ self.assertGreater(result.timing.duration, 0)
216
+
217
+ # Below 2 lines should be removed when the attributes are removed
218
+ assert agent.monitor.total_input_token_count == 10
219
+ assert agent.monitor.total_output_token_count == 20
220
+
221
+ def test_run_result_no_token_usage(self):
222
+ agent = CodeAgent(
223
+ tools=[],
224
+ model=FakeLLMModel(give_token_usage=False),
225
+ max_steps=1,
226
+ return_full_result=True,
227
+ )
228
+
229
+ result = agent.run("Fake task")
230
+
231
+ self.assertIsInstance(result, RunResult)
232
+ self.assertEqual(result.output, "This is the final answer.")
233
+ self.assertEqual(result.state, "success")
234
+ self.assertIsNone(result.token_usage)
235
+ self.assertIsInstance(result.messages, list)
236
+ self.assertGreater(result.timing.duration, 0)
237
+
238
+ agent = ToolCallingAgent(
239
+ tools=[],
240
+ model=FakeLLMModel(give_token_usage=False),
241
+ max_steps=1,
242
+ return_full_result=True,
243
+ )
244
+
245
+ result = agent.run("Fake task")
246
+
247
+ self.assertIsInstance(result, RunResult)
248
+ self.assertEqual(result.output, "image")
249
+ self.assertEqual(result.state, "success")
250
+ self.assertIsNone(result.token_usage)
251
+ self.assertIsInstance(result.messages, list)
252
+ self.assertGreater(result.timing.duration, 0)
tests/test_remote_executors.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from textwrap import dedent
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import docker
6
+ import PIL.Image
7
+ import pytest
8
+ from rich.console import Console
9
+
10
+ from smolagents.default_tools import FinalAnswerTool, WikipediaSearchTool
11
+ from smolagents.monitoring import AgentLogger, LogLevel
12
+ from smolagents.remote_executors import DockerExecutor, E2BExecutor, RemotePythonExecutor
13
+ from smolagents.utils import AgentError
14
+
15
+ from .utils.markers import require_run_all
16
+
17
+
18
+ class TestRemotePythonExecutor:
19
+ def test_send_tools_empty_tools(self):
20
+ executor = RemotePythonExecutor(additional_imports=[], logger=MagicMock())
21
+ executor.run_code_raise_errors = MagicMock()
22
+ executor.send_tools({})
23
+ assert executor.run_code_raise_errors.call_count == 1
24
+ # No new packages should be installed
25
+ assert "!pip install" not in executor.run_code_raise_errors.call_args.args[0]
26
+
27
+ @require_run_all
28
+ def test_send_tools_with_default_wikipedia_search_tool(self):
29
+ tool = WikipediaSearchTool()
30
+ executor = RemotePythonExecutor(additional_imports=[], logger=MagicMock())
31
+ executor.run_code_raise_errors = MagicMock()
32
+ executor.run_code_raise_errors.return_value = (None, "", False)
33
+ executor.send_tools({"wikipedia_search": tool})
34
+ assert executor.run_code_raise_errors.call_count == 2
35
+ assert "!pip install wikipedia-api" == executor.run_code_raise_errors.call_args_list[0].args[0]
36
+ assert "class WikipediaSearchTool(Tool)" in executor.run_code_raise_errors.call_args_list[1].args[0]
37
+
38
+
39
+ class TestE2BExecutorUnit:
40
+ def test_e2b_executor_instantiation(self):
41
+ logger = MagicMock()
42
+ with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
43
+ mock_sandbox.return_value.commands.run.return_value.error = None
44
+ mock_sandbox.return_value.run_code.return_value.error = None
45
+ executor = E2BExecutor(
46
+ additional_imports=[], logger=logger, api_key="dummy-api-key", template="dummy-template-id", timeout=60
47
+ )
48
+ assert isinstance(executor, E2BExecutor)
49
+ assert executor.logger == logger
50
+ assert executor.sandbox == mock_sandbox.return_value
51
+ assert mock_sandbox.call_count == 1
52
+ assert mock_sandbox.call_args.kwargs == {
53
+ "api_key": "dummy-api-key",
54
+ "template": "dummy-template-id",
55
+ "timeout": 60,
56
+ }
57
+
58
+ def test_cleanup(self):
59
+ """Test that the cleanup method properly shuts down the sandbox"""
60
+ logger = MagicMock()
61
+ with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
62
+ # Setup mock
63
+ mock_sandbox.return_value.kill = MagicMock()
64
+
65
+ # Create executor
66
+ executor = E2BExecutor(additional_imports=[], logger=logger, api_key="dummy-api-key")
67
+
68
+ # Call cleanup
69
+ executor.cleanup()
70
+
71
+ # Verify sandbox was killed
72
+ mock_sandbox.return_value.kill.assert_called_once()
73
+ assert logger.log.call_count >= 2 # Should log start and completion messages
74
+
75
+
76
+ @pytest.fixture
77
+ def e2b_executor():
78
+ executor = E2BExecutor(
79
+ additional_imports=["pillow", "numpy"],
80
+ logger=AgentLogger(LogLevel.INFO, Console(force_terminal=False, file=io.StringIO())),
81
+ )
82
+ yield executor
83
+ executor.cleanup()
84
+
85
+
86
+ @require_run_all
87
+ class TestE2BExecutorIntegration:
88
+ @pytest.fixture(autouse=True)
89
+ def set_executor(self, e2b_executor):
90
+ self.executor = e2b_executor
91
+
92
+ @pytest.mark.parametrize(
93
+ "code_action, expected_result",
94
+ [
95
+ (
96
+ dedent('''
97
+ final_answer("""This is
98
+ a multiline
99
+ final answer""")
100
+ '''),
101
+ "This is\na multiline\nfinal answer",
102
+ ),
103
+ (
104
+ dedent("""
105
+ text = '''Text containing
106
+ final_answer(5)
107
+ '''
108
+ final_answer(text)
109
+ """),
110
+ "Text containing\nfinal_answer(5)\n",
111
+ ),
112
+ (
113
+ dedent("""
114
+ num = 2
115
+ if num == 1:
116
+ final_answer("One")
117
+ elif num == 2:
118
+ final_answer("Two")
119
+ """),
120
+ "Two",
121
+ ),
122
+ ],
123
+ )
124
+ def test_final_answer_patterns(self, code_action, expected_result):
125
+ self.executor.send_tools({"final_answer": FinalAnswerTool()})
126
+ result, logs, final_answer = self.executor(code_action)
127
+ assert final_answer is True
128
+ assert result == expected_result
129
+
130
+ def test_custom_final_answer(self):
131
+ class CustomFinalAnswerTool(FinalAnswerTool):
132
+ def forward(self, answer: str) -> str:
133
+ return "CUSTOM" + answer
134
+
135
+ self.executor.send_tools({"final_answer": CustomFinalAnswerTool()})
136
+ code_action = dedent("""
137
+ final_answer(answer="_answer")
138
+ """)
139
+ result, logs, final_answer = self.executor(code_action)
140
+ assert final_answer is True
141
+ assert result == "CUSTOM_answer"
142
+
143
+ def test_custom_final_answer_with_custom_inputs(self):
144
+ class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
145
+ inputs = {
146
+ "answer1": {"type": "string", "description": "First part of the answer."},
147
+ "answer2": {"type": "string", "description": "Second part of the answer."},
148
+ }
149
+
150
+ def forward(self, answer1: str, answer2: str) -> str:
151
+ return answer1 + "CUSTOM" + answer2
152
+
153
+ self.executor.send_tools({"final_answer": CustomFinalAnswerToolWithCustomInputs()})
154
+ code_action = dedent("""
155
+ final_answer(
156
+ answer1="answer1_",
157
+ answer2="_answer2"
158
+ )
159
+ """)
160
+ result, logs, final_answer = self.executor(code_action)
161
+ assert final_answer is True
162
+ assert result == "answer1_CUSTOM_answer2"
163
+
164
+
165
+ @pytest.fixture
166
+ def docker_executor():
167
+ executor = DockerExecutor(
168
+ additional_imports=["pillow", "numpy"],
169
+ logger=AgentLogger(LogLevel.INFO, Console(force_terminal=False, file=io.StringIO())),
170
+ )
171
+ yield executor
172
+ executor.delete()
173
+
174
+
175
+ @require_run_all
176
+ class TestDockerExecutorIntegration:
177
+ @pytest.fixture(autouse=True)
178
+ def set_executor(self, docker_executor):
179
+ self.executor = docker_executor
180
+
181
+ def test_initialization(self):
182
+ """Check if DockerExecutor initializes without errors"""
183
+ assert self.executor.container is not None, "Container should be initialized"
184
+
185
+ def test_state_persistence(self):
186
+ """Test that variables and imports form one snippet persist in the next"""
187
+ code_action = "import numpy as np; a = 2"
188
+ self.executor(code_action)
189
+
190
+ code_action = "print(np.sqrt(a))"
191
+ result, logs, final_answer = self.executor(code_action)
192
+ assert "1.41421" in logs
193
+
194
+ def test_execute_output(self):
195
+ """Test execution that returns a string"""
196
+ code_action = 'final_answer("This is the final answer")'
197
+ result, logs, final_answer = self.executor(code_action)
198
+ assert result == "This is the final answer", "Result should be 'This is the final answer'"
199
+
200
+ def test_execute_multiline_output(self):
201
+ """Test execution that returns a string"""
202
+ code_action = 'result = "This is the final answer"\nfinal_answer(result)'
203
+ result, logs, final_answer = self.executor(code_action)
204
+ assert result == "This is the final answer", "Result should be 'This is the final answer'"
205
+
206
+ def test_execute_image_output(self):
207
+ """Test execution that returns a base64 image"""
208
+ code_action = dedent("""
209
+ import base64
210
+ from PIL import Image
211
+ from io import BytesIO
212
+ image = Image.new("RGB", (10, 10), (255, 0, 0))
213
+ final_answer(image)
214
+ """)
215
+ result, logs, final_answer = self.executor(code_action)
216
+ assert isinstance(result, PIL.Image.Image), "Result should be a PIL Image"
217
+
218
+ def test_syntax_error_handling(self):
219
+ """Test handling of syntax errors"""
220
+ code_action = 'print("Missing Parenthesis' # Syntax error
221
+ with pytest.raises(AgentError) as exception_info:
222
+ self.executor(code_action)
223
+ assert "SyntaxError" in str(exception_info.value), "Should raise a syntax error"
224
+
225
+ def test_cleanup_on_deletion(self):
226
+ """Test if Docker container stops and removes on deletion"""
227
+ container_id = self.executor.container.id
228
+ self.executor.delete() # Trigger cleanup
229
+
230
+ client = docker.from_env()
231
+ containers = [c.id for c in client.containers.list(all=True)]
232
+ assert container_id not in containers, "Container should be removed"
233
+
234
+ @pytest.mark.parametrize(
235
+ "code_action, expected_result",
236
+ [
237
+ (
238
+ dedent('''
239
+ final_answer("""This is
240
+ a multiline
241
+ final answer""")
242
+ '''),
243
+ "This is\na multiline\nfinal answer",
244
+ ),
245
+ (
246
+ dedent("""
247
+ text = '''Text containing
248
+ final_answer(5)
249
+ '''
250
+ final_answer(text)
251
+ """),
252
+ "Text containing\nfinal_answer(5)\n",
253
+ ),
254
+ (
255
+ dedent("""
256
+ num = 2
257
+ if num == 1:
258
+ final_answer("One")
259
+ elif num == 2:
260
+ final_answer("Two")
261
+ """),
262
+ "Two",
263
+ ),
264
+ ],
265
+ )
266
+ def test_final_answer_patterns(self, code_action, expected_result):
267
+ self.executor.send_tools({"final_answer": FinalAnswerTool()})
268
+ result, logs, final_answer = self.executor(code_action)
269
+ assert final_answer is True
270
+ assert result == expected_result
271
+
272
+ def test_custom_final_answer(self):
273
+ class CustomFinalAnswerTool(FinalAnswerTool):
274
+ def forward(self, answer: str) -> str:
275
+ return "CUSTOM" + answer
276
+
277
+ self.executor.send_tools({"final_answer": CustomFinalAnswerTool()})
278
+ code_action = dedent("""
279
+ final_answer(answer="_answer")
280
+ """)
281
+ result, logs, final_answer = self.executor(code_action)
282
+ assert final_answer is True
283
+ assert result == "CUSTOM_answer"
284
+
285
+ def test_custom_final_answer_with_custom_inputs(self):
286
+ class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
287
+ inputs = {
288
+ "answer1": {"type": "string", "description": "First part of the answer."},
289
+ "answer2": {"type": "string", "description": "Second part of the answer."},
290
+ }
291
+
292
+ def forward(self, answer1: str, answer2: str) -> str:
293
+ return answer1 + "CUSTOM" + answer2
294
+
295
+ self.executor.send_tools({"final_answer": CustomFinalAnswerToolWithCustomInputs()})
296
+ code_action = dedent("""
297
+ final_answer(
298
+ answer1="answer1_",
299
+ answer2="_answer2"
300
+ )
301
+ """)
302
+ result, logs, final_answer = self.executor(code_action)
303
+ assert final_answer is True
304
+ assert result == "answer1_CUSTOM_answer2"
305
+
306
+
307
+ class TestDockerExecutorUnit:
308
+ def test_cleanup(self):
309
+ """Test that cleanup properly stops and removes the container"""
310
+ logger = MagicMock()
311
+ with (
312
+ patch("docker.from_env") as mock_docker_client,
313
+ patch("requests.post") as mock_post,
314
+ patch("websocket.create_connection"),
315
+ ):
316
+ # Setup mocks
317
+ mock_container = MagicMock()
318
+ mock_container.status = "running"
319
+ mock_container.short_id = "test123"
320
+
321
+ mock_docker_client.return_value.containers.run.return_value = mock_container
322
+ mock_docker_client.return_value.images.get.return_value = MagicMock()
323
+
324
+ mock_post.return_value.status_code = 201
325
+ mock_post.return_value.json.return_value = {"id": "test-kernel-id"}
326
+
327
+ # Create executor
328
+ executor = DockerExecutor(additional_imports=[], logger=logger, build_new_image=False)
329
+
330
+ # Call cleanup
331
+ executor.cleanup()
332
+
333
+ # Verify container was stopped and removed
334
+ mock_container.stop.assert_called_once()
335
+ mock_container.remove.assert_called_once()
tests/test_search.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from smolagents import DuckDuckGoSearchTool
18
+
19
+ from .test_tools import ToolTesterMixin
20
+ from .utils.markers import require_run_all
21
+
22
+
23
+ class TestDuckDuckGoSearchTool(ToolTesterMixin):
24
+ def setup_method(self):
25
+ self.tool = DuckDuckGoSearchTool()
26
+ self.tool.setup()
27
+
28
+ @require_run_all
29
+ def test_exact_match_arg(self):
30
+ result = self.tool("Agents")
31
+ assert isinstance(result, str)
32
+
33
+ @require_run_all
34
+ def test_agent_type_output(self):
35
+ super().test_agent_type_output()
tests/test_tool_validation.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ from textwrap import dedent
3
+
4
+ import pytest
5
+
6
+ from smolagents.default_tools import (
7
+ DuckDuckGoSearchTool,
8
+ GoogleSearchTool,
9
+ SpeechToTextTool,
10
+ VisitWebpageTool,
11
+ WebSearchTool,
12
+ )
13
+ from smolagents.tool_validation import MethodChecker, validate_tool_attributes
14
+ from smolagents.tools import Tool, tool
15
+
16
+
17
+ UNDEFINED_VARIABLE = "undefined_variable"
18
+
19
+
20
+ @pytest.mark.parametrize(
21
+ "tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool, WebSearchTool]
22
+ )
23
+ def test_validate_tool_attributes_with_default_tools(tool_class):
24
+ assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool"
25
+
26
+
27
+ class ValidTool(Tool):
28
+ name = "valid_tool"
29
+ description = "A valid tool"
30
+ inputs = {"input": {"type": "string", "description": "input"}}
31
+ output_type = "string"
32
+ simple_attr = "string"
33
+ dict_attr = {"key": "value"}
34
+
35
+ def __init__(self, optional_param="default"):
36
+ super().__init__()
37
+ self.param = optional_param
38
+
39
+ def forward(self, input: str) -> str:
40
+ return input.upper()
41
+
42
+
43
+ @tool
44
+ def valid_tool_function(input: str) -> str:
45
+ """A valid tool function.
46
+
47
+ Args:
48
+ input (str): Input string.
49
+ """
50
+ return input.upper()
51
+
52
+
53
+ @pytest.mark.parametrize("tool_class", [ValidTool, valid_tool_function.__class__])
54
+ def test_validate_tool_attributes_valid(tool_class):
55
+ assert validate_tool_attributes(tool_class) is None
56
+
57
+
58
+ class InvalidToolName(Tool):
59
+ name = "invalid tool name"
60
+ description = "Tool with invalid name"
61
+ inputs = {"input": {"type": "string", "description": "input"}}
62
+ output_type = "string"
63
+
64
+ def __init__(self):
65
+ super().__init__()
66
+
67
+ def forward(self, input: str) -> str:
68
+ return input
69
+
70
+
71
+ class InvalidToolComplexAttrs(Tool):
72
+ name = "invalid_tool"
73
+ description = "Tool with complex class attributes"
74
+ inputs = {"input": {"type": "string", "description": "input"}}
75
+ output_type = "string"
76
+ complex_attr = [x for x in range(3)] # Complex class attribute
77
+
78
+ def __init__(self):
79
+ super().__init__()
80
+
81
+ def forward(self, input: str) -> str:
82
+ return input
83
+
84
+
85
+ class InvalidToolRequiredParams(Tool):
86
+ name = "invalid_tool"
87
+ description = "Tool with required params"
88
+ inputs = {"input": {"type": "string", "description": "input"}}
89
+ output_type = "string"
90
+
91
+ def __init__(self, required_param, kwarg1=1): # No default value
92
+ super().__init__()
93
+ self.param = required_param
94
+
95
+ def forward(self, input: str) -> str:
96
+ return input
97
+
98
+
99
+ class InvalidToolNonLiteralDefaultParam(Tool):
100
+ name = "invalid_tool"
101
+ description = "Tool with non-literal default parameter value"
102
+ inputs = {"input": {"type": "string", "description": "input"}}
103
+ output_type = "string"
104
+
105
+ def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal
106
+ super().__init__()
107
+ self.default_param = default_param
108
+
109
+ def forward(self, input: str) -> str:
110
+ return input
111
+
112
+
113
+ class InvalidToolUndefinedNames(Tool):
114
+ name = "invalid_tool"
115
+ description = "Tool with undefined names"
116
+ inputs = {"input": {"type": "string", "description": "input"}}
117
+ output_type = "string"
118
+
119
+ def forward(self, input: str) -> str:
120
+ return UNDEFINED_VARIABLE # Undefined name
121
+
122
+
123
+ @pytest.mark.parametrize(
124
+ "tool_class, expected_error",
125
+ [
126
+ (
127
+ InvalidToolName,
128
+ "Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found 'invalid tool name'",
129
+ ),
130
+ (InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"),
131
+ (InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"),
132
+ (
133
+ InvalidToolNonLiteralDefaultParam,
134
+ "Parameters in __init__ must have literal default values, found non-literal defaults",
135
+ ),
136
+ (InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"),
137
+ ],
138
+ )
139
+ def test_validate_tool_attributes_exceptions(tool_class, expected_error):
140
+ with pytest.raises(ValueError, match=expected_error):
141
+ validate_tool_attributes(tool_class)
142
+
143
+
144
+ class MultipleAssignmentsTool(Tool):
145
+ name = "multiple_assignments_tool"
146
+ description = "Tool with multiple assignments"
147
+ inputs = {"input": {"type": "string", "description": "input"}}
148
+ output_type = "string"
149
+
150
+ def __init__(self):
151
+ super().__init__()
152
+
153
+ def forward(self, input: str) -> str:
154
+ a, b = "1", "2"
155
+ return a + b
156
+
157
+
158
+ def test_validate_tool_attributes_multiple_assignments():
159
+ validate_tool_attributes(MultipleAssignmentsTool)
160
+
161
+
162
+ @tool
163
+ def tool_function_with_multiple_assignments(input: str) -> str:
164
+ """A valid tool function.
165
+
166
+ Args:
167
+ input (str): Input string.
168
+ """
169
+ a, b = "1", "2"
170
+ return input.upper() + a + b
171
+
172
+
173
+ @pytest.mark.parametrize("tool_instance", [MultipleAssignmentsTool(), tool_function_with_multiple_assignments])
174
+ def test_tool_to_dict_validation_with_multiple_assignments(tool_instance):
175
+ tool_instance.to_dict()
176
+
177
+
178
+ class TestMethodChecker:
179
+ def test_multiple_assignments(self):
180
+ source_code = dedent(
181
+ """
182
+ def forward(self) -> str:
183
+ a, b = "1", "2"
184
+ return a + b
185
+ """
186
+ )
187
+ method_checker = MethodChecker(set())
188
+ method_checker.visit(ast.parse(source_code))
189
+ assert method_checker.errors == []
tests/test_tools.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ import os
17
+ from textwrap import dedent
18
+ from typing import Any, Literal
19
+ from unittest.mock import MagicMock, patch
20
+
21
+ import mcp
22
+ import numpy as np
23
+ import PIL.Image
24
+ import pytest
25
+
26
+ from smolagents.agent_types import _AGENT_TYPE_MAPPING
27
+ from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, launch_gradio_demo, tool, validate_tool_arguments
28
+
29
+ from .utils.markers import require_run_all
30
+
31
+
32
+ class ToolTesterMixin:
33
+ def test_inputs_output(self):
34
+ assert hasattr(self.tool, "inputs")
35
+ assert hasattr(self.tool, "output_type")
36
+
37
+ inputs = self.tool.inputs
38
+ assert isinstance(inputs, dict)
39
+
40
+ for _, input_spec in inputs.items():
41
+ assert "type" in input_spec
42
+ assert "description" in input_spec
43
+ assert input_spec["type"] in AUTHORIZED_TYPES
44
+ assert isinstance(input_spec["description"], str)
45
+
46
+ output_type = self.tool.output_type
47
+ assert output_type in AUTHORIZED_TYPES
48
+
49
+ def test_common_attributes(self):
50
+ assert hasattr(self.tool, "description")
51
+ assert hasattr(self.tool, "name")
52
+ assert hasattr(self.tool, "inputs")
53
+ assert hasattr(self.tool, "output_type")
54
+
55
+ def test_agent_type_output(self, create_inputs):
56
+ inputs = create_inputs(self.tool.inputs)
57
+ output = self.tool(**inputs, sanitize_inputs_outputs=True)
58
+ if self.tool.output_type != "any":
59
+ agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
60
+ assert isinstance(output, agent_type)
61
+
62
+ @pytest.fixture
63
+ def create_inputs(self, shared_datadir):
64
+ def _create_inputs(tool_inputs: dict[str, dict[str | type, str]]) -> dict[str, Any]:
65
+ inputs = {}
66
+
67
+ for input_name, input_desc in tool_inputs.items():
68
+ input_type = input_desc["type"]
69
+
70
+ if input_type == "string":
71
+ inputs[input_name] = "Text input"
72
+ elif input_type == "image":
73
+ inputs[input_name] = PIL.Image.open(shared_datadir / "000000039769.png").resize((512, 512))
74
+ elif input_type == "audio":
75
+ inputs[input_name] = np.ones(3000)
76
+ else:
77
+ raise ValueError(f"Invalid type requested: {input_type}")
78
+
79
+ return inputs
80
+
81
+ return _create_inputs
82
+
83
+
84
+ class TestTool:
85
+ def test_tool_init_with_decorator(self):
86
+ @tool
87
+ def coolfunc(a: str, b: int) -> float:
88
+ """Cool function
89
+
90
+ Args:
91
+ a: The first argument
92
+ b: The second one
93
+ """
94
+ return b + 2, a
95
+
96
+ assert coolfunc.output_type == "number"
97
+
98
+ def test_tool_init_vanilla(self):
99
+ class HFModelDownloadsTool(Tool):
100
+ name = "model_download_counter"
101
+ description = """
102
+ This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
103
+ It returns the name of the checkpoint."""
104
+
105
+ inputs = {
106
+ "task": {
107
+ "type": "string",
108
+ "description": "the task category (such as text-classification, depth-estimation, etc)",
109
+ }
110
+ }
111
+ output_type = "string"
112
+
113
+ def forward(self, task: str) -> str:
114
+ return "best model"
115
+
116
+ tool = HFModelDownloadsTool()
117
+ assert list(tool.inputs.keys())[0] == "task"
118
+
119
+ def test_tool_init_decorator_raises_issues(self):
120
+ with pytest.raises(Exception) as e:
121
+
122
+ @tool
123
+ def coolfunc(a: str, b: int):
124
+ """Cool function
125
+
126
+ Args:
127
+ a: The first argument
128
+ b: The second one
129
+ """
130
+ return a + b
131
+
132
+ assert coolfunc.output_type == "number"
133
+ assert "Tool return type not found" in str(e)
134
+
135
+ with pytest.raises(Exception) as e:
136
+
137
+ @tool
138
+ def coolfunc(a: str, b: int) -> int:
139
+ """Cool function
140
+
141
+ Args:
142
+ a: The first argument
143
+ """
144
+ return b + a
145
+
146
+ assert coolfunc.output_type == "number"
147
+ assert "docstring has no description for the argument" in str(e)
148
+
149
+ def test_saving_tool_raises_error_imports_outside_function(self, tmp_path):
150
+ with pytest.raises(Exception) as e:
151
+ import numpy as np
152
+
153
+ @tool
154
+ def get_current_time() -> str:
155
+ """
156
+ Gets the current time.
157
+ """
158
+ return str(np.random.random())
159
+
160
+ get_current_time.save(tmp_path)
161
+
162
+ assert "np" in str(e)
163
+
164
+ # Also test with classic definition
165
+ with pytest.raises(Exception) as e:
166
+
167
+ class GetCurrentTimeTool(Tool):
168
+ name = "get_current_time_tool"
169
+ description = "Gets the current time"
170
+ inputs = {}
171
+ output_type = "string"
172
+
173
+ def forward(self):
174
+ return str(np.random.random())
175
+
176
+ get_current_time = GetCurrentTimeTool()
177
+ get_current_time.save(tmp_path)
178
+
179
+ assert "np" in str(e)
180
+
181
+ def test_tool_definition_raises_no_error_imports_in_function(self):
182
+ @tool
183
+ def get_current_time() -> str:
184
+ """
185
+ Gets the current time.
186
+ """
187
+ from datetime import datetime
188
+
189
+ return str(datetime.now())
190
+
191
+ class GetCurrentTimeTool(Tool):
192
+ name = "get_current_time_tool"
193
+ description = "Gets the current time"
194
+ inputs = {}
195
+ output_type = "string"
196
+
197
+ def forward(self):
198
+ from datetime import datetime
199
+
200
+ return str(datetime.now())
201
+
202
+ def test_tool_to_dict_allows_no_arg_in_init(self):
203
+ """Test that a tool cannot be saved with required args in init"""
204
+
205
+ class FailTool(Tool):
206
+ name = "specific"
207
+ description = "test description"
208
+ inputs = {"string_input": {"type": "string", "description": "input description"}}
209
+ output_type = "string"
210
+
211
+ def __init__(self, url):
212
+ super().__init__(self)
213
+ self.url = url
214
+
215
+ def forward(self, string_input: str) -> str:
216
+ return self.url + string_input
217
+
218
+ fail_tool = FailTool("dummy_url")
219
+ with pytest.raises(Exception) as e:
220
+ fail_tool.to_dict()
221
+ assert "Parameters in __init__ must have default values, found required parameters" in str(e)
222
+
223
+ class PassTool(Tool):
224
+ name = "specific"
225
+ description = "test description"
226
+ inputs = {"string_input": {"type": "string", "description": "input description"}}
227
+ output_type = "string"
228
+
229
+ def __init__(self, url: str | None = "none"):
230
+ super().__init__(self)
231
+ self.url = url
232
+
233
+ def forward(self, string_input: str) -> str:
234
+ return self.url + string_input
235
+
236
+ fail_tool = PassTool()
237
+ fail_tool.to_dict()
238
+
239
+ def test_saving_tool_allows_no_imports_from_outside_methods(self, tmp_path):
240
+ # Test that using imports from outside functions fails
241
+ import numpy as np
242
+
243
+ class FailTool(Tool):
244
+ name = "specific"
245
+ description = "test description"
246
+ inputs = {"string_input": {"type": "string", "description": "input description"}}
247
+ output_type = "string"
248
+
249
+ def useless_method(self):
250
+ self.client = np.random.random()
251
+ return ""
252
+
253
+ def forward(self, string_input):
254
+ return self.useless_method() + string_input
255
+
256
+ fail_tool = FailTool()
257
+ with pytest.raises(Exception) as e:
258
+ fail_tool.save(tmp_path)
259
+ assert "'np' is undefined" in str(e)
260
+
261
+ # Test that putting these imports inside functions works
262
+ class SuccessTool(Tool):
263
+ name = "specific"
264
+ description = "test description"
265
+ inputs = {"string_input": {"type": "string", "description": "input description"}}
266
+ output_type = "string"
267
+
268
+ def useless_method(self):
269
+ import numpy as np
270
+
271
+ self.client = np.random.random()
272
+ return ""
273
+
274
+ def forward(self, string_input):
275
+ return self.useless_method() + string_input
276
+
277
+ success_tool = SuccessTool()
278
+ success_tool.save(tmp_path)
279
+
280
+ def test_tool_missing_class_attributes_raises_error(self):
281
+ with pytest.raises(Exception) as e:
282
+
283
+ class GetWeatherTool(Tool):
284
+ name = "get_weather"
285
+ description = "Get weather in the next days at given location."
286
+ inputs = {
287
+ "location": {"type": "string", "description": "the location"},
288
+ "celsius": {
289
+ "type": "string",
290
+ "description": "the temperature type",
291
+ },
292
+ }
293
+
294
+ def forward(self, location: str, celsius: bool | None = False) -> str:
295
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
296
+
297
+ GetWeatherTool()
298
+ assert "You must set an attribute output_type" in str(e)
299
+
300
+ def test_tool_from_decorator_optional_args(self):
301
+ @tool
302
+ def get_weather(location: str, celsius: bool | None = False) -> str:
303
+ """
304
+ Get weather in the next days at given location.
305
+ Secretly this tool does not care about the location, it hates the weather everywhere.
306
+
307
+ Args:
308
+ location: the location
309
+ celsius: the temperature type
310
+ """
311
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
312
+
313
+ assert "nullable" in get_weather.inputs["celsius"]
314
+ assert get_weather.inputs["celsius"]["nullable"]
315
+ assert "nullable" not in get_weather.inputs["location"]
316
+
317
+ def test_tool_mismatching_nullable_args_raises_error(self):
318
+ with pytest.raises(Exception) as e:
319
+
320
+ class GetWeatherTool(Tool):
321
+ name = "get_weather"
322
+ description = "Get weather in the next days at given location."
323
+ inputs = {
324
+ "location": {"type": "string", "description": "the location"},
325
+ "celsius": {
326
+ "type": "string",
327
+ "description": "the temperature type",
328
+ },
329
+ }
330
+ output_type = "string"
331
+
332
+ def forward(self, location: str, celsius: bool | None = False) -> str:
333
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
334
+
335
+ GetWeatherTool()
336
+ assert "Nullable" in str(e)
337
+
338
+ with pytest.raises(Exception) as e:
339
+
340
+ class GetWeatherTool2(Tool):
341
+ name = "get_weather"
342
+ description = "Get weather in the next days at given location."
343
+ inputs = {
344
+ "location": {"type": "string", "description": "the location"},
345
+ "celsius": {
346
+ "type": "string",
347
+ "description": "the temperature type",
348
+ },
349
+ }
350
+ output_type = "string"
351
+
352
+ def forward(self, location: str, celsius: bool = False) -> str:
353
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
354
+
355
+ GetWeatherTool2()
356
+ assert "Nullable" in str(e)
357
+
358
+ with pytest.raises(Exception) as e:
359
+
360
+ class GetWeatherTool3(Tool):
361
+ name = "get_weather"
362
+ description = "Get weather in the next days at given location."
363
+ inputs = {
364
+ "location": {"type": "string", "description": "the location"},
365
+ "celsius": {
366
+ "type": "string",
367
+ "description": "the temperature type",
368
+ "nullable": True,
369
+ },
370
+ }
371
+ output_type = "string"
372
+
373
+ def forward(self, location, celsius: str) -> str:
374
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
375
+
376
+ GetWeatherTool3()
377
+ assert "Nullable" in str(e)
378
+
379
+ def test_tool_default_parameters_is_nullable(self):
380
+ @tool
381
+ def get_weather(location: str, celsius: bool = False) -> str:
382
+ """
383
+ Get weather in the next days at given location.
384
+
385
+ Args:
386
+ location: The location to get the weather for.
387
+ celsius: is the temperature given in celsius?
388
+ """
389
+ return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
390
+
391
+ assert get_weather.inputs["celsius"]["nullable"]
392
+
393
+ def test_tool_supports_any_none(self, tmp_path):
394
+ @tool
395
+ def get_weather(location: Any) -> None:
396
+ """
397
+ Get weather in the next days at given location.
398
+
399
+ Args:
400
+ location: The location to get the weather for.
401
+ """
402
+ return
403
+
404
+ get_weather.save(tmp_path)
405
+ assert get_weather.inputs["location"]["type"] == "any"
406
+ assert get_weather.output_type == "null"
407
+
408
+ def test_tool_supports_array(self):
409
+ @tool
410
+ def get_weather(locations: list[str], months: tuple[str, str] | None = None) -> dict[str, float]:
411
+ """
412
+ Get weather in the next days at given locations.
413
+
414
+ Args:
415
+ locations: The locations to get the weather for.
416
+ months: The months to get the weather for
417
+ """
418
+ return
419
+
420
+ assert get_weather.inputs["locations"]["type"] == "array"
421
+ assert get_weather.inputs["months"]["type"] == "array"
422
+
423
+ def test_tool_supports_string_literal(self):
424
+ @tool
425
+ def get_weather(unit: Literal["celsius", "fahrenheit"] = "celsius") -> None:
426
+ """
427
+ Get weather in the next days at given location.
428
+
429
+ Args:
430
+ unit: The unit of temperature
431
+ """
432
+ return
433
+
434
+ assert get_weather.inputs["unit"]["type"] == "string"
435
+ assert get_weather.inputs["unit"]["enum"] == ["celsius", "fahrenheit"]
436
+
437
+ def test_tool_supports_numeric_literal(self):
438
+ @tool
439
+ def get_choice(choice: Literal[1, 2, 3]) -> None:
440
+ """
441
+ Get choice based on the provided numeric literal.
442
+
443
+ Args:
444
+ choice: The numeric choice to be made.
445
+ """
446
+ return
447
+
448
+ assert get_choice.inputs["choice"]["type"] == "integer"
449
+ assert get_choice.inputs["choice"]["enum"] == [1, 2, 3]
450
+
451
+ def test_tool_supports_nullable_literal(self):
452
+ @tool
453
+ def get_choice(choice: Literal[1, 2, 3, None]) -> None:
454
+ """
455
+ Get choice based on the provided value.
456
+
457
+ Args:
458
+ choice: The numeric choice to be made.
459
+ """
460
+ return
461
+
462
+ assert get_choice.inputs["choice"]["type"] == "integer"
463
+ assert get_choice.inputs["choice"]["nullable"] is True
464
+ assert get_choice.inputs["choice"]["enum"] == [1, 2, 3]
465
+
466
+ def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self, tmp_path):
467
+ @tool
468
+ def get_weather(location: Any) -> None:
469
+ """
470
+ Get weather in the next days at given location.
471
+ And works pretty well.
472
+
473
+ Args:
474
+ location: The location to get the weather for.
475
+ """
476
+ return
477
+
478
+ get_weather.save(tmp_path)
479
+ with open(os.path.join(tmp_path, "tool.py"), "r", encoding="utf-8") as f:
480
+ source_code = f.read()
481
+ compile(source_code, f.name, "exec")
482
+
483
+ @pytest.mark.parametrize("fixture_name", ["boolean_default_tool_class", "boolean_default_tool_function"])
484
+ def test_to_dict_boolean_default_input(self, fixture_name, request):
485
+ """Test that boolean input parameter with default value is correctly represented in to_dict output"""
486
+ tool = request.getfixturevalue(fixture_name)
487
+ result = tool.to_dict()
488
+ # Check that the boolean default annotation is preserved
489
+ assert "flag: bool = False" in result["code"]
490
+ # Check nullable attribute is set for the parameter with default value
491
+ assert "'nullable': True" in result["code"]
492
+
493
+ @pytest.mark.parametrize("fixture_name", ["optional_input_tool_class", "optional_input_tool_function"])
494
+ def test_to_dict_optional_input(self, fixture_name, request):
495
+ """Test that Optional/nullable input parameter is correctly represented in to_dict output"""
496
+ tool = request.getfixturevalue(fixture_name)
497
+ result = tool.to_dict()
498
+ # Check the Optional type annotation is preserved
499
+ assert "optional_text: str | None = None" in result["code"]
500
+ # Check that the input is marked as nullable in the code
501
+ assert "'nullable': True" in result["code"]
502
+
503
+ def test_from_dict_roundtrip(self, example_tool):
504
+ # Convert to dict
505
+ tool_dict = example_tool.to_dict()
506
+ # Create from dict
507
+ recreated_tool = Tool.from_dict(tool_dict)
508
+ # Verify properties
509
+ assert recreated_tool.name == example_tool.name
510
+ assert recreated_tool.description == example_tool.description
511
+ assert recreated_tool.inputs == example_tool.inputs
512
+ assert recreated_tool.output_type == example_tool.output_type
513
+ # Verify functionality
514
+ test_input = "Hello, world!"
515
+ assert recreated_tool(test_input) == test_input.upper()
516
+
517
+ def test_tool_from_dict_invalid(self):
518
+ # Missing code key
519
+ with pytest.raises(ValueError) as e:
520
+ Tool.from_dict({"name": "invalid_tool"})
521
+ assert "must contain 'code' key" in str(e)
522
+
523
+ def test_tool_decorator_preserves_original_function(self):
524
+ # Define a test function with type hints and docstring
525
+ def test_function(items: list[str]) -> str:
526
+ """Join a list of strings.
527
+ Args:
528
+ items: A list of strings to join
529
+ Returns:
530
+ The joined string
531
+ """
532
+ return ", ".join(items)
533
+
534
+ # Store original function signature, name, and source
535
+ original_signature = inspect.signature(test_function)
536
+ original_name = test_function.__name__
537
+ original_docstring = test_function.__doc__
538
+
539
+ # Create a tool from the function
540
+ test_tool = tool(test_function)
541
+
542
+ # Check that the original function is unchanged
543
+ assert original_signature == inspect.signature(test_function)
544
+ assert original_name == test_function.__name__
545
+ assert original_docstring == test_function.__doc__
546
+
547
+ # Verify that the tool's forward method has a different signature (it has 'self')
548
+ tool_forward_sig = inspect.signature(test_tool.forward)
549
+ assert list(tool_forward_sig.parameters.keys())[0] == "self"
550
+
551
+ # Original function should not have 'self' parameter
552
+ assert "self" not in original_signature.parameters
553
+
554
+ def test_tool_with_union_type_return(self):
555
+ @tool
556
+ def union_type_return_tool_function(param: int) -> str | bool:
557
+ """
558
+ Tool with output union type.
559
+
560
+ Args:
561
+ param: Input parameter.
562
+ """
563
+ return str(param) if param > 0 else False
564
+
565
+ assert isinstance(union_type_return_tool_function, Tool)
566
+ assert union_type_return_tool_function.output_type == "any"
567
+
568
+
569
+ @pytest.fixture
570
+ def mock_server_parameters():
571
+ return MagicMock()
572
+
573
+
574
+ @pytest.fixture
575
+ def mock_mcp_adapt():
576
+ with patch("mcpadapt.core.MCPAdapt") as mock:
577
+ mock.return_value.__enter__.return_value = ["tool1", "tool2"]
578
+ mock.return_value.__exit__.return_value = None
579
+ yield mock
580
+
581
+
582
+ @pytest.fixture
583
+ def mock_smolagents_adapter():
584
+ with patch("mcpadapt.smolagents_adapter.SmolAgentsAdapter") as mock:
585
+ yield mock
586
+
587
+
588
+ class TestToolCollection:
589
+ def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
590
+ with ToolCollection.from_mcp(mock_server_parameters, trust_remote_code=True) as tool_collection:
591
+ assert isinstance(tool_collection, ToolCollection)
592
+ assert len(tool_collection.tools) == 2
593
+ assert "tool1" in tool_collection.tools
594
+ assert "tool2" in tool_collection.tools
595
+
596
+ @require_run_all
597
+ def test_integration_from_mcp(self):
598
+ # define the most simple mcp server with one tool that echoes the input text
599
+ mcp_server_script = dedent("""\
600
+ from mcp.server.fastmcp import FastMCP
601
+
602
+ mcp = FastMCP("Echo Server")
603
+
604
+ @mcp.tool()
605
+ def echo_tool(text: str) -> str:
606
+ return text
607
+
608
+ mcp.run()
609
+ """).strip()
610
+
611
+ mcp_server_params = mcp.StdioServerParameters(
612
+ command="python",
613
+ args=["-c", mcp_server_script],
614
+ )
615
+
616
+ with ToolCollection.from_mcp(mcp_server_params, trust_remote_code=True) as tool_collection:
617
+ assert len(tool_collection.tools) == 1, "Expected 1 tool"
618
+ assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
619
+ assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
620
+
621
+ def test_integration_from_mcp_with_streamable_http(self):
622
+ import subprocess
623
+ import time
624
+
625
+ # define the most simple mcp server with one tool that echoes the input text
626
+ mcp_server_script = dedent("""\
627
+ from mcp.server.fastmcp import FastMCP
628
+
629
+ mcp = FastMCP("Echo Server", host="127.0.0.1", port=8000)
630
+
631
+ @mcp.tool()
632
+ def echo_tool(text: str) -> str:
633
+ return text
634
+
635
+ mcp.run(transport="streamable-http")
636
+ """).strip()
637
+
638
+ # start the SSE mcp server in a subprocess
639
+ server_process = subprocess.Popen(
640
+ ["python", "-c", mcp_server_script],
641
+ )
642
+
643
+ # wait for the server to start
644
+ time.sleep(1)
645
+
646
+ try:
647
+ with ToolCollection.from_mcp(
648
+ {"url": "http://127.0.0.1:8000/mcp", "transport": "streamable-http"}, trust_remote_code=True
649
+ ) as tool_collection:
650
+ assert len(tool_collection.tools) == 1, "Expected 1 tool"
651
+ assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
652
+ assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
653
+ finally:
654
+ # clean up the process when test is done
655
+ server_process.kill()
656
+ server_process.wait()
657
+
658
+ def test_integration_from_mcp_with_sse(self):
659
+ import subprocess
660
+ import time
661
+
662
+ # define the most simple mcp server with one tool that echoes the input text
663
+ mcp_server_script = dedent("""\
664
+ from mcp.server.fastmcp import FastMCP
665
+
666
+ mcp = FastMCP("Echo Server", host="127.0.0.1", port=8000)
667
+
668
+ @mcp.tool()
669
+ def echo_tool(text: str) -> str:
670
+ return text
671
+
672
+ mcp.run("sse")
673
+ """).strip()
674
+
675
+ # start the SSE mcp server in a subprocess
676
+ server_process = subprocess.Popen(
677
+ ["python", "-c", mcp_server_script],
678
+ )
679
+
680
+ # wait for the server to start
681
+ time.sleep(1)
682
+
683
+ try:
684
+ with ToolCollection.from_mcp(
685
+ {"url": "http://127.0.0.1:8000/sse", "transport": "sse"}, trust_remote_code=True
686
+ ) as tool_collection:
687
+ assert len(tool_collection.tools) == 1, "Expected 1 tool"
688
+ assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
689
+ assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
690
+ finally:
691
+ # clean up the process when test is done
692
+ server_process.kill()
693
+ server_process.wait()
694
+
695
+
696
+ @pytest.mark.parametrize("tool_fixture_name", ["boolean_default_tool_class"])
697
+ def test_launch_gradio_demo_does_not_raise(tool_fixture_name, request):
698
+ tool = request.getfixturevalue(tool_fixture_name)
699
+ with patch("gradio.Interface.launch") as mock_launch:
700
+ launch_gradio_demo(tool)
701
+ assert mock_launch.call_count == 1
702
+
703
+
704
+ @pytest.mark.parametrize(
705
+ "tool_input_type, expected_input, expects_error",
706
+ [
707
+ (bool, True, False),
708
+ (str, "b", False),
709
+ (int, 1, False),
710
+ (list, ["a", "b"], False),
711
+ (list[str], ["a", "b"], False),
712
+ (dict[str, str], {"a": "b"}, False),
713
+ (dict[str, str], "b", True),
714
+ (bool, "b", True),
715
+ ],
716
+ )
717
+ def test_validate_tool_arguments(tool_input_type, expected_input, expects_error):
718
+ @tool
719
+ def test_tool(argument_a: tool_input_type) -> str:
720
+ """Fake tool
721
+
722
+ Args:
723
+ argument_a: The input
724
+ """
725
+ return argument_a
726
+
727
+ error = validate_tool_arguments(test_tool, {"argument_a": expected_input})
728
+ if expects_error:
729
+ assert error is not None
730
+ else:
731
+ assert error is None
tests/test_types.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+ import uuid
19
+
20
+ import PIL.Image
21
+ from transformers.testing_utils import (
22
+ require_soundfile,
23
+ )
24
+
25
+ from smolagents.agent_types import AgentAudio, AgentImage, AgentText
26
+
27
+ from .utils.markers import require_torch
28
+
29
+
30
+ def get_new_path(suffix="") -> str:
31
+ directory = tempfile.mkdtemp()
32
+ return os.path.join(directory, str(uuid.uuid4()) + suffix)
33
+
34
+
35
+ @require_soundfile
36
+ @require_torch
37
+ class AgentAudioTests(unittest.TestCase):
38
+ def test_from_tensor(self):
39
+ import soundfile as sf
40
+ import torch
41
+
42
+ tensor = torch.rand(12, dtype=torch.float64) - 0.5
43
+ agent_type = AgentAudio(tensor)
44
+ path = str(agent_type.to_string())
45
+
46
+ # Ensure that the tensor and the agent_type's tensor are the same
47
+ self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
48
+
49
+ del agent_type
50
+
51
+ # Ensure the path remains even after the object deletion
52
+ self.assertTrue(os.path.exists(path))
53
+
54
+ # Ensure that the file contains the same value as the original tensor
55
+ new_tensor, _ = sf.read(path)
56
+ self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
57
+
58
+ def test_from_string(self):
59
+ import soundfile as sf
60
+ import torch
61
+
62
+ tensor = torch.rand(12, dtype=torch.float64) - 0.5
63
+ path = get_new_path(suffix=".wav")
64
+ sf.write(path, tensor, 16000)
65
+
66
+ agent_type = AgentAudio(path)
67
+
68
+ self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
69
+ self.assertEqual(agent_type.to_string(), path)
70
+
71
+
72
+ @require_torch
73
+ class TestAgentImage:
74
+ def test_from_tensor(self):
75
+ import torch
76
+
77
+ tensor = torch.randint(0, 256, (64, 64, 3))
78
+ agent_type = AgentImage(tensor)
79
+ path = str(agent_type.to_string())
80
+
81
+ # Ensure that the tensor and the agent_type's tensor are the same
82
+ assert torch.allclose(tensor, agent_type._tensor, atol=1e-4)
83
+
84
+ assert isinstance(agent_type.to_raw(), PIL.Image.Image)
85
+
86
+ # Ensure the path remains even after the object deletion
87
+ del agent_type
88
+ assert os.path.exists(path)
89
+
90
+ def test_from_string(self, shared_datadir):
91
+ path = shared_datadir / "000000039769.png"
92
+ image = PIL.Image.open(path)
93
+ agent_type = AgentImage(path)
94
+
95
+ assert path.samefile(agent_type.to_string())
96
+ assert image == agent_type.to_raw()
97
+
98
+ # Ensure the path remains even after the object deletion
99
+ del agent_type
100
+ assert os.path.exists(path)
101
+
102
+ def test_from_image(self, shared_datadir):
103
+ path = shared_datadir / "000000039769.png"
104
+ image = PIL.Image.open(path)
105
+ agent_type = AgentImage(image)
106
+
107
+ assert not path.samefile(agent_type.to_string())
108
+ assert image == agent_type.to_raw()
109
+
110
+ # Ensure the path remains even after the object deletion
111
+ del agent_type
112
+ assert os.path.exists(path)
113
+
114
+
115
+ class AgentTextTests(unittest.TestCase):
116
+ def test_from_string(self):
117
+ string = "Hey!"
118
+ agent_type = AgentText(string)
119
+
120
+ self.assertEqual(string, agent_type.to_string())
121
+ self.assertEqual(string, agent_type.to_raw())
tests/test_utils.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ import os
17
+ import textwrap
18
+ import unittest
19
+
20
+ import pytest
21
+ from IPython.core.interactiveshell import InteractiveShell
22
+
23
+ from smolagents import Tool
24
+ from smolagents.tools import tool
25
+ from smolagents.utils import get_source, instance_to_source, is_valid_name, parse_code_blobs, parse_json_blob
26
+
27
+
28
+ class ValidTool(Tool):
29
+ name = "valid_tool"
30
+ description = "A valid tool"
31
+ inputs = {"input": {"type": "string", "description": "input"}}
32
+ output_type = "string"
33
+ simple_attr = "string"
34
+ dict_attr = {"key": "value"}
35
+
36
+ def __init__(self, optional_param="default"):
37
+ super().__init__()
38
+ self.param = optional_param
39
+
40
+ def forward(self, input: str) -> str:
41
+ return input.upper()
42
+
43
+
44
+ @tool
45
+ def valid_tool_function(input: str) -> str:
46
+ """A valid tool function.
47
+
48
+ Args:
49
+ input (str): Input string.
50
+ """
51
+ return input.upper()
52
+
53
+
54
+ VALID_TOOL_SOURCE = """\
55
+ from smolagents.tools import Tool
56
+
57
+ class ValidTool(Tool):
58
+ name = "valid_tool"
59
+ description = "A valid tool"
60
+ inputs = {'input': {'type': 'string', 'description': 'input'}}
61
+ output_type = "string"
62
+ simple_attr = "string"
63
+ dict_attr = {'key': 'value'}
64
+
65
+ def __init__(self, optional_param="default"):
66
+ super().__init__()
67
+ self.param = optional_param
68
+
69
+ def forward(self, input: str) -> str:
70
+ return input.upper()
71
+ """
72
+
73
+ VALID_TOOL_FUNCTION_SOURCE = '''\
74
+ from smolagents.tools import Tool
75
+
76
+ class SimpleTool(Tool):
77
+ name = "valid_tool_function"
78
+ description = "A valid tool function."
79
+ inputs = {'input': {'type': 'string', 'description': 'Input string.'}}
80
+ output_type = "string"
81
+
82
+ def __init__(self):
83
+ self.is_initialized = True
84
+
85
+ def forward(self, input: str) -> str:
86
+ """A valid tool function.
87
+
88
+ Args:
89
+ input (str): Input string.
90
+ """
91
+ return input.upper()
92
+ '''
93
+
94
+
95
+ class AgentTextTests(unittest.TestCase):
96
+ def test_parse_code_blobs(self):
97
+ with pytest.raises(ValueError):
98
+ parse_code_blobs("Wrong blob!")
99
+
100
+ # Parsing mardkwon with code blobs should work
101
+ output = parse_code_blobs("""
102
+ Here is how to solve the problem:
103
+ <code>
104
+ import numpy as np
105
+ </code>
106
+ """)
107
+ assert output == "import numpy as np"
108
+
109
+ # Parsing code blobs should work
110
+ code_blob = "import numpy as np"
111
+ output = parse_code_blobs(code_blob)
112
+ assert output == code_blob
113
+
114
+ # Allow whitespaces after header
115
+ output = parse_code_blobs("<code> \ncode_a\n</code>")
116
+ assert output == "code_a"
117
+
118
+ def test_multiple_code_blobs(self):
119
+ test_input = "<code>\nFoo\n</code>\n\n<code>\ncode_a\n</code>\n\n<code>\ncode_b\n</code>"
120
+ result = parse_code_blobs(test_input)
121
+ assert result == "Foo\n\ncode_a\n\ncode_b"
122
+
123
+
124
+ @pytest.fixture(scope="function")
125
+ def ipython_shell():
126
+ """Reset IPython shell before and after each test."""
127
+ shell = InteractiveShell.instance()
128
+ shell.reset() # Clean before test
129
+ yield shell
130
+ shell.reset() # Clean after test
131
+
132
+
133
+ @pytest.mark.parametrize(
134
+ "obj_name, code_blob",
135
+ [
136
+ ("test_func", "def test_func():\n return 42"),
137
+ ("TestClass", "class TestClass:\n ..."),
138
+ ],
139
+ )
140
+ def test_get_source_ipython(ipython_shell, obj_name, code_blob):
141
+ ipython_shell.run_cell(code_blob, store_history=True)
142
+ obj = ipython_shell.user_ns[obj_name]
143
+ assert get_source(obj) == code_blob
144
+
145
+
146
+ def test_get_source_standard_class():
147
+ class TestClass: ...
148
+
149
+ source = get_source(TestClass)
150
+ assert source == "class TestClass: ..."
151
+ assert source == textwrap.dedent(inspect.getsource(TestClass)).strip()
152
+
153
+
154
+ def test_get_source_standard_function():
155
+ def test_func(): ...
156
+
157
+ source = get_source(test_func)
158
+ assert source == "def test_func(): ..."
159
+ assert source == textwrap.dedent(inspect.getsource(test_func)).strip()
160
+
161
+
162
+ def test_get_source_ipython_errors_empty_cells(ipython_shell):
163
+ test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
164
+ ipython_shell.user_ns["In"] = [""]
165
+ ipython_shell.run_cell(test_code, store_history=True)
166
+ with pytest.raises(ValueError, match="No code cells found in IPython session"):
167
+ get_source(ipython_shell.user_ns["TestClass"])
168
+
169
+
170
+ def test_get_source_ipython_errors_definition_not_found(ipython_shell):
171
+ test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
172
+ ipython_shell.user_ns["In"] = ["", "print('No class definition here')"]
173
+ ipython_shell.run_cell(test_code, store_history=True)
174
+ with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"):
175
+ get_source(ipython_shell.user_ns["TestClass"])
176
+
177
+
178
+ def test_get_source_ipython_errors_type_error():
179
+ with pytest.raises(TypeError, match="Expected class or callable"):
180
+ get_source(None)
181
+
182
+
183
+ @pytest.mark.parametrize(
184
+ "tool, expected_tool_source", [(ValidTool(), VALID_TOOL_SOURCE), (valid_tool_function, VALID_TOOL_FUNCTION_SOURCE)]
185
+ )
186
+ def test_instance_to_source(tool, expected_tool_source):
187
+ tool_source = instance_to_source(tool, base_cls=Tool)
188
+ assert tool_source == expected_tool_source
189
+
190
+
191
+ def test_e2e_class_tool_save(tmp_path):
192
+ class TestTool(Tool):
193
+ name = "test_tool"
194
+ description = "Test tool description"
195
+ inputs = {
196
+ "task": {
197
+ "type": "string",
198
+ "description": "tool input",
199
+ }
200
+ }
201
+ output_type = "string"
202
+
203
+ def forward(self, task: str):
204
+ import IPython # noqa: F401
205
+
206
+ return task
207
+
208
+ test_tool = TestTool()
209
+ test_tool.save(tmp_path, make_gradio_app=True)
210
+ assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
211
+ assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
212
+ """\
213
+ from typing import Any, Optional
214
+ from smolagents.tools import Tool
215
+ import IPython
216
+
217
+ class TestTool(Tool):
218
+ name = "test_tool"
219
+ description = "Test tool description"
220
+ inputs = {'task': {'type': 'string', 'description': 'tool input'}}
221
+ output_type = "string"
222
+
223
+ def forward(self, task: str):
224
+ import IPython # noqa: F401
225
+
226
+ return task
227
+
228
+ def __init__(self, *args, **kwargs):
229
+ self.is_initialized = False
230
+ """
231
+ )
232
+ requirements = set((tmp_path / "requirements.txt").read_text().split())
233
+ assert requirements == {"IPython", "smolagents"}
234
+ assert (tmp_path / "app.py").read_text() == textwrap.dedent(
235
+ """\
236
+ from smolagents import launch_gradio_demo
237
+ from tool import TestTool
238
+
239
+ tool = TestTool()
240
+ launch_gradio_demo(tool)
241
+ """
242
+ )
243
+
244
+
245
+ def test_e2e_ipython_class_tool_save(tmp_path):
246
+ shell = InteractiveShell.instance()
247
+ code_blob = textwrap.dedent(
248
+ f"""\
249
+ from smolagents.tools import Tool
250
+ class TestTool(Tool):
251
+ name = "test_tool"
252
+ description = "Test tool description"
253
+ inputs = {{"task": {{"type": "string",
254
+ "description": "tool input",
255
+ }}
256
+ }}
257
+ output_type = "string"
258
+
259
+ def forward(self, task: str):
260
+ import IPython # noqa: F401
261
+
262
+ return task
263
+ TestTool().save("{tmp_path}", make_gradio_app=True)
264
+ """
265
+ )
266
+ assert shell.run_cell(code_blob, store_history=True).success
267
+ assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
268
+ assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
269
+ """\
270
+ from typing import Any, Optional
271
+ from smolagents.tools import Tool
272
+ import IPython
273
+
274
+ class TestTool(Tool):
275
+ name = "test_tool"
276
+ description = "Test tool description"
277
+ inputs = {'task': {'type': 'string', 'description': 'tool input'}}
278
+ output_type = "string"
279
+
280
+ def forward(self, task: str):
281
+ import IPython # noqa: F401
282
+
283
+ return task
284
+
285
+ def __init__(self, *args, **kwargs):
286
+ self.is_initialized = False
287
+ """
288
+ )
289
+ requirements = set((tmp_path / "requirements.txt").read_text().split())
290
+ assert requirements == {"IPython", "smolagents"}
291
+ assert (tmp_path / "app.py").read_text() == textwrap.dedent(
292
+ """\
293
+ from smolagents import launch_gradio_demo
294
+ from tool import TestTool
295
+
296
+ tool = TestTool()
297
+ launch_gradio_demo(tool)
298
+ """
299
+ )
300
+
301
+
302
+ def test_e2e_function_tool_save(tmp_path):
303
+ @tool
304
+ def test_tool(task: str) -> str:
305
+ """
306
+ Test tool description
307
+
308
+ Args:
309
+ task: tool input
310
+ """
311
+ import IPython # noqa: F401
312
+
313
+ return task
314
+
315
+ test_tool.save(tmp_path, make_gradio_app=True)
316
+ assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
317
+ assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
318
+ """\
319
+ from smolagents import Tool
320
+ from typing import Any, Optional
321
+
322
+ class SimpleTool(Tool):
323
+ name = "test_tool"
324
+ description = "Test tool description"
325
+ inputs = {'task': {'type': 'string', 'description': 'tool input'}}
326
+ output_type = "string"
327
+
328
+ def forward(self, task: str) -> str:
329
+ \"""
330
+ Test tool description
331
+
332
+ Args:
333
+ task: tool input
334
+ \"""
335
+ import IPython # noqa: F401
336
+
337
+ return task"""
338
+ )
339
+ requirements = set((tmp_path / "requirements.txt").read_text().split())
340
+ assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
341
+ assert (tmp_path / "app.py").read_text() == textwrap.dedent(
342
+ """\
343
+ from smolagents import launch_gradio_demo
344
+ from tool import SimpleTool
345
+
346
+ tool = SimpleTool()
347
+ launch_gradio_demo(tool)
348
+ """
349
+ )
350
+
351
+
352
+ def test_e2e_ipython_function_tool_save(tmp_path):
353
+ shell = InteractiveShell.instance()
354
+ code_blob = textwrap.dedent(
355
+ f"""
356
+ from smolagents import tool
357
+
358
+ @tool
359
+ def test_tool(task: str) -> str:
360
+ \"""
361
+ Test tool description
362
+
363
+ Args:
364
+ task: tool input
365
+ \"""
366
+ import IPython # noqa: F401
367
+
368
+ return task
369
+
370
+ test_tool.save("{tmp_path}", make_gradio_app=True)
371
+ """
372
+ )
373
+ assert shell.run_cell(code_blob, store_history=True).success
374
+ assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
375
+ assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
376
+ """\
377
+ from smolagents import Tool
378
+ from typing import Any, Optional
379
+
380
+ class SimpleTool(Tool):
381
+ name = "test_tool"
382
+ description = "Test tool description"
383
+ inputs = {'task': {'type': 'string', 'description': 'tool input'}}
384
+ output_type = "string"
385
+
386
+ def forward(self, task: str) -> str:
387
+ \"""
388
+ Test tool description
389
+
390
+ Args:
391
+ task: tool input
392
+ \"""
393
+ import IPython # noqa: F401
394
+
395
+ return task"""
396
+ )
397
+ requirements = set((tmp_path / "requirements.txt").read_text().split())
398
+ assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
399
+ assert (tmp_path / "app.py").read_text() == textwrap.dedent(
400
+ """\
401
+ from smolagents import launch_gradio_demo
402
+ from tool import SimpleTool
403
+
404
+ tool = SimpleTool()
405
+ launch_gradio_demo(tool)
406
+ """
407
+ )
408
+
409
+
410
+ @pytest.mark.parametrize(
411
+ "raw_json, expected_data, expected_blob",
412
+ [
413
+ (
414
+ """{}""",
415
+ {},
416
+ "",
417
+ ),
418
+ (
419
+ """Text{}""",
420
+ {},
421
+ "Text",
422
+ ),
423
+ (
424
+ """{"simple": "json"}""",
425
+ {"simple": "json"},
426
+ "",
427
+ ),
428
+ (
429
+ """With text here{"simple": "json"}""",
430
+ {"simple": "json"},
431
+ "With text here",
432
+ ),
433
+ (
434
+ """{"simple": "json"}With text after""",
435
+ {"simple": "json"},
436
+ "",
437
+ ),
438
+ (
439
+ """With text before{"simple": "json"}And text after""",
440
+ {"simple": "json"},
441
+ "With text before",
442
+ ),
443
+ ],
444
+ )
445
+ def test_parse_json_blob_with_valid_json(raw_json, expected_data, expected_blob):
446
+ data, blob = parse_json_blob(raw_json)
447
+
448
+ assert data == expected_data
449
+ assert blob == expected_blob
450
+
451
+
452
+ @pytest.mark.parametrize(
453
+ "raw_json",
454
+ [
455
+ """simple": "json"}""",
456
+ """With text here"simple": "json"}""",
457
+ """{"simple": ""json"}With text after""",
458
+ """{"simple": "json"With text after""",
459
+ "}}",
460
+ ],
461
+ )
462
+ def test_parse_json_blob_with_invalid_json(raw_json):
463
+ with pytest.raises(Exception):
464
+ parse_json_blob(raw_json)
465
+
466
+
467
+ @pytest.mark.parametrize(
468
+ "name,expected",
469
+ [
470
+ # Valid identifiers
471
+ ("valid_name", True),
472
+ ("ValidName", True),
473
+ ("valid123", True),
474
+ ("_private", True),
475
+ # Invalid identifiers
476
+ ("", False),
477
+ ("123invalid", False),
478
+ ("invalid-name", False),
479
+ ("invalid name", False),
480
+ ("invalid.name", False),
481
+ # Python keywords
482
+ ("if", False),
483
+ ("for", False),
484
+ ("class", False),
485
+ ("return", False),
486
+ # Non-string inputs
487
+ (123, False),
488
+ (None, False),
489
+ ([], False),
490
+ ({}, False),
491
+ ],
492
+ )
493
+ def test_is_valid_name(name, expected):
494
+ """Test the is_valid_name function with various inputs."""
495
+ assert is_valid_name(name) is expected