Fix component tool parameters (#9342)

* Starting property schema refactor

* Adding more tests

* More tests

* Handle null type explicitly

* More updates of tests to accomodate Optional properly

* Fix more tests

* Remove unecessary check

* Some cleanup

* Update test

* Add reno

* Fix typing

* Add license header

* Use docstrings of dataclasses in parameter spec generation

* More tests of Haystack dataclass types

* Properly handle Sequence

* Fix license header

* Update OpenAI tests to add more complicated tool parameter signature

* Properly set required for dataclasses

* Add integration test for azure that includes additionalProperties

* Add more complicated integration test for HuggingFaceAPIChatGenerator

* Alternate approach using pydantic like we do in from_function.py

* Cleanup and fix other affected tests

* Fix mypy

* PR comments

* PR comment

* Remove test from HF API

* Update reno

* Update reno
This commit is contained in:
Sebastian Husch Lee 2025-05-15 09:51:06 +02:00 committed by GitHub
parent 42b378950f
commit 9ae76e1653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 567 additions and 295 deletions

View File

@ -11,6 +11,10 @@ from typing import Any, Dict, Optional
class ByteStream:
"""
Base data class representing a binary object in the Haystack API.
:param data: The binary data stored in Bytestream.
:param meta: Additional metadata to be stored with the ByteStream.
:param mime_type: The mime type of the binary data.
"""
data: bytes

View File

@ -2,11 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import fields, is_dataclass
from inspect import getdoc
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
from pydantic import TypeAdapter
from pydantic import Field, TypeAdapter, create_model
from haystack import logging
from haystack.core.component import Component
@ -16,15 +14,12 @@ from haystack.core.serialization import (
generate_qualified_class_name,
import_class_by_name,
)
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool
from haystack.tools.errors import SchemaGenerationError
from haystack.tools.from_function import _remove_title_from_schema
from haystack.tools.parameters_schema_utils import _get_param_descriptions, _resolve_type
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
from docstring_parser import parse
logger = logging.getLogger(__name__)
@ -275,10 +270,10 @@ class ComponentTool(Tool):
:raises SchemaGenerationError: If schema generation fails
:returns: OpenAI tools schema for the component's run method parameters.
"""
properties = {}
required = []
component_run_description, param_descriptions = _get_param_descriptions(component.run)
param_descriptions = self._get_param_descriptions(component.run)
# collect fields (types and defaults) and descriptions from function parameters
fields: Dict[str, Any] = {}
for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined]
if inputs_from_state is not None and input_name in inputs_from_state:
@ -286,135 +281,23 @@ class ComponentTool(Tool):
input_type = socket.type
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
try:
property_schema = self._create_property_schema(input_type, description)
except Exception as e:
raise SchemaGenerationError(
f"Error processing input '{input_name}': {e}. "
f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
f"and lists of these types as input types for component's run method."
) from e
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
# to explicitly indicate that the parameter is required
default = ... if socket.is_mandatory else socket.default_value
resolved_type = _resolve_type(input_type)
fields[input_name] = (resolved_type, Field(default=default, description=description))
properties[input_name] = property_schema
try:
model = create_model(component.run.__name__, __doc__=component_run_description, **fields)
parameters_schema = model.model_json_schema()
except Exception as e:
raise SchemaGenerationError(
f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'"
) from e
# Use socket.is_mandatory to check if the input is required
if socket.is_mandatory:
required.append(input_name)
parameters_schema = {"type": "object", "properties": properties}
if required:
parameters_schema["required"] = required
# we don't want to include title keywords in the schema, as they contain redundant information
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later
# see https://github.com/pydantic/pydantic/discussions/8504
_remove_title_from_schema(parameters_schema)
return parameters_schema
@staticmethod
def _get_param_descriptions(method: Callable) -> Dict[str, str]:
"""
Extracts parameter descriptions from the method's docstring using docstring_parser.
:param method: The method to extract parameter descriptions from.
:returns: A dictionary mapping parameter names to their descriptions.
"""
docstring = getdoc(method)
if not docstring:
return {}
docstring_parser_import.check()
parsed_doc = parse(docstring)
param_descriptions = {}
for param in parsed_doc.params:
if not param.description:
logger.warning(
"Missing description for parameter '%s'. Please add a description in the component's "
"run() method docstring using the format ':param %%s: <description>'. "
"This description helps the LLM understand how to use this parameter." % param.arg_name
)
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
return param_descriptions
@staticmethod
def _is_nullable_type(python_type: Any) -> bool:
"""
Checks if the type is a Union with NoneType (i.e., Optional).
:param python_type: The Python type to check.
:returns: True if the type is a Union with NoneType, False otherwise.
"""
origin = get_origin(python_type)
if origin is Union:
return type(None) in get_args(python_type)
return False
def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a list type.
:param item_type: The type of items in the list.
:param description: The description of the list.
:returns: A dictionary representing the list schema.
"""
items_schema = self._create_property_schema(item_type, "")
items_schema.pop("description", None)
return {"type": "array", "description": description, "items": items_schema}
def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a dataclass.
:param python_type: The dataclass type.
:param description: The description of the dataclass.
:returns: A dictionary representing the dataclass schema.
"""
schema = {"type": "object", "description": description, "properties": {}}
cls = python_type if isinstance(python_type, type) else python_type.__class__
for field in fields(cls):
field_description = f"Field '{field.name}' of '{cls.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][field.name] = self._create_property_schema(field.type, field_description)
return schema
@staticmethod
def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a basic Python type.
:param python_type: The Python type.
:param description: The description of the type.
:returns: A dictionary representing the basic type schema.
"""
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
return {"type": type_mapping.get(python_type, "string"), "description": description}
def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
"""
Creates a property schema for a given Python type, recursively if necessary.
:param python_type: The Python type to create a property schema for.
:param description: The description of the property.
:param default: The default value of the property.
:returns: A dictionary representing the property schema.
:raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models
"""
nullable = self._is_nullable_type(python_type)
if nullable:
non_none_types = [t for t in get_args(python_type) if t is not type(None)]
python_type = non_none_types[0] if non_none_types else str
origin = get_origin(python_type)
if origin is list:
schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
elif is_dataclass(python_type):
schema = self._create_dataclass_schema(python_type, description)
elif hasattr(python_type, "model_validate"):
raise SchemaGenerationError(
f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for "
f"component's run method."
)
else:
schema = self._create_basic_type_schema(python_type, description)
if default is not None:
schema["default"] = default
return schema

View File

@ -210,9 +210,16 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
:param schema:
The JSON schema to remove the 'title' keyword from.
"""
schema.pop("title", None)
for property_schema in schema["properties"].values():
for key in list(property_schema.keys()):
if key == "title":
del property_schema[key]
for key, value in list(schema.items()):
# Make sure not to remove parameters named title
if key == "properties" and isinstance(value, dict) and "title" in value:
for sub_val in value.values():
_remove_title_from_schema(sub_val)
elif key == "title":
del schema[key]
elif isinstance(value, dict):
_remove_title_from_schema(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
_remove_title_from_schema(item)

View File

@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import collections
from dataclasses import MISSING, fields, is_dataclass
from inspect import getdoc
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, get_args, get_origin
from pydantic import BaseModel, Field, create_model
from haystack import logging
from haystack.dataclasses import ChatMessage
from haystack.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
from docstring_parser import parse
logger = logging.getLogger(__name__)
def _get_param_descriptions(method: Callable) -> Tuple[str, Dict[str, str]]:
"""
Extracts parameter descriptions from the method's docstring using docstring_parser.
:param method: The method to extract parameter descriptions from.
:returns:
A tuple including the short description of the method and a dictionary mapping parameter names to their
descriptions.
"""
docstring = getdoc(method)
if not docstring:
return "", {}
docstring_parser_import.check()
parsed_doc = parse(docstring)
param_descriptions = {}
for param in parsed_doc.params:
if not param.description:
logger.warning(
"Missing description for parameter '%s'. Please add a description in the component's "
"run() method docstring using the format ':param %%s: <description>'. "
"This description helps the LLM understand how to use this parameter." % param.arg_name
)
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
return parsed_doc.short_description or "", param_descriptions
def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]:
"""
Convert a Python dataclass to an equivalent Pydantic model.
:param dc_type: The dataclass type to convert.
:returns:
A dynamically generated Pydantic model class with fields and types derived from the dataclass definition.
Field descriptions are extracted from docstrings when available.
"""
_, param_descriptions = _get_param_descriptions(dc_type)
cls = dc_type if isinstance(dc_type, type) else dc_type.__class__
field_defs: Dict[str, Any] = {}
for field in fields(dc_type):
f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type)
default = field.default if field.default is not MISSING else ...
default = field.default_factory() if callable(field.default_factory) else default
# Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores
field_name = field.name
if dc_type is ChatMessage and field_name.startswith("_"):
# We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore
field_name = field_name[1:]
description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.")
field_defs[field_name] = (f_type, Field(default, description=description))
model = create_model(cls.__name__, **field_defs)
return model
def _resolve_type(_type: Any) -> Any:
"""
Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types.
This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types
it encounters into corresponding Pydantic models.
:param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model.
For generic types (like List[SomeDataclass]), the inner types are also resolved recursively.
:returns:
A fully resolved type, with all dataclass types converted to Pydantic models
"""
if is_dataclass(_type):
return _dataclass_to_pydantic_model(_type)
origin = get_origin(_type)
args = get_args(_type)
if origin is list:
return List[_resolve_type(args[0]) if args else Any] # type: ignore[misc]
if origin is collections.abc.Sequence:
return Sequence[_resolve_type(args[0]) if args else Any] # type: ignore[misc]
if origin is Union:
return Union[tuple(_resolve_type(a) for a in args)] # type: ignore[misc]
if origin is dict:
return Dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any] # type: ignore[misc]
return _type

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Refactored JSON Schema generation for ComponentTool parameters using Pydantics model_json_schema, enabling expanded type support (e.g. Union, Enum, Dict, etc.).
We also convert dataclasses to Pydantic models before calling model_json_schema to preserve docstring descriptions of the parameters in the schema.
This means dataclasses like ChatMessage, Document, etc. now have correctly defined JSON schemas.

View File

@ -2,36 +2,62 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, Dict, List, Optional
import pytest
from openai import OpenAIError
from haystack import Pipeline
from haystack import component, Pipeline
from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.tool import Tool
from haystack.tools import ComponentTool, Tool
from haystack.tools.toolset import Toolset
from haystack.utils.auth import Secret
from haystack.utils.azure import default_azure_ad_token_provider
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
def get_weather(city: str) -> Dict[str, Any]:
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@component
class MessageExtractor:
@component.output_types(messages=List[str], meta=Dict[str, Any])
def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Extracts the text content of ChatMessage objects
:param messages: List of Haystack ChatMessage objects
:param meta: Optional metadata to include in the response.
:returns:
A dictionary with keys "messages" and "meta".
"""
if meta is None:
meta = {}
return {"messages": [m.text for m in messages], "meta": meta}
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=get_weather,
)
return [tool]
# We add a tool that has a more complex parameter signature
message_extractor_tool = ComponentTool(
component=MessageExtractor(),
name="message_extractor",
description="Useful for returning the text content of ChatMessage objects",
)
return [weather_tool, message_extractor_tool]
class TestAzureOpenAIChatGenerator:
@ -307,7 +333,7 @@ class TestAzureOpenAIChatGenerator:
def test_to_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
toolset = Toolset(tools[:1])
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
data = component.to_dict()

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
import os
from typing import Any, Dict
from unittest.mock import MagicMock, Mock, AsyncMock, patch
import pytest
@ -43,22 +44,24 @@ def chat_messages():
]
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
def get_weather(city: str) -> Dict[str, Any]:
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=get_weather,
)
return [tool]
return [weather_tool]
@pytest.fixture
@ -974,7 +977,7 @@ class TestHuggingFaceAPIChatGenerator:
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
toolset = Toolset(tools[:1])
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)

