diff --git a/haystack/dataclasses/byte_stream.py b/haystack/dataclasses/byte_stream.py index 34b66add8..fac6e8eb4 100644 --- a/haystack/dataclasses/byte_stream.py +++ b/haystack/dataclasses/byte_stream.py @@ -11,6 +11,10 @@ from typing import Any, Dict, Optional class ByteStream: """ Base data class representing a binary object in the Haystack API. + + :param data: The binary data stored in Bytestream. + :param meta: Additional metadata to be stored with the ByteStream. + :param mime_type: The mime type of the binary data. """ data: bytes diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 1e58459d4..82ed1b744 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import fields, is_dataclass -from inspect import getdoc from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin -from pydantic import TypeAdapter +from pydantic import Field, TypeAdapter, create_model from haystack import logging from haystack.core.component import Component @@ -16,15 +14,12 @@ from haystack.core.serialization import ( generate_qualified_class_name, import_class_by_name, ) -from haystack.lazy_imports import LazyImport from haystack.tools import Tool from haystack.tools.errors import SchemaGenerationError +from haystack.tools.from_function import _remove_title_from_schema +from haystack.tools.parameters_schema_utils import _get_param_descriptions, _resolve_type from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import: - from docstring_parser import parse - - logger = logging.getLogger(__name__) @@ -275,10 +270,10 @@ class ComponentTool(Tool): :raises SchemaGenerationError: If schema generation fails :returns: OpenAI tools schema for the component's run method parameters. """ - properties = {} - required = [] + component_run_description, param_descriptions = _get_param_descriptions(component.run) - param_descriptions = self._get_param_descriptions(component.run) + # collect fields (types and defaults) and descriptions from function parameters + fields: Dict[str, Any] = {} for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined] if inputs_from_state is not None and input_name in inputs_from_state: @@ -286,135 +281,23 @@ class ComponentTool(Tool): input_type = socket.type description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") - try: - property_schema = self._create_property_schema(input_type, description) - except Exception as e: - raise SchemaGenerationError( - f"Error processing input '{input_name}': {e}. " - f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, " - f"and lists of these types as input types for component's run method." - ) from e + # if the parameter has not a default value, Pydantic requires an Ellipsis (...) + # to explicitly indicate that the parameter is required + default = ... if socket.is_mandatory else socket.default_value + resolved_type = _resolve_type(input_type) + fields[input_name] = (resolved_type, Field(default=default, description=description)) - properties[input_name] = property_schema + try: + model = create_model(component.run.__name__, __doc__=component_run_description, **fields) + parameters_schema = model.model_json_schema() + except Exception as e: + raise SchemaGenerationError( + f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'" + ) from e - # Use socket.is_mandatory to check if the input is required - if socket.is_mandatory: - required.append(input_name) - - parameters_schema = {"type": "object", "properties": properties} - - if required: - parameters_schema["required"] = required + # we don't want to include title keywords in the schema, as they contain redundant information + # there is no programmatic way to prevent Pydantic from adding them, so we remove them later + # see https://github.com/pydantic/pydantic/discussions/8504 + _remove_title_from_schema(parameters_schema) return parameters_schema - - @staticmethod - def _get_param_descriptions(method: Callable) -> Dict[str, str]: - """ - Extracts parameter descriptions from the method's docstring using docstring_parser. - - :param method: The method to extract parameter descriptions from. - :returns: A dictionary mapping parameter names to their descriptions. - """ - docstring = getdoc(method) - if not docstring: - return {} - - docstring_parser_import.check() - parsed_doc = parse(docstring) - param_descriptions = {} - for param in parsed_doc.params: - if not param.description: - logger.warning( - "Missing description for parameter '%s'. Please add a description in the component's " - "run() method docstring using the format ':param %%s: '. " - "This description helps the LLM understand how to use this parameter." % param.arg_name - ) - param_descriptions[param.arg_name] = param.description.strip() if param.description else "" - return param_descriptions - - @staticmethod - def _is_nullable_type(python_type: Any) -> bool: - """ - Checks if the type is a Union with NoneType (i.e., Optional). - - :param python_type: The Python type to check. - :returns: True if the type is a Union with NoneType, False otherwise. - """ - origin = get_origin(python_type) - if origin is Union: - return type(None) in get_args(python_type) - return False - - def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]: - """ - Creates a schema for a list type. - - :param item_type: The type of items in the list. - :param description: The description of the list. - :returns: A dictionary representing the list schema. - """ - items_schema = self._create_property_schema(item_type, "") - items_schema.pop("description", None) - return {"type": "array", "description": description, "items": items_schema} - - def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]: - """ - Creates a schema for a dataclass. - - :param python_type: The dataclass type. - :param description: The description of the dataclass. - :returns: A dictionary representing the dataclass schema. - """ - schema = {"type": "object", "description": description, "properties": {}} - cls = python_type if isinstance(python_type, type) else python_type.__class__ - for field in fields(cls): - field_description = f"Field '{field.name}' of '{cls.__name__}'." - if isinstance(schema["properties"], dict): - schema["properties"][field.name] = self._create_property_schema(field.type, field_description) - return schema - - @staticmethod - def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]: - """ - Creates a schema for a basic Python type. - - :param python_type: The Python type. - :param description: The description of the type. - :returns: A dictionary representing the basic type schema. - """ - type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} - return {"type": type_mapping.get(python_type, "string"), "description": description} - - def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: - """ - Creates a property schema for a given Python type, recursively if necessary. - - :param python_type: The Python type to create a property schema for. - :param description: The description of the property. - :param default: The default value of the property. - :returns: A dictionary representing the property schema. - :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models - """ - nullable = self._is_nullable_type(python_type) - if nullable: - non_none_types = [t for t in get_args(python_type) if t is not type(None)] - python_type = non_none_types[0] if non_none_types else str - - origin = get_origin(python_type) - if origin is list: - schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) - elif is_dataclass(python_type): - schema = self._create_dataclass_schema(python_type, description) - elif hasattr(python_type, "model_validate"): - raise SchemaGenerationError( - f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for " - f"component's run method." - ) - else: - schema = self._create_basic_type_schema(python_type, description) - - if default is not None: - schema["default"] = default - - return schema diff --git a/haystack/tools/from_function.py b/haystack/tools/from_function.py index f48579b9f..edea78420 100644 --- a/haystack/tools/from_function.py +++ b/haystack/tools/from_function.py @@ -210,9 +210,16 @@ def _remove_title_from_schema(schema: Dict[str, Any]): :param schema: The JSON schema to remove the 'title' keyword from. """ - schema.pop("title", None) - - for property_schema in schema["properties"].values(): - for key in list(property_schema.keys()): - if key == "title": - del property_schema[key] + for key, value in list(schema.items()): + # Make sure not to remove parameters named title + if key == "properties" and isinstance(value, dict) and "title" in value: + for sub_val in value.values(): + _remove_title_from_schema(sub_val) + elif key == "title": + del schema[key] + elif isinstance(value, dict): + _remove_title_from_schema(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + _remove_title_from_schema(item) diff --git a/haystack/tools/parameters_schema_utils.py b/haystack/tools/parameters_schema_utils.py new file mode 100644 index 000000000..c322785a7 --- /dev/null +++ b/haystack/tools/parameters_schema_utils.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import collections +from dataclasses import MISSING, fields, is_dataclass +from inspect import getdoc +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, get_args, get_origin + +from pydantic import BaseModel, Field, create_model + +from haystack import logging +from haystack.dataclasses import ChatMessage +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import: + from docstring_parser import parse + + +logger = logging.getLogger(__name__) + + +def _get_param_descriptions(method: Callable) -> Tuple[str, Dict[str, str]]: + """ + Extracts parameter descriptions from the method's docstring using docstring_parser. + + :param method: The method to extract parameter descriptions from. + :returns: + A tuple including the short description of the method and a dictionary mapping parameter names to their + descriptions. + """ + docstring = getdoc(method) + if not docstring: + return "", {} + + docstring_parser_import.check() + parsed_doc = parse(docstring) + param_descriptions = {} + for param in parsed_doc.params: + if not param.description: + logger.warning( + "Missing description for parameter '%s'. Please add a description in the component's " + "run() method docstring using the format ':param %%s: '. " + "This description helps the LLM understand how to use this parameter." % param.arg_name + ) + param_descriptions[param.arg_name] = param.description.strip() if param.description else "" + return parsed_doc.short_description or "", param_descriptions + + +def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]: + """ + Convert a Python dataclass to an equivalent Pydantic model. + + :param dc_type: The dataclass type to convert. + :returns: + A dynamically generated Pydantic model class with fields and types derived from the dataclass definition. + Field descriptions are extracted from docstrings when available. + """ + _, param_descriptions = _get_param_descriptions(dc_type) + cls = dc_type if isinstance(dc_type, type) else dc_type.__class__ + + field_defs: Dict[str, Any] = {} + for field in fields(dc_type): + f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type) + default = field.default if field.default is not MISSING else ... + default = field.default_factory() if callable(field.default_factory) else default + + # Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores + field_name = field.name + if dc_type is ChatMessage and field_name.startswith("_"): + # We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore + field_name = field_name[1:] + + description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.") + field_defs[field_name] = (f_type, Field(default, description=description)) + + model = create_model(cls.__name__, **field_defs) + return model + + +def _resolve_type(_type: Any) -> Any: + """ + Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types. + + This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types + it encounters into corresponding Pydantic models. + + :param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model. + For generic types (like List[SomeDataclass]), the inner types are also resolved recursively. + + :returns: + A fully resolved type, with all dataclass types converted to Pydantic models + """ + if is_dataclass(_type): + return _dataclass_to_pydantic_model(_type) + + origin = get_origin(_type) + args = get_args(_type) + + if origin is list: + return List[_resolve_type(args[0]) if args else Any] # type: ignore[misc] + + if origin is collections.abc.Sequence: + return Sequence[_resolve_type(args[0]) if args else Any] # type: ignore[misc] + + if origin is Union: + return Union[tuple(_resolve_type(a) for a in args)] # type: ignore[misc] + + if origin is dict: + return Dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any] # type: ignore[misc] + + return _type diff --git a/releasenotes/notes/expand-json-schema-parameter-generation-e4d60bbeada14b6c.yaml b/releasenotes/notes/expand-json-schema-parameter-generation-e4d60bbeada14b6c.yaml new file mode 100644 index 000000000..6de36b01d --- /dev/null +++ b/releasenotes/notes/expand-json-schema-parameter-generation-e4d60bbeada14b6c.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Refactored JSON Schema generation for ComponentTool parameters using Pydantic’s model_json_schema, enabling expanded type support (e.g. Union, Enum, Dict, etc.). + We also convert dataclasses to Pydantic models before calling model_json_schema to preserve docstring descriptions of the parameters in the schema. + This means dataclasses like ChatMessage, Document, etc. now have correctly defined JSON schemas. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index ad1c23feb..e5afcce04 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -2,36 +2,62 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from typing import Any, Dict, List, Optional import pytest from openai import OpenAIError -from haystack import Pipeline +from haystack import component, Pipeline from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ToolCall -from haystack.tools.tool import Tool +from haystack.tools import ComponentTool, Tool from haystack.tools.toolset import Toolset from haystack.utils.auth import Secret from haystack.utils.azure import default_azure_ad_token_provider -def get_weather(city: str) -> str: - """Get weather information for a city.""" - return f"Weather info for {city}" +def get_weather(city: str) -> Dict[str, Any]: + weather_info = { + "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, + "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, + "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, + } + return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) + + +@component +class MessageExtractor: + @component.output_types(messages=List[str], meta=Dict[str, Any]) + def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Extracts the text content of ChatMessage objects + + :param messages: List of Haystack ChatMessage objects + :param meta: Optional metadata to include in the response. + :returns: + A dictionary with keys "messages" and "meta". + """ + if meta is None: + meta = {} + return {"messages": [m.text for m in messages], "meta": meta} @pytest.fixture def tools(): - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - tool = Tool( + weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", - parameters=tool_parameters, + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, function=get_weather, ) - - return [tool] + # We add a tool that has a more complex parameter signature + message_extractor_tool = ComponentTool( + component=MessageExtractor(), + name="message_extractor", + description="Useful for returning the text content of ChatMessage objects", + ) + return [weather_tool, message_extractor_tool] class TestAzureOpenAIChatGenerator: @@ -307,7 +333,7 @@ class TestAzureOpenAIChatGenerator: def test_to_dict_with_toolset(self, tools, monkeypatch): """Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset.""" monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") - toolset = Toolset(tools) + toolset = Toolset(tools[:1]) component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset) data = component.to_dict() diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index f376886c5..f12a37ee3 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime import os +from typing import Any, Dict from unittest.mock import MagicMock, Mock, AsyncMock, patch import pytest @@ -43,22 +44,24 @@ def chat_messages(): ] -def get_weather(city: str) -> str: - """Get weather information for a city.""" - return f"Weather info for {city}" +def get_weather(city: str) -> Dict[str, Any]: + weather_info = { + "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, + "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, + "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, + } + return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) @pytest.fixture def tools(): - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - tool = Tool( + weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", - parameters=tool_parameters, + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, function=get_weather, ) - - return [tool] + return [weather_tool] @pytest.fixture @@ -974,7 +977,7 @@ class TestHuggingFaceAPIChatGenerator: def test_to_dict_with_toolset(self, mock_check_valid_model, tools): """Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset.""" - toolset = Toolset(tools) + toolset = Toolset(tools[:1]) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset ) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 343a4f9fb..d393e33a1 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -8,6 +8,7 @@ import pytest import logging import os from datetime import datetime +from typing import Any, Dict, List, Optional from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall @@ -16,11 +17,12 @@ from openai.types.completion_usage import CompletionTokensDetails, CompletionUsa from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk +from haystack import component from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall -from haystack.tools import Tool +from haystack.tools import ComponentTool, Tool from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.tools.toolset import Toolset @@ -72,21 +74,47 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream): yield mock_chat_completion_create -def mock_tool_function(x): - return x +def weather_function(city: str) -> Dict[str, Any]: + weather_info = { + "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, + "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, + "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, + } + return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) + + +@component +class MessageExtractor: + @component.output_types(messages=List[str], meta=Dict[str, Any]) + def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Extracts the text content of ChatMessage objects + + :param messages: List of Haystack ChatMessage objects + :param meta: Optional metadata to include in the response. + :returns: + A dictionary with keys "messages" and "meta". + """ + if meta is None: + meta = {} + return {"messages": [m.text for m in messages], "meta": meta} @pytest.fixture def tools(): - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - tool = Tool( + weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", - parameters=tool_parameters, - function=mock_tool_function, + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=weather_function, ) - - return [tool] + # We add a tool that has a more complex parameter signature + message_extractor_tool = ComponentTool( + component=MessageExtractor(), + name="message_extractor", + description="Useful for returning the text content of ChatMessage objects", + ) + return [weather_tool, message_extractor_tool] class TestOpenAIChatGenerator: @@ -462,7 +490,9 @@ class TestOpenAIChatGenerator: mock_chat_completion_create.return_value = completion - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools, tools_strict=True) + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), tools=tools[:1], tools_strict=True + ) response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) # ensure that the tools are passed to the OpenAI API diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index e2f292fa0..549200a66 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -24,7 +24,10 @@ from haystack.dataclasses import ChatMessage, ChatRole, Document from haystack.tools import ComponentTool from haystack.utils.auth import Secret -### Component and Model Definitions +from test.tools.test_parameters_schema_utils import BYTE_STREAM_SCHEMA, DOCUMENT_SCHEMA, SPARSE_EMBEDDING_SCHEMA + + +# Component and Model Definitions @component @@ -134,17 +137,18 @@ def output_handler(old, new): return old + new -## Unit tests -class TestToolComponent: +# TODO Add test for Builder components that have dynamic input types +# Does create_parameters schema work in these cases? +# Unit tests +class TestComponentTool: def test_from_component_basic(self): - component = SimpleComponent() - - tool = ComponentTool(component=component) + tool = ComponentTool(component=SimpleComponent()) assert tool.name == "simple_component" assert tool.description == "A simple component that generates text." assert tool.parameters == { "type": "object", + "description": "A simple component that generates text.", "properties": {"text": {"type": "string", "description": "user's name"}}, "required": ["text"], } @@ -156,44 +160,39 @@ class TestToolComponent: assert result["reply"] == "Hello, world!" def test_from_component_long_description(self): - component = SimpleComponent() - tool = ComponentTool(component=component, description="".join(["A"] * 1024)) - + tool = ComponentTool(component=SimpleComponent(), description="".join(["A"] * 1024)) assert len(tool.description) == 1024 def test_from_component_with_inputs(self): - component = SimpleComponent() - - tool = ComponentTool(component=component, inputs_from_state={"text": "text"}) - + tool = ComponentTool(component=SimpleComponent(), inputs_from_state={"text": "text"}) assert tool.inputs_from_state == {"text": "text"} # Inputs should be excluded from schema generation - assert tool.parameters == {"type": "object", "properties": {}} + assert tool.parameters == { + "type": "object", + "properties": {}, + "description": "A simple component that generates text.", + } def test_from_component_with_outputs(self): - component = SimpleComponent() - - tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}}) - + tool = ComponentTool(component=SimpleComponent(), outputs_to_state={"replies": {"source": "reply"}}) assert tool.outputs_to_state == {"replies": {"source": "reply"}} def test_from_component_with_dataclass(self): - component = UserGreeter() - - tool = ComponentTool(component=component) + tool = ComponentTool(component=UserGreeter()) assert tool.parameters == { - "type": "object", - "properties": { - "user": { - "type": "object", - "description": "The User object to process.", + "$defs": { + "User": { "properties": { - "name": {"type": "string", "description": "Field 'name' of 'User'."}, - "age": {"type": "integer", "description": "Field 'age' of 'User'."}, + "name": {"description": "Field 'name' of 'User'.", "type": "string", "default": "Anonymous"}, + "age": {"description": "Field 'age' of 'User'.", "type": "integer", "default": 0}, }, + "type": "object", } }, + "description": "A simple component that processes a User.", + "properties": {"user": {"$ref": "#/$defs/User", "description": "The User object to process."}}, "required": ["user"], + "type": "object", } assert tool.name == "user_greeter" @@ -206,14 +205,13 @@ class TestToolComponent: assert result["message"] == "User Alice is 30 years old" def test_from_component_with_list_input(self): - component = ListProcessor() - tool = ComponentTool( - component=component, name="list_processing_tool", description="A tool that concatenates strings" + component=ListProcessor(), name="list_processing_tool", description="A tool that concatenates strings" ) assert tool.parameters == { "type": "object", + "description": "Concatenates a list of strings into a single string.", "properties": { "texts": { "type": "array", @@ -231,30 +229,33 @@ class TestToolComponent: assert result["concatenated"] == "hello world" def test_from_component_with_nested_dataclass(self): - component = PersonProcessor() - - tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people") + tool = ComponentTool( + component=PersonProcessor(), name="person_tool", description="A tool that processes people" + ) assert tool.parameters == { - "type": "object", - "properties": { - "person": { - "type": "object", - "description": "The Person to process.", + "$defs": { + "Address": { "properties": { - "name": {"type": "string", "description": "Field 'name' of 'Person'."}, - "address": { - "type": "object", - "description": "Field 'address' of 'Person'.", - "properties": { - "street": {"type": "string", "description": "Field 'street' of 'Address'."}, - "city": {"type": "string", "description": "Field 'city' of 'Address'."}, - }, - }, + "street": {"description": "Field 'street' of 'Address'.", "type": "string"}, + "city": {"description": "Field 'city' of 'Address'.", "type": "string"}, }, - } + "required": ["street", "city"], + "type": "object", + }, + "Person": { + "properties": { + "name": {"description": "Field 'name' of 'Person'.", "type": "string"}, + "address": {"$ref": "#/$defs/Address", "description": "Field 'address' of 'Person'."}, + }, + "required": ["name", "address"], + "type": "object", + }, }, + "description": "Creates information about the person.", + "properties": {"person": {"$ref": "#/$defs/Person", "description": "The Person to process."}}, "required": ["person"], + "type": "object", } # Test tool invocation @@ -264,64 +265,29 @@ class TestToolComponent: assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." def test_from_component_with_document_list(self): - component = DocumentProcessor() - tool = ComponentTool( - component=component, name="document_processor", description="A tool that concatenates document contents" + component=DocumentProcessor(), + name="document_processor", + description="A tool that concatenates document contents", ) assert tool.parameters == { - "type": "object", + "$defs": { + "ByteStream": BYTE_STREAM_SCHEMA, + "Document": DOCUMENT_SCHEMA, + "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, + }, + "description": "Concatenates the content of multiple documents with newlines.", "properties": { "documents": { - "type": "array", "description": "List of Documents whose content will be concatenated", - "items": { - "type": "object", - "properties": { - "id": {"type": "string", "description": "Field 'id' of 'Document'."}, - "content": {"type": "string", "description": "Field 'content' of 'Document'."}, - "blob": { - "type": "object", - "description": "Field 'blob' of 'Document'.", - "properties": { - "data": {"type": "string", "description": "Field 'data' of 'ByteStream'."}, - "meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."}, - "mime_type": { - "type": "string", - "description": "Field 'mime_type' of 'ByteStream'.", - }, - }, - }, - "meta": {"type": "string", "description": "Field 'meta' of 'Document'."}, - "score": {"type": "number", "description": "Field 'score' of 'Document'."}, - "embedding": { - "type": "array", - "description": "Field 'embedding' of 'Document'.", - "items": {"type": "number"}, - }, - "sparse_embedding": { - "type": "object", - "description": "Field 'sparse_embedding' of 'Document'.", - "properties": { - "indices": { - "type": "array", - "description": "Field 'indices' of 'SparseEmbedding'.", - "items": {"type": "integer"}, - }, - "values": { - "type": "array", - "description": "Field 'values' of 'SparseEmbedding'.", - "items": {"type": "number"}, - }, - }, - }, - }, - }, + "items": {"$ref": "#/$defs/Document"}, + "type": "array", }, - "top_k": {"description": "The number of top documents to concatenate", "type": "integer"}, + "top_k": {"description": "The number of top documents to concatenate", "type": "integer", "default": 5}, }, "required": ["documents"], + "type": "object", } # Test tool invocation @@ -341,15 +307,16 @@ class TestToolComponent: ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") -## Integration tests +# Integration tests class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_component_tool_in_pipeline(self): # Create component and convert it to tool - component = SimpleComponent() tool = ComponentTool( - component=component, name="hello_tool", description="A tool that generates a greeting message for the user" + component=SimpleComponent(), + name="hello_tool", + description="A tool that generates a greeting message for the user", ) # Create pipeline with OpenAIChatGenerator and ToolInvoker @@ -378,9 +345,10 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.integration def test_component_tool_in_pipeline_openai_tools_strict(self): # Create component and convert it to tool - component = SimpleComponent() tool = ComponentTool( - component=component, name="hello_tool", description="A tool that generates a greeting message for the user" + component=SimpleComponent(), + name="hello_tool", + description="A tool that generates a greeting message for the user", ) # Create pipeline with OpenAIChatGenerator and ToolInvoker @@ -408,9 +376,8 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_user_greeter_in_pipeline(self): - component = UserGreeter() tool = ComponentTool( - component=component, name="user_greeter", description="A tool that greets users with their name and age" + component=UserGreeter(), name="user_greeter", description="A tool that greets users with their name and age" ) pipeline = Pipeline() @@ -432,9 +399,8 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_list_processor_in_pipeline(self): - component = ListProcessor() tool = ComponentTool( - component=component, name="list_processor", description="A tool that concatenates a list of strings" + component=ListProcessor(), name="list_processor", description="A tool that concatenates a list of strings" ) pipeline = Pipeline() @@ -456,9 +422,8 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_person_processor_in_pipeline(self): - component = PersonProcessor() tool = ComponentTool( - component=component, + component=PersonProcessor(), name="person_processor", description="A tool that processes information about a person and their address", ) @@ -482,9 +447,8 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_document_processor_in_pipeline(self): - component = DocumentProcessor() tool = ComponentTool( - component=component, + component=DocumentProcessor(), name="document_processor", description="A tool that concatenates the content of multiple documents", ) @@ -516,9 +480,8 @@ class TestToolComponentInPipelineWithOpenAI: def test_lost_in_middle_ranker_in_pipeline(self): from haystack.components.rankers import LostInTheMiddleRanker - component = LostInTheMiddleRanker() tool = ComponentTool( - component=component, + component=LostInTheMiddleRanker(), name="lost_in_middle_ranker", description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results", ) @@ -543,9 +506,10 @@ class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY not set") @pytest.mark.integration def test_serper_dev_web_search_in_pipeline(self): - component = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3) tool = ComponentTool( - component=component, name="web_search", description="Search the web for current information on any topic" + component=SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3), + name="web_search", + description="Search the web for current information on any topic", ) pipeline = Pipeline() @@ -603,10 +567,8 @@ class TestToolComponentInPipelineWithOpenAI: assert new_pipeline == pipeline def test_component_tool_serde(self): - component = SimpleComponent() - tool = ComponentTool( - component=component, + component=SimpleComponent(), name="simple_tool", description="A simple tool", inputs_from_state={"test": "input"}, @@ -632,16 +594,16 @@ class TestToolComponentInPipelineWithOpenAI: assert isinstance(new_tool._component, SimpleComponent) def test_pipeline_component_fails(self): - component = SimpleComponent() + comp = SimpleComponent() # Create a pipeline and add the component to it pipeline = Pipeline() - pipeline.add_component("simple", component) + pipeline.add_component("simple", comp) # Try to create a tool from the component and it should fail because the component has been added to a pipeline and # thus can't be used as tool with pytest.raises(ValueError, match="Component has been added to a pipeline"): - ComponentTool(component=component) + ComponentTool(component=comp) def test_deepcopy_with_jinja_based_component(self): builder = PromptBuilder("{{query}}") diff --git a/test/tools/test_parameters_schema_utils.py b/test/tools/test_parameters_schema_utils.py new file mode 100644 index 000000000..9f7bff3f4 --- /dev/null +++ b/test/tools/test_parameters_schema_utils.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from typing import List + +from haystack.dataclasses import ByteStream, ChatMessage, Document, TextContent, ToolCall, ToolCallResult +from pydantic import Field, create_model +from haystack.tools.parameters_schema_utils import _resolve_type +from haystack.tools.from_function import _remove_title_from_schema + + +BYTE_STREAM_SCHEMA = { + "type": "object", + "properties": { + "data": {"type": "string", "description": "The binary data stored in Bytestream.", "format": "binary"}, + "meta": { + "type": "object", + "default": {}, + "description": "Additional metadata to be stored with the ByteStream.", + "additionalProperties": True, + }, + "mime_type": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "The mime type of the binary data.", + }, + }, + "required": ["data"], +} + +SPARSE_EMBEDDING_SCHEMA = { + "type": "object", + "properties": { + "indices": { + "type": "array", + "description": "List of indices of non-zero elements in the embedding.", + "items": {"type": "integer"}, + }, + "values": { + "type": "array", + "description": "List of values of non-zero elements in the embedding.", + "items": {"type": "number"}, + }, + }, + "required": ["indices", "values"], +} + +DOCUMENT_SCHEMA = { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the document. When not set, it's generated based on the Document fields' values.", + "default": "", + }, + "content": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Text of the document, if the document contains text.", + }, + "blob": { + "anyOf": [{"$ref": "#/$defs/ByteStream"}, {"type": "null"}], + "default": None, + "description": "Binary data associated with the document, if the document has any binary data associated with it.", + }, + "meta": { + "type": "object", + "description": "Additional custom metadata for the document. Must be JSON-serializable.", + "default": {}, + "additionalProperties": True, + }, + "score": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "default": None, + "description": "Score of the document. Used for ranking, usually assigned by retrievers.", + }, + "embedding": { + "anyOf": [{"type": "array", "items": {"type": "number"}}, {"type": "null"}], + "default": None, + "description": "dense vector representation of the document.", + }, + "sparse_embedding": { + "anyOf": [{"$ref": "#/$defs/SparseEmbedding"}, {"type": "null"}], + "default": None, + "description": "sparse vector representation of the document.", + }, + }, +} + +TEXT_CONTENT_SCHEMA = { + "type": "object", + "properties": {"text": {"type": "string", "description": "The text content of the message."}}, + "required": ["text"], +} + +TOOL_CALL_SCHEMA = { + "type": "object", + "properties": { + "tool_name": {"type": "string", "description": "The name of the Tool to call."}, + "arguments": { + "type": "object", + "description": "The arguments to call the Tool with.", + "additionalProperties": True, + }, + "id": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "The ID of the Tool call.", + }, + }, + "required": ["tool_name", "arguments"], +} + +TOOL_CALL_RESULT_SCHEMA = { + "type": "object", + "properties": { + "result": {"type": "string", "description": "The result of the Tool invocation."}, + "origin": {"$ref": "#/$defs/ToolCall", "description": "The Tool call that produced this result."}, + "error": {"type": "boolean", "description": "Whether the Tool invocation resulted in an error."}, + }, + "required": ["result", "origin", "error"], +} + +CHAT_ROLE_SCHEMA = { + "description": "Enumeration representing the roles within a chat.", + "enum": ["user", "system", "assistant", "tool"], + "type": "string", +} + +CHAT_MESSAGE_SCHEMA = { + "type": "object", + "properties": { + "role": {"$ref": "#/$defs/ChatRole", "description": "Field 'role' of 'ChatMessage'."}, + "content": { + "type": "array", + "description": "Field 'content' of 'ChatMessage'.", + "items": { + "anyOf": [ + {"$ref": "#/$defs/TextContent"}, + {"$ref": "#/$defs/ToolCall"}, + {"$ref": "#/$defs/ToolCallResult"}, + ] + }, + }, + "name": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Field 'name' of 'ChatMessage'.", + }, + "meta": { + "type": "object", + "description": "Field 'meta' of 'ChatMessage'.", + "default": {}, + "additionalProperties": True, + }, + }, + "required": ["role", "content"], +} + + +@pytest.mark.parametrize( + "python_type, description, expected_schema, expected_defs_schema", + [ + ( + ByteStream, + "A byte stream", + {"$ref": "#/$defs/ByteStream", "description": "A byte stream"}, + {"ByteStream": BYTE_STREAM_SCHEMA}, + ), + ( + Document, + "A document", + {"$ref": "#/$defs/Document", "description": "A document"}, + {"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA}, + ), + ( + TextContent, + "A text content", + {"$ref": "#/$defs/TextContent", "description": "A text content"}, + {"TextContent": TEXT_CONTENT_SCHEMA}, + ), + ( + ToolCall, + "A tool call", + {"$ref": "#/$defs/ToolCall", "description": "A tool call"}, + {"ToolCall": TOOL_CALL_SCHEMA}, + ), + ( + ToolCallResult, + "A tool call result", + {"$ref": "#/$defs/ToolCallResult", "description": "A tool call result"}, + {"ToolCallResult": TOOL_CALL_RESULT_SCHEMA, "ToolCall": TOOL_CALL_SCHEMA}, + ), + ( + ChatMessage, + "A chat message", + {"$ref": "#/$defs/ChatMessage", "description": "A chat message"}, + { + "ChatMessage": CHAT_MESSAGE_SCHEMA, + "TextContent": TEXT_CONTENT_SCHEMA, + "ToolCall": TOOL_CALL_SCHEMA, + "ToolCallResult": TOOL_CALL_RESULT_SCHEMA, + "ChatRole": CHAT_ROLE_SCHEMA, + }, + ), + ( + List[Document], + "A list of documents", + {"type": "array", "description": "A list of documents", "items": {"$ref": "#/$defs/Document"}}, + {"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA}, + ), + ( + List[ChatMessage], + "A list of chat messages", + {"type": "array", "description": "A list of chat messages", "items": {"$ref": "#/$defs/ChatMessage"}}, + { + "ChatMessage": CHAT_MESSAGE_SCHEMA, + "TextContent": TEXT_CONTENT_SCHEMA, + "ToolCall": TOOL_CALL_SCHEMA, + "ToolCallResult": TOOL_CALL_RESULT_SCHEMA, + "ChatRole": CHAT_ROLE_SCHEMA, + }, + ), + ], +) +def test_create_parameters_schema_haystack_dataclasses(python_type, description, expected_schema, expected_defs_schema): + resolved_type = _resolve_type(python_type) + fields = {"input_name": (resolved_type, Field(default=..., description=description))} + model = create_model("run", __doc__="A test function", **fields) + parameters_schema = model.model_json_schema() + _remove_title_from_schema(parameters_schema) + + defs_schema = parameters_schema["$defs"] + assert defs_schema == expected_defs_schema + + property_schema = parameters_schema["properties"]["input_name"] + assert property_schema == expected_schema