diff --git a/docs/pydoc/config/data_classess_api.yml b/docs/pydoc/config/data_classess_api.yml index 0a17ce3d4..5a33103a6 100644 --- a/docs/pydoc/config/data_classess_api.yml +++ b/docs/pydoc/config/data_classess_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/dataclasses] modules: - ["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding",] + ["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 8703a6695..e70ea794b 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -2,32 +2,44 @@ # # SPDX-License-Identifier: Apache-2.0 +import inspect import json -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses.chat_message import ChatMessage, ToolCall +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.core.component.sockets import Sockets +from haystack.dataclasses import ChatMessage, State, ToolCall +from haystack.tools.component_tool import ComponentTool from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace -_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}." -_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}." -_TOOL_RESULT_CONVERSION_FAILURE = ( - "Failed to convert tool result to string using '{conversion_function}'. Error: {error}." -) +logger = logging.getLogger(__name__) -class ToolNotFoundException(Exception): - """ - Exception raised when a tool is not found in the list of available tools. - """ +class ToolInvokerError(Exception): + """Base exception class for ToolInvoker errors.""" - pass + def __init__(self, message: str): + super().__init__(message) -class StringConversionError(Exception): - """ - Exception raised when the conversion of a tool result to a string fails. - """ +class ToolNotFoundException(ToolInvokerError): + """Exception raised when a tool is not found in the list of available tools.""" + + def __init__(self, tool_name: str, available_tools: List[str]): + message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}" + super().__init__(message) + + +class StringConversionError(ToolInvokerError): + """Exception raised when the conversion of a tool result to a string fails.""" + + def __init__(self, tool_name: str, conversion_function: str, error: Exception): + message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}" + super().__init__(message) + + +class ToolOutputMergeError(ToolInvokerError): + """Exception raised when merging tool outputs into state fails.""" pass @@ -37,6 +49,7 @@ class ToolInvoker: """ Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects. + Also handles reading/writing from a shared `State`. At initialization, the ToolInvoker component is provided with a list of available tools. At runtime, the component processes a list of ChatMessage object containing tool calls and invokes the corresponding tools. @@ -111,20 +124,36 @@ class ToolInvoker: :param convert_result_to_json_string: If True, the tool invocation result will be converted to a string using `json.dumps`. If False, the tool invocation result will be converted to a string using `str`. - :raises ValueError: If no tools are provided or if duplicate tool names are found. """ - if not tools: - raise ValueError("ToolInvoker requires at least one tool to be provided.") + raise ValueError("ToolInvoker requires at least one tool.") _check_duplicate_tool_names(tools) + tool_names = [tool.name for tool in tools] + duplicates = {name for name in tool_names if tool_names.count(name) > 1} + if duplicates: + raise ValueError(f"Duplicate tool names found: {duplicates}") self.tools = tools - self._tools_with_names = dict(zip([tool.name for tool in tools], tools)) + self._tools_with_names = dict(zip(tool_names, tools)) self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string + def _handle_error(self, error: Exception) -> str: + """ + Handles errors by logging and either raising or returning a fallback error message. + + :param error: The exception instance. + :returns: The fallback error message when `raise_on_failure` is False. + :raises: The provided error if `raise_on_failure` is True. + """ + logger.error("{error_exception}", error_exception=error) + if self.raise_on_failure: + # We re-raise the original error maintaining the exception chain + raise error + return str(error) + def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage: """ Prepares a ChatMessage with the result of a tool invocation. @@ -133,40 +162,141 @@ class ToolInvoker: The tool result. :returns: A ChatMessage object containing the tool result as a string. - :raises StringConversionError: If the conversion of the tool result to a string fails and `raise_on_failure` is True. """ error = False - - if self.convert_result_to_json_string: - try: + try: + if self.convert_result_to_json_string: # We disable ensure_ascii so special chars like emojis are not converted tool_result_str = json.dumps(result, ensure_ascii=False) - except Exception as e: - if self.raise_on_failure: - raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e - tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps") - error = True - return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) - - try: - tool_result_str = str(result) + else: + tool_result_str = str(result) except Exception as e: - if self.raise_on_failure: - raise StringConversionError("Failed to convert tool result to string using `str`") from e - tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str") - error = True + conversion_method = "json.dumps" if self.convert_result_to_json_string else "str" + try: + tool_result_str = self._handle_error(StringConversionError(tool_call.tool_name, conversion_method, e)) + error = True + except StringConversionError as conversion_error: + # If _handle_error re-raises, this properly preserves the chain + raise conversion_error from e + return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) - @component.output_types(tool_messages=List[ChatMessage]) - def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + @staticmethod + def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]: + """ + Combine LLM-provided arguments (llm_args) with state-based arguments. + + Tool arguments take precedence in the following order: + - LLM overrides state if the same param is present in both + - local tool.inputs mappings (if any) + - function signature name matching + """ + final_args = dict(llm_args) # start with LLM-provided + + # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets + # to find out which parameters the tool accepts. + if isinstance(tool, ComponentTool): + # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component + assert hasattr(tool._component, "__haystack_input__") and isinstance( + tool._component.__haystack_input__, Sockets + ) + func_params = set(tool._component.__haystack_input__._sockets_dict.keys()) + else: + func_params = set(inspect.signature(tool.function).parameters.keys()) + + # Determine the source of parameter mappings (explicit tool inputs or direct function parameters) + # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"} + if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict): + param_mappings = tool.inputs_from_state + else: + param_mappings = {name: name for name in func_params} + + # Populate final_args from state if not provided by LLM + for state_key, param_name in param_mappings.items(): + if param_name not in final_args and state.has(state_key): + final_args[param_name] = state.get(state_key) + + return final_args + + def _merge_tool_outputs(self, tool: Tool, result: Any, state: State) -> Any: + """ + Merges the tool result into the global state and determines the response message. + + This method processes the output of a tool execution and integrates it into the global state. + It also determines what message, if any, should be returned for further processing in a conversation. + + Processing Steps: + 1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned. + 2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state. + The return value in this case is simply the full `result` dictionary. + 3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be + processed), the method delegates to `_handle_tool_outputs` to process the output accordingly. + This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom + handlers. + + :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing. + :param result: The output from tool execution. Can be a dictionary, or any other type. + :param state: The global State object to which results should be merged. + :returns: Three possible values: + - A string message for conversation + - The merged result dictionary + - Or the raw result if not a dictionary + """ + # If result is not a dictionary, return it as the output message. + if not isinstance(result, dict): + return result + + # If there is no specific `outputs_to_state` mapping, we just return the full result + if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict): + return result + + # Handle tool outputs with specific mapping for message and state updates + return self._handle_tool_outputs(tool.outputs_to_state, result, state) + + @staticmethod + def _handle_tool_outputs(outputs_to_state: dict, result: dict, state: State) -> Union[dict, str]: + """ + Handles the `outputs_to_state` mapping from the tool and updates the state accordingly. + + :param outputs_to_state: Mapping of outputs from the tool. + :param result: Result of the tool execution. + :param state: Global state to merge results into. + :returns: Final message for LLM or the entire result. + """ + message_content = None + + for state_key, config in outputs_to_state.items(): + # Get the source key from the output config, otherwise use the entire result + source_key = config.get("source", None) + output_value = result if source_key is None else result.get(source_key) + + # Get the handler function, if any + handler = config.get("handler", None) + + if state_key == "message": + # Handle the message output separately + if handler is not None: + message_content = handler(output_value) + else: + message_content = str(output_value) + else: + # Merge other outputs into the state + state.set(state_key, output_value, handler_override=handler) + + # If no "message" key was found, return the result or message content + return message_content if message_content is not None else result + + @component.output_types(tool_messages=List[ChatMessage], state=State) + def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]: """ Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available. :param messages: A list of ChatMessage objects. + :param state: The runtime state that should be used by the tools. :returns: A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role. Each ChatMessage objects wraps the result of a tool invocation. @@ -177,39 +307,61 @@ class ToolInvoker: If the tool invocation fails and `raise_on_failure` is True. :raises StringConversionError: If the conversion of the tool result to a string fails and `raise_on_failure` is True. + :raises ToolOutputMergeError: + If merging tool outputs into state fails and `raise_on_failure` is True. """ + if state is None: + state = State(schema={}) + + # Only keep messages with tool calls + messages_with_tool_calls = [message for message in messages if message.tool_calls] + tool_messages = [] - - for message in messages: - tool_calls = message.tool_calls - if not tool_calls: - continue - - for tool_call in tool_calls: + for message in messages_with_tool_calls: + for tool_call in message.tool_calls: tool_name = tool_call.tool_name - tool_arguments = tool_call.arguments - if not tool_name in self._tools_with_names: - msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys()) - if self.raise_on_failure: - raise ToolNotFoundException(msg) - tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) + # Check if the tool is available, otherwise return an error message + if tool_name not in self._tools_with_names: + error_message = self._handle_error( + ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) + ) + tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) continue tool_to_invoke = self._tools_with_names[tool_name] + + # 1) Combine user + state inputs + llm_args = tool_call.arguments.copy() + final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + + # 2) Invoke the tool try: - tool_result = tool_to_invoke.invoke(**tool_arguments) + tool_result = tool_to_invoke.invoke(**final_args) except ToolInvocationError as e: - if self.raise_on_failure: - raise e - msg = _TOOL_INVOCATION_FAILURE.format(error=e) - tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) + error_message = self._handle_error(e) + tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) continue - tool_message = self._prepare_tool_result_message(tool_result, tool_call) - tool_messages.append(tool_message) + # 3) Merge outputs into state & create a single ChatMessage for the LLM + try: + tool_text = self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e - return {"tool_messages": tool_messages} + tool_messages.append(self._prepare_tool_result_message(result=tool_text, tool_call=tool_call)) + + return {"tool_messages": tool_messages, "state": state} def to_dict(self) -> Dict[str, Any]: """ diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index f8c9b8441..15da1ae84 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -13,6 +13,7 @@ _import_structure = { "chat_message": ["ChatMessage", "ChatRole", "TextContent", "ToolCall", "ToolCallResult"], "document": ["Document"], "sparse_embedding": ["SparseEmbedding"], + "state": ["State"], "streaming_chunk": [ "StreamingChunk", "AsyncStreamingCallbackT", @@ -28,6 +29,7 @@ if TYPE_CHECKING: from .chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult from .document import Document from .sparse_embedding import SparseEmbedding + from .state import State from .streaming_chunk import ( AsyncStreamingCallbackT, StreamingCallbackT, diff --git a/haystack/dataclasses/state.py b/haystack/dataclasses/state.py new file mode 100644 index 000000000..daf815e0b --- /dev/null +++ b/haystack/dataclasses/state.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional + +from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.utils.type_serialization import deserialize_type, serialize_type + + +def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert a schema dictionary to a serializable format. + + Converts each parameter's type and optional handler function into a serializable + format using type and callable serialization utilities. + + :param schema: Dictionary mapping parameter names to their type and handler configs + :returns: Dictionary with serialized type and handler information + """ + serialized_schema = {} + for param, config in schema.items(): + serialized_schema[param] = {"type": serialize_type(config["type"])} + if config.get("handler"): + serialized_schema[param]["handler"] = serialize_callable(config["handler"]) + + return serialized_schema + + +def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert a serialized schema dictionary back to its original format. + + Deserializes the type and optional handler function for each parameter from their + serialized format back into Python types and callables. + + :param schema: Dictionary containing serialized schema information + :returns: Dictionary with deserialized type and handler configurations + """ + deserialized_schema = {} + for param, config in schema.items(): + deserialized_schema[param] = {"type": deserialize_type(config["type"])} + + if config.get("handler"): + deserialized_schema[param]["handler"] = deserialize_callable(config["handler"]) + + return deserialized_schema + + +def _validate_schema(schema: Dict[str, Any]) -> None: + """ + Validate that a schema dictionary meets all required constraints. + + Checks that each parameter definition has a valid type field and that any handler + specified is a callable function. + + :param schema: Dictionary mapping parameter names to their type and handler configs + :raises ValueError: If schema validation fails due to missing or invalid fields + """ + for param, definition in schema.items(): + if "type" not in definition: + raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.") + if not _is_valid_type(definition["type"]): + raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") + if definition.get("handler") is not None and not callable(definition["handler"]): + raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") + + +class State: + """ + A dataclass that wraps a StateSchema and maintains an internal _data dictionary. + + Each schema entry has: + "parameter_name": { + "type": SomeType, + "handler": Optional[Callable[[Any, Any], Any]] + } + """ + + def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None): + """ + Initialize a State object with a schema and optional data. + + :param schema: Dictionary mapping parameter names to their type and handler configs. + Type must be a valid Python type, and handler must be a callable function or None. + If handler is None, the default handler for the type will be used. The default handlers are: + - For list types: `haystack.dataclasses.state_utils.merge_lists` + - For all other types: `haystack.dataclasses.state_utils.replace_values` + :param data: Optional dictionary of initial data to populate the state + """ + _validate_schema(schema) + self.schema = schema + self._data = data or {} + + # Set default handlers if not provided in schema + for definition in schema.values(): + # Skip if handler is already defined and not None + if definition.get("handler") is not None: + continue + # Set default handler based on type + if _is_list_type(definition["type"]): + definition["handler"] = merge_lists + else: + definition["handler"] = replace_values + + def get(self, key: str, default: Any = None) -> Any: + """ + Retrieve a value from the state by key. + + :param key: Key to look up in the state + :param default: Value to return if key is not found + :returns: Value associated with key or default if not found + """ + return self._data.get(key, default) + + def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None: + """ + Set or merge a value in the state according to schema rules. + + Value is merged or overwritten according to these rules: + - if handler_override is given, use that + - else use the handler defined in the schema for 'key' + + :param key: Key to store the value under + :param value: Value to store or merge + :param handler_override: Optional function to override the default merge behavior + """ + # If key not in schema, we throw an error + definition = self.schema.get(key, None) + if definition is None: + raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}") + + # Get current value from state and apply handler + current_value = self._data.get(key, None) + handler = handler_override or definition["handler"] + self._data[key] = handler(current_value, value) + + @property + def data(self): + """ + All current data of the state. + """ + return self._data + + def has(self, key: str) -> bool: + """ + Check if a key exists in the state. + + :param key: Key to check for existence + :returns: True if key exists in state, False otherwise + """ + return key in self._data diff --git a/haystack/dataclasses/state_utils.py b/haystack/dataclasses/state_utils.py new file mode 100644 index 000000000..19bcf1ded --- /dev/null +++ b/haystack/dataclasses/state_utils.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Any, List, TypeVar, Union, get_origin + +T = TypeVar("T") + + +def _is_valid_type(obj: Any) -> bool: + """ + Check if an object is a valid type annotation. + + Valid types include: + - Normal classes (str, dict, CustomClass) + - Generic types (List[str], Dict[str, int]) + - Union types (Union[str, int], Optional[str]) + + :param obj: The object to check + :return: True if the object is a valid type annotation, False otherwise + + Example usage: + >>> _is_valid_type(str) + True + >>> _is_valid_type(List[int]) + True + >>> _is_valid_type(Union[str, int]) + True + >>> _is_valid_type(42) + False + """ + # Handle Union types (including Optional) + if hasattr(obj, "__origin__") and obj.__origin__ is Union: + return True + + # Handle normal classes and generic types + return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"} + + +def _is_list_type(type_hint: Any) -> bool: + """ + Check if a type hint represents a list type. + + :param type_hint: The type hint to check + :return: True if the type hint represents a list, False otherwise + """ + return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list) + + +def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]: + """ + Merges two values into a single list. + + If either `current` or `new` is not already a list, it is converted into one. + The function ensures that both inputs are treated as lists and concatenates them. + + If `current` is None, it is treated as an empty list. + + :param current: The existing value(s), either a single item or a list. + :param new: The new value(s) to merge, either a single item or a list. + :return: A list containing elements from both `current` and `new`. + """ + current_list = [] if current is None else current if isinstance(current, list) else [current] + new_list = new if isinstance(new, list) else [new] + return current_list + new_list + + +def replace_values(current: Any, new: Any) -> Any: + """ + Replace the `current` value with the `new` value. + + :param current: The existing value + :param new: The new value to replace + :return: The new value + """ + return new diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 93e384c37..ca083f496 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from copy import copy, deepcopy from dataclasses import fields, is_dataclass from inspect import getdoc from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin @@ -19,6 +20,7 @@ from haystack.core.serialization import ( from haystack.lazy_imports import LazyImport from haystack.tools import Tool from haystack.tools.errors import SchemaGenerationError +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import: from docstring_parser import parse @@ -87,13 +89,34 @@ class ComponentTool(Tool): """ - def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None): + def __init__( + self, + component: Component, + name: Optional[str] = None, + description: Optional[str] = None, + parameters: Optional[Dict[str, Any]] = None, + *, + inputs_from_state: Optional[Dict[str, Any]] = None, + outputs_to_state: Optional[Dict[str, Any]] = None, + ): """ Create a Tool instance from a Haystack component. :param component: The Haystack component to wrap as a tool. :param name: Optional name for the tool (defaults to snake_case of component class name). :param description: Optional description (defaults to component's docstring). + :param parameters: + A JSON schema defining the parameters expected by the Tool. + Will fall back to the parameters defined in the component's run method signature if not provided. + :param inputs_from_state: + Optional dictionary mapping state keys to tool parameter names. + Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter. + :param outputs_to_state: + Optional dictionary defining how tool outputs map to keys within state as well as optional handlers. + Example: { + "documents": {"source": "docs", "handler": custom_handler}, + "message": {"source": "summary", "handler": format_summary} + } :raises ValueError: If the component is invalid or schema generation fails. """ if not isinstance(component, Component): @@ -110,8 +133,9 @@ class ComponentTool(Tool): ) raise ValueError(msg) + self._unresolved_parameters = parameters # Create the tools schema from the component run method parameters - tool_schema = self._create_tool_parameters_schema(component) + tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {}) def component_invoker(**kwargs): """ @@ -155,16 +179,39 @@ class ComponentTool(Tool): description = description or component.__doc__ or name # Create the Tool instance with the component invoker as the function to be called and the schema - super().__init__(name, description, tool_schema, component_invoker) + super().__init__( + name=name, + description=description, + parameters=tool_schema, + function=component_invoker, + inputs_from_state=inputs_from_state, + outputs_to_state=outputs_to_state, + ) self._component = component def to_dict(self) -> Dict[str, Any]: """ Serializes the ComponentTool to a dictionary. """ - # we do not serialize the function in this case: it can be recreated from the component at deserialization time - serialized = {"name": self.name, "description": self.description, "parameters": self.parameters} - serialized["component"] = component_to_dict(obj=self._component, name=self.name) + serialized_component = component_to_dict(obj=self._component, name=self.name) + + serialized = { + "component": serialized_component, + "name": self.name, + "description": self.description, + "parameters": self._unresolved_parameters, + "inputs_from_state": self.inputs_from_state, + } + + if self.outputs_to_state is not None: + serialized_outputs = {} + for key, config in self.outputs_to_state.items(): + serialized_config = config.copy() + if "handler" in config: + serialized_config["handler"] = serialize_callable(config["handler"]) + serialized_outputs[key] = serialized_config + serialized["outputs_to_state"] = serialized_outputs + return {"type": generate_qualified_class_name(type(self)), "data": serialized} @classmethod @@ -175,9 +222,26 @@ class ComponentTool(Tool): inner_data = data["data"] component_class = import_class_by_name(inner_data["component"]["type"]) component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"]) - return cls(component=component, name=inner_data["name"], description=inner_data["description"]) - def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]: + if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]: + deserialized_outputs = {} + for key, config in inner_data["outputs_to_state"].items(): + deserialized_config = config.copy() + if "handler" in config: + deserialized_config["handler"] = deserialize_callable(config["handler"]) + deserialized_outputs[key] = deserialized_config + inner_data["outputs_to_state"] = deserialized_outputs + + return cls( + component=component, + name=inner_data["name"], + description=inner_data["description"], + parameters=inner_data.get("parameters", None), + inputs_from_state=inner_data.get("inputs_from_state", None), + outputs_to_state=inner_data.get("outputs_to_state", None), + ) + + def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]: """ Creates an OpenAI tools schema from a component's run method parameters. @@ -191,6 +255,8 @@ class ComponentTool(Tool): param_descriptions = self._get_param_descriptions(component.run) for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined] + if inputs_from_state is not None and input_name in inputs_from_state: + continue input_type = socket.type description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") @@ -326,3 +392,29 @@ class ComponentTool(Tool): schema["default"] = default return schema + + def __deepcopy__(self, memo: Dict[Any, Any]) -> "ComponentTool": + # Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758) + # When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool + # We overwrite ComponentTool.__deepcopy__ to fix this until a more comprehensive fix is merged. + # We track the issue here: https://github.com/deepset-ai/haystack/issues/9011 + result = copy(self) + + # Add the object to the memo dictionary to handle circular references + memo[id(self)] = result + + # Deep copy all attributes with exception handling + for key, value in self.__dict__.items(): + try: + # Try to deep copy the attribute + setattr(result, key, deepcopy(value, memo)) + except TypeError: + # Fall back to using the original attribute for components that use Jinja2-templates + logger.debug( + "deepcopy of ComponentTool {tool_name} failed. Using original attribute '{attribute}' instead.", + tool_name=self.name, + attribute=key, + ) + setattr(result, key, getattr(self, key)) + + return result diff --git a/haystack/tools/from_function.py b/haystack/tools/from_function.py index 67fd47620..f48579b9f 100644 --- a/haystack/tools/from_function.py +++ b/haystack/tools/from_function.py @@ -3,16 +3,20 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union from pydantic import create_model -from haystack.tools.errors import SchemaGenerationError -from haystack.tools.tool import Tool +from .errors import SchemaGenerationError +from .tool import Tool def create_tool_from_function( - function: Callable, name: Optional[str] = None, description: Optional[str] = None + function: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + inputs_from_state: Optional[Dict[str, str]] = None, + outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None, ) -> "Tool": """ Create a Tool instance from a function. @@ -62,7 +66,15 @@ def create_tool_from_function( :param description: The description of the Tool. If not provided, the docstring of the function will be used. To intentionally leave the description empty, pass an empty string. - + :param inputs_from_state: + Optional dictionary mapping state keys to tool parameter names. + Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter. + :param outputs_to_state: + Optional dictionary defining how tool outputs map to state and message handling. + Example: { + "documents": {"source": "docs", "handler": custom_handler}, + "message": {"source": "summary", "handler": format_summary} + } :returns: The Tool created from the function. @@ -71,7 +83,6 @@ def create_tool_from_function( :raises SchemaGenerationError: If there is an error generating the JSON schema for the Tool. """ - tool_description = description if description is not None else (function.__doc__ or "") signature = inspect.signature(function) @@ -81,6 +92,10 @@ def create_tool_from_function( descriptions = {} for param_name, param in signature.parameters.items(): + # Skip adding parameter names that will be passed to the tool from State + if inputs_from_state and param_name in inputs_from_state.values(): + continue + if param.annotation is param.empty: raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") @@ -109,15 +124,33 @@ def create_tool_from_function( if param_name in schema["properties"]: schema["properties"][param_name]["description"] = param_description - return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + return Tool( + name=name or function.__name__, + description=tool_description, + parameters=schema, + function=function, + inputs_from_state=inputs_from_state, + outputs_to_state=outputs_to_state, + ) -def tool(function: Callable) -> Tool: +def tool( + function: Optional[Callable] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + inputs_from_state: Optional[Dict[str, str]] = None, + outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None, +) -> Union[Tool, Callable[[Callable], Tool]]: """ Decorator to convert a function into a Tool. - Tool name, description, and parameters are inferred from the function. - If you need to customize more the Tool, use `create_tool_from_function` instead. + Can be used with or without parameters: + @tool # without parameters + def my_function(): ... + + @tool(name="custom_name") # with parameters + def my_function(): ... ### Usage example ```python @@ -147,8 +180,27 @@ def tool(function: Callable) -> Tool: >>> }, >>> function=) ``` + + :param function: The function to decorate (when used without parameters) + :param name: Optional custom name for the tool + :param description: Optional custom description + :param inputs_from_state: Optional dictionary mapping state keys to tool parameter names + :param outputs_to_state: Optional dictionary defining how tool outputs map to state and message handling + :return: Either a Tool instance or a decorator function that will create one """ - return create_tool_from_function(function) + + def decorator(func: Callable) -> Tool: + return create_tool_from_function( + function=func, + name=name, + description=description, + inputs_from_state=inputs_from_state, + outputs_to_state=outputs_to_state, + ) + + if function is None: + return decorator + return decorator(function) def _remove_title_from_schema(schema: Dict[str, Any]): diff --git a/haystack/tools/tool.py b/haystack/tools/tool.py index fd4802879..71e0904ac 100644 --- a/haystack/tools/tool.py +++ b/haystack/tools/tool.py @@ -29,12 +29,23 @@ class Tool: A JSON schema defining the parameters expected by the Tool. :param function: The function that will be invoked when the Tool is called. + :param inputs_from_state: + Optional dictionary mapping state keys to tool parameter names. + Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter. + :param outputs_to_state: + Optional dictionary defining how tool outputs map to keys within state as well as optional handlers. + Example: { + "documents": {"source": "docs", "handler": custom_handler}, + "message": {"source": "summary", "handler": format_summary} + } """ name: str description: str parameters: Dict[str, Any] function: Callable + inputs_from_state: Optional[Dict[str, str]] = None + outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None def __post_init__(self): # Check that the parameters define a valid JSON schema @@ -43,6 +54,16 @@ class Tool: except SchemaError as e: raise ValueError("The provided parameters do not define a valid JSON schema") from e + # Validate outputs structure if provided + if self.outputs_to_state is not None: + for key, config in self.outputs_to_state.items(): + if not isinstance(config, dict): + raise ValueError(f"Output configuration for key '{key}' must be a dictionary") + if "source" in config and not isinstance(config["source"], str): + raise ValueError(f"Output source for key '{key}' must be a string.") + if "handler" in config and not callable(config["handler"]): + raise ValueError(f"Output handler for key '{key}' must be callable") + @property def tool_spec(self) -> Dict[str, Any]: """ @@ -54,11 +75,12 @@ class Tool: """ Invoke the Tool with the provided keyword arguments. """ - try: result = self.function(**kwargs) except Exception as e: - raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e + raise ToolInvocationError( + f"Failed to invoke Tool `{self.name}` with parameters {kwargs}. Error: {e}" + ) from e return result def to_dict(self) -> Dict[str, Any]: @@ -68,9 +90,19 @@ class Tool: :returns: Dictionary with serialized data. """ - data = asdict(self) data["function"] = serialize_callable(self.function) + + # Serialize output handlers if they exist + if self.outputs_to_state: + serialized_outputs = {} + for key, config in self.outputs_to_state.items(): + serialized_config = config.copy() + if "handler" in config: + serialized_config["handler"] = serialize_callable(config["handler"]) + serialized_outputs[key] = serialized_config + data["outputs_to_state"] = serialized_outputs + return {"type": generate_qualified_class_name(type(self)), "data": data} @classmethod @@ -85,6 +117,17 @@ class Tool: """ init_parameters = data["data"] init_parameters["function"] = deserialize_callable(init_parameters["function"]) + + # Deserialize output handlers if they exist + if "outputs_to_state" in init_parameters and init_parameters["outputs_to_state"]: + deserialized_outputs = {} + for key, config in init_parameters["outputs_to_state"].items(): + deserialized_config = config.copy() + if "handler" in config: + deserialized_config["handler"] = deserialize_callable(config["handler"]) + deserialized_outputs[key] = deserialized_config + init_parameters["outputs_to_state"] = deserialized_outputs + return cls(**init_parameters) diff --git a/releasenotes/notes/move-tool-ff98d464d3e5d775.yaml b/releasenotes/notes/move-tool-ff98d464d3e5d775.yaml new file mode 100644 index 000000000..96f7d12b9 --- /dev/null +++ b/releasenotes/notes/move-tool-ff98d464d3e5d775.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Introduce new State dataclass with a customizable schema for managing Agent state. + Enhance error logging of Tool and extend the ToolInvoker component to work with newly introduced State. +fixes: + - | + Fix an issue that prevented Jinja2-based ComponentTools from being passed into pipelines at runtime. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index dc0f8c706..14a9f4d13 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -258,7 +258,9 @@ class TestHuggingFaceAPIChatGenerator: "data": { "description": "description", "function": "builtins.print", + "inputs_from_state": None, "name": "name", + "outputs_to_state": None, "parameters": {"x": {"type": "string"}}, }, } @@ -319,7 +321,9 @@ class TestHuggingFaceAPIChatGenerator: { "type": "haystack.tools.tool.Tool", "data": { + "inputs_from_state": None, "name": "name", + "outputs_to_state": None, "description": "description", "parameters": {"x": {"type": "string"}}, "function": "builtins.print", diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 8bfc1a10b..dd407e72c 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -192,7 +192,9 @@ class TestHuggingFaceLocalChatGenerator: { "type": "haystack.tools.tool.Tool", "data": { + "inputs_from_state": None, "name": "weather", + "outputs_to_state": None, "description": "useful to determine the weather in a given location", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, "function": "generators.chat.test_hugging_face_local.get_weather", diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 5b57dddf9..e76722ca5 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -205,7 +205,9 @@ class TestOpenAIChatGenerator: "data": { "description": "description", "function": "builtins.print", + "inputs_from_state": None, "name": "name", + "outputs_to_state": None, "parameters": {"x": {"type": "string"}}, }, } diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 4ed4126e6..135631305 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -3,10 +3,14 @@ import datetime from haystack import Pipeline -from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole -from haystack.tools.tool import Tool, ToolInvocationError -from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError + +from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError +from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole +from haystack.dataclasses.state import State +from haystack.tools import ComponentTool, Tool +from haystack.tools.errors import ToolInvocationError def weather_function(location): @@ -82,6 +86,49 @@ class TestToolInvoker: with pytest.raises(ValueError): ToolInvoker(tools=[weather_tool, new_tool]) + def test_inject_state_args_no_tool_inputs(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + ) + state = State(schema={"location": {"type": str}}, data={"location": "Berlin"}) + args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state) + assert args == {"location": "Berlin"} + + def test_inject_state_args_no_tool_inputs_component_tool(self): + comp = PromptBuilder(template="Hello, {{name}}!") + prompt_tool = ComponentTool( + component=comp, name="prompt_tool", description="Creates a personalized greeting prompt." + ) + state = State(schema={"name": {"type": str}}, data={"name": "James"}) + args = ToolInvoker._inject_state_args(tool=prompt_tool, llm_args={}, state=state) + assert args == {"name": "James"} + + def test_inject_state_args_with_tool_inputs(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + inputs_from_state={"loc": "location"}, + ) + state = State(schema={"location": {"type": str}}, data={"loc": "Berlin"}) + args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state) + assert args == {"location": "Berlin"} + + def test_inject_state_args_param_in_state_and_llm(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + ) + state = State(schema={"location": {"type": str}}, data={"location": "Berlin"}) + args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state) + assert args == {"location": "Paris"} + def test_run(self, invoker): tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) message = ChatMessage.from_assistant(tool_calls=[tool_call]) @@ -104,14 +151,14 @@ class TestToolInvoker: def test_run_no_messages(self, invoker): result = invoker.run(messages=[]) - assert result == {"tool_messages": []} + assert result["tool_messages"] == [] def test_run_no_tool_calls(self, invoker): user_message = ChatMessage.from_user(text="Hello!") assistant_message = ChatMessage.from_assistant(text="How can I help you?") result = invoker.run(messages=[user_message, assistant_message]) - assert result == {"tool_messages": []} + assert result["tool_messages"] == [] def test_run_multiple_tool_calls(self, invoker): tool_calls = [ @@ -143,8 +190,8 @@ class TestToolInvoker: with pytest.raises(ToolNotFoundException): invoker.run(messages=[tool_call_message]) - def test_tool_not_found_does_not_raise_exception(self, invoker): - invoker.raise_on_failure = False + def test_tool_not_found_does_not_raise_exception(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=False) tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"}) tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) @@ -162,8 +209,8 @@ class TestToolInvoker: with pytest.raises(ToolInvocationError): faulty_invoker.run(messages=[tool_call_message]) - def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker): - faulty_invoker.raise_on_failure = False + def test_tool_invocation_error_does_not_raise_exception(self, faulty_tool): + faulty_invoker = ToolInvoker(tools=[faulty_tool], raise_on_failure=False, convert_result_to_json_string=False) tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) @@ -171,10 +218,10 @@ class TestToolInvoker: result = faulty_invoker.run(messages=[tool_call_message]) tool_message = result["tool_messages"][0] assert tool_message.tool_call_results[0].error - assert "invocation failed" in tool_message.tool_call_results[0].result + assert "Failed to invoke" in tool_message.tool_call_results[0].result - def test_string_conversion_error(self, invoker): - invoker.convert_result_to_json_string = True + def test_string_conversion_error(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=True) tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) @@ -182,9 +229,8 @@ class TestToolInvoker: with pytest.raises(StringConversionError): invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call) - def test_string_conversion_error_does_not_raise_exception(self, invoker): - invoker.convert_result_to_json_string = True - invoker.raise_on_failure = False + def test_string_conversion_error_does_not_raise_exception(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=True) tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) @@ -231,8 +277,8 @@ class TestToolInvoker: pipeline_dict = pipeline.to_dict() assert pipeline_dict == { "metadata": {}, - "max_runs_per_component": 100, "connection_type_validation": True, + "max_runs_per_component": 100, "components": { "invoker": { "type": "haystack.components.tools.tool_invoker.ToolInvoker", @@ -249,6 +295,8 @@ class TestToolInvoker: "required": ["location"], }, "function": "tools.test_tool_invoker.weather_function", + "inputs_from_state": None, + "outputs_to_state": None, }, } ], @@ -279,3 +327,93 @@ class TestToolInvoker: new_pipeline = Pipeline.loads(pipeline_yaml) assert new_pipeline == pipeline + + +class TestMergeToolOutputs: + def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"weather": {"type": str}}) + merged_results = invoker._merge_tool_outputs(tool=weather_tool, result="test", state=state) + assert merged_results == "test" + assert state.data == {} + + def test_merge_tool_outputs_empty_dict(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"weather": {"type": str}}) + merged_results = invoker._merge_tool_outputs(tool=weather_tool, result={}, state=state) + assert merged_results == {} + assert state.data == {} + + def test_merge_tool_outputs_no_output_mapping(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"weather": {"type": str}}) + merged_results = invoker._merge_tool_outputs( + tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state + ) + assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"} + assert state.data == {} + + def test_merge_tool_outputs_with_output_mapping(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + outputs_to_state={"weather": {"source": "weather"}}, + ) + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"weather": {"type": str}}) + merged_results = invoker._merge_tool_outputs( + tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state + ) + assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"} + assert state.data == {"weather": "sunny"} + + def test_merge_tool_outputs_with_output_mapping_2(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + outputs_to_state={"all_weather_results": {}}, + ) + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"all_weather_results": {"type": str}}) + merged_results = invoker._merge_tool_outputs( + tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state + ) + assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"} + assert state.data == {"all_weather_results": {"weather": "sunny", "temperature": 14, "unit": "celsius"}} + + def test_merge_tool_outputs_with_output_mapping_and_handler(self): + handler = lambda old, new: f"{new}" + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + outputs_to_state={"temperature": {"source": "temperature", "handler": handler}}, + ) + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={"temperature": {"type": str}}) + merged_results = invoker._merge_tool_outputs( + tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state + ) + assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"} + assert state.data == {"temperature": "14"} + + def test_merge_tool_outputs_with_message_output_mapping(self): + weather_tool = Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + outputs_to_state={"message": {"source": "weather"}}, + ) + invoker = ToolInvoker(tools=[weather_tool]) + state = State(schema={}) + merged_results = invoker._merge_tool_outputs( + tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state + ) + assert merged_results == "sunny" + assert state.data == {} diff --git a/test/dataclasses/test_state.py b/test/dataclasses/test_state.py new file mode 100644 index 000000000..748bf2dc6 --- /dev/null +++ b/test/dataclasses/test_state.py @@ -0,0 +1,146 @@ +import pytest +from typing import List, Dict + +from haystack.dataclasses.state import State, _validate_schema + + +@pytest.fixture +def basic_schema(): + return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}} + + +@pytest.fixture +def complex_schema(): + return { + "numbers": { + "type": list, + "handler": lambda current, new: sorted(set(current + new)) if current else sorted(set(new)), + }, + "metadata": {"type": dict}, + "name": {"type": str}, + } + + +def test_validate_schema_valid(basic_schema): + # Should not raise any exceptions + _validate_schema(basic_schema) + + +def test_validate_schema_invalid_type(): + invalid_schema = {"test": {"type": "not_a_type"}} + with pytest.raises(ValueError, match="must be a Python type"): + _validate_schema(invalid_schema) + + +def test_validate_schema_missing_type(): + invalid_schema = {"test": {"handler": lambda x, y: x + y}} + with pytest.raises(ValueError, match="missing a 'type' entry"): + _validate_schema(invalid_schema) + + +def test_validate_schema_invalid_handler(): + invalid_schema = {"test": {"type": list, "handler": "not_callable"}} + with pytest.raises(ValueError, match="must be callable or None"): + _validate_schema(invalid_schema) + + +def test_state_initialization(basic_schema): + # Test empty initialization + state = State(basic_schema) + assert state.data == {} + + # Test initialization with data + initial_data = {"numbers": [1, 2, 3], "name": "test"} + state = State(basic_schema, initial_data) + assert state.data["numbers"] == [1, 2, 3] + assert state.data["name"] == "test" + + +def test_state_get(basic_schema): + state = State(basic_schema, {"name": "test"}) + assert state.get("name") == "test" + assert state.get("non_existent") is None + assert state.get("non_existent", "default") == "default" + + +def test_state_set_basic(basic_schema): + state = State(basic_schema) + + # Test setting new values + state.set("numbers", [1, 2]) + assert state.get("numbers") == [1, 2] + + # Test updating existing values + state.set("numbers", [3, 4]) + assert state.get("numbers") == [1, 2, 3, 4] + + +def test_state_set_with_handler(complex_schema): + state = State(complex_schema) + + # Test custom handler for numbers + state.set("numbers", [3, 2, 1]) + assert state.get("numbers") == [1, 2, 3] + + state.set("numbers", [6, 5, 4]) + assert state.get("numbers") == [1, 2, 3, 4, 5, 6] + + +def test_state_set_with_handler_override(basic_schema): + state = State(basic_schema) + + # Custom handler that concatenates strings + custom_handler = lambda current, new: f"{current}-{new}" if current else new + + state.set("name", "first") + state.set("name", "second", handler_override=custom_handler) + assert state.get("name") == "first-second" + + +def test_state_has(basic_schema): + state = State(basic_schema, {"name": "test"}) + assert state.has("name") is True + assert state.has("non_existent") is False + + +def test_state_empty_schema(): + state = State({}) + assert state.data == {} + with pytest.raises(ValueError, match="Key 'any_key' not found in schema"): + state.set("any_key", "value") + + +def test_state_none_values(basic_schema): + state = State(basic_schema) + state.set("name", None) + assert state.get("name") is None + state.set("name", "value") + assert state.get("name") == "value" + + +def test_state_merge_lists(basic_schema): + state = State(basic_schema) + state.set("numbers", "not_a_list") + assert state.get("numbers") == ["not_a_list"] + state.set("numbers", [1, 2]) + assert state.get("numbers") == ["not_a_list", 1, 2] + + +def test_state_nested_structures(): + schema = { + "complex": { + "type": Dict[str, List[int]], + "handler": lambda current, new: { + k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys()) + } + if current + else new, + } + } + + state = State(schema) + state.set("complex", {"a": [1, 2], "b": [3, 4]}) + state.set("complex", {"b": [5, 6], "c": [7, 8]}) + + expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]} + assert state.get("complex") == expected diff --git a/test/dataclasses/test_state_utils.py b/test/dataclasses/test_state_utils.py new file mode 100644 index 000000000..150ecca14 --- /dev/null +++ b/test/dataclasses/test_state_utils.py @@ -0,0 +1,161 @@ +import pytest +from typing import List, Dict, Optional, Union, TypeVar, Generic +from dataclasses import dataclass + +from haystack.dataclasses.state_utils import _is_list_type, merge_lists, _is_valid_type + +import inspect + + +def test_is_list_type(): + assert _is_list_type(list) is True + assert _is_list_type(List[int]) is True + assert _is_list_type(List[str]) is True + assert _is_list_type(dict) is False + assert _is_list_type(int) is False + assert _is_list_type(Union[List[int], None]) is False + + +class TestMergeLists: + def test_merge_two_lists(self): + current = [1, 2, 3] + new = [4, 5, 6] + result = merge_lists(current, new) + assert result == [1, 2, 3, 4, 5, 6] + # Ensure original lists weren't modified + assert current == [1, 2, 3] + assert new == [4, 5, 6] + + def test_append_to_list(self): + current = [1, 2, 3] + new = 4 + result = merge_lists(current, new) + assert result == [1, 2, 3, 4] + assert current == [1, 2, 3] # Ensure original wasn't modified + + def test_create_new_list(self): + current = 1 + new = 2 + result = merge_lists(current, new) + assert result == [1, 2] + + def test_replace_with_list(self): + current = 1 + new = [2, 3] + result = merge_lists(current, new) + assert result == [1, 2, 3] + + +class TestIsValidType: + def test_builtin_types(self): + assert _is_valid_type(str) is True + assert _is_valid_type(int) is True + assert _is_valid_type(dict) is True + assert _is_valid_type(list) is True + assert _is_valid_type(tuple) is True + assert _is_valid_type(set) is True + assert _is_valid_type(bool) is True + assert _is_valid_type(float) is True + + def test_generic_types(self): + assert _is_valid_type(List[str]) is True + assert _is_valid_type(Dict[str, int]) is True + assert _is_valid_type(List[Dict[str, int]]) is True + assert _is_valid_type(Dict[str, List[int]]) is True + + def test_custom_classes(self): + @dataclass + class CustomClass: + value: int + + T = TypeVar("T") + + class GenericCustomClass(Generic[T]): + pass + + # Test regular and generic custom classes + assert _is_valid_type(CustomClass) is True + assert _is_valid_type(GenericCustomClass) is True + assert _is_valid_type(GenericCustomClass[int]) is True + + # Test generic types with custom classes + assert _is_valid_type(List[CustomClass]) is True + assert _is_valid_type(Dict[str, CustomClass]) is True + assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True + + def test_invalid_types(self): + # Test regular values + assert _is_valid_type(42) is False + assert _is_valid_type("string") is False + assert _is_valid_type([1, 2, 3]) is False + assert _is_valid_type({"a": 1}) is False + assert _is_valid_type(True) is False + + # Test class instances + @dataclass + class SampleClass: + value: int + + instance = SampleClass(42) + assert _is_valid_type(instance) is False + + # Test callable objects + assert _is_valid_type(len) is False + assert _is_valid_type(lambda x: x) is False + assert _is_valid_type(print) is False + + def test_union_and_optional_types(self): + # Test basic Union types + assert _is_valid_type(Union[str, int]) is True + assert _is_valid_type(Union[str, None]) is True + assert _is_valid_type(Union[List[int], Dict[str, str]]) is True + + # Test Optional types (which are Union[T, None]) + assert _is_valid_type(Optional[str]) is True + assert _is_valid_type(Optional[List[int]]) is True + assert _is_valid_type(Optional[Dict[str, list]]) is True + + # Test that Union itself is not a valid type (only instantiated Unions are) + assert _is_valid_type(Union) is False + + def test_nested_generic_types(self): + assert _is_valid_type(List[List[Dict[str, List[int]]]]) is True + assert _is_valid_type(Dict[str, List[Dict[str, set]]]) is True + assert _is_valid_type(Dict[str, Optional[List[int]]]) is True + assert _is_valid_type(List[Union[str, Dict[str, List[int]]]]) is True + + def test_edge_cases(self): + # Test None and NoneType + assert _is_valid_type(None) is False + assert _is_valid_type(type(None)) is True + + # Test functions and methods + def sample_func(): + pass + + assert _is_valid_type(sample_func) is False + assert _is_valid_type(type(sample_func)) is True + + # Test modules + assert _is_valid_type(inspect) is False + + # Test type itself + assert _is_valid_type(type) is True + + @pytest.mark.parametrize( + "test_input,expected", + [ + (str, True), + (int, True), + (List[int], True), + (Dict[str, int], True), + (Union[str, int], True), + (Optional[str], True), + (42, False), + ("string", False), + ([1, 2, 3], False), + (lambda x: x, False), + ], + ) + def test_parametrized_cases(self, test_input, expected): + assert _is_valid_type(test_input) is expected diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index 338f420ee..c8ae49c35 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -4,20 +4,21 @@ import json import os +from copy import deepcopy from dataclasses import dataclass from typing import Dict, List import pytest from haystack import Pipeline, component +from haystack.components.builders import PromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator -from haystack.components.tools.tool_invoker import ToolInvoker +from haystack.components.tools import ToolInvoker from haystack.components.websearch.serper_dev import SerperDevWebSearch from haystack.dataclasses import ChatMessage, ChatRole, Document from haystack.tools import ComponentTool from haystack.utils.auth import Secret - ### Component and Model Definitions @@ -121,6 +122,13 @@ class DocumentProcessor: return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])} +def output_handler(old, new): + """ + Output handler to test serialization. + """ + return old + new + + ## Unit tests class TestToolComponent: def test_from_component_basic(self): @@ -148,6 +156,22 @@ class TestToolComponent: assert len(tool.description) == 1024 + def test_from_component_with_inputs(self): + component = SimpleComponent() + + tool = ComponentTool(component=component, inputs_from_state={"text": "text"}) + + assert tool.inputs_from_state == {"text": "text"} + # Inputs should be excluded from schema generation + assert tool.parameters == {"type": "object", "properties": {}} + + def test_from_component_with_outputs(self): + component = SimpleComponent() + + tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}}) + + assert tool.outputs_to_state == {"replies": {"source": "reply"}} + def test_from_component_with_dataclass(self): component = UserGreeter() @@ -466,7 +490,7 @@ class TestToolComponentInPipelineWithOpenAI: pipeline.connect("llm.replies", "tool_invoker.messages") message = ChatMessage.from_user( - text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields." + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." ) result = pipeline.run({"llm": {"messages": [message]}}) @@ -500,7 +524,7 @@ class TestToolComponentInPipelineWithOpenAI: pipeline.connect("llm.replies", "tool_invoker.messages") message = ChatMessage.from_user( - text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields." + text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." ) result = pipeline.run({"llm": {"messages": [message]}}) @@ -576,7 +600,13 @@ class TestToolComponentInPipelineWithOpenAI: def test_component_tool_serde(self): component = SimpleComponent() - tool = ComponentTool(component=component, name="simple_tool", description="A simple tool") + tool = ComponentTool( + component=component, + name="simple_tool", + description="A simple tool", + inputs_from_state={"test": "input"}, + outputs_to_state={"output": {"source": "out", "handler": output_handler}}, + ) # Test serialization tool_dict = tool.to_dict() @@ -584,12 +614,16 @@ class TestToolComponentInPipelineWithOpenAI: assert tool_dict["data"]["name"] == "simple_tool" assert tool_dict["data"]["description"] == "A simple tool" assert "component" in tool_dict["data"] + assert tool_dict["data"]["inputs_from_state"] == {"test": "input"} + assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler" # Test deserialization new_tool = ComponentTool.from_dict(tool_dict) assert new_tool.name == tool.name assert new_tool.description == tool.description assert new_tool.parameters == tool.parameters + assert new_tool.inputs_from_state == tool.inputs_from_state + assert new_tool.outputs_to_state == tool.outputs_to_state assert isinstance(new_tool._component, SimpleComponent) def test_pipeline_component_fails(self): @@ -603,3 +637,21 @@ class TestToolComponentInPipelineWithOpenAI: # thus can't be used as tool with pytest.raises(ValueError, match="Component has been added to a pipeline"): ComponentTool(component=component) + + def test_deepcopy_with_jinja_based_component(self): + # Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758) + # When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool + # We overwrite ComponentTool.__deepcopy__ to fix this in experimental until a more comprehensive fix is merged. + # We track the issue here: https://github.com/deepset-ai/haystack/issues/9011 + + builder = PromptBuilder("{{query}}") + + tool = ComponentTool(component=builder) + result = tool.function(query="Hello") + + tool_copy = deepcopy(tool) + + result_from_copy = tool_copy.function(query="Hello") + + assert "prompt" in result_from_copy + assert result_from_copy["prompt"] == result["prompt"] diff --git a/test/tools/test_from_function.py b/test/tools/test_from_function.py index 5834b737d..478d7d59b 100644 --- a/test/tools/test_from_function.py +++ b/test/tools/test_from_function.py @@ -1,7 +1,7 @@ import pytest -from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool from haystack.tools.errors import SchemaGenerationError +from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool from typing import Annotated, Literal, Optional @@ -133,6 +133,38 @@ def test_tool_decorator_with_annotated_params(): assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny" +def test_tool_decorator_with_parameters(): + @tool(name="fetch_weather", description="A tool to check the weather.") + def get_weather( + city: Annotated[str, "The target city"] = "Berlin", + format: Annotated[Literal["short", "long"], "Output format"] = "short", + ) -> str: + """Get weather report for a city.""" + return f"Weather report for {city} ({format} format): 20°C, sunny" + + assert get_weather.name == "fetch_weather" + assert get_weather.description == "A tool to check the weather." + + +def test_tool_decorator_with_inputs_and_outputs(): + @tool(inputs_from_state={"format": "format"}, outputs_to_state={"output": {"source": "output"}}) + def get_weather( + city: Annotated[str, "The target city"] = "Berlin", + format: Annotated[Literal["short", "long"], "Output format"] = "short", + ) -> str: + """Get weather report for a city.""" + return f"Weather report for {city} ({format} format): 20°C, sunny" + + assert get_weather.name == "get_weather" + assert get_weather.inputs_from_state == {"format": "format"} + assert get_weather.outputs_to_state == {"output": {"source": "output"}} + # Inputs should be excluded from auto-generated parameters + assert get_weather.parameters == { + "type": "object", + "properties": {"city": {"type": "string", "description": "The target city", "default": "Berlin"}}, + } + + def test_remove_title_from_schema(): complex_schema = { "properties": { diff --git a/test/tools/test_tool.py b/test/tools/test_tool.py index 43ed42044..36d945cd8 100644 --- a/test/tools/test_tool.py +++ b/test/tools/test_tool.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import re import pytest from haystack.tools.tool import Tool, ToolInvocationError, deserialize_tools_inplace, _check_duplicate_tool_names @@ -24,12 +25,31 @@ class TestTool: assert tool.description == "Get weather report" assert tool.parameters == parameters assert tool.function == get_weather_report + assert tool.inputs_from_state is None + assert tool.outputs_to_state is None def test_init_invalid_parameters(self): - parameters = {"type": "invalid", "properties": {"city": {"type": "string"}}} - + params = {"type": "invalid", "properties": {"city": {"type": "string"}}} with pytest.raises(ValueError): - Tool(name="irrelevant", description="irrelevant", parameters=parameters, function=get_weather_report) + Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report) + + @pytest.mark.parametrize( + "outputs_to_state", + [ + pytest.param({"documents": ["some_value"]}, id="config-not-a-dict"), + pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"), + pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"), + ], + ) + def test_init_invalid_output_structure(self, outputs_to_state): + with pytest.raises(ValueError): + Tool( + name="irrelevant", + description="irrelevant", + parameters={"type": "object", "properties": {"city": {"type": "string"}}}, + function=get_weather_report, + outputs_to_state=outputs_to_state, + ) def test_tool_spec(self): tool = Tool( @@ -49,13 +69,21 @@ class TestTool: tool = Tool( name="weather", description="Get weather report", parameters=parameters, function=get_weather_report ) - - with pytest.raises(ToolInvocationError): + with pytest.raises( + ToolInvocationError, + match=re.escape( + "Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required positional argument: 'city'" + ), + ): tool.invoke() def test_to_dict(self): tool = Tool( - name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + name="weather", + description="Get weather report", + parameters=parameters, + function=get_weather_report, + outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}}, ) assert tool.to_dict() == { @@ -65,6 +93,8 @@ class TestTool: "description": "Get weather report", "parameters": parameters, "function": "test_tool.get_weather_report", + "inputs_from_state": None, + "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, }, } @@ -76,6 +106,7 @@ class TestTool: "description": "Get weather report", "parameters": parameters, "function": "test_tool.get_weather_report", + "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, }, } @@ -85,6 +116,8 @@ class TestTool: assert tool.description == "Get weather report" assert tool.parameters == parameters assert tool.function == get_weather_report + assert tool.outputs_to_state["documents"]["source"] == "docs" + assert tool.outputs_to_state["documents"]["handler"] == get_weather_report def test_deserialize_tools_inplace():