mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 05:37:25 +00:00
refactor: create_tool_from_function + tool decorator (#8697)
* create_tool_from_function + decorator * release note * improve usage example * add imports to @tool usage example * clarify docstrings * small docstring addition
This commit is contained in:
parent
dd9660f90d
commit
08cf09f83f
@ -2,7 +2,7 @@ loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/tools]
|
||||
modules:
|
||||
["tool"]
|
||||
["tool", "from_function"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from haystack.tools.from_function import create_tool_from_function, tool
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
|
||||
__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"]
|
||||
__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace", "create_tool_from_function", "tool"]
|
||||
|
||||
19
haystack/tools/errors.py
Normal file
19
haystack/tools/errors.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
class SchemaGenerationError(Exception):
|
||||
"""
|
||||
Exception raised when automatic schema generation fails.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolInvocationError(Exception):
|
||||
"""
|
||||
Exception raised when a Tool invocation fails.
|
||||
"""
|
||||
|
||||
pass
|
||||
166
haystack/tools/from_function.py
Normal file
166
haystack/tools/from_function.py
Normal file
@ -0,0 +1,166 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from haystack.tools.tool import Tool
|
||||
|
||||
|
||||
def create_tool_from_function(
|
||||
function: Callable, name: Optional[str] = None, description: Optional[str] = None
|
||||
) -> "Tool":
|
||||
"""
|
||||
Create a Tool instance from a function.
|
||||
|
||||
Allows customizing the Tool name and description.
|
||||
For simpler use cases, consider using the `@tool` decorator.
|
||||
|
||||
### Usage example
|
||||
|
||||
```python
|
||||
from typing import Annotated, Literal
|
||||
from haystack.tools import create_tool_from_function
|
||||
|
||||
def get_weather(
|
||||
city: Annotated[str, "the city for which to get the weather"] = "Munich",
|
||||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"):
|
||||
'''A simple function to get the current weather for a location.'''
|
||||
return f"Weather report for {city}: 20 {unit}, sunny"
|
||||
|
||||
tool = create_tool_from_function(get_weather)
|
||||
|
||||
print(tool)
|
||||
>>> Tool(name='get_weather', description='A simple function to get the current weather for a location.',
|
||||
>>> parameters={
|
||||
>>> 'type': 'object',
|
||||
>>> 'properties': {
|
||||
>>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'},
|
||||
>>> 'unit': {
|
||||
>>> 'type': 'string',
|
||||
>>> 'enum': ['Celsius', 'Fahrenheit'],
|
||||
>>> 'description': 'the unit for the temperature',
|
||||
>>> 'default': 'Celsius',
|
||||
>>> },
|
||||
>>> }
|
||||
>>> },
|
||||
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
|
||||
```
|
||||
|
||||
:param function:
|
||||
The function to be converted into a Tool.
|
||||
The function must include type hints for all parameters.
|
||||
The function is expected to have basic python input types (str, int, float, bool, list, dict, tuple).
|
||||
Other input types may work but are not guaranteed.
|
||||
If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description.
|
||||
:param name:
|
||||
The name of the Tool. If not provided, the name of the function will be used.
|
||||
:param description:
|
||||
The description of the Tool. If not provided, the docstring of the function will be used.
|
||||
To intentionally leave the description empty, pass an empty string.
|
||||
|
||||
:returns:
|
||||
The Tool created from the function.
|
||||
|
||||
:raises ValueError:
|
||||
If any parameter of the function lacks a type hint.
|
||||
:raises SchemaGenerationError:
|
||||
If there is an error generating the JSON schema for the Tool.
|
||||
"""
|
||||
|
||||
tool_description = description if description is not None else (function.__doc__ or "")
|
||||
|
||||
signature = inspect.signature(function)
|
||||
|
||||
# collect fields (types and defaults) and descriptions from function parameters
|
||||
fields: Dict[str, Any] = {}
|
||||
descriptions = {}
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation is param.empty:
|
||||
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
|
||||
|
||||
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
|
||||
# to explicitly indicate that the parameter is required
|
||||
default = param.default if param.default is not param.empty else ...
|
||||
fields[param_name] = (param.annotation, default)
|
||||
|
||||
if hasattr(param.annotation, "__metadata__"):
|
||||
descriptions[param_name] = param.annotation.__metadata__[0]
|
||||
|
||||
# create Pydantic model and generate JSON schema
|
||||
try:
|
||||
model = create_model(function.__name__, **fields)
|
||||
schema = model.model_json_schema()
|
||||
except Exception as e:
|
||||
raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e
|
||||
|
||||
# we don't want to include title keywords in the schema, as they contain redundant information
|
||||
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later
|
||||
# see https://github.com/pydantic/pydantic/discussions/8504
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
# add parameters descriptions to the schema
|
||||
for param_name, param_description in descriptions.items():
|
||||
if param_name in schema["properties"]:
|
||||
schema["properties"][param_name]["description"] = param_description
|
||||
|
||||
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
|
||||
|
||||
|
||||
def tool(function: Callable) -> Tool:
|
||||
"""
|
||||
Decorator to convert a function into a Tool.
|
||||
|
||||
Tool name, description, and parameters are inferred from the function.
|
||||
If you need to customize more the Tool, use `create_tool_from_function` instead.
|
||||
|
||||
### Usage example
|
||||
```python
|
||||
from typing import Annotated, Literal
|
||||
from haystack.tools import tool
|
||||
|
||||
@tool
|
||||
def get_weather(
|
||||
city: Annotated[str, "the city for which to get the weather"] = "Munich",
|
||||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"):
|
||||
'''A simple function to get the current weather for a location.'''
|
||||
return f"Weather report for {city}: 20 {unit}, sunny"
|
||||
|
||||
print(get_weather)
|
||||
>>> Tool(name='get_weather', description='A simple function to get the current weather for a location.',
|
||||
>>> parameters={
|
||||
>>> 'type': 'object',
|
||||
>>> 'properties': {
|
||||
>>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'},
|
||||
>>> 'unit': {
|
||||
>>> 'type': 'string',
|
||||
>>> 'enum': ['Celsius', 'Fahrenheit'],
|
||||
>>> 'description': 'the unit for the temperature',
|
||||
>>> 'default': 'Celsius',
|
||||
>>> },
|
||||
>>> }
|
||||
>>> },
|
||||
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
|
||||
```
|
||||
"""
|
||||
return create_tool_from_function(function)
|
||||
|
||||
|
||||
def _remove_title_from_schema(schema: Dict[str, Any]):
|
||||
"""
|
||||
Remove the 'title' keyword from JSON schema and contained property schemas.
|
||||
|
||||
:param schema:
|
||||
The JSON schema to remove the 'title' keyword from.
|
||||
"""
|
||||
schema.pop("title", None)
|
||||
|
||||
for property_schema in schema["properties"].values():
|
||||
for key in list(property_schema.keys()):
|
||||
if key == "title":
|
||||
del property_schema[key]
|
||||
@ -2,14 +2,12 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools.errors import ToolInvocationError
|
||||
from haystack.utils import deserialize_callable, serialize_callable
|
||||
|
||||
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
|
||||
@ -17,22 +15,6 @@ with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
|
||||
from jsonschema.exceptions import SchemaError
|
||||
|
||||
|
||||
class ToolInvocationError(Exception):
|
||||
"""
|
||||
Exception raised when a Tool invocation fails.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SchemaGenerationError(Exception):
|
||||
"""
|
||||
Exception raised when automatic schema generation fails.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""
|
||||
@ -108,115 +90,6 @@ class Tool:
|
||||
init_parameters["function"] = deserialize_callable(init_parameters["function"])
|
||||
return cls(**init_parameters)
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool":
|
||||
"""
|
||||
Create a Tool instance from a function.
|
||||
|
||||
### Usage example
|
||||
|
||||
```python
|
||||
from typing import Annotated, Literal
|
||||
from haystack.dataclasses import Tool
|
||||
|
||||
def get_weather(
|
||||
city: Annotated[str, "the city for which to get the weather"] = "Munich",
|
||||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"):
|
||||
'''A simple function to get the current weather for a location.'''
|
||||
return f"Weather report for {city}: 20 {unit}, sunny"
|
||||
|
||||
tool = Tool.from_function(get_weather)
|
||||
|
||||
print(tool)
|
||||
>>> Tool(name='get_weather', description='A simple function to get the current weather for a location.',
|
||||
>>> parameters={
|
||||
>>> 'type': 'object',
|
||||
>>> 'properties': {
|
||||
>>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'},
|
||||
>>> 'unit': {
|
||||
>>> 'type': 'string',
|
||||
>>> 'enum': ['Celsius', 'Fahrenheit'],
|
||||
>>> 'description': 'the unit for the temperature',
|
||||
>>> 'default': 'Celsius',
|
||||
>>> },
|
||||
>>> }
|
||||
>>> },
|
||||
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
|
||||
```
|
||||
|
||||
:param function:
|
||||
The function to be converted into a Tool.
|
||||
The function must include type hints for all parameters.
|
||||
If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description.
|
||||
:param name:
|
||||
The name of the Tool. If not provided, the name of the function will be used.
|
||||
:param description:
|
||||
The description of the Tool. If not provided, the docstring of the function will be used.
|
||||
To intentionally leave the description empty, pass an empty string.
|
||||
|
||||
:returns:
|
||||
The Tool created from the function.
|
||||
|
||||
:raises ValueError:
|
||||
If any parameter of the function lacks a type hint.
|
||||
:raises SchemaGenerationError:
|
||||
If there is an error generating the JSON schema for the Tool.
|
||||
"""
|
||||
|
||||
tool_description = description if description is not None else (function.__doc__ or "")
|
||||
|
||||
signature = inspect.signature(function)
|
||||
|
||||
# collect fields (types and defaults) and descriptions from function parameters
|
||||
fields: Dict[str, Any] = {}
|
||||
descriptions = {}
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation is param.empty:
|
||||
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
|
||||
|
||||
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
|
||||
# to explicitly indicate that the parameter is required
|
||||
default = param.default if param.default is not param.empty else ...
|
||||
fields[param_name] = (param.annotation, default)
|
||||
|
||||
if hasattr(param.annotation, "__metadata__"):
|
||||
descriptions[param_name] = param.annotation.__metadata__[0]
|
||||
|
||||
# create Pydantic model and generate JSON schema
|
||||
try:
|
||||
model = create_model(function.__name__, **fields)
|
||||
schema = model.model_json_schema()
|
||||
except Exception as e:
|
||||
raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e
|
||||
|
||||
# we don't want to include title keywords in the schema, as they contain redundant information
|
||||
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later
|
||||
# see https://github.com/pydantic/pydantic/discussions/8504
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
# add parameters descriptions to the schema
|
||||
for param_name, param_description in descriptions.items():
|
||||
if param_name in schema["properties"]:
|
||||
schema["properties"][param_name]["description"] = param_description
|
||||
|
||||
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
|
||||
|
||||
|
||||
def _remove_title_from_schema(schema: Dict[str, Any]):
|
||||
"""
|
||||
Remove the 'title' keyword from JSON schema and contained property schemas.
|
||||
|
||||
:param schema:
|
||||
The JSON schema to remove the 'title' keyword from.
|
||||
"""
|
||||
schema.pop("title", None)
|
||||
|
||||
for property_schema in schema["properties"].values():
|
||||
for key in list(property_schema.keys()):
|
||||
if key == "title":
|
||||
del property_schema[key]
|
||||
|
||||
|
||||
def _check_duplicate_tool_names(tools: Optional[List[Tool]]) -> None:
|
||||
"""
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added a new `create_tool_from_function` function to create a `Tool` instance from a function, with automatic
|
||||
generation of name, description and parameters.
|
||||
Added a `tool` decorator to achieve the same result.
|
||||
231
test/tools/test_from_function.py
Normal file
231
test/tools/test_from_function.py
Normal file
@ -0,0 +1,231 @@
|
||||
import pytest
|
||||
|
||||
from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from typing import Literal, Optional
|
||||
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
def function_with_docstring(city: str) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
|
||||
def test_from_function_description_from_docstring():
|
||||
tool = create_tool_from_function(function=function_with_docstring)
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == "Get weather report for a city."
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
|
||||
def test_from_function_with_empty_description():
|
||||
tool = create_tool_from_function(function=function_with_docstring, description="")
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == ""
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
|
||||
def test_from_function_with_custom_description():
|
||||
tool = create_tool_from_function(function=function_with_docstring, description="custom description")
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == "custom description"
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
|
||||
def test_from_function_with_custom_name():
|
||||
tool = create_tool_from_function(function=function_with_docstring, name="custom_name")
|
||||
|
||||
assert tool.name == "custom_name"
|
||||
assert tool.description == "Get weather report for a city."
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
|
||||
def test_from_function_annotated():
|
||||
def function_with_annotations(
|
||||
city: Annotated[str, "the city for which to get the weather"] = "Munich",
|
||||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius",
|
||||
nullable_param: Annotated[Optional[str], "a nullable parameter"] = None,
|
||||
) -> str:
|
||||
"""A simple function to get the current weather for a location."""
|
||||
return f"Weather report for {city}: 20 {unit}, sunny"
|
||||
|
||||
tool = create_tool_from_function(function=function_with_annotations)
|
||||
|
||||
assert tool.name == "function_with_annotations"
|
||||
assert tool.description == "A simple function to get the current weather for a location."
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["Celsius", "Fahrenheit"],
|
||||
"description": "the unit for the temperature",
|
||||
"default": "Celsius",
|
||||
},
|
||||
"nullable_param": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "a nullable parameter",
|
||||
"default": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_from_function_missing_type_hint():
|
||||
def function_missing_type_hint(city) -> str:
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_tool_from_function(function=function_missing_type_hint)
|
||||
|
||||
|
||||
def test_from_function_schema_generation_error():
|
||||
def function_with_invalid_type_hint(city: "invalid") -> str:
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
with pytest.raises(SchemaGenerationError):
|
||||
create_tool_from_function(function=function_with_invalid_type_hint)
|
||||
|
||||
|
||||
def test_tool_decorator():
|
||||
@tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
assert get_weather.name == "get_weather"
|
||||
assert get_weather.description == "Get weather report for a city."
|
||||
assert get_weather.parameters == {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
}
|
||||
assert callable(get_weather.function)
|
||||
assert get_weather.function("Berlin") == "Weather report for Berlin: 20°C, sunny"
|
||||
|
||||
|
||||
def test_tool_decorator_with_annotated_params():
|
||||
@tool
|
||||
def get_weather(
|
||||
city: Annotated[str, "The target city"] = "Berlin",
|
||||
format: Annotated[Literal["short", "long"], "Output format"] = "short",
|
||||
) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city} ({format} format): 20°C, sunny"
|
||||
|
||||
assert get_weather.name == "get_weather"
|
||||
assert get_weather.description == "Get weather report for a city."
|
||||
assert get_weather.parameters == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The target city", "default": "Berlin"},
|
||||
"format": {"type": "string", "enum": ["short", "long"], "description": "Output format", "default": "short"},
|
||||
},
|
||||
}
|
||||
assert callable(get_weather.function)
|
||||
assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny"
|
||||
|
||||
|
||||
def test_remove_title_from_schema():
|
||||
complex_schema = {
|
||||
"properties": {
|
||||
"parameter1": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
"default": "default_value",
|
||||
"title": "Parameter1",
|
||||
},
|
||||
"parameter2": {
|
||||
"default": [1, 2, 3],
|
||||
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"title": "Parameter2",
|
||||
"type": "array",
|
||||
},
|
||||
"parameter3": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
{"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
|
||||
],
|
||||
"default": 42,
|
||||
"title": "Parameter3",
|
||||
},
|
||||
"parameter4": {
|
||||
"anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
|
||||
"default": {"key": "value"},
|
||||
"title": "Parameter4",
|
||||
},
|
||||
},
|
||||
"title": "complex_function",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(complex_schema)
|
||||
|
||||
assert complex_schema == {
|
||||
"properties": {
|
||||
"parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"},
|
||||
"parameter2": {
|
||||
"default": [1, 2, 3],
|
||||
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"type": "array",
|
||||
},
|
||||
"parameter3": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
{"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
|
||||
],
|
||||
"default": 42,
|
||||
},
|
||||
"parameter4": {
|
||||
"anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
|
||||
"default": {"key": "value"},
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_remove_title_from_schema_do_not_remove_title_property():
|
||||
"""Test that the utility function only removes the 'title' keywords and not the 'title' property (if present)."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"parameter1": {"type": "string", "title": "Parameter1"},
|
||||
"title": {"type": "string", "title": "Title"},
|
||||
},
|
||||
"title": "complex_function",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"}
|
||||
|
||||
|
||||
def test_remove_title_from_schema_handle_no_title_in_top_level():
|
||||
schema = {
|
||||
"properties": {
|
||||
"parameter1": {"type": "string", "title": "Parameter1"},
|
||||
"parameter2": {"type": "integer", "title": "Parameter2"},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
assert schema == {
|
||||
"properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}},
|
||||
"type": "object",
|
||||
}
|
||||
@ -2,22 +2,9 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
from haystack.tools.tool import (
|
||||
SchemaGenerationError,
|
||||
Tool,
|
||||
ToolInvocationError,
|
||||
_remove_title_from_schema,
|
||||
deserialize_tools_inplace,
|
||||
_check_duplicate_tool_names,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
from typing_extensions import Annotated
|
||||
from haystack.tools.tool import Tool, ToolInvocationError, deserialize_tools_inplace, _check_duplicate_tool_names
|
||||
|
||||
|
||||
def get_weather_report(city: str) -> str:
|
||||
@ -27,11 +14,6 @@ def get_weather_report(city: str) -> str:
|
||||
parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
|
||||
|
||||
def function_with_docstring(city: str) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
|
||||
class TestTool:
|
||||
def test_init(self):
|
||||
tool = Tool(
|
||||
@ -104,83 +86,6 @@ class TestTool:
|
||||
assert tool.parameters == parameters
|
||||
assert tool.function == get_weather_report
|
||||
|
||||
def test_from_function_description_from_docstring(self):
|
||||
tool = Tool.from_function(function=function_with_docstring)
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == "Get weather report for a city."
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
def test_from_function_with_empty_description(self):
|
||||
tool = Tool.from_function(function=function_with_docstring, description="")
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == ""
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
def test_from_function_with_custom_description(self):
|
||||
tool = Tool.from_function(function=function_with_docstring, description="custom description")
|
||||
|
||||
assert tool.name == "function_with_docstring"
|
||||
assert tool.description == "custom description"
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
def test_from_function_with_custom_name(self):
|
||||
tool = Tool.from_function(function=function_with_docstring, name="custom_name")
|
||||
|
||||
assert tool.name == "custom_name"
|
||||
assert tool.description == "Get weather report for a city."
|
||||
assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
assert tool.function == function_with_docstring
|
||||
|
||||
def test_from_function_missing_type_hint(self):
|
||||
def function_missing_type_hint(city) -> str:
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Tool.from_function(function=function_missing_type_hint)
|
||||
|
||||
def test_from_function_schema_generation_error(self):
|
||||
def function_with_invalid_type_hint(city: "invalid") -> str:
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
with pytest.raises(SchemaGenerationError):
|
||||
Tool.from_function(function=function_with_invalid_type_hint)
|
||||
|
||||
def test_from_function_annotated(self):
|
||||
def function_with_annotations(
|
||||
city: Annotated[str, "the city for which to get the weather"] = "Munich",
|
||||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius",
|
||||
nullable_param: Annotated[Optional[str], "a nullable parameter"] = None,
|
||||
) -> str:
|
||||
"""A simple function to get the current weather for a location."""
|
||||
return f"Weather report for {city}: 20 {unit}, sunny"
|
||||
|
||||
tool = Tool.from_function(function=function_with_annotations)
|
||||
|
||||
assert tool.name == "function_with_annotations"
|
||||
assert tool.description == "A simple function to get the current weather for a location."
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["Celsius", "Fahrenheit"],
|
||||
"description": "the unit for the temperature",
|
||||
"default": "Celsius",
|
||||
},
|
||||
"nullable_param": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "a nullable parameter",
|
||||
"default": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_deserialize_tools_inplace():
|
||||
tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report)
|
||||
@ -221,99 +126,6 @@ def test_deserialize_tools_inplace_failures():
|
||||
deserialize_tools_inplace(data)
|
||||
|
||||
|
||||
def test_remove_title_from_schema():
|
||||
complex_schema = {
|
||||
"properties": {
|
||||
"parameter1": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
"default": "default_value",
|
||||
"title": "Parameter1",
|
||||
},
|
||||
"parameter2": {
|
||||
"default": [1, 2, 3],
|
||||
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"title": "Parameter2",
|
||||
"type": "array",
|
||||
},
|
||||
"parameter3": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
{"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
|
||||
],
|
||||
"default": 42,
|
||||
"title": "Parameter3",
|
||||
},
|
||||
"parameter4": {
|
||||
"anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
|
||||
"default": {"key": "value"},
|
||||
"title": "Parameter4",
|
||||
},
|
||||
},
|
||||
"title": "complex_function",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(complex_schema)
|
||||
|
||||
assert complex_schema == {
|
||||
"properties": {
|
||||
"parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"},
|
||||
"parameter2": {
|
||||
"default": [1, 2, 3],
|
||||
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"type": "array",
|
||||
},
|
||||
"parameter3": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
{"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
|
||||
],
|
||||
"default": 42,
|
||||
},
|
||||
"parameter4": {
|
||||
"anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
|
||||
"default": {"key": "value"},
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_remove_title_from_schema_do_not_remove_title_property():
|
||||
"""Test that the utility function only removes the 'title' keywords and not the 'title' property (if present)."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"parameter1": {"type": "string", "title": "Parameter1"},
|
||||
"title": {"type": "string", "title": "Title"},
|
||||
},
|
||||
"title": "complex_function",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"}
|
||||
|
||||
|
||||
def test_remove_title_from_schema_handle_no_title_in_top_level():
|
||||
schema = {
|
||||
"properties": {
|
||||
"parameter1": {"type": "string", "title": "Parameter1"},
|
||||
"parameter2": {"type": "integer", "title": "Parameter2"},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
_remove_title_from_schema(schema)
|
||||
|
||||
assert schema == {
|
||||
"properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}},
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_check_duplicate_tool_names():
|
||||
tools = [
|
||||
Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user