View File

@ -8,6 +8,7 @@ import pytest
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
@ -16,11 +17,12 @@ from openai.types.completion_usage import CompletionTokensDetails, CompletionUsa
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk
from haystack import component
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.tools import ComponentTool, Tool
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.tools.toolset import Toolset
@ -72,21 +74,47 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream):
yield mock_chat_completion_create
def mock_tool_function(x):
return x
def weather_function(city: str) -> Dict[str, Any]:
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@component
class MessageExtractor:
@component.output_types(messages=List[str], meta=Dict[str, Any])
def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Extracts the text content of ChatMessage objects
:param messages: List of Haystack ChatMessage objects
:param meta: Optional metadata to include in the response.
:returns:
A dictionary with keys "messages" and "meta".
"""
if meta is None:
meta = {}
return {"messages": [m.text for m in messages], "meta": meta}
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=mock_tool_function,
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)
return [tool]
# We add a tool that has a more complex parameter signature
message_extractor_tool = ComponentTool(
component=MessageExtractor(),
name="message_extractor",
description="Useful for returning the text content of ChatMessage objects",
)
return [weather_tool, message_extractor_tool]
class TestOpenAIChatGenerator:
@ -462,7 +490,9 @@ class TestOpenAIChatGenerator:
mock_chat_completion_create.return_value = completion
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools, tools_strict=True)
component = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), tools=tools[:1], tools_strict=True
)
response = component.run([ChatMessage.from_user("What's the weather like in Paris?")])
# ensure that the tools are passed to the OpenAI API

View File

@ -24,7 +24,10 @@ from haystack.dataclasses import ChatMessage, ChatRole, Document
from haystack.tools import ComponentTool
from haystack.utils.auth import Secret
### Component and Model Definitions
from test.tools.test_parameters_schema_utils import BYTE_STREAM_SCHEMA, DOCUMENT_SCHEMA, SPARSE_EMBEDDING_SCHEMA
# Component and Model Definitions
@component
@ -134,17 +137,18 @@ def output_handler(old, new):
return old + new
## Unit tests
class TestToolComponent:
# TODO Add test for Builder components that have dynamic input types
# Does create_parameters schema work in these cases?
# Unit tests
class TestComponentTool:
def test_from_component_basic(self):
component = SimpleComponent()
tool = ComponentTool(component=component)
tool = ComponentTool(component=SimpleComponent())
assert tool.name == "simple_component"
assert tool.description == "A simple component that generates text."
assert tool.parameters == {
"type": "object",
"description": "A simple component that generates text.",
"properties": {"text": {"type": "string", "description": "user's name"}},
"required": ["text"],
}
@ -156,44 +160,39 @@ class TestToolComponent:
assert result["reply"] == "Hello, world!"
def test_from_component_long_description(self):
component = SimpleComponent()
tool = ComponentTool(component=component, description="".join(["A"] * 1024))
tool = ComponentTool(component=SimpleComponent(), description="".join(["A"] * 1024))
assert len(tool.description) == 1024
def test_from_component_with_inputs(self):
component = SimpleComponent()
tool = ComponentTool(component=component, inputs_from_state={"text": "text"})
tool = ComponentTool(component=SimpleComponent(), inputs_from_state={"text": "text"})
assert tool.inputs_from_state == {"text": "text"}
# Inputs should be excluded from schema generation
assert tool.parameters == {"type": "object", "properties": {}}
assert tool.parameters == {
"type": "object",
"properties": {},
"description": "A simple component that generates text.",
}
def test_from_component_with_outputs(self):
component = SimpleComponent()
tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}})
tool = ComponentTool(component=SimpleComponent(), outputs_to_state={"replies": {"source": "reply"}})
assert tool.outputs_to_state == {"replies": {"source": "reply"}}
def test_from_component_with_dataclass(self):
component = UserGreeter()
tool = ComponentTool(component=component)
tool = ComponentTool(component=UserGreeter())
assert tool.parameters == {
"type": "object",
"properties": {
"user": {
"type": "object",
"description": "The User object to process.",
"$defs": {
"User": {
"properties": {
"name": {"type": "string", "description": "Field 'name' of 'User'."},
"age": {"type": "integer", "description": "Field 'age' of 'User'."},
"name": {"description": "Field 'name' of 'User'.", "type": "string", "default": "Anonymous"},
"age": {"description": "Field 'age' of 'User'.", "type": "integer", "default": 0},
},
"type": "object",
}
},
"description": "A simple component that processes a User.",
"properties": {"user": {"$ref": "#/$defs/User", "description": "The User object to process."}},
"required": ["user"],
"type": "object",
}
assert tool.name == "user_greeter"
@ -206,14 +205,13 @@ class TestToolComponent:
assert result["message"] == "User Alice is 30 years old"
def test_from_component_with_list_input(self):
component = ListProcessor()
tool = ComponentTool(
component=component, name="list_processing_tool", description="A tool that concatenates strings"
component=ListProcessor(), name="list_processing_tool", description="A tool that concatenates strings"
)
assert tool.parameters == {
"type": "object",
"description": "Concatenates a list of strings into a single string.",
"properties": {
"texts": {
"type": "array",
@ -231,30 +229,33 @@ class TestToolComponent:
assert result["concatenated"] == "hello world"
def test_from_component_with_nested_dataclass(self):
component = PersonProcessor()
tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people")
tool = ComponentTool(
component=PersonProcessor(), name="person_tool", description="A tool that processes people"
)
assert tool.parameters == {
"type": "object",
"properties": {
"person": {
"type": "object",
"description": "The Person to process.",
"$defs": {
"Address": {
"properties": {
"name": {"type": "string", "description": "Field 'name' of 'Person'."},
"address": {
"type": "object",
"description": "Field 'address' of 'Person'.",
"properties": {
"street": {"type": "string", "description": "Field 'street' of 'Address'."},
"city": {"type": "string", "description": "Field 'city' of 'Address'."},
},
},
"street": {"description": "Field 'street' of 'Address'.", "type": "string"},
"city": {"description": "Field 'city' of 'Address'.", "type": "string"},
},
}
"required": ["street", "city"],
"type": "object",
},
"Person": {
"properties": {
"name": {"description": "Field 'name' of 'Person'.", "type": "string"},
"address": {"$ref": "#/$defs/Address", "description": "Field 'address' of 'Person'."},
},
"required": ["name", "address"],
"type": "object",
},
},
"description": "Creates information about the person.",
"properties": {"person": {"$ref": "#/$defs/Person", "description": "The Person to process."}},
"required": ["person"],
"type": "object",
}
# Test tool invocation
@ -264,64 +265,29 @@ class TestToolComponent:
assert result["info"] == "Diana lives at 123 Elm Street, Metropolis."
def test_from_component_with_document_list(self):
component = DocumentProcessor()
tool = ComponentTool(
component=component, name="document_processor", description="A tool that concatenates document contents"
component=DocumentProcessor(),
name="document_processor",
description="A tool that concatenates document contents",
)
assert tool.parameters == {
"type": "object",
"$defs": {
"ByteStream": BYTE_STREAM_SCHEMA,
"Document": DOCUMENT_SCHEMA,
"SparseEmbedding": SPARSE_EMBEDDING_SCHEMA,
},
"description": "Concatenates the content of multiple documents with newlines.",
"properties": {
"documents": {
"type": "array",
"description": "List of Documents whose content will be concatenated",
"items": {
"type": "object",
"properties": {
"id": {"type": "string", "description": "Field 'id' of 'Document'."},
"content": {"type": "string", "description": "Field 'content' of 'Document'."},
"blob": {
"type": "object",
"description": "Field 'blob' of 'Document'.",
"properties": {
"data": {"type": "string", "description": "Field 'data' of 'ByteStream'."},
"meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."},
"mime_type": {
"type": "string",
"description": "Field 'mime_type' of 'ByteStream'.",
},
},
},
"meta": {"type": "string", "description": "Field 'meta' of 'Document'."},
"score": {"type": "number", "description": "Field 'score' of 'Document'."},
"embedding": {
"type": "array",
"description": "Field 'embedding' of 'Document'.",
"items": {"type": "number"},
},
"sparse_embedding": {
"type": "object",
"description": "Field 'sparse_embedding' of 'Document'.",
"properties": {
"indices": {
"type": "array",
"description": "Field 'indices' of 'SparseEmbedding'.",
"items": {"type": "integer"},
},
"values": {
"type": "array",
"description": "Field 'values' of 'SparseEmbedding'.",
"items": {"type": "number"},
},
},
},
},
},
"items": {"$ref": "#/$defs/Document"},
"type": "array",
},
"top_k": {"description": "The number of top documents to concatenate", "type": "integer"},
"top_k": {"description": "The number of top documents to concatenate", "type": "integer", "default": 5},
},
"required": ["documents"],
"type": "object",
}
# Test tool invocation
@ -341,15 +307,16 @@ class TestToolComponent:
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail")
## Integration tests
# Integration tests
class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_component_tool_in_pipeline(self):
# Create component and convert it to tool
component = SimpleComponent()
tool = ComponentTool(
component=component, name="hello_tool", description="A tool that generates a greeting message for the user"
component=SimpleComponent(),
name="hello_tool",
description="A tool that generates a greeting message for the user",
)
# Create pipeline with OpenAIChatGenerator and ToolInvoker
@ -378,9 +345,10 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.integration
def test_component_tool_in_pipeline_openai_tools_strict(self):
# Create component and convert it to tool
component = SimpleComponent()
tool = ComponentTool(
component=component, name="hello_tool", description="A tool that generates a greeting message for the user"
component=SimpleComponent(),
name="hello_tool",
description="A tool that generates a greeting message for the user",
)
# Create pipeline with OpenAIChatGenerator and ToolInvoker
@ -408,9 +376,8 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_user_greeter_in_pipeline(self):
component = UserGreeter()
tool = ComponentTool(
component=component, name="user_greeter", description="A tool that greets users with their name and age"
component=UserGreeter(), name="user_greeter", description="A tool that greets users with their name and age"
)
pipeline = Pipeline()
@ -432,9 +399,8 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_list_processor_in_pipeline(self):
component = ListProcessor()
tool = ComponentTool(
component=component, name="list_processor", description="A tool that concatenates a list of strings"
component=ListProcessor(), name="list_processor", description="A tool that concatenates a list of strings"
)
pipeline = Pipeline()
@ -456,9 +422,8 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_person_processor_in_pipeline(self):
component = PersonProcessor()
tool = ComponentTool(
component=component,
component=PersonProcessor(),
name="person_processor",
description="A tool that processes information about a person and their address",
)
@ -482,9 +447,8 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_document_processor_in_pipeline(self):
component = DocumentProcessor()
tool = ComponentTool(
component=component,
component=DocumentProcessor(),
name="document_processor",
description="A tool that concatenates the content of multiple documents",
)
@ -516,9 +480,8 @@ class TestToolComponentInPipelineWithOpenAI:
def test_lost_in_middle_ranker_in_pipeline(self):
from haystack.components.rankers import LostInTheMiddleRanker
component = LostInTheMiddleRanker()
tool = ComponentTool(
component=component,
component=LostInTheMiddleRanker(),
name="lost_in_middle_ranker",
description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results",
)
@ -543,9 +506,10 @@ class TestToolComponentInPipelineWithOpenAI:
@pytest.mark.skipif(not os.environ.get("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY not set")
@pytest.mark.integration
def test_serper_dev_web_search_in_pipeline(self):
component = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
tool = ComponentTool(
component=component, name="web_search", description="Search the web for current information on any topic"
component=SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3),
name="web_search",
description="Search the web for current information on any topic",
)
pipeline = Pipeline()
@ -603,10 +567,8 @@ class TestToolComponentInPipelineWithOpenAI:
assert new_pipeline == pipeline
def test_component_tool_serde(self):
component = SimpleComponent()
tool = ComponentTool(
component=component,
component=SimpleComponent(),
name="simple_tool",
description="A simple tool",
inputs_from_state={"test": "input"},
@ -632,16 +594,16 @@ class TestToolComponentInPipelineWithOpenAI:
assert isinstance(new_tool._component, SimpleComponent)
def test_pipeline_component_fails(self):
component = SimpleComponent()
comp = SimpleComponent()
# Create a pipeline and add the component to it
pipeline = Pipeline()
pipeline.add_component("simple", component)
pipeline.add_component("simple", comp)
# Try to create a tool from the component and it should fail because the component has been added to a pipeline and
# thus can't be used as tool
with pytest.raises(ValueError, match="Component has been added to a pipeline"):
ComponentTool(component=component)
ComponentTool(component=comp)
def test_deepcopy_with_jinja_based_component(self):
builder = PromptBuilder("{{query}}")

View File

@ -0,0 +1,239 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from typing import List
from haystack.dataclasses import ByteStream, ChatMessage, Document, TextContent, ToolCall, ToolCallResult
from pydantic import Field, create_model
from haystack.tools.parameters_schema_utils import _resolve_type
from haystack.tools.from_function import _remove_title_from_schema
BYTE_STREAM_SCHEMA = {
"type": "object",
"properties": {
"data": {"type": "string", "description": "The binary data stored in Bytestream.", "format": "binary"},
"meta": {
"type": "object",
"default": {},
"description": "Additional metadata to be stored with the ByteStream.",
"additionalProperties": True,
},
"mime_type": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "The mime type of the binary data.",
},
},
"required": ["data"],
}
SPARSE_EMBEDDING_SCHEMA = {
"type": "object",
"properties": {
"indices": {
"type": "array",
"description": "List of indices of non-zero elements in the embedding.",
"items": {"type": "integer"},
},
"values": {
"type": "array",
"description": "List of values of non-zero elements in the embedding.",
"items": {"type": "number"},
},
},
"required": ["indices", "values"],
}
DOCUMENT_SCHEMA = {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the document. When not set, it's generated based on the Document fields' values.",
"default": "",
},
"content": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "Text of the document, if the document contains text.",
},
"blob": {
"anyOf": [{"$ref": "#/$defs/ByteStream"}, {"type": "null"}],
"default": None,
"description": "Binary data associated with the document, if the document has any binary data associated with it.",
},
"meta": {
"type": "object",
"description": "Additional custom metadata for the document. Must be JSON-serializable.",
"default": {},
"additionalProperties": True,
},
"score": {
"anyOf": [{"type": "number"}, {"type": "null"}],
"default": None,
"description": "Score of the document. Used for ranking, usually assigned by retrievers.",
},
"embedding": {
"anyOf": [{"type": "array", "items": {"type": "number"}}, {"type": "null"}],
"default": None,
"description": "dense vector representation of the document.",
},
"sparse_embedding": {
"anyOf": [{"$ref": "#/$defs/SparseEmbedding"}, {"type": "null"}],
"default": None,
"description": "sparse vector representation of the document.",
},
},
}
TEXT_CONTENT_SCHEMA = {
"type": "object",
"properties": {"text": {"type": "string", "description": "The text content of the message."}},
"required": ["text"],
}
TOOL_CALL_SCHEMA = {
"type": "object",
"properties": {
"tool_name": {"type": "string", "description": "The name of the Tool to call."},
"arguments": {
"type": "object",
"description": "The arguments to call the Tool with.",
"additionalProperties": True,
},
"id": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "The ID of the Tool call.",
},
},
"required": ["tool_name", "arguments"],
}
TOOL_CALL_RESULT_SCHEMA = {
"type": "object",
"properties": {
"result": {"type": "string", "description": "The result of the Tool invocation."},
"origin": {"$ref": "#/$defs/ToolCall", "description": "The Tool call that produced this result."},
"error": {"type": "boolean", "description": "Whether the Tool invocation resulted in an error."},
},
"required": ["result", "origin", "error"],
}
CHAT_ROLE_SCHEMA = {
"description": "Enumeration representing the roles within a chat.",
"enum": ["user", "system", "assistant", "tool"],
"type": "string",
}
CHAT_MESSAGE_SCHEMA = {
"type": "object",
"properties": {
"role": {"$ref": "#/$defs/ChatRole", "description": "Field 'role' of 'ChatMessage'."},
"content": {
"type": "array",
"description": "Field 'content' of 'ChatMessage'.",
"items": {
"anyOf": [
{"$ref": "#/$defs/TextContent"},
{"$ref": "#/$defs/ToolCall"},
{"$ref": "#/$defs/ToolCallResult"},
]
},
},
"name": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "Field 'name' of 'ChatMessage'.",
},
"meta": {
"type": "object",
"description": "Field 'meta' of 'ChatMessage'.",
"default": {},
"additionalProperties": True,
},
},
"required": ["role", "content"],
}
@pytest.mark.parametrize(
"python_type, description, expected_schema, expected_defs_schema",
[
(
ByteStream,
"A byte stream",
{"$ref": "#/$defs/ByteStream", "description": "A byte stream"},
{"ByteStream": BYTE_STREAM_SCHEMA},
),
(
Document,
"A document",
{"$ref": "#/$defs/Document", "description": "A document"},
{"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA},
),
(
TextContent,
"A text content",
{"$ref": "#/$defs/TextContent", "description": "A text content"},
{"TextContent": TEXT_CONTENT_SCHEMA},
),
(
ToolCall,
"A tool call",
{"$ref": "#/$defs/ToolCall", "description": "A tool call"},
{"ToolCall": TOOL_CALL_SCHEMA},
),
(
ToolCallResult,
"A tool call result",
{"$ref": "#/$defs/ToolCallResult", "description": "A tool call result"},
{"ToolCallResult": TOOL_CALL_RESULT_SCHEMA, "ToolCall": TOOL_CALL_SCHEMA},
),
(
ChatMessage,
"A chat message",
{"$ref": "#/$defs/ChatMessage", "description": "A chat message"},
{
"ChatMessage": CHAT_MESSAGE_SCHEMA,
"TextContent": TEXT_CONTENT_SCHEMA,
"ToolCall": TOOL_CALL_SCHEMA,
"ToolCallResult": TOOL_CALL_RESULT_SCHEMA,
"ChatRole": CHAT_ROLE_SCHEMA,
},
),
(
List[Document],
"A list of documents",
{"type": "array", "description": "A list of documents", "items": {"$ref": "#/$defs/Document"}},
{"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA},
),
(
List[ChatMessage],
"A list of chat messages",
{"type": "array", "description": "A list of chat messages", "items": {"$ref": "#/$defs/ChatMessage"}},
{
"ChatMessage": CHAT_MESSAGE_SCHEMA,
"TextContent": TEXT_CONTENT_SCHEMA,
"ToolCall": TOOL_CALL_SCHEMA,
"ToolCallResult": TOOL_CALL_RESULT_SCHEMA,
"ChatRole": CHAT_ROLE_SCHEMA,
},
),
],
)
def test_create_parameters_schema_haystack_dataclasses(python_type, description, expected_schema, expected_defs_schema):
resolved_type = _resolve_type(python_type)
fields = {"input_name": (resolved_type, Field(default=..., description=description))}
model = create_model("run", __doc__="A test function", **fields)
parameters_schema = model.model_json_schema()
_remove_title_from_schema(parameters_schema)
defs_schema = parameters_schema["$defs"]
assert defs_schema == expected_defs_schema
property_schema = parameters_schema["properties"]["input_name"]
assert property_schema == expected_schema