Upload 21 files
Browse files- tests/__init__.py +0 -0
- tests/conftest.py +23 -0
- tests/test_agents.py +2089 -0
- tests/test_all_docs.py +176 -0
- tests/test_cli.py +112 -0
- tests/test_default_tools.py +134 -0
- tests/test_final_answer.py +56 -0
- tests/test_function_type_hints_utils.py +514 -0
- tests/test_gradio_ui.py +385 -0
- tests/test_import.py +31 -0
- tests/test_local_python_executor.py +2353 -0
- tests/test_mcp_client.py +60 -0
- tests/test_memory.py +228 -0
- tests/test_models.py +763 -0
- tests/test_monitoring.py +252 -0
- tests/test_remote_executors.py +335 -0
- tests/test_search.py +35 -0
- tests/test_tool_validation.py +189 -0
- tests/test_tools.py +731 -0
- tests/test_types.py +121 -0
- tests/test_utils.py +495 -0
tests/__init__.py
ADDED
File without changes
|
tests/conftest.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import patch
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from smolagents.agents import MultiStepAgent
|
6 |
+
from smolagents.monitoring import LogLevel
|
7 |
+
|
8 |
+
|
9 |
+
# Import fixture modules as plugins
|
10 |
+
pytest_plugins = ["tests.fixtures.agents", "tests.fixtures.tools"]
|
11 |
+
|
12 |
+
original_multi_step_agent_init = MultiStepAgent.__init__
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.fixture(autouse=True)
|
16 |
+
def patch_multi_step_agent_with_suppressed_logging():
|
17 |
+
with patch.object(MultiStepAgent, "__init__", autospec=True) as mock_init:
|
18 |
+
|
19 |
+
def init_with_suppressed_logging(self, *args, verbosity_level=LogLevel.OFF, **kwargs):
|
20 |
+
original_multi_step_agent_init(self, *args, verbosity_level=verbosity_level, **kwargs)
|
21 |
+
|
22 |
+
mock_init.side_effect = init_with_suppressed_logging
|
23 |
+
yield
|
tests/test_agents.py
ADDED
@@ -0,0 +1,2089 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import io
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import tempfile
|
19 |
+
import uuid
|
20 |
+
import warnings
|
21 |
+
from collections.abc import Generator
|
22 |
+
from contextlib import nullcontext as does_not_raise
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from pathlib import Path
|
25 |
+
from textwrap import dedent
|
26 |
+
from typing import Optional
|
27 |
+
from unittest.mock import MagicMock, patch
|
28 |
+
|
29 |
+
import pytest
|
30 |
+
from huggingface_hub import (
|
31 |
+
ChatCompletionOutputFunctionDefinition,
|
32 |
+
ChatCompletionOutputMessage,
|
33 |
+
ChatCompletionOutputToolCall,
|
34 |
+
)
|
35 |
+
from rich.console import Console
|
36 |
+
|
37 |
+
from smolagents import EMPTY_PROMPT_TEMPLATES
|
38 |
+
from smolagents.agent_types import AgentImage, AgentText
|
39 |
+
from smolagents.agents import (
|
40 |
+
AgentError,
|
41 |
+
AgentMaxStepsError,
|
42 |
+
AgentToolCallError,
|
43 |
+
CodeAgent,
|
44 |
+
MultiStepAgent,
|
45 |
+
ToolCall,
|
46 |
+
ToolCallingAgent,
|
47 |
+
ToolOutput,
|
48 |
+
populate_template,
|
49 |
+
)
|
50 |
+
from smolagents.default_tools import DuckDuckGoSearchTool, FinalAnswerTool, PythonInterpreterTool, VisitWebpageTool
|
51 |
+
from smolagents.memory import (
|
52 |
+
ActionStep,
|
53 |
+
PlanningStep,
|
54 |
+
TaskStep,
|
55 |
+
)
|
56 |
+
from smolagents.models import (
|
57 |
+
ChatMessage,
|
58 |
+
ChatMessageToolCall,
|
59 |
+
ChatMessageToolCallFunction,
|
60 |
+
InferenceClientModel,
|
61 |
+
MessageRole,
|
62 |
+
Model,
|
63 |
+
TransformersModel,
|
64 |
+
)
|
65 |
+
from smolagents.monitoring import AgentLogger, LogLevel, TokenUsage
|
66 |
+
from smolagents.tools import Tool, tool
|
67 |
+
from smolagents.utils import (
|
68 |
+
BASE_BUILTIN_MODULES,
|
69 |
+
AgentExecutionError,
|
70 |
+
AgentGenerationError,
|
71 |
+
AgentToolExecutionError,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class ChoiceDeltaToolCallFunction:
|
77 |
+
arguments: Optional[str] = None
|
78 |
+
name: Optional[str] = None
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class ChoiceDeltaToolCall:
|
83 |
+
index: Optional[int] = None
|
84 |
+
id: Optional[str] = None
|
85 |
+
function: Optional[ChoiceDeltaToolCallFunction] = None
|
86 |
+
type: Optional[str] = None
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class ChoiceDelta:
|
91 |
+
content: Optional[str] = None
|
92 |
+
function_call: Optional[str] = None
|
93 |
+
refusal: Optional[str] = None
|
94 |
+
role: Optional[str] = None
|
95 |
+
tool_calls: Optional[list] = None
|
96 |
+
|
97 |
+
|
98 |
+
def get_new_path(suffix="") -> str:
|
99 |
+
directory = tempfile.mkdtemp()
|
100 |
+
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
101 |
+
|
102 |
+
|
103 |
+
@pytest.fixture
|
104 |
+
def agent_logger():
|
105 |
+
return AgentLogger(
|
106 |
+
LogLevel.DEBUG, console=Console(record=True, no_color=True, force_terminal=False, file=io.StringIO())
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
class FakeToolCallModel(Model):
|
111 |
+
def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
|
112 |
+
if len(messages) < 3:
|
113 |
+
return ChatMessage(
|
114 |
+
role=MessageRole.ASSISTANT,
|
115 |
+
content="",
|
116 |
+
tool_calls=[
|
117 |
+
ChatMessageToolCall(
|
118 |
+
id="call_0",
|
119 |
+
type="function",
|
120 |
+
function=ChatMessageToolCallFunction(
|
121 |
+
name="python_interpreter", arguments={"code": "2*3.6452"}
|
122 |
+
),
|
123 |
+
)
|
124 |
+
],
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
return ChatMessage(
|
128 |
+
role=MessageRole.ASSISTANT,
|
129 |
+
content="",
|
130 |
+
tool_calls=[
|
131 |
+
ChatMessageToolCall(
|
132 |
+
id="call_1",
|
133 |
+
type="function",
|
134 |
+
function=ChatMessageToolCallFunction(name="final_answer", arguments={"answer": "7.2904"}),
|
135 |
+
)
|
136 |
+
],
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
class FakeToolCallModelImage(Model):
|
141 |
+
def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
|
142 |
+
if len(messages) < 3:
|
143 |
+
return ChatMessage(
|
144 |
+
role=MessageRole.ASSISTANT,
|
145 |
+
content="",
|
146 |
+
tool_calls=[
|
147 |
+
ChatMessageToolCall(
|
148 |
+
id="call_0",
|
149 |
+
type="function",
|
150 |
+
function=ChatMessageToolCallFunction(
|
151 |
+
name="fake_image_generation_tool",
|
152 |
+
arguments={"prompt": "An image of a cat"},
|
153 |
+
),
|
154 |
+
)
|
155 |
+
],
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
return ChatMessage(
|
159 |
+
role=MessageRole.ASSISTANT,
|
160 |
+
content="",
|
161 |
+
tool_calls=[
|
162 |
+
ChatMessageToolCall(
|
163 |
+
id="call_1",
|
164 |
+
type="function",
|
165 |
+
function=ChatMessageToolCallFunction(name="final_answer", arguments="image.png"),
|
166 |
+
)
|
167 |
+
],
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
class FakeToolCallModelVL(Model):
|
172 |
+
def generate(self, messages, tools_to_call_from=None, stop_sequences=None):
|
173 |
+
if len(messages) < 3:
|
174 |
+
return ChatMessage(
|
175 |
+
role=MessageRole.ASSISTANT,
|
176 |
+
content="",
|
177 |
+
tool_calls=[
|
178 |
+
ChatMessageToolCall(
|
179 |
+
id="call_0",
|
180 |
+
type="function",
|
181 |
+
function=ChatMessageToolCallFunction(
|
182 |
+
name="fake_image_understanding_tool",
|
183 |
+
arguments={
|
184 |
+
"prompt": "What is in this image?",
|
185 |
+
"image": "image.png",
|
186 |
+
},
|
187 |
+
),
|
188 |
+
)
|
189 |
+
],
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
return ChatMessage(
|
193 |
+
role=MessageRole.ASSISTANT,
|
194 |
+
content="",
|
195 |
+
tool_calls=[
|
196 |
+
ChatMessageToolCall(
|
197 |
+
id="call_1",
|
198 |
+
type="function",
|
199 |
+
function=ChatMessageToolCallFunction(name="final_answer", arguments="The image is a cat."),
|
200 |
+
)
|
201 |
+
],
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
class FakeCodeModel(Model):
|
206 |
+
def generate(self, messages, stop_sequences=None):
|
207 |
+
prompt = str(messages)
|
208 |
+
if "special_marker" not in prompt:
|
209 |
+
return ChatMessage(
|
210 |
+
role=MessageRole.ASSISTANT,
|
211 |
+
content="""
|
212 |
+
Thought: I should multiply 2 by 3.6452. special_marker
|
213 |
+
<code>
|
214 |
+
result = 2**3.6452
|
215 |
+
</code>
|
216 |
+
""",
|
217 |
+
)
|
218 |
+
else: # We're at step 2
|
219 |
+
return ChatMessage(
|
220 |
+
role=MessageRole.ASSISTANT,
|
221 |
+
content="""
|
222 |
+
Thought: I can now answer the initial question
|
223 |
+
<code>
|
224 |
+
final_answer(7.2904)
|
225 |
+
</code>
|
226 |
+
""",
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
class FakeCodeModelPlanning(Model):
|
231 |
+
def generate(self, messages, stop_sequences=None):
|
232 |
+
prompt = str(messages)
|
233 |
+
if "planning_marker" not in prompt:
|
234 |
+
return ChatMessage(
|
235 |
+
role=MessageRole.ASSISTANT,
|
236 |
+
content="llm plan update planning_marker",
|
237 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=10),
|
238 |
+
)
|
239 |
+
elif "action_marker" not in prompt:
|
240 |
+
return ChatMessage(
|
241 |
+
role=MessageRole.ASSISTANT,
|
242 |
+
content="""
|
243 |
+
Thought: I should multiply 2 by 3.6452. action_marker
|
244 |
+
<code>
|
245 |
+
result = 2**3.6452
|
246 |
+
</code>
|
247 |
+
""",
|
248 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=10),
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
return ChatMessage(
|
252 |
+
role=MessageRole.ASSISTANT,
|
253 |
+
content="llm plan again",
|
254 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=10),
|
255 |
+
)
|
256 |
+
|
257 |
+
|
258 |
+
class FakeCodeModelError(Model):
|
259 |
+
def generate(self, messages, stop_sequences=None):
|
260 |
+
prompt = str(messages)
|
261 |
+
if "special_marker" not in prompt:
|
262 |
+
return ChatMessage(
|
263 |
+
role=MessageRole.ASSISTANT,
|
264 |
+
content="""
|
265 |
+
Thought: I should multiply 2 by 3.6452. special_marker
|
266 |
+
<code>
|
267 |
+
print("Flag!")
|
268 |
+
def error_function():
|
269 |
+
raise ValueError("error")
|
270 |
+
|
271 |
+
error_function()
|
272 |
+
</code>
|
273 |
+
""",
|
274 |
+
)
|
275 |
+
else: # We're at step 2
|
276 |
+
return ChatMessage(
|
277 |
+
role=MessageRole.ASSISTANT,
|
278 |
+
content="""
|
279 |
+
Thought: I faced an error in the previous step.
|
280 |
+
<code>
|
281 |
+
final_answer("got an error")
|
282 |
+
</code>
|
283 |
+
""",
|
284 |
+
)
|
285 |
+
|
286 |
+
|
287 |
+
class FakeCodeModelSyntaxError(Model):
|
288 |
+
def generate(self, messages, stop_sequences=None):
|
289 |
+
prompt = str(messages)
|
290 |
+
if "special_marker" not in prompt:
|
291 |
+
return ChatMessage(
|
292 |
+
role=MessageRole.ASSISTANT,
|
293 |
+
content="""
|
294 |
+
Thought: I should multiply 2 by 3.6452. special_marker
|
295 |
+
<code>
|
296 |
+
a = 2
|
297 |
+
b = a * 2
|
298 |
+
print("Failing due to unexpected indent")
|
299 |
+
print("Ok, calculation done!")
|
300 |
+
</code>
|
301 |
+
""",
|
302 |
+
)
|
303 |
+
else: # We're at step 2
|
304 |
+
return ChatMessage(
|
305 |
+
role=MessageRole.ASSISTANT,
|
306 |
+
content="""
|
307 |
+
Thought: I can now answer the initial question
|
308 |
+
<code>
|
309 |
+
final_answer("got an error")
|
310 |
+
</code>
|
311 |
+
""",
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
class FakeCodeModelImport(Model):
|
316 |
+
def generate(self, messages, stop_sequences=None):
|
317 |
+
return ChatMessage(
|
318 |
+
role=MessageRole.ASSISTANT,
|
319 |
+
content="""
|
320 |
+
Thought: I can answer the question
|
321 |
+
<code>
|
322 |
+
import numpy as np
|
323 |
+
final_answer("got an error")
|
324 |
+
</code>
|
325 |
+
""",
|
326 |
+
)
|
327 |
+
|
328 |
+
|
329 |
+
class FakeCodeModelFunctionDef(Model):
|
330 |
+
def generate(self, messages, stop_sequences=None):
|
331 |
+
prompt = str(messages)
|
332 |
+
if "special_marker" not in prompt:
|
333 |
+
return ChatMessage(
|
334 |
+
role=MessageRole.ASSISTANT,
|
335 |
+
content="""
|
336 |
+
Thought: Let's define the function. special_marker
|
337 |
+
<code>
|
338 |
+
import numpy as np
|
339 |
+
|
340 |
+
def moving_average(x, w):
|
341 |
+
return np.convolve(x, np.ones(w), 'valid') / w
|
342 |
+
</code>
|
343 |
+
""",
|
344 |
+
)
|
345 |
+
else: # We're at step 2
|
346 |
+
return ChatMessage(
|
347 |
+
role=MessageRole.ASSISTANT,
|
348 |
+
content="""
|
349 |
+
Thought: I can now answer the initial question
|
350 |
+
<code>
|
351 |
+
x, w = [0, 1, 2, 3, 4, 5], 2
|
352 |
+
res = moving_average(x, w)
|
353 |
+
final_answer(res)
|
354 |
+
</code>
|
355 |
+
""",
|
356 |
+
)
|
357 |
+
|
358 |
+
|
359 |
+
class FakeCodeModelSingleStep(Model):
|
360 |
+
def generate(self, messages, stop_sequences=None):
|
361 |
+
return ChatMessage(
|
362 |
+
role=MessageRole.ASSISTANT,
|
363 |
+
content="""
|
364 |
+
Thought: I should multiply 2 by 3.6452. special_marker
|
365 |
+
<code>
|
366 |
+
result = python_interpreter(code="2*3.6452")
|
367 |
+
final_answer(result)
|
368 |
+
```
|
369 |
+
""",
|
370 |
+
)
|
371 |
+
|
372 |
+
|
373 |
+
class FakeCodeModelNoReturn(Model):
|
374 |
+
def generate(self, messages, stop_sequences=None):
|
375 |
+
return ChatMessage(
|
376 |
+
role=MessageRole.ASSISTANT,
|
377 |
+
content="""
|
378 |
+
Thought: I should multiply 2 by 3.6452. special_marker
|
379 |
+
<code>
|
380 |
+
result = python_interpreter(code="2*3.6452")
|
381 |
+
print(result)
|
382 |
+
```
|
383 |
+
""",
|
384 |
+
)
|
385 |
+
|
386 |
+
|
387 |
+
class TestAgent:
|
388 |
+
def test_fake_toolcalling_agent(self):
|
389 |
+
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
|
390 |
+
output = agent.run("What is 2 multiplied by 3.6452?")
|
391 |
+
assert isinstance(output, str)
|
392 |
+
assert "7.2904" in output
|
393 |
+
assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
|
394 |
+
assert "7.2904" in agent.memory.steps[1].observations
|
395 |
+
assert (
|
396 |
+
agent.memory.steps[2].model_output
|
397 |
+
== "Tool call call_1: calling 'final_answer' with arguments: {'answer': '7.2904'}"
|
398 |
+
)
|
399 |
+
|
400 |
+
def test_toolcalling_agent_handles_image_tool_outputs(self, shared_datadir):
|
401 |
+
import PIL.Image
|
402 |
+
|
403 |
+
@tool
|
404 |
+
def fake_image_generation_tool(prompt: str) -> PIL.Image.Image:
|
405 |
+
"""Tool that generates an image.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
prompt: The prompt
|
409 |
+
"""
|
410 |
+
|
411 |
+
import PIL.Image
|
412 |
+
|
413 |
+
return PIL.Image.open(shared_datadir / "000000039769.png")
|
414 |
+
|
415 |
+
agent = ToolCallingAgent(
|
416 |
+
tools=[fake_image_generation_tool], model=FakeToolCallModelImage(), verbosity_level=10
|
417 |
+
)
|
418 |
+
output = agent.run("Make me an image.")
|
419 |
+
assert isinstance(output, AgentImage)
|
420 |
+
assert isinstance(agent.state["image.png"], PIL.Image.Image)
|
421 |
+
|
422 |
+
def test_toolcalling_agent_handles_image_inputs(self, shared_datadir):
|
423 |
+
import PIL.Image
|
424 |
+
|
425 |
+
image = PIL.Image.open(shared_datadir / "000000039769.png") # dummy input
|
426 |
+
|
427 |
+
@tool
|
428 |
+
def fake_image_understanding_tool(prompt: str, image: PIL.Image.Image) -> str:
|
429 |
+
"""Tool that creates a caption for an image.
|
430 |
+
|
431 |
+
Args:
|
432 |
+
prompt: The prompt
|
433 |
+
image: The image
|
434 |
+
"""
|
435 |
+
return "The image is a cat."
|
436 |
+
|
437 |
+
agent = ToolCallingAgent(tools=[fake_image_understanding_tool], model=FakeToolCallModelVL())
|
438 |
+
output = agent.run("Caption this image.", images=[image])
|
439 |
+
assert output == "The image is a cat."
|
440 |
+
|
441 |
+
def test_fake_code_agent(self):
|
442 |
+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel(), verbosity_level=10)
|
443 |
+
output = agent.run("What is 2 multiplied by 3.6452?")
|
444 |
+
assert isinstance(output, float)
|
445 |
+
assert output == 7.2904
|
446 |
+
assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
|
447 |
+
assert agent.memory.steps[2].tool_calls == [
|
448 |
+
ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_2")
|
449 |
+
]
|
450 |
+
|
451 |
+
def test_additional_args_added_to_task(self):
|
452 |
+
agent = CodeAgent(tools=[], model=FakeCodeModel())
|
453 |
+
agent.run(
|
454 |
+
"What is 2 multiplied by 3.6452?",
|
455 |
+
additional_args={"instruction": "Remember this."},
|
456 |
+
)
|
457 |
+
assert "Remember this" in agent.task
|
458 |
+
|
459 |
+
def test_reset_conversations(self):
|
460 |
+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel())
|
461 |
+
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
|
462 |
+
assert output == 7.2904
|
463 |
+
assert len(agent.memory.steps) == 3
|
464 |
+
|
465 |
+
output = agent.run("What is 2 multiplied by 3.6452?", reset=False)
|
466 |
+
assert output == 7.2904
|
467 |
+
assert len(agent.memory.steps) == 5
|
468 |
+
|
469 |
+
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
|
470 |
+
assert output == 7.2904
|
471 |
+
assert len(agent.memory.steps) == 3
|
472 |
+
|
473 |
+
def test_setup_agent_with_empty_toolbox(self):
|
474 |
+
ToolCallingAgent(model=FakeToolCallModel(), tools=[])
|
475 |
+
|
476 |
+
def test_fails_max_steps(self):
|
477 |
+
agent = CodeAgent(
|
478 |
+
tools=[PythonInterpreterTool()],
|
479 |
+
model=FakeCodeModelNoReturn(), # use this callable because it never ends
|
480 |
+
max_steps=5,
|
481 |
+
)
|
482 |
+
answer = agent.run("What is 2 multiplied by 3.6452?")
|
483 |
+
assert len(agent.memory.steps) == 7 # Task step + 5 action steps + Final answer
|
484 |
+
assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
|
485 |
+
assert isinstance(answer, str)
|
486 |
+
|
487 |
+
agent = CodeAgent(
|
488 |
+
tools=[PythonInterpreterTool()],
|
489 |
+
model=FakeCodeModelNoReturn(), # use this callable because it never ends
|
490 |
+
max_steps=5,
|
491 |
+
)
|
492 |
+
answer = agent.run("What is 2 multiplied by 3.6452?", max_steps=3)
|
493 |
+
assert len(agent.memory.steps) == 5 # Task step + 3 action steps + Final answer
|
494 |
+
assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
|
495 |
+
assert isinstance(answer, str)
|
496 |
+
|
497 |
+
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
498 |
+
tool = PythonInterpreterTool()
|
499 |
+
tool.name = "fake_tool_name"
|
500 |
+
tool.description = "fake_tool_description"
|
501 |
+
agent = CodeAgent(tools=[tool], model=FakeCodeModel())
|
502 |
+
agent.run("Empty task")
|
503 |
+
assert agent.system_prompt is not None
|
504 |
+
assert f"def {tool.name}(" in agent.system_prompt
|
505 |
+
assert f'"""{tool.description}' in agent.system_prompt
|
506 |
+
|
507 |
+
def test_module_imports_get_baked_in_system_prompt(self):
|
508 |
+
agent = CodeAgent(tools=[], model=FakeCodeModel())
|
509 |
+
agent.run("Empty task")
|
510 |
+
for module in BASE_BUILTIN_MODULES:
|
511 |
+
assert module in agent.system_prompt
|
512 |
+
|
513 |
+
def test_init_agent_with_different_toolsets(self):
|
514 |
+
toolset_1 = []
|
515 |
+
agent = CodeAgent(tools=toolset_1, model=FakeCodeModel())
|
516 |
+
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
|
517 |
+
|
518 |
+
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
519 |
+
with pytest.raises(ValueError) as e:
|
520 |
+
agent = CodeAgent(tools=toolset_2, model=FakeCodeModel())
|
521 |
+
assert "Each tool or managed_agent should have a unique name!" in str(e)
|
522 |
+
|
523 |
+
with pytest.raises(ValueError) as e:
|
524 |
+
agent.name = "python_interpreter"
|
525 |
+
agent.description = "empty"
|
526 |
+
CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModel(), managed_agents=[agent])
|
527 |
+
assert "Each tool or managed_agent should have a unique name!" in str(e)
|
528 |
+
|
529 |
+
# check that python_interpreter base tool does not get added to CodeAgent
|
530 |
+
agent = CodeAgent(tools=[], model=FakeCodeModel(), add_base_tools=True)
|
531 |
+
assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage
|
532 |
+
|
533 |
+
# check that python_interpreter base tool gets added to ToolCallingAgent
|
534 |
+
agent = ToolCallingAgent(tools=[], model=FakeCodeModel(), add_base_tools=True)
|
535 |
+
assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage
|
536 |
+
|
537 |
+
def test_function_persistence_across_steps(self):
|
538 |
+
agent = CodeAgent(
|
539 |
+
tools=[],
|
540 |
+
model=FakeCodeModelFunctionDef(),
|
541 |
+
max_steps=2,
|
542 |
+
additional_authorized_imports=["numpy"],
|
543 |
+
verbosity_level=100,
|
544 |
+
)
|
545 |
+
res = agent.run("ok")
|
546 |
+
assert res[0] == 0.5
|
547 |
+
|
548 |
+
def test_init_managed_agent(self):
|
549 |
+
agent = CodeAgent(tools=[], model=FakeCodeModelFunctionDef(), name="managed_agent", description="Empty")
|
550 |
+
assert agent.name == "managed_agent"
|
551 |
+
assert agent.description == "Empty"
|
552 |
+
|
553 |
+
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
554 |
+
managed_agent = CodeAgent(
|
555 |
+
tools=[], model=FakeCodeModelFunctionDef(), name="managed_agent", description="Empty"
|
556 |
+
)
|
557 |
+
manager_agent = CodeAgent(
|
558 |
+
tools=[],
|
559 |
+
model=FakeCodeModelFunctionDef(),
|
560 |
+
managed_agents=[managed_agent],
|
561 |
+
)
|
562 |
+
assert "You can also give tasks to team members." not in managed_agent.system_prompt
|
563 |
+
assert "{{managed_agents_descriptions}}" not in managed_agent.system_prompt
|
564 |
+
assert "You can also give tasks to team members." in manager_agent.system_prompt
|
565 |
+
|
566 |
+
def test_replay_shows_logs(self, agent_logger):
|
567 |
+
agent = CodeAgent(
|
568 |
+
tools=[],
|
569 |
+
model=FakeCodeModelImport(),
|
570 |
+
verbosity_level=0,
|
571 |
+
additional_authorized_imports=["numpy"],
|
572 |
+
logger=agent_logger,
|
573 |
+
)
|
574 |
+
agent.run("Count to 3")
|
575 |
+
|
576 |
+
str_output = agent_logger.console.export_text()
|
577 |
+
|
578 |
+
assert "New run" in str_output
|
579 |
+
assert 'final_answer("got' in str_output
|
580 |
+
assert "</code>" in str_output
|
581 |
+
|
582 |
+
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel(), verbosity_level=0)
|
583 |
+
agent.logger = agent_logger
|
584 |
+
|
585 |
+
agent.run("What is 2 multiplied by 3.6452?")
|
586 |
+
agent.replay()
|
587 |
+
|
588 |
+
str_output = agent_logger.console.export_text()
|
589 |
+
assert "Tool call" in str_output
|
590 |
+
assert "arguments" in str_output
|
591 |
+
|
592 |
+
def test_code_nontrivial_final_answer_works(self):
|
593 |
+
class FakeCodeModelFinalAnswer(Model):
|
594 |
+
def generate(self, messages, stop_sequences=None):
|
595 |
+
return ChatMessage(
|
596 |
+
role=MessageRole.ASSISTANT,
|
597 |
+
content="""<code>
|
598 |
+
def nested_answer():
|
599 |
+
final_answer("Correct!")
|
600 |
+
|
601 |
+
nested_answer()
|
602 |
+
</code>""",
|
603 |
+
)
|
604 |
+
|
605 |
+
agent = CodeAgent(tools=[], model=FakeCodeModelFinalAnswer())
|
606 |
+
|
607 |
+
output = agent.run("Count to 3")
|
608 |
+
assert output == "Correct!"
|
609 |
+
|
610 |
+
def test_transformers_toolcalling_agent(self):
|
611 |
+
@tool
|
612 |
+
def weather_api(location: str, celsius: str = "") -> str:
|
613 |
+
"""
|
614 |
+
Gets the weather in the next days at given location.
|
615 |
+
Secretly this tool does not care about the location, it hates the weather everywhere.
|
616 |
+
|
617 |
+
Args:
|
618 |
+
location: the location
|
619 |
+
celsius: the temperature type
|
620 |
+
"""
|
621 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
622 |
+
|
623 |
+
model = TransformersModel(
|
624 |
+
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
|
625 |
+
max_new_tokens=100,
|
626 |
+
device_map="auto",
|
627 |
+
do_sample=False,
|
628 |
+
)
|
629 |
+
agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1, verbosity_level=10)
|
630 |
+
task = "What is the weather in Paris? "
|
631 |
+
agent.run(task)
|
632 |
+
assert agent.memory.steps[0].task == task
|
633 |
+
assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
|
634 |
+
step_memory_dict = agent.memory.get_succinct_steps()[1]
|
635 |
+
assert step_memory_dict["model_output_message"]["tool_calls"][0]["function"]["name"] == "weather_api"
|
636 |
+
assert step_memory_dict["model_output_message"]["raw"]["completion_kwargs"]["max_new_tokens"] == 100
|
637 |
+
assert "model_input_messages" in agent.memory.get_full_steps()[1]
|
638 |
+
assert step_memory_dict["token_usage"]["total_tokens"] > 100
|
639 |
+
assert step_memory_dict["timing"]["duration"] > 0.1
|
640 |
+
|
641 |
+
def test_final_answer_checks(self):
|
642 |
+
error_string = "failed with error"
|
643 |
+
|
644 |
+
def check_always_fails(final_answer, agent_memory):
|
645 |
+
assert False, "Error raised in check"
|
646 |
+
|
647 |
+
agent = CodeAgent(model=FakeCodeModel(), tools=[], final_answer_checks=[check_always_fails])
|
648 |
+
agent.run("Dummy task.")
|
649 |
+
assert error_string in str(agent.write_memory_to_messages())
|
650 |
+
assert "Error raised in check" in str(agent.write_memory_to_messages())
|
651 |
+
|
652 |
+
agent = CodeAgent(
|
653 |
+
model=FakeCodeModel(),
|
654 |
+
tools=[],
|
655 |
+
final_answer_checks=[lambda x, y: x == 7.2904],
|
656 |
+
verbosity_level=1000,
|
657 |
+
)
|
658 |
+
output = agent.run("Dummy task.")
|
659 |
+
assert output == 7.2904 # Check that output is correct
|
660 |
+
assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2
|
661 |
+
assert error_string not in str(agent.write_memory_to_messages())
|
662 |
+
|
663 |
+
def test_generation_errors_are_raised(self):
|
664 |
+
class FakeCodeModel(Model):
|
665 |
+
def generate(self, messages, stop_sequences=None):
|
666 |
+
assert False, "Generation failed"
|
667 |
+
|
668 |
+
agent = CodeAgent(model=FakeCodeModel(), tools=[])
|
669 |
+
with pytest.raises(AgentGenerationError) as e:
|
670 |
+
agent.run("Dummy task.")
|
671 |
+
assert len(agent.memory.steps) == 2
|
672 |
+
assert "Generation failed" in str(e)
|
673 |
+
|
674 |
+
def test_planning_step_with_injected_memory(self):
|
675 |
+
"""Test that agent properly uses update plan prompts when memory is injected before a run.
|
676 |
+
|
677 |
+
This test verifies:
|
678 |
+
1. Planning steps are created with the correct frequency
|
679 |
+
2. Injected memory is included in planning context
|
680 |
+
3. Messages are properly formatted with expected roles and content
|
681 |
+
"""
|
682 |
+
planning_interval = 1
|
683 |
+
max_steps = 4
|
684 |
+
task = "Continuous task"
|
685 |
+
previous_task = "Previous user request"
|
686 |
+
|
687 |
+
# Create agent with planning capability
|
688 |
+
agent = CodeAgent(
|
689 |
+
tools=[],
|
690 |
+
planning_interval=planning_interval,
|
691 |
+
model=FakeCodeModelPlanning(),
|
692 |
+
max_steps=max_steps,
|
693 |
+
)
|
694 |
+
|
695 |
+
# Inject memory before run to simulate existing conversation history
|
696 |
+
previous_step = TaskStep(task=previous_task)
|
697 |
+
agent.memory.steps.append(previous_step)
|
698 |
+
|
699 |
+
# Run the agent
|
700 |
+
agent.run(task, reset=False)
|
701 |
+
|
702 |
+
# Extract and validate planning steps
|
703 |
+
planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)]
|
704 |
+
assert len(planning_steps) > 2, "Expected multiple planning steps to be generated"
|
705 |
+
|
706 |
+
# Verify first planning step incorporates injected memory
|
707 |
+
first_planning_step = planning_steps[0]
|
708 |
+
input_messages = first_planning_step.model_input_messages
|
709 |
+
|
710 |
+
# Check message structure and content
|
711 |
+
assert len(input_messages) == 4, (
|
712 |
+
"First planning step should have 4 messages: system-plan-pre-update + memory + task + user-plan-post-update"
|
713 |
+
)
|
714 |
+
|
715 |
+
# Verify system message contains current task
|
716 |
+
system_message = input_messages[0]
|
717 |
+
assert system_message.role == "system", "First message should have system role"
|
718 |
+
assert task in system_message.content[0]["text"], f"System message should contain the current task: '{task}'"
|
719 |
+
|
720 |
+
# Verify memory message contains previous task
|
721 |
+
memory_message = input_messages[1]
|
722 |
+
assert previous_task in memory_message.content[0]["text"], (
|
723 |
+
f"Memory message should contain previous task: '{previous_task}'"
|
724 |
+
)
|
725 |
+
|
726 |
+
# Verify task message contains current task
|
727 |
+
task_message = input_messages[2]
|
728 |
+
assert task in task_message.content[0]["text"], f"Task message should contain current task: '{task}'"
|
729 |
+
|
730 |
+
# Verify user message for planning
|
731 |
+
user_message = input_messages[3]
|
732 |
+
assert user_message.role == "user", "Fourth message should have user role"
|
733 |
+
|
734 |
+
# Verify second planning step has more context from first agent actions
|
735 |
+
second_planning_step = planning_steps[1]
|
736 |
+
second_messages = second_planning_step.model_input_messages
|
737 |
+
|
738 |
+
# Check that conversation history is growing appropriately
|
739 |
+
assert len(second_messages) == 6, "Second planning step should have 6 messages including tool interactions"
|
740 |
+
|
741 |
+
# Verify all conversation elements are present
|
742 |
+
conversation_text = "".join([msg.content[0]["text"] for msg in second_messages if hasattr(msg, "content")])
|
743 |
+
assert previous_task in conversation_text, "Previous task should be included in the conversation history"
|
744 |
+
assert task in conversation_text, "Current task should be included in the conversation history"
|
745 |
+
assert "tools" in conversation_text, "Tool interactions should be included in the conversation history"
|
746 |
+
|
747 |
+
|
748 |
+
class CustomFinalAnswerTool(FinalAnswerTool):
|
749 |
+
def forward(self, answer) -> str:
|
750 |
+
return answer + "CUSTOM"
|
751 |
+
|
752 |
+
|
753 |
+
class MockTool(Tool):
|
754 |
+
def __init__(self, name):
|
755 |
+
self.name = name
|
756 |
+
self.description = "Mock tool description"
|
757 |
+
self.inputs = {}
|
758 |
+
self.output_type = "string"
|
759 |
+
|
760 |
+
def forward(self):
|
761 |
+
return "Mock tool output"
|
762 |
+
|
763 |
+
|
764 |
+
class MockAgent:
|
765 |
+
def __init__(self, name, tools, description="Mock agent description"):
|
766 |
+
self.name = name
|
767 |
+
self.tools = {t.name: t for t in tools}
|
768 |
+
self.description = description
|
769 |
+
|
770 |
+
|
771 |
+
class DummyMultiStepAgent(MultiStepAgent):
|
772 |
+
def step(self, memory_step: ActionStep) -> Generator[None]:
|
773 |
+
yield None
|
774 |
+
|
775 |
+
def initialize_system_prompt(self):
|
776 |
+
pass
|
777 |
+
|
778 |
+
|
779 |
+
class TestMultiStepAgent:
|
780 |
+
def test_instantiation_disables_logging_to_terminal(self):
|
781 |
+
fake_model = MagicMock()
|
782 |
+
agent = DummyMultiStepAgent(tools=[], model=fake_model)
|
783 |
+
assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture"
|
784 |
+
|
785 |
+
def test_instantiation_with_prompt_templates(self, prompt_templates):
|
786 |
+
agent = DummyMultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates)
|
787 |
+
assert agent.prompt_templates == prompt_templates
|
788 |
+
assert agent.prompt_templates["system_prompt"] == "This is a test system prompt."
|
789 |
+
assert "managed_agent" in agent.prompt_templates
|
790 |
+
assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}"
|
791 |
+
assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}"
|
792 |
+
|
793 |
+
@pytest.mark.parametrize(
|
794 |
+
"tools, expected_final_answer_tool",
|
795 |
+
[([], FinalAnswerTool), ([CustomFinalAnswerTool()], CustomFinalAnswerTool)],
|
796 |
+
)
|
797 |
+
def test_instantiation_with_final_answer_tool(self, tools, expected_final_answer_tool):
|
798 |
+
agent = DummyMultiStepAgent(tools=tools, model=MagicMock())
|
799 |
+
assert "final_answer" in agent.tools
|
800 |
+
assert isinstance(agent.tools["final_answer"], expected_final_answer_tool)
|
801 |
+
|
802 |
+
def test_instantiation_with_deprecated_grammar(self):
|
803 |
+
class SimpleAgent(MultiStepAgent):
|
804 |
+
def initialize_system_prompt(self) -> str:
|
805 |
+
return "Test system prompt"
|
806 |
+
|
807 |
+
# Test with a non-None grammar parameter
|
808 |
+
with pytest.warns(
|
809 |
+
FutureWarning, match="Parameter 'grammar' is deprecated and will be removed in version 1.20."
|
810 |
+
):
|
811 |
+
SimpleAgent(tools=[], model=MagicMock(), grammar={"format": "json"}, verbosity_level=LogLevel.DEBUG)
|
812 |
+
|
813 |
+
# Verify no warning when grammar is None
|
814 |
+
with warnings.catch_warnings():
|
815 |
+
warnings.simplefilter("error") # Turn warnings into errors
|
816 |
+
SimpleAgent(tools=[], model=MagicMock(), grammar=None, verbosity_level=LogLevel.DEBUG)
|
817 |
+
|
818 |
+
def test_system_prompt_property(self):
|
819 |
+
"""Test that system_prompt property is read-only and calls initialize_system_prompt."""
|
820 |
+
|
821 |
+
class SimpleAgent(MultiStepAgent):
|
822 |
+
def initialize_system_prompt(self) -> str:
|
823 |
+
return "Test system prompt"
|
824 |
+
|
825 |
+
def step(self, memory_step: ActionStep) -> Generator[None]:
|
826 |
+
yield None
|
827 |
+
|
828 |
+
# Create a simple agent with mocked model
|
829 |
+
model = MagicMock()
|
830 |
+
agent = SimpleAgent(tools=[], model=model)
|
831 |
+
|
832 |
+
# Test reading the property works and calls initialize_system_prompt
|
833 |
+
assert agent.system_prompt == "Test system prompt"
|
834 |
+
|
835 |
+
# Test setting the property raises AttributeError with correct message
|
836 |
+
with pytest.raises(
|
837 |
+
AttributeError,
|
838 |
+
match=re.escape(
|
839 |
+
"""The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
|
840 |
+
),
|
841 |
+
):
|
842 |
+
agent.system_prompt = "New system prompt"
|
843 |
+
|
844 |
+
# assert "read-only" in str(exc_info.value)
|
845 |
+
# assert "Use 'self.prompt_templates[\"system_prompt\"]' instead" in str(exc_info.value)
|
846 |
+
|
847 |
+
def test_logs_display_thoughts_even_if_error(self):
|
848 |
+
class FakeJsonModelNoCall(Model):
|
849 |
+
def generate(self, messages, stop_sequences=None, tools_to_call_from=None):
|
850 |
+
return ChatMessage(
|
851 |
+
role=MessageRole.ASSISTANT,
|
852 |
+
content="""I don't want to call tools today""",
|
853 |
+
tool_calls=None,
|
854 |
+
raw="""I don't want to call tools today""",
|
855 |
+
)
|
856 |
+
|
857 |
+
agent_toolcalling = ToolCallingAgent(model=FakeJsonModelNoCall(), tools=[], max_steps=1, verbosity_level=10)
|
858 |
+
with agent_toolcalling.logger.console.capture() as capture:
|
859 |
+
agent_toolcalling.run("Dummy task")
|
860 |
+
assert "don't" in capture.get() and "want" in capture.get()
|
861 |
+
|
862 |
+
class FakeCodeModelNoCall(Model):
|
863 |
+
def generate(self, messages, stop_sequences=None):
|
864 |
+
return ChatMessage(
|
865 |
+
role=MessageRole.ASSISTANT,
|
866 |
+
content="""I don't want to write an action today""",
|
867 |
+
)
|
868 |
+
|
869 |
+
agent_code = CodeAgent(model=FakeCodeModelNoCall(), tools=[], max_steps=1, verbosity_level=10)
|
870 |
+
with agent_code.logger.console.capture() as capture:
|
871 |
+
agent_code.run("Dummy task")
|
872 |
+
assert "don't" in capture.get() and "want" in capture.get()
|
873 |
+
|
874 |
+
def test_step_number(self):
|
875 |
+
fake_model = MagicMock()
|
876 |
+
fake_model.generate.return_value = ChatMessage(
|
877 |
+
role=MessageRole.ASSISTANT,
|
878 |
+
content="Model output.",
|
879 |
+
tool_calls=None,
|
880 |
+
raw="Model output.",
|
881 |
+
token_usage=None,
|
882 |
+
)
|
883 |
+
max_steps = 2
|
884 |
+
agent = CodeAgent(tools=[], model=fake_model, max_steps=max_steps)
|
885 |
+
assert hasattr(agent, "step_number"), "step_number attribute should be defined"
|
886 |
+
assert agent.step_number == 0, "step_number should be initialized to 0"
|
887 |
+
agent.run("Test task")
|
888 |
+
assert hasattr(agent, "step_number"), "step_number attribute should be defined"
|
889 |
+
assert agent.step_number == max_steps + 1, "step_number should be max_steps + 1 after run method is called"
|
890 |
+
|
891 |
+
@pytest.mark.parametrize(
|
892 |
+
"step, expected_messages_list",
|
893 |
+
[
|
894 |
+
(
|
895 |
+
1,
|
896 |
+
[
|
897 |
+
[
|
898 |
+
ChatMessage(
|
899 |
+
role=MessageRole.USER, content=[{"type": "text", "text": "INITIAL_PLAN_USER_PROMPT"}]
|
900 |
+
),
|
901 |
+
],
|
902 |
+
],
|
903 |
+
),
|
904 |
+
(
|
905 |
+
2,
|
906 |
+
[
|
907 |
+
[
|
908 |
+
ChatMessage(
|
909 |
+
role=MessageRole.SYSTEM,
|
910 |
+
content=[{"type": "text", "text": "UPDATE_PLAN_SYSTEM_PROMPT"}],
|
911 |
+
),
|
912 |
+
ChatMessage(
|
913 |
+
role=MessageRole.USER,
|
914 |
+
content=[{"type": "text", "text": "UPDATE_PLAN_USER_PROMPT"}],
|
915 |
+
),
|
916 |
+
],
|
917 |
+
],
|
918 |
+
),
|
919 |
+
],
|
920 |
+
)
|
921 |
+
def test_planning_step(self, step, expected_messages_list):
|
922 |
+
fake_model = MagicMock()
|
923 |
+
agent = CodeAgent(
|
924 |
+
tools=[],
|
925 |
+
model=fake_model,
|
926 |
+
)
|
927 |
+
task = "Test task"
|
928 |
+
|
929 |
+
planning_step = list(agent._generate_planning_step(task, is_first_step=(step == 1), step=step))[-1]
|
930 |
+
expected_message_texts = {
|
931 |
+
"INITIAL_PLAN_USER_PROMPT": populate_template(
|
932 |
+
agent.prompt_templates["planning"]["initial_plan"],
|
933 |
+
variables=dict(
|
934 |
+
task=task,
|
935 |
+
tools=agent.tools,
|
936 |
+
managed_agents=agent.managed_agents,
|
937 |
+
answer_facts=planning_step.model_output_message.content,
|
938 |
+
),
|
939 |
+
),
|
940 |
+
"UPDATE_PLAN_SYSTEM_PROMPT": populate_template(
|
941 |
+
agent.prompt_templates["planning"]["update_plan_pre_messages"], variables=dict(task=task)
|
942 |
+
),
|
943 |
+
"UPDATE_PLAN_USER_PROMPT": populate_template(
|
944 |
+
agent.prompt_templates["planning"]["update_plan_post_messages"],
|
945 |
+
variables=dict(
|
946 |
+
task=task,
|
947 |
+
tools=agent.tools,
|
948 |
+
managed_agents=agent.managed_agents,
|
949 |
+
facts_update=planning_step.model_output_message.content,
|
950 |
+
remaining_steps=agent.max_steps - step,
|
951 |
+
),
|
952 |
+
),
|
953 |
+
}
|
954 |
+
for expected_messages in expected_messages_list:
|
955 |
+
for expected_message in expected_messages:
|
956 |
+
expected_message.content[0]["text"] = expected_message_texts[expected_message.content[0]["text"]]
|
957 |
+
assert isinstance(planning_step, PlanningStep)
|
958 |
+
expected_model_input_messages = expected_messages_list[0]
|
959 |
+
model_input_messages = planning_step.model_input_messages
|
960 |
+
assert isinstance(model_input_messages, list)
|
961 |
+
assert len(model_input_messages) == len(expected_model_input_messages) # 2
|
962 |
+
for message, expected_message in zip(model_input_messages, expected_model_input_messages):
|
963 |
+
assert isinstance(message, ChatMessage)
|
964 |
+
assert message.role in MessageRole.__members__.values()
|
965 |
+
assert message.role == expected_message.role
|
966 |
+
assert isinstance(message.content, list)
|
967 |
+
for content, expected_content in zip(message.content, expected_message.content):
|
968 |
+
assert content == expected_content
|
969 |
+
# Test calls to model
|
970 |
+
assert len(fake_model.generate.call_args_list) == 1
|
971 |
+
for call_args, expected_messages in zip(fake_model.generate.call_args_list, expected_messages_list):
|
972 |
+
assert len(call_args.args) == 1
|
973 |
+
messages = call_args.args[0]
|
974 |
+
assert isinstance(messages, list)
|
975 |
+
assert len(messages) == len(expected_messages)
|
976 |
+
for message, expected_message in zip(messages, expected_messages):
|
977 |
+
assert isinstance(message, ChatMessage)
|
978 |
+
assert message.role in MessageRole.__members__.values()
|
979 |
+
assert message.role == expected_message.role
|
980 |
+
assert isinstance(message.content, list)
|
981 |
+
for content, expected_content in zip(message.content, expected_message.content):
|
982 |
+
assert content == expected_content
|
983 |
+
|
984 |
+
@pytest.mark.parametrize(
|
985 |
+
"images, expected_messages_list",
|
986 |
+
[
|
987 |
+
(
|
988 |
+
None,
|
989 |
+
[
|
990 |
+
[
|
991 |
+
ChatMessage(
|
992 |
+
role=MessageRole.SYSTEM,
|
993 |
+
content=[{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}],
|
994 |
+
),
|
995 |
+
ChatMessage(
|
996 |
+
role=MessageRole.USER,
|
997 |
+
content=[{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}],
|
998 |
+
),
|
999 |
+
]
|
1000 |
+
],
|
1001 |
+
),
|
1002 |
+
(
|
1003 |
+
["image1.png"],
|
1004 |
+
[
|
1005 |
+
[
|
1006 |
+
ChatMessage(
|
1007 |
+
role=MessageRole.SYSTEM,
|
1008 |
+
content=[
|
1009 |
+
{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"},
|
1010 |
+
{"type": "image", "image": "image1.png"},
|
1011 |
+
],
|
1012 |
+
),
|
1013 |
+
ChatMessage(
|
1014 |
+
role=MessageRole.USER,
|
1015 |
+
content=[{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}],
|
1016 |
+
),
|
1017 |
+
]
|
1018 |
+
],
|
1019 |
+
),
|
1020 |
+
],
|
1021 |
+
)
|
1022 |
+
def test_provide_final_answer(self, images, expected_messages_list):
|
1023 |
+
fake_model = MagicMock()
|
1024 |
+
fake_model.generate.return_value = ChatMessage(
|
1025 |
+
role=MessageRole.ASSISTANT,
|
1026 |
+
content="Final answer.",
|
1027 |
+
tool_calls=None,
|
1028 |
+
raw="Final answer.",
|
1029 |
+
token_usage=None,
|
1030 |
+
)
|
1031 |
+
agent = CodeAgent(
|
1032 |
+
tools=[],
|
1033 |
+
model=fake_model,
|
1034 |
+
)
|
1035 |
+
task = "Test task"
|
1036 |
+
final_answer = agent.provide_final_answer(task, images=images).content
|
1037 |
+
expected_message_texts = {
|
1038 |
+
"FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"],
|
1039 |
+
"FINAL_ANSWER_USER_PROMPT": populate_template(
|
1040 |
+
agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task)
|
1041 |
+
),
|
1042 |
+
}
|
1043 |
+
for expected_messages in expected_messages_list:
|
1044 |
+
for expected_message in expected_messages:
|
1045 |
+
for expected_content in expected_message.content:
|
1046 |
+
if "text" in expected_content:
|
1047 |
+
expected_content["text"] = expected_message_texts[expected_content["text"]]
|
1048 |
+
assert final_answer == "Final answer."
|
1049 |
+
# Test calls to model
|
1050 |
+
assert len(fake_model.generate.call_args_list) == 1
|
1051 |
+
for call_args, expected_messages in zip(fake_model.generate.call_args_list, expected_messages_list):
|
1052 |
+
assert len(call_args.args) == 1
|
1053 |
+
messages = call_args.args[0]
|
1054 |
+
assert isinstance(messages, list)
|
1055 |
+
assert len(messages) == len(expected_messages)
|
1056 |
+
for message, expected_message in zip(messages, expected_messages):
|
1057 |
+
assert isinstance(message, ChatMessage)
|
1058 |
+
assert message.role in MessageRole.__members__.values()
|
1059 |
+
assert message.role == expected_message.role
|
1060 |
+
assert isinstance(message.content, list)
|
1061 |
+
for content, expected_content in zip(message.content, expected_message.content):
|
1062 |
+
assert content == expected_content
|
1063 |
+
|
1064 |
+
def test_interrupt(self):
|
1065 |
+
fake_model = MagicMock()
|
1066 |
+
fake_model.generate.return_value = ChatMessage(
|
1067 |
+
role=MessageRole.ASSISTANT,
|
1068 |
+
content="Model output.",
|
1069 |
+
tool_calls=None,
|
1070 |
+
raw="Model output.",
|
1071 |
+
token_usage=None,
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
def interrupt_callback(memory_step, agent):
|
1075 |
+
agent.interrupt()
|
1076 |
+
|
1077 |
+
agent = CodeAgent(
|
1078 |
+
tools=[],
|
1079 |
+
model=fake_model,
|
1080 |
+
step_callbacks=[interrupt_callback],
|
1081 |
+
)
|
1082 |
+
with pytest.raises(AgentError) as e:
|
1083 |
+
agent.run("Test task")
|
1084 |
+
assert "Agent interrupted" in str(e)
|
1085 |
+
|
1086 |
+
@pytest.mark.parametrize(
|
1087 |
+
"tools, managed_agents, name, expectation",
|
1088 |
+
[
|
1089 |
+
# Valid case: no duplicates
|
1090 |
+
(
|
1091 |
+
[MockTool("tool1"), MockTool("tool2")],
|
1092 |
+
[MockAgent("agent1", [MockTool("tool3")])],
|
1093 |
+
"test_agent",
|
1094 |
+
does_not_raise(),
|
1095 |
+
),
|
1096 |
+
# Invalid case: duplicate tool names
|
1097 |
+
([MockTool("tool1"), MockTool("tool1")], [], "test_agent", pytest.raises(ValueError)),
|
1098 |
+
# Invalid case: tool name same as managed agent name
|
1099 |
+
(
|
1100 |
+
[MockTool("tool1")],
|
1101 |
+
[MockAgent("tool1", [MockTool("final_answer")])],
|
1102 |
+
"test_agent",
|
1103 |
+
pytest.raises(ValueError),
|
1104 |
+
),
|
1105 |
+
# Valid case: tool name same as managed agent's tool name
|
1106 |
+
([MockTool("tool1")], [MockAgent("agent1", [MockTool("tool1")])], "test_agent", does_not_raise()),
|
1107 |
+
# Invalid case: duplicate managed agent name and managed agent tool name
|
1108 |
+
([MockTool("tool1")], [], "tool1", pytest.raises(ValueError)),
|
1109 |
+
# Valid case: duplicate tool names across managed agents
|
1110 |
+
(
|
1111 |
+
[MockTool("tool1")],
|
1112 |
+
[
|
1113 |
+
MockAgent("agent1", [MockTool("tool2"), MockTool("final_answer")]),
|
1114 |
+
MockAgent("agent2", [MockTool("tool2"), MockTool("final_answer")]),
|
1115 |
+
],
|
1116 |
+
"test_agent",
|
1117 |
+
does_not_raise(),
|
1118 |
+
),
|
1119 |
+
],
|
1120 |
+
)
|
1121 |
+
def test_validate_tools_and_managed_agents(self, tools, managed_agents, name, expectation):
|
1122 |
+
fake_model = MagicMock()
|
1123 |
+
with expectation:
|
1124 |
+
DummyMultiStepAgent(
|
1125 |
+
tools=tools,
|
1126 |
+
model=fake_model,
|
1127 |
+
name=name,
|
1128 |
+
managed_agents=managed_agents,
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
def test_from_dict(self):
|
1132 |
+
# Create a test agent dictionary
|
1133 |
+
agent_dict = {
|
1134 |
+
"model": {"class": "TransformersModel", "data": {"model_id": "test/model"}},
|
1135 |
+
"tools": [
|
1136 |
+
{
|
1137 |
+
"name": "valid_tool_function",
|
1138 |
+
"code": 'from smolagents import Tool\nfrom typing import Any, Optional\n\nclass SimpleTool(Tool):\n name = "valid_tool_function"\n description = "A valid tool function."\n inputs = {"input":{"type":"string","description":"Input string."}}\n output_type = "string"\n\n def forward(self, input: str) -> str:\n """A valid tool function.\n\n Args:\n input (str): Input string.\n """\n return input.upper()',
|
1139 |
+
"requirements": {"smolagents"},
|
1140 |
+
}
|
1141 |
+
],
|
1142 |
+
"managed_agents": {},
|
1143 |
+
"prompt_templates": EMPTY_PROMPT_TEMPLATES,
|
1144 |
+
"max_steps": 15,
|
1145 |
+
"verbosity_level": 2,
|
1146 |
+
"planning_interval": 3,
|
1147 |
+
"name": "test_agent",
|
1148 |
+
"description": "Test agent description",
|
1149 |
+
}
|
1150 |
+
|
1151 |
+
# Call from_dict
|
1152 |
+
with patch("smolagents.models.TransformersModel") as mock_model_class:
|
1153 |
+
mock_model_instance = mock_model_class.from_dict.return_value
|
1154 |
+
agent = DummyMultiStepAgent.from_dict(agent_dict)
|
1155 |
+
|
1156 |
+
# Verify the agent was created correctly
|
1157 |
+
assert agent.model == mock_model_instance
|
1158 |
+
assert mock_model_class.from_dict.call_args.args[0] == {"model_id": "test/model"}
|
1159 |
+
assert agent.max_steps == 15
|
1160 |
+
assert agent.logger.level == 2
|
1161 |
+
assert agent.planning_interval == 3
|
1162 |
+
assert agent.name == "test_agent"
|
1163 |
+
assert agent.description == "Test agent description"
|
1164 |
+
# Verify the tool was created correctly
|
1165 |
+
assert sorted(agent.tools.keys()) == ["final_answer", "valid_tool_function"]
|
1166 |
+
assert agent.tools["valid_tool_function"].name == "valid_tool_function"
|
1167 |
+
assert agent.tools["valid_tool_function"].description == "A valid tool function."
|
1168 |
+
assert agent.tools["valid_tool_function"].inputs == {
|
1169 |
+
"input": {"type": "string", "description": "Input string."}
|
1170 |
+
}
|
1171 |
+
assert agent.tools["valid_tool_function"]("test") == "TEST"
|
1172 |
+
|
1173 |
+
# Test overriding with kwargs
|
1174 |
+
with patch("smolagents.models.TransformersModel") as mock_model_class:
|
1175 |
+
agent = DummyMultiStepAgent.from_dict(agent_dict, max_steps=30)
|
1176 |
+
assert agent.max_steps == 30
|
1177 |
+
|
1178 |
+
|
1179 |
+
class TestToolCallingAgent:
|
1180 |
+
def test_toolcalling_agent_instructions(self):
|
1181 |
+
agent = ToolCallingAgent(tools=[], model=MagicMock(), instructions="Test instructions")
|
1182 |
+
assert agent.instructions == "Test instructions"
|
1183 |
+
assert "Test instructions" in agent.system_prompt
|
1184 |
+
|
1185 |
+
def test_toolcalling_agent_passes_both_tools_and_managed_agents(self, test_tool):
|
1186 |
+
"""Test that both tools and managed agents are passed to the model."""
|
1187 |
+
managed_agent = MagicMock()
|
1188 |
+
managed_agent.name = "managed_agent"
|
1189 |
+
model = MagicMock()
|
1190 |
+
model.generate.return_value = ChatMessage(
|
1191 |
+
role=MessageRole.ASSISTANT,
|
1192 |
+
content="",
|
1193 |
+
tool_calls=[
|
1194 |
+
ChatMessageToolCall(
|
1195 |
+
id="call_0",
|
1196 |
+
type="function",
|
1197 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "test_value"}),
|
1198 |
+
)
|
1199 |
+
],
|
1200 |
+
)
|
1201 |
+
agent = ToolCallingAgent(tools=[test_tool], managed_agents=[managed_agent], model=model)
|
1202 |
+
# Run the agent one step to trigger the model call
|
1203 |
+
next(agent.run("Test task", stream=True))
|
1204 |
+
# Check that the model was called with both tools and managed agents:
|
1205 |
+
# - Get all tool_to_call_from names passed to the model
|
1206 |
+
tools_to_call_from_names = [tool.name for tool in model.generate.call_args.kwargs["tools_to_call_from"]]
|
1207 |
+
# - Verify both regular tools and managed agents are included
|
1208 |
+
assert "test_tool" in tools_to_call_from_names # The regular tool
|
1209 |
+
assert "managed_agent" in tools_to_call_from_names # The managed agent
|
1210 |
+
assert "final_answer" in tools_to_call_from_names # The final_answer tool (added by default)
|
1211 |
+
|
1212 |
+
@patch("huggingface_hub.InferenceClient")
|
1213 |
+
def test_toolcalling_agent_api(self, mock_inference_client):
|
1214 |
+
mock_client = mock_inference_client.return_value
|
1215 |
+
mock_response = mock_client.chat_completion.return_value
|
1216 |
+
mock_response.choices[0].message = ChatCompletionOutputMessage(
|
1217 |
+
role=MessageRole.ASSISTANT,
|
1218 |
+
content='{"name": "weather_api", "arguments": {"location": "Paris", "date": "today"}}',
|
1219 |
+
)
|
1220 |
+
mock_response.usage.prompt_tokens = 10
|
1221 |
+
mock_response.usage.completion_tokens = 20
|
1222 |
+
|
1223 |
+
model = InferenceClientModel(model_id="test-model")
|
1224 |
+
|
1225 |
+
from smolagents import tool
|
1226 |
+
|
1227 |
+
@tool
|
1228 |
+
def weather_api(location: str, date: str) -> str:
|
1229 |
+
"""
|
1230 |
+
Gets the weather in the next days at given location.
|
1231 |
+
Args:
|
1232 |
+
location: the location
|
1233 |
+
date: the date
|
1234 |
+
"""
|
1235 |
+
return f"The weather in {location} on date:{date} is sunny."
|
1236 |
+
|
1237 |
+
agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1)
|
1238 |
+
agent.run("What's the weather in Paris?")
|
1239 |
+
assert agent.memory.steps[0].task == "What's the weather in Paris?"
|
1240 |
+
assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
|
1241 |
+
assert agent.memory.steps[1].tool_calls[0].arguments == {"location": "Paris", "date": "today"}
|
1242 |
+
assert agent.memory.steps[1].observations == "The weather in Paris on date:today is sunny."
|
1243 |
+
|
1244 |
+
mock_response.choices[0].message = ChatCompletionOutputMessage(
|
1245 |
+
role=MessageRole.ASSISTANT,
|
1246 |
+
content=None,
|
1247 |
+
tool_calls=[
|
1248 |
+
ChatCompletionOutputToolCall(
|
1249 |
+
function=ChatCompletionOutputFunctionDefinition(
|
1250 |
+
name="weather_api", arguments='{"location": "Paris", "date": "today"}'
|
1251 |
+
),
|
1252 |
+
id="call_0",
|
1253 |
+
type="function",
|
1254 |
+
)
|
1255 |
+
],
|
1256 |
+
)
|
1257 |
+
|
1258 |
+
agent.run("What's the weather in Paris?")
|
1259 |
+
assert agent.memory.steps[0].task == "What's the weather in Paris?"
|
1260 |
+
assert agent.memory.steps[1].tool_calls[0].name == "weather_api"
|
1261 |
+
assert agent.memory.steps[1].tool_calls[0].arguments == {"location": "Paris", "date": "today"}
|
1262 |
+
assert agent.memory.steps[1].observations == "The weather in Paris on date:today is sunny."
|
1263 |
+
|
1264 |
+
@patch("openai.OpenAI")
|
1265 |
+
def test_toolcalling_agent_stream_outputs_multiple_tool_calls(self, mock_openai_client, test_tool):
|
1266 |
+
"""Test that ToolCallingAgent with stream_outputs=True returns the first final_answer when multiple are called."""
|
1267 |
+
mock_client = mock_openai_client.return_value
|
1268 |
+
from smolagents import OpenAIServerModel
|
1269 |
+
|
1270 |
+
# Mock streaming response with multiple final_answer calls
|
1271 |
+
mock_deltas = [
|
1272 |
+
ChoiceDelta(role=MessageRole.ASSISTANT),
|
1273 |
+
ChoiceDelta(
|
1274 |
+
tool_calls=[
|
1275 |
+
ChoiceDeltaToolCall(
|
1276 |
+
index=0,
|
1277 |
+
id="call_1",
|
1278 |
+
function=ChoiceDeltaToolCallFunction(name="final_answer"),
|
1279 |
+
type="function",
|
1280 |
+
)
|
1281 |
+
]
|
1282 |
+
),
|
1283 |
+
ChoiceDelta(
|
1284 |
+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"an'))]
|
1285 |
+
),
|
1286 |
+
ChoiceDelta(
|
1287 |
+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='swer"'))]
|
1288 |
+
),
|
1289 |
+
ChoiceDelta(
|
1290 |
+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments=': "out'))]
|
1291 |
+
),
|
1292 |
+
ChoiceDelta(
|
1293 |
+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments="put1"))]
|
1294 |
+
),
|
1295 |
+
ChoiceDelta(
|
1296 |
+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))]
|
1297 |
+
),
|
1298 |
+
ChoiceDelta(
|
1299 |
+
tool_calls=[
|
1300 |
+
ChoiceDeltaToolCall(
|
1301 |
+
index=1,
|
1302 |
+
id="call_2",
|
1303 |
+
function=ChoiceDeltaToolCallFunction(name="test_tool"),
|
1304 |
+
type="function",
|
1305 |
+
)
|
1306 |
+
]
|
1307 |
+
),
|
1308 |
+
ChoiceDelta(
|
1309 |
+
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"in'))]
|
1310 |
+
),
|
1311 |
+
ChoiceDelta(
|
1312 |
+
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='put"'))]
|
1313 |
+
),
|
1314 |
+
ChoiceDelta(
|
1315 |
+
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments=': "out'))]
|
1316 |
+
),
|
1317 |
+
ChoiceDelta(
|
1318 |
+
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments="put2"))]
|
1319 |
+
),
|
1320 |
+
ChoiceDelta(
|
1321 |
+
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"}'))]
|
1322 |
+
),
|
1323 |
+
]
|
1324 |
+
|
1325 |
+
class MockChoice:
|
1326 |
+
def __init__(self, delta):
|
1327 |
+
self.delta = delta
|
1328 |
+
|
1329 |
+
class MockChunk:
|
1330 |
+
def __init__(self, delta):
|
1331 |
+
self.choices = [MockChoice(delta)]
|
1332 |
+
self.usage = None
|
1333 |
+
|
1334 |
+
mock_client.chat.completions.create.return_value = (MockChunk(delta) for delta in mock_deltas)
|
1335 |
+
|
1336 |
+
# Mock usage for non-streaming fallback
|
1337 |
+
mock_usage = MagicMock()
|
1338 |
+
mock_usage.prompt_tokens = 10
|
1339 |
+
mock_usage.completion_tokens = 20
|
1340 |
+
|
1341 |
+
model = OpenAIServerModel(model_id="fakemodel")
|
1342 |
+
|
1343 |
+
agent = ToolCallingAgent(model=model, tools=[test_tool], max_steps=1, stream_outputs=True)
|
1344 |
+
result = agent.run("Make 2 calls to final answer: return both 'output1' and 'output2'")
|
1345 |
+
assert len(agent.memory.steps[-1].model_output_message.tool_calls) == 2
|
1346 |
+
assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
|
1347 |
+
assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
|
1348 |
+
|
1349 |
+
# The agent should return the final answer call
|
1350 |
+
assert result == "output1"
|
1351 |
+
|
1352 |
+
@patch("huggingface_hub.InferenceClient")
|
1353 |
+
def test_toolcalling_agent_api_misformatted_output(self, mock_inference_client):
|
1354 |
+
"""Test that even misformatted json blobs don't interrupt the run for a ToolCallingAgent."""
|
1355 |
+
mock_client = mock_inference_client.return_value
|
1356 |
+
mock_response = mock_client.chat_completion.return_value
|
1357 |
+
mock_response.choices[0].message = ChatCompletionOutputMessage(
|
1358 |
+
role=MessageRole.ASSISTANT,
|
1359 |
+
content='{"name": weather_api", "arguments": {"location": "Paris", "date": "today"}}',
|
1360 |
+
)
|
1361 |
+
|
1362 |
+
mock_response.usage.prompt_tokens = 10
|
1363 |
+
mock_response.usage.completion_tokens = 20
|
1364 |
+
|
1365 |
+
model = InferenceClientModel(model_id="test-model")
|
1366 |
+
|
1367 |
+
logger = AgentLogger(console=Console(markup=False, no_color=True))
|
1368 |
+
|
1369 |
+
agent = ToolCallingAgent(model=model, tools=[], max_steps=2, verbosity_level=1, logger=logger)
|
1370 |
+
with agent.logger.console.capture() as capture:
|
1371 |
+
agent.run("What's the weather in Paris?")
|
1372 |
+
assert agent.memory.steps[0].task == "What's the weather in Paris?"
|
1373 |
+
assert agent.memory.steps[1].tool_calls is None
|
1374 |
+
assert "The JSON blob you used is invalid" in agent.memory.steps[1].error.message
|
1375 |
+
assert "Error while parsing" in capture.get()
|
1376 |
+
assert len(agent.memory.steps) == 4
|
1377 |
+
|
1378 |
+
def test_change_tools_after_init(self):
|
1379 |
+
from smolagents import tool
|
1380 |
+
|
1381 |
+
@tool
|
1382 |
+
def fake_tool_1() -> str:
|
1383 |
+
"""Fake tool"""
|
1384 |
+
return "1"
|
1385 |
+
|
1386 |
+
@tool
|
1387 |
+
def fake_tool_2() -> str:
|
1388 |
+
"""Fake tool"""
|
1389 |
+
return "2"
|
1390 |
+
|
1391 |
+
class FakeCodeModel(Model):
|
1392 |
+
def generate(self, messages, stop_sequences=None):
|
1393 |
+
return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(fake_tool_1())\n</code>")
|
1394 |
+
|
1395 |
+
agent = CodeAgent(tools=[fake_tool_1], model=FakeCodeModel())
|
1396 |
+
|
1397 |
+
agent.tools["final_answer"] = CustomFinalAnswerTool()
|
1398 |
+
agent.tools["fake_tool_1"] = fake_tool_2
|
1399 |
+
|
1400 |
+
answer = agent.run("Fake task.")
|
1401 |
+
assert answer == "2CUSTOM"
|
1402 |
+
|
1403 |
+
def test_custom_final_answer_with_custom_inputs(self, test_tool):
|
1404 |
+
class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
|
1405 |
+
inputs = {
|
1406 |
+
"answer1": {"type": "string", "description": "First part of the answer."},
|
1407 |
+
"answer2": {"type": "string", "description": "Second part of the answer."},
|
1408 |
+
}
|
1409 |
+
|
1410 |
+
def forward(self, answer1: str, answer2: str) -> str:
|
1411 |
+
return answer1 + " and " + answer2
|
1412 |
+
|
1413 |
+
model = MagicMock()
|
1414 |
+
model.generate.return_value = ChatMessage(
|
1415 |
+
role=MessageRole.ASSISTANT,
|
1416 |
+
content=None,
|
1417 |
+
tool_calls=[
|
1418 |
+
ChatMessageToolCall(
|
1419 |
+
id="call_0",
|
1420 |
+
type="function",
|
1421 |
+
function=ChatMessageToolCallFunction(
|
1422 |
+
name="final_answer", arguments={"answer1": "1", "answer2": "2"}
|
1423 |
+
),
|
1424 |
+
),
|
1425 |
+
ChatMessageToolCall(
|
1426 |
+
id="call_1",
|
1427 |
+
type="function",
|
1428 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "3"}),
|
1429 |
+
),
|
1430 |
+
],
|
1431 |
+
)
|
1432 |
+
agent = ToolCallingAgent(tools=[test_tool, CustomFinalAnswerToolWithCustomInputs()], model=model)
|
1433 |
+
answer = agent.run("Fake task.")
|
1434 |
+
assert answer == "1 and 2"
|
1435 |
+
assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
|
1436 |
+
assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
|
1437 |
+
|
1438 |
+
@pytest.mark.parametrize(
|
1439 |
+
"test_case",
|
1440 |
+
[
|
1441 |
+
# Case 0: Single valid tool call
|
1442 |
+
{
|
1443 |
+
"tool_calls": [
|
1444 |
+
ChatMessageToolCall(
|
1445 |
+
id="call_1",
|
1446 |
+
type="function",
|
1447 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "test_value"}),
|
1448 |
+
)
|
1449 |
+
],
|
1450 |
+
"expected_model_output": "Tool call call_1: calling 'test_tool' with arguments: {'input': 'test_value'}",
|
1451 |
+
"expected_observations": "Processed: test_value",
|
1452 |
+
"expected_final_outputs": ["Processed: test_value"],
|
1453 |
+
"expected_error": None,
|
1454 |
+
},
|
1455 |
+
# Case 1: Multiple tool calls
|
1456 |
+
{
|
1457 |
+
"tool_calls": [
|
1458 |
+
ChatMessageToolCall(
|
1459 |
+
id="call_1",
|
1460 |
+
type="function",
|
1461 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "value1"}),
|
1462 |
+
),
|
1463 |
+
ChatMessageToolCall(
|
1464 |
+
id="call_2",
|
1465 |
+
type="function",
|
1466 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "value2"}),
|
1467 |
+
),
|
1468 |
+
],
|
1469 |
+
"expected_model_output": "Tool call call_1: calling 'test_tool' with arguments: {'input': 'value1'}\nTool call call_2: calling 'test_tool' with arguments: {'input': 'value2'}",
|
1470 |
+
"expected_observations": "Processed: value1\nProcessed: value2",
|
1471 |
+
"expected_final_outputs": ["Processed: value1", "Processed: value2"],
|
1472 |
+
"expected_error": None,
|
1473 |
+
},
|
1474 |
+
# Case 2: Invalid tool name
|
1475 |
+
{
|
1476 |
+
"tool_calls": [
|
1477 |
+
ChatMessageToolCall(
|
1478 |
+
id="call_1",
|
1479 |
+
type="function",
|
1480 |
+
function=ChatMessageToolCallFunction(name="nonexistent_tool", arguments={"input": "test"}),
|
1481 |
+
)
|
1482 |
+
],
|
1483 |
+
"expected_error": AgentToolExecutionError,
|
1484 |
+
},
|
1485 |
+
# Case 3: Tool execution error
|
1486 |
+
{
|
1487 |
+
"tool_calls": [
|
1488 |
+
ChatMessageToolCall(
|
1489 |
+
id="call_1",
|
1490 |
+
type="function",
|
1491 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "error"}),
|
1492 |
+
)
|
1493 |
+
],
|
1494 |
+
"expected_error": AgentToolExecutionError,
|
1495 |
+
},
|
1496 |
+
# Case 4: Empty tool calls list
|
1497 |
+
{
|
1498 |
+
"tool_calls": [],
|
1499 |
+
"expected_model_output": "",
|
1500 |
+
"expected_observations": "",
|
1501 |
+
"expected_final_outputs": [],
|
1502 |
+
"expected_error": None,
|
1503 |
+
},
|
1504 |
+
# Case 5: Final answer call
|
1505 |
+
{
|
1506 |
+
"tool_calls": [
|
1507 |
+
ChatMessageToolCall(
|
1508 |
+
id="call_1",
|
1509 |
+
type="function",
|
1510 |
+
function=ChatMessageToolCallFunction(
|
1511 |
+
name="final_answer", arguments={"answer": "This is the final answer"}
|
1512 |
+
),
|
1513 |
+
)
|
1514 |
+
],
|
1515 |
+
"expected_model_output": "Tool call call_1: calling 'final_answer' with arguments: {'answer': 'This is the final answer'}",
|
1516 |
+
"expected_observations": "This is the final answer",
|
1517 |
+
"expected_final_outputs": ["This is the final answer"],
|
1518 |
+
"expected_error": None,
|
1519 |
+
},
|
1520 |
+
# Case 6: Invalid arguments
|
1521 |
+
{
|
1522 |
+
"tool_calls": [
|
1523 |
+
ChatMessageToolCall(
|
1524 |
+
id="call_1",
|
1525 |
+
type="function",
|
1526 |
+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"wrong_param": "value"}),
|
1527 |
+
)
|
1528 |
+
],
|
1529 |
+
"expected_error": AgentToolCallError,
|
1530 |
+
},
|
1531 |
+
],
|
1532 |
+
)
|
1533 |
+
def test_process_tool_calls(self, test_case, test_tool):
|
1534 |
+
# Create a ToolCallingAgent instance with the test tool
|
1535 |
+
agent = ToolCallingAgent(tools=[test_tool], model=MagicMock())
|
1536 |
+
# Create chat message with the specified tool calls for process_tool_calls
|
1537 |
+
chat_message = ChatMessage(role=MessageRole.ASSISTANT, content="", tool_calls=test_case["tool_calls"])
|
1538 |
+
# Create a memory step for process_tool_calls
|
1539 |
+
memory_step = ActionStep(step_number=10, timing="mock_timing")
|
1540 |
+
|
1541 |
+
# Process tool calls
|
1542 |
+
if test_case["expected_error"]:
|
1543 |
+
with pytest.raises(test_case["expected_error"]):
|
1544 |
+
list(agent.process_tool_calls(chat_message, memory_step))
|
1545 |
+
else:
|
1546 |
+
final_outputs = list(agent.process_tool_calls(chat_message, memory_step))
|
1547 |
+
assert memory_step.model_output == test_case["expected_model_output"]
|
1548 |
+
assert memory_step.observations == test_case["expected_observations"]
|
1549 |
+
assert [
|
1550 |
+
final_output.output for final_output in final_outputs if isinstance(final_output, ToolOutput)
|
1551 |
+
] == test_case["expected_final_outputs"]
|
1552 |
+
# Verify memory step tool calls were updated correctly
|
1553 |
+
if test_case["tool_calls"]:
|
1554 |
+
assert memory_step.tool_calls == [
|
1555 |
+
ToolCall(name=tool_call.function.name, arguments=tool_call.function.arguments, id=tool_call.id)
|
1556 |
+
for tool_call in test_case["tool_calls"]
|
1557 |
+
]
|
1558 |
+
|
1559 |
+
|
1560 |
+
class TestCodeAgent:
|
1561 |
+
def test_code_agent_instructions(self):
|
1562 |
+
agent = CodeAgent(tools=[], model=MagicMock(), instructions="Test instructions")
|
1563 |
+
assert agent.instructions == "Test instructions"
|
1564 |
+
assert "Test instructions" in agent.system_prompt
|
1565 |
+
|
1566 |
+
agent = CodeAgent(
|
1567 |
+
tools=[], model=MagicMock(), instructions="Test instructions", use_structured_outputs_internally=True
|
1568 |
+
)
|
1569 |
+
assert agent.instructions == "Test instructions"
|
1570 |
+
assert "Test instructions" in agent.system_prompt
|
1571 |
+
|
1572 |
+
@pytest.mark.filterwarnings("ignore") # Ignore FutureWarning for deprecated grammar parameter
|
1573 |
+
def test_init_with_incompatible_grammar_and_use_structured_outputs_internally(self):
|
1574 |
+
# Test that using both parameters raises ValueError with correct message
|
1575 |
+
with pytest.raises(
|
1576 |
+
ValueError, match="You cannot use 'grammar' and 'use_structured_outputs_internally' at the same time."
|
1577 |
+
):
|
1578 |
+
CodeAgent(
|
1579 |
+
tools=[],
|
1580 |
+
model=MagicMock(),
|
1581 |
+
grammar={"format": "json"},
|
1582 |
+
use_structured_outputs_internally=True,
|
1583 |
+
verbosity_level=LogLevel.DEBUG,
|
1584 |
+
)
|
1585 |
+
|
1586 |
+
# Verify no error when only one option is used
|
1587 |
+
# Only grammar
|
1588 |
+
agent_with_grammar = CodeAgent(
|
1589 |
+
tools=[],
|
1590 |
+
model=MagicMock(),
|
1591 |
+
grammar={"format": "json"},
|
1592 |
+
use_structured_outputs_internally=False,
|
1593 |
+
verbosity_level=LogLevel.DEBUG,
|
1594 |
+
)
|
1595 |
+
assert agent_with_grammar.grammar is not None
|
1596 |
+
assert agent_with_grammar._use_structured_outputs_internally is False
|
1597 |
+
|
1598 |
+
# Only structured output
|
1599 |
+
agent_with_structured = CodeAgent(
|
1600 |
+
tools=[],
|
1601 |
+
model=MagicMock(),
|
1602 |
+
grammar=None,
|
1603 |
+
use_structured_outputs_internally=True,
|
1604 |
+
verbosity_level=LogLevel.DEBUG,
|
1605 |
+
)
|
1606 |
+
assert agent_with_structured.grammar is None
|
1607 |
+
assert agent_with_structured._use_structured_outputs_internally is True
|
1608 |
+
|
1609 |
+
@pytest.mark.parametrize("provide_run_summary", [False, True])
|
1610 |
+
def test_call_with_provide_run_summary(self, provide_run_summary):
|
1611 |
+
agent = CodeAgent(tools=[], model=MagicMock(), provide_run_summary=provide_run_summary)
|
1612 |
+
assert agent.provide_run_summary is provide_run_summary
|
1613 |
+
agent.name = "test_agent"
|
1614 |
+
agent.run = MagicMock(return_value="Test output")
|
1615 |
+
agent.write_memory_to_messages = MagicMock(return_value=[{"content": "Test summary"}])
|
1616 |
+
|
1617 |
+
result = agent("Test request")
|
1618 |
+
expected_summary = "Here is the final answer from your managed agent 'test_agent':\nTest output"
|
1619 |
+
if provide_run_summary:
|
1620 |
+
expected_summary += (
|
1621 |
+
"\n\nFor more detail, find below a summary of this agent's work:\n"
|
1622 |
+
"<summary_of_work>\n\nTest summary\n---\n</summary_of_work>"
|
1623 |
+
)
|
1624 |
+
assert result == expected_summary
|
1625 |
+
|
1626 |
+
def test_errors_logging(self):
|
1627 |
+
class FakeCodeModel(Model):
|
1628 |
+
def generate(self, messages, stop_sequences=None):
|
1629 |
+
return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nsecret=3;['1', '2'][secret]\n</code>")
|
1630 |
+
|
1631 |
+
agent = CodeAgent(tools=[], model=FakeCodeModel(), verbosity_level=1)
|
1632 |
+
|
1633 |
+
with agent.logger.console.capture() as capture:
|
1634 |
+
agent.run("Test request")
|
1635 |
+
assert "secret\\\\" in repr(capture.get())
|
1636 |
+
|
1637 |
+
def test_missing_import_triggers_advice_in_error_log(self):
|
1638 |
+
# Set explicit verbosity level to 1 to override the default verbosity level of -1 set in CI fixture
|
1639 |
+
agent = CodeAgent(tools=[], model=FakeCodeModelImport(), verbosity_level=1)
|
1640 |
+
|
1641 |
+
with agent.logger.console.capture() as capture:
|
1642 |
+
agent.run("Count to 3")
|
1643 |
+
str_output = capture.get()
|
1644 |
+
assert "`additional_authorized_imports`" in str_output.replace("\n", "")
|
1645 |
+
|
1646 |
+
def test_errors_show_offending_line_and_error(self):
|
1647 |
+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError())
|
1648 |
+
output = agent.run("What is 2 multiplied by 3.6452?")
|
1649 |
+
assert isinstance(output, AgentText)
|
1650 |
+
assert output == "got an error"
|
1651 |
+
assert "Code execution failed at line 'error_function()'" in str(agent.memory.steps[1].error)
|
1652 |
+
assert "ValueError" in str(agent.memory.steps)
|
1653 |
+
|
1654 |
+
def test_error_saves_previous_print_outputs(self):
|
1655 |
+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError(), verbosity_level=10)
|
1656 |
+
agent.run("What is 2 multiplied by 3.6452?")
|
1657 |
+
assert "Flag!" in str(agent.memory.steps[1].observations)
|
1658 |
+
|
1659 |
+
def test_syntax_error_show_offending_lines(self):
|
1660 |
+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelSyntaxError())
|
1661 |
+
output = agent.run("What is 2 multiplied by 3.6452?")
|
1662 |
+
assert isinstance(output, AgentText)
|
1663 |
+
assert output == "got an error"
|
1664 |
+
assert ' print("Failing due to unexpected indent")' in str(agent.memory.steps)
|
1665 |
+
assert isinstance(agent.memory.steps[-2], ActionStep)
|
1666 |
+
assert agent.memory.steps[-2].code_action == dedent("""a = 2
|
1667 |
+
b = a * 2
|
1668 |
+
print("Failing due to unexpected indent")
|
1669 |
+
print("Ok, calculation done!")""")
|
1670 |
+
|
1671 |
+
def test_end_code_appending(self):
|
1672 |
+
# Checking original output message
|
1673 |
+
orig_output = FakeCodeModelNoReturn().generate([])
|
1674 |
+
assert not orig_output.content.endswith("<end_code>")
|
1675 |
+
|
1676 |
+
# Checking the step output
|
1677 |
+
agent = CodeAgent(
|
1678 |
+
tools=[PythonInterpreterTool()],
|
1679 |
+
model=FakeCodeModelNoReturn(),
|
1680 |
+
max_steps=1,
|
1681 |
+
)
|
1682 |
+
answer = agent.run("What is 2 multiplied by 3.6452?")
|
1683 |
+
assert answer
|
1684 |
+
|
1685 |
+
memory_steps = agent.memory.steps
|
1686 |
+
actions_steps = [s for s in memory_steps if isinstance(s, ActionStep)]
|
1687 |
+
|
1688 |
+
outputs = [s.model_output for s in actions_steps if s.model_output]
|
1689 |
+
assert outputs
|
1690 |
+
assert all(o.endswith("<end_code>") for o in outputs)
|
1691 |
+
|
1692 |
+
messages = [s.model_output_message for s in actions_steps if s.model_output_message]
|
1693 |
+
assert messages
|
1694 |
+
assert all(m.content.endswith("<end_code>") for m in messages)
|
1695 |
+
|
1696 |
+
def test_change_tools_after_init(self):
|
1697 |
+
from smolagents import tool
|
1698 |
+
|
1699 |
+
@tool
|
1700 |
+
def fake_tool_1() -> str:
|
1701 |
+
"""Fake tool"""
|
1702 |
+
return "1"
|
1703 |
+
|
1704 |
+
@tool
|
1705 |
+
def fake_tool_2() -> str:
|
1706 |
+
"""Fake tool"""
|
1707 |
+
return "2"
|
1708 |
+
|
1709 |
+
class FakeCodeModel(Model):
|
1710 |
+
def generate(self, messages, stop_sequences=None):
|
1711 |
+
return ChatMessage(role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(fake_tool_1())\n</code>")
|
1712 |
+
|
1713 |
+
agent = CodeAgent(tools=[fake_tool_1], model=FakeCodeModel())
|
1714 |
+
|
1715 |
+
agent.tools["final_answer"] = CustomFinalAnswerTool()
|
1716 |
+
agent.tools["fake_tool_1"] = fake_tool_2
|
1717 |
+
|
1718 |
+
answer = agent.run("Fake task.")
|
1719 |
+
assert answer == "2CUSTOM"
|
1720 |
+
|
1721 |
+
def test_local_python_executor_with_custom_functions(self):
|
1722 |
+
model = MagicMock()
|
1723 |
+
model.generate.return_value = ChatMessage(
|
1724 |
+
role=MessageRole.ASSISTANT,
|
1725 |
+
content="",
|
1726 |
+
tool_calls=None,
|
1727 |
+
raw="",
|
1728 |
+
token_usage=None,
|
1729 |
+
)
|
1730 |
+
agent = CodeAgent(tools=[], model=model, executor_kwargs={"additional_functions": {"open": open}})
|
1731 |
+
agent.run("Test run")
|
1732 |
+
assert "open" in agent.python_executor.static_tools
|
1733 |
+
|
1734 |
+
@pytest.mark.parametrize("agent_dict_version", ["v1.9", "v1.10"])
|
1735 |
+
def test_from_folder(self, agent_dict_version, get_agent_dict):
|
1736 |
+
agent_dict = get_agent_dict(agent_dict_version)
|
1737 |
+
with (
|
1738 |
+
patch("smolagents.agents.Path") as mock_path,
|
1739 |
+
patch("smolagents.models.InferenceClientModel") as mock_model,
|
1740 |
+
):
|
1741 |
+
import json
|
1742 |
+
|
1743 |
+
mock_path.return_value.__truediv__.return_value.read_text.return_value = json.dumps(agent_dict)
|
1744 |
+
mock_model.from_dict.return_value.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
1745 |
+
agent = CodeAgent.from_folder("ignored_dummy_folder")
|
1746 |
+
assert isinstance(agent, CodeAgent)
|
1747 |
+
assert agent.name == "test_agent"
|
1748 |
+
assert agent.description == "dummy description"
|
1749 |
+
assert agent.max_steps == 10
|
1750 |
+
assert agent.planning_interval == 2
|
1751 |
+
assert agent.additional_authorized_imports == ["pandas"]
|
1752 |
+
assert "pandas" in agent.authorized_imports
|
1753 |
+
assert agent.executor_type == "local"
|
1754 |
+
assert agent.executor_kwargs == {}
|
1755 |
+
assert agent.max_print_outputs_length is None
|
1756 |
+
assert agent.managed_agents == {}
|
1757 |
+
assert set(agent.tools.keys()) == {"final_answer"}
|
1758 |
+
assert agent.model == mock_model.from_dict.return_value
|
1759 |
+
assert mock_model.from_dict.call_args.args[0]["model_id"] == "Qwen/Qwen2.5-Coder-32B-Instruct"
|
1760 |
+
assert agent.model.model_id == "Qwen/Qwen2.5-Coder-32B-Instruct"
|
1761 |
+
assert agent.logger.level == 2
|
1762 |
+
assert agent.prompt_templates["system_prompt"] == "dummy system prompt"
|
1763 |
+
|
1764 |
+
def test_from_dict(self):
|
1765 |
+
# Create a test agent dictionary
|
1766 |
+
agent_dict = {
|
1767 |
+
"model": {"class": "InferenceClientModel", "data": {"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct"}},
|
1768 |
+
"tools": [
|
1769 |
+
{
|
1770 |
+
"name": "valid_tool_function",
|
1771 |
+
"code": 'from smolagents import Tool\nfrom typing import Any, Optional\n\nclass SimpleTool(Tool):\n name = "valid_tool_function"\n description = "A valid tool function."\n inputs = {"input":{"type":"string","description":"Input string."}}\n output_type = "string"\n\n def forward(self, input: str) -> str:\n """A valid tool function.\n\n Args:\n input (str): Input string.\n """\n return input.upper()',
|
1772 |
+
"requirements": {"smolagents"},
|
1773 |
+
}
|
1774 |
+
],
|
1775 |
+
"managed_agents": {},
|
1776 |
+
"prompt_templates": EMPTY_PROMPT_TEMPLATES,
|
1777 |
+
"max_steps": 15,
|
1778 |
+
"verbosity_level": 2,
|
1779 |
+
"use_structured_output": False,
|
1780 |
+
"planning_interval": 3,
|
1781 |
+
"name": "test_code_agent",
|
1782 |
+
"description": "Test code agent description",
|
1783 |
+
"authorized_imports": ["pandas", "numpy"],
|
1784 |
+
"executor_type": "local",
|
1785 |
+
"executor_kwargs": {"max_print_outputs_length": 10_000},
|
1786 |
+
"max_print_outputs_length": 1000,
|
1787 |
+
}
|
1788 |
+
|
1789 |
+
# Call from_dict
|
1790 |
+
with patch("smolagents.models.InferenceClientModel") as mock_model_class:
|
1791 |
+
mock_model_instance = mock_model_class.from_dict.return_value
|
1792 |
+
agent = CodeAgent.from_dict(agent_dict)
|
1793 |
+
|
1794 |
+
# Verify the agent was created correctly with CodeAgent-specific parameters
|
1795 |
+
assert agent.model == mock_model_instance
|
1796 |
+
assert agent.additional_authorized_imports == ["pandas", "numpy"]
|
1797 |
+
assert agent.executor_type == "local"
|
1798 |
+
assert agent.executor_kwargs == {"max_print_outputs_length": 10_000}
|
1799 |
+
assert agent.max_print_outputs_length == 1000
|
1800 |
+
|
1801 |
+
# Test with missing optional parameters
|
1802 |
+
minimal_agent_dict = {
|
1803 |
+
"model": {"class": "InferenceClientModel", "data": {"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct"}},
|
1804 |
+
"tools": [],
|
1805 |
+
"managed_agents": {},
|
1806 |
+
}
|
1807 |
+
|
1808 |
+
with patch("smolagents.models.InferenceClientModel"):
|
1809 |
+
agent = CodeAgent.from_dict(minimal_agent_dict)
|
1810 |
+
# Verify defaults are used
|
1811 |
+
assert agent.max_steps == 20 # default from MultiStepAgent.__init__
|
1812 |
+
|
1813 |
+
# Test overriding with kwargs
|
1814 |
+
with patch("smolagents.models.InferenceClientModel"):
|
1815 |
+
agent = CodeAgent.from_dict(
|
1816 |
+
agent_dict,
|
1817 |
+
additional_authorized_imports=["matplotlib"],
|
1818 |
+
executor_kwargs={"max_print_outputs_length": 5_000},
|
1819 |
+
)
|
1820 |
+
assert agent.additional_authorized_imports == ["matplotlib"]
|
1821 |
+
assert agent.executor_kwargs == {"max_print_outputs_length": 5_000}
|
1822 |
+
|
1823 |
+
def test_custom_final_answer_with_custom_inputs(self):
|
1824 |
+
class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
|
1825 |
+
inputs = {
|
1826 |
+
"answer1": {"type": "string", "description": "First part of the answer."},
|
1827 |
+
"answer2": {"type": "string", "description": "Second part of the answer."},
|
1828 |
+
}
|
1829 |
+
|
1830 |
+
def forward(self, answer1: str, answer2: str) -> str:
|
1831 |
+
return answer1 + "CUSTOM" + answer2
|
1832 |
+
|
1833 |
+
model = MagicMock()
|
1834 |
+
model.generate.return_value = ChatMessage(
|
1835 |
+
role=MessageRole.ASSISTANT, content="<code>\nfinal_answer(answer1='1', answer2='2')\n</code>"
|
1836 |
+
)
|
1837 |
+
agent = CodeAgent(tools=[CustomFinalAnswerToolWithCustomInputs()], model=model)
|
1838 |
+
answer = agent.run("Fake task.")
|
1839 |
+
assert answer == "1CUSTOM2"
|
1840 |
+
|
1841 |
+
|
1842 |
+
class TestMultiAgents:
|
1843 |
+
def test_multiagents_save(self, tmp_path):
|
1844 |
+
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)
|
1845 |
+
|
1846 |
+
web_agent = ToolCallingAgent(
|
1847 |
+
model=model,
|
1848 |
+
tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()],
|
1849 |
+
name="web_agent",
|
1850 |
+
description="does web searches",
|
1851 |
+
)
|
1852 |
+
code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular")
|
1853 |
+
|
1854 |
+
agent = CodeAgent(
|
1855 |
+
model=model,
|
1856 |
+
tools=[],
|
1857 |
+
additional_authorized_imports=["pandas", "datetime"],
|
1858 |
+
managed_agents=[web_agent, code_agent],
|
1859 |
+
max_print_outputs_length=1000,
|
1860 |
+
executor_type="local",
|
1861 |
+
executor_kwargs={"max_print_outputs_length": 10_000},
|
1862 |
+
)
|
1863 |
+
agent.save(tmp_path)
|
1864 |
+
|
1865 |
+
expected_structure = {
|
1866 |
+
"managed_agents": {
|
1867 |
+
"useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]},
|
1868 |
+
"web_agent": {
|
1869 |
+
"tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]},
|
1870 |
+
"files": ["agent.json", "prompts.yaml"],
|
1871 |
+
},
|
1872 |
+
},
|
1873 |
+
"tools": {"files": ["final_answer.py"]},
|
1874 |
+
"files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"],
|
1875 |
+
}
|
1876 |
+
|
1877 |
+
def verify_structure(current_path: Path, structure: dict):
|
1878 |
+
for dir_name, contents in structure.items():
|
1879 |
+
if dir_name != "files":
|
1880 |
+
# For directories, verify they exist and recurse into them
|
1881 |
+
dir_path = current_path / dir_name
|
1882 |
+
assert dir_path.exists(), f"Directory {dir_path} does not exist"
|
1883 |
+
assert dir_path.is_dir(), f"{dir_path} is not a directory"
|
1884 |
+
verify_structure(dir_path, contents)
|
1885 |
+
else:
|
1886 |
+
# For files, verify each exists in the current path
|
1887 |
+
for file_name in contents:
|
1888 |
+
file_path = current_path / file_name
|
1889 |
+
assert file_path.exists(), f"File {file_path} does not exist"
|
1890 |
+
assert file_path.is_file(), f"{file_path} is not a file"
|
1891 |
+
|
1892 |
+
verify_structure(tmp_path, expected_structure)
|
1893 |
+
|
1894 |
+
# Test that re-loaded agents work as expected.
|
1895 |
+
agent2 = CodeAgent.from_folder(tmp_path, planning_interval=5)
|
1896 |
+
assert agent2.planning_interval == 5 # Check that kwargs are used
|
1897 |
+
assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
|
1898 |
+
assert agent2.max_print_outputs_length == 1000
|
1899 |
+
assert agent2.executor_type == "local"
|
1900 |
+
assert agent2.executor_kwargs == {"max_print_outputs_length": 10_000}
|
1901 |
+
assert (
|
1902 |
+
agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10
|
1903 |
+
) # For now tool init parameters are forgotten
|
1904 |
+
assert agent2.model.kwargs["temperature"] == pytest.approx(0.5)
|
1905 |
+
|
1906 |
+
def test_multiagents(self):
|
1907 |
+
class FakeModelMultiagentsManagerAgent(Model):
|
1908 |
+
model_id = "fake_model"
|
1909 |
+
|
1910 |
+
def generate(
|
1911 |
+
self,
|
1912 |
+
messages,
|
1913 |
+
stop_sequences=None,
|
1914 |
+
tools_to_call_from=None,
|
1915 |
+
):
|
1916 |
+
if tools_to_call_from is not None:
|
1917 |
+
if len(messages) < 3:
|
1918 |
+
return ChatMessage(
|
1919 |
+
role=MessageRole.ASSISTANT,
|
1920 |
+
content="",
|
1921 |
+
tool_calls=[
|
1922 |
+
ChatMessageToolCall(
|
1923 |
+
id="call_0",
|
1924 |
+
type="function",
|
1925 |
+
function=ChatMessageToolCallFunction(
|
1926 |
+
name="search_agent",
|
1927 |
+
arguments="Who is the current US president?",
|
1928 |
+
),
|
1929 |
+
)
|
1930 |
+
],
|
1931 |
+
)
|
1932 |
+
else:
|
1933 |
+
assert "Report on the current US president" in str(messages)
|
1934 |
+
return ChatMessage(
|
1935 |
+
role=MessageRole.ASSISTANT,
|
1936 |
+
content="",
|
1937 |
+
tool_calls=[
|
1938 |
+
ChatMessageToolCall(
|
1939 |
+
id="call_0",
|
1940 |
+
type="function",
|
1941 |
+
function=ChatMessageToolCallFunction(
|
1942 |
+
name="final_answer", arguments="Final report."
|
1943 |
+
),
|
1944 |
+
)
|
1945 |
+
],
|
1946 |
+
)
|
1947 |
+
else:
|
1948 |
+
if len(messages) < 3:
|
1949 |
+
return ChatMessage(
|
1950 |
+
role=MessageRole.ASSISTANT,
|
1951 |
+
content="""
|
1952 |
+
Thought: Let's call our search agent.
|
1953 |
+
<code>
|
1954 |
+
result = search_agent("Who is the current US president?")
|
1955 |
+
</code>
|
1956 |
+
""",
|
1957 |
+
)
|
1958 |
+
else:
|
1959 |
+
assert "Report on the current US president" in str(messages)
|
1960 |
+
return ChatMessage(
|
1961 |
+
role=MessageRole.ASSISTANT,
|
1962 |
+
content="""
|
1963 |
+
Thought: Let's return the report.
|
1964 |
+
<code>
|
1965 |
+
final_answer("Final report.")
|
1966 |
+
</code>
|
1967 |
+
""",
|
1968 |
+
)
|
1969 |
+
|
1970 |
+
manager_model = FakeModelMultiagentsManagerAgent()
|
1971 |
+
|
1972 |
+
class FakeModelMultiagentsManagedAgent(Model):
|
1973 |
+
model_id = "fake_model"
|
1974 |
+
|
1975 |
+
def generate(
|
1976 |
+
self,
|
1977 |
+
messages,
|
1978 |
+
tools_to_call_from=None,
|
1979 |
+
stop_sequences=None,
|
1980 |
+
):
|
1981 |
+
return ChatMessage(
|
1982 |
+
role=MessageRole.ASSISTANT,
|
1983 |
+
content="Here is the secret content: FLAG1",
|
1984 |
+
tool_calls=[
|
1985 |
+
ChatMessageToolCall(
|
1986 |
+
id="call_0",
|
1987 |
+
type="function",
|
1988 |
+
function=ChatMessageToolCallFunction(
|
1989 |
+
name="final_answer",
|
1990 |
+
arguments="Report on the current US president",
|
1991 |
+
),
|
1992 |
+
)
|
1993 |
+
],
|
1994 |
+
)
|
1995 |
+
|
1996 |
+
managed_model = FakeModelMultiagentsManagedAgent()
|
1997 |
+
|
1998 |
+
web_agent = ToolCallingAgent(
|
1999 |
+
tools=[],
|
2000 |
+
model=managed_model,
|
2001 |
+
max_steps=10,
|
2002 |
+
name="search_agent",
|
2003 |
+
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
|
2004 |
+
verbosity_level=2,
|
2005 |
+
)
|
2006 |
+
|
2007 |
+
manager_code_agent = CodeAgent(
|
2008 |
+
tools=[],
|
2009 |
+
model=manager_model,
|
2010 |
+
managed_agents=[web_agent],
|
2011 |
+
additional_authorized_imports=["time", "numpy", "pandas"],
|
2012 |
+
)
|
2013 |
+
|
2014 |
+
report = manager_code_agent.run("Fake question.")
|
2015 |
+
assert report == "Final report."
|
2016 |
+
|
2017 |
+
manager_toolcalling_agent = ToolCallingAgent(
|
2018 |
+
tools=[],
|
2019 |
+
model=manager_model,
|
2020 |
+
managed_agents=[web_agent],
|
2021 |
+
)
|
2022 |
+
|
2023 |
+
with web_agent.logger.console.capture() as capture:
|
2024 |
+
report = manager_toolcalling_agent.run("Fake question.")
|
2025 |
+
assert report == "Final report."
|
2026 |
+
assert "FLAG1" in capture.get() # Check that managed agent's output is properly logged
|
2027 |
+
|
2028 |
+
# Test that visualization works
|
2029 |
+
with manager_toolcalling_agent.logger.console.capture() as capture:
|
2030 |
+
manager_toolcalling_agent.visualize()
|
2031 |
+
assert "├──" in capture.get()
|
2032 |
+
|
2033 |
+
|
2034 |
+
@pytest.fixture
|
2035 |
+
def prompt_templates():
|
2036 |
+
return {
|
2037 |
+
"system_prompt": "This is a test system prompt.",
|
2038 |
+
"managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"},
|
2039 |
+
"planning": {
|
2040 |
+
"initial_plan": "The plan.",
|
2041 |
+
"update_plan_pre_messages": "custom",
|
2042 |
+
"update_plan_post_messages": "custom",
|
2043 |
+
},
|
2044 |
+
"final_answer": {"pre_messages": "custom", "post_messages": "custom"},
|
2045 |
+
}
|
2046 |
+
|
2047 |
+
|
2048 |
+
@pytest.mark.parametrize(
|
2049 |
+
"arguments",
|
2050 |
+
[
|
2051 |
+
{},
|
2052 |
+
{"arg": "bar"},
|
2053 |
+
{None: None},
|
2054 |
+
[1, 2, 3],
|
2055 |
+
],
|
2056 |
+
)
|
2057 |
+
def test_tool_calling_agents_raises_tool_call_error_being_invoked_with_wrong_arguments(arguments):
|
2058 |
+
@tool
|
2059 |
+
def _sample_tool(prompt: str) -> str:
|
2060 |
+
"""Tool that returns same string
|
2061 |
+
Args:
|
2062 |
+
prompt: The string to return
|
2063 |
+
Returns:
|
2064 |
+
The same string
|
2065 |
+
"""
|
2066 |
+
|
2067 |
+
return prompt
|
2068 |
+
|
2069 |
+
agent = ToolCallingAgent(model=FakeToolCallModel(), tools=[_sample_tool])
|
2070 |
+
with pytest.raises(AgentToolCallError):
|
2071 |
+
agent.execute_tool_call(_sample_tool.name, arguments)
|
2072 |
+
|
2073 |
+
|
2074 |
+
def test_tool_calling_agents_raises_agent_execution_error_when_tool_raises():
|
2075 |
+
@tool
|
2076 |
+
def _sample_tool(_: str) -> float:
|
2077 |
+
"""Tool that fails
|
2078 |
+
|
2079 |
+
Args:
|
2080 |
+
_: The pointless string
|
2081 |
+
Returns:
|
2082 |
+
Some number
|
2083 |
+
"""
|
2084 |
+
|
2085 |
+
return 1 / 0
|
2086 |
+
|
2087 |
+
agent = ToolCallingAgent(model=FakeToolCallModel(), tools=[_sample_tool])
|
2088 |
+
with pytest.raises(AgentExecutionError):
|
2089 |
+
agent.execute_tool_call(_sample_tool.name, "sample")
|
tests/test_all_docs.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import ast
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import shutil
|
20 |
+
import subprocess
|
21 |
+
import tempfile
|
22 |
+
import traceback
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
import pytest
|
26 |
+
from dotenv import load_dotenv
|
27 |
+
|
28 |
+
from .utils.markers import require_run_all
|
29 |
+
|
30 |
+
|
31 |
+
class SubprocessCallException(Exception):
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
def run_command(command: list[str], return_stdout=False, env=None):
|
36 |
+
"""
|
37 |
+
Runs command with subprocess.check_output and returns stdout if requested.
|
38 |
+
Properly captures and handles errors during command execution.
|
39 |
+
"""
|
40 |
+
for i, c in enumerate(command):
|
41 |
+
if isinstance(c, Path):
|
42 |
+
command[i] = str(c)
|
43 |
+
|
44 |
+
if env is None:
|
45 |
+
env = os.environ.copy()
|
46 |
+
|
47 |
+
try:
|
48 |
+
output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
|
49 |
+
if return_stdout:
|
50 |
+
if hasattr(output, "decode"):
|
51 |
+
output = output.decode("utf-8")
|
52 |
+
return output
|
53 |
+
except subprocess.CalledProcessError as e:
|
54 |
+
raise SubprocessCallException(
|
55 |
+
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
56 |
+
) from e
|
57 |
+
|
58 |
+
|
59 |
+
class DocCodeExtractor:
|
60 |
+
"""Handles extraction and validation of Python code from markdown files."""
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def extract_python_code(content: str) -> list[str]:
|
64 |
+
"""Extract Python code blocks from markdown content."""
|
65 |
+
pattern = r"```(?:python|py)\n(.*?)\n```"
|
66 |
+
matches = re.finditer(pattern, content, re.DOTALL)
|
67 |
+
return [match.group(1).strip() for match in matches]
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def create_test_script(code_blocks: list[str], tmp_dir: str) -> Path:
|
71 |
+
"""Create a temporary Python script from code blocks."""
|
72 |
+
combined_code = "\n\n".join(code_blocks)
|
73 |
+
assert len(combined_code) > 0, "Code is empty!"
|
74 |
+
tmp_file = Path(tmp_dir) / "test_script.py"
|
75 |
+
|
76 |
+
with open(tmp_file, "w", encoding="utf-8") as f:
|
77 |
+
f.write(combined_code)
|
78 |
+
|
79 |
+
return tmp_file
|
80 |
+
|
81 |
+
|
82 |
+
# Skip: slow tests + require API keys
|
83 |
+
@require_run_all
|
84 |
+
class TestDocs:
|
85 |
+
"""Test case for documentation code testing."""
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def setup_class(cls):
|
89 |
+
cls._tmpdir = tempfile.mkdtemp()
|
90 |
+
cls.launch_args = ["python3"]
|
91 |
+
cls.docs_dir = Path(__file__).parent.parent / "docs" / "source" / "en"
|
92 |
+
cls.extractor = DocCodeExtractor()
|
93 |
+
|
94 |
+
if not cls.docs_dir.exists():
|
95 |
+
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
|
96 |
+
|
97 |
+
load_dotenv()
|
98 |
+
|
99 |
+
cls.md_files = list(cls.docs_dir.rglob("*.md")) + list(cls.docs_dir.rglob("*.mdx"))
|
100 |
+
if not cls.md_files:
|
101 |
+
raise ValueError(f"No markdown files found in {cls.docs_dir}")
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def teardown_class(cls):
|
105 |
+
shutil.rmtree(cls._tmpdir)
|
106 |
+
|
107 |
+
@pytest.mark.timeout(100)
|
108 |
+
def test_single_doc(self, doc_path: Path):
|
109 |
+
"""Test a single documentation file."""
|
110 |
+
with open(doc_path, "r", encoding="utf-8") as f:
|
111 |
+
content = f.read()
|
112 |
+
|
113 |
+
code_blocks = self.extractor.extract_python_code(content)
|
114 |
+
excluded_snippets = [
|
115 |
+
"ToolCollection",
|
116 |
+
"image_generation_tool", # We don't want to run this expensive operation
|
117 |
+
"from_langchain", # Langchain is not a dependency
|
118 |
+
"while llm_should_continue(memory):", # This is pseudo code
|
119 |
+
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
|
120 |
+
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
|
121 |
+
"SmolagentsInstrumentor", # Exclude telemetry since it needs additional installs
|
122 |
+
]
|
123 |
+
code_blocks = [
|
124 |
+
block
|
125 |
+
for block in code_blocks
|
126 |
+
if not any(
|
127 |
+
[snippet in block for snippet in excluded_snippets]
|
128 |
+
) # Exclude these tools that take longer to run and add dependencies
|
129 |
+
]
|
130 |
+
if len(code_blocks) == 0:
|
131 |
+
pytest.skip(f"No Python code blocks found in {doc_path.name}")
|
132 |
+
|
133 |
+
# Validate syntax of each block individually by parsing it
|
134 |
+
for i, block in enumerate(code_blocks, 1):
|
135 |
+
ast.parse(block)
|
136 |
+
|
137 |
+
# Create and execute test script
|
138 |
+
print("\n\nCollected code block:==========\n".join(code_blocks))
|
139 |
+
try:
|
140 |
+
code_blocks = [
|
141 |
+
(
|
142 |
+
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
|
143 |
+
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
144 |
+
.replace("{your_username}", "m-ric")
|
145 |
+
)
|
146 |
+
for block in code_blocks
|
147 |
+
]
|
148 |
+
test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
|
149 |
+
run_command(self.launch_args + [str(test_script)])
|
150 |
+
|
151 |
+
except SubprocessCallException as e:
|
152 |
+
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
|
153 |
+
except Exception:
|
154 |
+
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
|
155 |
+
|
156 |
+
@pytest.fixture(autouse=True)
|
157 |
+
def _setup(self):
|
158 |
+
"""Fixture to ensure temporary directory exists for each test."""
|
159 |
+
os.makedirs(self._tmpdir, exist_ok=True)
|
160 |
+
yield
|
161 |
+
# Clean up test files after each test
|
162 |
+
for file in Path(self._tmpdir).glob("*"):
|
163 |
+
file.unlink()
|
164 |
+
|
165 |
+
|
166 |
+
def pytest_generate_tests(metafunc):
|
167 |
+
"""Generate test cases for each markdown file."""
|
168 |
+
if "doc_path" in metafunc.fixturenames:
|
169 |
+
test_class = metafunc.cls
|
170 |
+
|
171 |
+
# Initialize the class if needed
|
172 |
+
if not hasattr(test_class, "md_files"):
|
173 |
+
test_class.setup_class()
|
174 |
+
|
175 |
+
# Parameterize with the markdown files
|
176 |
+
metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])
|
tests/test_cli.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import patch
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from smolagents.cli import load_model
|
6 |
+
from smolagents.local_python_executor import LocalPythonExecutor
|
7 |
+
from smolagents.models import InferenceClientModel, LiteLLMModel, OpenAIServerModel, TransformersModel
|
8 |
+
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def set_env_vars(monkeypatch):
|
12 |
+
monkeypatch.setenv("FIREWORKS_API_KEY", "test_fireworks_api_key")
|
13 |
+
monkeypatch.setenv("HF_TOKEN", "test_hf_api_key")
|
14 |
+
|
15 |
+
|
16 |
+
def test_load_model_openai_server_model(set_env_vars):
|
17 |
+
with patch("openai.OpenAI") as MockOpenAI:
|
18 |
+
model = load_model("OpenAIServerModel", "test_model_id")
|
19 |
+
assert isinstance(model, OpenAIServerModel)
|
20 |
+
assert model.model_id == "test_model_id"
|
21 |
+
assert MockOpenAI.call_count == 1
|
22 |
+
assert MockOpenAI.call_args.kwargs["base_url"] == "https://api.fireworks.ai/inference/v1"
|
23 |
+
assert MockOpenAI.call_args.kwargs["api_key"] == "test_fireworks_api_key"
|
24 |
+
|
25 |
+
|
26 |
+
def test_load_model_litellm_model():
|
27 |
+
model = load_model("LiteLLMModel", "test_model_id", api_key="test_api_key", api_base="https://api.test.com")
|
28 |
+
assert isinstance(model, LiteLLMModel)
|
29 |
+
assert model.api_key == "test_api_key"
|
30 |
+
assert model.api_base == "https://api.test.com"
|
31 |
+
assert model.model_id == "test_model_id"
|
32 |
+
|
33 |
+
|
34 |
+
def test_load_model_transformers_model():
|
35 |
+
with (
|
36 |
+
patch(
|
37 |
+
"transformers.AutoModelForImageTextToText.from_pretrained",
|
38 |
+
side_effect=ValueError("Unrecognized configuration class"),
|
39 |
+
),
|
40 |
+
patch("transformers.AutoModelForCausalLM.from_pretrained"),
|
41 |
+
patch("transformers.AutoTokenizer.from_pretrained"),
|
42 |
+
):
|
43 |
+
model = load_model("TransformersModel", "test_model_id")
|
44 |
+
assert isinstance(model, TransformersModel)
|
45 |
+
assert model.model_id == "test_model_id"
|
46 |
+
|
47 |
+
|
48 |
+
def test_load_model_hf_api_model(set_env_vars):
|
49 |
+
with patch("huggingface_hub.InferenceClient") as huggingface_hub_InferenceClient:
|
50 |
+
model = load_model("InferenceClientModel", "test_model_id")
|
51 |
+
assert isinstance(model, InferenceClientModel)
|
52 |
+
assert model.model_id == "test_model_id"
|
53 |
+
assert huggingface_hub_InferenceClient.call_count == 1
|
54 |
+
assert huggingface_hub_InferenceClient.call_args.kwargs["token"] == "test_hf_api_key"
|
55 |
+
|
56 |
+
|
57 |
+
def test_load_model_invalid_model_type():
|
58 |
+
with pytest.raises(ValueError, match="Unsupported model type: InvalidModel"):
|
59 |
+
load_model("InvalidModel", "test_model_id")
|
60 |
+
|
61 |
+
|
62 |
+
def test_cli_main(capsys):
|
63 |
+
with patch("smolagents.cli.load_model") as mock_load_model:
|
64 |
+
mock_load_model.return_value = "mock_model"
|
65 |
+
with patch("smolagents.cli.CodeAgent") as mock_code_agent:
|
66 |
+
from smolagents.cli import run_smolagent
|
67 |
+
|
68 |
+
run_smolagent("test_prompt", [], "InferenceClientModel", "test_model_id", provider="hf-inference")
|
69 |
+
# load_model
|
70 |
+
assert len(mock_load_model.call_args_list) == 1
|
71 |
+
assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id")
|
72 |
+
assert mock_load_model.call_args.kwargs == {"api_base": None, "api_key": None, "provider": "hf-inference"}
|
73 |
+
# CodeAgent
|
74 |
+
assert len(mock_code_agent.call_args_list) == 1
|
75 |
+
assert mock_code_agent.call_args.args == ()
|
76 |
+
assert mock_code_agent.call_args.kwargs == {
|
77 |
+
"tools": [],
|
78 |
+
"model": "mock_model",
|
79 |
+
"additional_authorized_imports": None,
|
80 |
+
}
|
81 |
+
# agent.run
|
82 |
+
assert len(mock_code_agent.return_value.run.call_args_list) == 1
|
83 |
+
assert mock_code_agent.return_value.run.call_args.args == ("test_prompt",)
|
84 |
+
# print
|
85 |
+
captured = capsys.readouterr()
|
86 |
+
assert "Running agent with these tools: []" in captured.out
|
87 |
+
|
88 |
+
|
89 |
+
def test_vision_web_browser_main():
|
90 |
+
with patch("smolagents.vision_web_browser.helium"):
|
91 |
+
with patch("smolagents.vision_web_browser.load_model") as mock_load_model:
|
92 |
+
mock_load_model.return_value = "mock_model"
|
93 |
+
with patch("smolagents.vision_web_browser.CodeAgent") as mock_code_agent:
|
94 |
+
from smolagents.vision_web_browser import helium_instructions, run_webagent
|
95 |
+
|
96 |
+
run_webagent("test_prompt", "InferenceClientModel", "test_model_id", provider="hf-inference")
|
97 |
+
# load_model
|
98 |
+
assert len(mock_load_model.call_args_list) == 1
|
99 |
+
assert mock_load_model.call_args.args == ("InferenceClientModel", "test_model_id")
|
100 |
+
# CodeAgent
|
101 |
+
assert len(mock_code_agent.call_args_list) == 1
|
102 |
+
assert mock_code_agent.call_args.args == ()
|
103 |
+
assert len(mock_code_agent.call_args.kwargs["tools"]) == 4
|
104 |
+
assert mock_code_agent.call_args.kwargs["model"] == "mock_model"
|
105 |
+
assert mock_code_agent.call_args.kwargs["additional_authorized_imports"] == ["helium"]
|
106 |
+
# agent.python_executor
|
107 |
+
assert len(mock_code_agent.return_value.python_executor.call_args_list) == 1
|
108 |
+
assert mock_code_agent.return_value.python_executor.call_args.args == ("from helium import *",)
|
109 |
+
assert LocalPythonExecutor(["helium"])("from helium import *") == (None, "", False)
|
110 |
+
# agent.run
|
111 |
+
assert len(mock_code_agent.return_value.run.call_args_list) == 1
|
112 |
+
assert mock_code_agent.return_value.run.call_args.args == ("test_prompt" + helium_instructions,)
|
tests/test_default_tools.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import unittest
|
16 |
+
|
17 |
+
import pytest
|
18 |
+
|
19 |
+
from smolagents.agent_types import _AGENT_TYPE_MAPPING
|
20 |
+
from smolagents.default_tools import (
|
21 |
+
DuckDuckGoSearchTool,
|
22 |
+
PythonInterpreterTool,
|
23 |
+
SpeechToTextTool,
|
24 |
+
VisitWebpageTool,
|
25 |
+
WikipediaSearchTool,
|
26 |
+
)
|
27 |
+
|
28 |
+
from .test_tools import ToolTesterMixin
|
29 |
+
from .utils.markers import require_run_all
|
30 |
+
|
31 |
+
|
32 |
+
class DefaultToolTests(unittest.TestCase):
|
33 |
+
def test_visit_webpage(self):
|
34 |
+
arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
|
35 |
+
result = VisitWebpageTool()(arguments)
|
36 |
+
assert isinstance(result, str)
|
37 |
+
assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
|
38 |
+
|
39 |
+
@require_run_all
|
40 |
+
def test_ddgs_with_kwargs(self):
|
41 |
+
result = DuckDuckGoSearchTool(timeout=20)("DeepSeek parent company")
|
42 |
+
assert isinstance(result, str)
|
43 |
+
|
44 |
+
|
45 |
+
class TestPythonInterpreterTool(ToolTesterMixin):
|
46 |
+
def setup_method(self):
|
47 |
+
self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
|
48 |
+
self.tool.setup()
|
49 |
+
|
50 |
+
def test_exact_match_arg(self):
|
51 |
+
result = self.tool("(2 / 2) * 4")
|
52 |
+
assert result == "Stdout:\n\nOutput: 4.0"
|
53 |
+
|
54 |
+
def test_exact_match_kwarg(self):
|
55 |
+
result = self.tool(code="(2 / 2) * 4")
|
56 |
+
assert result == "Stdout:\n\nOutput: 4.0"
|
57 |
+
|
58 |
+
def test_agent_type_output(self):
|
59 |
+
inputs = ["2 * 2"]
|
60 |
+
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
61 |
+
output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
62 |
+
assert isinstance(output, output_type)
|
63 |
+
|
64 |
+
def test_agent_types_inputs(self):
|
65 |
+
inputs = ["2 * 2"]
|
66 |
+
_inputs = []
|
67 |
+
|
68 |
+
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
69 |
+
input_type = expected_input["type"]
|
70 |
+
if isinstance(input_type, list):
|
71 |
+
_inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
72 |
+
else:
|
73 |
+
_inputs.append(_AGENT_TYPE_MAPPING[input_type](_input))
|
74 |
+
|
75 |
+
# Should not raise an error
|
76 |
+
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
77 |
+
output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
78 |
+
assert isinstance(output, output_type)
|
79 |
+
|
80 |
+
def test_imports_work(self):
|
81 |
+
result = self.tool("import numpy as np")
|
82 |
+
assert "import from numpy is not allowed" not in result.lower()
|
83 |
+
|
84 |
+
def test_unauthorized_imports_fail(self):
|
85 |
+
with pytest.raises(Exception) as e:
|
86 |
+
self.tool("import sympy as sp")
|
87 |
+
assert "sympy" in str(e).lower()
|
88 |
+
|
89 |
+
|
90 |
+
class TestSpeechToTextTool:
|
91 |
+
def test_new_instance(self):
|
92 |
+
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
93 |
+
|
94 |
+
tool = SpeechToTextTool()
|
95 |
+
assert tool is not None
|
96 |
+
assert tool.pre_processor_class == WhisperProcessor
|
97 |
+
assert tool.model_class == WhisperForConditionalGeneration
|
98 |
+
|
99 |
+
def test_initialization(self):
|
100 |
+
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
101 |
+
|
102 |
+
tool = SpeechToTextTool(model="dummy_model_id")
|
103 |
+
assert tool is not None
|
104 |
+
assert tool.pre_processor_class == WhisperProcessor
|
105 |
+
assert tool.model_class == WhisperForConditionalGeneration
|
106 |
+
|
107 |
+
|
108 |
+
@pytest.mark.parametrize(
|
109 |
+
"language, content_type, extract_format, query",
|
110 |
+
[
|
111 |
+
("en", "summary", "HTML", "Python_(programming_language)"), # English, Summary Mode, HTML format
|
112 |
+
("en", "text", "WIKI", "Python_(programming_language)"), # English, Full Text Mode, WIKI format
|
113 |
+
("es", "summary", "HTML", "Python_(lenguaje_de_programación)"), # Spanish, Summary Mode, HTML format
|
114 |
+
("es", "text", "WIKI", "Python_(lenguaje_de_programación)"), # Spanish, Full Text Mode, WIKI format
|
115 |
+
],
|
116 |
+
)
|
117 |
+
def test_wikipedia_search(language, content_type, extract_format, query):
|
118 |
+
tool = WikipediaSearchTool(
|
119 |
+
user_agent="TestAgent ([email protected])",
|
120 |
+
language=language,
|
121 |
+
content_type=content_type,
|
122 |
+
extract_format=extract_format,
|
123 |
+
)
|
124 |
+
|
125 |
+
result = tool.forward(query)
|
126 |
+
|
127 |
+
assert isinstance(result, str), "Output should be a string"
|
128 |
+
assert "✅ **Wikipedia Page:**" in result, "Response should contain Wikipedia page title"
|
129 |
+
assert "🔗 **Read more:**" in result, "Response should contain Wikipedia page URL"
|
130 |
+
|
131 |
+
if content_type == "summary":
|
132 |
+
assert len(result.split()) < 1000, "Summary mode should return a shorter text"
|
133 |
+
if content_type == "text":
|
134 |
+
assert len(result.split()) > 1000, "Full text mode should return a longer text"
|
tests/test_final_answer.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import pytest
|
20 |
+
|
21 |
+
from smolagents.agent_types import _AGENT_TYPE_MAPPING
|
22 |
+
from smolagents.default_tools import FinalAnswerTool
|
23 |
+
|
24 |
+
from .test_tools import ToolTesterMixin
|
25 |
+
from .utils.markers import require_torch
|
26 |
+
|
27 |
+
|
28 |
+
class TestFinalAnswerTool(ToolTesterMixin):
|
29 |
+
def setup_method(self):
|
30 |
+
self.inputs = {"answer": "Final answer"}
|
31 |
+
self.tool = FinalAnswerTool()
|
32 |
+
|
33 |
+
def test_exact_match_arg(self):
|
34 |
+
result = self.tool("Final answer")
|
35 |
+
assert result == "Final answer"
|
36 |
+
|
37 |
+
def test_exact_match_kwarg(self):
|
38 |
+
result = self.tool(answer=self.inputs["answer"])
|
39 |
+
assert result == "Final answer"
|
40 |
+
|
41 |
+
@require_torch
|
42 |
+
def test_agent_type_output(self, inputs):
|
43 |
+
for input_type, input in inputs.items():
|
44 |
+
output = self.tool(**input, sanitize_inputs_outputs=True)
|
45 |
+
agent_type = _AGENT_TYPE_MAPPING[input_type]
|
46 |
+
assert isinstance(output, agent_type)
|
47 |
+
|
48 |
+
@pytest.fixture
|
49 |
+
def inputs(self, shared_datadir):
|
50 |
+
import torch
|
51 |
+
|
52 |
+
return {
|
53 |
+
"string": {"answer": "Text input"},
|
54 |
+
"image": {"answer": PIL.Image.open(shared_datadir / "000000039769.png").resize((512, 512))},
|
55 |
+
"audio": {"answer": torch.Tensor(np.ones(3000))},
|
56 |
+
}
|
tests/test_function_type_hints_utils.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from typing import Any
|
16 |
+
|
17 |
+
import pytest
|
18 |
+
|
19 |
+
from smolagents._function_type_hints_utils import DocstringParsingException, get_imports, get_json_schema
|
20 |
+
|
21 |
+
|
22 |
+
@pytest.fixture
|
23 |
+
def valid_func():
|
24 |
+
"""A well-formed function with docstring, type hints, and return block."""
|
25 |
+
|
26 |
+
def multiply(x: int, y: float) -> float:
|
27 |
+
"""
|
28 |
+
Multiplies two numbers.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
x: The first number.
|
32 |
+
y: The second number.
|
33 |
+
Returns:
|
34 |
+
Product of x and y.
|
35 |
+
"""
|
36 |
+
return x * y
|
37 |
+
|
38 |
+
return multiply
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.fixture
|
42 |
+
def no_docstring_func():
|
43 |
+
"""Function with no docstring."""
|
44 |
+
|
45 |
+
def sample(x: int):
|
46 |
+
return x
|
47 |
+
|
48 |
+
return sample
|
49 |
+
|
50 |
+
|
51 |
+
@pytest.fixture
|
52 |
+
def missing_arg_doc_func():
|
53 |
+
"""Function with docstring but missing an argument description."""
|
54 |
+
|
55 |
+
def add(x: int, y: int):
|
56 |
+
"""
|
57 |
+
Adds two numbers.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x: The first number.
|
61 |
+
"""
|
62 |
+
return x + y
|
63 |
+
|
64 |
+
return add
|
65 |
+
|
66 |
+
|
67 |
+
@pytest.fixture
|
68 |
+
def bad_return_func():
|
69 |
+
"""Function docstring with missing return description (allowed)."""
|
70 |
+
|
71 |
+
def do_nothing(x: str | None = None):
|
72 |
+
"""
|
73 |
+
Does nothing.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
x: Some optional string.
|
77 |
+
"""
|
78 |
+
pass
|
79 |
+
|
80 |
+
return do_nothing
|
81 |
+
|
82 |
+
|
83 |
+
@pytest.fixture
|
84 |
+
def complex_types_func():
|
85 |
+
def process_data(items: list[str], config: dict[str, float], point: tuple[int, int]) -> dict:
|
86 |
+
"""
|
87 |
+
Process some data.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
items: List of items to process.
|
91 |
+
config: Configuration parameters.
|
92 |
+
point: A position as (x,y).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Processed data result.
|
96 |
+
"""
|
97 |
+
return {"result": True}
|
98 |
+
|
99 |
+
return process_data
|
100 |
+
|
101 |
+
|
102 |
+
@pytest.fixture
|
103 |
+
def optional_types_func():
|
104 |
+
def process_with_optional(required_arg: str, optional_arg: int | None = None) -> str:
|
105 |
+
"""
|
106 |
+
Process with optional argument.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
required_arg: A required string argument.
|
110 |
+
optional_arg: An optional integer argument.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Processing result.
|
114 |
+
"""
|
115 |
+
return "processed"
|
116 |
+
|
117 |
+
return process_with_optional
|
118 |
+
|
119 |
+
|
120 |
+
@pytest.fixture
|
121 |
+
def enum_choices_func():
|
122 |
+
def select_color(color: str) -> str:
|
123 |
+
"""
|
124 |
+
Select a color.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
color: The color to select (choices: ["red", "green", "blue"])
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Selected color.
|
131 |
+
"""
|
132 |
+
return color
|
133 |
+
|
134 |
+
return select_color
|
135 |
+
|
136 |
+
|
137 |
+
@pytest.fixture
|
138 |
+
def union_types_func():
|
139 |
+
def process_union(value: int | str) -> bool | str:
|
140 |
+
"""
|
141 |
+
Process a value that can be either int or string.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
value: An integer or string value.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Processing result.
|
148 |
+
"""
|
149 |
+
return True if isinstance(value, int) else "string result"
|
150 |
+
|
151 |
+
return process_union
|
152 |
+
|
153 |
+
|
154 |
+
@pytest.fixture
|
155 |
+
def nested_types_func():
|
156 |
+
def process_nested_data(data: list[dict[str, Any]]) -> list[str]:
|
157 |
+
"""
|
158 |
+
Process nested data structure.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
data: List of dictionaries to process.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
List of processed results.
|
165 |
+
"""
|
166 |
+
return ["result"]
|
167 |
+
|
168 |
+
return process_nested_data
|
169 |
+
|
170 |
+
|
171 |
+
@pytest.fixture
|
172 |
+
def typed_docstring_func():
|
173 |
+
def calculate(x: int, y: float) -> float:
|
174 |
+
"""
|
175 |
+
Calculate something.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
x (int): An integer parameter with type in docstring.
|
179 |
+
y (float): A float parameter with type in docstring.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
float: The calculated result.
|
183 |
+
"""
|
184 |
+
return x * y
|
185 |
+
|
186 |
+
return calculate
|
187 |
+
|
188 |
+
|
189 |
+
@pytest.fixture
|
190 |
+
def mismatched_types_func():
|
191 |
+
def convert(value: int) -> str:
|
192 |
+
"""
|
193 |
+
Convert a value.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
value (str): A string value (type mismatch with hint).
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
int: Converted value (type mismatch with hint).
|
200 |
+
"""
|
201 |
+
return str(value)
|
202 |
+
|
203 |
+
return convert
|
204 |
+
|
205 |
+
|
206 |
+
@pytest.fixture
|
207 |
+
def complex_docstring_types_func():
|
208 |
+
def process(data: dict[str, list[int]]) -> list[dict[str, Any]]:
|
209 |
+
"""
|
210 |
+
Process complex data.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
data (Dict[str, List[int]]): Nested structure with types.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
List[Dict[str, Any]]: Processed results with types.
|
217 |
+
"""
|
218 |
+
return [{"result": sum(v) for k, v in data.items()}]
|
219 |
+
|
220 |
+
return process
|
221 |
+
|
222 |
+
|
223 |
+
@pytest.fixture
|
224 |
+
def keywords_in_description_func():
|
225 |
+
def process(value: str) -> str:
|
226 |
+
"""
|
227 |
+
Function with Args: or Returns: keywords in its description.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
value: A string value.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
str: Processed value.
|
234 |
+
"""
|
235 |
+
return value.upper()
|
236 |
+
|
237 |
+
return process
|
238 |
+
|
239 |
+
|
240 |
+
class TestGetJsonSchema:
|
241 |
+
def test_get_json_schema_example(self):
|
242 |
+
def fn(x: int, y: tuple[str, str, float] | None = None) -> None:
|
243 |
+
"""
|
244 |
+
Test function
|
245 |
+
Args:
|
246 |
+
x: The first input
|
247 |
+
y: The second input
|
248 |
+
"""
|
249 |
+
pass
|
250 |
+
|
251 |
+
schema = get_json_schema(fn)
|
252 |
+
expected_schema = {
|
253 |
+
"name": "fn",
|
254 |
+
"description": "Test function",
|
255 |
+
"parameters": {
|
256 |
+
"type": "object",
|
257 |
+
"properties": {
|
258 |
+
"x": {"type": "integer", "description": "The first input"},
|
259 |
+
"y": {
|
260 |
+
"type": "array",
|
261 |
+
"description": "The second input",
|
262 |
+
"nullable": True,
|
263 |
+
"prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}],
|
264 |
+
},
|
265 |
+
},
|
266 |
+
"required": ["x"],
|
267 |
+
},
|
268 |
+
"return": {"type": "null"},
|
269 |
+
}
|
270 |
+
assert schema["function"]["parameters"]["properties"]["y"] == expected_schema["parameters"]["properties"]["y"]
|
271 |
+
assert schema["function"] == expected_schema
|
272 |
+
|
273 |
+
@pytest.mark.parametrize(
|
274 |
+
"fixture_name,should_fail",
|
275 |
+
[
|
276 |
+
("valid_func", False),
|
277 |
+
# ('no_docstring_func', True),
|
278 |
+
# ('missing_arg_doc_func', True),
|
279 |
+
("bad_return_func", False),
|
280 |
+
],
|
281 |
+
)
|
282 |
+
def test_get_json_schema(self, request, fixture_name, should_fail):
|
283 |
+
func = request.getfixturevalue(fixture_name)
|
284 |
+
schema = get_json_schema(func)
|
285 |
+
assert schema["type"] == "function"
|
286 |
+
assert "function" in schema
|
287 |
+
assert "parameters" in schema["function"]
|
288 |
+
|
289 |
+
@pytest.mark.parametrize(
|
290 |
+
"fixture_name,should_fail",
|
291 |
+
[
|
292 |
+
# ('valid_func', False),
|
293 |
+
("no_docstring_func", True),
|
294 |
+
("missing_arg_doc_func", True),
|
295 |
+
# ('bad_return_func', False),
|
296 |
+
],
|
297 |
+
)
|
298 |
+
def test_get_json_schema_raises(self, request, fixture_name, should_fail):
|
299 |
+
func = request.getfixturevalue(fixture_name)
|
300 |
+
with pytest.raises(DocstringParsingException):
|
301 |
+
get_json_schema(func)
|
302 |
+
|
303 |
+
@pytest.mark.parametrize(
|
304 |
+
"fixture_name,expected_properties",
|
305 |
+
[
|
306 |
+
("valid_func", {"x": "integer", "y": "number"}),
|
307 |
+
("bad_return_func", {"x": "string"}),
|
308 |
+
],
|
309 |
+
)
|
310 |
+
def test_property_types(self, request, fixture_name, expected_properties):
|
311 |
+
"""Test that property types are correctly mapped."""
|
312 |
+
func = request.getfixturevalue(fixture_name)
|
313 |
+
schema = get_json_schema(func)
|
314 |
+
|
315 |
+
properties = schema["function"]["parameters"]["properties"]
|
316 |
+
for prop_name, expected_type in expected_properties.items():
|
317 |
+
assert properties[prop_name]["type"] == expected_type
|
318 |
+
|
319 |
+
def test_schema_basic_structure(self, valid_func):
|
320 |
+
"""Test that basic schema structure is correct."""
|
321 |
+
schema = get_json_schema(valid_func)
|
322 |
+
# Check schema type
|
323 |
+
assert schema["type"] == "function"
|
324 |
+
assert "function" in schema
|
325 |
+
# Check function schema
|
326 |
+
function_schema = schema["function"]
|
327 |
+
assert function_schema["name"] == "multiply"
|
328 |
+
assert "description" in function_schema
|
329 |
+
assert function_schema["description"] == "Multiplies two numbers."
|
330 |
+
# Check parameters schema
|
331 |
+
assert "parameters" in function_schema
|
332 |
+
params = function_schema["parameters"]
|
333 |
+
assert params["type"] == "object"
|
334 |
+
assert "properties" in params
|
335 |
+
assert "required" in params
|
336 |
+
assert set(params["required"]) == {"x", "y"}
|
337 |
+
properties = params["properties"]
|
338 |
+
assert properties["x"]["type"] == "integer"
|
339 |
+
assert properties["y"]["type"] == "number"
|
340 |
+
# Check return schema
|
341 |
+
assert "return" in function_schema
|
342 |
+
return_schema = function_schema["return"]
|
343 |
+
assert return_schema["type"] == "number"
|
344 |
+
assert return_schema["description"] == "Product of x and y."
|
345 |
+
|
346 |
+
def test_complex_types(self, complex_types_func):
|
347 |
+
"""Test schema generation for complex types."""
|
348 |
+
schema = get_json_schema(complex_types_func)
|
349 |
+
properties = schema["function"]["parameters"]["properties"]
|
350 |
+
# Check list type
|
351 |
+
assert properties["items"]["type"] == "array"
|
352 |
+
# Check dict type
|
353 |
+
assert properties["config"]["type"] == "object"
|
354 |
+
# Check tuple type
|
355 |
+
assert properties["point"]["type"] == "array"
|
356 |
+
assert len(properties["point"]["prefixItems"]) == 2
|
357 |
+
assert properties["point"]["prefixItems"][0]["type"] == "integer"
|
358 |
+
assert properties["point"]["prefixItems"][1]["type"] == "integer"
|
359 |
+
|
360 |
+
def test_optional_types(self, optional_types_func):
|
361 |
+
"""Test schema generation for optional arguments."""
|
362 |
+
schema = get_json_schema(optional_types_func)
|
363 |
+
params = schema["function"]["parameters"]
|
364 |
+
# Required argument should be in required list
|
365 |
+
assert "required_arg" in params["required"]
|
366 |
+
# Optional argument should not be in required list
|
367 |
+
assert "optional_arg" not in params["required"]
|
368 |
+
# Optional argument should be nullable
|
369 |
+
assert params["properties"]["optional_arg"]["nullable"] is True
|
370 |
+
assert params["properties"]["optional_arg"]["type"] == "integer"
|
371 |
+
|
372 |
+
def test_enum_choices(self, enum_choices_func):
|
373 |
+
"""Test schema generation for enum choices in docstring."""
|
374 |
+
schema = get_json_schema(enum_choices_func)
|
375 |
+
color_prop = schema["function"]["parameters"]["properties"]["color"]
|
376 |
+
assert "enum" in color_prop
|
377 |
+
assert color_prop["enum"] == ["red", "green", "blue"]
|
378 |
+
|
379 |
+
def test_union_types(self, union_types_func):
|
380 |
+
"""Test schema generation for union types."""
|
381 |
+
schema = get_json_schema(union_types_func)
|
382 |
+
value_prop = schema["function"]["parameters"]["properties"]["value"]
|
383 |
+
return_prop = schema["function"]["return"]
|
384 |
+
# Check union in parameter
|
385 |
+
assert len(value_prop["type"]) == 2
|
386 |
+
# Check union in return type: should be converted to "any"
|
387 |
+
assert return_prop["type"] == "any"
|
388 |
+
|
389 |
+
def test_nested_types(self, nested_types_func):
|
390 |
+
"""Test schema generation for nested complex types."""
|
391 |
+
schema = get_json_schema(nested_types_func)
|
392 |
+
data_prop = schema["function"]["parameters"]["properties"]["data"]
|
393 |
+
assert data_prop["type"] == "array"
|
394 |
+
|
395 |
+
def test_typed_docstring_parsing(self, typed_docstring_func):
|
396 |
+
"""Test parsing of docstrings with type annotations."""
|
397 |
+
schema = get_json_schema(typed_docstring_func)
|
398 |
+
# Type hints should take precedence over docstring types
|
399 |
+
assert schema["function"]["parameters"]["properties"]["x"]["type"] == "integer"
|
400 |
+
assert schema["function"]["parameters"]["properties"]["y"]["type"] == "number"
|
401 |
+
# Description should be extracted correctly
|
402 |
+
assert (
|
403 |
+
schema["function"]["parameters"]["properties"]["x"]["description"]
|
404 |
+
== "An integer parameter with type in docstring."
|
405 |
+
)
|
406 |
+
assert (
|
407 |
+
schema["function"]["parameters"]["properties"]["y"]["description"]
|
408 |
+
== "A float parameter with type in docstring."
|
409 |
+
)
|
410 |
+
# Return type and description should be correct
|
411 |
+
assert schema["function"]["return"]["type"] == "number"
|
412 |
+
assert schema["function"]["return"]["description"] == "The calculated result."
|
413 |
+
|
414 |
+
def test_mismatched_docstring_types(self, mismatched_types_func):
|
415 |
+
"""Test that type hints take precedence over docstring types when they conflict."""
|
416 |
+
schema = get_json_schema(mismatched_types_func)
|
417 |
+
# Type hints should take precedence over docstring types
|
418 |
+
assert schema["function"]["parameters"]["properties"]["value"]["type"] == "integer"
|
419 |
+
# Return type from type hint should be used, not docstring
|
420 |
+
assert schema["function"]["return"]["type"] == "string"
|
421 |
+
|
422 |
+
def test_complex_docstring_types(self, complex_docstring_types_func):
|
423 |
+
"""Test parsing of complex type annotations in docstrings."""
|
424 |
+
schema = get_json_schema(complex_docstring_types_func)
|
425 |
+
# Check that complex nested type is parsed correctly from type hints
|
426 |
+
data_prop = schema["function"]["parameters"]["properties"]["data"]
|
427 |
+
assert data_prop["type"] == "object"
|
428 |
+
# Check return type
|
429 |
+
return_prop = schema["function"]["return"]
|
430 |
+
assert return_prop["type"] == "array"
|
431 |
+
# Description should include the type information from docstring
|
432 |
+
assert data_prop["description"] == "Nested structure with types."
|
433 |
+
assert return_prop["description"] == "Processed results with types."
|
434 |
+
|
435 |
+
@pytest.mark.parametrize(
|
436 |
+
"fixture_name,expected_description",
|
437 |
+
[
|
438 |
+
("typed_docstring_func", "An integer parameter with type in docstring."),
|
439 |
+
("complex_docstring_types_func", "Nested structure with types."),
|
440 |
+
],
|
441 |
+
)
|
442 |
+
def test_type_in_description_handling(self, request, fixture_name, expected_description):
|
443 |
+
"""Test that type information in docstrings is preserved in description."""
|
444 |
+
func = request.getfixturevalue(fixture_name)
|
445 |
+
schema = get_json_schema(func)
|
446 |
+
# First parameter description should contain the expected text
|
447 |
+
first_param_name = list(schema["function"]["parameters"]["properties"].keys())[0]
|
448 |
+
assert schema["function"]["parameters"]["properties"][first_param_name]["description"] == expected_description
|
449 |
+
|
450 |
+
def test_with_special_words_in_description_func(self, keywords_in_description_func):
|
451 |
+
schema = get_json_schema(keywords_in_description_func)
|
452 |
+
assert schema["function"]["description"] == "Function with Args: or Returns: keywords in its description."
|
453 |
+
|
454 |
+
|
455 |
+
class TestGetCode:
|
456 |
+
@pytest.mark.parametrize(
|
457 |
+
"code, expected",
|
458 |
+
[
|
459 |
+
(
|
460 |
+
"""
|
461 |
+
import numpy
|
462 |
+
import pandas
|
463 |
+
""",
|
464 |
+
["numpy", "pandas"],
|
465 |
+
),
|
466 |
+
# From imports
|
467 |
+
(
|
468 |
+
"""
|
469 |
+
from torch import nn
|
470 |
+
from transformers import AutoModel
|
471 |
+
""",
|
472 |
+
["torch", "transformers"],
|
473 |
+
),
|
474 |
+
# Mixed case with nested imports
|
475 |
+
(
|
476 |
+
"""
|
477 |
+
import numpy as np
|
478 |
+
from torch.nn import Linear
|
479 |
+
import os.path
|
480 |
+
""",
|
481 |
+
["numpy", "torch", "os"],
|
482 |
+
),
|
483 |
+
# Try/except block (should be filtered)
|
484 |
+
(
|
485 |
+
"""
|
486 |
+
try:
|
487 |
+
import torch
|
488 |
+
except ImportError:
|
489 |
+
pass
|
490 |
+
import numpy
|
491 |
+
""",
|
492 |
+
["numpy"],
|
493 |
+
),
|
494 |
+
# Flash attention block (should be filtered)
|
495 |
+
(
|
496 |
+
"""
|
497 |
+
if is_flash_attn_2_available():
|
498 |
+
from flash_attn import flash_attn_func
|
499 |
+
import transformers
|
500 |
+
""",
|
501 |
+
["transformers"],
|
502 |
+
),
|
503 |
+
# Relative imports (should be excluded)
|
504 |
+
(
|
505 |
+
"""
|
506 |
+
from .utils import helper
|
507 |
+
from ..models import transformer
|
508 |
+
""",
|
509 |
+
[],
|
510 |
+
),
|
511 |
+
],
|
512 |
+
)
|
513 |
+
def test_get_imports(self, code: str, expected: list[str]):
|
514 |
+
assert sorted(get_imports(code)) == sorted(expected)
|
tests/test_gradio_ui.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import tempfile
|
19 |
+
import unittest
|
20 |
+
from unittest.mock import Mock, patch
|
21 |
+
|
22 |
+
import pytest
|
23 |
+
|
24 |
+
from smolagents.agent_types import AgentAudio, AgentImage, AgentText
|
25 |
+
from smolagents.gradio_ui import GradioUI, pull_messages_from_step, stream_to_gradio
|
26 |
+
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, ToolCall
|
27 |
+
from smolagents.models import ChatMessageStreamDelta
|
28 |
+
from smolagents.monitoring import Timing, TokenUsage
|
29 |
+
|
30 |
+
|
31 |
+
class GradioUITester(unittest.TestCase):
|
32 |
+
def setUp(self):
|
33 |
+
"""Initialize test environment"""
|
34 |
+
self.temp_dir = tempfile.mkdtemp()
|
35 |
+
self.mock_agent = Mock()
|
36 |
+
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
|
37 |
+
self.allowed_types = [".pdf", ".docx", ".txt"]
|
38 |
+
|
39 |
+
def tearDown(self):
|
40 |
+
"""Clean up test environment"""
|
41 |
+
shutil.rmtree(self.temp_dir)
|
42 |
+
|
43 |
+
def test_upload_file_default_types(self):
|
44 |
+
"""Test default allowed file types"""
|
45 |
+
default_types = [".pdf", ".docx", ".txt"]
|
46 |
+
for file_type in default_types:
|
47 |
+
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
|
48 |
+
mock_file = Mock()
|
49 |
+
mock_file.name = temp_file.name
|
50 |
+
|
51 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [])
|
52 |
+
|
53 |
+
self.assertIn("File uploaded:", textbox.value)
|
54 |
+
self.assertEqual(len(uploads_log), 1)
|
55 |
+
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
|
56 |
+
|
57 |
+
def test_upload_file_default_types_disallowed(self):
|
58 |
+
"""Test default disallowed file types"""
|
59 |
+
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
|
60 |
+
for file_type in disallowed_types:
|
61 |
+
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
|
62 |
+
mock_file = Mock()
|
63 |
+
mock_file.name = temp_file.name
|
64 |
+
|
65 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [])
|
66 |
+
|
67 |
+
self.assertEqual(textbox.value, "File type disallowed")
|
68 |
+
self.assertEqual(len(uploads_log), 0)
|
69 |
+
|
70 |
+
def test_upload_file_success(self):
|
71 |
+
"""Test successful file upload scenario"""
|
72 |
+
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
|
73 |
+
mock_file = Mock()
|
74 |
+
mock_file.name = temp_file.name
|
75 |
+
|
76 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [])
|
77 |
+
|
78 |
+
self.assertIn("File uploaded:", textbox.value)
|
79 |
+
self.assertEqual(len(uploads_log), 1)
|
80 |
+
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
|
81 |
+
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
|
82 |
+
|
83 |
+
def test_upload_file_none(self):
|
84 |
+
"""Test scenario when no file is selected"""
|
85 |
+
textbox, uploads_log = self.ui.upload_file(None, [])
|
86 |
+
|
87 |
+
self.assertEqual(textbox.value, "No file uploaded")
|
88 |
+
self.assertEqual(len(uploads_log), 0)
|
89 |
+
|
90 |
+
def test_upload_file_invalid_type(self):
|
91 |
+
"""Test disallowed file type"""
|
92 |
+
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
|
93 |
+
mock_file = Mock()
|
94 |
+
mock_file.name = temp_file.name
|
95 |
+
|
96 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [])
|
97 |
+
|
98 |
+
self.assertEqual(textbox.value, "File type disallowed")
|
99 |
+
self.assertEqual(len(uploads_log), 0)
|
100 |
+
|
101 |
+
def test_upload_file_special_chars(self):
|
102 |
+
"""Test scenario with special characters in filename"""
|
103 |
+
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
|
104 |
+
# Create a new temporary file with special characters
|
105 |
+
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
|
106 |
+
shutil.copy(temp_file.name, special_char_name)
|
107 |
+
try:
|
108 |
+
mock_file = Mock()
|
109 |
+
mock_file.name = special_char_name
|
110 |
+
|
111 |
+
with patch("shutil.copy"):
|
112 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [])
|
113 |
+
|
114 |
+
self.assertIn("File uploaded:", textbox.value)
|
115 |
+
self.assertEqual(len(uploads_log), 1)
|
116 |
+
self.assertIn("test_____", uploads_log[0])
|
117 |
+
finally:
|
118 |
+
# Clean up the special character file
|
119 |
+
if os.path.exists(special_char_name):
|
120 |
+
os.remove(special_char_name)
|
121 |
+
|
122 |
+
def test_upload_file_custom_types(self):
|
123 |
+
"""Test custom allowed file types"""
|
124 |
+
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
|
125 |
+
mock_file = Mock()
|
126 |
+
mock_file.name = temp_file.name
|
127 |
+
|
128 |
+
textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])
|
129 |
+
|
130 |
+
self.assertIn("File uploaded:", textbox.value)
|
131 |
+
self.assertEqual(len(uploads_log), 1)
|
132 |
+
|
133 |
+
|
134 |
+
class TestStreamToGradio:
|
135 |
+
"""Tests for the stream_to_gradio function."""
|
136 |
+
|
137 |
+
@patch("smolagents.gradio_ui.pull_messages_from_step")
|
138 |
+
def test_stream_to_gradio_memory_step(self, mock_pull_messages):
|
139 |
+
"""Test streaming a memory step"""
|
140 |
+
# Create mock agent and memory step
|
141 |
+
mock_agent = Mock()
|
142 |
+
mock_agent.run = Mock(return_value=[Mock(spec=ActionStep)])
|
143 |
+
mock_agent.model = Mock()
|
144 |
+
mock_agent.model.last_input_token_count = 100
|
145 |
+
mock_agent.model.last_output_token_count = 200
|
146 |
+
# Mock the pull_messages_from_step function to return some messages
|
147 |
+
mock_message = Mock()
|
148 |
+
mock_pull_messages.return_value = [mock_message]
|
149 |
+
# Call stream_to_gradio
|
150 |
+
result = list(stream_to_gradio(mock_agent, "test task"))
|
151 |
+
# Verify that pull_messages_from_step was called and the message was yielded
|
152 |
+
mock_pull_messages.assert_called_once()
|
153 |
+
assert result == [mock_message]
|
154 |
+
|
155 |
+
def test_stream_to_gradio_stream_delta(self):
|
156 |
+
"""Test streaming a ChatMessageStreamDelta"""
|
157 |
+
# Create mock agent and stream delta
|
158 |
+
mock_agent = Mock()
|
159 |
+
mock_delta = ChatMessageStreamDelta(content="Hello")
|
160 |
+
mock_agent.run = Mock(return_value=[mock_delta])
|
161 |
+
mock_agent.model = Mock()
|
162 |
+
mock_agent.model.last_input_token_count = 100
|
163 |
+
mock_agent.model.last_output_token_count = 200
|
164 |
+
# Call stream_to_gradio
|
165 |
+
result = list(stream_to_gradio(mock_agent, "test task"))
|
166 |
+
# Verify that the content was yielded
|
167 |
+
assert result == ["Hello"]
|
168 |
+
|
169 |
+
def test_stream_to_gradio_multiple_deltas(self):
|
170 |
+
"""Test streaming multiple ChatMessageStreamDeltas"""
|
171 |
+
# Create mock agent and stream deltas
|
172 |
+
mock_agent = Mock()
|
173 |
+
mock_delta1 = ChatMessageStreamDelta(content="Hello")
|
174 |
+
mock_delta2 = ChatMessageStreamDelta(content=" world")
|
175 |
+
mock_agent.run = Mock(return_value=[mock_delta1, mock_delta2])
|
176 |
+
mock_agent.model = Mock()
|
177 |
+
mock_agent.model.last_input_token_count = 100
|
178 |
+
mock_agent.model.last_output_token_count = 200
|
179 |
+
# Call stream_to_gradio
|
180 |
+
result = list(stream_to_gradio(mock_agent, "test task"))
|
181 |
+
# Verify that the content was accumulated and yielded
|
182 |
+
assert result == ["Hello", "Hello world"]
|
183 |
+
|
184 |
+
@pytest.mark.parametrize(
|
185 |
+
"task,task_images,reset_memory,additional_args",
|
186 |
+
[
|
187 |
+
("simple task", None, False, None),
|
188 |
+
("task with images", ["image1.png", "image2.png"], False, None),
|
189 |
+
("task with reset", None, True, None),
|
190 |
+
("task with args", None, False, {"arg1": "value1"}),
|
191 |
+
("complex task", ["image.png"], True, {"arg1": "value1", "arg2": "value2"}),
|
192 |
+
],
|
193 |
+
)
|
194 |
+
def test_stream_to_gradio_parameters(self, task, task_images, reset_memory, additional_args):
|
195 |
+
"""Test that stream_to_gradio passes parameters correctly to agent.run"""
|
196 |
+
# Create mock agent
|
197 |
+
mock_agent = Mock()
|
198 |
+
mock_agent.run = Mock(return_value=[])
|
199 |
+
# Call stream_to_gradio
|
200 |
+
list(
|
201 |
+
stream_to_gradio(
|
202 |
+
mock_agent,
|
203 |
+
task=task,
|
204 |
+
task_images=task_images,
|
205 |
+
reset_agent_memory=reset_memory,
|
206 |
+
additional_args=additional_args,
|
207 |
+
)
|
208 |
+
)
|
209 |
+
# Verify that agent.run was called with the right parameters
|
210 |
+
mock_agent.run.assert_called_once_with(
|
211 |
+
task, images=task_images, stream=True, reset=reset_memory, additional_args=additional_args
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class TestPullMessagesFromStep:
|
216 |
+
def test_action_step_basic(
|
217 |
+
self,
|
218 |
+
):
|
219 |
+
"""Test basic ActionStep processing."""
|
220 |
+
step = ActionStep(
|
221 |
+
step_number=1,
|
222 |
+
model_output="This is the model output",
|
223 |
+
observations="Some execution logs",
|
224 |
+
error=None,
|
225 |
+
timing=Timing(start_time=1.0, end_time=3.5),
|
226 |
+
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
|
227 |
+
)
|
228 |
+
messages = list(pull_messages_from_step(step))
|
229 |
+
assert len(messages) == 5 # step number, model_output, logs, footnote, divider
|
230 |
+
for message, expected_content in zip(
|
231 |
+
messages,
|
232 |
+
[
|
233 |
+
"**Step 1**",
|
234 |
+
"This is the model output",
|
235 |
+
"execution logs",
|
236 |
+
"Input tokens: 100 | Output tokens: 50 | Duration: 2.5",
|
237 |
+
"-----",
|
238 |
+
],
|
239 |
+
):
|
240 |
+
assert expected_content in message.content
|
241 |
+
|
242 |
+
def test_action_step_with_tool_calls(self):
|
243 |
+
"""Test ActionStep with tool calls."""
|
244 |
+
step = ActionStep(
|
245 |
+
step_number=2,
|
246 |
+
tool_calls=[ToolCall(name="test_tool", arguments={"answer": "Test answer"}, id="tool_call_1")],
|
247 |
+
observations="Tool execution logs",
|
248 |
+
timing=Timing(start_time=1.0, end_time=2.5),
|
249 |
+
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
|
250 |
+
)
|
251 |
+
messages = list(pull_messages_from_step(step))
|
252 |
+
assert len(messages) == 5 # step, tool call, logs, footnote, divider
|
253 |
+
assert messages[1].content == "Test answer"
|
254 |
+
assert "Used tool test_tool" in messages[1].metadata["title"]
|
255 |
+
|
256 |
+
@pytest.mark.parametrize(
|
257 |
+
"tool_name, args, expected",
|
258 |
+
[
|
259 |
+
("python_interpreter", "print('Hello')", "```python\nprint('Hello')\n```"),
|
260 |
+
("regular_tool", {"key": "value"}, "{'key': 'value'}"),
|
261 |
+
("string_args_tool", "simple string", "simple string"),
|
262 |
+
],
|
263 |
+
)
|
264 |
+
def test_action_step_tool_call_formats(self, tool_name, args, expected):
|
265 |
+
"""Test different formats of tool calls."""
|
266 |
+
tool_call = Mock()
|
267 |
+
tool_call.name = tool_name
|
268 |
+
tool_call.arguments = args
|
269 |
+
step = ActionStep(
|
270 |
+
step_number=1,
|
271 |
+
tool_calls=[tool_call],
|
272 |
+
timing=Timing(start_time=1.0, end_time=2.5),
|
273 |
+
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
|
274 |
+
)
|
275 |
+
messages = list(pull_messages_from_step(step))
|
276 |
+
tool_message = next(
|
277 |
+
msg
|
278 |
+
for msg in messages
|
279 |
+
if msg.role == "assistant" and msg.metadata and msg.metadata.get("title", "").startswith("🛠️")
|
280 |
+
)
|
281 |
+
assert expected in tool_message.content
|
282 |
+
|
283 |
+
def test_action_step_with_error(self):
|
284 |
+
"""Test ActionStep with error."""
|
285 |
+
step = ActionStep(
|
286 |
+
step_number=3,
|
287 |
+
error="This is an error message",
|
288 |
+
timing=Timing(start_time=1.0, end_time=2.0),
|
289 |
+
token_usage=TokenUsage(input_tokens=100, output_tokens=200),
|
290 |
+
)
|
291 |
+
messages = list(pull_messages_from_step(step))
|
292 |
+
error_message = next((m for m in messages if "error" in str(m.content).lower()), None)
|
293 |
+
assert error_message is not None
|
294 |
+
assert "This is an error message" in error_message.content
|
295 |
+
|
296 |
+
def test_action_step_with_images(self):
|
297 |
+
"""Test ActionStep with observation images."""
|
298 |
+
step = ActionStep(
|
299 |
+
step_number=4,
|
300 |
+
observations_images=["image1.png", "image2.jpg"],
|
301 |
+
token_usage=TokenUsage(input_tokens=100, output_tokens=200),
|
302 |
+
timing=Timing(start_time=1.0, end_time=2.0),
|
303 |
+
)
|
304 |
+
with patch("smolagents.gradio_ui.AgentImage") as mock_agent_image:
|
305 |
+
mock_agent_image.return_value.to_string.side_effect = lambda: "path/to/image.png"
|
306 |
+
messages = list(pull_messages_from_step(step))
|
307 |
+
image_messages = [m for m in messages if "image" in str(m).lower()]
|
308 |
+
assert len(image_messages) == 2
|
309 |
+
assert "path/to/image.png" in str(image_messages[0])
|
310 |
+
|
311 |
+
@pytest.mark.parametrize(
|
312 |
+
"skip_model_outputs, expected_messages_length, token_usage",
|
313 |
+
[(False, 4, TokenUsage(input_tokens=80, output_tokens=30)), (True, 2, None)],
|
314 |
+
)
|
315 |
+
def test_planning_step(self, skip_model_outputs, expected_messages_length, token_usage):
|
316 |
+
"""Test PlanningStep processing."""
|
317 |
+
step = PlanningStep(
|
318 |
+
plan="1. First step\n2. Second step",
|
319 |
+
model_input_messages=Mock(),
|
320 |
+
model_output_message=Mock(),
|
321 |
+
token_usage=token_usage,
|
322 |
+
timing=Timing(start_time=1.0, end_time=2.0),
|
323 |
+
)
|
324 |
+
messages = list(pull_messages_from_step(step, skip_model_outputs=skip_model_outputs))
|
325 |
+
assert len(messages) == expected_messages_length # [header, plan,] footnote, divider
|
326 |
+
expected_contents = [
|
327 |
+
"**Planning step**",
|
328 |
+
"1. First step\n2. Second step",
|
329 |
+
"Input tokens: 80 | Output tokens: 30" if token_usage else "",
|
330 |
+
"-----",
|
331 |
+
]
|
332 |
+
for message, expected_content in zip(messages, expected_contents[-expected_messages_length:]):
|
333 |
+
assert expected_content in message.content
|
334 |
+
|
335 |
+
if not token_usage:
|
336 |
+
assert "Input tokens: 80 | Output tokens: 30" not in message.content
|
337 |
+
|
338 |
+
@pytest.mark.parametrize(
|
339 |
+
"answer_type, answer_value, expected_content",
|
340 |
+
[
|
341 |
+
(AgentText, "This is a text answer", "**Final answer:**\nThis is a text answer\n"),
|
342 |
+
(lambda: "Plain string", "Plain string", "**Final answer:** Plain string"),
|
343 |
+
],
|
344 |
+
)
|
345 |
+
def test_final_answer_step(self, answer_type, answer_value, expected_content):
|
346 |
+
"""Test FinalAnswerStep with different answer types."""
|
347 |
+
try:
|
348 |
+
final_answer = answer_type()
|
349 |
+
except TypeError:
|
350 |
+
with patch.object(answer_type, "to_string", return_value=answer_value):
|
351 |
+
final_answer = answer_type(answer_value)
|
352 |
+
step = FinalAnswerStep(
|
353 |
+
output=final_answer,
|
354 |
+
)
|
355 |
+
messages = list(pull_messages_from_step(step))
|
356 |
+
assert len(messages) == 1
|
357 |
+
assert messages[0].content == expected_content
|
358 |
+
|
359 |
+
def test_final_answer_step_image(self):
|
360 |
+
"""Test FinalAnswerStep with image answer."""
|
361 |
+
with patch.object(AgentImage, "to_string", return_value="path/to/image.png"):
|
362 |
+
step = FinalAnswerStep(output=AgentImage("path/to/image.png"))
|
363 |
+
messages = list(pull_messages_from_step(step))
|
364 |
+
assert len(messages) == 1
|
365 |
+
assert messages[0].content["path"] == "path/to/image.png"
|
366 |
+
assert messages[0].content["mime_type"] == "image/png"
|
367 |
+
|
368 |
+
def test_final_answer_step_audio(self):
|
369 |
+
"""Test FinalAnswerStep with audio answer."""
|
370 |
+
with patch.object(AgentAudio, "to_string", return_value="path/to/audio.wav"):
|
371 |
+
step = FinalAnswerStep(output=AgentAudio("path/to/audio.wav"))
|
372 |
+
messages = list(pull_messages_from_step(step))
|
373 |
+
assert len(messages) == 1
|
374 |
+
assert messages[0].content["path"] == "path/to/audio.wav"
|
375 |
+
assert messages[0].content["mime_type"] == "audio/wav"
|
376 |
+
|
377 |
+
def test_unsupported_step_type(self):
|
378 |
+
"""Test handling of unsupported step types."""
|
379 |
+
|
380 |
+
class UnsupportedStep(Mock):
|
381 |
+
pass
|
382 |
+
|
383 |
+
step = UnsupportedStep()
|
384 |
+
with pytest.raises(ValueError, match="Unsupported step type"):
|
385 |
+
list(pull_messages_from_step(step))
|
tests/test_import.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
|
6 |
+
def test_import_smolagents_without_extras(monkeypatch):
|
7 |
+
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
8 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
9 |
+
# Create a virtual environment
|
10 |
+
venv_dir = os.path.join(temp_dir, "venv")
|
11 |
+
subprocess.run(["uv", "venv", venv_dir], check=True)
|
12 |
+
|
13 |
+
# Install smolagents in the virtual environment
|
14 |
+
subprocess.run(
|
15 |
+
["uv", "pip", "install", "--python", os.path.join(venv_dir, "bin", "python"), "smolagents @ ."], check=True
|
16 |
+
)
|
17 |
+
|
18 |
+
# Run the import test in the virtual environment
|
19 |
+
result = subprocess.run(
|
20 |
+
[os.path.join(venv_dir, "bin", "python"), "-c", "import smolagents"],
|
21 |
+
capture_output=True,
|
22 |
+
text=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
# Check if the import was successful
|
26 |
+
assert result.returncode == 0, (
|
27 |
+
"Import failed with error: "
|
28 |
+
+ (result.stderr.splitlines()[-1] if result.stderr else "No error message")
|
29 |
+
+ "\n"
|
30 |
+
+ result.stderr
|
31 |
+
)
|
tests/test_local_python_executor.py
ADDED
@@ -0,0 +1,2353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import ast
|
17 |
+
import types
|
18 |
+
from contextlib import nullcontext as does_not_raise
|
19 |
+
from textwrap import dedent
|
20 |
+
from unittest.mock import patch
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import pandas as pd
|
24 |
+
import pytest
|
25 |
+
|
26 |
+
from smolagents.default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
27 |
+
from smolagents.local_python_executor import (
|
28 |
+
DANGEROUS_FUNCTIONS,
|
29 |
+
DANGEROUS_MODULES,
|
30 |
+
InterpreterError,
|
31 |
+
LocalPythonExecutor,
|
32 |
+
PrintContainer,
|
33 |
+
check_import_authorized,
|
34 |
+
evaluate_boolop,
|
35 |
+
evaluate_condition,
|
36 |
+
evaluate_delete,
|
37 |
+
evaluate_python_code,
|
38 |
+
evaluate_subscript,
|
39 |
+
fix_final_answer_code,
|
40 |
+
get_safe_module,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
# Fake function we will use as tool
|
45 |
+
def add_two(x):
|
46 |
+
return x + 2
|
47 |
+
|
48 |
+
|
49 |
+
class TestEvaluatePythonCode:
|
50 |
+
def assertDictEqualNoPrint(self, dict1, dict2):
|
51 |
+
assert {k: v for k, v in dict1.items() if k != "_print_outputs"} == {
|
52 |
+
k: v for k, v in dict2.items() if k != "_print_outputs"
|
53 |
+
}
|
54 |
+
|
55 |
+
def test_evaluate_assign(self):
|
56 |
+
code = "x = 3"
|
57 |
+
state = {}
|
58 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
59 |
+
assert result == 3
|
60 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "_operations_count": {"counter": 2}})
|
61 |
+
|
62 |
+
code = "x = y"
|
63 |
+
state = {"y": 5}
|
64 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
65 |
+
# evaluate returns the value of the last assignment.
|
66 |
+
assert result == 5
|
67 |
+
self.assertDictEqualNoPrint(state, {"x": 5, "y": 5, "_operations_count": {"counter": 2}})
|
68 |
+
|
69 |
+
code = "a=1;b=None"
|
70 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
71 |
+
# evaluate returns the value of the last assignment.
|
72 |
+
assert result is None
|
73 |
+
|
74 |
+
def test_assignment_cannot_overwrite_tool(self):
|
75 |
+
code = "print = '3'"
|
76 |
+
with pytest.raises(InterpreterError) as e:
|
77 |
+
evaluate_python_code(code, {"print": print}, state={})
|
78 |
+
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
79 |
+
|
80 |
+
def test_subscript_call(self):
|
81 |
+
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
82 |
+
state = {}
|
83 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
84 |
+
assert result == 64
|
85 |
+
assert state["result_foo"] == 8
|
86 |
+
assert state["result_boo"] == 64
|
87 |
+
|
88 |
+
def test_evaluate_call(self):
|
89 |
+
code = "y = add_two(x)"
|
90 |
+
state = {"x": 3}
|
91 |
+
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
92 |
+
assert result == 5
|
93 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": {"counter": 3}})
|
94 |
+
|
95 |
+
# Should not work without the tool
|
96 |
+
with pytest.raises(InterpreterError, match="Forbidden function evaluation: 'add_two'"):
|
97 |
+
evaluate_python_code(code, {}, state=state)
|
98 |
+
|
99 |
+
def test_evaluate_class_def(self):
|
100 |
+
code = dedent('''\
|
101 |
+
class MyClass:
|
102 |
+
"""A class with a value."""
|
103 |
+
|
104 |
+
def __init__(self, value):
|
105 |
+
self.value = value
|
106 |
+
|
107 |
+
def get_value(self):
|
108 |
+
return self.value
|
109 |
+
|
110 |
+
instance = MyClass(42)
|
111 |
+
result = instance.get_value()
|
112 |
+
''')
|
113 |
+
state = {}
|
114 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
115 |
+
assert result == 42
|
116 |
+
assert state["instance"].__doc__ == "A class with a value."
|
117 |
+
|
118 |
+
def test_evaluate_class_def_with_assign_attribute_target(self):
|
119 |
+
"""
|
120 |
+
Test evaluate_class_def function when stmt is an instance of ast.Assign with ast.Attribute target.
|
121 |
+
"""
|
122 |
+
code = dedent("""
|
123 |
+
class TestSubClass:
|
124 |
+
attr1 = 1
|
125 |
+
class TestClass:
|
126 |
+
data = TestSubClass()
|
127 |
+
data.attr1 = "value1"
|
128 |
+
data.attr2 = "value2"
|
129 |
+
result = (TestClass.data.attr1, TestClass.data.attr2)
|
130 |
+
""")
|
131 |
+
|
132 |
+
state = {}
|
133 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
134 |
+
|
135 |
+
assert result == ("value1", "value2")
|
136 |
+
assert isinstance(state["TestClass"], type)
|
137 |
+
assert state["TestClass"].data.attr1 == "value1"
|
138 |
+
assert state["TestClass"].data.attr2 == "value2"
|
139 |
+
|
140 |
+
def test_evaluate_constant(self):
|
141 |
+
code = "x = 3"
|
142 |
+
state = {}
|
143 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
144 |
+
assert result == 3
|
145 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "_operations_count": {"counter": 2}})
|
146 |
+
|
147 |
+
def test_evaluate_dict(self):
|
148 |
+
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
149 |
+
state = {"x": 3}
|
150 |
+
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
151 |
+
assert result == {"x": 3, "y": 5}
|
152 |
+
self.assertDictEqualNoPrint(
|
153 |
+
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": {"counter": 7}}
|
154 |
+
)
|
155 |
+
|
156 |
+
def test_evaluate_expression(self):
|
157 |
+
code = "x = 3\ny = 5"
|
158 |
+
state = {}
|
159 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
160 |
+
# evaluate returns the value of the last assignment.
|
161 |
+
assert result == 5
|
162 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": {"counter": 4}})
|
163 |
+
|
164 |
+
def test_evaluate_f_string(self):
|
165 |
+
code = "text = f'This is x: {x}.'"
|
166 |
+
state = {"x": 3}
|
167 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
168 |
+
# evaluate returns the value of the last assignment.
|
169 |
+
assert result == "This is x: 3."
|
170 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "text": "This is x: 3.", "_operations_count": {"counter": 6}})
|
171 |
+
|
172 |
+
def test_evaluate_f_string_with_format(self):
|
173 |
+
code = "text = f'This is x: {x:.2f}.'"
|
174 |
+
state = {"x": 3.336}
|
175 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
176 |
+
assert result == "This is x: 3.34."
|
177 |
+
self.assertDictEqualNoPrint(
|
178 |
+
state, {"x": 3.336, "text": "This is x: 3.34.", "_operations_count": {"counter": 8}}
|
179 |
+
)
|
180 |
+
|
181 |
+
def test_evaluate_f_string_with_complex_format(self):
|
182 |
+
code = "text = f'This is x: {x:>{width}.{precision}f}.'"
|
183 |
+
state = {"x": 3.336, "width": 10, "precision": 2}
|
184 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
185 |
+
assert result == "This is x: 3.34."
|
186 |
+
self.assertDictEqualNoPrint(
|
187 |
+
state,
|
188 |
+
{
|
189 |
+
"x": 3.336,
|
190 |
+
"width": 10,
|
191 |
+
"precision": 2,
|
192 |
+
"text": "This is x: 3.34.",
|
193 |
+
"_operations_count": {"counter": 14},
|
194 |
+
},
|
195 |
+
)
|
196 |
+
|
197 |
+
def test_evaluate_if(self):
|
198 |
+
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
199 |
+
state = {"x": 3}
|
200 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
201 |
+
# evaluate returns the value of the last assignment.
|
202 |
+
assert result == 2
|
203 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "y": 2, "_operations_count": {"counter": 6}})
|
204 |
+
|
205 |
+
state = {"x": 8}
|
206 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
207 |
+
# evaluate returns the value of the last assignment.
|
208 |
+
assert result == 5
|
209 |
+
self.assertDictEqualNoPrint(state, {"x": 8, "y": 5, "_operations_count": {"counter": 6}})
|
210 |
+
|
211 |
+
def test_evaluate_list(self):
|
212 |
+
code = "test_list = [x, add_two(x)]"
|
213 |
+
state = {"x": 3}
|
214 |
+
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
215 |
+
assert result == [3, 5]
|
216 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": {"counter": 5}})
|
217 |
+
|
218 |
+
def test_evaluate_name(self):
|
219 |
+
code = "y = x"
|
220 |
+
state = {"x": 3}
|
221 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
222 |
+
assert result == 3
|
223 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "y": 3, "_operations_count": {"counter": 2}})
|
224 |
+
|
225 |
+
def test_evaluate_subscript(self):
|
226 |
+
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
227 |
+
state = {"x": 3}
|
228 |
+
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
229 |
+
assert result == 5
|
230 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": {"counter": 9}})
|
231 |
+
|
232 |
+
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
233 |
+
state = {"x": 3}
|
234 |
+
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
235 |
+
assert result == 5
|
236 |
+
self.assertDictEqualNoPrint(
|
237 |
+
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": {"counter": 11}}
|
238 |
+
)
|
239 |
+
|
240 |
+
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
241 |
+
state = {}
|
242 |
+
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
243 |
+
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
244 |
+
|
245 |
+
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
246 |
+
code = """
|
247 |
+
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
|
248 |
+
for result in search_results:
|
249 |
+
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
|
250 |
+
current_weather_url = result['href']
|
251 |
+
print(current_weather_url)
|
252 |
+
break"""
|
253 |
+
with pytest.raises(InterpreterError) as e:
|
254 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
255 |
+
assert "You're trying to subscript a string with a string index" in e
|
256 |
+
|
257 |
+
def test_evaluate_for(self):
|
258 |
+
code = "x = 0\nfor i in range(3):\n x = i"
|
259 |
+
state = {}
|
260 |
+
result, _ = evaluate_python_code(code, {"range": range}, state=state)
|
261 |
+
assert result == 2
|
262 |
+
self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": {"counter": 11}})
|
263 |
+
|
264 |
+
def test_evaluate_binop(self):
|
265 |
+
code = "y + x"
|
266 |
+
state = {"x": 3, "y": 6}
|
267 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
268 |
+
assert result == 9
|
269 |
+
self.assertDictEqualNoPrint(state, {"x": 3, "y": 6, "_operations_count": {"counter": 4}})
|
270 |
+
|
271 |
+
def test_recursive_function(self):
|
272 |
+
code = """
|
273 |
+
def recur_fibo(n):
|
274 |
+
if n <= 1:
|
275 |
+
return n
|
276 |
+
else:
|
277 |
+
return(recur_fibo(n-1) + recur_fibo(n-2))
|
278 |
+
recur_fibo(6)"""
|
279 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
280 |
+
assert result == 8
|
281 |
+
|
282 |
+
def test_max_operations(self):
|
283 |
+
# Check that operation counter is not reset in functions
|
284 |
+
code = dedent(
|
285 |
+
"""
|
286 |
+
def func(a):
|
287 |
+
for j in range(10):
|
288 |
+
a += j
|
289 |
+
return a
|
290 |
+
|
291 |
+
for i in range(5):
|
292 |
+
func(i)
|
293 |
+
"""
|
294 |
+
)
|
295 |
+
with patch("smolagents.local_python_executor.MAX_OPERATIONS", 100):
|
296 |
+
with pytest.raises(InterpreterError) as exception_info:
|
297 |
+
evaluate_python_code(code, {"range": range}, state={})
|
298 |
+
assert "Reached the max number of operations" in str(exception_info.value)
|
299 |
+
|
300 |
+
def test_operations_count(self):
|
301 |
+
# Check that operation counter is not reset in functions
|
302 |
+
code = dedent(
|
303 |
+
"""
|
304 |
+
def func():
|
305 |
+
return 0
|
306 |
+
|
307 |
+
func()
|
308 |
+
"""
|
309 |
+
)
|
310 |
+
state = {}
|
311 |
+
evaluate_python_code(code, {"range": range}, state=state)
|
312 |
+
assert state["_operations_count"]["counter"] == 5
|
313 |
+
|
314 |
+
def test_evaluate_string_methods(self):
|
315 |
+
code = "'hello'.replace('h', 'o').split('e')"
|
316 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
317 |
+
assert result == ["o", "llo"]
|
318 |
+
|
319 |
+
def test_evaluate_slicing(self):
|
320 |
+
code = "'hello'[1:3][::-1]"
|
321 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
322 |
+
assert result == "le"
|
323 |
+
|
324 |
+
def test_access_attributes(self):
|
325 |
+
class A:
|
326 |
+
attr = 2
|
327 |
+
|
328 |
+
code = "A.attr"
|
329 |
+
result, _ = evaluate_python_code(code, {}, state={"A": A})
|
330 |
+
assert result == 2
|
331 |
+
|
332 |
+
def test_list_comprehension(self):
|
333 |
+
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
334 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
335 |
+
assert result == "t-h-e-s-e-a-g-u-l-l"
|
336 |
+
|
337 |
+
def test_string_indexing(self):
|
338 |
+
code = """text_block = [
|
339 |
+
"THESE",
|
340 |
+
"AGULL"
|
341 |
+
]
|
342 |
+
sentence = ""
|
343 |
+
for block in text_block:
|
344 |
+
for col in range(len(text_block[0])):
|
345 |
+
sentence += block[col]
|
346 |
+
"""
|
347 |
+
result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
348 |
+
assert result == "THESEAGULL"
|
349 |
+
|
350 |
+
def test_tuples(self):
|
351 |
+
code = "x = (1, 2, 3)\nx[1]"
|
352 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
353 |
+
assert result == 2
|
354 |
+
|
355 |
+
code = """
|
356 |
+
digits, i = [1, 2, 3], 1
|
357 |
+
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
358 |
+
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
|
359 |
+
|
360 |
+
code = """
|
361 |
+
def calculate_isbn_10_check_digit(number):
|
362 |
+
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
|
363 |
+
remainder = total % 11
|
364 |
+
check_digit = 11 - remainder
|
365 |
+
if check_digit == 10:
|
366 |
+
return 'X'
|
367 |
+
elif check_digit == 11:
|
368 |
+
return '0'
|
369 |
+
else:
|
370 |
+
return str(check_digit)
|
371 |
+
|
372 |
+
# Given 9-digit numbers
|
373 |
+
numbers = [
|
374 |
+
"478225952",
|
375 |
+
"643485613",
|
376 |
+
"739394228",
|
377 |
+
"291726859",
|
378 |
+
"875262394",
|
379 |
+
"542617795",
|
380 |
+
"031810713",
|
381 |
+
"957007669",
|
382 |
+
"871467426"
|
383 |
+
]
|
384 |
+
|
385 |
+
# Calculate check digits for each number
|
386 |
+
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
|
387 |
+
print(check_digits)
|
388 |
+
"""
|
389 |
+
state = {}
|
390 |
+
evaluate_python_code(
|
391 |
+
code,
|
392 |
+
{
|
393 |
+
"range": range,
|
394 |
+
"print": print,
|
395 |
+
"sum": sum,
|
396 |
+
"enumerate": enumerate,
|
397 |
+
"int": int,
|
398 |
+
"str": str,
|
399 |
+
},
|
400 |
+
state,
|
401 |
+
)
|
402 |
+
|
403 |
+
def test_listcomp(self):
|
404 |
+
code = "x = [i for i in range(3)]"
|
405 |
+
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
406 |
+
assert result == [0, 1, 2]
|
407 |
+
|
408 |
+
def test_setcomp(self):
|
409 |
+
code = "batman_times = {entry['time'] for entry in [{'time': 10}, {'time': 19}, {'time': 20}]}"
|
410 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
411 |
+
assert result == {10, 19, 20}
|
412 |
+
|
413 |
+
def test_break_continue(self):
|
414 |
+
code = "for i in range(10):\n if i == 5:\n break\ni"
|
415 |
+
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
416 |
+
assert result == 5
|
417 |
+
|
418 |
+
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
419 |
+
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
420 |
+
assert result == 9
|
421 |
+
|
422 |
+
def test_call_int(self):
|
423 |
+
code = "import math\nstr(math.ceil(149))"
|
424 |
+
result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
425 |
+
assert result == "149"
|
426 |
+
|
427 |
+
def test_lambda(self):
|
428 |
+
code = "f = lambda x: x + 2\nf(3)"
|
429 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
430 |
+
assert result == 5
|
431 |
+
|
432 |
+
def test_dictcomp(self):
|
433 |
+
code = "x = {i: i**2 for i in range(3)}"
|
434 |
+
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
435 |
+
assert result == {0: 0, 1: 1, 2: 4}
|
436 |
+
|
437 |
+
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
438 |
+
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
439 |
+
assert result == {102: "b"}
|
440 |
+
|
441 |
+
code = """
|
442 |
+
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
443 |
+
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
444 |
+
"""
|
445 |
+
result, _ = evaluate_python_code(code, {}, state={})
|
446 |
+
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
447 |
+
|
448 |
+
def test_tuple_assignment(self):
|
449 |
+
code = "a, b = 0, 1\nb"
|
450 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
451 |
+
assert result == 1
|
452 |
+
|
453 |
+
def test_while(self):
|
454 |
+
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
455 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
456 |
+
assert result == 3
|
457 |
+
|
458 |
+
# test infinite loop
|
459 |
+
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
460 |
+
with patch("smolagents.local_python_executor.MAX_WHILE_ITERATIONS", 100):
|
461 |
+
with pytest.raises(InterpreterError, match=".*Maximum number of 100 iterations in While loop exceeded"):
|
462 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
463 |
+
|
464 |
+
# test lazy evaluation
|
465 |
+
code = dedent(
|
466 |
+
"""
|
467 |
+
house_positions = [0, 7, 10, 15, 18, 22, 22]
|
468 |
+
i, n, loc = 0, 7, 30
|
469 |
+
while i < n and house_positions[i] <= loc:
|
470 |
+
i += 1
|
471 |
+
"""
|
472 |
+
)
|
473 |
+
state = {}
|
474 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
475 |
+
|
476 |
+
def test_generator(self):
|
477 |
+
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
478 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
479 |
+
assert result == [1, 4, 9, 16, 25]
|
480 |
+
|
481 |
+
def test_boolops(self):
|
482 |
+
code = """if (not (a > b and a > c)) or d > e:
|
483 |
+
best_city = "Brooklyn"
|
484 |
+
else:
|
485 |
+
best_city = "Manhattan"
|
486 |
+
best_city
|
487 |
+
"""
|
488 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
489 |
+
assert result == "Brooklyn"
|
490 |
+
|
491 |
+
code = """if d > e and a < b:
|
492 |
+
best_city = "Brooklyn"
|
493 |
+
elif d < e and a < b:
|
494 |
+
best_city = "Sacramento"
|
495 |
+
else:
|
496 |
+
best_city = "Manhattan"
|
497 |
+
best_city
|
498 |
+
"""
|
499 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
500 |
+
assert result == "Sacramento"
|
501 |
+
|
502 |
+
# Short-circuit evaluation:
|
503 |
+
# (T and 0) or (T and T) => 0 or True => True
|
504 |
+
code = "result = (x > 3 and y) or (z == 10 and not y)\nresult"
|
505 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"x": 5, "y": 0, "z": 10})
|
506 |
+
assert result
|
507 |
+
|
508 |
+
# (None or "") or "Found" => "" or "Found" => "Found"
|
509 |
+
code = "result = (a or c) or b\nresult"
|
510 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": None, "b": "Found", "c": ""})
|
511 |
+
assert result == "Found"
|
512 |
+
|
513 |
+
# ("First" and "") or "Third" => "" or "Third" -> "Third"
|
514 |
+
code = "result = (a and b) or c\nresult"
|
515 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": "First", "b": "", "c": "Third"})
|
516 |
+
assert result == "Third"
|
517 |
+
|
518 |
+
def test_if_conditions(self):
|
519 |
+
code = """char='a'
|
520 |
+
if char.isalpha():
|
521 |
+
print('2')"""
|
522 |
+
state = {}
|
523 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
524 |
+
assert state["_print_outputs"].value == "2\n"
|
525 |
+
|
526 |
+
def test_imports(self):
|
527 |
+
code = "import math\nmath.sqrt(4)"
|
528 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
529 |
+
assert result == 2.0
|
530 |
+
|
531 |
+
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
532 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
533 |
+
assert result == "lose"
|
534 |
+
|
535 |
+
code = "import time, re\ntime.sleep(0.1)"
|
536 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
537 |
+
assert result is None
|
538 |
+
|
539 |
+
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
540 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
541 |
+
assert result == 1
|
542 |
+
|
543 |
+
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
544 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
545 |
+
assert result == [0, 1, 2]
|
546 |
+
|
547 |
+
code = "import re\nre.search('a', 'abc').group()"
|
548 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
549 |
+
assert result == "a"
|
550 |
+
|
551 |
+
code = "import stat\nstat.S_ISREG(0o100644)"
|
552 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
553 |
+
assert result
|
554 |
+
|
555 |
+
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
556 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
557 |
+
assert result == 2.8
|
558 |
+
|
559 |
+
code = "import unicodedata\nunicodedata.name('A')"
|
560 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
561 |
+
assert result == "LATIN CAPITAL LETTER A"
|
562 |
+
|
563 |
+
# Test submodules are handled properly, thus not raising error
|
564 |
+
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
565 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
|
566 |
+
|
567 |
+
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
568 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
|
569 |
+
|
570 |
+
def test_additional_imports(self):
|
571 |
+
code = "import numpy as np"
|
572 |
+
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
573 |
+
|
574 |
+
# Test that allowing 'numpy.*' allows numpy root package and its submodules
|
575 |
+
code = "import numpy as np\nnp.random.default_rng(123)\nnp.array([1, 2])"
|
576 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.*"])
|
577 |
+
|
578 |
+
# Test that allowing 'numpy.*' allows importing a submodule
|
579 |
+
code = "import numpy.random as rd\nrd.default_rng(12345)"
|
580 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.*"])
|
581 |
+
|
582 |
+
code = "import numpy.random as rd"
|
583 |
+
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
584 |
+
evaluate_python_code(code, authorized_imports=["numpy.*"], state={})
|
585 |
+
evaluate_python_code(code, authorized_imports=["*"], state={})
|
586 |
+
with pytest.raises(InterpreterError):
|
587 |
+
evaluate_python_code(code, authorized_imports=["random"], state={})
|
588 |
+
|
589 |
+
with pytest.raises(InterpreterError):
|
590 |
+
evaluate_python_code(code, authorized_imports=["numpy.a"], state={})
|
591 |
+
with pytest.raises(InterpreterError):
|
592 |
+
evaluate_python_code(code, authorized_imports=["numpy.a.*"], state={})
|
593 |
+
|
594 |
+
def test_multiple_comparators(self):
|
595 |
+
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
596 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
597 |
+
assert not result
|
598 |
+
|
599 |
+
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
600 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
601 |
+
assert not result
|
602 |
+
|
603 |
+
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
604 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
605 |
+
assert not result
|
606 |
+
|
607 |
+
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
608 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
609 |
+
assert result
|
610 |
+
|
611 |
+
def test_print_output(self):
|
612 |
+
code = "print('Hello world!')\nprint('Ok no one cares')"
|
613 |
+
state = {}
|
614 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
615 |
+
assert result is None
|
616 |
+
assert state["_print_outputs"].value == "Hello world!\nOk no one cares\n"
|
617 |
+
|
618 |
+
# Test print in function (state copy)
|
619 |
+
code = """
|
620 |
+
print("1")
|
621 |
+
def function():
|
622 |
+
print("2")
|
623 |
+
function()"""
|
624 |
+
state = {}
|
625 |
+
evaluate_python_code(code, {"print": print}, state=state)
|
626 |
+
assert state["_print_outputs"].value == "1\n2\n"
|
627 |
+
|
628 |
+
# Test print in list comprehension (state copy)
|
629 |
+
code = """
|
630 |
+
print("1")
|
631 |
+
def function():
|
632 |
+
print("2")
|
633 |
+
[function() for i in range(10)]"""
|
634 |
+
state = {}
|
635 |
+
evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
636 |
+
assert state["_print_outputs"].value == "1\n2\n2\n2\n2\n2\n2\n2\n2\n2\n2\n"
|
637 |
+
|
638 |
+
def test_tuple_target_in_iterator(self):
|
639 |
+
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
640 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
641 |
+
assert result == "Samuel"
|
642 |
+
|
643 |
+
def test_classes(self):
|
644 |
+
code = """
|
645 |
+
class Animal:
|
646 |
+
species = "Generic Animal"
|
647 |
+
|
648 |
+
def __init__(self, name, age):
|
649 |
+
self.name = name
|
650 |
+
self.age = age
|
651 |
+
|
652 |
+
def sound(self):
|
653 |
+
return "The animal makes a sound."
|
654 |
+
|
655 |
+
def __str__(self):
|
656 |
+
return f"{self.name}, {self.age} years old"
|
657 |
+
|
658 |
+
class Dog(Animal):
|
659 |
+
species = "Canine"
|
660 |
+
|
661 |
+
def __init__(self, name, age, breed):
|
662 |
+
super().__init__(name, age)
|
663 |
+
self.breed = breed
|
664 |
+
|
665 |
+
def sound(self):
|
666 |
+
return "The dog barks."
|
667 |
+
|
668 |
+
def __str__(self):
|
669 |
+
return f"{self.name}, {self.age} years old, {self.breed}"
|
670 |
+
|
671 |
+
class Cat(Animal):
|
672 |
+
def sound(self):
|
673 |
+
return "The cat meows."
|
674 |
+
|
675 |
+
def __str__(self):
|
676 |
+
return f"{self.name}, {self.age} years old, {self.species}"
|
677 |
+
|
678 |
+
|
679 |
+
# Testing multiple instances
|
680 |
+
dog1 = Dog("Fido", 3, "Labrador")
|
681 |
+
dog2 = Dog("Buddy", 5, "Golden Retriever")
|
682 |
+
|
683 |
+
# Testing method with built-in function
|
684 |
+
animals = [dog1, dog2, Cat("Whiskers", 2)]
|
685 |
+
num_animals = len(animals)
|
686 |
+
|
687 |
+
# Testing exceptions in methods
|
688 |
+
class ExceptionTest:
|
689 |
+
def method_that_raises(self):
|
690 |
+
raise ValueError("An error occurred")
|
691 |
+
|
692 |
+
try:
|
693 |
+
exc_test = ExceptionTest()
|
694 |
+
exc_test.method_that_raises()
|
695 |
+
except ValueError as e:
|
696 |
+
exception_message = str(e)
|
697 |
+
|
698 |
+
|
699 |
+
# Collecting results
|
700 |
+
dog1_sound = dog1.sound()
|
701 |
+
dog1_str = str(dog1)
|
702 |
+
dog2_sound = dog2.sound()
|
703 |
+
dog2_str = str(dog2)
|
704 |
+
cat = Cat("Whiskers", 2)
|
705 |
+
cat_sound = cat.sound()
|
706 |
+
cat_str = str(cat)
|
707 |
+
"""
|
708 |
+
state = {}
|
709 |
+
evaluate_python_code(
|
710 |
+
code,
|
711 |
+
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
712 |
+
state=state,
|
713 |
+
)
|
714 |
+
|
715 |
+
# Assert results
|
716 |
+
assert state["dog1_sound"] == "The dog barks."
|
717 |
+
assert state["dog1_str"] == "Fido, 3 years old, Labrador"
|
718 |
+
assert state["dog2_sound"] == "The dog barks."
|
719 |
+
assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
|
720 |
+
assert state["cat_sound"] == "The cat meows."
|
721 |
+
assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
|
722 |
+
assert state["num_animals"] == 3
|
723 |
+
assert state["exception_message"] == "An error occurred"
|
724 |
+
|
725 |
+
def test_variable_args(self):
|
726 |
+
code = """
|
727 |
+
def var_args_method(self, *args, **kwargs):
|
728 |
+
return sum(args) + sum(kwargs.values())
|
729 |
+
|
730 |
+
var_args_method(1, 2, 3, x=4, y=5)
|
731 |
+
"""
|
732 |
+
state = {}
|
733 |
+
result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
|
734 |
+
assert result == 15
|
735 |
+
|
736 |
+
def test_exceptions(self):
|
737 |
+
code = """
|
738 |
+
def method_that_raises(self):
|
739 |
+
raise ValueError("An error occurred")
|
740 |
+
|
741 |
+
try:
|
742 |
+
method_that_raises()
|
743 |
+
except ValueError as e:
|
744 |
+
exception_message = str(e)
|
745 |
+
"""
|
746 |
+
state = {}
|
747 |
+
evaluate_python_code(
|
748 |
+
code,
|
749 |
+
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
750 |
+
state=state,
|
751 |
+
)
|
752 |
+
assert state["exception_message"] == "An error occurred"
|
753 |
+
|
754 |
+
def test_print(self):
|
755 |
+
code = "print(min([1, 2, 3]))"
|
756 |
+
state = {}
|
757 |
+
evaluate_python_code(code, {"min": min, "print": print}, state=state)
|
758 |
+
assert state["_print_outputs"].value == "1\n"
|
759 |
+
|
760 |
+
def test_types_as_objects(self):
|
761 |
+
code = "type_a = float(2); type_b = str; type_c = int"
|
762 |
+
state = {}
|
763 |
+
result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
764 |
+
# Type objects are not wrapped by safer_func
|
765 |
+
assert not hasattr(result, "__wrapped__")
|
766 |
+
assert result is int
|
767 |
+
|
768 |
+
def test_tuple_id(self):
|
769 |
+
code = """
|
770 |
+
food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
|
771 |
+
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
772 |
+
"""
|
773 |
+
state = {}
|
774 |
+
result, is_final_answer = evaluate_python_code(code, {}, state=state)
|
775 |
+
assert result == ["orange", "pear"]
|
776 |
+
|
777 |
+
def test_nonsimple_augassign(self):
|
778 |
+
code = """
|
779 |
+
counts_dict = {'a': 0}
|
780 |
+
counts_dict['a'] += 1
|
781 |
+
counts_list = [1, 2, 3]
|
782 |
+
counts_list += [4, 5, 6]
|
783 |
+
|
784 |
+
class Counter:
|
785 |
+
def __init__(self):
|
786 |
+
self.count = 0
|
787 |
+
|
788 |
+
a = Counter()
|
789 |
+
a.count += 1
|
790 |
+
"""
|
791 |
+
state = {}
|
792 |
+
evaluate_python_code(code, {}, state=state)
|
793 |
+
assert state["counts_dict"] == {"a": 1}
|
794 |
+
assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
|
795 |
+
assert state["a"].count == 1
|
796 |
+
|
797 |
+
def test_adding_int_to_list_raises_error(self):
|
798 |
+
code = """
|
799 |
+
counts = [1, 2, 3]
|
800 |
+
counts += 1"""
|
801 |
+
with pytest.raises(InterpreterError) as e:
|
802 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
803 |
+
assert "Cannot add non-list value 1 to a list." in str(e)
|
804 |
+
|
805 |
+
def test_error_highlights_correct_line_of_code(self):
|
806 |
+
code = """a = 1
|
807 |
+
b = 2
|
808 |
+
|
809 |
+
counts = [1, 2, 3]
|
810 |
+
counts += 1
|
811 |
+
b += 1"""
|
812 |
+
with pytest.raises(InterpreterError) as e:
|
813 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
814 |
+
assert "Code execution failed at line 'counts += 1" in str(e)
|
815 |
+
|
816 |
+
def test_error_type_returned_in_function_call(self):
|
817 |
+
code = """def error_function():
|
818 |
+
raise ValueError("error")
|
819 |
+
|
820 |
+
error_function()"""
|
821 |
+
with pytest.raises(InterpreterError) as e:
|
822 |
+
evaluate_python_code(code)
|
823 |
+
assert "error" in str(e)
|
824 |
+
assert "ValueError" in str(e)
|
825 |
+
|
826 |
+
def test_assert(self):
|
827 |
+
code = """
|
828 |
+
assert 1 == 1
|
829 |
+
assert 1 == 2
|
830 |
+
"""
|
831 |
+
with pytest.raises(InterpreterError) as e:
|
832 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
833 |
+
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
|
834 |
+
|
835 |
+
def test_with_context_manager(self):
|
836 |
+
code = """
|
837 |
+
class SimpleLock:
|
838 |
+
def __init__(self):
|
839 |
+
self.locked = False
|
840 |
+
|
841 |
+
def __enter__(self):
|
842 |
+
self.locked = True
|
843 |
+
return self
|
844 |
+
|
845 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
846 |
+
self.locked = False
|
847 |
+
|
848 |
+
lock = SimpleLock()
|
849 |
+
|
850 |
+
with lock as l:
|
851 |
+
assert l.locked == True
|
852 |
+
|
853 |
+
assert lock.locked == False
|
854 |
+
"""
|
855 |
+
state = {}
|
856 |
+
tools = {}
|
857 |
+
evaluate_python_code(code, tools, state=state)
|
858 |
+
|
859 |
+
def test_default_arg_in_function(self):
|
860 |
+
code = """
|
861 |
+
def f(a, b=333, n=1000):
|
862 |
+
return b + n
|
863 |
+
n = f(1, n=667)
|
864 |
+
"""
|
865 |
+
res, is_final_answer = evaluate_python_code(code, {}, {})
|
866 |
+
assert res == 1000
|
867 |
+
assert not is_final_answer
|
868 |
+
|
869 |
+
def test_set(self):
|
870 |
+
code = """
|
871 |
+
S1 = {'a', 'b', 'c'}
|
872 |
+
S2 = {'b', 'c', 'd'}
|
873 |
+
S3 = S1.difference(S2)
|
874 |
+
S4 = S1.intersection(S2)
|
875 |
+
"""
|
876 |
+
state = {}
|
877 |
+
evaluate_python_code(code, {}, state=state)
|
878 |
+
assert state["S3"] == {"a"}
|
879 |
+
assert state["S4"] == {"b", "c"}
|
880 |
+
|
881 |
+
def test_break(self):
|
882 |
+
code = """
|
883 |
+
i = 0
|
884 |
+
|
885 |
+
while True:
|
886 |
+
i+= 1
|
887 |
+
if i==3:
|
888 |
+
break
|
889 |
+
|
890 |
+
i"""
|
891 |
+
result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
892 |
+
assert result == 3
|
893 |
+
assert not is_final_answer
|
894 |
+
|
895 |
+
def test_return(self):
|
896 |
+
# test early returns
|
897 |
+
code = """
|
898 |
+
def add_one(n, shift):
|
899 |
+
if True:
|
900 |
+
return n + shift
|
901 |
+
return n
|
902 |
+
|
903 |
+
add_one(1, 1)
|
904 |
+
"""
|
905 |
+
state = {}
|
906 |
+
result, is_final_answer = evaluate_python_code(
|
907 |
+
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
908 |
+
)
|
909 |
+
assert result == 2
|
910 |
+
|
911 |
+
# test returning None
|
912 |
+
code = """
|
913 |
+
def returns_none(a):
|
914 |
+
return
|
915 |
+
|
916 |
+
returns_none(1)
|
917 |
+
"""
|
918 |
+
state = {}
|
919 |
+
result, is_final_answer = evaluate_python_code(
|
920 |
+
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
921 |
+
)
|
922 |
+
assert result is None
|
923 |
+
|
924 |
+
def test_nested_for_loop(self):
|
925 |
+
code = """
|
926 |
+
all_res = []
|
927 |
+
for i in range(10):
|
928 |
+
subres = []
|
929 |
+
for j in range(i):
|
930 |
+
subres.append(j)
|
931 |
+
all_res.append(subres)
|
932 |
+
|
933 |
+
out = [i for sublist in all_res for i in sublist]
|
934 |
+
out[:10]
|
935 |
+
"""
|
936 |
+
state = {}
|
937 |
+
result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
938 |
+
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
939 |
+
|
940 |
+
def test_pandas(self):
|
941 |
+
code = """
|
942 |
+
import pandas as pd
|
943 |
+
|
944 |
+
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
|
945 |
+
|
946 |
+
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
|
947 |
+
|
948 |
+
parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
949 |
+
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
950 |
+
"""
|
951 |
+
state = {}
|
952 |
+
result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
953 |
+
assert np.array_equal(result, [-1, 5])
|
954 |
+
|
955 |
+
code = """
|
956 |
+
import pandas as pd
|
957 |
+
|
958 |
+
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
959 |
+
|
960 |
+
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
961 |
+
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
962 |
+
"""
|
963 |
+
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
964 |
+
assert np.array_equal(result.values[0], [104, 1])
|
965 |
+
|
966 |
+
# Test groupby
|
967 |
+
code = """import pandas as pd
|
968 |
+
data = pd.DataFrame.from_dict([
|
969 |
+
{"Pclass": 1, "Survived": 1},
|
970 |
+
{"Pclass": 2, "Survived": 0},
|
971 |
+
{"Pclass": 2, "Survived": 1}
|
972 |
+
])
|
973 |
+
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
974 |
+
"""
|
975 |
+
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
976 |
+
assert result.values[1] == 0.5
|
977 |
+
|
978 |
+
# Test loc and iloc
|
979 |
+
code = """import pandas as pd
|
980 |
+
data = pd.DataFrame.from_dict([
|
981 |
+
{"Pclass": 1, "Survived": 1},
|
982 |
+
{"Pclass": 2, "Survived": 0},
|
983 |
+
{"Pclass": 2, "Survived": 1}
|
984 |
+
])
|
985 |
+
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
986 |
+
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
987 |
+
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
|
988 |
+
"""
|
989 |
+
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
990 |
+
|
991 |
+
def test_starred(self):
|
992 |
+
code = """
|
993 |
+
from math import radians, sin, cos, sqrt, atan2
|
994 |
+
|
995 |
+
def haversine(lat1, lon1, lat2, lon2):
|
996 |
+
R = 6371000 # Radius of the Earth in meters
|
997 |
+
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
|
998 |
+
dlat = lat2 - lat1
|
999 |
+
dlon = lon2 - lon1
|
1000 |
+
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
1001 |
+
c = 2 * atan2(sqrt(a), sqrt(1 - a))
|
1002 |
+
distance = R * c
|
1003 |
+
return distance
|
1004 |
+
|
1005 |
+
coords_geneva = (46.1978, 6.1342)
|
1006 |
+
coords_barcelona = (41.3869, 2.1660)
|
1007 |
+
|
1008 |
+
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
1009 |
+
"""
|
1010 |
+
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
1011 |
+
assert round(result, 1) == 622395.4
|
1012 |
+
|
1013 |
+
def test_for(self):
|
1014 |
+
code = """
|
1015 |
+
shifts = {
|
1016 |
+
"Worker A": ("6:45 pm", "8:00 pm"),
|
1017 |
+
"Worker B": ("10:00 am", "11:45 am")
|
1018 |
+
}
|
1019 |
+
|
1020 |
+
shift_intervals = {}
|
1021 |
+
for worker, (start, end) in shifts.items():
|
1022 |
+
shift_intervals[worker] = end
|
1023 |
+
shift_intervals
|
1024 |
+
"""
|
1025 |
+
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
1026 |
+
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
1027 |
+
|
1028 |
+
def test_syntax_error_points_error(self):
|
1029 |
+
code = "a = ;"
|
1030 |
+
with pytest.raises(InterpreterError) as e:
|
1031 |
+
evaluate_python_code(code)
|
1032 |
+
assert "SyntaxError" in str(e)
|
1033 |
+
assert " ^" in str(e)
|
1034 |
+
|
1035 |
+
def test_close_matches_subscript(self):
|
1036 |
+
code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
|
1037 |
+
with pytest.raises(Exception) as e:
|
1038 |
+
evaluate_python_code(code)
|
1039 |
+
assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "")
|
1040 |
+
|
1041 |
+
def test_dangerous_builtins_calls_are_blocked(self):
|
1042 |
+
unsafe_code = "import os"
|
1043 |
+
dangerous_code = f"""
|
1044 |
+
exec = callable.__self__.exec
|
1045 |
+
compile = callable.__self__.compile
|
1046 |
+
exec(compile('{unsafe_code}', 'no filename', 'exec'))
|
1047 |
+
"""
|
1048 |
+
|
1049 |
+
with pytest.raises(InterpreterError):
|
1050 |
+
evaluate_python_code(unsafe_code, static_tools=BASE_PYTHON_TOOLS)
|
1051 |
+
|
1052 |
+
with pytest.raises(InterpreterError):
|
1053 |
+
evaluate_python_code(dangerous_code, static_tools=BASE_PYTHON_TOOLS)
|
1054 |
+
|
1055 |
+
def test_final_answer_accepts_kwarg_answer(self):
|
1056 |
+
code = "final_answer(answer=2)"
|
1057 |
+
result, _ = evaluate_python_code(code, {"final_answer": (lambda answer: 2 * answer)}, state={})
|
1058 |
+
assert result == 4
|
1059 |
+
|
1060 |
+
def test_dangerous_builtins_are_callable_if_explicitly_added(self):
|
1061 |
+
dangerous_code = dedent("""
|
1062 |
+
eval("1 + 1")
|
1063 |
+
exec(compile("1 + 1", "no filename", "exec"))
|
1064 |
+
""")
|
1065 |
+
evaluate_python_code(
|
1066 |
+
dangerous_code, static_tools={"compile": compile, "eval": eval, "exec": exec} | BASE_PYTHON_TOOLS
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
def test_can_import_os_if_explicitly_authorized(self):
|
1070 |
+
dangerous_code = "import os; os.listdir('./')"
|
1071 |
+
evaluate_python_code(dangerous_code, authorized_imports=["os"])
|
1072 |
+
|
1073 |
+
def test_can_import_os_if_all_imports_authorized(self):
|
1074 |
+
dangerous_code = "import os; os.listdir('./')"
|
1075 |
+
evaluate_python_code(dangerous_code, authorized_imports=["*"])
|
1076 |
+
|
1077 |
+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
1078 |
+
def test_can_import_scipy_if_explicitly_authorized(self):
|
1079 |
+
code = "import scipy"
|
1080 |
+
evaluate_python_code(code, authorized_imports=["scipy"])
|
1081 |
+
|
1082 |
+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
1083 |
+
def test_can_import_sklearn_if_explicitly_authorized(self):
|
1084 |
+
code = "import sklearn"
|
1085 |
+
evaluate_python_code(code, authorized_imports=["sklearn"])
|
1086 |
+
|
1087 |
+
def test_function_def_recovers_source_code(self):
|
1088 |
+
executor = LocalPythonExecutor([])
|
1089 |
+
|
1090 |
+
executor.send_tools({"final_answer": FinalAnswerTool()})
|
1091 |
+
|
1092 |
+
res, _, _ = executor(
|
1093 |
+
dedent(
|
1094 |
+
"""
|
1095 |
+
def target_function():
|
1096 |
+
return "Hello world"
|
1097 |
+
|
1098 |
+
final_answer(target_function)
|
1099 |
+
"""
|
1100 |
+
)
|
1101 |
+
)
|
1102 |
+
assert res.__name__ == "target_function"
|
1103 |
+
assert res.__source__ == "def target_function():\n return 'Hello world'"
|
1104 |
+
|
1105 |
+
def test_evaluate_class_def_with_pass(self):
|
1106 |
+
code = dedent("""
|
1107 |
+
class TestClass:
|
1108 |
+
pass
|
1109 |
+
|
1110 |
+
instance = TestClass()
|
1111 |
+
instance.attr = "value"
|
1112 |
+
result = instance.attr
|
1113 |
+
""")
|
1114 |
+
state = {}
|
1115 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
1116 |
+
assert result == "value"
|
1117 |
+
|
1118 |
+
def test_evaluate_class_def_with_ann_assign_name(self):
|
1119 |
+
"""
|
1120 |
+
Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Name target.
|
1121 |
+
|
1122 |
+
This test verifies that annotated assignments within a class definition are correctly evaluated.
|
1123 |
+
"""
|
1124 |
+
code = dedent("""
|
1125 |
+
class TestClass:
|
1126 |
+
x: int = 5
|
1127 |
+
y: str = "test"
|
1128 |
+
|
1129 |
+
instance = TestClass()
|
1130 |
+
result = (instance.x, instance.y)
|
1131 |
+
""")
|
1132 |
+
|
1133 |
+
state = {}
|
1134 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
1135 |
+
|
1136 |
+
assert result == (5, "test")
|
1137 |
+
assert isinstance(state["TestClass"], type)
|
1138 |
+
# Type objects are not wrapped by safer_func
|
1139 |
+
for value in state["TestClass"].__annotations__.values():
|
1140 |
+
assert not hasattr(value, "__wrapped__")
|
1141 |
+
assert state["TestClass"].__annotations__ == {"x": int, "y": str}
|
1142 |
+
assert state["TestClass"].x == 5
|
1143 |
+
assert state["TestClass"].y == "test"
|
1144 |
+
assert isinstance(state["instance"], state["TestClass"])
|
1145 |
+
assert state["instance"].x == 5
|
1146 |
+
assert state["instance"].y == "test"
|
1147 |
+
|
1148 |
+
def test_evaluate_class_def_with_ann_assign_attribute(self):
|
1149 |
+
"""
|
1150 |
+
Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Attribute target.
|
1151 |
+
|
1152 |
+
This test ensures that class attributes using attribute notation are correctly handled.
|
1153 |
+
"""
|
1154 |
+
code = dedent("""
|
1155 |
+
class TestSubClass:
|
1156 |
+
attr = 1
|
1157 |
+
class TestClass:
|
1158 |
+
data: TestSubClass = TestSubClass()
|
1159 |
+
data.attr: str = "value"
|
1160 |
+
|
1161 |
+
result = TestClass.data.attr
|
1162 |
+
""")
|
1163 |
+
|
1164 |
+
state = {}
|
1165 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
1166 |
+
|
1167 |
+
assert result == "value"
|
1168 |
+
assert isinstance(state["TestClass"], type)
|
1169 |
+
assert state["TestClass"].__annotations__.keys() == {"data"}
|
1170 |
+
assert isinstance(state["TestClass"].__annotations__["data"], type)
|
1171 |
+
assert state["TestClass"].__annotations__["data"].__name__ == "TestSubClass"
|
1172 |
+
assert state["TestClass"].data.attr == "value"
|
1173 |
+
|
1174 |
+
def test_evaluate_class_def_with_ann_assign_subscript(self):
|
1175 |
+
"""
|
1176 |
+
Test evaluate_class_def function when stmt is an instance of ast.AnnAssign with ast.Subscript target.
|
1177 |
+
|
1178 |
+
This test ensures that class attributes using subscript notation are correctly handled.
|
1179 |
+
"""
|
1180 |
+
code = dedent("""
|
1181 |
+
class TestClass:
|
1182 |
+
key_data: dict = {}
|
1183 |
+
key_data["key"]: str = "value"
|
1184 |
+
index_data: list = [10, 20, 30]
|
1185 |
+
index_data[0:2]: list[str] = ["a", "b"]
|
1186 |
+
|
1187 |
+
result = (TestClass.key_data['key'], TestClass.index_data[1:])
|
1188 |
+
""")
|
1189 |
+
|
1190 |
+
state = {}
|
1191 |
+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
1192 |
+
|
1193 |
+
assert result == ("value", ["b", 30])
|
1194 |
+
assert isinstance(state["TestClass"], type)
|
1195 |
+
# Type objects are not wrapped by safer_func
|
1196 |
+
for value in state["TestClass"].__annotations__.values():
|
1197 |
+
assert not hasattr(value, "__wrapped__")
|
1198 |
+
assert state["TestClass"].__annotations__ == {"key_data": dict, "index_data": list}
|
1199 |
+
assert state["TestClass"].key_data == {"key": "value"}
|
1200 |
+
assert state["TestClass"].index_data == ["a", "b", 30]
|
1201 |
+
|
1202 |
+
def test_evaluate_annassign(self):
|
1203 |
+
code = dedent("""\
|
1204 |
+
# Basic annotated assignment
|
1205 |
+
x: int = 42
|
1206 |
+
|
1207 |
+
# Type annotations with expressions
|
1208 |
+
y: float = x / 2
|
1209 |
+
|
1210 |
+
# Type annotation without assignment
|
1211 |
+
z: list
|
1212 |
+
|
1213 |
+
# Type annotation with complex value
|
1214 |
+
names: list = ["Alice", "Bob", "Charlie"]
|
1215 |
+
|
1216 |
+
# Type hint shouldn't restrict values at runtime
|
1217 |
+
s: str = 123 # Would be a type error in static checking, but valid at runtime
|
1218 |
+
|
1219 |
+
# Access the values
|
1220 |
+
result = (x, y, names, s)
|
1221 |
+
""")
|
1222 |
+
state = {}
|
1223 |
+
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
1224 |
+
assert state["x"] == 42
|
1225 |
+
assert state["y"] == 21.0
|
1226 |
+
assert "z" not in state # z should be not be defined
|
1227 |
+
assert state["names"] == ["Alice", "Bob", "Charlie"]
|
1228 |
+
assert state["s"] == 123 # Type hints don't restrict at runtime
|
1229 |
+
assert state["result"] == (42, 21.0, ["Alice", "Bob", "Charlie"], 123)
|
1230 |
+
|
1231 |
+
@pytest.mark.parametrize(
|
1232 |
+
"code, expected_result",
|
1233 |
+
[
|
1234 |
+
(
|
1235 |
+
dedent("""\
|
1236 |
+
x = 1
|
1237 |
+
x += 2
|
1238 |
+
"""),
|
1239 |
+
3,
|
1240 |
+
),
|
1241 |
+
(
|
1242 |
+
dedent("""\
|
1243 |
+
x = "a"
|
1244 |
+
x += "b"
|
1245 |
+
"""),
|
1246 |
+
"ab",
|
1247 |
+
),
|
1248 |
+
(
|
1249 |
+
dedent("""\
|
1250 |
+
class Custom:
|
1251 |
+
def __init__(self, value):
|
1252 |
+
self.value = value
|
1253 |
+
def __iadd__(self, other):
|
1254 |
+
self.value += other * 10
|
1255 |
+
return self
|
1256 |
+
|
1257 |
+
x = Custom(1)
|
1258 |
+
x += 2
|
1259 |
+
x.value
|
1260 |
+
"""),
|
1261 |
+
21,
|
1262 |
+
),
|
1263 |
+
],
|
1264 |
+
)
|
1265 |
+
def test_evaluate_augassign(self, code, expected_result):
|
1266 |
+
state = {}
|
1267 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
1268 |
+
assert result == expected_result
|
1269 |
+
|
1270 |
+
@pytest.mark.parametrize(
|
1271 |
+
"operator, expected_result",
|
1272 |
+
[
|
1273 |
+
("+=", 7),
|
1274 |
+
("-=", 3),
|
1275 |
+
("*=", 10),
|
1276 |
+
("/=", 2.5),
|
1277 |
+
("//=", 2),
|
1278 |
+
("%=", 1),
|
1279 |
+
("**=", 25),
|
1280 |
+
("&=", 0),
|
1281 |
+
("|=", 7),
|
1282 |
+
("^=", 7),
|
1283 |
+
(">>=", 1),
|
1284 |
+
("<<=", 20),
|
1285 |
+
],
|
1286 |
+
)
|
1287 |
+
def test_evaluate_augassign_number(self, operator, expected_result):
|
1288 |
+
code = dedent("""\
|
1289 |
+
x = 5
|
1290 |
+
x {operator} 2
|
1291 |
+
""").format(operator=operator)
|
1292 |
+
state = {}
|
1293 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
1294 |
+
assert result == expected_result
|
1295 |
+
|
1296 |
+
@pytest.mark.parametrize(
|
1297 |
+
"operator, expected_result",
|
1298 |
+
[
|
1299 |
+
("+=", 7),
|
1300 |
+
("-=", 3),
|
1301 |
+
("*=", 10),
|
1302 |
+
("/=", 2.5),
|
1303 |
+
("//=", 2),
|
1304 |
+
("%=", 1),
|
1305 |
+
("**=", 25),
|
1306 |
+
("&=", 0),
|
1307 |
+
("|=", 7),
|
1308 |
+
("^=", 7),
|
1309 |
+
(">>=", 1),
|
1310 |
+
("<<=", 20),
|
1311 |
+
],
|
1312 |
+
)
|
1313 |
+
def test_evaluate_augassign_custom(self, operator, expected_result):
|
1314 |
+
operator_names = {
|
1315 |
+
"+=": "iadd",
|
1316 |
+
"-=": "isub",
|
1317 |
+
"*=": "imul",
|
1318 |
+
"/=": "itruediv",
|
1319 |
+
"//=": "ifloordiv",
|
1320 |
+
"%=": "imod",
|
1321 |
+
"**=": "ipow",
|
1322 |
+
"&=": "iand",
|
1323 |
+
"|=": "ior",
|
1324 |
+
"^=": "ixor",
|
1325 |
+
">>=": "irshift",
|
1326 |
+
"<<=": "ilshift",
|
1327 |
+
}
|
1328 |
+
code = dedent("""\
|
1329 |
+
class Custom:
|
1330 |
+
def __init__(self, value):
|
1331 |
+
self.value = value
|
1332 |
+
def __{operator_name}__(self, other):
|
1333 |
+
self.value {operator} other
|
1334 |
+
return self
|
1335 |
+
|
1336 |
+
x = Custom(5)
|
1337 |
+
x {operator} 2
|
1338 |
+
x.value
|
1339 |
+
""").format(operator=operator, operator_name=operator_names[operator])
|
1340 |
+
state = {}
|
1341 |
+
result, _ = evaluate_python_code(code, {}, state=state)
|
1342 |
+
assert result == expected_result
|
1343 |
+
|
1344 |
+
@pytest.mark.parametrize(
|
1345 |
+
"code, expected_error_message",
|
1346 |
+
[
|
1347 |
+
(
|
1348 |
+
dedent("""\
|
1349 |
+
x = 5
|
1350 |
+
del x
|
1351 |
+
x
|
1352 |
+
"""),
|
1353 |
+
"The variable `x` is not defined",
|
1354 |
+
),
|
1355 |
+
(
|
1356 |
+
dedent("""\
|
1357 |
+
x = [1, 2, 3]
|
1358 |
+
del x[2]
|
1359 |
+
x[2]
|
1360 |
+
"""),
|
1361 |
+
"IndexError: list index out of range",
|
1362 |
+
),
|
1363 |
+
(
|
1364 |
+
dedent("""\
|
1365 |
+
x = {"key": "value"}
|
1366 |
+
del x["key"]
|
1367 |
+
x["key"]
|
1368 |
+
"""),
|
1369 |
+
"Could not index {} with 'key'",
|
1370 |
+
),
|
1371 |
+
(
|
1372 |
+
dedent("""\
|
1373 |
+
del x
|
1374 |
+
"""),
|
1375 |
+
"Cannot delete name 'x': name is not defined",
|
1376 |
+
),
|
1377 |
+
],
|
1378 |
+
)
|
1379 |
+
def test_evaluate_delete(self, code, expected_error_message):
|
1380 |
+
state = {}
|
1381 |
+
with pytest.raises(InterpreterError) as exception_info:
|
1382 |
+
evaluate_python_code(code, {}, state=state)
|
1383 |
+
assert expected_error_message in str(exception_info.value)
|
1384 |
+
|
1385 |
+
def test_non_standard_comparisons(self):
|
1386 |
+
code = dedent("""\
|
1387 |
+
class NonStdEqualsResult:
|
1388 |
+
def __init__(self, left:object, right:object):
|
1389 |
+
self._left = left
|
1390 |
+
self._right = right
|
1391 |
+
def __str__(self) -> str:
|
1392 |
+
return f'{self._left} == {self._right}'
|
1393 |
+
|
1394 |
+
class NonStdComparisonClass:
|
1395 |
+
def __init__(self, value: str ):
|
1396 |
+
self._value = value
|
1397 |
+
def __str__(self):
|
1398 |
+
return self._value
|
1399 |
+
def __eq__(self, other):
|
1400 |
+
return NonStdEqualsResult(self, other)
|
1401 |
+
a = NonStdComparisonClass("a")
|
1402 |
+
b = NonStdComparisonClass("b")
|
1403 |
+
result = a == b
|
1404 |
+
""")
|
1405 |
+
result, _ = evaluate_python_code(code, state={})
|
1406 |
+
assert not isinstance(result, bool)
|
1407 |
+
assert str(result) == "a == b"
|
1408 |
+
|
1409 |
+
|
1410 |
+
class TestEvaluateBoolop:
|
1411 |
+
@pytest.mark.parametrize("a", [1, 0])
|
1412 |
+
@pytest.mark.parametrize("b", [2, 0])
|
1413 |
+
@pytest.mark.parametrize("c", [3, 0])
|
1414 |
+
def test_evaluate_boolop_and(self, a, b, c):
|
1415 |
+
boolop_ast = ast.parse("a and b and c").body[0].value
|
1416 |
+
state = {"a": a, "b": b, "c": c}
|
1417 |
+
result = evaluate_boolop(boolop_ast, state, {}, {}, [])
|
1418 |
+
assert result == (a and b and c)
|
1419 |
+
|
1420 |
+
@pytest.mark.parametrize("a", [1, 0])
|
1421 |
+
@pytest.mark.parametrize("b", [2, 0])
|
1422 |
+
@pytest.mark.parametrize("c", [3, 0])
|
1423 |
+
def test_evaluate_boolop_or(self, a, b, c):
|
1424 |
+
boolop_ast = ast.parse("a or b or c").body[0].value
|
1425 |
+
state = {"a": a, "b": b, "c": c}
|
1426 |
+
result = evaluate_boolop(boolop_ast, state, {}, {}, [])
|
1427 |
+
assert result == (a or b or c)
|
1428 |
+
|
1429 |
+
|
1430 |
+
class TestEvaluateDelete:
|
1431 |
+
@pytest.mark.parametrize(
|
1432 |
+
"code, state, expectation",
|
1433 |
+
[
|
1434 |
+
("del x", {"x": 1}, {}),
|
1435 |
+
("del x[1]", {"x": [1, 2, 3]}, {"x": [1, 3]}),
|
1436 |
+
("del x['key']", {"x": {"key": "value"}}, {"x": {}}),
|
1437 |
+
("del x", {}, InterpreterError("Cannot delete name 'x': name is not defined")),
|
1438 |
+
],
|
1439 |
+
)
|
1440 |
+
def test_evaluate_delete(self, code, state, expectation):
|
1441 |
+
delete_node = ast.parse(code).body[0]
|
1442 |
+
if isinstance(expectation, Exception):
|
1443 |
+
with pytest.raises(type(expectation)) as exception_info:
|
1444 |
+
evaluate_delete(delete_node, state, {}, {}, [])
|
1445 |
+
assert str(expectation) in str(exception_info.value)
|
1446 |
+
else:
|
1447 |
+
evaluate_delete(delete_node, state, {}, {}, [])
|
1448 |
+
_ = state.pop("_operations_count", None)
|
1449 |
+
assert state == expectation
|
1450 |
+
|
1451 |
+
|
1452 |
+
class TestEvaluateCondition:
|
1453 |
+
@pytest.mark.parametrize(
|
1454 |
+
"condition, state, expected_result",
|
1455 |
+
[
|
1456 |
+
("a == b", {"a": 1, "b": 1}, True),
|
1457 |
+
("a == b", {"a": 1, "b": 2}, False),
|
1458 |
+
("a != b", {"a": 1, "b": 1}, False),
|
1459 |
+
("a != b", {"a": 1, "b": 2}, True),
|
1460 |
+
("a < b", {"a": 1, "b": 1}, False),
|
1461 |
+
("a < b", {"a": 1, "b": 2}, True),
|
1462 |
+
("a < b", {"a": 2, "b": 1}, False),
|
1463 |
+
("a <= b", {"a": 1, "b": 1}, True),
|
1464 |
+
("a <= b", {"a": 1, "b": 2}, True),
|
1465 |
+
("a <= b", {"a": 2, "b": 1}, False),
|
1466 |
+
("a > b", {"a": 1, "b": 1}, False),
|
1467 |
+
("a > b", {"a": 1, "b": 2}, False),
|
1468 |
+
("a > b", {"a": 2, "b": 1}, True),
|
1469 |
+
("a >= b", {"a": 1, "b": 1}, True),
|
1470 |
+
("a >= b", {"a": 1, "b": 2}, False),
|
1471 |
+
("a >= b", {"a": 2, "b": 1}, True),
|
1472 |
+
("a is b", {"a": 1, "b": 1}, True),
|
1473 |
+
("a is b", {"a": 1, "b": 2}, False),
|
1474 |
+
("a is not b", {"a": 1, "b": 1}, False),
|
1475 |
+
("a is not b", {"a": 1, "b": 2}, True),
|
1476 |
+
("a in b", {"a": 1, "b": [1, 2, 3]}, True),
|
1477 |
+
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
1478 |
+
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
1479 |
+
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
1480 |
+
# Chained conditions:
|
1481 |
+
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
1482 |
+
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
1483 |
+
("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
|
1484 |
+
("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
|
1485 |
+
],
|
1486 |
+
)
|
1487 |
+
def test_evaluate_condition(self, condition, state, expected_result):
|
1488 |
+
condition_ast = ast.parse(condition, mode="eval").body
|
1489 |
+
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
1490 |
+
assert result == expected_result
|
1491 |
+
|
1492 |
+
@pytest.mark.parametrize(
|
1493 |
+
"condition, state, expected_result",
|
1494 |
+
[
|
1495 |
+
("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
|
1496 |
+
("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
|
1497 |
+
("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
|
1498 |
+
("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
|
1499 |
+
("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
|
1500 |
+
("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
|
1501 |
+
(
|
1502 |
+
"a == b",
|
1503 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
1504 |
+
pd.DataFrame({"x": [True, True], "y": [True, False]}),
|
1505 |
+
),
|
1506 |
+
(
|
1507 |
+
"a != b",
|
1508 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
1509 |
+
pd.DataFrame({"x": [False, False], "y": [False, True]}),
|
1510 |
+
),
|
1511 |
+
(
|
1512 |
+
"a < b",
|
1513 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
1514 |
+
pd.DataFrame({"x": [True, False], "y": [False, False]}),
|
1515 |
+
),
|
1516 |
+
(
|
1517 |
+
"a <= b",
|
1518 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
1519 |
+
pd.DataFrame({"x": [True, True], "y": [False, False]}),
|
1520 |
+
),
|
1521 |
+
(
|
1522 |
+
"a > b",
|
1523 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
1524 |
+
pd.DataFrame({"x": [False, False], "y": [True, True]}),
|
1525 |
+
),
|
1526 |
+
(
|
1527 |
+
"a >= b",
|
1528 |
+
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
1529 |
+
pd.DataFrame({"x": [False, True], "y": [True, True]}),
|
1530 |
+
),
|
1531 |
+
],
|
1532 |
+
)
|
1533 |
+
def test_evaluate_condition_with_pandas(self, condition, state, expected_result):
|
1534 |
+
condition_ast = ast.parse(condition, mode="eval").body
|
1535 |
+
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
1536 |
+
if isinstance(result, pd.Series):
|
1537 |
+
pd.testing.assert_series_equal(result, expected_result)
|
1538 |
+
else:
|
1539 |
+
pd.testing.assert_frame_equal(result, expected_result)
|
1540 |
+
|
1541 |
+
@pytest.mark.parametrize(
|
1542 |
+
"condition, state, expected_exception",
|
1543 |
+
[
|
1544 |
+
# Chained conditions:
|
1545 |
+
(
|
1546 |
+
"a == b == c",
|
1547 |
+
{
|
1548 |
+
"a": pd.Series([1, 2, 3]),
|
1549 |
+
"b": pd.Series([2, 2, 2]),
|
1550 |
+
"c": pd.Series([3, 3, 3]),
|
1551 |
+
},
|
1552 |
+
ValueError(
|
1553 |
+
"The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
1554 |
+
),
|
1555 |
+
),
|
1556 |
+
(
|
1557 |
+
"a == b == c",
|
1558 |
+
{
|
1559 |
+
"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
|
1560 |
+
"b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
|
1561 |
+
"c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
|
1562 |
+
},
|
1563 |
+
ValueError(
|
1564 |
+
"The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
1565 |
+
),
|
1566 |
+
),
|
1567 |
+
],
|
1568 |
+
)
|
1569 |
+
def test_evaluate_condition_with_pandas_exceptions(self, condition, state, expected_exception):
|
1570 |
+
condition_ast = ast.parse(condition, mode="eval").body
|
1571 |
+
with pytest.raises(type(expected_exception)) as exception_info:
|
1572 |
+
_ = evaluate_condition(condition_ast, state, {}, {}, [])
|
1573 |
+
assert str(expected_exception) in str(exception_info.value)
|
1574 |
+
|
1575 |
+
|
1576 |
+
class TestEvaluateSubscript:
|
1577 |
+
@pytest.mark.parametrize(
|
1578 |
+
"subscript, state, expected_result",
|
1579 |
+
[
|
1580 |
+
("dct[1]", {"dct": {1: 11, 2: 22}}, 11),
|
1581 |
+
("dct[2]", {"dct": {1: "a", 2: "b"}}, "b"),
|
1582 |
+
("dct['b']", {"dct": {"a": 1, "b": 2}}, 2),
|
1583 |
+
("dct['a']", {"dct": {"a": "aa", "b": "bb"}}, "aa"),
|
1584 |
+
("dct[1, 2]", {"dct": {(1, 2): 3}}, 3), # tuple-index
|
1585 |
+
("dct['a']['b']", {"dct": {"a": {"b": 1}}}, 1), # nested
|
1586 |
+
("lst[0]", {"lst": [1, 2, 3]}, 1),
|
1587 |
+
("lst[-1]", {"lst": [1, 2, 3]}, 3),
|
1588 |
+
("lst[1:3]", {"lst": [1, 2, 3, 4]}, [2, 3]),
|
1589 |
+
("lst[:]", {"lst": [1, 2, 3]}, [1, 2, 3]),
|
1590 |
+
("lst[::2]", {"lst": [1, 2, 3, 4]}, [1, 3]),
|
1591 |
+
("lst[::-1]", {"lst": [1, 2, 3]}, [3, 2, 1]),
|
1592 |
+
("tup[1]", {"tup": (1, 2, 3)}, 2),
|
1593 |
+
("tup[-1]", {"tup": (1, 2, 3)}, 3),
|
1594 |
+
("tup[1:3]", {"tup": (1, 2, 3, 4)}, (2, 3)),
|
1595 |
+
("tup[:]", {"tup": (1, 2, 3)}, (1, 2, 3)),
|
1596 |
+
("tup[::2]", {"tup": (1, 2, 3, 4)}, (1, 3)),
|
1597 |
+
("tup[::-1]", {"tup": (1, 2, 3)}, (3, 2, 1)),
|
1598 |
+
("st[1]", {"str": "abc"}, "b"),
|
1599 |
+
("st[-1]", {"str": "abc"}, "c"),
|
1600 |
+
("st[1:3]", {"str": "abcd"}, "bc"),
|
1601 |
+
("st[:]", {"str": "abc"}, "abc"),
|
1602 |
+
("st[::2]", {"str": "abcd"}, "ac"),
|
1603 |
+
("st[::-1]", {"str": "abc"}, "cba"),
|
1604 |
+
("arr[1]", {"arr": np.array([1, 2, 3])}, 2),
|
1605 |
+
("arr[1:3]", {"arr": np.array([1, 2, 3, 4])}, np.array([2, 3])),
|
1606 |
+
("arr[:]", {"arr": np.array([1, 2, 3])}, np.array([1, 2, 3])),
|
1607 |
+
("arr[::2]", {"arr": np.array([1, 2, 3, 4])}, np.array([1, 3])),
|
1608 |
+
("arr[::-1]", {"arr": np.array([1, 2, 3])}, np.array([3, 2, 1])),
|
1609 |
+
("arr[1, 2]", {"arr": np.array([[1, 2, 3], [4, 5, 6]])}, 6),
|
1610 |
+
("ser[1]", {"ser": pd.Series([1, 2, 3])}, 2),
|
1611 |
+
("ser.loc[1]", {"ser": pd.Series([1, 2, 3])}, 2),
|
1612 |
+
("ser.loc[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 3),
|
1613 |
+
("ser.iloc[1]", {"ser": pd.Series([1, 2, 3])}, 2),
|
1614 |
+
("ser.iloc[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 2),
|
1615 |
+
("ser.at[1]", {"ser": pd.Series([1, 2, 3])}, 2),
|
1616 |
+
("ser.at[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 3),
|
1617 |
+
("ser.iat[1]", {"ser": pd.Series([1, 2, 3])}, 2),
|
1618 |
+
("ser.iat[1]", {"ser": pd.Series([1, 2, 3], index=[2, 3, 1])}, 2),
|
1619 |
+
("ser[1:3]", {"ser": pd.Series([1, 2, 3, 4])}, pd.Series([2, 3], index=[1, 2])),
|
1620 |
+
("ser[:]", {"ser": pd.Series([1, 2, 3])}, pd.Series([1, 2, 3])),
|
1621 |
+
("ser[::2]", {"ser": pd.Series([1, 2, 3, 4])}, pd.Series([1, 3], index=[0, 2])),
|
1622 |
+
("ser[::-1]", {"ser": pd.Series([1, 2, 3])}, pd.Series([3, 2, 1], index=[2, 1, 0])),
|
1623 |
+
("df['y'][1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
|
1624 |
+
("df['y'][5]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
|
1625 |
+
("df.loc[1, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
|
1626 |
+
("df.loc[5, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
|
1627 |
+
("df.iloc[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
|
1628 |
+
("df.iloc[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 4),
|
1629 |
+
("df.at[1, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
|
1630 |
+
("df.at[5, 'y']", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 3),
|
1631 |
+
("df.iat[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]})}, 4),
|
1632 |
+
("df.iat[1, 1]", {"df": pd.DataFrame({"x": [1, 2], "y": [3, 4]}, index=[5, 6])}, 4),
|
1633 |
+
],
|
1634 |
+
)
|
1635 |
+
def test_evaluate_subscript(self, subscript, state, expected_result):
|
1636 |
+
subscript_ast = ast.parse(subscript).body[0].value
|
1637 |
+
result = evaluate_subscript(subscript_ast, state, {}, {}, [])
|
1638 |
+
try:
|
1639 |
+
assert result == expected_result
|
1640 |
+
except ValueError:
|
1641 |
+
assert (result == expected_result).all()
|
1642 |
+
|
1643 |
+
@pytest.mark.parametrize(
|
1644 |
+
"subscript, state, expected_error_message",
|
1645 |
+
[
|
1646 |
+
("dct['a']", {"dct": {}}, "KeyError: 'a'"),
|
1647 |
+
("dct[0]", {"dct": {}}, "KeyError: 0"),
|
1648 |
+
("dct['c']", {"dct": {"a": 1, "b": 2}}, "KeyError: 'c'"),
|
1649 |
+
("dct[1, 2, 3]", {"dct": {(1, 2): 3}}, "KeyError: (1, 2, 3)"),
|
1650 |
+
("lst[0]", {"lst": []}, "IndexError: list index out of range"),
|
1651 |
+
("lst[3]", {"lst": [1, 2, 3]}, "IndexError: list index out of range"),
|
1652 |
+
("lst[-4]", {"lst": [1, 2, 3]}, "IndexError: list index out of range"),
|
1653 |
+
("value[0]", {"value": 1}, "TypeError: 'int' object is not subscriptable"),
|
1654 |
+
],
|
1655 |
+
)
|
1656 |
+
def test_evaluate_subscript_error(self, subscript, state, expected_error_message):
|
1657 |
+
subscript_ast = ast.parse(subscript).body[0].value
|
1658 |
+
with pytest.raises(InterpreterError, match="Could not index") as exception_info:
|
1659 |
+
_ = evaluate_subscript(subscript_ast, state, {}, {}, [])
|
1660 |
+
assert expected_error_message in str(exception_info.value)
|
1661 |
+
|
1662 |
+
@pytest.mark.parametrize(
|
1663 |
+
"subscriptable_class, expectation",
|
1664 |
+
[
|
1665 |
+
(True, 20),
|
1666 |
+
(False, InterpreterError("TypeError: 'Custom' object is not subscriptable")),
|
1667 |
+
],
|
1668 |
+
)
|
1669 |
+
def test_evaluate_subscript_with_custom_class(self, subscriptable_class, expectation):
|
1670 |
+
if subscriptable_class:
|
1671 |
+
|
1672 |
+
class Custom:
|
1673 |
+
def __getitem__(self, key):
|
1674 |
+
return key * 10
|
1675 |
+
else:
|
1676 |
+
|
1677 |
+
class Custom:
|
1678 |
+
pass
|
1679 |
+
|
1680 |
+
state = {"obj": Custom()}
|
1681 |
+
subscript = "obj[2]"
|
1682 |
+
subscript_ast = ast.parse(subscript).body[0].value
|
1683 |
+
if isinstance(expectation, Exception):
|
1684 |
+
with pytest.raises(type(expectation), match="Could not index") as exception_info:
|
1685 |
+
evaluate_subscript(subscript_ast, state, {}, {}, [])
|
1686 |
+
assert "TypeError: 'Custom' object is not subscriptable" in str(exception_info.value)
|
1687 |
+
else:
|
1688 |
+
result = evaluate_subscript(subscript_ast, state, {}, {}, [])
|
1689 |
+
assert result == expectation
|
1690 |
+
|
1691 |
+
|
1692 |
+
def test_get_safe_module_handle_lazy_imports():
|
1693 |
+
class FakeModule(types.ModuleType):
|
1694 |
+
def __init__(self, name):
|
1695 |
+
super().__init__(name)
|
1696 |
+
self.non_lazy_attribute = "ok"
|
1697 |
+
|
1698 |
+
def __getattr__(self, name):
|
1699 |
+
if name == "lazy_attribute":
|
1700 |
+
raise ImportError("lazy import failure")
|
1701 |
+
return super().__getattr__(name)
|
1702 |
+
|
1703 |
+
def __dir__(self):
|
1704 |
+
return super().__dir__() + ["lazy_attribute"]
|
1705 |
+
|
1706 |
+
fake_module = FakeModule("fake_module")
|
1707 |
+
safe_module = get_safe_module(fake_module, authorized_imports=set())
|
1708 |
+
assert not hasattr(safe_module, "lazy_attribute")
|
1709 |
+
assert getattr(safe_module, "non_lazy_attribute") == "ok"
|
1710 |
+
|
1711 |
+
|
1712 |
+
class TestPrintContainer:
|
1713 |
+
def test_initial_value(self):
|
1714 |
+
pc = PrintContainer()
|
1715 |
+
assert pc.value == ""
|
1716 |
+
|
1717 |
+
def test_append(self):
|
1718 |
+
pc = PrintContainer()
|
1719 |
+
pc.append("Hello")
|
1720 |
+
assert pc.value == "Hello"
|
1721 |
+
|
1722 |
+
def test_iadd(self):
|
1723 |
+
pc = PrintContainer()
|
1724 |
+
pc += "World"
|
1725 |
+
assert pc.value == "World"
|
1726 |
+
|
1727 |
+
def test_str(self):
|
1728 |
+
pc = PrintContainer()
|
1729 |
+
pc.append("Hello")
|
1730 |
+
assert str(pc) == "Hello"
|
1731 |
+
|
1732 |
+
def test_repr(self):
|
1733 |
+
pc = PrintContainer()
|
1734 |
+
pc.append("Hello")
|
1735 |
+
assert repr(pc) == "PrintContainer(Hello)"
|
1736 |
+
|
1737 |
+
def test_len(self):
|
1738 |
+
pc = PrintContainer()
|
1739 |
+
pc.append("Hello")
|
1740 |
+
assert len(pc) == 5
|
1741 |
+
|
1742 |
+
|
1743 |
+
def test_fix_final_answer_code():
|
1744 |
+
test_cases = [
|
1745 |
+
(
|
1746 |
+
"final_answer = 3.21\nfinal_answer(final_answer)",
|
1747 |
+
"final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
|
1748 |
+
),
|
1749 |
+
(
|
1750 |
+
"x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
|
1751 |
+
"x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
|
1752 |
+
),
|
1753 |
+
(
|
1754 |
+
"def func():\n final_answer = 42\n return final_answer(final_answer)",
|
1755 |
+
"def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
|
1756 |
+
),
|
1757 |
+
(
|
1758 |
+
"final_answer(5) # Should not change function calls",
|
1759 |
+
"final_answer(5) # Should not change function calls",
|
1760 |
+
),
|
1761 |
+
(
|
1762 |
+
"obj.final_answer = 5 # Should not change object attributes",
|
1763 |
+
"obj.final_answer = 5 # Should not change object attributes",
|
1764 |
+
),
|
1765 |
+
(
|
1766 |
+
"final_answer=3.21;final_answer(final_answer)",
|
1767 |
+
"final_answer_variable=3.21;final_answer(final_answer_variable)",
|
1768 |
+
),
|
1769 |
+
]
|
1770 |
+
|
1771 |
+
for i, (input_code, expected) in enumerate(test_cases, 1):
|
1772 |
+
result = fix_final_answer_code(input_code)
|
1773 |
+
assert result == expected, f"""
|
1774 |
+
Test case {i} failed:
|
1775 |
+
Input: {input_code}
|
1776 |
+
Expected: {expected}
|
1777 |
+
Got: {result}
|
1778 |
+
"""
|
1779 |
+
|
1780 |
+
|
1781 |
+
@pytest.mark.parametrize(
|
1782 |
+
"module,authorized_imports,expected",
|
1783 |
+
[
|
1784 |
+
("os", ["other", "*"], True),
|
1785 |
+
("AnyModule", ["*"], True),
|
1786 |
+
("os", ["os"], True),
|
1787 |
+
("AnyModule", ["AnyModule"], True),
|
1788 |
+
("Module.os", ["Module"], False),
|
1789 |
+
("Module.os", ["Module", "Module.os"], True),
|
1790 |
+
("os.path", ["os.*"], True),
|
1791 |
+
("os", ["os.path"], True),
|
1792 |
+
],
|
1793 |
+
)
|
1794 |
+
def test_check_import_authorized(module: str, authorized_imports: list[str], expected: bool):
|
1795 |
+
assert check_import_authorized(module, authorized_imports) == expected
|
1796 |
+
|
1797 |
+
|
1798 |
+
class TestLocalPythonExecutor:
|
1799 |
+
def test_state_name(self):
|
1800 |
+
executor = LocalPythonExecutor(additional_authorized_imports=[])
|
1801 |
+
assert executor.state.get("__name__") == "__main__"
|
1802 |
+
|
1803 |
+
@pytest.mark.parametrize(
|
1804 |
+
"code",
|
1805 |
+
[
|
1806 |
+
"d = {'func': lambda x: x + 10}; func = d['func']; func(1)",
|
1807 |
+
"d = {'func': lambda x: x + 10}; d['func'](1)",
|
1808 |
+
],
|
1809 |
+
)
|
1810 |
+
def test_call_from_dict(self, code):
|
1811 |
+
executor = LocalPythonExecutor([])
|
1812 |
+
result, _, _ = executor(code)
|
1813 |
+
assert result == 11
|
1814 |
+
|
1815 |
+
@pytest.mark.parametrize(
|
1816 |
+
"code",
|
1817 |
+
[
|
1818 |
+
"a = b = 1; a",
|
1819 |
+
"a = b = 1; b",
|
1820 |
+
"a, b = c, d = 1, 1; a",
|
1821 |
+
"a, b = c, d = 1, 1; b",
|
1822 |
+
"a, b = c, d = 1, 1; c",
|
1823 |
+
"a, b = c, d = {1, 2}; a",
|
1824 |
+
"a, b = c, d = {1, 2}; c",
|
1825 |
+
"a, b = c, d = {1: 10, 2: 20}; a",
|
1826 |
+
"a, b = c, d = {1: 10, 2: 20}; c",
|
1827 |
+
"a = b = (lambda: 1)(); b",
|
1828 |
+
"a = b = (lambda: 1)(); lambda x: 10; b",
|
1829 |
+
"a = b = (lambda x: lambda y: x + y)(0)(1); b",
|
1830 |
+
dedent("""
|
1831 |
+
def foo():
|
1832 |
+
return 1;
|
1833 |
+
a = b = foo(); b"""),
|
1834 |
+
dedent("""
|
1835 |
+
def foo(*args, **kwargs):
|
1836 |
+
return sum(args)
|
1837 |
+
a = b = foo(1,-1,1); b"""),
|
1838 |
+
"a, b = 1, 2; a, b = b, a; b",
|
1839 |
+
],
|
1840 |
+
)
|
1841 |
+
def test_chained_assignments(self, code):
|
1842 |
+
executor = LocalPythonExecutor([])
|
1843 |
+
executor.send_tools({})
|
1844 |
+
result, _, _ = executor(code)
|
1845 |
+
assert result == 1
|
1846 |
+
|
1847 |
+
def test_evaluate_assign_error(self):
|
1848 |
+
code = "a, b = 1, 2, 3; a"
|
1849 |
+
executor = LocalPythonExecutor([])
|
1850 |
+
with pytest.raises(InterpreterError, match=".*Cannot unpack tuple of wrong size"):
|
1851 |
+
executor(code)
|
1852 |
+
|
1853 |
+
def test_function_def_recovers_source_code(self):
|
1854 |
+
executor = LocalPythonExecutor([])
|
1855 |
+
executor.send_tools({"final_answer": FinalAnswerTool()})
|
1856 |
+
res, _, _ = executor(
|
1857 |
+
dedent(
|
1858 |
+
"""
|
1859 |
+
def target_function():
|
1860 |
+
return "Hello world"
|
1861 |
+
|
1862 |
+
final_answer(target_function)
|
1863 |
+
"""
|
1864 |
+
)
|
1865 |
+
)
|
1866 |
+
assert res.__name__ == "target_function"
|
1867 |
+
assert res.__source__ == "def target_function():\n return 'Hello world'"
|
1868 |
+
|
1869 |
+
@pytest.mark.parametrize(
|
1870 |
+
"code, expected_result",
|
1871 |
+
[("isinstance(5, int)", True), ("isinstance('foo', str)", True), ("isinstance(5, str)", False)],
|
1872 |
+
)
|
1873 |
+
def test_isinstance_builtin_type(self, code, expected_result):
|
1874 |
+
executor = LocalPythonExecutor([])
|
1875 |
+
executor.send_tools({})
|
1876 |
+
result, _, _ = executor(code)
|
1877 |
+
assert result is expected_result
|
1878 |
+
|
1879 |
+
|
1880 |
+
class TestLocalPythonExecutorSecurity:
|
1881 |
+
@pytest.mark.parametrize(
|
1882 |
+
"additional_authorized_imports, expected_error",
|
1883 |
+
[([], InterpreterError("Import of os is not allowed")), (["os"], None)],
|
1884 |
+
)
|
1885 |
+
def test_vulnerability_import(self, additional_authorized_imports, expected_error):
|
1886 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1887 |
+
with (
|
1888 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1889 |
+
if isinstance(expected_error, Exception)
|
1890 |
+
else does_not_raise()
|
1891 |
+
):
|
1892 |
+
executor("import os")
|
1893 |
+
|
1894 |
+
@pytest.mark.parametrize(
|
1895 |
+
"additional_authorized_imports, expected_error",
|
1896 |
+
[([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)],
|
1897 |
+
)
|
1898 |
+
def test_vulnerability_builtins(self, additional_authorized_imports, expected_error):
|
1899 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1900 |
+
with (
|
1901 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1902 |
+
if isinstance(expected_error, Exception)
|
1903 |
+
else does_not_raise()
|
1904 |
+
):
|
1905 |
+
executor("import builtins")
|
1906 |
+
|
1907 |
+
@pytest.mark.parametrize(
|
1908 |
+
"additional_authorized_imports, expected_error",
|
1909 |
+
[([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)],
|
1910 |
+
)
|
1911 |
+
def test_vulnerability_builtins_safe_functions(self, additional_authorized_imports, expected_error):
|
1912 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1913 |
+
with (
|
1914 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1915 |
+
if isinstance(expected_error, Exception)
|
1916 |
+
else does_not_raise()
|
1917 |
+
):
|
1918 |
+
executor("import builtins; builtins.print(1)")
|
1919 |
+
|
1920 |
+
@pytest.mark.parametrize(
|
1921 |
+
"additional_authorized_imports, additional_tools, expected_error",
|
1922 |
+
[
|
1923 |
+
([], [], InterpreterError("Import of builtins is not allowed")),
|
1924 |
+
(["builtins"], [], InterpreterError("Forbidden access to function: exec")),
|
1925 |
+
(["builtins"], ["exec"], None),
|
1926 |
+
],
|
1927 |
+
)
|
1928 |
+
def test_vulnerability_builtins_dangerous_functions(
|
1929 |
+
self, additional_authorized_imports, additional_tools, expected_error
|
1930 |
+
):
|
1931 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1932 |
+
if additional_tools:
|
1933 |
+
from builtins import exec
|
1934 |
+
|
1935 |
+
executor.send_tools({"exec": exec})
|
1936 |
+
with (
|
1937 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1938 |
+
if isinstance(expected_error, Exception)
|
1939 |
+
else does_not_raise()
|
1940 |
+
):
|
1941 |
+
executor("import builtins; builtins.exec")
|
1942 |
+
|
1943 |
+
@pytest.mark.parametrize(
|
1944 |
+
"additional_authorized_imports, additional_tools, expected_error",
|
1945 |
+
[
|
1946 |
+
([], [], InterpreterError("Import of os is not allowed")),
|
1947 |
+
(["os"], [], InterpreterError("Forbidden access to function: popen")),
|
1948 |
+
(["os"], ["popen"], None),
|
1949 |
+
],
|
1950 |
+
)
|
1951 |
+
def test_vulnerability_dangerous_functions(self, additional_authorized_imports, additional_tools, expected_error):
|
1952 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1953 |
+
if additional_tools:
|
1954 |
+
from os import popen
|
1955 |
+
|
1956 |
+
executor.send_tools({"popen": popen})
|
1957 |
+
with (
|
1958 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1959 |
+
if isinstance(expected_error, Exception)
|
1960 |
+
else does_not_raise()
|
1961 |
+
):
|
1962 |
+
executor("import os; os.popen")
|
1963 |
+
|
1964 |
+
@pytest.mark.parametrize("dangerous_function", DANGEROUS_FUNCTIONS)
|
1965 |
+
def test_vulnerability_for_all_dangerous_functions(self, dangerous_function):
|
1966 |
+
dangerous_module_name, dangerous_function_name = dangerous_function.rsplit(".", 1)
|
1967 |
+
# Skip test if module is not installed: posix module is not installed on Windows
|
1968 |
+
pytest.importorskip(dangerous_module_name)
|
1969 |
+
executor = LocalPythonExecutor([dangerous_module_name])
|
1970 |
+
if "__" in dangerous_function_name:
|
1971 |
+
error_match = f".*Forbidden access to dunder attribute: {dangerous_function_name}"
|
1972 |
+
else:
|
1973 |
+
error_match = f".*Forbidden access to function: {dangerous_function_name}.*"
|
1974 |
+
with pytest.raises(InterpreterError, match=error_match):
|
1975 |
+
executor(f"import {dangerous_module_name}; {dangerous_function}")
|
1976 |
+
|
1977 |
+
@pytest.mark.parametrize(
|
1978 |
+
"additional_authorized_imports, expected_error",
|
1979 |
+
[
|
1980 |
+
([], InterpreterError("Import of sys is not allowed")),
|
1981 |
+
(["sys"], InterpreterError("Forbidden access to module: os")),
|
1982 |
+
(["sys", "os"], None),
|
1983 |
+
],
|
1984 |
+
)
|
1985 |
+
def test_vulnerability_via_sys(self, additional_authorized_imports, expected_error):
|
1986 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
1987 |
+
with (
|
1988 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
1989 |
+
if isinstance(expected_error, Exception)
|
1990 |
+
else does_not_raise()
|
1991 |
+
):
|
1992 |
+
executor(
|
1993 |
+
dedent(
|
1994 |
+
"""
|
1995 |
+
import sys
|
1996 |
+
sys.modules["os"].system(":")
|
1997 |
+
"""
|
1998 |
+
)
|
1999 |
+
)
|
2000 |
+
|
2001 |
+
@pytest.mark.parametrize("dangerous_module", DANGEROUS_MODULES)
|
2002 |
+
def test_vulnerability_via_sys_for_all_dangerous_modules(self, dangerous_module):
|
2003 |
+
import sys
|
2004 |
+
|
2005 |
+
if dangerous_module not in sys.modules or dangerous_module == "sys":
|
2006 |
+
pytest.skip("module not present in sys.modules")
|
2007 |
+
executor = LocalPythonExecutor(["sys"])
|
2008 |
+
with pytest.raises(InterpreterError) as exception_info:
|
2009 |
+
executor(
|
2010 |
+
dedent(
|
2011 |
+
f"""
|
2012 |
+
import sys
|
2013 |
+
sys.modules["{dangerous_module}"]
|
2014 |
+
"""
|
2015 |
+
)
|
2016 |
+
)
|
2017 |
+
assert f"Forbidden access to module: {dangerous_module}" in str(exception_info.value)
|
2018 |
+
|
2019 |
+
@pytest.mark.parametrize(
|
2020 |
+
"additional_authorized_imports, expected_error",
|
2021 |
+
[(["importlib"], InterpreterError("Forbidden access to module: os")), (["importlib", "os"], None)],
|
2022 |
+
)
|
2023 |
+
def test_vulnerability_via_importlib(self, additional_authorized_imports, expected_error):
|
2024 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2025 |
+
with (
|
2026 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2027 |
+
if isinstance(expected_error, Exception)
|
2028 |
+
else does_not_raise()
|
2029 |
+
):
|
2030 |
+
executor(
|
2031 |
+
dedent(
|
2032 |
+
"""
|
2033 |
+
import importlib
|
2034 |
+
importlib.import_module("os").system(":")
|
2035 |
+
"""
|
2036 |
+
)
|
2037 |
+
)
|
2038 |
+
|
2039 |
+
@pytest.mark.parametrize(
|
2040 |
+
"code, additional_authorized_imports, expected_error",
|
2041 |
+
[
|
2042 |
+
# os submodule
|
2043 |
+
(
|
2044 |
+
"import queue; queue.threading._os.system(':')",
|
2045 |
+
[],
|
2046 |
+
InterpreterError("Forbidden access to module: threading"),
|
2047 |
+
),
|
2048 |
+
(
|
2049 |
+
"import queue; queue.threading._os.system(':')",
|
2050 |
+
["threading"],
|
2051 |
+
InterpreterError("Forbidden access to module: os"),
|
2052 |
+
),
|
2053 |
+
("import random; random._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
|
2054 |
+
(
|
2055 |
+
"import random; random.__dict__['_os'].system(':')",
|
2056 |
+
[],
|
2057 |
+
InterpreterError("Forbidden access to dunder attribute: __dict__"),
|
2058 |
+
),
|
2059 |
+
(
|
2060 |
+
"import doctest; doctest.inspect.os.system(':')",
|
2061 |
+
["doctest"],
|
2062 |
+
InterpreterError("Forbidden access to module: inspect"),
|
2063 |
+
),
|
2064 |
+
(
|
2065 |
+
"import doctest; doctest.inspect.os.system(':')",
|
2066 |
+
["doctest", "inspect"],
|
2067 |
+
InterpreterError("Forbidden access to module: os"),
|
2068 |
+
),
|
2069 |
+
# subprocess submodule
|
2070 |
+
(
|
2071 |
+
"import asyncio; asyncio.base_events.events.subprocess",
|
2072 |
+
["asyncio"],
|
2073 |
+
InterpreterError("Forbidden access to module: asyncio.base_events"),
|
2074 |
+
),
|
2075 |
+
(
|
2076 |
+
"import asyncio; asyncio.base_events.events.subprocess",
|
2077 |
+
["asyncio", "asyncio.base_events"],
|
2078 |
+
InterpreterError("Forbidden access to module: asyncio.events"),
|
2079 |
+
),
|
2080 |
+
(
|
2081 |
+
"import asyncio; asyncio.base_events.events.subprocess",
|
2082 |
+
["asyncio", "asyncio.base_events", "asyncio.base_events.events"],
|
2083 |
+
InterpreterError("Forbidden access to module: asyncio.events"),
|
2084 |
+
),
|
2085 |
+
# sys submodule
|
2086 |
+
(
|
2087 |
+
"import queue; queue.threading._sys.modules['os'].system(':')",
|
2088 |
+
[],
|
2089 |
+
InterpreterError("Forbidden access to module: threading"),
|
2090 |
+
),
|
2091 |
+
(
|
2092 |
+
"import queue; queue.threading._sys.modules['os'].system(':')",
|
2093 |
+
["threading"],
|
2094 |
+
InterpreterError("Forbidden access to module: sys"),
|
2095 |
+
),
|
2096 |
+
("import warnings; warnings.sys", ["warnings"], InterpreterError("Forbidden access to module: sys")),
|
2097 |
+
# Allowed
|
2098 |
+
("import pandas; pandas.io", ["pandas", "pandas.io"], None),
|
2099 |
+
],
|
2100 |
+
)
|
2101 |
+
def test_vulnerability_via_submodules(self, code, additional_authorized_imports, expected_error):
|
2102 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2103 |
+
with (
|
2104 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2105 |
+
if isinstance(expected_error, Exception)
|
2106 |
+
else does_not_raise()
|
2107 |
+
):
|
2108 |
+
executor(code)
|
2109 |
+
|
2110 |
+
@pytest.mark.parametrize(
|
2111 |
+
"code, additional_authorized_imports, expected_error",
|
2112 |
+
[
|
2113 |
+
# Using filter with functools.partial
|
2114 |
+
(
|
2115 |
+
dedent(
|
2116 |
+
"""
|
2117 |
+
import functools
|
2118 |
+
import warnings
|
2119 |
+
list(filter(functools.partial(getattr, warnings), ["sys"]))
|
2120 |
+
"""
|
2121 |
+
),
|
2122 |
+
["warnings", "functools"],
|
2123 |
+
InterpreterError("Forbidden access to module: sys"),
|
2124 |
+
),
|
2125 |
+
# Using map
|
2126 |
+
(
|
2127 |
+
dedent(
|
2128 |
+
"""
|
2129 |
+
import warnings
|
2130 |
+
list(map(getattr, [warnings], ["sys"]))
|
2131 |
+
"""
|
2132 |
+
),
|
2133 |
+
["warnings"],
|
2134 |
+
InterpreterError("Forbidden access to module: sys"),
|
2135 |
+
),
|
2136 |
+
# Using map with functools.partial
|
2137 |
+
(
|
2138 |
+
dedent(
|
2139 |
+
"""
|
2140 |
+
import functools
|
2141 |
+
import warnings
|
2142 |
+
list(map(functools.partial(getattr, warnings), ["sys"]))
|
2143 |
+
"""
|
2144 |
+
),
|
2145 |
+
["warnings", "functools"],
|
2146 |
+
InterpreterError("Forbidden access to module: sys"),
|
2147 |
+
),
|
2148 |
+
],
|
2149 |
+
)
|
2150 |
+
def test_vulnerability_via_submodules_through_indirect_attribute_access(
|
2151 |
+
self, code, additional_authorized_imports, expected_error
|
2152 |
+
):
|
2153 |
+
# warnings.sys
|
2154 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2155 |
+
executor.send_tools({})
|
2156 |
+
with pytest.raises(type(expected_error), match=f".*{expected_error}"):
|
2157 |
+
executor(code)
|
2158 |
+
|
2159 |
+
@pytest.mark.parametrize(
|
2160 |
+
"additional_authorized_imports, additional_tools, expected_error",
|
2161 |
+
[
|
2162 |
+
([], [], InterpreterError("Import of sys is not allowed")),
|
2163 |
+
(["sys"], [], InterpreterError("Forbidden access to module: builtins")),
|
2164 |
+
(
|
2165 |
+
["sys", "builtins"],
|
2166 |
+
[],
|
2167 |
+
InterpreterError("Forbidden access to function: __import__"),
|
2168 |
+
),
|
2169 |
+
(["sys", "builtins"], ["__import__"], InterpreterError("Forbidden access to module: os")),
|
2170 |
+
(["sys", "builtins", "os"], ["__import__"], None),
|
2171 |
+
],
|
2172 |
+
)
|
2173 |
+
def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, additional_tools, expected_error):
|
2174 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2175 |
+
if additional_tools:
|
2176 |
+
from builtins import __import__
|
2177 |
+
|
2178 |
+
executor.send_tools({"__import__": __import__})
|
2179 |
+
with (
|
2180 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2181 |
+
if isinstance(expected_error, Exception)
|
2182 |
+
else does_not_raise()
|
2183 |
+
):
|
2184 |
+
executor(
|
2185 |
+
dedent(
|
2186 |
+
"""
|
2187 |
+
import sys
|
2188 |
+
builtins = sys._getframe().f_builtins
|
2189 |
+
builtins_import = builtins["__import__"]
|
2190 |
+
os_module = builtins_import("os")
|
2191 |
+
os_module.system(":")
|
2192 |
+
"""
|
2193 |
+
)
|
2194 |
+
)
|
2195 |
+
|
2196 |
+
@pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
|
2197 |
+
@pytest.mark.parametrize(
|
2198 |
+
"additional_authorized_imports, additional_tools, expected_error",
|
2199 |
+
[
|
2200 |
+
([], [], InterpreterError("Forbidden access to dunder attribute: __traceback__")),
|
2201 |
+
(
|
2202 |
+
["builtins", "os"],
|
2203 |
+
["__import__"],
|
2204 |
+
InterpreterError("Forbidden access to dunder attribute: __traceback__"),
|
2205 |
+
),
|
2206 |
+
],
|
2207 |
+
)
|
2208 |
+
def test_vulnerability_builtins_via_traceback(
|
2209 |
+
self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
|
2210 |
+
):
|
2211 |
+
if patch_builtin_import_module:
|
2212 |
+
monkeypatch.setattr("builtins.__import__.__module__", None) # inspect.getmodule(func) = None
|
2213 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2214 |
+
if additional_tools:
|
2215 |
+
from builtins import __import__
|
2216 |
+
|
2217 |
+
executor.send_tools({"__import__": __import__})
|
2218 |
+
with (
|
2219 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2220 |
+
if isinstance(expected_error, Exception)
|
2221 |
+
else does_not_raise()
|
2222 |
+
):
|
2223 |
+
executor(
|
2224 |
+
dedent(
|
2225 |
+
"""
|
2226 |
+
try:
|
2227 |
+
1 / 0
|
2228 |
+
except Exception as e:
|
2229 |
+
builtins = e.__traceback__.tb_frame.f_back.f_globals["__builtins__"]
|
2230 |
+
builtins_import = builtins["__import__"]
|
2231 |
+
os_module = builtins_import("os")
|
2232 |
+
os_module.system(":")
|
2233 |
+
"""
|
2234 |
+
)
|
2235 |
+
)
|
2236 |
+
|
2237 |
+
@pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
|
2238 |
+
@pytest.mark.parametrize(
|
2239 |
+
"additional_authorized_imports, additional_tools, expected_error",
|
2240 |
+
[
|
2241 |
+
([], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
|
2242 |
+
(["warnings"], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
|
2243 |
+
(
|
2244 |
+
["warnings", "builtins"],
|
2245 |
+
[],
|
2246 |
+
InterpreterError("Forbidden access to dunder attribute: __base__"),
|
2247 |
+
),
|
2248 |
+
(["warnings", "builtins", "os"], [], InterpreterError("Forbidden access to dunder attribute: __base__")),
|
2249 |
+
(
|
2250 |
+
["warnings", "builtins", "os"],
|
2251 |
+
["__import__"],
|
2252 |
+
InterpreterError("Forbidden access to dunder attribute: __base__"),
|
2253 |
+
),
|
2254 |
+
],
|
2255 |
+
)
|
2256 |
+
def test_vulnerability_builtins_via_class_catch_warnings(
|
2257 |
+
self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
|
2258 |
+
):
|
2259 |
+
if patch_builtin_import_module:
|
2260 |
+
monkeypatch.setattr("builtins.__import__.__module__", None) # inspect.getmodule(func) = None
|
2261 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2262 |
+
if additional_tools:
|
2263 |
+
from builtins import __import__
|
2264 |
+
|
2265 |
+
executor.send_tools({"__import__": __import__})
|
2266 |
+
if isinstance(expected_error, tuple): # different error depending on patch status
|
2267 |
+
expected_error = expected_error[patch_builtin_import_module]
|
2268 |
+
if isinstance(expected_error, Exception):
|
2269 |
+
expectation = pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2270 |
+
elif expected_error is None:
|
2271 |
+
expectation = does_not_raise()
|
2272 |
+
with expectation:
|
2273 |
+
executor(
|
2274 |
+
dedent(
|
2275 |
+
"""
|
2276 |
+
classes = {}.__class__.__base__.__subclasses__()
|
2277 |
+
for cls in classes:
|
2278 |
+
if cls.__name__ == "catch_warnings":
|
2279 |
+
break
|
2280 |
+
builtins = cls()._module.__builtins__
|
2281 |
+
builtins_import = builtins["__import__"]
|
2282 |
+
os_module = builtins_import('os')
|
2283 |
+
os_module.system(":")
|
2284 |
+
"""
|
2285 |
+
)
|
2286 |
+
)
|
2287 |
+
|
2288 |
+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
2289 |
+
@pytest.mark.parametrize(
|
2290 |
+
"additional_authorized_imports, expected_error",
|
2291 |
+
[
|
2292 |
+
([], InterpreterError("Forbidden access to dunder attribute: __base__")),
|
2293 |
+
(["os"], InterpreterError("Forbidden access to dunder attribute: __base__")),
|
2294 |
+
],
|
2295 |
+
)
|
2296 |
+
def test_vulnerability_load_module_via_builtin_importer(self, additional_authorized_imports, expected_error):
|
2297 |
+
executor = LocalPythonExecutor(additional_authorized_imports)
|
2298 |
+
with (
|
2299 |
+
pytest.raises(type(expected_error), match=f".*{expected_error}")
|
2300 |
+
if isinstance(expected_error, Exception)
|
2301 |
+
else does_not_raise()
|
2302 |
+
):
|
2303 |
+
executor(
|
2304 |
+
dedent(
|
2305 |
+
"""
|
2306 |
+
classes = {}.__class__.__base__.__subclasses__()
|
2307 |
+
for cls in classes:
|
2308 |
+
if cls.__name__ == "BuiltinImporter":
|
2309 |
+
break
|
2310 |
+
os_module = cls().load_module("os")
|
2311 |
+
os_module.system(":")
|
2312 |
+
"""
|
2313 |
+
)
|
2314 |
+
)
|
2315 |
+
|
2316 |
+
def test_vulnerability_class_via_subclasses(self):
|
2317 |
+
# Subclass: subprocess.Popen
|
2318 |
+
executor = LocalPythonExecutor([])
|
2319 |
+
code = dedent(
|
2320 |
+
"""
|
2321 |
+
for cls in ().__class__.__base__.__subclasses__():
|
2322 |
+
if 'Popen' in cls.__class__.__repr__(cls):
|
2323 |
+
break
|
2324 |
+
cls(["sh", "-c", ":"]).wait()
|
2325 |
+
"""
|
2326 |
+
)
|
2327 |
+
with pytest.raises(InterpreterError, match="Forbidden access to dunder attribute: __base__"):
|
2328 |
+
executor(code)
|
2329 |
+
|
2330 |
+
code = dedent(
|
2331 |
+
"""
|
2332 |
+
[c for c in ().__class__.__base__.__subclasses__() if "Popen" in c.__class__.__repr__(c)][0](
|
2333 |
+
["sh", "-c", ":"]
|
2334 |
+
).wait()
|
2335 |
+
"""
|
2336 |
+
)
|
2337 |
+
with pytest.raises(InterpreterError, match="Forbidden access to dunder attribute: __base__"):
|
2338 |
+
executor(code)
|
2339 |
+
|
2340 |
+
@pytest.mark.parametrize(
|
2341 |
+
"code, dunder_attribute",
|
2342 |
+
[("a = (); b = a.__class__", "__class__"), ("class A:\n attr=1\nx = A()\nx_dict = x.__dict__", "__dict__")],
|
2343 |
+
)
|
2344 |
+
def test_vulnerability_via_dunder_access(self, code, dunder_attribute):
|
2345 |
+
executor = LocalPythonExecutor([])
|
2346 |
+
with pytest.raises(InterpreterError, match=f"Forbidden access to dunder attribute: {dunder_attribute}"):
|
2347 |
+
executor(code)
|
2348 |
+
|
2349 |
+
def test_vulnerability_via_dunder_indirect_access(self):
|
2350 |
+
executor = LocalPythonExecutor([])
|
2351 |
+
code = "a = (); b = getattr(a, '__class__')"
|
2352 |
+
with pytest.raises(InterpreterError, match="Forbidden function evaluation: 'getattr'"):
|
2353 |
+
executor(code)
|
tests/test_mcp_client.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from textwrap import dedent
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from mcp import StdioServerParameters
|
5 |
+
|
6 |
+
from smolagents.mcp_client import MCPClient
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def echo_server_script():
|
11 |
+
return dedent(
|
12 |
+
'''
|
13 |
+
from mcp.server.fastmcp import FastMCP
|
14 |
+
|
15 |
+
mcp = FastMCP("Echo Server")
|
16 |
+
|
17 |
+
@mcp.tool()
|
18 |
+
def echo_tool(text: str) -> str:
|
19 |
+
"""Echo the input text"""
|
20 |
+
return f"Echo: {text}"
|
21 |
+
|
22 |
+
mcp.run()
|
23 |
+
'''
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def test_mcp_client_with_syntax(echo_server_script: str):
|
28 |
+
"""Test the MCPClient with the context manager syntax."""
|
29 |
+
server_parameters = StdioServerParameters(command="python", args=["-c", echo_server_script])
|
30 |
+
with MCPClient(server_parameters) as tools:
|
31 |
+
assert len(tools) == 1
|
32 |
+
assert tools[0].name == "echo_tool"
|
33 |
+
assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
|
34 |
+
|
35 |
+
|
36 |
+
def test_mcp_client_try_finally_syntax(echo_server_script: str):
|
37 |
+
"""Test the MCPClient with the try ... finally syntax."""
|
38 |
+
server_parameters = StdioServerParameters(command="python", args=["-c", echo_server_script])
|
39 |
+
mcp_client = MCPClient(server_parameters)
|
40 |
+
try:
|
41 |
+
tools = mcp_client.get_tools()
|
42 |
+
assert len(tools) == 1
|
43 |
+
assert tools[0].name == "echo_tool"
|
44 |
+
assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
|
45 |
+
finally:
|
46 |
+
mcp_client.disconnect()
|
47 |
+
|
48 |
+
|
49 |
+
def test_multiple_servers(echo_server_script: str):
|
50 |
+
"""Test the MCPClient with multiple servers."""
|
51 |
+
server_parameters = [
|
52 |
+
StdioServerParameters(command="python", args=["-c", echo_server_script]),
|
53 |
+
StdioServerParameters(command="python", args=["-c", echo_server_script]),
|
54 |
+
]
|
55 |
+
with MCPClient(server_parameters) as tools:
|
56 |
+
assert len(tools) == 2
|
57 |
+
assert tools[0].name == "echo_tool"
|
58 |
+
assert tools[1].name == "echo_tool"
|
59 |
+
assert tools[0].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
|
60 |
+
assert tools[1].forward(**{"text": "Hello, world!"}) == "Echo: Hello, world!"
|
tests/test_memory.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from smolagents.agents import ToolCall
|
5 |
+
from smolagents.memory import (
|
6 |
+
ActionStep,
|
7 |
+
AgentMemory,
|
8 |
+
ChatMessage,
|
9 |
+
MemoryStep,
|
10 |
+
MessageRole,
|
11 |
+
PlanningStep,
|
12 |
+
SystemPromptStep,
|
13 |
+
TaskStep,
|
14 |
+
)
|
15 |
+
from smolagents.monitoring import Timing, TokenUsage
|
16 |
+
|
17 |
+
|
18 |
+
class TestAgentMemory:
|
19 |
+
def test_initialization(self):
|
20 |
+
system_prompt = "This is a system prompt."
|
21 |
+
memory = AgentMemory(system_prompt=system_prompt)
|
22 |
+
assert memory.system_prompt.system_prompt == system_prompt
|
23 |
+
assert memory.steps == []
|
24 |
+
|
25 |
+
def test_return_all_code_actions(self):
|
26 |
+
memory = AgentMemory(system_prompt="This is a system prompt.")
|
27 |
+
memory.steps = [
|
28 |
+
ActionStep(step_number=1, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('Hello')"),
|
29 |
+
ActionStep(step_number=2, timing=Timing(start_time=0.0, end_time=1.0), code_action=None),
|
30 |
+
ActionStep(step_number=3, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('World')"),
|
31 |
+
] # type: ignore
|
32 |
+
assert memory.return_full_code() == "print('Hello')\n\nprint('World')"
|
33 |
+
|
34 |
+
|
35 |
+
class TestMemoryStep:
|
36 |
+
def test_initialization(self):
|
37 |
+
step = MemoryStep()
|
38 |
+
assert isinstance(step, MemoryStep)
|
39 |
+
|
40 |
+
def test_dict(self):
|
41 |
+
step = MemoryStep()
|
42 |
+
assert step.dict() == {}
|
43 |
+
|
44 |
+
def test_to_messages(self):
|
45 |
+
step = MemoryStep()
|
46 |
+
with pytest.raises(NotImplementedError):
|
47 |
+
step.to_messages()
|
48 |
+
|
49 |
+
|
50 |
+
def test_action_step_dict():
|
51 |
+
action_step = ActionStep(
|
52 |
+
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
|
53 |
+
tool_calls=[
|
54 |
+
ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}),
|
55 |
+
],
|
56 |
+
timing=Timing(start_time=0.0, end_time=1.0),
|
57 |
+
step_number=1,
|
58 |
+
error=None,
|
59 |
+
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
|
60 |
+
model_output="Hi",
|
61 |
+
observations="This is a nice observation",
|
62 |
+
observations_images=[Image.new("RGB", (100, 100))],
|
63 |
+
action_output="Output",
|
64 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
|
65 |
+
)
|
66 |
+
action_step_dict = action_step.dict()
|
67 |
+
# Check each key individually for better test failure messages
|
68 |
+
assert "model_input_messages" in action_step_dict
|
69 |
+
assert action_step_dict["model_input_messages"] == [ChatMessage(role=MessageRole.USER, content="Hello")]
|
70 |
+
|
71 |
+
assert "tool_calls" in action_step_dict
|
72 |
+
assert len(action_step_dict["tool_calls"]) == 1
|
73 |
+
assert action_step_dict["tool_calls"][0] == {
|
74 |
+
"id": "id",
|
75 |
+
"type": "function",
|
76 |
+
"function": {
|
77 |
+
"name": "get_weather",
|
78 |
+
"arguments": {"location": "Paris"},
|
79 |
+
},
|
80 |
+
}
|
81 |
+
|
82 |
+
assert "timing" in action_step_dict
|
83 |
+
assert action_step_dict["timing"] == {"start_time": 0.0, "end_time": 1.0, "duration": 1.0}
|
84 |
+
|
85 |
+
assert "token_usage" in action_step_dict
|
86 |
+
assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
|
87 |
+
|
88 |
+
assert "step_number" in action_step_dict
|
89 |
+
assert action_step_dict["step_number"] == 1
|
90 |
+
|
91 |
+
assert "error" in action_step_dict
|
92 |
+
assert action_step_dict["error"] is None
|
93 |
+
|
94 |
+
assert "model_output_message" in action_step_dict
|
95 |
+
assert action_step_dict["model_output_message"] == {
|
96 |
+
"role": "assistant",
|
97 |
+
"content": "Hi",
|
98 |
+
"tool_calls": None,
|
99 |
+
"raw": None,
|
100 |
+
"token_usage": None,
|
101 |
+
}
|
102 |
+
|
103 |
+
assert "model_output" in action_step_dict
|
104 |
+
assert action_step_dict["model_output"] == "Hi"
|
105 |
+
|
106 |
+
assert "observations" in action_step_dict
|
107 |
+
assert action_step_dict["observations"] == "This is a nice observation"
|
108 |
+
|
109 |
+
assert "observations_images" in action_step_dict
|
110 |
+
|
111 |
+
assert "action_output" in action_step_dict
|
112 |
+
assert action_step_dict["action_output"] == "Output"
|
113 |
+
|
114 |
+
|
115 |
+
def test_action_step_to_messages():
|
116 |
+
action_step = ActionStep(
|
117 |
+
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
|
118 |
+
tool_calls=[
|
119 |
+
ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}),
|
120 |
+
],
|
121 |
+
timing=Timing(start_time=0.0, end_time=1.0),
|
122 |
+
step_number=1,
|
123 |
+
error=None,
|
124 |
+
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
|
125 |
+
model_output="Hi",
|
126 |
+
observations="This is a nice observation",
|
127 |
+
observations_images=[Image.new("RGB", (100, 100))],
|
128 |
+
action_output="Output",
|
129 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
|
130 |
+
)
|
131 |
+
messages = action_step.to_messages()
|
132 |
+
assert len(messages) == 4
|
133 |
+
for message in messages:
|
134 |
+
assert isinstance(message, ChatMessage)
|
135 |
+
assistant_message = messages[0]
|
136 |
+
assert assistant_message.role == MessageRole.ASSISTANT
|
137 |
+
assert len(assistant_message.content) == 1
|
138 |
+
assert assistant_message.content[0]["type"] == "text"
|
139 |
+
assert assistant_message.content[0]["text"] == "Hi"
|
140 |
+
message = messages[1]
|
141 |
+
assert message.role == MessageRole.TOOL_CALL
|
142 |
+
|
143 |
+
assert len(message.content) == 1
|
144 |
+
assert message.content[0]["type"] == "text"
|
145 |
+
assert "Calling tools:" in message.content[0]["text"]
|
146 |
+
|
147 |
+
image_message = messages[2]
|
148 |
+
assert image_message.content[0]["type"] == "image" # type: ignore
|
149 |
+
|
150 |
+
observation_message = messages[3]
|
151 |
+
assert observation_message.role == MessageRole.TOOL_RESPONSE
|
152 |
+
assert "Observation:\nThis is a nice observation" in observation_message.content[0]["text"]
|
153 |
+
|
154 |
+
|
155 |
+
def test_action_step_to_messages_no_tool_calls_with_observations():
|
156 |
+
action_step = ActionStep(
|
157 |
+
model_input_messages=None,
|
158 |
+
tool_calls=None,
|
159 |
+
timing=Timing(start_time=0.0, end_time=1.0),
|
160 |
+
step_number=1,
|
161 |
+
error=None,
|
162 |
+
model_output_message=None,
|
163 |
+
model_output=None,
|
164 |
+
observations="This is an observation.",
|
165 |
+
observations_images=None,
|
166 |
+
action_output=None,
|
167 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
|
168 |
+
)
|
169 |
+
messages = action_step.to_messages()
|
170 |
+
assert len(messages) == 1
|
171 |
+
observation_message = messages[0]
|
172 |
+
assert observation_message.role == MessageRole.TOOL_RESPONSE
|
173 |
+
assert "Observation:\nThis is an observation." in observation_message.content[0]["text"]
|
174 |
+
|
175 |
+
|
176 |
+
def test_planning_step_to_messages():
|
177 |
+
planning_step = PlanningStep(
|
178 |
+
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")],
|
179 |
+
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Plan"),
|
180 |
+
plan="This is a plan.",
|
181 |
+
timing=Timing(start_time=0.0, end_time=1.0),
|
182 |
+
)
|
183 |
+
messages = planning_step.to_messages(summary_mode=False)
|
184 |
+
assert len(messages) == 2
|
185 |
+
for message in messages:
|
186 |
+
assert isinstance(message, ChatMessage)
|
187 |
+
assert isinstance(message.content, list)
|
188 |
+
assert len(message.content) == 1
|
189 |
+
for content in message.content:
|
190 |
+
assert isinstance(content, dict)
|
191 |
+
assert "type" in content
|
192 |
+
assert "text" in content
|
193 |
+
assert messages[0].role == MessageRole.ASSISTANT
|
194 |
+
assert messages[1].role == MessageRole.USER
|
195 |
+
|
196 |
+
|
197 |
+
def test_task_step_to_messages():
|
198 |
+
task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))])
|
199 |
+
messages = task_step.to_messages(summary_mode=False)
|
200 |
+
assert len(messages) == 1
|
201 |
+
for message in messages:
|
202 |
+
assert isinstance(message, ChatMessage)
|
203 |
+
assert message.role == MessageRole.USER
|
204 |
+
assert isinstance(message.content, list)
|
205 |
+
assert len(message.content) == 2
|
206 |
+
text_content = message.content[0]
|
207 |
+
assert isinstance(text_content, dict)
|
208 |
+
assert "type" in text_content
|
209 |
+
assert "text" in text_content
|
210 |
+
for image_content in message.content[1:]:
|
211 |
+
assert isinstance(image_content, dict)
|
212 |
+
assert "type" in image_content
|
213 |
+
assert "image" in image_content
|
214 |
+
|
215 |
+
|
216 |
+
def test_system_prompt_step_to_messages():
|
217 |
+
system_prompt_step = SystemPromptStep(system_prompt="This is a system prompt.")
|
218 |
+
messages = system_prompt_step.to_messages(summary_mode=False)
|
219 |
+
assert len(messages) == 1
|
220 |
+
for message in messages:
|
221 |
+
assert isinstance(message, ChatMessage)
|
222 |
+
assert message.role == MessageRole.SYSTEM
|
223 |
+
assert isinstance(message.content, list)
|
224 |
+
assert len(message.content) == 1
|
225 |
+
for content in message.content:
|
226 |
+
assert isinstance(content, dict)
|
227 |
+
assert "type" in content
|
228 |
+
assert "text" in content
|
tests/test_models.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
import sys
|
17 |
+
import unittest
|
18 |
+
from contextlib import ExitStack
|
19 |
+
from unittest.mock import MagicMock, patch
|
20 |
+
|
21 |
+
import pytest
|
22 |
+
from huggingface_hub import ChatCompletionOutputMessage
|
23 |
+
|
24 |
+
from smolagents.default_tools import FinalAnswerTool
|
25 |
+
from smolagents.models import (
|
26 |
+
AmazonBedrockServerModel,
|
27 |
+
AzureOpenAIServerModel,
|
28 |
+
ChatMessage,
|
29 |
+
ChatMessageToolCall,
|
30 |
+
InferenceClientModel,
|
31 |
+
LiteLLMModel,
|
32 |
+
LiteLLMRouterModel,
|
33 |
+
MessageRole,
|
34 |
+
MLXModel,
|
35 |
+
Model,
|
36 |
+
OpenAIServerModel,
|
37 |
+
TransformersModel,
|
38 |
+
get_clean_message_list,
|
39 |
+
get_tool_call_from_text,
|
40 |
+
get_tool_json_schema,
|
41 |
+
parse_json_if_needed,
|
42 |
+
supports_stop_parameter,
|
43 |
+
)
|
44 |
+
from smolagents.tools import tool
|
45 |
+
|
46 |
+
from .utils.markers import require_run_all
|
47 |
+
|
48 |
+
|
49 |
+
class TestModel:
|
50 |
+
def test_agglomerate_stream_deltas(self):
|
51 |
+
from smolagents.models import (
|
52 |
+
ChatMessageStreamDelta,
|
53 |
+
ChatMessageToolCallFunction,
|
54 |
+
ChatMessageToolCallStreamDelta,
|
55 |
+
TokenUsage,
|
56 |
+
agglomerate_stream_deltas,
|
57 |
+
)
|
58 |
+
|
59 |
+
stream_deltas = [
|
60 |
+
ChatMessageStreamDelta(
|
61 |
+
content="Hi",
|
62 |
+
tool_calls=[
|
63 |
+
ChatMessageToolCallStreamDelta(
|
64 |
+
index=0,
|
65 |
+
type="function",
|
66 |
+
function=ChatMessageToolCallFunction(arguments="", name="web_search", description=None),
|
67 |
+
)
|
68 |
+
],
|
69 |
+
token_usage=None,
|
70 |
+
),
|
71 |
+
ChatMessageStreamDelta(
|
72 |
+
content=" everyone",
|
73 |
+
tool_calls=[
|
74 |
+
ChatMessageToolCallStreamDelta(
|
75 |
+
index=0,
|
76 |
+
type="function",
|
77 |
+
function=ChatMessageToolCallFunction(arguments=' {"', name="web_search", description=None),
|
78 |
+
)
|
79 |
+
],
|
80 |
+
token_usage=None,
|
81 |
+
),
|
82 |
+
ChatMessageStreamDelta(
|
83 |
+
content=", it's",
|
84 |
+
tool_calls=[
|
85 |
+
ChatMessageToolCallStreamDelta(
|
86 |
+
index=0,
|
87 |
+
type="function",
|
88 |
+
function=ChatMessageToolCallFunction(
|
89 |
+
arguments='query": "current pope name and date of birth"}',
|
90 |
+
name="web_search",
|
91 |
+
description=None,
|
92 |
+
),
|
93 |
+
)
|
94 |
+
],
|
95 |
+
token_usage=None,
|
96 |
+
),
|
97 |
+
ChatMessageStreamDelta(
|
98 |
+
content="",
|
99 |
+
tool_calls=None,
|
100 |
+
token_usage=TokenUsage(input_tokens=1348, output_tokens=24),
|
101 |
+
),
|
102 |
+
]
|
103 |
+
agglomerated_stream_delta = agglomerate_stream_deltas(stream_deltas)
|
104 |
+
assert agglomerated_stream_delta.content == "Hi everyone, it's"
|
105 |
+
assert (
|
106 |
+
agglomerated_stream_delta.tool_calls[0].function.arguments
|
107 |
+
== ' {"query": "current pope name and date of birth"}'
|
108 |
+
)
|
109 |
+
assert agglomerated_stream_delta.token_usage.total_tokens == 1372
|
110 |
+
|
111 |
+
@pytest.mark.parametrize(
|
112 |
+
"model_id, stop_sequences, should_contain_stop",
|
113 |
+
[
|
114 |
+
("regular-model", ["stop1", "stop2"], True), # Regular model should include stop
|
115 |
+
("openai/o3", ["stop1", "stop2"], False), # o3 model should not include stop
|
116 |
+
("openai/o4-mini", ["stop1", "stop2"], False), # o4-mini model should not include stop
|
117 |
+
("something/else/o3", ["stop1", "stop2"], False), # Path ending with o3 should not include stop
|
118 |
+
("something/else/o4-mini", ["stop1", "stop2"], False), # Path ending with o4-mini should not include stop
|
119 |
+
("o3", ["stop1", "stop2"], False), # Exact o3 model should not include stop
|
120 |
+
("o4-mini", ["stop1", "stop2"], False), # Exact o4-mini model should not include stop
|
121 |
+
("regular-model", None, False), # None stop_sequences should not add stop parameter
|
122 |
+
],
|
123 |
+
)
|
124 |
+
def test_prepare_completion_kwargs_stop_sequences(self, model_id, stop_sequences, should_contain_stop):
|
125 |
+
model = Model()
|
126 |
+
model.model_id = model_id
|
127 |
+
completion_kwargs = model._prepare_completion_kwargs(
|
128 |
+
messages=[
|
129 |
+
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello"}]),
|
130 |
+
],
|
131 |
+
stop_sequences=stop_sequences,
|
132 |
+
)
|
133 |
+
# Verify that the stop parameter is only included when appropriate
|
134 |
+
if should_contain_stop:
|
135 |
+
assert "stop" in completion_kwargs
|
136 |
+
assert completion_kwargs["stop"] == stop_sequences
|
137 |
+
else:
|
138 |
+
assert "stop" not in completion_kwargs
|
139 |
+
|
140 |
+
@pytest.mark.parametrize(
|
141 |
+
"with_tools, tool_choice, expected_result",
|
142 |
+
[
|
143 |
+
# Default behavior: With tools but no explicit tool_choice, should default to "required"
|
144 |
+
(True, ..., {"has_tool_choice": True, "value": "required"}),
|
145 |
+
# Custom value: With tools and explicit tool_choice="auto"
|
146 |
+
(True, "auto", {"has_tool_choice": True, "value": "auto"}),
|
147 |
+
# Tool name as string
|
148 |
+
(True, "valid_tool_function", {"has_tool_choice": True, "value": "valid_tool_function"}),
|
149 |
+
# Tool choice as dictionary
|
150 |
+
(
|
151 |
+
True,
|
152 |
+
{"type": "function", "function": {"name": "valid_tool_function"}},
|
153 |
+
{"has_tool_choice": True, "value": {"type": "function", "function": {"name": "valid_tool_function"}}},
|
154 |
+
),
|
155 |
+
# With tools but explicit None tool_choice: should exclude tool_choice
|
156 |
+
(True, None, {"has_tool_choice": False, "value": None}),
|
157 |
+
# Without tools: tool_choice should never be included
|
158 |
+
(False, "required", {"has_tool_choice": False, "value": None}),
|
159 |
+
(False, "auto", {"has_tool_choice": False, "value": None}),
|
160 |
+
(False, None, {"has_tool_choice": False, "value": None}),
|
161 |
+
(False, ..., {"has_tool_choice": False, "value": None}),
|
162 |
+
],
|
163 |
+
)
|
164 |
+
def test_prepare_completion_kwargs_tool_choice(self, with_tools, tool_choice, expected_result, example_tool):
|
165 |
+
model = Model()
|
166 |
+
kwargs = {"messages": [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello"}])]}
|
167 |
+
if with_tools:
|
168 |
+
kwargs["tools_to_call_from"] = [example_tool]
|
169 |
+
if tool_choice is not ...:
|
170 |
+
kwargs["tool_choice"] = tool_choice
|
171 |
+
|
172 |
+
completion_kwargs = model._prepare_completion_kwargs(**kwargs)
|
173 |
+
|
174 |
+
if expected_result["has_tool_choice"]:
|
175 |
+
assert "tool_choice" in completion_kwargs
|
176 |
+
assert completion_kwargs["tool_choice"] == expected_result["value"]
|
177 |
+
else:
|
178 |
+
assert "tool_choice" not in completion_kwargs
|
179 |
+
|
180 |
+
def test_get_json_schema_has_nullable_args(self):
|
181 |
+
@tool
|
182 |
+
def get_weather(location: str, celsius: bool | None = False) -> str:
|
183 |
+
"""
|
184 |
+
Get weather in the next days at given location.
|
185 |
+
Secretly this tool does not care about the location, it hates the weather everywhere.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
location: the location
|
189 |
+
celsius: the temperature type
|
190 |
+
"""
|
191 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
192 |
+
|
193 |
+
assert "nullable" in get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
194 |
+
|
195 |
+
def test_chatmessage_has_model_dumps_json(self):
|
196 |
+
message = ChatMessage("user", [{"type": "text", "text": "Hello!"}])
|
197 |
+
data = json.loads(message.model_dump_json())
|
198 |
+
assert data["content"] == [{"type": "text", "text": "Hello!"}]
|
199 |
+
|
200 |
+
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
201 |
+
def test_get_mlx_message_no_tool(self):
|
202 |
+
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
|
203 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
204 |
+
output = model(messages, stop_sequences=["great"]).content
|
205 |
+
assert output.startswith("Hello")
|
206 |
+
|
207 |
+
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
208 |
+
def test_get_mlx_message_tricky_stop_sequence(self):
|
209 |
+
# In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'"
|
210 |
+
# which is required to test capturing stop_sequences that have extra chars at the end.
|
211 |
+
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100)
|
212 |
+
stop_sequence = " print '>"
|
213 |
+
messages = [
|
214 |
+
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": f"Please{stop_sequence}'"}]),
|
215 |
+
]
|
216 |
+
# check our assumption that that ">" is followed by "'"
|
217 |
+
assert model.tokenizer.vocab[">'"]
|
218 |
+
assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'"
|
219 |
+
# check stop_sequence capture when output has trailing chars
|
220 |
+
assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you"
|
221 |
+
|
222 |
+
def test_transformers_message_no_tool(self, monkeypatch):
|
223 |
+
monkeypatch.setattr("huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT", 30) # instead of 10
|
224 |
+
model = TransformersModel(
|
225 |
+
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
226 |
+
max_new_tokens=5,
|
227 |
+
device_map="cpu",
|
228 |
+
do_sample=False,
|
229 |
+
)
|
230 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
231 |
+
output = model.generate(messages).content
|
232 |
+
assert output == "Hello! I'm here"
|
233 |
+
|
234 |
+
output = model.generate_stream(messages, stop_sequences=["great"])
|
235 |
+
output_str = ""
|
236 |
+
for el in output:
|
237 |
+
output_str += el.content
|
238 |
+
assert output_str == "Hello! I'm here"
|
239 |
+
|
240 |
+
def test_transformers_message_vl_no_tool(self, shared_datadir, monkeypatch):
|
241 |
+
monkeypatch.setattr("huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT", 30) # instead of 10
|
242 |
+
import PIL.Image
|
243 |
+
|
244 |
+
img = PIL.Image.open(shared_datadir / "000000039769.png")
|
245 |
+
model = TransformersModel(
|
246 |
+
model_id="llava-hf/llava-interleave-qwen-0.5b-hf",
|
247 |
+
max_new_tokens=4,
|
248 |
+
device_map="cpu",
|
249 |
+
do_sample=False,
|
250 |
+
)
|
251 |
+
messages = [
|
252 |
+
ChatMessage(
|
253 |
+
role=MessageRole.USER,
|
254 |
+
content=[{"type": "text", "text": "What is this?"}, {"type": "image", "image": img}],
|
255 |
+
)
|
256 |
+
]
|
257 |
+
output = model.generate(messages).content
|
258 |
+
assert output == "This is a very"
|
259 |
+
|
260 |
+
output = model.generate_stream(messages, stop_sequences=["great"])
|
261 |
+
output_str = ""
|
262 |
+
for el in output:
|
263 |
+
output_str += el.content
|
264 |
+
assert output_str == "This is a very"
|
265 |
+
|
266 |
+
def test_parse_json_if_needed(self):
|
267 |
+
args = "abc"
|
268 |
+
parsed_args = parse_json_if_needed(args)
|
269 |
+
assert parsed_args == "abc"
|
270 |
+
|
271 |
+
args = '{"a": 3}'
|
272 |
+
parsed_args = parse_json_if_needed(args)
|
273 |
+
assert parsed_args == {"a": 3}
|
274 |
+
|
275 |
+
args = "3"
|
276 |
+
parsed_args = parse_json_if_needed(args)
|
277 |
+
assert parsed_args == 3
|
278 |
+
|
279 |
+
args = 3
|
280 |
+
parsed_args = parse_json_if_needed(args)
|
281 |
+
assert parsed_args == 3
|
282 |
+
|
283 |
+
|
284 |
+
class TestInferenceClientModel:
|
285 |
+
def test_call_with_custom_role_conversions(self):
|
286 |
+
custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
|
287 |
+
model = InferenceClientModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
|
288 |
+
model.client = MagicMock()
|
289 |
+
mock_response = model.client.chat_completion.return_value
|
290 |
+
mock_response.choices[0].message = ChatCompletionOutputMessage(role=MessageRole.ASSISTANT)
|
291 |
+
messages = [ChatMessage(role=MessageRole.USER, content="Test message")]
|
292 |
+
_ = model(messages)
|
293 |
+
# Verify that the role conversion was applied
|
294 |
+
assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
|
295 |
+
"role conversion should be applied"
|
296 |
+
)
|
297 |
+
|
298 |
+
def test_init_model_with_tokens(self):
|
299 |
+
model = InferenceClientModel(model_id="test-model", token="abc")
|
300 |
+
assert model.client.token == "abc"
|
301 |
+
|
302 |
+
model = InferenceClientModel(model_id="test-model", api_key="abc")
|
303 |
+
assert model.client.token == "abc"
|
304 |
+
|
305 |
+
with pytest.raises(ValueError, match="Received both `token` and `api_key` arguments."):
|
306 |
+
InferenceClientModel(model_id="test-model", token="abc", api_key="def")
|
307 |
+
|
308 |
+
def test_structured_outputs_with_unsupported_provider(self):
|
309 |
+
with pytest.raises(
|
310 |
+
ValueError, match="InferenceClientModel only supports structured outputs with these providers:"
|
311 |
+
):
|
312 |
+
model = InferenceClientModel(model_id="test-model", token="abc", provider="some_provider")
|
313 |
+
model.generate(
|
314 |
+
messages=[ChatMessage(role=MessageRole.USER, content="Hello!")],
|
315 |
+
response_format={"type": "json_object"},
|
316 |
+
)
|
317 |
+
|
318 |
+
@require_run_all
|
319 |
+
def test_get_hfapi_message_no_tool(self):
|
320 |
+
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=10)
|
321 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
322 |
+
model(messages, stop_sequences=["great"])
|
323 |
+
|
324 |
+
@require_run_all
|
325 |
+
def test_get_hfapi_message_no_tool_external_provider(self):
|
326 |
+
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
327 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
328 |
+
model(messages, stop_sequences=["great"])
|
329 |
+
|
330 |
+
@require_run_all
|
331 |
+
def test_get_hfapi_message_stream_no_tool(self):
|
332 |
+
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=10)
|
333 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
334 |
+
for el in model.generate_stream(messages, stop_sequences=["great"]):
|
335 |
+
assert el.content is not None
|
336 |
+
|
337 |
+
@require_run_all
|
338 |
+
def test_get_hfapi_message_stream_no_tool_external_provider(self):
|
339 |
+
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
340 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}])]
|
341 |
+
for el in model.generate_stream(messages, stop_sequences=["great"]):
|
342 |
+
assert el.content is not None
|
343 |
+
|
344 |
+
|
345 |
+
class TestLiteLLMModel:
|
346 |
+
@pytest.mark.parametrize(
|
347 |
+
"model_id, error_flag",
|
348 |
+
[
|
349 |
+
("groq/llama-3.3-70b", "Invalid API Key"),
|
350 |
+
("cerebras/llama-3.3-70b", "The api_key client option must be set"),
|
351 |
+
("mistral/mistral-tiny", "The api_key client option must be set"),
|
352 |
+
],
|
353 |
+
)
|
354 |
+
def test_call_different_providers_without_key(self, model_id, error_flag):
|
355 |
+
model = LiteLLMModel(model_id=model_id)
|
356 |
+
messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Test message"}])]
|
357 |
+
with pytest.raises(Exception) as e:
|
358 |
+
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
|
359 |
+
model.generate(messages)
|
360 |
+
assert error_flag in str(e)
|
361 |
+
with pytest.raises(Exception) as e:
|
362 |
+
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
|
363 |
+
for el in model.generate_stream(messages):
|
364 |
+
assert el.content is not None
|
365 |
+
assert error_flag in str(e)
|
366 |
+
|
367 |
+
def test_passing_flatten_messages(self):
|
368 |
+
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
|
369 |
+
assert not model.flatten_messages_as_text
|
370 |
+
|
371 |
+
model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
|
372 |
+
assert model.flatten_messages_as_text
|
373 |
+
|
374 |
+
|
375 |
+
class TestLiteLLMRouterModel:
|
376 |
+
@pytest.mark.parametrize(
|
377 |
+
"model_id, expected",
|
378 |
+
[
|
379 |
+
("llama-3.3-70b", False),
|
380 |
+
("llama-3.3-70b", True),
|
381 |
+
("mistral-tiny", True),
|
382 |
+
],
|
383 |
+
)
|
384 |
+
def test_flatten_messages_as_text(self, model_id, expected):
|
385 |
+
model_list = [
|
386 |
+
{"model_name": "llama-3.3-70b", "litellm_params": {"model": "groq/llama-3.3-70b"}},
|
387 |
+
{"model_name": "llama-3.3-70b", "litellm_params": {"model": "cerebras/llama-3.3-70b"}},
|
388 |
+
{"model_name": "mistral-tiny", "litellm_params": {"model": "mistral/mistral-tiny"}},
|
389 |
+
]
|
390 |
+
model = LiteLLMRouterModel(model_id=model_id, model_list=model_list, flatten_messages_as_text=expected)
|
391 |
+
assert model.flatten_messages_as_text is expected
|
392 |
+
|
393 |
+
def test_create_client(self):
|
394 |
+
model_list = [
|
395 |
+
{"model_name": "llama-3.3-70b", "litellm_params": {"model": "groq/llama-3.3-70b"}},
|
396 |
+
{"model_name": "llama-3.3-70b", "litellm_params": {"model": "cerebras/llama-3.3-70b"}},
|
397 |
+
]
|
398 |
+
with patch("litellm.router.Router") as mock_router:
|
399 |
+
router_model = LiteLLMRouterModel(
|
400 |
+
model_id="model-group-1", model_list=model_list, client_kwargs={"routing_strategy": "simple-shuffle"}
|
401 |
+
)
|
402 |
+
# Ensure that the Router constructor was called with the expected keyword arguments
|
403 |
+
mock_router.assert_called_once()
|
404 |
+
assert mock_router.call_count == 1
|
405 |
+
assert mock_router.call_args.kwargs["model_list"] == model_list
|
406 |
+
assert mock_router.call_args.kwargs["routing_strategy"] == "simple-shuffle"
|
407 |
+
assert router_model.client == mock_router.return_value
|
408 |
+
|
409 |
+
|
410 |
+
class TestOpenAIServerModel:
|
411 |
+
def test_client_kwargs_passed_correctly(self):
|
412 |
+
model_id = "gpt-3.5-turbo"
|
413 |
+
api_base = "https://api.openai.com/v1"
|
414 |
+
api_key = "test_api_key"
|
415 |
+
organization = "test_org"
|
416 |
+
project = "test_project"
|
417 |
+
client_kwargs = {"max_retries": 5}
|
418 |
+
|
419 |
+
with patch("openai.OpenAI") as MockOpenAI:
|
420 |
+
model = OpenAIServerModel(
|
421 |
+
model_id=model_id,
|
422 |
+
api_base=api_base,
|
423 |
+
api_key=api_key,
|
424 |
+
organization=organization,
|
425 |
+
project=project,
|
426 |
+
client_kwargs=client_kwargs,
|
427 |
+
)
|
428 |
+
MockOpenAI.assert_called_once_with(
|
429 |
+
base_url=api_base, api_key=api_key, organization=organization, project=project, max_retries=5
|
430 |
+
)
|
431 |
+
assert model.client == MockOpenAI.return_value
|
432 |
+
|
433 |
+
@require_run_all
|
434 |
+
def test_streaming_tool_calls(self):
|
435 |
+
model = OpenAIServerModel(model_id="gpt-4o-mini")
|
436 |
+
messages = [
|
437 |
+
ChatMessage(
|
438 |
+
role=MessageRole.USER,
|
439 |
+
content=[
|
440 |
+
{
|
441 |
+
"type": "text",
|
442 |
+
"text": "Hello! Please return the final answer 'blob' and the final answer 'blob2' in two parallel tool calls",
|
443 |
+
}
|
444 |
+
],
|
445 |
+
),
|
446 |
+
]
|
447 |
+
for el in model.generate_stream(messages, tools_to_call_from=[FinalAnswerTool()]):
|
448 |
+
if el.tool_calls:
|
449 |
+
assert el.tool_calls[0].function.name == "final_answer"
|
450 |
+
args = el.tool_calls[0].function.arguments
|
451 |
+
if len(el.tool_calls) > 1:
|
452 |
+
assert el.tool_calls[1].function.name == "final_answer"
|
453 |
+
args2 = el.tool_calls[1].function.arguments
|
454 |
+
assert args == '{"answer": "blob"}'
|
455 |
+
assert args2 == '{"answer": "blob2"}'
|
456 |
+
|
457 |
+
|
458 |
+
class TestAmazonBedrockServerModel:
|
459 |
+
def test_client_for_bedrock(self):
|
460 |
+
model_id = "us.amazon.nova-pro-v1:0"
|
461 |
+
|
462 |
+
with patch("boto3.client") as MockBoto3:
|
463 |
+
model = AmazonBedrockServerModel(
|
464 |
+
model_id=model_id,
|
465 |
+
)
|
466 |
+
|
467 |
+
assert model.client == MockBoto3.return_value
|
468 |
+
|
469 |
+
|
470 |
+
class TestAzureOpenAIServerModel:
|
471 |
+
def test_client_kwargs_passed_correctly(self):
|
472 |
+
model_id = "gpt-3.5-turbo"
|
473 |
+
api_key = "test_api_key"
|
474 |
+
api_version = "2023-12-01-preview"
|
475 |
+
azure_endpoint = "https://example-resource.azure.openai.com/"
|
476 |
+
organization = "test_org"
|
477 |
+
project = "test_project"
|
478 |
+
client_kwargs = {"max_retries": 5}
|
479 |
+
|
480 |
+
with patch("openai.OpenAI") as MockOpenAI, patch("openai.AzureOpenAI") as MockAzureOpenAI:
|
481 |
+
model = AzureOpenAIServerModel(
|
482 |
+
model_id=model_id,
|
483 |
+
api_key=api_key,
|
484 |
+
api_version=api_version,
|
485 |
+
azure_endpoint=azure_endpoint,
|
486 |
+
organization=organization,
|
487 |
+
project=project,
|
488 |
+
client_kwargs=client_kwargs,
|
489 |
+
)
|
490 |
+
assert MockOpenAI.call_count == 0
|
491 |
+
MockAzureOpenAI.assert_called_once_with(
|
492 |
+
base_url=None,
|
493 |
+
api_key=api_key,
|
494 |
+
api_version=api_version,
|
495 |
+
azure_endpoint=azure_endpoint,
|
496 |
+
organization=organization,
|
497 |
+
project=project,
|
498 |
+
max_retries=5,
|
499 |
+
)
|
500 |
+
assert model.client == MockAzureOpenAI.return_value
|
501 |
+
|
502 |
+
|
503 |
+
class TestTransformersModel:
|
504 |
+
@pytest.mark.parametrize(
|
505 |
+
"patching",
|
506 |
+
[
|
507 |
+
[
|
508 |
+
(
|
509 |
+
"transformers.AutoModelForImageTextToText.from_pretrained",
|
510 |
+
{"side_effect": ValueError("Unrecognized configuration class")},
|
511 |
+
),
|
512 |
+
("transformers.AutoModelForCausalLM.from_pretrained", {}),
|
513 |
+
("transformers.AutoTokenizer.from_pretrained", {}),
|
514 |
+
],
|
515 |
+
[
|
516 |
+
("transformers.AutoModelForImageTextToText.from_pretrained", {}),
|
517 |
+
("transformers.AutoProcessor.from_pretrained", {}),
|
518 |
+
],
|
519 |
+
],
|
520 |
+
)
|
521 |
+
def test_init(self, patching):
|
522 |
+
with ExitStack() as stack:
|
523 |
+
mocks = {target: stack.enter_context(patch(target, **kwargs)) for target, kwargs in patching}
|
524 |
+
model = TransformersModel(
|
525 |
+
model_id="test-model", device_map="cpu", torch_dtype="float16", trust_remote_code=True
|
526 |
+
)
|
527 |
+
assert model.model_id == "test-model"
|
528 |
+
if "transformers.AutoTokenizer.from_pretrained" in mocks:
|
529 |
+
assert model.model == mocks["transformers.AutoModelForCausalLM.from_pretrained"].return_value
|
530 |
+
assert mocks["transformers.AutoModelForCausalLM.from_pretrained"].call_args.kwargs == {
|
531 |
+
"device_map": "cpu",
|
532 |
+
"torch_dtype": "float16",
|
533 |
+
"trust_remote_code": True,
|
534 |
+
}
|
535 |
+
assert model.tokenizer == mocks["transformers.AutoTokenizer.from_pretrained"].return_value
|
536 |
+
assert mocks["transformers.AutoTokenizer.from_pretrained"].call_args.args == ("test-model",)
|
537 |
+
assert mocks["transformers.AutoTokenizer.from_pretrained"].call_args.kwargs == {"trust_remote_code": True}
|
538 |
+
elif "transformers.AutoProcessor.from_pretrained" in mocks:
|
539 |
+
assert model.model == mocks["transformers.AutoModelForImageTextToText.from_pretrained"].return_value
|
540 |
+
assert mocks["transformers.AutoModelForImageTextToText.from_pretrained"].call_args.kwargs == {
|
541 |
+
"device_map": "cpu",
|
542 |
+
"torch_dtype": "float16",
|
543 |
+
"trust_remote_code": True,
|
544 |
+
}
|
545 |
+
assert model.processor == mocks["transformers.AutoProcessor.from_pretrained"].return_value
|
546 |
+
assert mocks["transformers.AutoProcessor.from_pretrained"].call_args.args == ("test-model",)
|
547 |
+
assert mocks["transformers.AutoProcessor.from_pretrained"].call_args.kwargs == {"trust_remote_code": True}
|
548 |
+
|
549 |
+
|
550 |
+
def test_get_clean_message_list_basic():
|
551 |
+
messages = [
|
552 |
+
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}]),
|
553 |
+
ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": "Hi there!"}]),
|
554 |
+
]
|
555 |
+
result = get_clean_message_list(messages)
|
556 |
+
assert len(result) == 2
|
557 |
+
assert result[0]["role"] == "user"
|
558 |
+
assert result[0]["content"][0]["text"] == "Hello!"
|
559 |
+
assert result[1]["role"] == "assistant"
|
560 |
+
assert result[1]["content"][0]["text"] == "Hi there!"
|
561 |
+
|
562 |
+
|
563 |
+
def test_get_clean_message_list_role_conversions():
|
564 |
+
messages = [
|
565 |
+
ChatMessage(role=MessageRole.TOOL_CALL, content=[{"type": "text", "text": "Calling tool..."}]),
|
566 |
+
ChatMessage(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": "Tool response"}]),
|
567 |
+
]
|
568 |
+
result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"})
|
569 |
+
assert len(result) == 2
|
570 |
+
assert result[0]["role"] == "assistant"
|
571 |
+
assert result[0]["content"][0]["text"] == "Calling tool..."
|
572 |
+
assert result[1]["role"] == "user"
|
573 |
+
assert result[1]["content"][0]["text"] == "Tool response"
|
574 |
+
|
575 |
+
|
576 |
+
@pytest.mark.parametrize(
|
577 |
+
"convert_images_to_image_urls, expected_clean_message",
|
578 |
+
[
|
579 |
+
(
|
580 |
+
False,
|
581 |
+
dict(
|
582 |
+
role=MessageRole.USER,
|
583 |
+
content=[
|
584 |
+
{"type": "image", "image": "encoded_image"},
|
585 |
+
{"type": "image", "image": "second_encoded_image"},
|
586 |
+
],
|
587 |
+
),
|
588 |
+
),
|
589 |
+
(
|
590 |
+
True,
|
591 |
+
dict(
|
592 |
+
role=MessageRole.USER,
|
593 |
+
content=[
|
594 |
+
{"type": "image_url", "image_url": {"url": "data:image/png;base64,encoded_image"}},
|
595 |
+
{"type": "image_url", "image_url": {"url": "data:image/png;base64,second_encoded_image"}},
|
596 |
+
],
|
597 |
+
),
|
598 |
+
),
|
599 |
+
],
|
600 |
+
)
|
601 |
+
def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message):
|
602 |
+
message = ChatMessage(
|
603 |
+
role=MessageRole.USER,
|
604 |
+
content=[{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
|
605 |
+
)
|
606 |
+
with patch("smolagents.models.encode_image_base64") as mock_encode:
|
607 |
+
mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
|
608 |
+
result = get_clean_message_list([message], convert_images_to_image_urls=convert_images_to_image_urls)
|
609 |
+
mock_encode.assert_any_call(b"image_data")
|
610 |
+
mock_encode.assert_any_call(b"second_image_data")
|
611 |
+
assert len(result) == 1
|
612 |
+
assert result[0] == expected_clean_message
|
613 |
+
|
614 |
+
|
615 |
+
def test_get_clean_message_list_flatten_messages_as_text():
|
616 |
+
messages = [
|
617 |
+
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}]),
|
618 |
+
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "How are you?"}]),
|
619 |
+
]
|
620 |
+
result = get_clean_message_list(messages, flatten_messages_as_text=True)
|
621 |
+
assert len(result) == 1
|
622 |
+
assert result[0]["role"] == "user"
|
623 |
+
assert result[0]["content"] == "Hello!\nHow are you?"
|
624 |
+
|
625 |
+
|
626 |
+
@pytest.mark.parametrize(
|
627 |
+
"model_class, model_kwargs, patching, expected_flatten_messages_as_text",
|
628 |
+
[
|
629 |
+
(AzureOpenAIServerModel, {}, ("openai.AzureOpenAI", {}), False),
|
630 |
+
(InferenceClientModel, {}, ("huggingface_hub.InferenceClient", {}), False),
|
631 |
+
(LiteLLMModel, {}, None, False),
|
632 |
+
(LiteLLMModel, {"model_id": "ollama"}, None, True),
|
633 |
+
(LiteLLMModel, {"model_id": "groq"}, None, True),
|
634 |
+
(LiteLLMModel, {"model_id": "cerebras"}, None, True),
|
635 |
+
(MLXModel, {}, ("mlx_lm.load", {"return_value": (MagicMock(), MagicMock())}), True),
|
636 |
+
(OpenAIServerModel, {}, ("openai.OpenAI", {}), False),
|
637 |
+
(OpenAIServerModel, {"flatten_messages_as_text": True}, ("openai.OpenAI", {}), True),
|
638 |
+
(
|
639 |
+
TransformersModel,
|
640 |
+
{},
|
641 |
+
[
|
642 |
+
(
|
643 |
+
"transformers.AutoModelForImageTextToText.from_pretrained",
|
644 |
+
{"side_effect": ValueError("Unrecognized configuration class")},
|
645 |
+
),
|
646 |
+
("transformers.AutoModelForCausalLM.from_pretrained", {}),
|
647 |
+
("transformers.AutoTokenizer.from_pretrained", {}),
|
648 |
+
],
|
649 |
+
True,
|
650 |
+
),
|
651 |
+
(
|
652 |
+
TransformersModel,
|
653 |
+
{},
|
654 |
+
[
|
655 |
+
("transformers.AutoModelForImageTextToText.from_pretrained", {}),
|
656 |
+
("transformers.AutoProcessor.from_pretrained", {}),
|
657 |
+
],
|
658 |
+
False,
|
659 |
+
),
|
660 |
+
],
|
661 |
+
)
|
662 |
+
def test_flatten_messages_as_text_for_all_models(
|
663 |
+
model_class, model_kwargs, patching, expected_flatten_messages_as_text
|
664 |
+
):
|
665 |
+
with ExitStack() as stack:
|
666 |
+
if isinstance(patching, list):
|
667 |
+
for target, kwargs in patching:
|
668 |
+
stack.enter_context(patch(target, **kwargs))
|
669 |
+
elif patching:
|
670 |
+
target, kwargs = patching
|
671 |
+
stack.enter_context(patch(target, **kwargs))
|
672 |
+
|
673 |
+
model = model_class(**{"model_id": "test-model", **model_kwargs})
|
674 |
+
assert model.flatten_messages_as_text is expected_flatten_messages_as_text, f"{model_class.__name__} failed"
|
675 |
+
|
676 |
+
|
677 |
+
@pytest.mark.parametrize(
|
678 |
+
"model_id,expected",
|
679 |
+
[
|
680 |
+
# Unsupported base models
|
681 |
+
("o3", False),
|
682 |
+
("o4-mini", False),
|
683 |
+
# Unsupported versioned models
|
684 |
+
("o3-2025-04-16", False),
|
685 |
+
("o4-mini-2025-04-16", False),
|
686 |
+
# Unsupported models with path prefixes
|
687 |
+
("openai/o3", False),
|
688 |
+
("openai/o4-mini", False),
|
689 |
+
("openai/o3-2025-04-16", False),
|
690 |
+
("openai/o4-mini-2025-04-16", False),
|
691 |
+
# Supported models
|
692 |
+
("o3-mini", True), # Different from o3
|
693 |
+
("o3-mini-2025-01-31", True), # Different from o3
|
694 |
+
("o4", True), # Different from o4-mini
|
695 |
+
("o4-turbo", True), # Different from o4-mini
|
696 |
+
("gpt-4", True),
|
697 |
+
("claude-3-5-sonnet", True),
|
698 |
+
("mistral-large", True),
|
699 |
+
# Supported models with path prefixes
|
700 |
+
("openai/gpt-4", True),
|
701 |
+
("anthropic/claude-3-5-sonnet", True),
|
702 |
+
("mistralai/mistral-large", True),
|
703 |
+
# Edge cases
|
704 |
+
("", True), # Empty string doesn't match pattern
|
705 |
+
("o3x", True), # Not exactly o3
|
706 |
+
("o3_mini", True), # Not o3-mini format
|
707 |
+
("prefix-o3", True), # o3 not at start
|
708 |
+
],
|
709 |
+
)
|
710 |
+
def test_supports_stop_parameter(model_id, expected):
|
711 |
+
"""Test the supports_stop_parameter function with various model IDs"""
|
712 |
+
assert supports_stop_parameter(model_id) == expected, f"Failed for model_id: {model_id}"
|
713 |
+
|
714 |
+
|
715 |
+
class TestGetToolCallFromText:
|
716 |
+
@pytest.fixture(autouse=True)
|
717 |
+
def mock_uuid4(self):
|
718 |
+
with patch("uuid.uuid4", return_value="test-uuid"):
|
719 |
+
yield
|
720 |
+
|
721 |
+
def test_get_tool_call_from_text_basic(self):
|
722 |
+
text = '{"name": "weather_tool", "arguments": "New York"}'
|
723 |
+
result = get_tool_call_from_text(text, "name", "arguments")
|
724 |
+
assert isinstance(result, ChatMessageToolCall)
|
725 |
+
assert result.id == "test-uuid"
|
726 |
+
assert result.type == "function"
|
727 |
+
assert result.function.name == "weather_tool"
|
728 |
+
assert result.function.arguments == "New York"
|
729 |
+
|
730 |
+
def test_get_tool_call_from_text_name_key_missing(self):
|
731 |
+
text = '{"action": "weather_tool", "arguments": "New York"}'
|
732 |
+
with pytest.raises(ValueError) as exc_info:
|
733 |
+
get_tool_call_from_text(text, "name", "arguments")
|
734 |
+
error_msg = str(exc_info.value)
|
735 |
+
assert "Key tool_name_key='name' not found" in error_msg
|
736 |
+
assert "'action', 'arguments'" in error_msg
|
737 |
+
|
738 |
+
def test_get_tool_call_from_text_json_object_args(self):
|
739 |
+
text = '{"name": "weather_tool", "arguments": {"city": "New York"}}'
|
740 |
+
result = get_tool_call_from_text(text, "name", "arguments")
|
741 |
+
assert result.function.arguments == {"city": "New York"}
|
742 |
+
|
743 |
+
def test_get_tool_call_from_text_json_string_args(self):
|
744 |
+
text = '{"name": "weather_tool", "arguments": "{\\"city\\": \\"New York\\"}"}'
|
745 |
+
result = get_tool_call_from_text(text, "name", "arguments")
|
746 |
+
assert result.function.arguments == {"city": "New York"}
|
747 |
+
|
748 |
+
def test_get_tool_call_from_text_missing_args(self):
|
749 |
+
text = '{"name": "weather_tool"}'
|
750 |
+
result = get_tool_call_from_text(text, "name", "arguments")
|
751 |
+
assert result.function.arguments is None
|
752 |
+
|
753 |
+
def test_get_tool_call_from_text_custom_keys(self):
|
754 |
+
text = '{"tool": "weather_tool", "params": "New York"}'
|
755 |
+
result = get_tool_call_from_text(text, "tool", "params")
|
756 |
+
assert result.function.name == "weather_tool"
|
757 |
+
assert result.function.arguments == "New York"
|
758 |
+
|
759 |
+
def test_get_tool_call_from_text_numeric_args(self):
|
760 |
+
text = '{"name": "calculator", "arguments": 42}'
|
761 |
+
result = get_tool_call_from_text(text, "name", "arguments")
|
762 |
+
assert result.function.name == "calculator"
|
763 |
+
assert result.function.arguments == 42
|
tests/test_monitoring.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import unittest
|
17 |
+
|
18 |
+
import PIL.Image
|
19 |
+
import pytest
|
20 |
+
|
21 |
+
from smolagents import (
|
22 |
+
CodeAgent,
|
23 |
+
RunResult,
|
24 |
+
ToolCallingAgent,
|
25 |
+
stream_to_gradio,
|
26 |
+
)
|
27 |
+
from smolagents.models import (
|
28 |
+
ChatMessage,
|
29 |
+
ChatMessageToolCall,
|
30 |
+
ChatMessageToolCallFunction,
|
31 |
+
MessageRole,
|
32 |
+
Model,
|
33 |
+
TokenUsage,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class FakeLLMModel(Model):
|
38 |
+
def __init__(self, give_token_usage: bool = True):
|
39 |
+
self.give_token_usage = give_token_usage
|
40 |
+
|
41 |
+
def generate(self, prompt, tools_to_call_from=None, **kwargs):
|
42 |
+
if tools_to_call_from is not None:
|
43 |
+
return ChatMessage(
|
44 |
+
role=MessageRole.ASSISTANT,
|
45 |
+
content="",
|
46 |
+
tool_calls=[
|
47 |
+
ChatMessageToolCall(
|
48 |
+
id="fake_id",
|
49 |
+
type="function",
|
50 |
+
function=ChatMessageToolCallFunction(name="final_answer", arguments={"answer": "image"}),
|
51 |
+
)
|
52 |
+
],
|
53 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
return ChatMessage(
|
57 |
+
role=MessageRole.ASSISTANT,
|
58 |
+
content="""<code>
|
59 |
+
final_answer('This is the final answer.')
|
60 |
+
</code>""",
|
61 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
class MonitoringTester(unittest.TestCase):
|
66 |
+
def test_code_agent_metrics(self):
|
67 |
+
agent = CodeAgent(
|
68 |
+
tools=[],
|
69 |
+
model=FakeLLMModel(),
|
70 |
+
max_steps=1,
|
71 |
+
)
|
72 |
+
agent.run("Fake task")
|
73 |
+
|
74 |
+
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
75 |
+
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
76 |
+
|
77 |
+
def test_toolcalling_agent_metrics(self):
|
78 |
+
agent = ToolCallingAgent(
|
79 |
+
tools=[],
|
80 |
+
model=FakeLLMModel(),
|
81 |
+
max_steps=1,
|
82 |
+
)
|
83 |
+
|
84 |
+
agent.run("Fake task")
|
85 |
+
|
86 |
+
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
87 |
+
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
88 |
+
|
89 |
+
def test_code_agent_metrics_max_steps(self):
|
90 |
+
class FakeLLMModelMalformedAnswer(Model):
|
91 |
+
def generate(self, prompt, **kwargs):
|
92 |
+
return ChatMessage(
|
93 |
+
role=MessageRole.ASSISTANT,
|
94 |
+
content="Malformed answer",
|
95 |
+
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
|
96 |
+
)
|
97 |
+
|
98 |
+
agent = CodeAgent(
|
99 |
+
tools=[],
|
100 |
+
model=FakeLLMModelMalformedAnswer(),
|
101 |
+
max_steps=1,
|
102 |
+
)
|
103 |
+
|
104 |
+
agent.run("Fake task")
|
105 |
+
|
106 |
+
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
107 |
+
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
108 |
+
|
109 |
+
def test_code_agent_metrics_generation_error(self):
|
110 |
+
class FakeLLMModelGenerationException(Model):
|
111 |
+
def generate(self, prompt, **kwargs):
|
112 |
+
raise Exception("Cannot generate")
|
113 |
+
|
114 |
+
agent = CodeAgent(
|
115 |
+
tools=[],
|
116 |
+
model=FakeLLMModelGenerationException(),
|
117 |
+
max_steps=1,
|
118 |
+
)
|
119 |
+
with pytest.raises(Exception) as e:
|
120 |
+
agent.run("Fake task")
|
121 |
+
assert "Cannot generate" in str(e.value)
|
122 |
+
|
123 |
+
def test_streaming_agent_text_output(self):
|
124 |
+
agent = CodeAgent(
|
125 |
+
tools=[],
|
126 |
+
model=FakeLLMModel(),
|
127 |
+
max_steps=1,
|
128 |
+
planning_interval=2,
|
129 |
+
)
|
130 |
+
|
131 |
+
# Use stream_to_gradio to capture the output
|
132 |
+
outputs = list(stream_to_gradio(agent, task="Test task"))
|
133 |
+
|
134 |
+
self.assertEqual(len(outputs), 11)
|
135 |
+
plan_message = outputs[1]
|
136 |
+
self.assertEqual(plan_message.role, "assistant")
|
137 |
+
self.assertIn("<code>", plan_message.content)
|
138 |
+
final_message = outputs[-1]
|
139 |
+
self.assertEqual(final_message.role, "assistant")
|
140 |
+
self.assertIn("This is the final answer.", final_message.content)
|
141 |
+
|
142 |
+
def test_streaming_agent_image_output(self):
|
143 |
+
agent = ToolCallingAgent(
|
144 |
+
tools=[],
|
145 |
+
model=FakeLLMModel(),
|
146 |
+
max_steps=1,
|
147 |
+
verbosity_level=100,
|
148 |
+
)
|
149 |
+
|
150 |
+
# Use stream_to_gradio to capture the output
|
151 |
+
outputs = list(
|
152 |
+
stream_to_gradio(
|
153 |
+
agent,
|
154 |
+
task="Test task",
|
155 |
+
additional_args=dict(image=PIL.Image.new("RGB", (100, 100))),
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
self.assertEqual(len(outputs), 7)
|
160 |
+
final_message = outputs[-1]
|
161 |
+
self.assertEqual(final_message.role, "assistant")
|
162 |
+
self.assertIsInstance(final_message.content, dict)
|
163 |
+
self.assertEqual(final_message.content["mime_type"], "image/png")
|
164 |
+
|
165 |
+
def test_streaming_with_agent_error(self):
|
166 |
+
class DummyModel(Model):
|
167 |
+
def generate(self, prompt, **kwargs):
|
168 |
+
return ChatMessage(role=MessageRole.ASSISTANT, content="Malformed call")
|
169 |
+
|
170 |
+
agent = CodeAgent(
|
171 |
+
tools=[],
|
172 |
+
model=DummyModel(),
|
173 |
+
max_steps=1,
|
174 |
+
)
|
175 |
+
|
176 |
+
# Use stream_to_gradio to capture the output
|
177 |
+
outputs = list(stream_to_gradio(agent, task="Test task"))
|
178 |
+
|
179 |
+
self.assertEqual(len(outputs), 11)
|
180 |
+
final_message = outputs[-1]
|
181 |
+
self.assertEqual(final_message.role, "assistant")
|
182 |
+
self.assertIn("Malformed call", final_message.content)
|
183 |
+
|
184 |
+
def test_run_return_full_result(self):
|
185 |
+
agent = CodeAgent(
|
186 |
+
tools=[],
|
187 |
+
model=FakeLLMModel(),
|
188 |
+
max_steps=1,
|
189 |
+
return_full_result=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
result = agent.run("Fake task")
|
193 |
+
|
194 |
+
self.assertIsInstance(result, RunResult)
|
195 |
+
self.assertEqual(result.output, "This is the final answer.")
|
196 |
+
self.assertEqual(result.state, "success")
|
197 |
+
self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20))
|
198 |
+
self.assertIsInstance(result.messages, list)
|
199 |
+
self.assertGreater(result.timing.duration, 0)
|
200 |
+
|
201 |
+
agent = ToolCallingAgent(
|
202 |
+
tools=[],
|
203 |
+
model=FakeLLMModel(),
|
204 |
+
max_steps=1,
|
205 |
+
return_full_result=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
result = agent.run("Fake task")
|
209 |
+
|
210 |
+
self.assertIsInstance(result, RunResult)
|
211 |
+
self.assertEqual(result.output, "image")
|
212 |
+
self.assertEqual(result.state, "success")
|
213 |
+
self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20))
|
214 |
+
self.assertIsInstance(result.messages, list)
|
215 |
+
self.assertGreater(result.timing.duration, 0)
|
216 |
+
|
217 |
+
# Below 2 lines should be removed when the attributes are removed
|
218 |
+
assert agent.monitor.total_input_token_count == 10
|
219 |
+
assert agent.monitor.total_output_token_count == 20
|
220 |
+
|
221 |
+
def test_run_result_no_token_usage(self):
|
222 |
+
agent = CodeAgent(
|
223 |
+
tools=[],
|
224 |
+
model=FakeLLMModel(give_token_usage=False),
|
225 |
+
max_steps=1,
|
226 |
+
return_full_result=True,
|
227 |
+
)
|
228 |
+
|
229 |
+
result = agent.run("Fake task")
|
230 |
+
|
231 |
+
self.assertIsInstance(result, RunResult)
|
232 |
+
self.assertEqual(result.output, "This is the final answer.")
|
233 |
+
self.assertEqual(result.state, "success")
|
234 |
+
self.assertIsNone(result.token_usage)
|
235 |
+
self.assertIsInstance(result.messages, list)
|
236 |
+
self.assertGreater(result.timing.duration, 0)
|
237 |
+
|
238 |
+
agent = ToolCallingAgent(
|
239 |
+
tools=[],
|
240 |
+
model=FakeLLMModel(give_token_usage=False),
|
241 |
+
max_steps=1,
|
242 |
+
return_full_result=True,
|
243 |
+
)
|
244 |
+
|
245 |
+
result = agent.run("Fake task")
|
246 |
+
|
247 |
+
self.assertIsInstance(result, RunResult)
|
248 |
+
self.assertEqual(result.output, "image")
|
249 |
+
self.assertEqual(result.state, "success")
|
250 |
+
self.assertIsNone(result.token_usage)
|
251 |
+
self.assertIsInstance(result.messages, list)
|
252 |
+
self.assertGreater(result.timing.duration, 0)
|
tests/test_remote_executors.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from textwrap import dedent
|
3 |
+
from unittest.mock import MagicMock, patch
|
4 |
+
|
5 |
+
import docker
|
6 |
+
import PIL.Image
|
7 |
+
import pytest
|
8 |
+
from rich.console import Console
|
9 |
+
|
10 |
+
from smolagents.default_tools import FinalAnswerTool, WikipediaSearchTool
|
11 |
+
from smolagents.monitoring import AgentLogger, LogLevel
|
12 |
+
from smolagents.remote_executors import DockerExecutor, E2BExecutor, RemotePythonExecutor
|
13 |
+
from smolagents.utils import AgentError
|
14 |
+
|
15 |
+
from .utils.markers import require_run_all
|
16 |
+
|
17 |
+
|
18 |
+
class TestRemotePythonExecutor:
|
19 |
+
def test_send_tools_empty_tools(self):
|
20 |
+
executor = RemotePythonExecutor(additional_imports=[], logger=MagicMock())
|
21 |
+
executor.run_code_raise_errors = MagicMock()
|
22 |
+
executor.send_tools({})
|
23 |
+
assert executor.run_code_raise_errors.call_count == 1
|
24 |
+
# No new packages should be installed
|
25 |
+
assert "!pip install" not in executor.run_code_raise_errors.call_args.args[0]
|
26 |
+
|
27 |
+
@require_run_all
|
28 |
+
def test_send_tools_with_default_wikipedia_search_tool(self):
|
29 |
+
tool = WikipediaSearchTool()
|
30 |
+
executor = RemotePythonExecutor(additional_imports=[], logger=MagicMock())
|
31 |
+
executor.run_code_raise_errors = MagicMock()
|
32 |
+
executor.run_code_raise_errors.return_value = (None, "", False)
|
33 |
+
executor.send_tools({"wikipedia_search": tool})
|
34 |
+
assert executor.run_code_raise_errors.call_count == 2
|
35 |
+
assert "!pip install wikipedia-api" == executor.run_code_raise_errors.call_args_list[0].args[0]
|
36 |
+
assert "class WikipediaSearchTool(Tool)" in executor.run_code_raise_errors.call_args_list[1].args[0]
|
37 |
+
|
38 |
+
|
39 |
+
class TestE2BExecutorUnit:
|
40 |
+
def test_e2b_executor_instantiation(self):
|
41 |
+
logger = MagicMock()
|
42 |
+
with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
|
43 |
+
mock_sandbox.return_value.commands.run.return_value.error = None
|
44 |
+
mock_sandbox.return_value.run_code.return_value.error = None
|
45 |
+
executor = E2BExecutor(
|
46 |
+
additional_imports=[], logger=logger, api_key="dummy-api-key", template="dummy-template-id", timeout=60
|
47 |
+
)
|
48 |
+
assert isinstance(executor, E2BExecutor)
|
49 |
+
assert executor.logger == logger
|
50 |
+
assert executor.sandbox == mock_sandbox.return_value
|
51 |
+
assert mock_sandbox.call_count == 1
|
52 |
+
assert mock_sandbox.call_args.kwargs == {
|
53 |
+
"api_key": "dummy-api-key",
|
54 |
+
"template": "dummy-template-id",
|
55 |
+
"timeout": 60,
|
56 |
+
}
|
57 |
+
|
58 |
+
def test_cleanup(self):
|
59 |
+
"""Test that the cleanup method properly shuts down the sandbox"""
|
60 |
+
logger = MagicMock()
|
61 |
+
with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
|
62 |
+
# Setup mock
|
63 |
+
mock_sandbox.return_value.kill = MagicMock()
|
64 |
+
|
65 |
+
# Create executor
|
66 |
+
executor = E2BExecutor(additional_imports=[], logger=logger, api_key="dummy-api-key")
|
67 |
+
|
68 |
+
# Call cleanup
|
69 |
+
executor.cleanup()
|
70 |
+
|
71 |
+
# Verify sandbox was killed
|
72 |
+
mock_sandbox.return_value.kill.assert_called_once()
|
73 |
+
assert logger.log.call_count >= 2 # Should log start and completion messages
|
74 |
+
|
75 |
+
|
76 |
+
@pytest.fixture
|
77 |
+
def e2b_executor():
|
78 |
+
executor = E2BExecutor(
|
79 |
+
additional_imports=["pillow", "numpy"],
|
80 |
+
logger=AgentLogger(LogLevel.INFO, Console(force_terminal=False, file=io.StringIO())),
|
81 |
+
)
|
82 |
+
yield executor
|
83 |
+
executor.cleanup()
|
84 |
+
|
85 |
+
|
86 |
+
@require_run_all
|
87 |
+
class TestE2BExecutorIntegration:
|
88 |
+
@pytest.fixture(autouse=True)
|
89 |
+
def set_executor(self, e2b_executor):
|
90 |
+
self.executor = e2b_executor
|
91 |
+
|
92 |
+
@pytest.mark.parametrize(
|
93 |
+
"code_action, expected_result",
|
94 |
+
[
|
95 |
+
(
|
96 |
+
dedent('''
|
97 |
+
final_answer("""This is
|
98 |
+
a multiline
|
99 |
+
final answer""")
|
100 |
+
'''),
|
101 |
+
"This is\na multiline\nfinal answer",
|
102 |
+
),
|
103 |
+
(
|
104 |
+
dedent("""
|
105 |
+
text = '''Text containing
|
106 |
+
final_answer(5)
|
107 |
+
'''
|
108 |
+
final_answer(text)
|
109 |
+
"""),
|
110 |
+
"Text containing\nfinal_answer(5)\n",
|
111 |
+
),
|
112 |
+
(
|
113 |
+
dedent("""
|
114 |
+
num = 2
|
115 |
+
if num == 1:
|
116 |
+
final_answer("One")
|
117 |
+
elif num == 2:
|
118 |
+
final_answer("Two")
|
119 |
+
"""),
|
120 |
+
"Two",
|
121 |
+
),
|
122 |
+
],
|
123 |
+
)
|
124 |
+
def test_final_answer_patterns(self, code_action, expected_result):
|
125 |
+
self.executor.send_tools({"final_answer": FinalAnswerTool()})
|
126 |
+
result, logs, final_answer = self.executor(code_action)
|
127 |
+
assert final_answer is True
|
128 |
+
assert result == expected_result
|
129 |
+
|
130 |
+
def test_custom_final_answer(self):
|
131 |
+
class CustomFinalAnswerTool(FinalAnswerTool):
|
132 |
+
def forward(self, answer: str) -> str:
|
133 |
+
return "CUSTOM" + answer
|
134 |
+
|
135 |
+
self.executor.send_tools({"final_answer": CustomFinalAnswerTool()})
|
136 |
+
code_action = dedent("""
|
137 |
+
final_answer(answer="_answer")
|
138 |
+
""")
|
139 |
+
result, logs, final_answer = self.executor(code_action)
|
140 |
+
assert final_answer is True
|
141 |
+
assert result == "CUSTOM_answer"
|
142 |
+
|
143 |
+
def test_custom_final_answer_with_custom_inputs(self):
|
144 |
+
class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
|
145 |
+
inputs = {
|
146 |
+
"answer1": {"type": "string", "description": "First part of the answer."},
|
147 |
+
"answer2": {"type": "string", "description": "Second part of the answer."},
|
148 |
+
}
|
149 |
+
|
150 |
+
def forward(self, answer1: str, answer2: str) -> str:
|
151 |
+
return answer1 + "CUSTOM" + answer2
|
152 |
+
|
153 |
+
self.executor.send_tools({"final_answer": CustomFinalAnswerToolWithCustomInputs()})
|
154 |
+
code_action = dedent("""
|
155 |
+
final_answer(
|
156 |
+
answer1="answer1_",
|
157 |
+
answer2="_answer2"
|
158 |
+
)
|
159 |
+
""")
|
160 |
+
result, logs, final_answer = self.executor(code_action)
|
161 |
+
assert final_answer is True
|
162 |
+
assert result == "answer1_CUSTOM_answer2"
|
163 |
+
|
164 |
+
|
165 |
+
@pytest.fixture
|
166 |
+
def docker_executor():
|
167 |
+
executor = DockerExecutor(
|
168 |
+
additional_imports=["pillow", "numpy"],
|
169 |
+
logger=AgentLogger(LogLevel.INFO, Console(force_terminal=False, file=io.StringIO())),
|
170 |
+
)
|
171 |
+
yield executor
|
172 |
+
executor.delete()
|
173 |
+
|
174 |
+
|
175 |
+
@require_run_all
|
176 |
+
class TestDockerExecutorIntegration:
|
177 |
+
@pytest.fixture(autouse=True)
|
178 |
+
def set_executor(self, docker_executor):
|
179 |
+
self.executor = docker_executor
|
180 |
+
|
181 |
+
def test_initialization(self):
|
182 |
+
"""Check if DockerExecutor initializes without errors"""
|
183 |
+
assert self.executor.container is not None, "Container should be initialized"
|
184 |
+
|
185 |
+
def test_state_persistence(self):
|
186 |
+
"""Test that variables and imports form one snippet persist in the next"""
|
187 |
+
code_action = "import numpy as np; a = 2"
|
188 |
+
self.executor(code_action)
|
189 |
+
|
190 |
+
code_action = "print(np.sqrt(a))"
|
191 |
+
result, logs, final_answer = self.executor(code_action)
|
192 |
+
assert "1.41421" in logs
|
193 |
+
|
194 |
+
def test_execute_output(self):
|
195 |
+
"""Test execution that returns a string"""
|
196 |
+
code_action = 'final_answer("This is the final answer")'
|
197 |
+
result, logs, final_answer = self.executor(code_action)
|
198 |
+
assert result == "This is the final answer", "Result should be 'This is the final answer'"
|
199 |
+
|
200 |
+
def test_execute_multiline_output(self):
|
201 |
+
"""Test execution that returns a string"""
|
202 |
+
code_action = 'result = "This is the final answer"\nfinal_answer(result)'
|
203 |
+
result, logs, final_answer = self.executor(code_action)
|
204 |
+
assert result == "This is the final answer", "Result should be 'This is the final answer'"
|
205 |
+
|
206 |
+
def test_execute_image_output(self):
|
207 |
+
"""Test execution that returns a base64 image"""
|
208 |
+
code_action = dedent("""
|
209 |
+
import base64
|
210 |
+
from PIL import Image
|
211 |
+
from io import BytesIO
|
212 |
+
image = Image.new("RGB", (10, 10), (255, 0, 0))
|
213 |
+
final_answer(image)
|
214 |
+
""")
|
215 |
+
result, logs, final_answer = self.executor(code_action)
|
216 |
+
assert isinstance(result, PIL.Image.Image), "Result should be a PIL Image"
|
217 |
+
|
218 |
+
def test_syntax_error_handling(self):
|
219 |
+
"""Test handling of syntax errors"""
|
220 |
+
code_action = 'print("Missing Parenthesis' # Syntax error
|
221 |
+
with pytest.raises(AgentError) as exception_info:
|
222 |
+
self.executor(code_action)
|
223 |
+
assert "SyntaxError" in str(exception_info.value), "Should raise a syntax error"
|
224 |
+
|
225 |
+
def test_cleanup_on_deletion(self):
|
226 |
+
"""Test if Docker container stops and removes on deletion"""
|
227 |
+
container_id = self.executor.container.id
|
228 |
+
self.executor.delete() # Trigger cleanup
|
229 |
+
|
230 |
+
client = docker.from_env()
|
231 |
+
containers = [c.id for c in client.containers.list(all=True)]
|
232 |
+
assert container_id not in containers, "Container should be removed"
|
233 |
+
|
234 |
+
@pytest.mark.parametrize(
|
235 |
+
"code_action, expected_result",
|
236 |
+
[
|
237 |
+
(
|
238 |
+
dedent('''
|
239 |
+
final_answer("""This is
|
240 |
+
a multiline
|
241 |
+
final answer""")
|
242 |
+
'''),
|
243 |
+
"This is\na multiline\nfinal answer",
|
244 |
+
),
|
245 |
+
(
|
246 |
+
dedent("""
|
247 |
+
text = '''Text containing
|
248 |
+
final_answer(5)
|
249 |
+
'''
|
250 |
+
final_answer(text)
|
251 |
+
"""),
|
252 |
+
"Text containing\nfinal_answer(5)\n",
|
253 |
+
),
|
254 |
+
(
|
255 |
+
dedent("""
|
256 |
+
num = 2
|
257 |
+
if num == 1:
|
258 |
+
final_answer("One")
|
259 |
+
elif num == 2:
|
260 |
+
final_answer("Two")
|
261 |
+
"""),
|
262 |
+
"Two",
|
263 |
+
),
|
264 |
+
],
|
265 |
+
)
|
266 |
+
def test_final_answer_patterns(self, code_action, expected_result):
|
267 |
+
self.executor.send_tools({"final_answer": FinalAnswerTool()})
|
268 |
+
result, logs, final_answer = self.executor(code_action)
|
269 |
+
assert final_answer is True
|
270 |
+
assert result == expected_result
|
271 |
+
|
272 |
+
def test_custom_final_answer(self):
|
273 |
+
class CustomFinalAnswerTool(FinalAnswerTool):
|
274 |
+
def forward(self, answer: str) -> str:
|
275 |
+
return "CUSTOM" + answer
|
276 |
+
|
277 |
+
self.executor.send_tools({"final_answer": CustomFinalAnswerTool()})
|
278 |
+
code_action = dedent("""
|
279 |
+
final_answer(answer="_answer")
|
280 |
+
""")
|
281 |
+
result, logs, final_answer = self.executor(code_action)
|
282 |
+
assert final_answer is True
|
283 |
+
assert result == "CUSTOM_answer"
|
284 |
+
|
285 |
+
def test_custom_final_answer_with_custom_inputs(self):
|
286 |
+
class CustomFinalAnswerToolWithCustomInputs(FinalAnswerTool):
|
287 |
+
inputs = {
|
288 |
+
"answer1": {"type": "string", "description": "First part of the answer."},
|
289 |
+
"answer2": {"type": "string", "description": "Second part of the answer."},
|
290 |
+
}
|
291 |
+
|
292 |
+
def forward(self, answer1: str, answer2: str) -> str:
|
293 |
+
return answer1 + "CUSTOM" + answer2
|
294 |
+
|
295 |
+
self.executor.send_tools({"final_answer": CustomFinalAnswerToolWithCustomInputs()})
|
296 |
+
code_action = dedent("""
|
297 |
+
final_answer(
|
298 |
+
answer1="answer1_",
|
299 |
+
answer2="_answer2"
|
300 |
+
)
|
301 |
+
""")
|
302 |
+
result, logs, final_answer = self.executor(code_action)
|
303 |
+
assert final_answer is True
|
304 |
+
assert result == "answer1_CUSTOM_answer2"
|
305 |
+
|
306 |
+
|
307 |
+
class TestDockerExecutorUnit:
|
308 |
+
def test_cleanup(self):
|
309 |
+
"""Test that cleanup properly stops and removes the container"""
|
310 |
+
logger = MagicMock()
|
311 |
+
with (
|
312 |
+
patch("docker.from_env") as mock_docker_client,
|
313 |
+
patch("requests.post") as mock_post,
|
314 |
+
patch("websocket.create_connection"),
|
315 |
+
):
|
316 |
+
# Setup mocks
|
317 |
+
mock_container = MagicMock()
|
318 |
+
mock_container.status = "running"
|
319 |
+
mock_container.short_id = "test123"
|
320 |
+
|
321 |
+
mock_docker_client.return_value.containers.run.return_value = mock_container
|
322 |
+
mock_docker_client.return_value.images.get.return_value = MagicMock()
|
323 |
+
|
324 |
+
mock_post.return_value.status_code = 201
|
325 |
+
mock_post.return_value.json.return_value = {"id": "test-kernel-id"}
|
326 |
+
|
327 |
+
# Create executor
|
328 |
+
executor = DockerExecutor(additional_imports=[], logger=logger, build_new_image=False)
|
329 |
+
|
330 |
+
# Call cleanup
|
331 |
+
executor.cleanup()
|
332 |
+
|
333 |
+
# Verify container was stopped and removed
|
334 |
+
mock_container.stop.assert_called_once()
|
335 |
+
mock_container.remove.assert_called_once()
|
tests/test_search.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
from smolagents import DuckDuckGoSearchTool
|
18 |
+
|
19 |
+
from .test_tools import ToolTesterMixin
|
20 |
+
from .utils.markers import require_run_all
|
21 |
+
|
22 |
+
|
23 |
+
class TestDuckDuckGoSearchTool(ToolTesterMixin):
|
24 |
+
def setup_method(self):
|
25 |
+
self.tool = DuckDuckGoSearchTool()
|
26 |
+
self.tool.setup()
|
27 |
+
|
28 |
+
@require_run_all
|
29 |
+
def test_exact_match_arg(self):
|
30 |
+
result = self.tool("Agents")
|
31 |
+
assert isinstance(result, str)
|
32 |
+
|
33 |
+
@require_run_all
|
34 |
+
def test_agent_type_output(self):
|
35 |
+
super().test_agent_type_output()
|
tests/test_tool_validation.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
from textwrap import dedent
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from smolagents.default_tools import (
|
7 |
+
DuckDuckGoSearchTool,
|
8 |
+
GoogleSearchTool,
|
9 |
+
SpeechToTextTool,
|
10 |
+
VisitWebpageTool,
|
11 |
+
WebSearchTool,
|
12 |
+
)
|
13 |
+
from smolagents.tool_validation import MethodChecker, validate_tool_attributes
|
14 |
+
from smolagents.tools import Tool, tool
|
15 |
+
|
16 |
+
|
17 |
+
UNDEFINED_VARIABLE = "undefined_variable"
|
18 |
+
|
19 |
+
|
20 |
+
@pytest.mark.parametrize(
|
21 |
+
"tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool, WebSearchTool]
|
22 |
+
)
|
23 |
+
def test_validate_tool_attributes_with_default_tools(tool_class):
|
24 |
+
assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool"
|
25 |
+
|
26 |
+
|
27 |
+
class ValidTool(Tool):
|
28 |
+
name = "valid_tool"
|
29 |
+
description = "A valid tool"
|
30 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
31 |
+
output_type = "string"
|
32 |
+
simple_attr = "string"
|
33 |
+
dict_attr = {"key": "value"}
|
34 |
+
|
35 |
+
def __init__(self, optional_param="default"):
|
36 |
+
super().__init__()
|
37 |
+
self.param = optional_param
|
38 |
+
|
39 |
+
def forward(self, input: str) -> str:
|
40 |
+
return input.upper()
|
41 |
+
|
42 |
+
|
43 |
+
@tool
|
44 |
+
def valid_tool_function(input: str) -> str:
|
45 |
+
"""A valid tool function.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
input (str): Input string.
|
49 |
+
"""
|
50 |
+
return input.upper()
|
51 |
+
|
52 |
+
|
53 |
+
@pytest.mark.parametrize("tool_class", [ValidTool, valid_tool_function.__class__])
|
54 |
+
def test_validate_tool_attributes_valid(tool_class):
|
55 |
+
assert validate_tool_attributes(tool_class) is None
|
56 |
+
|
57 |
+
|
58 |
+
class InvalidToolName(Tool):
|
59 |
+
name = "invalid tool name"
|
60 |
+
description = "Tool with invalid name"
|
61 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
62 |
+
output_type = "string"
|
63 |
+
|
64 |
+
def __init__(self):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
def forward(self, input: str) -> str:
|
68 |
+
return input
|
69 |
+
|
70 |
+
|
71 |
+
class InvalidToolComplexAttrs(Tool):
|
72 |
+
name = "invalid_tool"
|
73 |
+
description = "Tool with complex class attributes"
|
74 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
75 |
+
output_type = "string"
|
76 |
+
complex_attr = [x for x in range(3)] # Complex class attribute
|
77 |
+
|
78 |
+
def __init__(self):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
def forward(self, input: str) -> str:
|
82 |
+
return input
|
83 |
+
|
84 |
+
|
85 |
+
class InvalidToolRequiredParams(Tool):
|
86 |
+
name = "invalid_tool"
|
87 |
+
description = "Tool with required params"
|
88 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
89 |
+
output_type = "string"
|
90 |
+
|
91 |
+
def __init__(self, required_param, kwarg1=1): # No default value
|
92 |
+
super().__init__()
|
93 |
+
self.param = required_param
|
94 |
+
|
95 |
+
def forward(self, input: str) -> str:
|
96 |
+
return input
|
97 |
+
|
98 |
+
|
99 |
+
class InvalidToolNonLiteralDefaultParam(Tool):
|
100 |
+
name = "invalid_tool"
|
101 |
+
description = "Tool with non-literal default parameter value"
|
102 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
103 |
+
output_type = "string"
|
104 |
+
|
105 |
+
def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal
|
106 |
+
super().__init__()
|
107 |
+
self.default_param = default_param
|
108 |
+
|
109 |
+
def forward(self, input: str) -> str:
|
110 |
+
return input
|
111 |
+
|
112 |
+
|
113 |
+
class InvalidToolUndefinedNames(Tool):
|
114 |
+
name = "invalid_tool"
|
115 |
+
description = "Tool with undefined names"
|
116 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
117 |
+
output_type = "string"
|
118 |
+
|
119 |
+
def forward(self, input: str) -> str:
|
120 |
+
return UNDEFINED_VARIABLE # Undefined name
|
121 |
+
|
122 |
+
|
123 |
+
@pytest.mark.parametrize(
|
124 |
+
"tool_class, expected_error",
|
125 |
+
[
|
126 |
+
(
|
127 |
+
InvalidToolName,
|
128 |
+
"Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found 'invalid tool name'",
|
129 |
+
),
|
130 |
+
(InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"),
|
131 |
+
(InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"),
|
132 |
+
(
|
133 |
+
InvalidToolNonLiteralDefaultParam,
|
134 |
+
"Parameters in __init__ must have literal default values, found non-literal defaults",
|
135 |
+
),
|
136 |
+
(InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"),
|
137 |
+
],
|
138 |
+
)
|
139 |
+
def test_validate_tool_attributes_exceptions(tool_class, expected_error):
|
140 |
+
with pytest.raises(ValueError, match=expected_error):
|
141 |
+
validate_tool_attributes(tool_class)
|
142 |
+
|
143 |
+
|
144 |
+
class MultipleAssignmentsTool(Tool):
|
145 |
+
name = "multiple_assignments_tool"
|
146 |
+
description = "Tool with multiple assignments"
|
147 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
148 |
+
output_type = "string"
|
149 |
+
|
150 |
+
def __init__(self):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def forward(self, input: str) -> str:
|
154 |
+
a, b = "1", "2"
|
155 |
+
return a + b
|
156 |
+
|
157 |
+
|
158 |
+
def test_validate_tool_attributes_multiple_assignments():
|
159 |
+
validate_tool_attributes(MultipleAssignmentsTool)
|
160 |
+
|
161 |
+
|
162 |
+
@tool
|
163 |
+
def tool_function_with_multiple_assignments(input: str) -> str:
|
164 |
+
"""A valid tool function.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
input (str): Input string.
|
168 |
+
"""
|
169 |
+
a, b = "1", "2"
|
170 |
+
return input.upper() + a + b
|
171 |
+
|
172 |
+
|
173 |
+
@pytest.mark.parametrize("tool_instance", [MultipleAssignmentsTool(), tool_function_with_multiple_assignments])
|
174 |
+
def test_tool_to_dict_validation_with_multiple_assignments(tool_instance):
|
175 |
+
tool_instance.to_dict()
|
176 |
+
|
177 |
+
|
178 |
+
class TestMethodChecker:
|
179 |
+
def test_multiple_assignments(self):
|
180 |
+
source_code = dedent(
|
181 |
+
"""
|
182 |
+
def forward(self) -> str:
|
183 |
+
a, b = "1", "2"
|
184 |
+
return a + b
|
185 |
+
"""
|
186 |
+
)
|
187 |
+
method_checker = MethodChecker(set())
|
188 |
+
method_checker.visit(ast.parse(source_code))
|
189 |
+
assert method_checker.errors == []
|
tests/test_tools.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import inspect
|
16 |
+
import os
|
17 |
+
from textwrap import dedent
|
18 |
+
from typing import Any, Literal
|
19 |
+
from unittest.mock import MagicMock, patch
|
20 |
+
|
21 |
+
import mcp
|
22 |
+
import numpy as np
|
23 |
+
import PIL.Image
|
24 |
+
import pytest
|
25 |
+
|
26 |
+
from smolagents.agent_types import _AGENT_TYPE_MAPPING
|
27 |
+
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, launch_gradio_demo, tool, validate_tool_arguments
|
28 |
+
|
29 |
+
from .utils.markers import require_run_all
|
30 |
+
|
31 |
+
|
32 |
+
class ToolTesterMixin:
|
33 |
+
def test_inputs_output(self):
|
34 |
+
assert hasattr(self.tool, "inputs")
|
35 |
+
assert hasattr(self.tool, "output_type")
|
36 |
+
|
37 |
+
inputs = self.tool.inputs
|
38 |
+
assert isinstance(inputs, dict)
|
39 |
+
|
40 |
+
for _, input_spec in inputs.items():
|
41 |
+
assert "type" in input_spec
|
42 |
+
assert "description" in input_spec
|
43 |
+
assert input_spec["type"] in AUTHORIZED_TYPES
|
44 |
+
assert isinstance(input_spec["description"], str)
|
45 |
+
|
46 |
+
output_type = self.tool.output_type
|
47 |
+
assert output_type in AUTHORIZED_TYPES
|
48 |
+
|
49 |
+
def test_common_attributes(self):
|
50 |
+
assert hasattr(self.tool, "description")
|
51 |
+
assert hasattr(self.tool, "name")
|
52 |
+
assert hasattr(self.tool, "inputs")
|
53 |
+
assert hasattr(self.tool, "output_type")
|
54 |
+
|
55 |
+
def test_agent_type_output(self, create_inputs):
|
56 |
+
inputs = create_inputs(self.tool.inputs)
|
57 |
+
output = self.tool(**inputs, sanitize_inputs_outputs=True)
|
58 |
+
if self.tool.output_type != "any":
|
59 |
+
agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
60 |
+
assert isinstance(output, agent_type)
|
61 |
+
|
62 |
+
@pytest.fixture
|
63 |
+
def create_inputs(self, shared_datadir):
|
64 |
+
def _create_inputs(tool_inputs: dict[str, dict[str | type, str]]) -> dict[str, Any]:
|
65 |
+
inputs = {}
|
66 |
+
|
67 |
+
for input_name, input_desc in tool_inputs.items():
|
68 |
+
input_type = input_desc["type"]
|
69 |
+
|
70 |
+
if input_type == "string":
|
71 |
+
inputs[input_name] = "Text input"
|
72 |
+
elif input_type == "image":
|
73 |
+
inputs[input_name] = PIL.Image.open(shared_datadir / "000000039769.png").resize((512, 512))
|
74 |
+
elif input_type == "audio":
|
75 |
+
inputs[input_name] = np.ones(3000)
|
76 |
+
else:
|
77 |
+
raise ValueError(f"Invalid type requested: {input_type}")
|
78 |
+
|
79 |
+
return inputs
|
80 |
+
|
81 |
+
return _create_inputs
|
82 |
+
|
83 |
+
|
84 |
+
class TestTool:
|
85 |
+
def test_tool_init_with_decorator(self):
|
86 |
+
@tool
|
87 |
+
def coolfunc(a: str, b: int) -> float:
|
88 |
+
"""Cool function
|
89 |
+
|
90 |
+
Args:
|
91 |
+
a: The first argument
|
92 |
+
b: The second one
|
93 |
+
"""
|
94 |
+
return b + 2, a
|
95 |
+
|
96 |
+
assert coolfunc.output_type == "number"
|
97 |
+
|
98 |
+
def test_tool_init_vanilla(self):
|
99 |
+
class HFModelDownloadsTool(Tool):
|
100 |
+
name = "model_download_counter"
|
101 |
+
description = """
|
102 |
+
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
|
103 |
+
It returns the name of the checkpoint."""
|
104 |
+
|
105 |
+
inputs = {
|
106 |
+
"task": {
|
107 |
+
"type": "string",
|
108 |
+
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
109 |
+
}
|
110 |
+
}
|
111 |
+
output_type = "string"
|
112 |
+
|
113 |
+
def forward(self, task: str) -> str:
|
114 |
+
return "best model"
|
115 |
+
|
116 |
+
tool = HFModelDownloadsTool()
|
117 |
+
assert list(tool.inputs.keys())[0] == "task"
|
118 |
+
|
119 |
+
def test_tool_init_decorator_raises_issues(self):
|
120 |
+
with pytest.raises(Exception) as e:
|
121 |
+
|
122 |
+
@tool
|
123 |
+
def coolfunc(a: str, b: int):
|
124 |
+
"""Cool function
|
125 |
+
|
126 |
+
Args:
|
127 |
+
a: The first argument
|
128 |
+
b: The second one
|
129 |
+
"""
|
130 |
+
return a + b
|
131 |
+
|
132 |
+
assert coolfunc.output_type == "number"
|
133 |
+
assert "Tool return type not found" in str(e)
|
134 |
+
|
135 |
+
with pytest.raises(Exception) as e:
|
136 |
+
|
137 |
+
@tool
|
138 |
+
def coolfunc(a: str, b: int) -> int:
|
139 |
+
"""Cool function
|
140 |
+
|
141 |
+
Args:
|
142 |
+
a: The first argument
|
143 |
+
"""
|
144 |
+
return b + a
|
145 |
+
|
146 |
+
assert coolfunc.output_type == "number"
|
147 |
+
assert "docstring has no description for the argument" in str(e)
|
148 |
+
|
149 |
+
def test_saving_tool_raises_error_imports_outside_function(self, tmp_path):
|
150 |
+
with pytest.raises(Exception) as e:
|
151 |
+
import numpy as np
|
152 |
+
|
153 |
+
@tool
|
154 |
+
def get_current_time() -> str:
|
155 |
+
"""
|
156 |
+
Gets the current time.
|
157 |
+
"""
|
158 |
+
return str(np.random.random())
|
159 |
+
|
160 |
+
get_current_time.save(tmp_path)
|
161 |
+
|
162 |
+
assert "np" in str(e)
|
163 |
+
|
164 |
+
# Also test with classic definition
|
165 |
+
with pytest.raises(Exception) as e:
|
166 |
+
|
167 |
+
class GetCurrentTimeTool(Tool):
|
168 |
+
name = "get_current_time_tool"
|
169 |
+
description = "Gets the current time"
|
170 |
+
inputs = {}
|
171 |
+
output_type = "string"
|
172 |
+
|
173 |
+
def forward(self):
|
174 |
+
return str(np.random.random())
|
175 |
+
|
176 |
+
get_current_time = GetCurrentTimeTool()
|
177 |
+
get_current_time.save(tmp_path)
|
178 |
+
|
179 |
+
assert "np" in str(e)
|
180 |
+
|
181 |
+
def test_tool_definition_raises_no_error_imports_in_function(self):
|
182 |
+
@tool
|
183 |
+
def get_current_time() -> str:
|
184 |
+
"""
|
185 |
+
Gets the current time.
|
186 |
+
"""
|
187 |
+
from datetime import datetime
|
188 |
+
|
189 |
+
return str(datetime.now())
|
190 |
+
|
191 |
+
class GetCurrentTimeTool(Tool):
|
192 |
+
name = "get_current_time_tool"
|
193 |
+
description = "Gets the current time"
|
194 |
+
inputs = {}
|
195 |
+
output_type = "string"
|
196 |
+
|
197 |
+
def forward(self):
|
198 |
+
from datetime import datetime
|
199 |
+
|
200 |
+
return str(datetime.now())
|
201 |
+
|
202 |
+
def test_tool_to_dict_allows_no_arg_in_init(self):
|
203 |
+
"""Test that a tool cannot be saved with required args in init"""
|
204 |
+
|
205 |
+
class FailTool(Tool):
|
206 |
+
name = "specific"
|
207 |
+
description = "test description"
|
208 |
+
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
209 |
+
output_type = "string"
|
210 |
+
|
211 |
+
def __init__(self, url):
|
212 |
+
super().__init__(self)
|
213 |
+
self.url = url
|
214 |
+
|
215 |
+
def forward(self, string_input: str) -> str:
|
216 |
+
return self.url + string_input
|
217 |
+
|
218 |
+
fail_tool = FailTool("dummy_url")
|
219 |
+
with pytest.raises(Exception) as e:
|
220 |
+
fail_tool.to_dict()
|
221 |
+
assert "Parameters in __init__ must have default values, found required parameters" in str(e)
|
222 |
+
|
223 |
+
class PassTool(Tool):
|
224 |
+
name = "specific"
|
225 |
+
description = "test description"
|
226 |
+
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
227 |
+
output_type = "string"
|
228 |
+
|
229 |
+
def __init__(self, url: str | None = "none"):
|
230 |
+
super().__init__(self)
|
231 |
+
self.url = url
|
232 |
+
|
233 |
+
def forward(self, string_input: str) -> str:
|
234 |
+
return self.url + string_input
|
235 |
+
|
236 |
+
fail_tool = PassTool()
|
237 |
+
fail_tool.to_dict()
|
238 |
+
|
239 |
+
def test_saving_tool_allows_no_imports_from_outside_methods(self, tmp_path):
|
240 |
+
# Test that using imports from outside functions fails
|
241 |
+
import numpy as np
|
242 |
+
|
243 |
+
class FailTool(Tool):
|
244 |
+
name = "specific"
|
245 |
+
description = "test description"
|
246 |
+
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
247 |
+
output_type = "string"
|
248 |
+
|
249 |
+
def useless_method(self):
|
250 |
+
self.client = np.random.random()
|
251 |
+
return ""
|
252 |
+
|
253 |
+
def forward(self, string_input):
|
254 |
+
return self.useless_method() + string_input
|
255 |
+
|
256 |
+
fail_tool = FailTool()
|
257 |
+
with pytest.raises(Exception) as e:
|
258 |
+
fail_tool.save(tmp_path)
|
259 |
+
assert "'np' is undefined" in str(e)
|
260 |
+
|
261 |
+
# Test that putting these imports inside functions works
|
262 |
+
class SuccessTool(Tool):
|
263 |
+
name = "specific"
|
264 |
+
description = "test description"
|
265 |
+
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
266 |
+
output_type = "string"
|
267 |
+
|
268 |
+
def useless_method(self):
|
269 |
+
import numpy as np
|
270 |
+
|
271 |
+
self.client = np.random.random()
|
272 |
+
return ""
|
273 |
+
|
274 |
+
def forward(self, string_input):
|
275 |
+
return self.useless_method() + string_input
|
276 |
+
|
277 |
+
success_tool = SuccessTool()
|
278 |
+
success_tool.save(tmp_path)
|
279 |
+
|
280 |
+
def test_tool_missing_class_attributes_raises_error(self):
|
281 |
+
with pytest.raises(Exception) as e:
|
282 |
+
|
283 |
+
class GetWeatherTool(Tool):
|
284 |
+
name = "get_weather"
|
285 |
+
description = "Get weather in the next days at given location."
|
286 |
+
inputs = {
|
287 |
+
"location": {"type": "string", "description": "the location"},
|
288 |
+
"celsius": {
|
289 |
+
"type": "string",
|
290 |
+
"description": "the temperature type",
|
291 |
+
},
|
292 |
+
}
|
293 |
+
|
294 |
+
def forward(self, location: str, celsius: bool | None = False) -> str:
|
295 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
296 |
+
|
297 |
+
GetWeatherTool()
|
298 |
+
assert "You must set an attribute output_type" in str(e)
|
299 |
+
|
300 |
+
def test_tool_from_decorator_optional_args(self):
|
301 |
+
@tool
|
302 |
+
def get_weather(location: str, celsius: bool | None = False) -> str:
|
303 |
+
"""
|
304 |
+
Get weather in the next days at given location.
|
305 |
+
Secretly this tool does not care about the location, it hates the weather everywhere.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
location: the location
|
309 |
+
celsius: the temperature type
|
310 |
+
"""
|
311 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
312 |
+
|
313 |
+
assert "nullable" in get_weather.inputs["celsius"]
|
314 |
+
assert get_weather.inputs["celsius"]["nullable"]
|
315 |
+
assert "nullable" not in get_weather.inputs["location"]
|
316 |
+
|
317 |
+
def test_tool_mismatching_nullable_args_raises_error(self):
|
318 |
+
with pytest.raises(Exception) as e:
|
319 |
+
|
320 |
+
class GetWeatherTool(Tool):
|
321 |
+
name = "get_weather"
|
322 |
+
description = "Get weather in the next days at given location."
|
323 |
+
inputs = {
|
324 |
+
"location": {"type": "string", "description": "the location"},
|
325 |
+
"celsius": {
|
326 |
+
"type": "string",
|
327 |
+
"description": "the temperature type",
|
328 |
+
},
|
329 |
+
}
|
330 |
+
output_type = "string"
|
331 |
+
|
332 |
+
def forward(self, location: str, celsius: bool | None = False) -> str:
|
333 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
334 |
+
|
335 |
+
GetWeatherTool()
|
336 |
+
assert "Nullable" in str(e)
|
337 |
+
|
338 |
+
with pytest.raises(Exception) as e:
|
339 |
+
|
340 |
+
class GetWeatherTool2(Tool):
|
341 |
+
name = "get_weather"
|
342 |
+
description = "Get weather in the next days at given location."
|
343 |
+
inputs = {
|
344 |
+
"location": {"type": "string", "description": "the location"},
|
345 |
+
"celsius": {
|
346 |
+
"type": "string",
|
347 |
+
"description": "the temperature type",
|
348 |
+
},
|
349 |
+
}
|
350 |
+
output_type = "string"
|
351 |
+
|
352 |
+
def forward(self, location: str, celsius: bool = False) -> str:
|
353 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
354 |
+
|
355 |
+
GetWeatherTool2()
|
356 |
+
assert "Nullable" in str(e)
|
357 |
+
|
358 |
+
with pytest.raises(Exception) as e:
|
359 |
+
|
360 |
+
class GetWeatherTool3(Tool):
|
361 |
+
name = "get_weather"
|
362 |
+
description = "Get weather in the next days at given location."
|
363 |
+
inputs = {
|
364 |
+
"location": {"type": "string", "description": "the location"},
|
365 |
+
"celsius": {
|
366 |
+
"type": "string",
|
367 |
+
"description": "the temperature type",
|
368 |
+
"nullable": True,
|
369 |
+
},
|
370 |
+
}
|
371 |
+
output_type = "string"
|
372 |
+
|
373 |
+
def forward(self, location, celsius: str) -> str:
|
374 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
375 |
+
|
376 |
+
GetWeatherTool3()
|
377 |
+
assert "Nullable" in str(e)
|
378 |
+
|
379 |
+
def test_tool_default_parameters_is_nullable(self):
|
380 |
+
@tool
|
381 |
+
def get_weather(location: str, celsius: bool = False) -> str:
|
382 |
+
"""
|
383 |
+
Get weather in the next days at given location.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
location: The location to get the weather for.
|
387 |
+
celsius: is the temperature given in celsius?
|
388 |
+
"""
|
389 |
+
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
390 |
+
|
391 |
+
assert get_weather.inputs["celsius"]["nullable"]
|
392 |
+
|
393 |
+
def test_tool_supports_any_none(self, tmp_path):
|
394 |
+
@tool
|
395 |
+
def get_weather(location: Any) -> None:
|
396 |
+
"""
|
397 |
+
Get weather in the next days at given location.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
location: The location to get the weather for.
|
401 |
+
"""
|
402 |
+
return
|
403 |
+
|
404 |
+
get_weather.save(tmp_path)
|
405 |
+
assert get_weather.inputs["location"]["type"] == "any"
|
406 |
+
assert get_weather.output_type == "null"
|
407 |
+
|
408 |
+
def test_tool_supports_array(self):
|
409 |
+
@tool
|
410 |
+
def get_weather(locations: list[str], months: tuple[str, str] | None = None) -> dict[str, float]:
|
411 |
+
"""
|
412 |
+
Get weather in the next days at given locations.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
locations: The locations to get the weather for.
|
416 |
+
months: The months to get the weather for
|
417 |
+
"""
|
418 |
+
return
|
419 |
+
|
420 |
+
assert get_weather.inputs["locations"]["type"] == "array"
|
421 |
+
assert get_weather.inputs["months"]["type"] == "array"
|
422 |
+
|
423 |
+
def test_tool_supports_string_literal(self):
|
424 |
+
@tool
|
425 |
+
def get_weather(unit: Literal["celsius", "fahrenheit"] = "celsius") -> None:
|
426 |
+
"""
|
427 |
+
Get weather in the next days at given location.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
unit: The unit of temperature
|
431 |
+
"""
|
432 |
+
return
|
433 |
+
|
434 |
+
assert get_weather.inputs["unit"]["type"] == "string"
|
435 |
+
assert get_weather.inputs["unit"]["enum"] == ["celsius", "fahrenheit"]
|
436 |
+
|
437 |
+
def test_tool_supports_numeric_literal(self):
|
438 |
+
@tool
|
439 |
+
def get_choice(choice: Literal[1, 2, 3]) -> None:
|
440 |
+
"""
|
441 |
+
Get choice based on the provided numeric literal.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
choice: The numeric choice to be made.
|
445 |
+
"""
|
446 |
+
return
|
447 |
+
|
448 |
+
assert get_choice.inputs["choice"]["type"] == "integer"
|
449 |
+
assert get_choice.inputs["choice"]["enum"] == [1, 2, 3]
|
450 |
+
|
451 |
+
def test_tool_supports_nullable_literal(self):
|
452 |
+
@tool
|
453 |
+
def get_choice(choice: Literal[1, 2, 3, None]) -> None:
|
454 |
+
"""
|
455 |
+
Get choice based on the provided value.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
choice: The numeric choice to be made.
|
459 |
+
"""
|
460 |
+
return
|
461 |
+
|
462 |
+
assert get_choice.inputs["choice"]["type"] == "integer"
|
463 |
+
assert get_choice.inputs["choice"]["nullable"] is True
|
464 |
+
assert get_choice.inputs["choice"]["enum"] == [1, 2, 3]
|
465 |
+
|
466 |
+
def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self, tmp_path):
|
467 |
+
@tool
|
468 |
+
def get_weather(location: Any) -> None:
|
469 |
+
"""
|
470 |
+
Get weather in the next days at given location.
|
471 |
+
And works pretty well.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
location: The location to get the weather for.
|
475 |
+
"""
|
476 |
+
return
|
477 |
+
|
478 |
+
get_weather.save(tmp_path)
|
479 |
+
with open(os.path.join(tmp_path, "tool.py"), "r", encoding="utf-8") as f:
|
480 |
+
source_code = f.read()
|
481 |
+
compile(source_code, f.name, "exec")
|
482 |
+
|
483 |
+
@pytest.mark.parametrize("fixture_name", ["boolean_default_tool_class", "boolean_default_tool_function"])
|
484 |
+
def test_to_dict_boolean_default_input(self, fixture_name, request):
|
485 |
+
"""Test that boolean input parameter with default value is correctly represented in to_dict output"""
|
486 |
+
tool = request.getfixturevalue(fixture_name)
|
487 |
+
result = tool.to_dict()
|
488 |
+
# Check that the boolean default annotation is preserved
|
489 |
+
assert "flag: bool = False" in result["code"]
|
490 |
+
# Check nullable attribute is set for the parameter with default value
|
491 |
+
assert "'nullable': True" in result["code"]
|
492 |
+
|
493 |
+
@pytest.mark.parametrize("fixture_name", ["optional_input_tool_class", "optional_input_tool_function"])
|
494 |
+
def test_to_dict_optional_input(self, fixture_name, request):
|
495 |
+
"""Test that Optional/nullable input parameter is correctly represented in to_dict output"""
|
496 |
+
tool = request.getfixturevalue(fixture_name)
|
497 |
+
result = tool.to_dict()
|
498 |
+
# Check the Optional type annotation is preserved
|
499 |
+
assert "optional_text: str | None = None" in result["code"]
|
500 |
+
# Check that the input is marked as nullable in the code
|
501 |
+
assert "'nullable': True" in result["code"]
|
502 |
+
|
503 |
+
def test_from_dict_roundtrip(self, example_tool):
|
504 |
+
# Convert to dict
|
505 |
+
tool_dict = example_tool.to_dict()
|
506 |
+
# Create from dict
|
507 |
+
recreated_tool = Tool.from_dict(tool_dict)
|
508 |
+
# Verify properties
|
509 |
+
assert recreated_tool.name == example_tool.name
|
510 |
+
assert recreated_tool.description == example_tool.description
|
511 |
+
assert recreated_tool.inputs == example_tool.inputs
|
512 |
+
assert recreated_tool.output_type == example_tool.output_type
|
513 |
+
# Verify functionality
|
514 |
+
test_input = "Hello, world!"
|
515 |
+
assert recreated_tool(test_input) == test_input.upper()
|
516 |
+
|
517 |
+
def test_tool_from_dict_invalid(self):
|
518 |
+
# Missing code key
|
519 |
+
with pytest.raises(ValueError) as e:
|
520 |
+
Tool.from_dict({"name": "invalid_tool"})
|
521 |
+
assert "must contain 'code' key" in str(e)
|
522 |
+
|
523 |
+
def test_tool_decorator_preserves_original_function(self):
|
524 |
+
# Define a test function with type hints and docstring
|
525 |
+
def test_function(items: list[str]) -> str:
|
526 |
+
"""Join a list of strings.
|
527 |
+
Args:
|
528 |
+
items: A list of strings to join
|
529 |
+
Returns:
|
530 |
+
The joined string
|
531 |
+
"""
|
532 |
+
return ", ".join(items)
|
533 |
+
|
534 |
+
# Store original function signature, name, and source
|
535 |
+
original_signature = inspect.signature(test_function)
|
536 |
+
original_name = test_function.__name__
|
537 |
+
original_docstring = test_function.__doc__
|
538 |
+
|
539 |
+
# Create a tool from the function
|
540 |
+
test_tool = tool(test_function)
|
541 |
+
|
542 |
+
# Check that the original function is unchanged
|
543 |
+
assert original_signature == inspect.signature(test_function)
|
544 |
+
assert original_name == test_function.__name__
|
545 |
+
assert original_docstring == test_function.__doc__
|
546 |
+
|
547 |
+
# Verify that the tool's forward method has a different signature (it has 'self')
|
548 |
+
tool_forward_sig = inspect.signature(test_tool.forward)
|
549 |
+
assert list(tool_forward_sig.parameters.keys())[0] == "self"
|
550 |
+
|
551 |
+
# Original function should not have 'self' parameter
|
552 |
+
assert "self" not in original_signature.parameters
|
553 |
+
|
554 |
+
def test_tool_with_union_type_return(self):
|
555 |
+
@tool
|
556 |
+
def union_type_return_tool_function(param: int) -> str | bool:
|
557 |
+
"""
|
558 |
+
Tool with output union type.
|
559 |
+
|
560 |
+
Args:
|
561 |
+
param: Input parameter.
|
562 |
+
"""
|
563 |
+
return str(param) if param > 0 else False
|
564 |
+
|
565 |
+
assert isinstance(union_type_return_tool_function, Tool)
|
566 |
+
assert union_type_return_tool_function.output_type == "any"
|
567 |
+
|
568 |
+
|
569 |
+
@pytest.fixture
|
570 |
+
def mock_server_parameters():
|
571 |
+
return MagicMock()
|
572 |
+
|
573 |
+
|
574 |
+
@pytest.fixture
|
575 |
+
def mock_mcp_adapt():
|
576 |
+
with patch("mcpadapt.core.MCPAdapt") as mock:
|
577 |
+
mock.return_value.__enter__.return_value = ["tool1", "tool2"]
|
578 |
+
mock.return_value.__exit__.return_value = None
|
579 |
+
yield mock
|
580 |
+
|
581 |
+
|
582 |
+
@pytest.fixture
|
583 |
+
def mock_smolagents_adapter():
|
584 |
+
with patch("mcpadapt.smolagents_adapter.SmolAgentsAdapter") as mock:
|
585 |
+
yield mock
|
586 |
+
|
587 |
+
|
588 |
+
class TestToolCollection:
|
589 |
+
def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
|
590 |
+
with ToolCollection.from_mcp(mock_server_parameters, trust_remote_code=True) as tool_collection:
|
591 |
+
assert isinstance(tool_collection, ToolCollection)
|
592 |
+
assert len(tool_collection.tools) == 2
|
593 |
+
assert "tool1" in tool_collection.tools
|
594 |
+
assert "tool2" in tool_collection.tools
|
595 |
+
|
596 |
+
@require_run_all
|
597 |
+
def test_integration_from_mcp(self):
|
598 |
+
# define the most simple mcp server with one tool that echoes the input text
|
599 |
+
mcp_server_script = dedent("""\
|
600 |
+
from mcp.server.fastmcp import FastMCP
|
601 |
+
|
602 |
+
mcp = FastMCP("Echo Server")
|
603 |
+
|
604 |
+
@mcp.tool()
|
605 |
+
def echo_tool(text: str) -> str:
|
606 |
+
return text
|
607 |
+
|
608 |
+
mcp.run()
|
609 |
+
""").strip()
|
610 |
+
|
611 |
+
mcp_server_params = mcp.StdioServerParameters(
|
612 |
+
command="python",
|
613 |
+
args=["-c", mcp_server_script],
|
614 |
+
)
|
615 |
+
|
616 |
+
with ToolCollection.from_mcp(mcp_server_params, trust_remote_code=True) as tool_collection:
|
617 |
+
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
618 |
+
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
|
619 |
+
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
|
620 |
+
|
621 |
+
def test_integration_from_mcp_with_streamable_http(self):
|
622 |
+
import subprocess
|
623 |
+
import time
|
624 |
+
|
625 |
+
# define the most simple mcp server with one tool that echoes the input text
|
626 |
+
mcp_server_script = dedent("""\
|
627 |
+
from mcp.server.fastmcp import FastMCP
|
628 |
+
|
629 |
+
mcp = FastMCP("Echo Server", host="127.0.0.1", port=8000)
|
630 |
+
|
631 |
+
@mcp.tool()
|
632 |
+
def echo_tool(text: str) -> str:
|
633 |
+
return text
|
634 |
+
|
635 |
+
mcp.run(transport="streamable-http")
|
636 |
+
""").strip()
|
637 |
+
|
638 |
+
# start the SSE mcp server in a subprocess
|
639 |
+
server_process = subprocess.Popen(
|
640 |
+
["python", "-c", mcp_server_script],
|
641 |
+
)
|
642 |
+
|
643 |
+
# wait for the server to start
|
644 |
+
time.sleep(1)
|
645 |
+
|
646 |
+
try:
|
647 |
+
with ToolCollection.from_mcp(
|
648 |
+
{"url": "http://127.0.0.1:8000/mcp", "transport": "streamable-http"}, trust_remote_code=True
|
649 |
+
) as tool_collection:
|
650 |
+
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
651 |
+
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
|
652 |
+
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
|
653 |
+
finally:
|
654 |
+
# clean up the process when test is done
|
655 |
+
server_process.kill()
|
656 |
+
server_process.wait()
|
657 |
+
|
658 |
+
def test_integration_from_mcp_with_sse(self):
|
659 |
+
import subprocess
|
660 |
+
import time
|
661 |
+
|
662 |
+
# define the most simple mcp server with one tool that echoes the input text
|
663 |
+
mcp_server_script = dedent("""\
|
664 |
+
from mcp.server.fastmcp import FastMCP
|
665 |
+
|
666 |
+
mcp = FastMCP("Echo Server", host="127.0.0.1", port=8000)
|
667 |
+
|
668 |
+
@mcp.tool()
|
669 |
+
def echo_tool(text: str) -> str:
|
670 |
+
return text
|
671 |
+
|
672 |
+
mcp.run("sse")
|
673 |
+
""").strip()
|
674 |
+
|
675 |
+
# start the SSE mcp server in a subprocess
|
676 |
+
server_process = subprocess.Popen(
|
677 |
+
["python", "-c", mcp_server_script],
|
678 |
+
)
|
679 |
+
|
680 |
+
# wait for the server to start
|
681 |
+
time.sleep(1)
|
682 |
+
|
683 |
+
try:
|
684 |
+
with ToolCollection.from_mcp(
|
685 |
+
{"url": "http://127.0.0.1:8000/sse", "transport": "sse"}, trust_remote_code=True
|
686 |
+
) as tool_collection:
|
687 |
+
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
688 |
+
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
|
689 |
+
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
|
690 |
+
finally:
|
691 |
+
# clean up the process when test is done
|
692 |
+
server_process.kill()
|
693 |
+
server_process.wait()
|
694 |
+
|
695 |
+
|
696 |
+
@pytest.mark.parametrize("tool_fixture_name", ["boolean_default_tool_class"])
|
697 |
+
def test_launch_gradio_demo_does_not_raise(tool_fixture_name, request):
|
698 |
+
tool = request.getfixturevalue(tool_fixture_name)
|
699 |
+
with patch("gradio.Interface.launch") as mock_launch:
|
700 |
+
launch_gradio_demo(tool)
|
701 |
+
assert mock_launch.call_count == 1
|
702 |
+
|
703 |
+
|
704 |
+
@pytest.mark.parametrize(
|
705 |
+
"tool_input_type, expected_input, expects_error",
|
706 |
+
[
|
707 |
+
(bool, True, False),
|
708 |
+
(str, "b", False),
|
709 |
+
(int, 1, False),
|
710 |
+
(list, ["a", "b"], False),
|
711 |
+
(list[str], ["a", "b"], False),
|
712 |
+
(dict[str, str], {"a": "b"}, False),
|
713 |
+
(dict[str, str], "b", True),
|
714 |
+
(bool, "b", True),
|
715 |
+
],
|
716 |
+
)
|
717 |
+
def test_validate_tool_arguments(tool_input_type, expected_input, expects_error):
|
718 |
+
@tool
|
719 |
+
def test_tool(argument_a: tool_input_type) -> str:
|
720 |
+
"""Fake tool
|
721 |
+
|
722 |
+
Args:
|
723 |
+
argument_a: The input
|
724 |
+
"""
|
725 |
+
return argument_a
|
726 |
+
|
727 |
+
error = validate_tool_arguments(test_tool, {"argument_a": expected_input})
|
728 |
+
if expects_error:
|
729 |
+
assert error is not None
|
730 |
+
else:
|
731 |
+
assert error is None
|
tests/test_types.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import os
|
16 |
+
import tempfile
|
17 |
+
import unittest
|
18 |
+
import uuid
|
19 |
+
|
20 |
+
import PIL.Image
|
21 |
+
from transformers.testing_utils import (
|
22 |
+
require_soundfile,
|
23 |
+
)
|
24 |
+
|
25 |
+
from smolagents.agent_types import AgentAudio, AgentImage, AgentText
|
26 |
+
|
27 |
+
from .utils.markers import require_torch
|
28 |
+
|
29 |
+
|
30 |
+
def get_new_path(suffix="") -> str:
|
31 |
+
directory = tempfile.mkdtemp()
|
32 |
+
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
33 |
+
|
34 |
+
|
35 |
+
@require_soundfile
|
36 |
+
@require_torch
|
37 |
+
class AgentAudioTests(unittest.TestCase):
|
38 |
+
def test_from_tensor(self):
|
39 |
+
import soundfile as sf
|
40 |
+
import torch
|
41 |
+
|
42 |
+
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
43 |
+
agent_type = AgentAudio(tensor)
|
44 |
+
path = str(agent_type.to_string())
|
45 |
+
|
46 |
+
# Ensure that the tensor and the agent_type's tensor are the same
|
47 |
+
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
48 |
+
|
49 |
+
del agent_type
|
50 |
+
|
51 |
+
# Ensure the path remains even after the object deletion
|
52 |
+
self.assertTrue(os.path.exists(path))
|
53 |
+
|
54 |
+
# Ensure that the file contains the same value as the original tensor
|
55 |
+
new_tensor, _ = sf.read(path)
|
56 |
+
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
57 |
+
|
58 |
+
def test_from_string(self):
|
59 |
+
import soundfile as sf
|
60 |
+
import torch
|
61 |
+
|
62 |
+
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
63 |
+
path = get_new_path(suffix=".wav")
|
64 |
+
sf.write(path, tensor, 16000)
|
65 |
+
|
66 |
+
agent_type = AgentAudio(path)
|
67 |
+
|
68 |
+
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
69 |
+
self.assertEqual(agent_type.to_string(), path)
|
70 |
+
|
71 |
+
|
72 |
+
@require_torch
|
73 |
+
class TestAgentImage:
|
74 |
+
def test_from_tensor(self):
|
75 |
+
import torch
|
76 |
+
|
77 |
+
tensor = torch.randint(0, 256, (64, 64, 3))
|
78 |
+
agent_type = AgentImage(tensor)
|
79 |
+
path = str(agent_type.to_string())
|
80 |
+
|
81 |
+
# Ensure that the tensor and the agent_type's tensor are the same
|
82 |
+
assert torch.allclose(tensor, agent_type._tensor, atol=1e-4)
|
83 |
+
|
84 |
+
assert isinstance(agent_type.to_raw(), PIL.Image.Image)
|
85 |
+
|
86 |
+
# Ensure the path remains even after the object deletion
|
87 |
+
del agent_type
|
88 |
+
assert os.path.exists(path)
|
89 |
+
|
90 |
+
def test_from_string(self, shared_datadir):
|
91 |
+
path = shared_datadir / "000000039769.png"
|
92 |
+
image = PIL.Image.open(path)
|
93 |
+
agent_type = AgentImage(path)
|
94 |
+
|
95 |
+
assert path.samefile(agent_type.to_string())
|
96 |
+
assert image == agent_type.to_raw()
|
97 |
+
|
98 |
+
# Ensure the path remains even after the object deletion
|
99 |
+
del agent_type
|
100 |
+
assert os.path.exists(path)
|
101 |
+
|
102 |
+
def test_from_image(self, shared_datadir):
|
103 |
+
path = shared_datadir / "000000039769.png"
|
104 |
+
image = PIL.Image.open(path)
|
105 |
+
agent_type = AgentImage(image)
|
106 |
+
|
107 |
+
assert not path.samefile(agent_type.to_string())
|
108 |
+
assert image == agent_type.to_raw()
|
109 |
+
|
110 |
+
# Ensure the path remains even after the object deletion
|
111 |
+
del agent_type
|
112 |
+
assert os.path.exists(path)
|
113 |
+
|
114 |
+
|
115 |
+
class AgentTextTests(unittest.TestCase):
|
116 |
+
def test_from_string(self):
|
117 |
+
string = "Hey!"
|
118 |
+
agent_type = AgentText(string)
|
119 |
+
|
120 |
+
self.assertEqual(string, agent_type.to_string())
|
121 |
+
self.assertEqual(string, agent_type.to_raw())
|
tests/test_utils.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import inspect
|
16 |
+
import os
|
17 |
+
import textwrap
|
18 |
+
import unittest
|
19 |
+
|
20 |
+
import pytest
|
21 |
+
from IPython.core.interactiveshell import InteractiveShell
|
22 |
+
|
23 |
+
from smolagents import Tool
|
24 |
+
from smolagents.tools import tool
|
25 |
+
from smolagents.utils import get_source, instance_to_source, is_valid_name, parse_code_blobs, parse_json_blob
|
26 |
+
|
27 |
+
|
28 |
+
class ValidTool(Tool):
|
29 |
+
name = "valid_tool"
|
30 |
+
description = "A valid tool"
|
31 |
+
inputs = {"input": {"type": "string", "description": "input"}}
|
32 |
+
output_type = "string"
|
33 |
+
simple_attr = "string"
|
34 |
+
dict_attr = {"key": "value"}
|
35 |
+
|
36 |
+
def __init__(self, optional_param="default"):
|
37 |
+
super().__init__()
|
38 |
+
self.param = optional_param
|
39 |
+
|
40 |
+
def forward(self, input: str) -> str:
|
41 |
+
return input.upper()
|
42 |
+
|
43 |
+
|
44 |
+
@tool
|
45 |
+
def valid_tool_function(input: str) -> str:
|
46 |
+
"""A valid tool function.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
input (str): Input string.
|
50 |
+
"""
|
51 |
+
return input.upper()
|
52 |
+
|
53 |
+
|
54 |
+
VALID_TOOL_SOURCE = """\
|
55 |
+
from smolagents.tools import Tool
|
56 |
+
|
57 |
+
class ValidTool(Tool):
|
58 |
+
name = "valid_tool"
|
59 |
+
description = "A valid tool"
|
60 |
+
inputs = {'input': {'type': 'string', 'description': 'input'}}
|
61 |
+
output_type = "string"
|
62 |
+
simple_attr = "string"
|
63 |
+
dict_attr = {'key': 'value'}
|
64 |
+
|
65 |
+
def __init__(self, optional_param="default"):
|
66 |
+
super().__init__()
|
67 |
+
self.param = optional_param
|
68 |
+
|
69 |
+
def forward(self, input: str) -> str:
|
70 |
+
return input.upper()
|
71 |
+
"""
|
72 |
+
|
73 |
+
VALID_TOOL_FUNCTION_SOURCE = '''\
|
74 |
+
from smolagents.tools import Tool
|
75 |
+
|
76 |
+
class SimpleTool(Tool):
|
77 |
+
name = "valid_tool_function"
|
78 |
+
description = "A valid tool function."
|
79 |
+
inputs = {'input': {'type': 'string', 'description': 'Input string.'}}
|
80 |
+
output_type = "string"
|
81 |
+
|
82 |
+
def __init__(self):
|
83 |
+
self.is_initialized = True
|
84 |
+
|
85 |
+
def forward(self, input: str) -> str:
|
86 |
+
"""A valid tool function.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
input (str): Input string.
|
90 |
+
"""
|
91 |
+
return input.upper()
|
92 |
+
'''
|
93 |
+
|
94 |
+
|
95 |
+
class AgentTextTests(unittest.TestCase):
|
96 |
+
def test_parse_code_blobs(self):
|
97 |
+
with pytest.raises(ValueError):
|
98 |
+
parse_code_blobs("Wrong blob!")
|
99 |
+
|
100 |
+
# Parsing mardkwon with code blobs should work
|
101 |
+
output = parse_code_blobs("""
|
102 |
+
Here is how to solve the problem:
|
103 |
+
<code>
|
104 |
+
import numpy as np
|
105 |
+
</code>
|
106 |
+
""")
|
107 |
+
assert output == "import numpy as np"
|
108 |
+
|
109 |
+
# Parsing code blobs should work
|
110 |
+
code_blob = "import numpy as np"
|
111 |
+
output = parse_code_blobs(code_blob)
|
112 |
+
assert output == code_blob
|
113 |
+
|
114 |
+
# Allow whitespaces after header
|
115 |
+
output = parse_code_blobs("<code> \ncode_a\n</code>")
|
116 |
+
assert output == "code_a"
|
117 |
+
|
118 |
+
def test_multiple_code_blobs(self):
|
119 |
+
test_input = "<code>\nFoo\n</code>\n\n<code>\ncode_a\n</code>\n\n<code>\ncode_b\n</code>"
|
120 |
+
result = parse_code_blobs(test_input)
|
121 |
+
assert result == "Foo\n\ncode_a\n\ncode_b"
|
122 |
+
|
123 |
+
|
124 |
+
@pytest.fixture(scope="function")
|
125 |
+
def ipython_shell():
|
126 |
+
"""Reset IPython shell before and after each test."""
|
127 |
+
shell = InteractiveShell.instance()
|
128 |
+
shell.reset() # Clean before test
|
129 |
+
yield shell
|
130 |
+
shell.reset() # Clean after test
|
131 |
+
|
132 |
+
|
133 |
+
@pytest.mark.parametrize(
|
134 |
+
"obj_name, code_blob",
|
135 |
+
[
|
136 |
+
("test_func", "def test_func():\n return 42"),
|
137 |
+
("TestClass", "class TestClass:\n ..."),
|
138 |
+
],
|
139 |
+
)
|
140 |
+
def test_get_source_ipython(ipython_shell, obj_name, code_blob):
|
141 |
+
ipython_shell.run_cell(code_blob, store_history=True)
|
142 |
+
obj = ipython_shell.user_ns[obj_name]
|
143 |
+
assert get_source(obj) == code_blob
|
144 |
+
|
145 |
+
|
146 |
+
def test_get_source_standard_class():
|
147 |
+
class TestClass: ...
|
148 |
+
|
149 |
+
source = get_source(TestClass)
|
150 |
+
assert source == "class TestClass: ..."
|
151 |
+
assert source == textwrap.dedent(inspect.getsource(TestClass)).strip()
|
152 |
+
|
153 |
+
|
154 |
+
def test_get_source_standard_function():
|
155 |
+
def test_func(): ...
|
156 |
+
|
157 |
+
source = get_source(test_func)
|
158 |
+
assert source == "def test_func(): ..."
|
159 |
+
assert source == textwrap.dedent(inspect.getsource(test_func)).strip()
|
160 |
+
|
161 |
+
|
162 |
+
def test_get_source_ipython_errors_empty_cells(ipython_shell):
|
163 |
+
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
|
164 |
+
ipython_shell.user_ns["In"] = [""]
|
165 |
+
ipython_shell.run_cell(test_code, store_history=True)
|
166 |
+
with pytest.raises(ValueError, match="No code cells found in IPython session"):
|
167 |
+
get_source(ipython_shell.user_ns["TestClass"])
|
168 |
+
|
169 |
+
|
170 |
+
def test_get_source_ipython_errors_definition_not_found(ipython_shell):
|
171 |
+
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
|
172 |
+
ipython_shell.user_ns["In"] = ["", "print('No class definition here')"]
|
173 |
+
ipython_shell.run_cell(test_code, store_history=True)
|
174 |
+
with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"):
|
175 |
+
get_source(ipython_shell.user_ns["TestClass"])
|
176 |
+
|
177 |
+
|
178 |
+
def test_get_source_ipython_errors_type_error():
|
179 |
+
with pytest.raises(TypeError, match="Expected class or callable"):
|
180 |
+
get_source(None)
|
181 |
+
|
182 |
+
|
183 |
+
@pytest.mark.parametrize(
|
184 |
+
"tool, expected_tool_source", [(ValidTool(), VALID_TOOL_SOURCE), (valid_tool_function, VALID_TOOL_FUNCTION_SOURCE)]
|
185 |
+
)
|
186 |
+
def test_instance_to_source(tool, expected_tool_source):
|
187 |
+
tool_source = instance_to_source(tool, base_cls=Tool)
|
188 |
+
assert tool_source == expected_tool_source
|
189 |
+
|
190 |
+
|
191 |
+
def test_e2e_class_tool_save(tmp_path):
|
192 |
+
class TestTool(Tool):
|
193 |
+
name = "test_tool"
|
194 |
+
description = "Test tool description"
|
195 |
+
inputs = {
|
196 |
+
"task": {
|
197 |
+
"type": "string",
|
198 |
+
"description": "tool input",
|
199 |
+
}
|
200 |
+
}
|
201 |
+
output_type = "string"
|
202 |
+
|
203 |
+
def forward(self, task: str):
|
204 |
+
import IPython # noqa: F401
|
205 |
+
|
206 |
+
return task
|
207 |
+
|
208 |
+
test_tool = TestTool()
|
209 |
+
test_tool.save(tmp_path, make_gradio_app=True)
|
210 |
+
assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
|
211 |
+
assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
|
212 |
+
"""\
|
213 |
+
from typing import Any, Optional
|
214 |
+
from smolagents.tools import Tool
|
215 |
+
import IPython
|
216 |
+
|
217 |
+
class TestTool(Tool):
|
218 |
+
name = "test_tool"
|
219 |
+
description = "Test tool description"
|
220 |
+
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
221 |
+
output_type = "string"
|
222 |
+
|
223 |
+
def forward(self, task: str):
|
224 |
+
import IPython # noqa: F401
|
225 |
+
|
226 |
+
return task
|
227 |
+
|
228 |
+
def __init__(self, *args, **kwargs):
|
229 |
+
self.is_initialized = False
|
230 |
+
"""
|
231 |
+
)
|
232 |
+
requirements = set((tmp_path / "requirements.txt").read_text().split())
|
233 |
+
assert requirements == {"IPython", "smolagents"}
|
234 |
+
assert (tmp_path / "app.py").read_text() == textwrap.dedent(
|
235 |
+
"""\
|
236 |
+
from smolagents import launch_gradio_demo
|
237 |
+
from tool import TestTool
|
238 |
+
|
239 |
+
tool = TestTool()
|
240 |
+
launch_gradio_demo(tool)
|
241 |
+
"""
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
def test_e2e_ipython_class_tool_save(tmp_path):
|
246 |
+
shell = InteractiveShell.instance()
|
247 |
+
code_blob = textwrap.dedent(
|
248 |
+
f"""\
|
249 |
+
from smolagents.tools import Tool
|
250 |
+
class TestTool(Tool):
|
251 |
+
name = "test_tool"
|
252 |
+
description = "Test tool description"
|
253 |
+
inputs = {{"task": {{"type": "string",
|
254 |
+
"description": "tool input",
|
255 |
+
}}
|
256 |
+
}}
|
257 |
+
output_type = "string"
|
258 |
+
|
259 |
+
def forward(self, task: str):
|
260 |
+
import IPython # noqa: F401
|
261 |
+
|
262 |
+
return task
|
263 |
+
TestTool().save("{tmp_path}", make_gradio_app=True)
|
264 |
+
"""
|
265 |
+
)
|
266 |
+
assert shell.run_cell(code_blob, store_history=True).success
|
267 |
+
assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
|
268 |
+
assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
|
269 |
+
"""\
|
270 |
+
from typing import Any, Optional
|
271 |
+
from smolagents.tools import Tool
|
272 |
+
import IPython
|
273 |
+
|
274 |
+
class TestTool(Tool):
|
275 |
+
name = "test_tool"
|
276 |
+
description = "Test tool description"
|
277 |
+
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
278 |
+
output_type = "string"
|
279 |
+
|
280 |
+
def forward(self, task: str):
|
281 |
+
import IPython # noqa: F401
|
282 |
+
|
283 |
+
return task
|
284 |
+
|
285 |
+
def __init__(self, *args, **kwargs):
|
286 |
+
self.is_initialized = False
|
287 |
+
"""
|
288 |
+
)
|
289 |
+
requirements = set((tmp_path / "requirements.txt").read_text().split())
|
290 |
+
assert requirements == {"IPython", "smolagents"}
|
291 |
+
assert (tmp_path / "app.py").read_text() == textwrap.dedent(
|
292 |
+
"""\
|
293 |
+
from smolagents import launch_gradio_demo
|
294 |
+
from tool import TestTool
|
295 |
+
|
296 |
+
tool = TestTool()
|
297 |
+
launch_gradio_demo(tool)
|
298 |
+
"""
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
def test_e2e_function_tool_save(tmp_path):
|
303 |
+
@tool
|
304 |
+
def test_tool(task: str) -> str:
|
305 |
+
"""
|
306 |
+
Test tool description
|
307 |
+
|
308 |
+
Args:
|
309 |
+
task: tool input
|
310 |
+
"""
|
311 |
+
import IPython # noqa: F401
|
312 |
+
|
313 |
+
return task
|
314 |
+
|
315 |
+
test_tool.save(tmp_path, make_gradio_app=True)
|
316 |
+
assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
|
317 |
+
assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
|
318 |
+
"""\
|
319 |
+
from smolagents import Tool
|
320 |
+
from typing import Any, Optional
|
321 |
+
|
322 |
+
class SimpleTool(Tool):
|
323 |
+
name = "test_tool"
|
324 |
+
description = "Test tool description"
|
325 |
+
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
326 |
+
output_type = "string"
|
327 |
+
|
328 |
+
def forward(self, task: str) -> str:
|
329 |
+
\"""
|
330 |
+
Test tool description
|
331 |
+
|
332 |
+
Args:
|
333 |
+
task: tool input
|
334 |
+
\"""
|
335 |
+
import IPython # noqa: F401
|
336 |
+
|
337 |
+
return task"""
|
338 |
+
)
|
339 |
+
requirements = set((tmp_path / "requirements.txt").read_text().split())
|
340 |
+
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
|
341 |
+
assert (tmp_path / "app.py").read_text() == textwrap.dedent(
|
342 |
+
"""\
|
343 |
+
from smolagents import launch_gradio_demo
|
344 |
+
from tool import SimpleTool
|
345 |
+
|
346 |
+
tool = SimpleTool()
|
347 |
+
launch_gradio_demo(tool)
|
348 |
+
"""
|
349 |
+
)
|
350 |
+
|
351 |
+
|
352 |
+
def test_e2e_ipython_function_tool_save(tmp_path):
|
353 |
+
shell = InteractiveShell.instance()
|
354 |
+
code_blob = textwrap.dedent(
|
355 |
+
f"""
|
356 |
+
from smolagents import tool
|
357 |
+
|
358 |
+
@tool
|
359 |
+
def test_tool(task: str) -> str:
|
360 |
+
\"""
|
361 |
+
Test tool description
|
362 |
+
|
363 |
+
Args:
|
364 |
+
task: tool input
|
365 |
+
\"""
|
366 |
+
import IPython # noqa: F401
|
367 |
+
|
368 |
+
return task
|
369 |
+
|
370 |
+
test_tool.save("{tmp_path}", make_gradio_app=True)
|
371 |
+
"""
|
372 |
+
)
|
373 |
+
assert shell.run_cell(code_blob, store_history=True).success
|
374 |
+
assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"}
|
375 |
+
assert (tmp_path / "tool.py").read_text() == textwrap.dedent(
|
376 |
+
"""\
|
377 |
+
from smolagents import Tool
|
378 |
+
from typing import Any, Optional
|
379 |
+
|
380 |
+
class SimpleTool(Tool):
|
381 |
+
name = "test_tool"
|
382 |
+
description = "Test tool description"
|
383 |
+
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
384 |
+
output_type = "string"
|
385 |
+
|
386 |
+
def forward(self, task: str) -> str:
|
387 |
+
\"""
|
388 |
+
Test tool description
|
389 |
+
|
390 |
+
Args:
|
391 |
+
task: tool input
|
392 |
+
\"""
|
393 |
+
import IPython # noqa: F401
|
394 |
+
|
395 |
+
return task"""
|
396 |
+
)
|
397 |
+
requirements = set((tmp_path / "requirements.txt").read_text().split())
|
398 |
+
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
|
399 |
+
assert (tmp_path / "app.py").read_text() == textwrap.dedent(
|
400 |
+
"""\
|
401 |
+
from smolagents import launch_gradio_demo
|
402 |
+
from tool import SimpleTool
|
403 |
+
|
404 |
+
tool = SimpleTool()
|
405 |
+
launch_gradio_demo(tool)
|
406 |
+
"""
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
@pytest.mark.parametrize(
|
411 |
+
"raw_json, expected_data, expected_blob",
|
412 |
+
[
|
413 |
+
(
|
414 |
+
"""{}""",
|
415 |
+
{},
|
416 |
+
"",
|
417 |
+
),
|
418 |
+
(
|
419 |
+
"""Text{}""",
|
420 |
+
{},
|
421 |
+
"Text",
|
422 |
+
),
|
423 |
+
(
|
424 |
+
"""{"simple": "json"}""",
|
425 |
+
{"simple": "json"},
|
426 |
+
"",
|
427 |
+
),
|
428 |
+
(
|
429 |
+
"""With text here{"simple": "json"}""",
|
430 |
+
{"simple": "json"},
|
431 |
+
"With text here",
|
432 |
+
),
|
433 |
+
(
|
434 |
+
"""{"simple": "json"}With text after""",
|
435 |
+
{"simple": "json"},
|
436 |
+
"",
|
437 |
+
),
|
438 |
+
(
|
439 |
+
"""With text before{"simple": "json"}And text after""",
|
440 |
+
{"simple": "json"},
|
441 |
+
"With text before",
|
442 |
+
),
|
443 |
+
],
|
444 |
+
)
|
445 |
+
def test_parse_json_blob_with_valid_json(raw_json, expected_data, expected_blob):
|
446 |
+
data, blob = parse_json_blob(raw_json)
|
447 |
+
|
448 |
+
assert data == expected_data
|
449 |
+
assert blob == expected_blob
|
450 |
+
|
451 |
+
|
452 |
+
@pytest.mark.parametrize(
|
453 |
+
"raw_json",
|
454 |
+
[
|
455 |
+
"""simple": "json"}""",
|
456 |
+
"""With text here"simple": "json"}""",
|
457 |
+
"""{"simple": ""json"}With text after""",
|
458 |
+
"""{"simple": "json"With text after""",
|
459 |
+
"}}",
|
460 |
+
],
|
461 |
+
)
|
462 |
+
def test_parse_json_blob_with_invalid_json(raw_json):
|
463 |
+
with pytest.raises(Exception):
|
464 |
+
parse_json_blob(raw_json)
|
465 |
+
|
466 |
+
|
467 |
+
@pytest.mark.parametrize(
|
468 |
+
"name,expected",
|
469 |
+
[
|
470 |
+
# Valid identifiers
|
471 |
+
("valid_name", True),
|
472 |
+
("ValidName", True),
|
473 |
+
("valid123", True),
|
474 |
+
("_private", True),
|
475 |
+
# Invalid identifiers
|
476 |
+
("", False),
|
477 |
+
("123invalid", False),
|
478 |
+
("invalid-name", False),
|
479 |
+
("invalid name", False),
|
480 |
+
("invalid.name", False),
|
481 |
+
# Python keywords
|
482 |
+
("if", False),
|
483 |
+
("for", False),
|
484 |
+
("class", False),
|
485 |
+
("return", False),
|
486 |
+
# Non-string inputs
|
487 |
+
(123, False),
|
488 |
+
(None, False),
|
489 |
+
([], False),
|
490 |
+
({}, False),
|
491 |
+
],
|
492 |
+
)
|
493 |
+
def test_is_valid_name(name, expected):
|
494 |
+
"""Test the is_valid_name function with various inputs."""
|
495 |
+
assert is_valid_name(name) is expected
|