feat: integrate updates of Tool, ToolInvoker, State, create_tool_from_function, ComponentTool from haystack-experimental (#9113)

* update Tool,ToolInvoker,ComponentTool,create_tool_from_function

* add State and its utils

* add tests for State and its utils

* update tests for Tool etc.

* reno

* fix circular imports

* update experimental imports in tests

* fix unit tests

* fix ChatGenerator unit tests

* mypy

* add State to init and pydoc

* explain State in more detail in release note

* add test from #8913

* re-add _check_duplicate_tool_names and refactor imports

* rename inputs and outputs
This commit is contained in:
Julian Risch 2025-03-28 10:49:23 +01:00 committed by GitHub
parent 726b7ef0c4
commit 657d09d7f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1261 additions and 112 deletions

View File

@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses]
modules:
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding",]
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -2,32 +2,44 @@
#
# SPDX-License-Identifier: Apache-2.0
import inspect
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.sockets import Sockets
from haystack.dataclasses import ChatMessage, State, ToolCall
from haystack.tools.component_tool import ComponentTool
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
_TOOL_RESULT_CONVERSION_FAILURE = (
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
)
logger = logging.getLogger(__name__)
class ToolNotFoundException(Exception):
"""
Exception raised when a tool is not found in the list of available tools.
"""
class ToolInvokerError(Exception):
"""Base exception class for ToolInvoker errors."""
pass
def __init__(self, message: str):
super().__init__(message)
class StringConversionError(Exception):
"""
Exception raised when the conversion of a tool result to a string fails.
"""
class ToolNotFoundException(ToolInvokerError):
"""Exception raised when a tool is not found in the list of available tools."""
def __init__(self, tool_name: str, available_tools: List[str]):
message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
super().__init__(message)
class StringConversionError(ToolInvokerError):
"""Exception raised when the conversion of a tool result to a string fails."""
def __init__(self, tool_name: str, conversion_function: str, error: Exception):
message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
super().__init__(message)
class ToolOutputMergeError(ToolInvokerError):
"""Exception raised when merging tool outputs into state fails."""
pass
@ -37,6 +49,7 @@ class ToolInvoker:
"""
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
Also handles reading/writing from a shared `State`.
At initialization, the ToolInvoker component is provided with a list of available tools.
At runtime, the component processes a list of ChatMessage object containing tool calls
and invokes the corresponding tools.
@ -111,20 +124,36 @@ class ToolInvoker:
:param convert_result_to_json_string:
If True, the tool invocation result will be converted to a string using `json.dumps`.
If False, the tool invocation result will be converted to a string using `str`.
:raises ValueError:
If no tools are provided or if duplicate tool names are found.
"""
if not tools:
raise ValueError("ToolInvoker requires at least one tool to be provided.")
raise ValueError("ToolInvoker requires at least one tool.")
_check_duplicate_tool_names(tools)
tool_names = [tool.name for tool in tools]
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
if duplicates:
raise ValueError(f"Duplicate tool names found: {duplicates}")
self.tools = tools
self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
self._tools_with_names = dict(zip(tool_names, tools))
self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string
def _handle_error(self, error: Exception) -> str:
"""
Handles errors by logging and either raising or returning a fallback error message.
:param error: The exception instance.
:returns: The fallback error message when `raise_on_failure` is False.
:raises: The provided error if `raise_on_failure` is True.
"""
logger.error("{error_exception}", error_exception=error)
if self.raise_on_failure:
# We re-raise the original error maintaining the exception chain
raise error
return str(error)
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
"""
Prepares a ChatMessage with the result of a tool invocation.
@ -133,40 +162,141 @@ class ToolInvoker:
The tool result.
:returns:
A ChatMessage object containing the tool result as a string.
:raises
StringConversionError: If the conversion of the tool result to a string fails
and `raise_on_failure` is True.
"""
error = False
if self.convert_result_to_json_string:
try:
try:
if self.convert_result_to_json_string:
# We disable ensure_ascii so special chars like emojis are not converted
tool_result_str = json.dumps(result, ensure_ascii=False)
except Exception as e:
if self.raise_on_failure:
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
error = True
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
try:
tool_result_str = str(result)
else:
tool_result_str = str(result)
except Exception as e:
if self.raise_on_failure:
raise StringConversionError("Failed to convert tool result to string using `str`") from e
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
error = True
conversion_method = "json.dumps" if self.convert_result_to_json_string else "str"
try:
tool_result_str = self._handle_error(StringConversionError(tool_call.tool_name, conversion_method, e))
error = True
except StringConversionError as conversion_error:
# If _handle_error re-raises, this properly preserves the chain
raise conversion_error from e
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
@component.output_types(tool_messages=List[ChatMessage])
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
@staticmethod
def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
"""
Combine LLM-provided arguments (llm_args) with state-based arguments.
Tool arguments take precedence in the following order:
- LLM overrides state if the same param is present in both
- local tool.inputs mappings (if any)
- function signature name matching
"""
final_args = dict(llm_args) # start with LLM-provided
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
# to find out which parameters the tool accepts.
if isinstance(tool, ComponentTool):
# mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
assert hasattr(tool._component, "__haystack_input__") and isinstance(
tool._component.__haystack_input__, Sockets
)
func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
else:
func_params = set(inspect.signature(tool.function).parameters.keys())
# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
param_mappings = tool.inputs_from_state
else:
param_mappings = {name: name for name in func_params}
# Populate final_args from state if not provided by LLM
for state_key, param_name in param_mappings.items():
if param_name not in final_args and state.has(state_key):
final_args[param_name] = state.get(state_key)
return final_args
def _merge_tool_outputs(self, tool: Tool, result: Any, state: State) -> Any:
"""
Merges the tool result into the global state and determines the response message.
This method processes the output of a tool execution and integrates it into the global state.
It also determines what message, if any, should be returned for further processing in a conversation.
Processing Steps:
1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
The return value in this case is simply the full `result` dictionary.
3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
handlers.
:param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
:param result: The output from tool execution. Can be a dictionary, or any other type.
:param state: The global State object to which results should be merged.
:returns: Three possible values:
- A string message for conversation
- The merged result dictionary
- Or the raw result if not a dictionary
"""
# If result is not a dictionary, return it as the output message.
if not isinstance(result, dict):
return result
# If there is no specific `outputs_to_state` mapping, we just return the full result
if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
return result
# Handle tool outputs with specific mapping for message and state updates
return self._handle_tool_outputs(tool.outputs_to_state, result, state)
@staticmethod
def _handle_tool_outputs(outputs_to_state: dict, result: dict, state: State) -> Union[dict, str]:
"""
Handles the `outputs_to_state` mapping from the tool and updates the state accordingly.
:param outputs_to_state: Mapping of outputs from the tool.
:param result: Result of the tool execution.
:param state: Global state to merge results into.
:returns: Final message for LLM or the entire result.
"""
message_content = None
for state_key, config in outputs_to_state.items():
# Get the source key from the output config, otherwise use the entire result
source_key = config.get("source", None)
output_value = result if source_key is None else result.get(source_key)
# Get the handler function, if any
handler = config.get("handler", None)
if state_key == "message":
# Handle the message output separately
if handler is not None:
message_content = handler(output_value)
else:
message_content = str(output_value)
else:
# Merge other outputs into the state
state.set(state_key, output_value, handler_override=handler)
# If no "message" key was found, return the result or message content
return message_content if message_content is not None else result
@component.output_types(tool_messages=List[ChatMessage], state=State)
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
"""
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
:param messages:
A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
@ -177,39 +307,61 @@ class ToolInvoker:
If the tool invocation fails and `raise_on_failure` is True.
:raises StringConversionError:
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
:raises ToolOutputMergeError:
If merging tool outputs into state fails and `raise_on_failure` is True.
"""
if state is None:
state = State(schema={})
# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
tool_messages = []
for message in messages:
tool_calls = message.tool_calls
if not tool_calls:
continue
for tool_call in tool_calls:
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
tool_arguments = tool_call.arguments
if not tool_name in self._tools_with_names:
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
if self.raise_on_failure:
raise ToolNotFoundException(msg)
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
# Check if the tool is available, otherwise return an error message
if tool_name not in self._tools_with_names:
error_message = self._handle_error(
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
tool_to_invoke = self._tools_with_names[tool_name]
# 1) Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# 2) Invoke the tool
try:
tool_result = tool_to_invoke.invoke(**tool_arguments)
tool_result = tool_to_invoke.invoke(**final_args)
except ToolInvocationError as e:
if self.raise_on_failure:
raise e
msg = _TOOL_INVOCATION_FAILURE.format(error=e)
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
error_message = self._handle_error(e)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
tool_message = self._prepare_tool_result_message(tool_result, tool_call)
tool_messages.append(tool_message)
# 3) Merge outputs into state & create a single ChatMessage for the LLM
try:
tool_text = self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
return {"tool_messages": tool_messages}
tool_messages.append(self._prepare_tool_result_message(result=tool_text, tool_call=tool_call))
return {"tool_messages": tool_messages, "state": state}
def to_dict(self) -> Dict[str, Any]:
"""

View File

@ -13,6 +13,7 @@ _import_structure = {
"chat_message": ["ChatMessage", "ChatRole", "TextContent", "ToolCall", "ToolCallResult"],
"document": ["Document"],
"sparse_embedding": ["SparseEmbedding"],
"state": ["State"],
"streaming_chunk": [
"StreamingChunk",
"AsyncStreamingCallbackT",
@ -28,6 +29,7 @@ if TYPE_CHECKING:
from .chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult
from .document import Document
from .sparse_embedding import SparseEmbedding
from .state import State
from .streaming_chunk import (
AsyncStreamingCallbackT,
StreamingCallbackT,

View File

@ -0,0 +1,153 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional
from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.type_serialization import deserialize_type, serialize_type
def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a schema dictionary to a serializable format.
Converts each parameter's type and optional handler function into a serializable
format using type and callable serialization utilities.
:param schema: Dictionary mapping parameter names to their type and handler configs
:returns: Dictionary with serialized type and handler information
"""
serialized_schema = {}
for param, config in schema.items():
serialized_schema[param] = {"type": serialize_type(config["type"])}
if config.get("handler"):
serialized_schema[param]["handler"] = serialize_callable(config["handler"])
return serialized_schema
def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a serialized schema dictionary back to its original format.
Deserializes the type and optional handler function for each parameter from their
serialized format back into Python types and callables.
:param schema: Dictionary containing serialized schema information
:returns: Dictionary with deserialized type and handler configurations
"""
deserialized_schema = {}
for param, config in schema.items():
deserialized_schema[param] = {"type": deserialize_type(config["type"])}
if config.get("handler"):
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
return deserialized_schema
def _validate_schema(schema: Dict[str, Any]) -> None:
"""
Validate that a schema dictionary meets all required constraints.
Checks that each parameter definition has a valid type field and that any handler
specified is a callable function.
:param schema: Dictionary mapping parameter names to their type and handler configs
:raises ValueError: If schema validation fails due to missing or invalid fields
"""
for param, definition in schema.items():
if "type" not in definition:
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
if not _is_valid_type(definition["type"]):
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
if definition.get("handler") is not None and not callable(definition["handler"]):
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
class State:
"""
A dataclass that wraps a StateSchema and maintains an internal _data dictionary.
Each schema entry has:
"parameter_name": {
"type": SomeType,
"handler": Optional[Callable[[Any, Any], Any]]
}
"""
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
"""
Initialize a State object with a schema and optional data.
:param schema: Dictionary mapping parameter names to their type and handler configs.
Type must be a valid Python type, and handler must be a callable function or None.
If handler is None, the default handler for the type will be used. The default handlers are:
- For list types: `haystack.dataclasses.state_utils.merge_lists`
- For all other types: `haystack.dataclasses.state_utils.replace_values`
:param data: Optional dictionary of initial data to populate the state
"""
_validate_schema(schema)
self.schema = schema
self._data = data or {}
# Set default handlers if not provided in schema
for definition in schema.values():
# Skip if handler is already defined and not None
if definition.get("handler") is not None:
continue
# Set default handler based on type
if _is_list_type(definition["type"]):
definition["handler"] = merge_lists
else:
definition["handler"] = replace_values
def get(self, key: str, default: Any = None) -> Any:
"""
Retrieve a value from the state by key.
:param key: Key to look up in the state
:param default: Value to return if key is not found
:returns: Value associated with key or default if not found
"""
return self._data.get(key, default)
def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
"""
Set or merge a value in the state according to schema rules.
Value is merged or overwritten according to these rules:
- if handler_override is given, use that
- else use the handler defined in the schema for 'key'
:param key: Key to store the value under
:param value: Value to store or merge
:param handler_override: Optional function to override the default merge behavior
"""
# If key not in schema, we throw an error
definition = self.schema.get(key, None)
if definition is None:
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
# Get current value from state and apply handler
current_value = self._data.get(key, None)
handler = handler_override or definition["handler"]
self._data[key] = handler(current_value, value)
@property
def data(self):
"""
All current data of the state.
"""
return self._data
def has(self, key: str) -> bool:
"""
Check if a key exists in the state.
:param key: Key to check for existence
:returns: True if key exists in state, False otherwise
"""
return key in self._data

View File

@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, List, TypeVar, Union, get_origin
T = TypeVar("T")
def _is_valid_type(obj: Any) -> bool:
"""
Check if an object is a valid type annotation.
Valid types include:
- Normal classes (str, dict, CustomClass)
- Generic types (List[str], Dict[str, int])
- Union types (Union[str, int], Optional[str])
:param obj: The object to check
:return: True if the object is a valid type annotation, False otherwise
Example usage:
>>> _is_valid_type(str)
True
>>> _is_valid_type(List[int])
True
>>> _is_valid_type(Union[str, int])
True
>>> _is_valid_type(42)
False
"""
# Handle Union types (including Optional)
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
return True
# Handle normal classes and generic types
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}
def _is_list_type(type_hint: Any) -> bool:
"""
Check if a type hint represents a list type.
:param type_hint: The type hint to check
:return: True if the type hint represents a list, False otherwise
"""
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.
If either `current` or `new` is not already a list, it is converted into one.
The function ensures that both inputs are treated as lists and concatenates them.
If `current` is None, it is treated as an empty list.
:param current: The existing value(s), either a single item or a list.
:param new: The new value(s) to merge, either a single item or a list.
:return: A list containing elements from both `current` and `new`.
"""
current_list = [] if current is None else current if isinstance(current, list) else [current]
new_list = new if isinstance(new, list) else [new]
return current_list + new_list
def replace_values(current: Any, new: Any) -> Any:
"""
Replace the `current` value with the `new` value.
:param current: The existing value
:param new: The new value to replace
:return: The new value
"""
return new

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from copy import copy, deepcopy
from dataclasses import fields, is_dataclass
from inspect import getdoc
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
@ -19,6 +20,7 @@ from haystack.core.serialization import (
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool
from haystack.tools.errors import SchemaGenerationError
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
from docstring_parser import parse
@ -87,13 +89,34 @@ class ComponentTool(Tool):
"""
def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None):
def __init__(
self,
component: Component,
name: Optional[str] = None,
description: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
*,
inputs_from_state: Optional[Dict[str, Any]] = None,
outputs_to_state: Optional[Dict[str, Any]] = None,
):
"""
Create a Tool instance from a Haystack component.
:param component: The Haystack component to wrap as a tool.
:param name: Optional name for the tool (defaults to snake_case of component class name).
:param description: Optional description (defaults to component's docstring).
:param parameters:
A JSON schema defining the parameters expected by the Tool.
Will fall back to the parameters defined in the component's run method signature if not provided.
:param inputs_from_state:
Optional dictionary mapping state keys to tool parameter names.
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
:param outputs_to_state:
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
Example: {
"documents": {"source": "docs", "handler": custom_handler},
"message": {"source": "summary", "handler": format_summary}
}
:raises ValueError: If the component is invalid or schema generation fails.
"""
if not isinstance(component, Component):
@ -110,8 +133,9 @@ class ComponentTool(Tool):
)
raise ValueError(msg)
self._unresolved_parameters = parameters
# Create the tools schema from the component run method parameters
tool_schema = self._create_tool_parameters_schema(component)
tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
def component_invoker(**kwargs):
"""
@ -155,16 +179,39 @@ class ComponentTool(Tool):
description = description or component.__doc__ or name
# Create the Tool instance with the component invoker as the function to be called and the schema
super().__init__(name, description, tool_schema, component_invoker)
super().__init__(
name=name,
description=description,
parameters=tool_schema,
function=component_invoker,
inputs_from_state=inputs_from_state,
outputs_to_state=outputs_to_state,
)
self._component = component
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the ComponentTool to a dictionary.
"""
# we do not serialize the function in this case: it can be recreated from the component at deserialization time
serialized = {"name": self.name, "description": self.description, "parameters": self.parameters}
serialized["component"] = component_to_dict(obj=self._component, name=self.name)
serialized_component = component_to_dict(obj=self._component, name=self.name)
serialized = {
"component": serialized_component,
"name": self.name,
"description": self.description,
"parameters": self._unresolved_parameters,
"inputs_from_state": self.inputs_from_state,
}
if self.outputs_to_state is not None:
serialized_outputs = {}
for key, config in self.outputs_to_state.items():
serialized_config = config.copy()
if "handler" in config:
serialized_config["handler"] = serialize_callable(config["handler"])
serialized_outputs[key] = serialized_config
serialized["outputs_to_state"] = serialized_outputs
return {"type": generate_qualified_class_name(type(self)), "data": serialized}
@classmethod
@ -175,9 +222,26 @@ class ComponentTool(Tool):
inner_data = data["data"]
component_class = import_class_by_name(inner_data["component"]["type"])
component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
return cls(component=component, name=inner_data["name"], description=inner_data["description"])
def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]:
if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
deserialized_outputs = {}
for key, config in inner_data["outputs_to_state"].items():
deserialized_config = config.copy()
if "handler" in config:
deserialized_config["handler"] = deserialize_callable(config["handler"])
deserialized_outputs[key] = deserialized_config
inner_data["outputs_to_state"] = deserialized_outputs
return cls(
component=component,
name=inner_data["name"],
description=inner_data["description"],
parameters=inner_data.get("parameters", None),
inputs_from_state=inner_data.get("inputs_from_state", None),
outputs_to_state=inner_data.get("outputs_to_state", None),
)
def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates an OpenAI tools schema from a component's run method parameters.
@ -191,6 +255,8 @@ class ComponentTool(Tool):
param_descriptions = self._get_param_descriptions(component.run)
for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined]
if inputs_from_state is not None and input_name in inputs_from_state:
continue
input_type = socket.type
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
@ -326,3 +392,29 @@ class ComponentTool(Tool):
schema["default"] = default
return schema
def __deepcopy__(self, memo: Dict[Any, Any]) -> "ComponentTool":
# Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
# When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
# We overwrite ComponentTool.__deepcopy__ to fix this until a more comprehensive fix is merged.
# We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
result = copy(self)
# Add the object to the memo dictionary to handle circular references
memo[id(self)] = result
# Deep copy all attributes with exception handling
for key, value in self.__dict__.items():
try:
# Try to deep copy the attribute
setattr(result, key, deepcopy(value, memo))
except TypeError:
# Fall back to using the original attribute for components that use Jinja2-templates
logger.debug(
"deepcopy of ComponentTool {tool_name} failed. Using original attribute '{attribute}' instead.",
tool_name=self.name,
attribute=key,
)
setattr(result, key, getattr(self, key))
return result

View File

@ -3,16 +3,20 @@
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from pydantic import create_model
from haystack.tools.errors import SchemaGenerationError
from haystack.tools.tool import Tool
from .errors import SchemaGenerationError
from .tool import Tool
def create_tool_from_function(
function: Callable, name: Optional[str] = None, description: Optional[str] = None
function: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
inputs_from_state: Optional[Dict[str, str]] = None,
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None,
) -> "Tool":
"""
Create a Tool instance from a function.
@ -62,7 +66,15 @@ def create_tool_from_function(
:param description:
The description of the Tool. If not provided, the docstring of the function will be used.
To intentionally leave the description empty, pass an empty string.
:param inputs_from_state:
Optional dictionary mapping state keys to tool parameter names.
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
:param outputs_to_state:
Optional dictionary defining how tool outputs map to state and message handling.
Example: {
"documents": {"source": "docs", "handler": custom_handler},
"message": {"source": "summary", "handler": format_summary}
}
:returns:
The Tool created from the function.
@ -71,7 +83,6 @@ def create_tool_from_function(
:raises SchemaGenerationError:
If there is an error generating the JSON schema for the Tool.
"""
tool_description = description if description is not None else (function.__doc__ or "")
signature = inspect.signature(function)
@ -81,6 +92,10 @@ def create_tool_from_function(
descriptions = {}
for param_name, param in signature.parameters.items():
# Skip adding parameter names that will be passed to the tool from State
if inputs_from_state and param_name in inputs_from_state.values():
continue
if param.annotation is param.empty:
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
@ -109,15 +124,33 @@ def create_tool_from_function(
if param_name in schema["properties"]:
schema["properties"][param_name]["description"] = param_description
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
return Tool(
name=name or function.__name__,
description=tool_description,
parameters=schema,
function=function,
inputs_from_state=inputs_from_state,
outputs_to_state=outputs_to_state,
)
def tool(function: Callable) -> Tool:
def tool(
function: Optional[Callable] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
inputs_from_state: Optional[Dict[str, str]] = None,
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Union[Tool, Callable[[Callable], Tool]]:
"""
Decorator to convert a function into a Tool.
Tool name, description, and parameters are inferred from the function.
If you need to customize more the Tool, use `create_tool_from_function` instead.
Can be used with or without parameters:
@tool # without parameters
def my_function(): ...
@tool(name="custom_name") # with parameters
def my_function(): ...
### Usage example
```python
@ -147,8 +180,27 @@ def tool(function: Callable) -> Tool:
>>> },
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
```
:param function: The function to decorate (when used without parameters)
:param name: Optional custom name for the tool
:param description: Optional custom description
:param inputs_from_state: Optional dictionary mapping state keys to tool parameter names
:param outputs_to_state: Optional dictionary defining how tool outputs map to state and message handling
:return: Either a Tool instance or a decorator function that will create one
"""
return create_tool_from_function(function)
def decorator(func: Callable) -> Tool:
return create_tool_from_function(
function=func,
name=name,
description=description,
inputs_from_state=inputs_from_state,
outputs_to_state=outputs_to_state,
)
if function is None:
return decorator
return decorator(function)
def _remove_title_from_schema(schema: Dict[str, Any]):

View File

@ -29,12 +29,23 @@ class Tool:
A JSON schema defining the parameters expected by the Tool.
:param function:
The function that will be invoked when the Tool is called.
:param inputs_from_state:
Optional dictionary mapping state keys to tool parameter names.
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
:param outputs_to_state:
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
Example: {
"documents": {"source": "docs", "handler": custom_handler},
"message": {"source": "summary", "handler": format_summary}
}
"""
name: str
description: str
parameters: Dict[str, Any]
function: Callable
inputs_from_state: Optional[Dict[str, str]] = None
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None
def __post_init__(self):
# Check that the parameters define a valid JSON schema
@ -43,6 +54,16 @@ class Tool:
except SchemaError as e:
raise ValueError("The provided parameters do not define a valid JSON schema") from e
# Validate outputs structure if provided
if self.outputs_to_state is not None:
for key, config in self.outputs_to_state.items():
if not isinstance(config, dict):
raise ValueError(f"Output configuration for key '{key}' must be a dictionary")
if "source" in config and not isinstance(config["source"], str):
raise ValueError(f"Output source for key '{key}' must be a string.")
if "handler" in config and not callable(config["handler"]):
raise ValueError(f"Output handler for key '{key}' must be callable")
@property
def tool_spec(self) -> Dict[str, Any]:
"""
@ -54,11 +75,12 @@ class Tool:
"""
Invoke the Tool with the provided keyword arguments.
"""
try:
result = self.function(**kwargs)
except Exception as e:
raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e
raise ToolInvocationError(
f"Failed to invoke Tool `{self.name}` with parameters {kwargs}. Error: {e}"
) from e
return result
def to_dict(self) -> Dict[str, Any]:
@ -68,9 +90,19 @@ class Tool:
:returns:
Dictionary with serialized data.
"""
data = asdict(self)
data["function"] = serialize_callable(self.function)
# Serialize output handlers if they exist
if self.outputs_to_state:
serialized_outputs = {}
for key, config in self.outputs_to_state.items():
serialized_config = config.copy()
if "handler" in config:
serialized_config["handler"] = serialize_callable(config["handler"])
serialized_outputs[key] = serialized_config
data["outputs_to_state"] = serialized_outputs
return {"type": generate_qualified_class_name(type(self)), "data": data}
@classmethod
@ -85,6 +117,17 @@ class Tool:
"""
init_parameters = data["data"]
init_parameters["function"] = deserialize_callable(init_parameters["function"])
# Deserialize output handlers if they exist
if "outputs_to_state" in init_parameters and init_parameters["outputs_to_state"]:
deserialized_outputs = {}
for key, config in init_parameters["outputs_to_state"].items():
deserialized_config = config.copy()
if "handler" in config:
deserialized_config["handler"] = deserialize_callable(config["handler"])
deserialized_outputs[key] = deserialized_config
init_parameters["outputs_to_state"] = deserialized_outputs
return cls(**init_parameters)

View File

@ -0,0 +1,8 @@
---
features:
- |
Introduce new State dataclass with a customizable schema for managing Agent state.
Enhance error logging of Tool and extend the ToolInvoker component to work with newly introduced State.
fixes:
- |
Fix an issue that prevented Jinja2-based ComponentTools from being passed into pipelines at runtime.

View File

@ -258,7 +258,9 @@ class TestHuggingFaceAPIChatGenerator:
"data": {
"description": "description",
"function": "builtins.print",
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"parameters": {"x": {"type": "string"}},
},
}
@ -319,7 +321,9 @@ class TestHuggingFaceAPIChatGenerator:
{
"type": "haystack.tools.tool.Tool",
"data": {
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"description": "description",
"parameters": {"x": {"type": "string"}},
"function": "builtins.print",

View File

@ -192,7 +192,9 @@ class TestHuggingFaceLocalChatGenerator:
{
"type": "haystack.tools.tool.Tool",
"data": {
"inputs_from_state": None,
"name": "weather",
"outputs_to_state": None,
"description": "useful to determine the weather in a given location",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
"function": "generators.chat.test_hugging_face_local.get_weather",

View File

@ -205,7 +205,9 @@ class TestOpenAIChatGenerator:
"data": {
"description": "description",
"function": "builtins.print",
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"parameters": {"x": {"type": "string"}},
},
}

View File

@ -3,10 +3,14 @@ import datetime
from haystack import Pipeline
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.tools.tool import Tool, ToolInvocationError
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.dataclasses.state import State
from haystack.tools import ComponentTool, Tool
from haystack.tools.errors import ToolInvocationError
def weather_function(location):
@ -82,6 +86,49 @@ class TestToolInvoker:
with pytest.raises(ValueError):
ToolInvoker(tools=[weather_tool, new_tool])
def test_inject_state_args_no_tool_inputs(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
)
state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
assert args == {"location": "Berlin"}
def test_inject_state_args_no_tool_inputs_component_tool(self):
comp = PromptBuilder(template="Hello, {{name}}!")
prompt_tool = ComponentTool(
component=comp, name="prompt_tool", description="Creates a personalized greeting prompt."
)
state = State(schema={"name": {"type": str}}, data={"name": "James"})
args = ToolInvoker._inject_state_args(tool=prompt_tool, llm_args={}, state=state)
assert args == {"name": "James"}
def test_inject_state_args_with_tool_inputs(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
inputs_from_state={"loc": "location"},
)
state = State(schema={"location": {"type": str}}, data={"loc": "Berlin"})
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
assert args == {"location": "Berlin"}
def test_inject_state_args_param_in_state_and_llm(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
)
state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
assert args == {"location": "Paris"}
def test_run(self, invoker):
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
@ -104,14 +151,14 @@ class TestToolInvoker:
def test_run_no_messages(self, invoker):
result = invoker.run(messages=[])
assert result == {"tool_messages": []}
assert result["tool_messages"] == []
def test_run_no_tool_calls(self, invoker):
user_message = ChatMessage.from_user(text="Hello!")
assistant_message = ChatMessage.from_assistant(text="How can I help you?")
result = invoker.run(messages=[user_message, assistant_message])
assert result == {"tool_messages": []}
assert result["tool_messages"] == []
def test_run_multiple_tool_calls(self, invoker):
tool_calls = [
@ -143,8 +190,8 @@ class TestToolInvoker:
with pytest.raises(ToolNotFoundException):
invoker.run(messages=[tool_call_message])
def test_tool_not_found_does_not_raise_exception(self, invoker):
invoker.raise_on_failure = False
def test_tool_not_found_does_not_raise_exception(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=False)
tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"})
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
@ -162,8 +209,8 @@ class TestToolInvoker:
with pytest.raises(ToolInvocationError):
faulty_invoker.run(messages=[tool_call_message])
def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker):
faulty_invoker.raise_on_failure = False
def test_tool_invocation_error_does_not_raise_exception(self, faulty_tool):
faulty_invoker = ToolInvoker(tools=[faulty_tool], raise_on_failure=False, convert_result_to_json_string=False)
tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
@ -171,10 +218,10 @@ class TestToolInvoker:
result = faulty_invoker.run(messages=[tool_call_message])
tool_message = result["tool_messages"][0]
assert tool_message.tool_call_results[0].error
assert "invocation failed" in tool_message.tool_call_results[0].result
assert "Failed to invoke" in tool_message.tool_call_results[0].result
def test_string_conversion_error(self, invoker):
invoker.convert_result_to_json_string = True
def test_string_conversion_error(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=True)
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
@ -182,9 +229,8 @@ class TestToolInvoker:
with pytest.raises(StringConversionError):
invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call)
def test_string_conversion_error_does_not_raise_exception(self, invoker):
invoker.convert_result_to_json_string = True
invoker.raise_on_failure = False
def test_string_conversion_error_does_not_raise_exception(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=True)
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
@ -231,8 +277,8 @@ class TestToolInvoker:
pipeline_dict = pipeline.to_dict()
assert pipeline_dict == {
"metadata": {},
"max_runs_per_component": 100,
"connection_type_validation": True,
"max_runs_per_component": 100,
"components": {
"invoker": {
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
@ -249,6 +295,8 @@ class TestToolInvoker:
"required": ["location"],
},
"function": "tools.test_tool_invoker.weather_function",
"inputs_from_state": None,
"outputs_to_state": None,
},
}
],
@ -279,3 +327,93 @@ class TestToolInvoker:
new_pipeline = Pipeline.loads(pipeline_yaml)
assert new_pipeline == pipeline
class TestMergeToolOutputs:
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"weather": {"type": str}})
merged_results = invoker._merge_tool_outputs(tool=weather_tool, result="test", state=state)
assert merged_results == "test"
assert state.data == {}
def test_merge_tool_outputs_empty_dict(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"weather": {"type": str}})
merged_results = invoker._merge_tool_outputs(tool=weather_tool, result={}, state=state)
assert merged_results == {}
assert state.data == {}
def test_merge_tool_outputs_no_output_mapping(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"weather": {"type": str}})
merged_results = invoker._merge_tool_outputs(
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
)
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
assert state.data == {}
def test_merge_tool_outputs_with_output_mapping(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
outputs_to_state={"weather": {"source": "weather"}},
)
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"weather": {"type": str}})
merged_results = invoker._merge_tool_outputs(
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
)
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
assert state.data == {"weather": "sunny"}
def test_merge_tool_outputs_with_output_mapping_2(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
outputs_to_state={"all_weather_results": {}},
)
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"all_weather_results": {"type": str}})
merged_results = invoker._merge_tool_outputs(
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
)
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
assert state.data == {"all_weather_results": {"weather": "sunny", "temperature": 14, "unit": "celsius"}}
def test_merge_tool_outputs_with_output_mapping_and_handler(self):
handler = lambda old, new: f"{new}"
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
outputs_to_state={"temperature": {"source": "temperature", "handler": handler}},
)
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={"temperature": {"type": str}})
merged_results = invoker._merge_tool_outputs(
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
)
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
assert state.data == {"temperature": "14"}
def test_merge_tool_outputs_with_message_output_mapping(self):
weather_tool = Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
outputs_to_state={"message": {"source": "weather"}},
)
invoker = ToolInvoker(tools=[weather_tool])
state = State(schema={})
merged_results = invoker._merge_tool_outputs(
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
)
assert merged_results == "sunny"
assert state.data == {}

View File

@ -0,0 +1,146 @@
import pytest
from typing import List, Dict
from haystack.dataclasses.state import State, _validate_schema
@pytest.fixture
def basic_schema():
return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
@pytest.fixture
def complex_schema():
return {
"numbers": {
"type": list,
"handler": lambda current, new: sorted(set(current + new)) if current else sorted(set(new)),
},
"metadata": {"type": dict},
"name": {"type": str},
}
def test_validate_schema_valid(basic_schema):
# Should not raise any exceptions
_validate_schema(basic_schema)
def test_validate_schema_invalid_type():
invalid_schema = {"test": {"type": "not_a_type"}}
with pytest.raises(ValueError, match="must be a Python type"):
_validate_schema(invalid_schema)
def test_validate_schema_missing_type():
invalid_schema = {"test": {"handler": lambda x, y: x + y}}
with pytest.raises(ValueError, match="missing a 'type' entry"):
_validate_schema(invalid_schema)
def test_validate_schema_invalid_handler():
invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
with pytest.raises(ValueError, match="must be callable or None"):
_validate_schema(invalid_schema)
def test_state_initialization(basic_schema):
# Test empty initialization
state = State(basic_schema)
assert state.data == {}
# Test initialization with data
initial_data = {"numbers": [1, 2, 3], "name": "test"}
state = State(basic_schema, initial_data)
assert state.data["numbers"] == [1, 2, 3]
assert state.data["name"] == "test"
def test_state_get(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.get("name") == "test"
assert state.get("non_existent") is None
assert state.get("non_existent", "default") == "default"
def test_state_set_basic(basic_schema):
state = State(basic_schema)
# Test setting new values
state.set("numbers", [1, 2])
assert state.get("numbers") == [1, 2]
# Test updating existing values
state.set("numbers", [3, 4])
assert state.get("numbers") == [1, 2, 3, 4]
def test_state_set_with_handler(complex_schema):
state = State(complex_schema)
# Test custom handler for numbers
state.set("numbers", [3, 2, 1])
assert state.get("numbers") == [1, 2, 3]
state.set("numbers", [6, 5, 4])
assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
def test_state_set_with_handler_override(basic_schema):
state = State(basic_schema)
# Custom handler that concatenates strings
custom_handler = lambda current, new: f"{current}-{new}" if current else new
state.set("name", "first")
state.set("name", "second", handler_override=custom_handler)
assert state.get("name") == "first-second"
def test_state_has(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.has("name") is True
assert state.has("non_existent") is False
def test_state_empty_schema():
state = State({})
assert state.data == {}
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
state.set("any_key", "value")
def test_state_none_values(basic_schema):
state = State(basic_schema)
state.set("name", None)
assert state.get("name") is None
state.set("name", "value")
assert state.get("name") == "value"
def test_state_merge_lists(basic_schema):
state = State(basic_schema)
state.set("numbers", "not_a_list")
assert state.get("numbers") == ["not_a_list"]
state.set("numbers", [1, 2])
assert state.get("numbers") == ["not_a_list", 1, 2]
def test_state_nested_structures():
schema = {
"complex": {
"type": Dict[str, List[int]],
"handler": lambda current, new: {
k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())
}
if current
else new,
}
}
state = State(schema)
state.set("complex", {"a": [1, 2], "b": [3, 4]})
state.set("complex", {"b": [5, 6], "c": [7, 8]})
expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
assert state.get("complex") == expected

View File

@ -0,0 +1,161 @@
import pytest
from typing import List, Dict, Optional, Union, TypeVar, Generic
from dataclasses import dataclass
from haystack.dataclasses.state_utils import _is_list_type, merge_lists, _is_valid_type
import inspect
def test_is_list_type():
assert _is_list_type(list) is True
assert _is_list_type(List[int]) is True
assert _is_list_type(List[str]) is True
assert _is_list_type(dict) is False
assert _is_list_type(int) is False
assert _is_list_type(Union[List[int], None]) is False
class TestMergeLists:
def test_merge_two_lists(self):
current = [1, 2, 3]
new = [4, 5, 6]
result = merge_lists(current, new)
assert result == [1, 2, 3, 4, 5, 6]
# Ensure original lists weren't modified
assert current == [1, 2, 3]
assert new == [4, 5, 6]
def test_append_to_list(self):
current = [1, 2, 3]
new = 4
result = merge_lists(current, new)
assert result == [1, 2, 3, 4]
assert current == [1, 2, 3] # Ensure original wasn't modified
def test_create_new_list(self):
current = 1
new = 2
result = merge_lists(current, new)
assert result == [1, 2]
def test_replace_with_list(self):
current = 1
new = [2, 3]
result = merge_lists(current, new)
assert result == [1, 2, 3]
class TestIsValidType:
def test_builtin_types(self):
assert _is_valid_type(str) is True
assert _is_valid_type(int) is True
assert _is_valid_type(dict) is True
assert _is_valid_type(list) is True
assert _is_valid_type(tuple) is True
assert _is_valid_type(set) is True
assert _is_valid_type(bool) is True
assert _is_valid_type(float) is True
def test_generic_types(self):
assert _is_valid_type(List[str]) is True
assert _is_valid_type(Dict[str, int]) is True
assert _is_valid_type(List[Dict[str, int]]) is True
assert _is_valid_type(Dict[str, List[int]]) is True
def test_custom_classes(self):
@dataclass
class CustomClass:
value: int
T = TypeVar("T")
class GenericCustomClass(Generic[T]):
pass
# Test regular and generic custom classes
assert _is_valid_type(CustomClass) is True
assert _is_valid_type(GenericCustomClass) is True
assert _is_valid_type(GenericCustomClass[int]) is True
# Test generic types with custom classes
assert _is_valid_type(List[CustomClass]) is True
assert _is_valid_type(Dict[str, CustomClass]) is True
assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True
def test_invalid_types(self):
# Test regular values
assert _is_valid_type(42) is False
assert _is_valid_type("string") is False
assert _is_valid_type([1, 2, 3]) is False
assert _is_valid_type({"a": 1}) is False
assert _is_valid_type(True) is False
# Test class instances
@dataclass
class SampleClass:
value: int
instance = SampleClass(42)
assert _is_valid_type(instance) is False
# Test callable objects
assert _is_valid_type(len) is False
assert _is_valid_type(lambda x: x) is False
assert _is_valid_type(print) is False
def test_union_and_optional_types(self):
# Test basic Union types
assert _is_valid_type(Union[str, int]) is True
assert _is_valid_type(Union[str, None]) is True
assert _is_valid_type(Union[List[int], Dict[str, str]]) is True
# Test Optional types (which are Union[T, None])
assert _is_valid_type(Optional[str]) is True
assert _is_valid_type(Optional[List[int]]) is True
assert _is_valid_type(Optional[Dict[str, list]]) is True
# Test that Union itself is not a valid type (only instantiated Unions are)
assert _is_valid_type(Union) is False
def test_nested_generic_types(self):
assert _is_valid_type(List[List[Dict[str, List[int]]]]) is True
assert _is_valid_type(Dict[str, List[Dict[str, set]]]) is True
assert _is_valid_type(Dict[str, Optional[List[int]]]) is True
assert _is_valid_type(List[Union[str, Dict[str, List[int]]]]) is True
def test_edge_cases(self):
# Test None and NoneType
assert _is_valid_type(None) is False
assert _is_valid_type(type(None)) is True
# Test functions and methods
def sample_func():
pass
assert _is_valid_type(sample_func) is False
assert _is_valid_type(type(sample_func)) is True
# Test modules
assert _is_valid_type(inspect) is False
# Test type itself
assert _is_valid_type(type) is True
@pytest.mark.parametrize(
"test_input,expected",
[
(str, True),
(int, True),
(List[int], True),
(Dict[str, int], True),
(Union[str, int], True),
(Optional[str], True),
(42, False),
("string", False),
([1, 2, 3], False),
(lambda x: x, False),
],
)
def test_parametrized_cases(self, test_input, expected):
assert _is_valid_type(test_input) is expected

View File

@ -4,20 +4,21 @@
import json
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List
import pytest
from haystack import Pipeline, component
from haystack.components.builders import PromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.tools.tool_invoker import ToolInvoker
from haystack.components.tools import ToolInvoker
from haystack.components.websearch.serper_dev import SerperDevWebSearch
from haystack.dataclasses import ChatMessage, ChatRole, Document
from haystack.tools import ComponentTool
from haystack.utils.auth import Secret
### Component and Model Definitions
@ -121,6 +122,13 @@ class DocumentProcessor:
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])}
def output_handler(old, new):
"""
Output handler to test serialization.
"""
return old + new
## Unit tests
class TestToolComponent:
def test_from_component_basic(self):
@ -148,6 +156,22 @@ class TestToolComponent:
assert len(tool.description) == 1024
def test_from_component_with_inputs(self):
component = SimpleComponent()
tool = ComponentTool(component=component, inputs_from_state={"text": "text"})
assert tool.inputs_from_state == {"text": "text"}
# Inputs should be excluded from schema generation
assert tool.parameters == {"type": "object", "properties": {}}
def test_from_component_with_outputs(self):
component = SimpleComponent()
tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}})
assert tool.outputs_to_state == {"replies": {"source": "reply"}}
def test_from_component_with_dataclass(self):
component = UserGreeter()
@ -466,7 +490,7 @@ class TestToolComponentInPipelineWithOpenAI:
pipeline.connect("llm.replies", "tool_invoker.messages")
message = ChatMessage.from_user(
text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields."
text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields."
)
result = pipeline.run({"llm": {"messages": [message]}})
@ -500,7 +524,7 @@ class TestToolComponentInPipelineWithOpenAI:
pipeline.connect("llm.replies", "tool_invoker.messages")
message = ChatMessage.from_user(
text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields."
text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields."
)
result = pipeline.run({"llm": {"messages": [message]}})
@ -576,7 +600,13 @@ class TestToolComponentInPipelineWithOpenAI:
def test_component_tool_serde(self):
component = SimpleComponent()
tool = ComponentTool(component=component, name="simple_tool", description="A simple tool")
tool = ComponentTool(
component=component,
name="simple_tool",
description="A simple tool",
inputs_from_state={"test": "input"},
outputs_to_state={"output": {"source": "out", "handler": output_handler}},
)
# Test serialization
tool_dict = tool.to_dict()
@ -584,12 +614,16 @@ class TestToolComponentInPipelineWithOpenAI:
assert tool_dict["data"]["name"] == "simple_tool"
assert tool_dict["data"]["description"] == "A simple tool"
assert "component" in tool_dict["data"]
assert tool_dict["data"]["inputs_from_state"] == {"test": "input"}
assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler"
# Test deserialization
new_tool = ComponentTool.from_dict(tool_dict)
assert new_tool.name == tool.name
assert new_tool.description == tool.description
assert new_tool.parameters == tool.parameters
assert new_tool.inputs_from_state == tool.inputs_from_state
assert new_tool.outputs_to_state == tool.outputs_to_state
assert isinstance(new_tool._component, SimpleComponent)
def test_pipeline_component_fails(self):
@ -603,3 +637,21 @@ class TestToolComponentInPipelineWithOpenAI:
# thus can't be used as tool
with pytest.raises(ValueError, match="Component has been added to a pipeline"):
ComponentTool(component=component)
def test_deepcopy_with_jinja_based_component(self):
# Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
# When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
# We overwrite ComponentTool.__deepcopy__ to fix this in experimental until a more comprehensive fix is merged.
# We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
builder = PromptBuilder("{{query}}")
tool = ComponentTool(component=builder)
result = tool.function(query="Hello")
tool_copy = deepcopy(tool)
result_from_copy = tool_copy.function(query="Hello")
assert "prompt" in result_from_copy
assert result_from_copy["prompt"] == result["prompt"]

View File

@ -1,7 +1,7 @@
import pytest
from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool
from haystack.tools.errors import SchemaGenerationError
from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool
from typing import Annotated, Literal, Optional
@ -133,6 +133,38 @@ def test_tool_decorator_with_annotated_params():
assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny"
def test_tool_decorator_with_parameters():
@tool(name="fetch_weather", description="A tool to check the weather.")
def get_weather(
city: Annotated[str, "The target city"] = "Berlin",
format: Annotated[Literal["short", "long"], "Output format"] = "short",
) -> str:
"""Get weather report for a city."""
return f"Weather report for {city} ({format} format): 20°C, sunny"
assert get_weather.name == "fetch_weather"
assert get_weather.description == "A tool to check the weather."
def test_tool_decorator_with_inputs_and_outputs():
@tool(inputs_from_state={"format": "format"}, outputs_to_state={"output": {"source": "output"}})
def get_weather(
city: Annotated[str, "The target city"] = "Berlin",
format: Annotated[Literal["short", "long"], "Output format"] = "short",
) -> str:
"""Get weather report for a city."""
return f"Weather report for {city} ({format} format): 20°C, sunny"
assert get_weather.name == "get_weather"
assert get_weather.inputs_from_state == {"format": "format"}
assert get_weather.outputs_to_state == {"output": {"source": "output"}}
# Inputs should be excluded from auto-generated parameters
assert get_weather.parameters == {
"type": "object",
"properties": {"city": {"type": "string", "description": "The target city", "default": "Berlin"}},
}
def test_remove_title_from_schema():
complex_schema = {
"properties": {

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import re
import pytest
from haystack.tools.tool import Tool, ToolInvocationError, deserialize_tools_inplace, _check_duplicate_tool_names
@ -24,12 +25,31 @@ class TestTool:
assert tool.description == "Get weather report"
assert tool.parameters == parameters
assert tool.function == get_weather_report
assert tool.inputs_from_state is None
assert tool.outputs_to_state is None
def test_init_invalid_parameters(self):
parameters = {"type": "invalid", "properties": {"city": {"type": "string"}}}
params = {"type": "invalid", "properties": {"city": {"type": "string"}}}
with pytest.raises(ValueError):
Tool(name="irrelevant", description="irrelevant", parameters=parameters, function=get_weather_report)
Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report)
@pytest.mark.parametrize(
"outputs_to_state",
[
pytest.param({"documents": ["some_value"]}, id="config-not-a-dict"),
pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"),
pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"),
],
)
def test_init_invalid_output_structure(self, outputs_to_state):
with pytest.raises(ValueError):
Tool(
name="irrelevant",
description="irrelevant",
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
function=get_weather_report,
outputs_to_state=outputs_to_state,
)
def test_tool_spec(self):
tool = Tool(
@ -49,13 +69,21 @@ class TestTool:
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)
with pytest.raises(ToolInvocationError):
with pytest.raises(
ToolInvocationError,
match=re.escape(
"Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required positional argument: 'city'"
),
):
tool.invoke()
def test_to_dict(self):
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
name="weather",
description="Get weather report",
parameters=parameters,
function=get_weather_report,
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
)
assert tool.to_dict() == {
@ -65,6 +93,8 @@ class TestTool:
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"inputs_from_state": None,
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
@ -76,6 +106,7 @@ class TestTool:
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
@ -85,6 +116,8 @@ class TestTool:
assert tool.description == "Get weather report"
assert tool.parameters == parameters
assert tool.function == get_weather_report
assert tool.outputs_to_state["documents"]["source"] == "docs"
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
def test_deserialize_tools_inplace():