mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
Fix component tool parameters (#9342)
* Starting property schema refactor * Adding more tests * More tests * Handle null type explicitly * More updates of tests to accomodate Optional properly * Fix more tests * Remove unecessary check * Some cleanup * Update test * Add reno * Fix typing * Add license header * Use docstrings of dataclasses in parameter spec generation * More tests of Haystack dataclass types * Properly handle Sequence * Fix license header * Update OpenAI tests to add more complicated tool parameter signature * Properly set required for dataclasses * Add integration test for azure that includes additionalProperties * Add more complicated integration test for HuggingFaceAPIChatGenerator * Alternate approach using pydantic like we do in from_function.py * Cleanup and fix other affected tests * Fix mypy * PR comments * PR comment * Remove test from HF API * Update reno * Update reno
This commit is contained in:
parent
42b378950f
commit
9ae76e1653
@ -11,6 +11,10 @@ from typing import Any, Dict, Optional
|
||||
class ByteStream:
|
||||
"""
|
||||
Base data class representing a binary object in the Haystack API.
|
||||
|
||||
:param data: The binary data stored in Bytestream.
|
||||
:param meta: Additional metadata to be stored with the ByteStream.
|
||||
:param mime_type: The mime type of the binary data.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
|
||||
@ -2,11 +2,9 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import fields, is_dataclass
|
||||
from inspect import getdoc
|
||||
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import Field, TypeAdapter, create_model
|
||||
|
||||
from haystack import logging
|
||||
from haystack.core.component import Component
|
||||
@ -16,15 +14,12 @@ from haystack.core.serialization import (
|
||||
generate_qualified_class_name,
|
||||
import_class_by_name,
|
||||
)
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import Tool
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from haystack.tools.from_function import _remove_title_from_schema
|
||||
from haystack.tools.parameters_schema_utils import _get_param_descriptions, _resolve_type
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
|
||||
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
|
||||
from docstring_parser import parse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -275,10 +270,10 @@ class ComponentTool(Tool):
|
||||
:raises SchemaGenerationError: If schema generation fails
|
||||
:returns: OpenAI tools schema for the component's run method parameters.
|
||||
"""
|
||||
properties = {}
|
||||
required = []
|
||||
component_run_description, param_descriptions = _get_param_descriptions(component.run)
|
||||
|
||||
param_descriptions = self._get_param_descriptions(component.run)
|
||||
# collect fields (types and defaults) and descriptions from function parameters
|
||||
fields: Dict[str, Any] = {}
|
||||
|
||||
for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined]
|
||||
if inputs_from_state is not None and input_name in inputs_from_state:
|
||||
@ -286,135 +281,23 @@ class ComponentTool(Tool):
|
||||
input_type = socket.type
|
||||
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
|
||||
|
||||
try:
|
||||
property_schema = self._create_property_schema(input_type, description)
|
||||
except Exception as e:
|
||||
raise SchemaGenerationError(
|
||||
f"Error processing input '{input_name}': {e}. "
|
||||
f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
|
||||
f"and lists of these types as input types for component's run method."
|
||||
) from e
|
||||
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
|
||||
# to explicitly indicate that the parameter is required
|
||||
default = ... if socket.is_mandatory else socket.default_value
|
||||
resolved_type = _resolve_type(input_type)
|
||||
fields[input_name] = (resolved_type, Field(default=default, description=description))
|
||||
|
||||
properties[input_name] = property_schema
|
||||
try:
|
||||
model = create_model(component.run.__name__, __doc__=component_run_description, **fields)
|
||||
parameters_schema = model.model_json_schema()
|
||||
except Exception as e:
|
||||
raise SchemaGenerationError(
|
||||
f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'"
|
||||
) from e
|
||||
|
||||
# Use socket.is_mandatory to check if the input is required
|
||||
if socket.is_mandatory:
|
||||
required.append(input_name)
|
||||
|
||||
parameters_schema = {"type": "object", "properties": properties}
|
||||
|
||||
if required:
|
||||
parameters_schema["required"] = required
|
||||
# we don't want to include title keywords in the schema, as they contain redundant information
|
||||
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later
|
||||
# see https://github.com/pydantic/pydantic/discussions/8504
|
||||
_remove_title_from_schema(parameters_schema)
|
||||
|
||||
return parameters_schema
|
||||
|
||||
@staticmethod
|
||||
def _get_param_descriptions(method: Callable) -> Dict[str, str]:
|
||||
"""
|
||||
Extracts parameter descriptions from the method's docstring using docstring_parser.
|
||||
|
||||
:param method: The method to extract parameter descriptions from.
|
||||
:returns: A dictionary mapping parameter names to their descriptions.
|
||||
"""
|
||||
docstring = getdoc(method)
|
||||
if not docstring:
|
||||
return {}
|
||||
|
||||
docstring_parser_import.check()
|
||||
parsed_doc = parse(docstring)
|
||||
param_descriptions = {}
|
||||
for param in parsed_doc.params:
|
||||
if not param.description:
|
||||
logger.warning(
|
||||
"Missing description for parameter '%s'. Please add a description in the component's "
|
||||
"run() method docstring using the format ':param %%s: <description>'. "
|
||||
"This description helps the LLM understand how to use this parameter." % param.arg_name
|
||||
)
|
||||
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
|
||||
return param_descriptions
|
||||
|
||||
@staticmethod
|
||||
def _is_nullable_type(python_type: Any) -> bool:
|
||||
"""
|
||||
Checks if the type is a Union with NoneType (i.e., Optional).
|
||||
|
||||
:param python_type: The Python type to check.
|
||||
:returns: True if the type is a Union with NoneType, False otherwise.
|
||||
"""
|
||||
origin = get_origin(python_type)
|
||||
if origin is Union:
|
||||
return type(None) in get_args(python_type)
|
||||
return False
|
||||
|
||||
def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a schema for a list type.
|
||||
|
||||
:param item_type: The type of items in the list.
|
||||
:param description: The description of the list.
|
||||
:returns: A dictionary representing the list schema.
|
||||
"""
|
||||
items_schema = self._create_property_schema(item_type, "")
|
||||
items_schema.pop("description", None)
|
||||
return {"type": "array", "description": description, "items": items_schema}
|
||||
|
||||
def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a schema for a dataclass.
|
||||
|
||||
:param python_type: The dataclass type.
|
||||
:param description: The description of the dataclass.
|
||||
:returns: A dictionary representing the dataclass schema.
|
||||
"""
|
||||
schema = {"type": "object", "description": description, "properties": {}}
|
||||
cls = python_type if isinstance(python_type, type) else python_type.__class__
|
||||
for field in fields(cls):
|
||||
field_description = f"Field '{field.name}' of '{cls.__name__}'."
|
||||
if isinstance(schema["properties"], dict):
|
||||
schema["properties"][field.name] = self._create_property_schema(field.type, field_description)
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a schema for a basic Python type.
|
||||
|
||||
:param python_type: The Python type.
|
||||
:param description: The description of the type.
|
||||
:returns: A dictionary representing the basic type schema.
|
||||
"""
|
||||
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
|
||||
return {"type": type_mapping.get(python_type, "string"), "description": description}
|
||||
|
||||
def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a property schema for a given Python type, recursively if necessary.
|
||||
|
||||
:param python_type: The Python type to create a property schema for.
|
||||
:param description: The description of the property.
|
||||
:param default: The default value of the property.
|
||||
:returns: A dictionary representing the property schema.
|
||||
:raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models
|
||||
"""
|
||||
nullable = self._is_nullable_type(python_type)
|
||||
if nullable:
|
||||
non_none_types = [t for t in get_args(python_type) if t is not type(None)]
|
||||
python_type = non_none_types[0] if non_none_types else str
|
||||
|
||||
origin = get_origin(python_type)
|
||||
if origin is list:
|
||||
schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
|
||||
elif is_dataclass(python_type):
|
||||
schema = self._create_dataclass_schema(python_type, description)
|
||||
elif hasattr(python_type, "model_validate"):
|
||||
raise SchemaGenerationError(
|
||||
f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for "
|
||||
f"component's run method."
|
||||
)
|
||||
else:
|
||||
schema = self._create_basic_type_schema(python_type, description)
|
||||
|
||||
if default is not None:
|
||||
schema["default"] = default
|
||||
|
||||
return schema
|
||||
|
||||
@ -210,9 +210,16 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
|
||||
:param schema:
|
||||
The JSON schema to remove the 'title' keyword from.
|
||||
"""
|
||||
schema.pop("title", None)
|
||||
|
||||
for property_schema in schema["properties"].values():
|
||||
for key in list(property_schema.keys()):
|
||||
if key == "title":
|
||||
del property_schema[key]
|
||||
for key, value in list(schema.items()):
|
||||
# Make sure not to remove parameters named title
|
||||
if key == "properties" and isinstance(value, dict) and "title" in value:
|
||||
for sub_val in value.values():
|
||||
_remove_title_from_schema(sub_val)
|
||||
elif key == "title":
|
||||
del schema[key]
|
||||
elif isinstance(value, dict):
|
||||
_remove_title_from_schema(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
_remove_title_from_schema(item)
|
||||
|
||||
112
haystack/tools/parameters_schema_utils.py
Normal file
112
haystack/tools/parameters_schema_utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import collections
|
||||
from dataclasses import MISSING, fields, is_dataclass
|
||||
from inspect import getdoc
|
||||
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from haystack import logging
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
|
||||
from docstring_parser import parse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_param_descriptions(method: Callable) -> Tuple[str, Dict[str, str]]:
|
||||
"""
|
||||
Extracts parameter descriptions from the method's docstring using docstring_parser.
|
||||
|
||||
:param method: The method to extract parameter descriptions from.
|
||||
:returns:
|
||||
A tuple including the short description of the method and a dictionary mapping parameter names to their
|
||||
descriptions.
|
||||
"""
|
||||
docstring = getdoc(method)
|
||||
if not docstring:
|
||||
return "", {}
|
||||
|
||||
docstring_parser_import.check()
|
||||
parsed_doc = parse(docstring)
|
||||
param_descriptions = {}
|
||||
for param in parsed_doc.params:
|
||||
if not param.description:
|
||||
logger.warning(
|
||||
"Missing description for parameter '%s'. Please add a description in the component's "
|
||||
"run() method docstring using the format ':param %%s: <description>'. "
|
||||
"This description helps the LLM understand how to use this parameter." % param.arg_name
|
||||
)
|
||||
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
|
||||
return parsed_doc.short_description or "", param_descriptions
|
||||
|
||||
|
||||
def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]:
|
||||
"""
|
||||
Convert a Python dataclass to an equivalent Pydantic model.
|
||||
|
||||
:param dc_type: The dataclass type to convert.
|
||||
:returns:
|
||||
A dynamically generated Pydantic model class with fields and types derived from the dataclass definition.
|
||||
Field descriptions are extracted from docstrings when available.
|
||||
"""
|
||||
_, param_descriptions = _get_param_descriptions(dc_type)
|
||||
cls = dc_type if isinstance(dc_type, type) else dc_type.__class__
|
||||
|
||||
field_defs: Dict[str, Any] = {}
|
||||
for field in fields(dc_type):
|
||||
f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type)
|
||||
default = field.default if field.default is not MISSING else ...
|
||||
default = field.default_factory() if callable(field.default_factory) else default
|
||||
|
||||
# Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores
|
||||
field_name = field.name
|
||||
if dc_type is ChatMessage and field_name.startswith("_"):
|
||||
# We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore
|
||||
field_name = field_name[1:]
|
||||
|
||||
description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.")
|
||||
field_defs[field_name] = (f_type, Field(default, description=description))
|
||||
|
||||
model = create_model(cls.__name__, **field_defs)
|
||||
return model
|
||||
|
||||
|
||||
def _resolve_type(_type: Any) -> Any:
|
||||
"""
|
||||
Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types.
|
||||
|
||||
This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types
|
||||
it encounters into corresponding Pydantic models.
|
||||
|
||||
:param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model.
|
||||
For generic types (like List[SomeDataclass]), the inner types are also resolved recursively.
|
||||
|
||||
:returns:
|
||||
A fully resolved type, with all dataclass types converted to Pydantic models
|
||||
"""
|
||||
if is_dataclass(_type):
|
||||
return _dataclass_to_pydantic_model(_type)
|
||||
|
||||
origin = get_origin(_type)
|
||||
args = get_args(_type)
|
||||
|
||||
if origin is list:
|
||||
return List[_resolve_type(args[0]) if args else Any] # type: ignore[misc]
|
||||
|
||||
if origin is collections.abc.Sequence:
|
||||
return Sequence[_resolve_type(args[0]) if args else Any] # type: ignore[misc]
|
||||
|
||||
if origin is Union:
|
||||
return Union[tuple(_resolve_type(a) for a in args)] # type: ignore[misc]
|
||||
|
||||
if origin is dict:
|
||||
return Dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any] # type: ignore[misc]
|
||||
|
||||
return _type
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Refactored JSON Schema generation for ComponentTool parameters using Pydantic’s model_json_schema, enabling expanded type support (e.g. Union, Enum, Dict, etc.).
|
||||
We also convert dataclasses to Pydantic models before calling model_json_schema to preserve docstring descriptions of the parameters in the schema.
|
||||
This means dataclasses like ChatMessage, Document, etc. now have correctly defined JSON schemas.
|
||||
@ -2,36 +2,62 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack import component, Pipeline
|
||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools.tool import Tool
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.azure import default_azure_ad_token_provider
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
return f"Weather info for {city}"
|
||||
def get_weather(city: str) -> Dict[str, Any]:
|
||||
weather_info = {
|
||||
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
||||
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
||||
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
||||
}
|
||||
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
||||
|
||||
|
||||
@component
|
||||
class MessageExtractor:
|
||||
@component.output_types(messages=List[str], meta=Dict[str, Any])
|
||||
def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts the text content of ChatMessage objects
|
||||
|
||||
:param messages: List of Haystack ChatMessage objects
|
||||
:param meta: Optional metadata to include in the response.
|
||||
:returns:
|
||||
A dictionary with keys "messages" and "meta".
|
||||
"""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
return {"messages": [m.text for m in messages], "meta": meta}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
tool = Tool(
|
||||
weather_tool = Tool(
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
|
||||
function=get_weather,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
# We add a tool that has a more complex parameter signature
|
||||
message_extractor_tool = ComponentTool(
|
||||
component=MessageExtractor(),
|
||||
name="message_extractor",
|
||||
description="Useful for returning the text content of ChatMessage objects",
|
||||
)
|
||||
return [weather_tool, message_extractor_tool]
|
||||
|
||||
|
||||
class TestAzureOpenAIChatGenerator:
|
||||
@ -307,7 +333,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
def test_to_dict_with_toolset(self, tools, monkeypatch):
|
||||
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
toolset = Toolset(tools[:1])
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||
data = component.to_dict()
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, Mock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@ -43,22 +44,24 @@ def chat_messages():
|
||||
]
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
return f"Weather info for {city}"
|
||||
def get_weather(city: str) -> Dict[str, Any]:
|
||||
weather_info = {
|
||||
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
||||
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
||||
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
||||
}
|
||||
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
tool = Tool(
|
||||
weather_tool = Tool(
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
|
||||
function=get_weather,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
return [weather_tool]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -974,7 +977,7 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
|
||||
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
|
||||
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
toolset = Toolset(tools[:1])
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai import OpenAIError
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
|
||||
@ -16,11 +17,12 @@ from openai.types.completion_usage import CompletionTokensDetails, CompletionUsa
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.chat import chat_completion_chunk
|
||||
|
||||
from haystack import component
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import Tool
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.tools.toolset import Toolset
|
||||
|
||||
@ -72,21 +74,47 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream):
|
||||
yield mock_chat_completion_create
|
||||
|
||||
|
||||
def mock_tool_function(x):
|
||||
return x
|
||||
def weather_function(city: str) -> Dict[str, Any]:
|
||||
weather_info = {
|
||||
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
||||
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
||||
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
||||
}
|
||||
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
||||
|
||||
|
||||
@component
|
||||
class MessageExtractor:
|
||||
@component.output_types(messages=List[str], meta=Dict[str, Any])
|
||||
def run(self, messages: List[ChatMessage], meta: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts the text content of ChatMessage objects
|
||||
|
||||
:param messages: List of Haystack ChatMessage objects
|
||||
:param meta: Optional metadata to include in the response.
|
||||
:returns:
|
||||
A dictionary with keys "messages" and "meta".
|
||||
"""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
return {"messages": [m.text for m in messages], "meta": meta}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
tool = Tool(
|
||||
weather_tool = Tool(
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
function=mock_tool_function,
|
||||
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
|
||||
function=weather_function,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
# We add a tool that has a more complex parameter signature
|
||||
message_extractor_tool = ComponentTool(
|
||||
component=MessageExtractor(),
|
||||
name="message_extractor",
|
||||
description="Useful for returning the text content of ChatMessage objects",
|
||||
)
|
||||
return [weather_tool, message_extractor_tool]
|
||||
|
||||
|
||||
class TestOpenAIChatGenerator:
|
||||
@ -462,7 +490,9 @@ class TestOpenAIChatGenerator:
|
||||
|
||||
mock_chat_completion_create.return_value = completion
|
||||
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools, tools_strict=True)
|
||||
component = OpenAIChatGenerator(
|
||||
api_key=Secret.from_token("test-api-key"), tools=tools[:1], tools_strict=True
|
||||
)
|
||||
response = component.run([ChatMessage.from_user("What's the weather like in Paris?")])
|
||||
|
||||
# ensure that the tools are passed to the OpenAI API
|
||||
|
||||
@ -24,7 +24,10 @@ from haystack.dataclasses import ChatMessage, ChatRole, Document
|
||||
from haystack.tools import ComponentTool
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
### Component and Model Definitions
|
||||
from test.tools.test_parameters_schema_utils import BYTE_STREAM_SCHEMA, DOCUMENT_SCHEMA, SPARSE_EMBEDDING_SCHEMA
|
||||
|
||||
|
||||
# Component and Model Definitions
|
||||
|
||||
|
||||
@component
|
||||
@ -134,17 +137,18 @@ def output_handler(old, new):
|
||||
return old + new
|
||||
|
||||
|
||||
## Unit tests
|
||||
class TestToolComponent:
|
||||
# TODO Add test for Builder components that have dynamic input types
|
||||
# Does create_parameters schema work in these cases?
|
||||
# Unit tests
|
||||
class TestComponentTool:
|
||||
def test_from_component_basic(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component)
|
||||
tool = ComponentTool(component=SimpleComponent())
|
||||
|
||||
assert tool.name == "simple_component"
|
||||
assert tool.description == "A simple component that generates text."
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"description": "A simple component that generates text.",
|
||||
"properties": {"text": {"type": "string", "description": "user's name"}},
|
||||
"required": ["text"],
|
||||
}
|
||||
@ -156,44 +160,39 @@ class TestToolComponent:
|
||||
assert result["reply"] == "Hello, world!"
|
||||
|
||||
def test_from_component_long_description(self):
|
||||
component = SimpleComponent()
|
||||
tool = ComponentTool(component=component, description="".join(["A"] * 1024))
|
||||
|
||||
tool = ComponentTool(component=SimpleComponent(), description="".join(["A"] * 1024))
|
||||
assert len(tool.description) == 1024
|
||||
|
||||
def test_from_component_with_inputs(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component, inputs_from_state={"text": "text"})
|
||||
|
||||
tool = ComponentTool(component=SimpleComponent(), inputs_from_state={"text": "text"})
|
||||
assert tool.inputs_from_state == {"text": "text"}
|
||||
# Inputs should be excluded from schema generation
|
||||
assert tool.parameters == {"type": "object", "properties": {}}
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"description": "A simple component that generates text.",
|
||||
}
|
||||
|
||||
def test_from_component_with_outputs(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}})
|
||||
|
||||
tool = ComponentTool(component=SimpleComponent(), outputs_to_state={"replies": {"source": "reply"}})
|
||||
assert tool.outputs_to_state == {"replies": {"source": "reply"}}
|
||||
|
||||
def test_from_component_with_dataclass(self):
|
||||
component = UserGreeter()
|
||||
|
||||
tool = ComponentTool(component=component)
|
||||
tool = ComponentTool(component=UserGreeter())
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"description": "The User object to process.",
|
||||
"$defs": {
|
||||
"User": {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Field 'name' of 'User'."},
|
||||
"age": {"type": "integer", "description": "Field 'age' of 'User'."},
|
||||
"name": {"description": "Field 'name' of 'User'.", "type": "string", "default": "Anonymous"},
|
||||
"age": {"description": "Field 'age' of 'User'.", "type": "integer", "default": 0},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
"description": "A simple component that processes a User.",
|
||||
"properties": {"user": {"$ref": "#/$defs/User", "description": "The User object to process."}},
|
||||
"required": ["user"],
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert tool.name == "user_greeter"
|
||||
@ -206,14 +205,13 @@ class TestToolComponent:
|
||||
assert result["message"] == "User Alice is 30 years old"
|
||||
|
||||
def test_from_component_with_list_input(self):
|
||||
component = ListProcessor()
|
||||
|
||||
tool = ComponentTool(
|
||||
component=component, name="list_processing_tool", description="A tool that concatenates strings"
|
||||
component=ListProcessor(), name="list_processing_tool", description="A tool that concatenates strings"
|
||||
)
|
||||
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"description": "Concatenates a list of strings into a single string.",
|
||||
"properties": {
|
||||
"texts": {
|
||||
"type": "array",
|
||||
@ -231,30 +229,33 @@ class TestToolComponent:
|
||||
assert result["concatenated"] == "hello world"
|
||||
|
||||
def test_from_component_with_nested_dataclass(self):
|
||||
component = PersonProcessor()
|
||||
|
||||
tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people")
|
||||
tool = ComponentTool(
|
||||
component=PersonProcessor(), name="person_tool", description="A tool that processes people"
|
||||
)
|
||||
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {
|
||||
"type": "object",
|
||||
"description": "The Person to process.",
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Field 'name' of 'Person'."},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"description": "Field 'address' of 'Person'.",
|
||||
"properties": {
|
||||
"street": {"type": "string", "description": "Field 'street' of 'Address'."},
|
||||
"city": {"type": "string", "description": "Field 'city' of 'Address'."},
|
||||
},
|
||||
},
|
||||
"street": {"description": "Field 'street' of 'Address'.", "type": "string"},
|
||||
"city": {"description": "Field 'city' of 'Address'.", "type": "string"},
|
||||
},
|
||||
}
|
||||
"required": ["street", "city"],
|
||||
"type": "object",
|
||||
},
|
||||
"Person": {
|
||||
"properties": {
|
||||
"name": {"description": "Field 'name' of 'Person'.", "type": "string"},
|
||||
"address": {"$ref": "#/$defs/Address", "description": "Field 'address' of 'Person'."},
|
||||
},
|
||||
"required": ["name", "address"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"description": "Creates information about the person.",
|
||||
"properties": {"person": {"$ref": "#/$defs/Person", "description": "The Person to process."}},
|
||||
"required": ["person"],
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# Test tool invocation
|
||||
@ -264,64 +265,29 @@ class TestToolComponent:
|
||||
assert result["info"] == "Diana lives at 123 Elm Street, Metropolis."
|
||||
|
||||
def test_from_component_with_document_list(self):
|
||||
component = DocumentProcessor()
|
||||
|
||||
tool = ComponentTool(
|
||||
component=component, name="document_processor", description="A tool that concatenates document contents"
|
||||
component=DocumentProcessor(),
|
||||
name="document_processor",
|
||||
description="A tool that concatenates document contents",
|
||||
)
|
||||
|
||||
assert tool.parameters == {
|
||||
"type": "object",
|
||||
"$defs": {
|
||||
"ByteStream": BYTE_STREAM_SCHEMA,
|
||||
"Document": DOCUMENT_SCHEMA,
|
||||
"SparseEmbedding": SPARSE_EMBEDDING_SCHEMA,
|
||||
},
|
||||
"description": "Concatenates the content of multiple documents with newlines.",
|
||||
"properties": {
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"description": "List of Documents whose content will be concatenated",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "Field 'id' of 'Document'."},
|
||||
"content": {"type": "string", "description": "Field 'content' of 'Document'."},
|
||||
"blob": {
|
||||
"type": "object",
|
||||
"description": "Field 'blob' of 'Document'.",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "Field 'data' of 'ByteStream'."},
|
||||
"meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "Field 'mime_type' of 'ByteStream'.",
|
||||
},
|
||||
},
|
||||
},
|
||||
"meta": {"type": "string", "description": "Field 'meta' of 'Document'."},
|
||||
"score": {"type": "number", "description": "Field 'score' of 'Document'."},
|
||||
"embedding": {
|
||||
"type": "array",
|
||||
"description": "Field 'embedding' of 'Document'.",
|
||||
"items": {"type": "number"},
|
||||
},
|
||||
"sparse_embedding": {
|
||||
"type": "object",
|
||||
"description": "Field 'sparse_embedding' of 'Document'.",
|
||||
"properties": {
|
||||
"indices": {
|
||||
"type": "array",
|
||||
"description": "Field 'indices' of 'SparseEmbedding'.",
|
||||
"items": {"type": "integer"},
|
||||
},
|
||||
"values": {
|
||||
"type": "array",
|
||||
"description": "Field 'values' of 'SparseEmbedding'.",
|
||||
"items": {"type": "number"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"items": {"$ref": "#/$defs/Document"},
|
||||
"type": "array",
|
||||
},
|
||||
"top_k": {"description": "The number of top documents to concatenate", "type": "integer"},
|
||||
"top_k": {"description": "The number of top documents to concatenate", "type": "integer", "default": 5},
|
||||
},
|
||||
"required": ["documents"],
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# Test tool invocation
|
||||
@ -341,15 +307,16 @@ class TestToolComponent:
|
||||
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail")
|
||||
|
||||
|
||||
## Integration tests
|
||||
# Integration tests
|
||||
class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_component_tool_in_pipeline(self):
|
||||
# Create component and convert it to tool
|
||||
component = SimpleComponent()
|
||||
tool = ComponentTool(
|
||||
component=component, name="hello_tool", description="A tool that generates a greeting message for the user"
|
||||
component=SimpleComponent(),
|
||||
name="hello_tool",
|
||||
description="A tool that generates a greeting message for the user",
|
||||
)
|
||||
|
||||
# Create pipeline with OpenAIChatGenerator and ToolInvoker
|
||||
@ -378,9 +345,10 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.integration
|
||||
def test_component_tool_in_pipeline_openai_tools_strict(self):
|
||||
# Create component and convert it to tool
|
||||
component = SimpleComponent()
|
||||
tool = ComponentTool(
|
||||
component=component, name="hello_tool", description="A tool that generates a greeting message for the user"
|
||||
component=SimpleComponent(),
|
||||
name="hello_tool",
|
||||
description="A tool that generates a greeting message for the user",
|
||||
)
|
||||
|
||||
# Create pipeline with OpenAIChatGenerator and ToolInvoker
|
||||
@ -408,9 +376,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_user_greeter_in_pipeline(self):
|
||||
component = UserGreeter()
|
||||
tool = ComponentTool(
|
||||
component=component, name="user_greeter", description="A tool that greets users with their name and age"
|
||||
component=UserGreeter(), name="user_greeter", description="A tool that greets users with their name and age"
|
||||
)
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -432,9 +399,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_list_processor_in_pipeline(self):
|
||||
component = ListProcessor()
|
||||
tool = ComponentTool(
|
||||
component=component, name="list_processor", description="A tool that concatenates a list of strings"
|
||||
component=ListProcessor(), name="list_processor", description="A tool that concatenates a list of strings"
|
||||
)
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -456,9 +422,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_person_processor_in_pipeline(self):
|
||||
component = PersonProcessor()
|
||||
tool = ComponentTool(
|
||||
component=component,
|
||||
component=PersonProcessor(),
|
||||
name="person_processor",
|
||||
description="A tool that processes information about a person and their address",
|
||||
)
|
||||
@ -482,9 +447,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_document_processor_in_pipeline(self):
|
||||
component = DocumentProcessor()
|
||||
tool = ComponentTool(
|
||||
component=component,
|
||||
component=DocumentProcessor(),
|
||||
name="document_processor",
|
||||
description="A tool that concatenates the content of multiple documents",
|
||||
)
|
||||
@ -516,9 +480,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
def test_lost_in_middle_ranker_in_pipeline(self):
|
||||
from haystack.components.rankers import LostInTheMiddleRanker
|
||||
|
||||
component = LostInTheMiddleRanker()
|
||||
tool = ComponentTool(
|
||||
component=component,
|
||||
component=LostInTheMiddleRanker(),
|
||||
name="lost_in_middle_ranker",
|
||||
description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results",
|
||||
)
|
||||
@ -543,9 +506,10 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
@pytest.mark.skipif(not os.environ.get("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_serper_dev_web_search_in_pipeline(self):
|
||||
component = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
|
||||
tool = ComponentTool(
|
||||
component=component, name="web_search", description="Search the web for current information on any topic"
|
||||
component=SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3),
|
||||
name="web_search",
|
||||
description="Search the web for current information on any topic",
|
||||
)
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -603,10 +567,8 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
assert new_pipeline == pipeline
|
||||
|
||||
def test_component_tool_serde(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(
|
||||
component=component,
|
||||
component=SimpleComponent(),
|
||||
name="simple_tool",
|
||||
description="A simple tool",
|
||||
inputs_from_state={"test": "input"},
|
||||
@ -632,16 +594,16 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
assert isinstance(new_tool._component, SimpleComponent)
|
||||
|
||||
def test_pipeline_component_fails(self):
|
||||
component = SimpleComponent()
|
||||
comp = SimpleComponent()
|
||||
|
||||
# Create a pipeline and add the component to it
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("simple", component)
|
||||
pipeline.add_component("simple", comp)
|
||||
|
||||
# Try to create a tool from the component and it should fail because the component has been added to a pipeline and
|
||||
# thus can't be used as tool
|
||||
with pytest.raises(ValueError, match="Component has been added to a pipeline"):
|
||||
ComponentTool(component=component)
|
||||
ComponentTool(component=comp)
|
||||
|
||||
def test_deepcopy_with_jinja_based_component(self):
|
||||
builder = PromptBuilder("{{query}}")
|
||||
|
||||
239
test/tools/test_parameters_schema_utils.py
Normal file
239
test/tools/test_parameters_schema_utils.py
Normal file
@ -0,0 +1,239 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
from haystack.dataclasses import ByteStream, ChatMessage, Document, TextContent, ToolCall, ToolCallResult
|
||||
from pydantic import Field, create_model
|
||||
from haystack.tools.parameters_schema_utils import _resolve_type
|
||||
from haystack.tools.from_function import _remove_title_from_schema
|
||||
|
||||
|
||||
BYTE_STREAM_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "The binary data stored in Bytestream.", "format": "binary"},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"default": {},
|
||||
"description": "Additional metadata to be stored with the ByteStream.",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"mime_type": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "The mime type of the binary data.",
|
||||
},
|
||||
},
|
||||
"required": ["data"],
|
||||
}
|
||||
|
||||
SPARSE_EMBEDDING_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"indices": {
|
||||
"type": "array",
|
||||
"description": "List of indices of non-zero elements in the embedding.",
|
||||
"items": {"type": "integer"},
|
||||
},
|
||||
"values": {
|
||||
"type": "array",
|
||||
"description": "List of values of non-zero elements in the embedding.",
|
||||
"items": {"type": "number"},
|
||||
},
|
||||
},
|
||||
"required": ["indices", "values"],
|
||||
}
|
||||
|
||||
DOCUMENT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the document. When not set, it's generated based on the Document fields' values.",
|
||||
"default": "",
|
||||
},
|
||||
"content": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "Text of the document, if the document contains text.",
|
||||
},
|
||||
"blob": {
|
||||
"anyOf": [{"$ref": "#/$defs/ByteStream"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "Binary data associated with the document, if the document has any binary data associated with it.",
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"description": "Additional custom metadata for the document. Must be JSON-serializable.",
|
||||
"default": {},
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"score": {
|
||||
"anyOf": [{"type": "number"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "Score of the document. Used for ranking, usually assigned by retrievers.",
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [{"type": "array", "items": {"type": "number"}}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "dense vector representation of the document.",
|
||||
},
|
||||
"sparse_embedding": {
|
||||
"anyOf": [{"$ref": "#/$defs/SparseEmbedding"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "sparse vector representation of the document.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
TEXT_CONTENT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string", "description": "The text content of the message."}},
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
TOOL_CALL_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool_name": {"type": "string", "description": "The name of the Tool to call."},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"description": "The arguments to call the Tool with.",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"id": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "The ID of the Tool call.",
|
||||
},
|
||||
},
|
||||
"required": ["tool_name", "arguments"],
|
||||
}
|
||||
|
||||
TOOL_CALL_RESULT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {"type": "string", "description": "The result of the Tool invocation."},
|
||||
"origin": {"$ref": "#/$defs/ToolCall", "description": "The Tool call that produced this result."},
|
||||
"error": {"type": "boolean", "description": "Whether the Tool invocation resulted in an error."},
|
||||
},
|
||||
"required": ["result", "origin", "error"],
|
||||
}
|
||||
|
||||
CHAT_ROLE_SCHEMA = {
|
||||
"description": "Enumeration representing the roles within a chat.",
|
||||
"enum": ["user", "system", "assistant", "tool"],
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
CHAT_MESSAGE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"$ref": "#/$defs/ChatRole", "description": "Field 'role' of 'ChatMessage'."},
|
||||
"content": {
|
||||
"type": "array",
|
||||
"description": "Field 'content' of 'ChatMessage'.",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"$ref": "#/$defs/TextContent"},
|
||||
{"$ref": "#/$defs/ToolCall"},
|
||||
{"$ref": "#/$defs/ToolCallResult"},
|
||||
]
|
||||
},
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "Field 'name' of 'ChatMessage'.",
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"description": "Field 'meta' of 'ChatMessage'.",
|
||||
"default": {},
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"python_type, description, expected_schema, expected_defs_schema",
|
||||
[
|
||||
(
|
||||
ByteStream,
|
||||
"A byte stream",
|
||||
{"$ref": "#/$defs/ByteStream", "description": "A byte stream"},
|
||||
{"ByteStream": BYTE_STREAM_SCHEMA},
|
||||
),
|
||||
(
|
||||
Document,
|
||||
"A document",
|
||||
{"$ref": "#/$defs/Document", "description": "A document"},
|
||||
{"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA},
|
||||
),
|
||||
(
|
||||
TextContent,
|
||||
"A text content",
|
||||
{"$ref": "#/$defs/TextContent", "description": "A text content"},
|
||||
{"TextContent": TEXT_CONTENT_SCHEMA},
|
||||
),
|
||||
(
|
||||
ToolCall,
|
||||
"A tool call",
|
||||
{"$ref": "#/$defs/ToolCall", "description": "A tool call"},
|
||||
{"ToolCall": TOOL_CALL_SCHEMA},
|
||||
),
|
||||
(
|
||||
ToolCallResult,
|
||||
"A tool call result",
|
||||
{"$ref": "#/$defs/ToolCallResult", "description": "A tool call result"},
|
||||
{"ToolCallResult": TOOL_CALL_RESULT_SCHEMA, "ToolCall": TOOL_CALL_SCHEMA},
|
||||
),
|
||||
(
|
||||
ChatMessage,
|
||||
"A chat message",
|
||||
{"$ref": "#/$defs/ChatMessage", "description": "A chat message"},
|
||||
{
|
||||
"ChatMessage": CHAT_MESSAGE_SCHEMA,
|
||||
"TextContent": TEXT_CONTENT_SCHEMA,
|
||||
"ToolCall": TOOL_CALL_SCHEMA,
|
||||
"ToolCallResult": TOOL_CALL_RESULT_SCHEMA,
|
||||
"ChatRole": CHAT_ROLE_SCHEMA,
|
||||
},
|
||||
),
|
||||
(
|
||||
List[Document],
|
||||
"A list of documents",
|
||||
{"type": "array", "description": "A list of documents", "items": {"$ref": "#/$defs/Document"}},
|
||||
{"Document": DOCUMENT_SCHEMA, "SparseEmbedding": SPARSE_EMBEDDING_SCHEMA, "ByteStream": BYTE_STREAM_SCHEMA},
|
||||
),
|
||||
(
|
||||
List[ChatMessage],
|
||||
"A list of chat messages",
|
||||
{"type": "array", "description": "A list of chat messages", "items": {"$ref": "#/$defs/ChatMessage"}},
|
||||
{
|
||||
"ChatMessage": CHAT_MESSAGE_SCHEMA,
|
||||
"TextContent": TEXT_CONTENT_SCHEMA,
|
||||
"ToolCall": TOOL_CALL_SCHEMA,
|
||||
"ToolCallResult": TOOL_CALL_RESULT_SCHEMA,
|
||||
"ChatRole": CHAT_ROLE_SCHEMA,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_create_parameters_schema_haystack_dataclasses(python_type, description, expected_schema, expected_defs_schema):
|
||||
resolved_type = _resolve_type(python_type)
|
||||
fields = {"input_name": (resolved_type, Field(default=..., description=description))}
|
||||
model = create_model("run", __doc__="A test function", **fields)
|
||||
parameters_schema = model.model_json_schema()
|
||||
_remove_title_from_schema(parameters_schema)
|
||||
|
||||
defs_schema = parameters_schema["$defs"]
|
||||
assert defs_schema == expected_defs_schema
|
||||
|
||||
property_schema = parameters_schema["properties"]["input_name"]
|
||||
assert property_schema == expected_schema
|
||||
Loading…
x
Reference in New Issue
Block a user