mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 06:28:33 +00:00
feat: integrate updates of Tool, ToolInvoker, State, create_tool_from_function, ComponentTool from haystack-experimental (#9113)
* update Tool,ToolInvoker,ComponentTool,create_tool_from_function * add State and its utils * add tests for State and its utils * update tests for Tool etc. * reno * fix circular imports * update experimental imports in tests * fix unit tests * fix ChatGenerator unit tests * mypy * add State to init and pydoc * explain State in more detail in release note * add test from #8913 * re-add _check_duplicate_tool_names and refactor imports * rename inputs and outputs
This commit is contained in:
parent
726b7ef0c4
commit
657d09d7f1
@ -2,7 +2,7 @@ loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/dataclasses]
|
||||
modules:
|
||||
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding",]
|
||||
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -2,32 +2,44 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.core.component.sockets import Sockets
|
||||
from haystack.dataclasses import ChatMessage, State, ToolCall
|
||||
from haystack.tools.component_tool import ComponentTool
|
||||
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
|
||||
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
|
||||
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
|
||||
_TOOL_RESULT_CONVERSION_FAILURE = (
|
||||
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolNotFoundException(Exception):
|
||||
"""
|
||||
Exception raised when a tool is not found in the list of available tools.
|
||||
"""
|
||||
class ToolInvokerError(Exception):
|
||||
"""Base exception class for ToolInvoker errors."""
|
||||
|
||||
pass
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StringConversionError(Exception):
|
||||
"""
|
||||
Exception raised when the conversion of a tool result to a string fails.
|
||||
"""
|
||||
class ToolNotFoundException(ToolInvokerError):
|
||||
"""Exception raised when a tool is not found in the list of available tools."""
|
||||
|
||||
def __init__(self, tool_name: str, available_tools: List[str]):
|
||||
message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StringConversionError(ToolInvokerError):
|
||||
"""Exception raised when the conversion of a tool result to a string fails."""
|
||||
|
||||
def __init__(self, tool_name: str, conversion_function: str, error: Exception):
|
||||
message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolOutputMergeError(ToolInvokerError):
|
||||
"""Exception raised when merging tool outputs into state fails."""
|
||||
|
||||
pass
|
||||
|
||||
@ -37,6 +49,7 @@ class ToolInvoker:
|
||||
"""
|
||||
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
|
||||
|
||||
Also handles reading/writing from a shared `State`.
|
||||
At initialization, the ToolInvoker component is provided with a list of available tools.
|
||||
At runtime, the component processes a list of ChatMessage object containing tool calls
|
||||
and invokes the corresponding tools.
|
||||
@ -111,20 +124,36 @@ class ToolInvoker:
|
||||
:param convert_result_to_json_string:
|
||||
If True, the tool invocation result will be converted to a string using `json.dumps`.
|
||||
If False, the tool invocation result will be converted to a string using `str`.
|
||||
|
||||
:raises ValueError:
|
||||
If no tools are provided or if duplicate tool names are found.
|
||||
"""
|
||||
|
||||
if not tools:
|
||||
raise ValueError("ToolInvoker requires at least one tool to be provided.")
|
||||
raise ValueError("ToolInvoker requires at least one tool.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
|
||||
if duplicates:
|
||||
raise ValueError(f"Duplicate tool names found: {duplicates}")
|
||||
|
||||
self.tools = tools
|
||||
self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
|
||||
self._tools_with_names = dict(zip(tool_names, tools))
|
||||
self.raise_on_failure = raise_on_failure
|
||||
self.convert_result_to_json_string = convert_result_to_json_string
|
||||
|
||||
def _handle_error(self, error: Exception) -> str:
|
||||
"""
|
||||
Handles errors by logging and either raising or returning a fallback error message.
|
||||
|
||||
:param error: The exception instance.
|
||||
:returns: The fallback error message when `raise_on_failure` is False.
|
||||
:raises: The provided error if `raise_on_failure` is True.
|
||||
"""
|
||||
logger.error("{error_exception}", error_exception=error)
|
||||
if self.raise_on_failure:
|
||||
# We re-raise the original error maintaining the exception chain
|
||||
raise error
|
||||
return str(error)
|
||||
|
||||
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
|
||||
"""
|
||||
Prepares a ChatMessage with the result of a tool invocation.
|
||||
@ -133,40 +162,141 @@ class ToolInvoker:
|
||||
The tool result.
|
||||
:returns:
|
||||
A ChatMessage object containing the tool result as a string.
|
||||
|
||||
:raises
|
||||
StringConversionError: If the conversion of the tool result to a string fails
|
||||
and `raise_on_failure` is True.
|
||||
"""
|
||||
error = False
|
||||
|
||||
if self.convert_result_to_json_string:
|
||||
try:
|
||||
try:
|
||||
if self.convert_result_to_json_string:
|
||||
# We disable ensure_ascii so special chars like emojis are not converted
|
||||
tool_result_str = json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
if self.raise_on_failure:
|
||||
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
|
||||
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
|
||||
error = True
|
||||
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
|
||||
|
||||
try:
|
||||
tool_result_str = str(result)
|
||||
else:
|
||||
tool_result_str = str(result)
|
||||
except Exception as e:
|
||||
if self.raise_on_failure:
|
||||
raise StringConversionError("Failed to convert tool result to string using `str`") from e
|
||||
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
|
||||
error = True
|
||||
conversion_method = "json.dumps" if self.convert_result_to_json_string else "str"
|
||||
try:
|
||||
tool_result_str = self._handle_error(StringConversionError(tool_call.tool_name, conversion_method, e))
|
||||
error = True
|
||||
except StringConversionError as conversion_error:
|
||||
# If _handle_error re-raises, this properly preserves the chain
|
||||
raise conversion_error from e
|
||||
|
||||
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
|
||||
|
||||
@component.output_types(tool_messages=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
||||
@staticmethod
|
||||
def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
|
||||
"""
|
||||
Combine LLM-provided arguments (llm_args) with state-based arguments.
|
||||
|
||||
Tool arguments take precedence in the following order:
|
||||
- LLM overrides state if the same param is present in both
|
||||
- local tool.inputs mappings (if any)
|
||||
- function signature name matching
|
||||
"""
|
||||
final_args = dict(llm_args) # start with LLM-provided
|
||||
|
||||
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
|
||||
# to find out which parameters the tool accepts.
|
||||
if isinstance(tool, ComponentTool):
|
||||
# mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
|
||||
assert hasattr(tool._component, "__haystack_input__") and isinstance(
|
||||
tool._component.__haystack_input__, Sockets
|
||||
)
|
||||
func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
|
||||
else:
|
||||
func_params = set(inspect.signature(tool.function).parameters.keys())
|
||||
|
||||
# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
|
||||
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
|
||||
if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
|
||||
param_mappings = tool.inputs_from_state
|
||||
else:
|
||||
param_mappings = {name: name for name in func_params}
|
||||
|
||||
# Populate final_args from state if not provided by LLM
|
||||
for state_key, param_name in param_mappings.items():
|
||||
if param_name not in final_args and state.has(state_key):
|
||||
final_args[param_name] = state.get(state_key)
|
||||
|
||||
return final_args
|
||||
|
||||
def _merge_tool_outputs(self, tool: Tool, result: Any, state: State) -> Any:
|
||||
"""
|
||||
Merges the tool result into the global state and determines the response message.
|
||||
|
||||
This method processes the output of a tool execution and integrates it into the global state.
|
||||
It also determines what message, if any, should be returned for further processing in a conversation.
|
||||
|
||||
Processing Steps:
|
||||
1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
|
||||
2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
|
||||
The return value in this case is simply the full `result` dictionary.
|
||||
3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
|
||||
processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
|
||||
This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
|
||||
handlers.
|
||||
|
||||
:param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
|
||||
:param result: The output from tool execution. Can be a dictionary, or any other type.
|
||||
:param state: The global State object to which results should be merged.
|
||||
:returns: Three possible values:
|
||||
- A string message for conversation
|
||||
- The merged result dictionary
|
||||
- Or the raw result if not a dictionary
|
||||
"""
|
||||
# If result is not a dictionary, return it as the output message.
|
||||
if not isinstance(result, dict):
|
||||
return result
|
||||
|
||||
# If there is no specific `outputs_to_state` mapping, we just return the full result
|
||||
if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
|
||||
return result
|
||||
|
||||
# Handle tool outputs with specific mapping for message and state updates
|
||||
return self._handle_tool_outputs(tool.outputs_to_state, result, state)
|
||||
|
||||
@staticmethod
|
||||
def _handle_tool_outputs(outputs_to_state: dict, result: dict, state: State) -> Union[dict, str]:
|
||||
"""
|
||||
Handles the `outputs_to_state` mapping from the tool and updates the state accordingly.
|
||||
|
||||
:param outputs_to_state: Mapping of outputs from the tool.
|
||||
:param result: Result of the tool execution.
|
||||
:param state: Global state to merge results into.
|
||||
:returns: Final message for LLM or the entire result.
|
||||
"""
|
||||
message_content = None
|
||||
|
||||
for state_key, config in outputs_to_state.items():
|
||||
# Get the source key from the output config, otherwise use the entire result
|
||||
source_key = config.get("source", None)
|
||||
output_value = result if source_key is None else result.get(source_key)
|
||||
|
||||
# Get the handler function, if any
|
||||
handler = config.get("handler", None)
|
||||
|
||||
if state_key == "message":
|
||||
# Handle the message output separately
|
||||
if handler is not None:
|
||||
message_content = handler(output_value)
|
||||
else:
|
||||
message_content = str(output_value)
|
||||
else:
|
||||
# Merge other outputs into the state
|
||||
state.set(state_key, output_value, handler_override=handler)
|
||||
|
||||
# If no "message" key was found, return the result or message content
|
||||
return message_content if message_content is not None else result
|
||||
|
||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
|
||||
|
||||
:param messages:
|
||||
A list of ChatMessage objects.
|
||||
:param state: The runtime state that should be used by the tools.
|
||||
:returns:
|
||||
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
|
||||
Each ChatMessage objects wraps the result of a tool invocation.
|
||||
@ -177,39 +307,61 @@ class ToolInvoker:
|
||||
If the tool invocation fails and `raise_on_failure` is True.
|
||||
:raises StringConversionError:
|
||||
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
|
||||
:raises ToolOutputMergeError:
|
||||
If merging tool outputs into state fails and `raise_on_failure` is True.
|
||||
"""
|
||||
if state is None:
|
||||
state = State(schema={})
|
||||
|
||||
# Only keep messages with tool calls
|
||||
messages_with_tool_calls = [message for message in messages if message.tool_calls]
|
||||
|
||||
tool_messages = []
|
||||
|
||||
for message in messages:
|
||||
tool_calls = message.tool_calls
|
||||
if not tool_calls:
|
||||
continue
|
||||
|
||||
for tool_call in tool_calls:
|
||||
for message in messages_with_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.tool_name
|
||||
tool_arguments = tool_call.arguments
|
||||
|
||||
if not tool_name in self._tools_with_names:
|
||||
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
|
||||
if self.raise_on_failure:
|
||||
raise ToolNotFoundException(msg)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
|
||||
# Check if the tool is available, otherwise return an error message
|
||||
if tool_name not in self._tools_with_names:
|
||||
error_message = self._handle_error(
|
||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
||||
)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
|
||||
tool_to_invoke = self._tools_with_names[tool_name]
|
||||
|
||||
# 1) Combine user + state inputs
|
||||
llm_args = tool_call.arguments.copy()
|
||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
||||
|
||||
# 2) Invoke the tool
|
||||
try:
|
||||
tool_result = tool_to_invoke.invoke(**tool_arguments)
|
||||
tool_result = tool_to_invoke.invoke(**final_args)
|
||||
except ToolInvocationError as e:
|
||||
if self.raise_on_failure:
|
||||
raise e
|
||||
msg = _TOOL_INVOCATION_FAILURE.format(error=e)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
|
||||
error_message = self._handle_error(e)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
|
||||
tool_message = self._prepare_tool_result_message(tool_result, tool_call)
|
||||
tool_messages.append(tool_message)
|
||||
# 3) Merge outputs into state & create a single ChatMessage for the LLM
|
||||
try:
|
||||
tool_text = self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||
except Exception as e:
|
||||
try:
|
||||
error_message = self._handle_error(
|
||||
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
|
||||
)
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
)
|
||||
continue
|
||||
except ToolOutputMergeError as propagated_e:
|
||||
# Re-raise with proper error chain
|
||||
raise propagated_e from e
|
||||
|
||||
return {"tool_messages": tool_messages}
|
||||
tool_messages.append(self._prepare_tool_result_message(result=tool_text, tool_call=tool_call))
|
||||
|
||||
return {"tool_messages": tool_messages, "state": state}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -13,6 +13,7 @@ _import_structure = {
|
||||
"chat_message": ["ChatMessage", "ChatRole", "TextContent", "ToolCall", "ToolCallResult"],
|
||||
"document": ["Document"],
|
||||
"sparse_embedding": ["SparseEmbedding"],
|
||||
"state": ["State"],
|
||||
"streaming_chunk": [
|
||||
"StreamingChunk",
|
||||
"AsyncStreamingCallbackT",
|
||||
@ -28,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from .chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult
|
||||
from .document import Document
|
||||
from .sparse_embedding import SparseEmbedding
|
||||
from .state import State
|
||||
from .streaming_chunk import (
|
||||
AsyncStreamingCallbackT,
|
||||
StreamingCallbackT,
|
||||
|
||||
153
haystack/dataclasses/state.py
Normal file
153
haystack/dataclasses/state.py
Normal file
@ -0,0 +1,153 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
from haystack.utils.type_serialization import deserialize_type, serialize_type
|
||||
|
||||
|
||||
def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a schema dictionary to a serializable format.
|
||||
|
||||
Converts each parameter's type and optional handler function into a serializable
|
||||
format using type and callable serialization utilities.
|
||||
|
||||
:param schema: Dictionary mapping parameter names to their type and handler configs
|
||||
:returns: Dictionary with serialized type and handler information
|
||||
"""
|
||||
serialized_schema = {}
|
||||
for param, config in schema.items():
|
||||
serialized_schema[param] = {"type": serialize_type(config["type"])}
|
||||
if config.get("handler"):
|
||||
serialized_schema[param]["handler"] = serialize_callable(config["handler"])
|
||||
|
||||
return serialized_schema
|
||||
|
||||
|
||||
def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a serialized schema dictionary back to its original format.
|
||||
|
||||
Deserializes the type and optional handler function for each parameter from their
|
||||
serialized format back into Python types and callables.
|
||||
|
||||
:param schema: Dictionary containing serialized schema information
|
||||
:returns: Dictionary with deserialized type and handler configurations
|
||||
"""
|
||||
deserialized_schema = {}
|
||||
for param, config in schema.items():
|
||||
deserialized_schema[param] = {"type": deserialize_type(config["type"])}
|
||||
|
||||
if config.get("handler"):
|
||||
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
|
||||
|
||||
return deserialized_schema
|
||||
|
||||
|
||||
def _validate_schema(schema: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate that a schema dictionary meets all required constraints.
|
||||
|
||||
Checks that each parameter definition has a valid type field and that any handler
|
||||
specified is a callable function.
|
||||
|
||||
:param schema: Dictionary mapping parameter names to their type and handler configs
|
||||
:raises ValueError: If schema validation fails due to missing or invalid fields
|
||||
"""
|
||||
for param, definition in schema.items():
|
||||
if "type" not in definition:
|
||||
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
|
||||
if not _is_valid_type(definition["type"]):
|
||||
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
|
||||
if definition.get("handler") is not None and not callable(definition["handler"]):
|
||||
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
|
||||
|
||||
|
||||
class State:
|
||||
"""
|
||||
A dataclass that wraps a StateSchema and maintains an internal _data dictionary.
|
||||
|
||||
Each schema entry has:
|
||||
"parameter_name": {
|
||||
"type": SomeType,
|
||||
"handler": Optional[Callable[[Any, Any], Any]]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize a State object with a schema and optional data.
|
||||
|
||||
:param schema: Dictionary mapping parameter names to their type and handler configs.
|
||||
Type must be a valid Python type, and handler must be a callable function or None.
|
||||
If handler is None, the default handler for the type will be used. The default handlers are:
|
||||
- For list types: `haystack.dataclasses.state_utils.merge_lists`
|
||||
- For all other types: `haystack.dataclasses.state_utils.replace_values`
|
||||
:param data: Optional dictionary of initial data to populate the state
|
||||
"""
|
||||
_validate_schema(schema)
|
||||
self.schema = schema
|
||||
self._data = data or {}
|
||||
|
||||
# Set default handlers if not provided in schema
|
||||
for definition in schema.values():
|
||||
# Skip if handler is already defined and not None
|
||||
if definition.get("handler") is not None:
|
||||
continue
|
||||
# Set default handler based on type
|
||||
if _is_list_type(definition["type"]):
|
||||
definition["handler"] = merge_lists
|
||||
else:
|
||||
definition["handler"] = replace_values
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Retrieve a value from the state by key.
|
||||
|
||||
:param key: Key to look up in the state
|
||||
:param default: Value to return if key is not found
|
||||
:returns: Value associated with key or default if not found
|
||||
"""
|
||||
return self._data.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
|
||||
"""
|
||||
Set or merge a value in the state according to schema rules.
|
||||
|
||||
Value is merged or overwritten according to these rules:
|
||||
- if handler_override is given, use that
|
||||
- else use the handler defined in the schema for 'key'
|
||||
|
||||
:param key: Key to store the value under
|
||||
:param value: Value to store or merge
|
||||
:param handler_override: Optional function to override the default merge behavior
|
||||
"""
|
||||
# If key not in schema, we throw an error
|
||||
definition = self.schema.get(key, None)
|
||||
if definition is None:
|
||||
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
|
||||
|
||||
# Get current value from state and apply handler
|
||||
current_value = self._data.get(key, None)
|
||||
handler = handler_override or definition["handler"]
|
||||
self._data[key] = handler(current_value, value)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""
|
||||
All current data of the state.
|
||||
"""
|
||||
return self._data
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in the state.
|
||||
|
||||
:param key: Key to check for existence
|
||||
:returns: True if key exists in state, False otherwise
|
||||
"""
|
||||
return key in self._data
|
||||
77
haystack/dataclasses/state_utils.py
Normal file
77
haystack/dataclasses/state_utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, TypeVar, Union, get_origin
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _is_valid_type(obj: Any) -> bool:
|
||||
"""
|
||||
Check if an object is a valid type annotation.
|
||||
|
||||
Valid types include:
|
||||
- Normal classes (str, dict, CustomClass)
|
||||
- Generic types (List[str], Dict[str, int])
|
||||
- Union types (Union[str, int], Optional[str])
|
||||
|
||||
:param obj: The object to check
|
||||
:return: True if the object is a valid type annotation, False otherwise
|
||||
|
||||
Example usage:
|
||||
>>> _is_valid_type(str)
|
||||
True
|
||||
>>> _is_valid_type(List[int])
|
||||
True
|
||||
>>> _is_valid_type(Union[str, int])
|
||||
True
|
||||
>>> _is_valid_type(42)
|
||||
False
|
||||
"""
|
||||
# Handle Union types (including Optional)
|
||||
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
|
||||
return True
|
||||
|
||||
# Handle normal classes and generic types
|
||||
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}
|
||||
|
||||
|
||||
def _is_list_type(type_hint: Any) -> bool:
|
||||
"""
|
||||
Check if a type hint represents a list type.
|
||||
|
||||
:param type_hint: The type hint to check
|
||||
:return: True if the type hint represents a list, False otherwise
|
||||
"""
|
||||
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
|
||||
|
||||
|
||||
def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]:
|
||||
"""
|
||||
Merges two values into a single list.
|
||||
|
||||
If either `current` or `new` is not already a list, it is converted into one.
|
||||
The function ensures that both inputs are treated as lists and concatenates them.
|
||||
|
||||
If `current` is None, it is treated as an empty list.
|
||||
|
||||
:param current: The existing value(s), either a single item or a list.
|
||||
:param new: The new value(s) to merge, either a single item or a list.
|
||||
:return: A list containing elements from both `current` and `new`.
|
||||
"""
|
||||
current_list = [] if current is None else current if isinstance(current, list) else [current]
|
||||
new_list = new if isinstance(new, list) else [new]
|
||||
return current_list + new_list
|
||||
|
||||
|
||||
def replace_values(current: Any, new: Any) -> Any:
|
||||
"""
|
||||
Replace the `current` value with the `new` value.
|
||||
|
||||
:param current: The existing value
|
||||
:param new: The new value to replace
|
||||
:return: The new value
|
||||
"""
|
||||
return new
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import fields, is_dataclass
|
||||
from inspect import getdoc
|
||||
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
||||
@ -19,6 +20,7 @@ from haystack.core.serialization import (
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import Tool
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
|
||||
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
|
||||
from docstring_parser import parse
|
||||
@ -87,13 +89,34 @@ class ComponentTool(Tool):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
component: Component,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
inputs_from_state: Optional[Dict[str, Any]] = None,
|
||||
outputs_to_state: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Create a Tool instance from a Haystack component.
|
||||
|
||||
:param component: The Haystack component to wrap as a tool.
|
||||
:param name: Optional name for the tool (defaults to snake_case of component class name).
|
||||
:param description: Optional description (defaults to component's docstring).
|
||||
:param parameters:
|
||||
A JSON schema defining the parameters expected by the Tool.
|
||||
Will fall back to the parameters defined in the component's run method signature if not provided.
|
||||
:param inputs_from_state:
|
||||
Optional dictionary mapping state keys to tool parameter names.
|
||||
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
|
||||
:param outputs_to_state:
|
||||
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
|
||||
Example: {
|
||||
"documents": {"source": "docs", "handler": custom_handler},
|
||||
"message": {"source": "summary", "handler": format_summary}
|
||||
}
|
||||
:raises ValueError: If the component is invalid or schema generation fails.
|
||||
"""
|
||||
if not isinstance(component, Component):
|
||||
@ -110,8 +133,9 @@ class ComponentTool(Tool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._unresolved_parameters = parameters
|
||||
# Create the tools schema from the component run method parameters
|
||||
tool_schema = self._create_tool_parameters_schema(component)
|
||||
tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
|
||||
|
||||
def component_invoker(**kwargs):
|
||||
"""
|
||||
@ -155,16 +179,39 @@ class ComponentTool(Tool):
|
||||
description = description or component.__doc__ or name
|
||||
|
||||
# Create the Tool instance with the component invoker as the function to be called and the schema
|
||||
super().__init__(name, description, tool_schema, component_invoker)
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=tool_schema,
|
||||
function=component_invoker,
|
||||
inputs_from_state=inputs_from_state,
|
||||
outputs_to_state=outputs_to_state,
|
||||
)
|
||||
self._component = component
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the ComponentTool to a dictionary.
|
||||
"""
|
||||
# we do not serialize the function in this case: it can be recreated from the component at deserialization time
|
||||
serialized = {"name": self.name, "description": self.description, "parameters": self.parameters}
|
||||
serialized["component"] = component_to_dict(obj=self._component, name=self.name)
|
||||
serialized_component = component_to_dict(obj=self._component, name=self.name)
|
||||
|
||||
serialized = {
|
||||
"component": serialized_component,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self._unresolved_parameters,
|
||||
"inputs_from_state": self.inputs_from_state,
|
||||
}
|
||||
|
||||
if self.outputs_to_state is not None:
|
||||
serialized_outputs = {}
|
||||
for key, config in self.outputs_to_state.items():
|
||||
serialized_config = config.copy()
|
||||
if "handler" in config:
|
||||
serialized_config["handler"] = serialize_callable(config["handler"])
|
||||
serialized_outputs[key] = serialized_config
|
||||
serialized["outputs_to_state"] = serialized_outputs
|
||||
|
||||
return {"type": generate_qualified_class_name(type(self)), "data": serialized}
|
||||
|
||||
@classmethod
|
||||
@ -175,9 +222,26 @@ class ComponentTool(Tool):
|
||||
inner_data = data["data"]
|
||||
component_class = import_class_by_name(inner_data["component"]["type"])
|
||||
component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
|
||||
return cls(component=component, name=inner_data["name"], description=inner_data["description"])
|
||||
|
||||
def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]:
|
||||
if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
|
||||
deserialized_outputs = {}
|
||||
for key, config in inner_data["outputs_to_state"].items():
|
||||
deserialized_config = config.copy()
|
||||
if "handler" in config:
|
||||
deserialized_config["handler"] = deserialize_callable(config["handler"])
|
||||
deserialized_outputs[key] = deserialized_config
|
||||
inner_data["outputs_to_state"] = deserialized_outputs
|
||||
|
||||
return cls(
|
||||
component=component,
|
||||
name=inner_data["name"],
|
||||
description=inner_data["description"],
|
||||
parameters=inner_data.get("parameters", None),
|
||||
inputs_from_state=inner_data.get("inputs_from_state", None),
|
||||
outputs_to_state=inner_data.get("outputs_to_state", None),
|
||||
)
|
||||
|
||||
def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates an OpenAI tools schema from a component's run method parameters.
|
||||
|
||||
@ -191,6 +255,8 @@ class ComponentTool(Tool):
|
||||
param_descriptions = self._get_param_descriptions(component.run)
|
||||
|
||||
for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined]
|
||||
if inputs_from_state is not None and input_name in inputs_from_state:
|
||||
continue
|
||||
input_type = socket.type
|
||||
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
|
||||
|
||||
@ -326,3 +392,29 @@ class ComponentTool(Tool):
|
||||
schema["default"] = default
|
||||
|
||||
return schema
|
||||
|
||||
def __deepcopy__(self, memo: Dict[Any, Any]) -> "ComponentTool":
|
||||
# Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
|
||||
# When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
|
||||
# We overwrite ComponentTool.__deepcopy__ to fix this until a more comprehensive fix is merged.
|
||||
# We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
|
||||
result = copy(self)
|
||||
|
||||
# Add the object to the memo dictionary to handle circular references
|
||||
memo[id(self)] = result
|
||||
|
||||
# Deep copy all attributes with exception handling
|
||||
for key, value in self.__dict__.items():
|
||||
try:
|
||||
# Try to deep copy the attribute
|
||||
setattr(result, key, deepcopy(value, memo))
|
||||
except TypeError:
|
||||
# Fall back to using the original attribute for components that use Jinja2-templates
|
||||
logger.debug(
|
||||
"deepcopy of ComponentTool {tool_name} failed. Using original attribute '{attribute}' instead.",
|
||||
tool_name=self.name,
|
||||
attribute=key,
|
||||
)
|
||||
setattr(result, key, getattr(self, key))
|
||||
|
||||
return result
|
||||
|
||||
@ -3,16 +3,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from haystack.tools.tool import Tool
|
||||
from .errors import SchemaGenerationError
|
||||
from .tool import Tool
|
||||
|
||||
|
||||
def create_tool_from_function(
|
||||
function: Callable, name: Optional[str] = None, description: Optional[str] = None
|
||||
function: Callable,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
inputs_from_state: Optional[Dict[str, str]] = None,
|
||||
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> "Tool":
|
||||
"""
|
||||
Create a Tool instance from a function.
|
||||
@ -62,7 +66,15 @@ def create_tool_from_function(
|
||||
:param description:
|
||||
The description of the Tool. If not provided, the docstring of the function will be used.
|
||||
To intentionally leave the description empty, pass an empty string.
|
||||
|
||||
:param inputs_from_state:
|
||||
Optional dictionary mapping state keys to tool parameter names.
|
||||
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
|
||||
:param outputs_to_state:
|
||||
Optional dictionary defining how tool outputs map to state and message handling.
|
||||
Example: {
|
||||
"documents": {"source": "docs", "handler": custom_handler},
|
||||
"message": {"source": "summary", "handler": format_summary}
|
||||
}
|
||||
:returns:
|
||||
The Tool created from the function.
|
||||
|
||||
@ -71,7 +83,6 @@ def create_tool_from_function(
|
||||
:raises SchemaGenerationError:
|
||||
If there is an error generating the JSON schema for the Tool.
|
||||
"""
|
||||
|
||||
tool_description = description if description is not None else (function.__doc__ or "")
|
||||
|
||||
signature = inspect.signature(function)
|
||||
@ -81,6 +92,10 @@ def create_tool_from_function(
|
||||
descriptions = {}
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
# Skip adding parameter names that will be passed to the tool from State
|
||||
if inputs_from_state and param_name in inputs_from_state.values():
|
||||
continue
|
||||
|
||||
if param.annotation is param.empty:
|
||||
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
|
||||
|
||||
@ -109,15 +124,33 @@ def create_tool_from_function(
|
||||
if param_name in schema["properties"]:
|
||||
schema["properties"][param_name]["description"] = param_description
|
||||
|
||||
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
|
||||
return Tool(
|
||||
name=name or function.__name__,
|
||||
description=tool_description,
|
||||
parameters=schema,
|
||||
function=function,
|
||||
inputs_from_state=inputs_from_state,
|
||||
outputs_to_state=outputs_to_state,
|
||||
)
|
||||
|
||||
|
||||
def tool(function: Callable) -> Tool:
|
||||
def tool(
|
||||
function: Optional[Callable] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
inputs_from_state: Optional[Dict[str, str]] = None,
|
||||
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Union[Tool, Callable[[Callable], Tool]]:
|
||||
"""
|
||||
Decorator to convert a function into a Tool.
|
||||
|
||||
Tool name, description, and parameters are inferred from the function.
|
||||
If you need to customize more the Tool, use `create_tool_from_function` instead.
|
||||
Can be used with or without parameters:
|
||||
@tool # without parameters
|
||||
def my_function(): ...
|
||||
|
||||
@tool(name="custom_name") # with parameters
|
||||
def my_function(): ...
|
||||
|
||||
### Usage example
|
||||
```python
|
||||
@ -147,8 +180,27 @@ def tool(function: Callable) -> Tool:
|
||||
>>> },
|
||||
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
|
||||
```
|
||||
|
||||
:param function: The function to decorate (when used without parameters)
|
||||
:param name: Optional custom name for the tool
|
||||
:param description: Optional custom description
|
||||
:param inputs_from_state: Optional dictionary mapping state keys to tool parameter names
|
||||
:param outputs_to_state: Optional dictionary defining how tool outputs map to state and message handling
|
||||
:return: Either a Tool instance or a decorator function that will create one
|
||||
"""
|
||||
return create_tool_from_function(function)
|
||||
|
||||
def decorator(func: Callable) -> Tool:
|
||||
return create_tool_from_function(
|
||||
function=func,
|
||||
name=name,
|
||||
description=description,
|
||||
inputs_from_state=inputs_from_state,
|
||||
outputs_to_state=outputs_to_state,
|
||||
)
|
||||
|
||||
if function is None:
|
||||
return decorator
|
||||
return decorator(function)
|
||||
|
||||
|
||||
def _remove_title_from_schema(schema: Dict[str, Any]):
|
||||
|
||||
@ -29,12 +29,23 @@ class Tool:
|
||||
A JSON schema defining the parameters expected by the Tool.
|
||||
:param function:
|
||||
The function that will be invoked when the Tool is called.
|
||||
:param inputs_from_state:
|
||||
Optional dictionary mapping state keys to tool parameter names.
|
||||
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
|
||||
:param outputs_to_state:
|
||||
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
|
||||
Example: {
|
||||
"documents": {"source": "docs", "handler": custom_handler},
|
||||
"message": {"source": "summary", "handler": format_summary}
|
||||
}
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
function: Callable
|
||||
inputs_from_state: Optional[Dict[str, str]] = None
|
||||
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Check that the parameters define a valid JSON schema
|
||||
@ -43,6 +54,16 @@ class Tool:
|
||||
except SchemaError as e:
|
||||
raise ValueError("The provided parameters do not define a valid JSON schema") from e
|
||||
|
||||
# Validate outputs structure if provided
|
||||
if self.outputs_to_state is not None:
|
||||
for key, config in self.outputs_to_state.items():
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"Output configuration for key '{key}' must be a dictionary")
|
||||
if "source" in config and not isinstance(config["source"], str):
|
||||
raise ValueError(f"Output source for key '{key}' must be a string.")
|
||||
if "handler" in config and not callable(config["handler"]):
|
||||
raise ValueError(f"Output handler for key '{key}' must be callable")
|
||||
|
||||
@property
|
||||
def tool_spec(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -54,11 +75,12 @@ class Tool:
|
||||
"""
|
||||
Invoke the Tool with the provided keyword arguments.
|
||||
"""
|
||||
|
||||
try:
|
||||
result = self.function(**kwargs)
|
||||
except Exception as e:
|
||||
raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e
|
||||
raise ToolInvocationError(
|
||||
f"Failed to invoke Tool `{self.name}` with parameters {kwargs}. Error: {e}"
|
||||
) from e
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -68,9 +90,19 @@ class Tool:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
|
||||
data = asdict(self)
|
||||
data["function"] = serialize_callable(self.function)
|
||||
|
||||
# Serialize output handlers if they exist
|
||||
if self.outputs_to_state:
|
||||
serialized_outputs = {}
|
||||
for key, config in self.outputs_to_state.items():
|
||||
serialized_config = config.copy()
|
||||
if "handler" in config:
|
||||
serialized_config["handler"] = serialize_callable(config["handler"])
|
||||
serialized_outputs[key] = serialized_config
|
||||
data["outputs_to_state"] = serialized_outputs
|
||||
|
||||
return {"type": generate_qualified_class_name(type(self)), "data": data}
|
||||
|
||||
@classmethod
|
||||
@ -85,6 +117,17 @@ class Tool:
|
||||
"""
|
||||
init_parameters = data["data"]
|
||||
init_parameters["function"] = deserialize_callable(init_parameters["function"])
|
||||
|
||||
# Deserialize output handlers if they exist
|
||||
if "outputs_to_state" in init_parameters and init_parameters["outputs_to_state"]:
|
||||
deserialized_outputs = {}
|
||||
for key, config in init_parameters["outputs_to_state"].items():
|
||||
deserialized_config = config.copy()
|
||||
if "handler" in config:
|
||||
deserialized_config["handler"] = deserialize_callable(config["handler"])
|
||||
deserialized_outputs[key] = deserialized_config
|
||||
init_parameters["outputs_to_state"] = deserialized_outputs
|
||||
|
||||
return cls(**init_parameters)
|
||||
|
||||
|
||||
|
||||
8
releasenotes/notes/move-tool-ff98d464d3e5d775.yaml
Normal file
8
releasenotes/notes/move-tool-ff98d464d3e5d775.yaml
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduce new State dataclass with a customizable schema for managing Agent state.
|
||||
Enhance error logging of Tool and extend the ToolInvoker component to work with newly introduced State.
|
||||
fixes:
|
||||
- |
|
||||
Fix an issue that prevented Jinja2-based ComponentTools from being passed into pipelines at runtime.
|
||||
@ -258,7 +258,9 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
"data": {
|
||||
"description": "description",
|
||||
"function": "builtins.print",
|
||||
"inputs_from_state": None,
|
||||
"name": "name",
|
||||
"outputs_to_state": None,
|
||||
"parameters": {"x": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
@ -319,7 +321,9 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
{
|
||||
"type": "haystack.tools.tool.Tool",
|
||||
"data": {
|
||||
"inputs_from_state": None,
|
||||
"name": "name",
|
||||
"outputs_to_state": None,
|
||||
"description": "description",
|
||||
"parameters": {"x": {"type": "string"}},
|
||||
"function": "builtins.print",
|
||||
|
||||
@ -192,7 +192,9 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
{
|
||||
"type": "haystack.tools.tool.Tool",
|
||||
"data": {
|
||||
"inputs_from_state": None,
|
||||
"name": "weather",
|
||||
"outputs_to_state": None,
|
||||
"description": "useful to determine the weather in a given location",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
|
||||
"function": "generators.chat.test_hugging_face_local.get_weather",
|
||||
|
||||
@ -205,7 +205,9 @@ class TestOpenAIChatGenerator:
|
||||
"data": {
|
||||
"description": "description",
|
||||
"function": "builtins.print",
|
||||
"inputs_from_state": None,
|
||||
"name": "name",
|
||||
"outputs_to_state": None,
|
||||
"parameters": {"x": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
@ -3,10 +3,14 @@ import datetime
|
||||
|
||||
from haystack import Pipeline
|
||||
|
||||
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
|
||||
from haystack.tools.tool import Tool, ToolInvocationError
|
||||
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
|
||||
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
|
||||
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
|
||||
from haystack.dataclasses.state import State
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
from haystack.tools.errors import ToolInvocationError
|
||||
|
||||
|
||||
def weather_function(location):
|
||||
@ -82,6 +86,49 @@ class TestToolInvoker:
|
||||
with pytest.raises(ValueError):
|
||||
ToolInvoker(tools=[weather_tool, new_tool])
|
||||
|
||||
def test_inject_state_args_no_tool_inputs(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
)
|
||||
state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
|
||||
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
|
||||
assert args == {"location": "Berlin"}
|
||||
|
||||
def test_inject_state_args_no_tool_inputs_component_tool(self):
|
||||
comp = PromptBuilder(template="Hello, {{name}}!")
|
||||
prompt_tool = ComponentTool(
|
||||
component=comp, name="prompt_tool", description="Creates a personalized greeting prompt."
|
||||
)
|
||||
state = State(schema={"name": {"type": str}}, data={"name": "James"})
|
||||
args = ToolInvoker._inject_state_args(tool=prompt_tool, llm_args={}, state=state)
|
||||
assert args == {"name": "James"}
|
||||
|
||||
def test_inject_state_args_with_tool_inputs(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
inputs_from_state={"loc": "location"},
|
||||
)
|
||||
state = State(schema={"location": {"type": str}}, data={"loc": "Berlin"})
|
||||
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
|
||||
assert args == {"location": "Berlin"}
|
||||
|
||||
def test_inject_state_args_param_in_state_and_llm(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
)
|
||||
state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
|
||||
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
|
||||
assert args == {"location": "Paris"}
|
||||
|
||||
def test_run(self, invoker):
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
@ -104,14 +151,14 @@ class TestToolInvoker:
|
||||
|
||||
def test_run_no_messages(self, invoker):
|
||||
result = invoker.run(messages=[])
|
||||
assert result == {"tool_messages": []}
|
||||
assert result["tool_messages"] == []
|
||||
|
||||
def test_run_no_tool_calls(self, invoker):
|
||||
user_message = ChatMessage.from_user(text="Hello!")
|
||||
assistant_message = ChatMessage.from_assistant(text="How can I help you?")
|
||||
|
||||
result = invoker.run(messages=[user_message, assistant_message])
|
||||
assert result == {"tool_messages": []}
|
||||
assert result["tool_messages"] == []
|
||||
|
||||
def test_run_multiple_tool_calls(self, invoker):
|
||||
tool_calls = [
|
||||
@ -143,8 +190,8 @@ class TestToolInvoker:
|
||||
with pytest.raises(ToolNotFoundException):
|
||||
invoker.run(messages=[tool_call_message])
|
||||
|
||||
def test_tool_not_found_does_not_raise_exception(self, invoker):
|
||||
invoker.raise_on_failure = False
|
||||
def test_tool_not_found_does_not_raise_exception(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=False)
|
||||
|
||||
tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"})
|
||||
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
@ -162,8 +209,8 @@ class TestToolInvoker:
|
||||
with pytest.raises(ToolInvocationError):
|
||||
faulty_invoker.run(messages=[tool_call_message])
|
||||
|
||||
def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker):
|
||||
faulty_invoker.raise_on_failure = False
|
||||
def test_tool_invocation_error_does_not_raise_exception(self, faulty_tool):
|
||||
faulty_invoker = ToolInvoker(tools=[faulty_tool], raise_on_failure=False, convert_result_to_json_string=False)
|
||||
|
||||
tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
|
||||
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
@ -171,10 +218,10 @@ class TestToolInvoker:
|
||||
result = faulty_invoker.run(messages=[tool_call_message])
|
||||
tool_message = result["tool_messages"][0]
|
||||
assert tool_message.tool_call_results[0].error
|
||||
assert "invocation failed" in tool_message.tool_call_results[0].result
|
||||
assert "Failed to invoke" in tool_message.tool_call_results[0].result
|
||||
|
||||
def test_string_conversion_error(self, invoker):
|
||||
invoker.convert_result_to_json_string = True
|
||||
def test_string_conversion_error(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=True)
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
|
||||
@ -182,9 +229,8 @@ class TestToolInvoker:
|
||||
with pytest.raises(StringConversionError):
|
||||
invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call)
|
||||
|
||||
def test_string_conversion_error_does_not_raise_exception(self, invoker):
|
||||
invoker.convert_result_to_json_string = True
|
||||
invoker.raise_on_failure = False
|
||||
def test_string_conversion_error_does_not_raise_exception(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=True)
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
|
||||
@ -231,8 +277,8 @@ class TestToolInvoker:
|
||||
pipeline_dict = pipeline.to_dict()
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"max_runs_per_component": 100,
|
||||
"components": {
|
||||
"invoker": {
|
||||
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
|
||||
@ -249,6 +295,8 @@ class TestToolInvoker:
|
||||
"required": ["location"],
|
||||
},
|
||||
"function": "tools.test_tool_invoker.weather_function",
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_state": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
@ -279,3 +327,93 @@ class TestToolInvoker:
|
||||
|
||||
new_pipeline = Pipeline.loads(pipeline_yaml)
|
||||
assert new_pipeline == pipeline
|
||||
|
||||
|
||||
class TestMergeToolOutputs:
|
||||
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"weather": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(tool=weather_tool, result="test", state=state)
|
||||
assert merged_results == "test"
|
||||
assert state.data == {}
|
||||
|
||||
def test_merge_tool_outputs_empty_dict(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"weather": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(tool=weather_tool, result={}, state=state)
|
||||
assert merged_results == {}
|
||||
assert state.data == {}
|
||||
|
||||
def test_merge_tool_outputs_no_output_mapping(self, weather_tool):
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"weather": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(
|
||||
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
|
||||
)
|
||||
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
|
||||
assert state.data == {}
|
||||
|
||||
def test_merge_tool_outputs_with_output_mapping(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
outputs_to_state={"weather": {"source": "weather"}},
|
||||
)
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"weather": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(
|
||||
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
|
||||
)
|
||||
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
|
||||
assert state.data == {"weather": "sunny"}
|
||||
|
||||
def test_merge_tool_outputs_with_output_mapping_2(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
outputs_to_state={"all_weather_results": {}},
|
||||
)
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"all_weather_results": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(
|
||||
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
|
||||
)
|
||||
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
|
||||
assert state.data == {"all_weather_results": {"weather": "sunny", "temperature": 14, "unit": "celsius"}}
|
||||
|
||||
def test_merge_tool_outputs_with_output_mapping_and_handler(self):
|
||||
handler = lambda old, new: f"{new}"
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
outputs_to_state={"temperature": {"source": "temperature", "handler": handler}},
|
||||
)
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={"temperature": {"type": str}})
|
||||
merged_results = invoker._merge_tool_outputs(
|
||||
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
|
||||
)
|
||||
assert merged_results == {"weather": "sunny", "temperature": 14, "unit": "celsius"}
|
||||
assert state.data == {"temperature": "14"}
|
||||
|
||||
def test_merge_tool_outputs_with_message_output_mapping(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
outputs_to_state={"message": {"source": "weather"}},
|
||||
)
|
||||
invoker = ToolInvoker(tools=[weather_tool])
|
||||
state = State(schema={})
|
||||
merged_results = invoker._merge_tool_outputs(
|
||||
tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
|
||||
)
|
||||
assert merged_results == "sunny"
|
||||
assert state.data == {}
|
||||
|
||||
146
test/dataclasses/test_state.py
Normal file
146
test/dataclasses/test_state.py
Normal file
@ -0,0 +1,146 @@
|
||||
import pytest
|
||||
from typing import List, Dict
|
||||
|
||||
from haystack.dataclasses.state import State, _validate_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_schema():
|
||||
return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_schema():
|
||||
return {
|
||||
"numbers": {
|
||||
"type": list,
|
||||
"handler": lambda current, new: sorted(set(current + new)) if current else sorted(set(new)),
|
||||
},
|
||||
"metadata": {"type": dict},
|
||||
"name": {"type": str},
|
||||
}
|
||||
|
||||
|
||||
def test_validate_schema_valid(basic_schema):
|
||||
# Should not raise any exceptions
|
||||
_validate_schema(basic_schema)
|
||||
|
||||
|
||||
def test_validate_schema_invalid_type():
|
||||
invalid_schema = {"test": {"type": "not_a_type"}}
|
||||
with pytest.raises(ValueError, match="must be a Python type"):
|
||||
_validate_schema(invalid_schema)
|
||||
|
||||
|
||||
def test_validate_schema_missing_type():
|
||||
invalid_schema = {"test": {"handler": lambda x, y: x + y}}
|
||||
with pytest.raises(ValueError, match="missing a 'type' entry"):
|
||||
_validate_schema(invalid_schema)
|
||||
|
||||
|
||||
def test_validate_schema_invalid_handler():
|
||||
invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
|
||||
with pytest.raises(ValueError, match="must be callable or None"):
|
||||
_validate_schema(invalid_schema)
|
||||
|
||||
|
||||
def test_state_initialization(basic_schema):
|
||||
# Test empty initialization
|
||||
state = State(basic_schema)
|
||||
assert state.data == {}
|
||||
|
||||
# Test initialization with data
|
||||
initial_data = {"numbers": [1, 2, 3], "name": "test"}
|
||||
state = State(basic_schema, initial_data)
|
||||
assert state.data["numbers"] == [1, 2, 3]
|
||||
assert state.data["name"] == "test"
|
||||
|
||||
|
||||
def test_state_get(basic_schema):
|
||||
state = State(basic_schema, {"name": "test"})
|
||||
assert state.get("name") == "test"
|
||||
assert state.get("non_existent") is None
|
||||
assert state.get("non_existent", "default") == "default"
|
||||
|
||||
|
||||
def test_state_set_basic(basic_schema):
|
||||
state = State(basic_schema)
|
||||
|
||||
# Test setting new values
|
||||
state.set("numbers", [1, 2])
|
||||
assert state.get("numbers") == [1, 2]
|
||||
|
||||
# Test updating existing values
|
||||
state.set("numbers", [3, 4])
|
||||
assert state.get("numbers") == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_state_set_with_handler(complex_schema):
|
||||
state = State(complex_schema)
|
||||
|
||||
# Test custom handler for numbers
|
||||
state.set("numbers", [3, 2, 1])
|
||||
assert state.get("numbers") == [1, 2, 3]
|
||||
|
||||
state.set("numbers", [6, 5, 4])
|
||||
assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
|
||||
|
||||
|
||||
def test_state_set_with_handler_override(basic_schema):
|
||||
state = State(basic_schema)
|
||||
|
||||
# Custom handler that concatenates strings
|
||||
custom_handler = lambda current, new: f"{current}-{new}" if current else new
|
||||
|
||||
state.set("name", "first")
|
||||
state.set("name", "second", handler_override=custom_handler)
|
||||
assert state.get("name") == "first-second"
|
||||
|
||||
|
||||
def test_state_has(basic_schema):
|
||||
state = State(basic_schema, {"name": "test"})
|
||||
assert state.has("name") is True
|
||||
assert state.has("non_existent") is False
|
||||
|
||||
|
||||
def test_state_empty_schema():
|
||||
state = State({})
|
||||
assert state.data == {}
|
||||
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
|
||||
state.set("any_key", "value")
|
||||
|
||||
|
||||
def test_state_none_values(basic_schema):
|
||||
state = State(basic_schema)
|
||||
state.set("name", None)
|
||||
assert state.get("name") is None
|
||||
state.set("name", "value")
|
||||
assert state.get("name") == "value"
|
||||
|
||||
|
||||
def test_state_merge_lists(basic_schema):
|
||||
state = State(basic_schema)
|
||||
state.set("numbers", "not_a_list")
|
||||
assert state.get("numbers") == ["not_a_list"]
|
||||
state.set("numbers", [1, 2])
|
||||
assert state.get("numbers") == ["not_a_list", 1, 2]
|
||||
|
||||
|
||||
def test_state_nested_structures():
|
||||
schema = {
|
||||
"complex": {
|
||||
"type": Dict[str, List[int]],
|
||||
"handler": lambda current, new: {
|
||||
k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())
|
||||
}
|
||||
if current
|
||||
else new,
|
||||
}
|
||||
}
|
||||
|
||||
state = State(schema)
|
||||
state.set("complex", {"a": [1, 2], "b": [3, 4]})
|
||||
state.set("complex", {"b": [5, 6], "c": [7, 8]})
|
||||
|
||||
expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
|
||||
assert state.get("complex") == expected
|
||||
161
test/dataclasses/test_state_utils.py
Normal file
161
test/dataclasses/test_state_utils.py
Normal file
@ -0,0 +1,161 @@
|
||||
import pytest
|
||||
from typing import List, Dict, Optional, Union, TypeVar, Generic
|
||||
from dataclasses import dataclass
|
||||
|
||||
from haystack.dataclasses.state_utils import _is_list_type, merge_lists, _is_valid_type
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
def test_is_list_type():
|
||||
assert _is_list_type(list) is True
|
||||
assert _is_list_type(List[int]) is True
|
||||
assert _is_list_type(List[str]) is True
|
||||
assert _is_list_type(dict) is False
|
||||
assert _is_list_type(int) is False
|
||||
assert _is_list_type(Union[List[int], None]) is False
|
||||
|
||||
|
||||
class TestMergeLists:
|
||||
def test_merge_two_lists(self):
|
||||
current = [1, 2, 3]
|
||||
new = [4, 5, 6]
|
||||
result = merge_lists(current, new)
|
||||
assert result == [1, 2, 3, 4, 5, 6]
|
||||
# Ensure original lists weren't modified
|
||||
assert current == [1, 2, 3]
|
||||
assert new == [4, 5, 6]
|
||||
|
||||
def test_append_to_list(self):
|
||||
current = [1, 2, 3]
|
||||
new = 4
|
||||
result = merge_lists(current, new)
|
||||
assert result == [1, 2, 3, 4]
|
||||
assert current == [1, 2, 3] # Ensure original wasn't modified
|
||||
|
||||
def test_create_new_list(self):
|
||||
current = 1
|
||||
new = 2
|
||||
result = merge_lists(current, new)
|
||||
assert result == [1, 2]
|
||||
|
||||
def test_replace_with_list(self):
|
||||
current = 1
|
||||
new = [2, 3]
|
||||
result = merge_lists(current, new)
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
class TestIsValidType:
|
||||
def test_builtin_types(self):
|
||||
assert _is_valid_type(str) is True
|
||||
assert _is_valid_type(int) is True
|
||||
assert _is_valid_type(dict) is True
|
||||
assert _is_valid_type(list) is True
|
||||
assert _is_valid_type(tuple) is True
|
||||
assert _is_valid_type(set) is True
|
||||
assert _is_valid_type(bool) is True
|
||||
assert _is_valid_type(float) is True
|
||||
|
||||
def test_generic_types(self):
|
||||
assert _is_valid_type(List[str]) is True
|
||||
assert _is_valid_type(Dict[str, int]) is True
|
||||
assert _is_valid_type(List[Dict[str, int]]) is True
|
||||
assert _is_valid_type(Dict[str, List[int]]) is True
|
||||
|
||||
def test_custom_classes(self):
|
||||
@dataclass
|
||||
class CustomClass:
|
||||
value: int
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class GenericCustomClass(Generic[T]):
|
||||
pass
|
||||
|
||||
# Test regular and generic custom classes
|
||||
assert _is_valid_type(CustomClass) is True
|
||||
assert _is_valid_type(GenericCustomClass) is True
|
||||
assert _is_valid_type(GenericCustomClass[int]) is True
|
||||
|
||||
# Test generic types with custom classes
|
||||
assert _is_valid_type(List[CustomClass]) is True
|
||||
assert _is_valid_type(Dict[str, CustomClass]) is True
|
||||
assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True
|
||||
|
||||
def test_invalid_types(self):
|
||||
# Test regular values
|
||||
assert _is_valid_type(42) is False
|
||||
assert _is_valid_type("string") is False
|
||||
assert _is_valid_type([1, 2, 3]) is False
|
||||
assert _is_valid_type({"a": 1}) is False
|
||||
assert _is_valid_type(True) is False
|
||||
|
||||
# Test class instances
|
||||
@dataclass
|
||||
class SampleClass:
|
||||
value: int
|
||||
|
||||
instance = SampleClass(42)
|
||||
assert _is_valid_type(instance) is False
|
||||
|
||||
# Test callable objects
|
||||
assert _is_valid_type(len) is False
|
||||
assert _is_valid_type(lambda x: x) is False
|
||||
assert _is_valid_type(print) is False
|
||||
|
||||
def test_union_and_optional_types(self):
|
||||
# Test basic Union types
|
||||
assert _is_valid_type(Union[str, int]) is True
|
||||
assert _is_valid_type(Union[str, None]) is True
|
||||
assert _is_valid_type(Union[List[int], Dict[str, str]]) is True
|
||||
|
||||
# Test Optional types (which are Union[T, None])
|
||||
assert _is_valid_type(Optional[str]) is True
|
||||
assert _is_valid_type(Optional[List[int]]) is True
|
||||
assert _is_valid_type(Optional[Dict[str, list]]) is True
|
||||
|
||||
# Test that Union itself is not a valid type (only instantiated Unions are)
|
||||
assert _is_valid_type(Union) is False
|
||||
|
||||
def test_nested_generic_types(self):
|
||||
assert _is_valid_type(List[List[Dict[str, List[int]]]]) is True
|
||||
assert _is_valid_type(Dict[str, List[Dict[str, set]]]) is True
|
||||
assert _is_valid_type(Dict[str, Optional[List[int]]]) is True
|
||||
assert _is_valid_type(List[Union[str, Dict[str, List[int]]]]) is True
|
||||
|
||||
def test_edge_cases(self):
|
||||
# Test None and NoneType
|
||||
assert _is_valid_type(None) is False
|
||||
assert _is_valid_type(type(None)) is True
|
||||
|
||||
# Test functions and methods
|
||||
def sample_func():
|
||||
pass
|
||||
|
||||
assert _is_valid_type(sample_func) is False
|
||||
assert _is_valid_type(type(sample_func)) is True
|
||||
|
||||
# Test modules
|
||||
assert _is_valid_type(inspect) is False
|
||||
|
||||
# Test type itself
|
||||
assert _is_valid_type(type) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
[
|
||||
(str, True),
|
||||
(int, True),
|
||||
(List[int], True),
|
||||
(Dict[str, int], True),
|
||||
(Union[str, int], True),
|
||||
(Optional[str], True),
|
||||
(42, False),
|
||||
("string", False),
|
||||
([1, 2, 3], False),
|
||||
(lambda x: x, False),
|
||||
],
|
||||
)
|
||||
def test_parametrized_cases(self, test_input, expected):
|
||||
assert _is_valid_type(test_input) is expected
|
||||
@ -4,20 +4,21 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline, component
|
||||
from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.tools.tool_invoker import ToolInvoker
|
||||
from haystack.components.tools import ToolInvoker
|
||||
from haystack.components.websearch.serper_dev import SerperDevWebSearch
|
||||
from haystack.dataclasses import ChatMessage, ChatRole, Document
|
||||
from haystack.tools import ComponentTool
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
### Component and Model Definitions
|
||||
|
||||
|
||||
@ -121,6 +122,13 @@ class DocumentProcessor:
|
||||
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])}
|
||||
|
||||
|
||||
def output_handler(old, new):
|
||||
"""
|
||||
Output handler to test serialization.
|
||||
"""
|
||||
return old + new
|
||||
|
||||
|
||||
## Unit tests
|
||||
class TestToolComponent:
|
||||
def test_from_component_basic(self):
|
||||
@ -148,6 +156,22 @@ class TestToolComponent:
|
||||
|
||||
assert len(tool.description) == 1024
|
||||
|
||||
def test_from_component_with_inputs(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component, inputs_from_state={"text": "text"})
|
||||
|
||||
assert tool.inputs_from_state == {"text": "text"}
|
||||
# Inputs should be excluded from schema generation
|
||||
assert tool.parameters == {"type": "object", "properties": {}}
|
||||
|
||||
def test_from_component_with_outputs(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component, outputs_to_state={"replies": {"source": "reply"}})
|
||||
|
||||
assert tool.outputs_to_state == {"replies": {"source": "reply"}}
|
||||
|
||||
def test_from_component_with_dataclass(self):
|
||||
component = UserGreeter()
|
||||
|
||||
@ -466,7 +490,7 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
pipeline.connect("llm.replies", "tool_invoker.messages")
|
||||
|
||||
message = ChatMessage.from_user(
|
||||
text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields."
|
||||
text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields."
|
||||
)
|
||||
|
||||
result = pipeline.run({"llm": {"messages": [message]}})
|
||||
@ -500,7 +524,7 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
pipeline.connect("llm.replies", "tool_invoker.messages")
|
||||
|
||||
message = ChatMessage.from_user(
|
||||
text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, blob fields."
|
||||
text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields."
|
||||
)
|
||||
|
||||
result = pipeline.run({"llm": {"messages": [message]}})
|
||||
@ -576,7 +600,13 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
def test_component_tool_serde(self):
|
||||
component = SimpleComponent()
|
||||
|
||||
tool = ComponentTool(component=component, name="simple_tool", description="A simple tool")
|
||||
tool = ComponentTool(
|
||||
component=component,
|
||||
name="simple_tool",
|
||||
description="A simple tool",
|
||||
inputs_from_state={"test": "input"},
|
||||
outputs_to_state={"output": {"source": "out", "handler": output_handler}},
|
||||
)
|
||||
|
||||
# Test serialization
|
||||
tool_dict = tool.to_dict()
|
||||
@ -584,12 +614,16 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
assert tool_dict["data"]["name"] == "simple_tool"
|
||||
assert tool_dict["data"]["description"] == "A simple tool"
|
||||
assert "component" in tool_dict["data"]
|
||||
assert tool_dict["data"]["inputs_from_state"] == {"test": "input"}
|
||||
assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler"
|
||||
|
||||
# Test deserialization
|
||||
new_tool = ComponentTool.from_dict(tool_dict)
|
||||
assert new_tool.name == tool.name
|
||||
assert new_tool.description == tool.description
|
||||
assert new_tool.parameters == tool.parameters
|
||||
assert new_tool.inputs_from_state == tool.inputs_from_state
|
||||
assert new_tool.outputs_to_state == tool.outputs_to_state
|
||||
assert isinstance(new_tool._component, SimpleComponent)
|
||||
|
||||
def test_pipeline_component_fails(self):
|
||||
@ -603,3 +637,21 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
# thus can't be used as tool
|
||||
with pytest.raises(ValueError, match="Component has been added to a pipeline"):
|
||||
ComponentTool(component=component)
|
||||
|
||||
def test_deepcopy_with_jinja_based_component(self):
|
||||
# Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
|
||||
# When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
|
||||
# We overwrite ComponentTool.__deepcopy__ to fix this in experimental until a more comprehensive fix is merged.
|
||||
# We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
|
||||
|
||||
builder = PromptBuilder("{{query}}")
|
||||
|
||||
tool = ComponentTool(component=builder)
|
||||
result = tool.function(query="Hello")
|
||||
|
||||
tool_copy = deepcopy(tool)
|
||||
|
||||
result_from_copy = tool_copy.function(query="Hello")
|
||||
|
||||
assert "prompt" in result_from_copy
|
||||
assert result_from_copy["prompt"] == result["prompt"]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool
|
||||
from haystack.tools.errors import SchemaGenerationError
|
||||
from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
|
||||
@ -133,6 +133,38 @@ def test_tool_decorator_with_annotated_params():
|
||||
assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny"
|
||||
|
||||
|
||||
def test_tool_decorator_with_parameters():
|
||||
@tool(name="fetch_weather", description="A tool to check the weather.")
|
||||
def get_weather(
|
||||
city: Annotated[str, "The target city"] = "Berlin",
|
||||
format: Annotated[Literal["short", "long"], "Output format"] = "short",
|
||||
) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city} ({format} format): 20°C, sunny"
|
||||
|
||||
assert get_weather.name == "fetch_weather"
|
||||
assert get_weather.description == "A tool to check the weather."
|
||||
|
||||
|
||||
def test_tool_decorator_with_inputs_and_outputs():
|
||||
@tool(inputs_from_state={"format": "format"}, outputs_to_state={"output": {"source": "output"}})
|
||||
def get_weather(
|
||||
city: Annotated[str, "The target city"] = "Berlin",
|
||||
format: Annotated[Literal["short", "long"], "Output format"] = "short",
|
||||
) -> str:
|
||||
"""Get weather report for a city."""
|
||||
return f"Weather report for {city} ({format} format): 20°C, sunny"
|
||||
|
||||
assert get_weather.name == "get_weather"
|
||||
assert get_weather.inputs_from_state == {"format": "format"}
|
||||
assert get_weather.outputs_to_state == {"output": {"source": "output"}}
|
||||
# Inputs should be excluded from auto-generated parameters
|
||||
assert get_weather.parameters == {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string", "description": "The target city", "default": "Berlin"}},
|
||||
}
|
||||
|
||||
|
||||
def test_remove_title_from_schema():
|
||||
complex_schema = {
|
||||
"properties": {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
import pytest
|
||||
|
||||
from haystack.tools.tool import Tool, ToolInvocationError, deserialize_tools_inplace, _check_duplicate_tool_names
|
||||
@ -24,12 +25,31 @@ class TestTool:
|
||||
assert tool.description == "Get weather report"
|
||||
assert tool.parameters == parameters
|
||||
assert tool.function == get_weather_report
|
||||
assert tool.inputs_from_state is None
|
||||
assert tool.outputs_to_state is None
|
||||
|
||||
def test_init_invalid_parameters(self):
|
||||
parameters = {"type": "invalid", "properties": {"city": {"type": "string"}}}
|
||||
|
||||
params = {"type": "invalid", "properties": {"city": {"type": "string"}}}
|
||||
with pytest.raises(ValueError):
|
||||
Tool(name="irrelevant", description="irrelevant", parameters=parameters, function=get_weather_report)
|
||||
Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"outputs_to_state",
|
||||
[
|
||||
pytest.param({"documents": ["some_value"]}, id="config-not-a-dict"),
|
||||
pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"),
|
||||
pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"),
|
||||
],
|
||||
)
|
||||
def test_init_invalid_output_structure(self, outputs_to_state):
|
||||
with pytest.raises(ValueError):
|
||||
Tool(
|
||||
name="irrelevant",
|
||||
description="irrelevant",
|
||||
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
function=get_weather_report,
|
||||
outputs_to_state=outputs_to_state,
|
||||
)
|
||||
|
||||
def test_tool_spec(self):
|
||||
tool = Tool(
|
||||
@ -49,13 +69,21 @@ class TestTool:
|
||||
tool = Tool(
|
||||
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
|
||||
)
|
||||
|
||||
with pytest.raises(ToolInvocationError):
|
||||
with pytest.raises(
|
||||
ToolInvocationError,
|
||||
match=re.escape(
|
||||
"Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required positional argument: 'city'"
|
||||
),
|
||||
):
|
||||
tool.invoke()
|
||||
|
||||
def test_to_dict(self):
|
||||
tool = Tool(
|
||||
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
|
||||
name="weather",
|
||||
description="Get weather report",
|
||||
parameters=parameters,
|
||||
function=get_weather_report,
|
||||
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
|
||||
)
|
||||
|
||||
assert tool.to_dict() == {
|
||||
@ -65,6 +93,8 @@ class TestTool:
|
||||
"description": "Get weather report",
|
||||
"parameters": parameters,
|
||||
"function": "test_tool.get_weather_report",
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
|
||||
},
|
||||
}
|
||||
|
||||
@ -76,6 +106,7 @@ class TestTool:
|
||||
"description": "Get weather report",
|
||||
"parameters": parameters,
|
||||
"function": "test_tool.get_weather_report",
|
||||
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
|
||||
},
|
||||
}
|
||||
|
||||
@ -85,6 +116,8 @@ class TestTool:
|
||||
assert tool.description == "Get weather report"
|
||||
assert tool.parameters == parameters
|
||||
assert tool.function == get_weather_report
|
||||
assert tool.outputs_to_state["documents"]["source"] == "docs"
|
||||
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
|
||||
|
||||
|
||||
def test_deserialize_tools_inplace():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user