diff --git a/haystack/components/agents/state/state.py b/haystack/components/agents/state/state.py new file mode 100644 index 000000000..857a034d1 --- /dev/null +++ b/haystack/components/agents/state/state.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional + +from haystack.dataclasses import ChatMessage +from haystack.utils import deserialize_value, serialize_value +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.utils.type_serialization import deserialize_type, serialize_type + +from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values + + +def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert a schema dictionary to a serializable format. + + Converts each parameter's type and optional handler function into a serializable + format using type and callable serialization utilities. + + :param schema: Dictionary mapping parameter names to their type and handler configs + :returns: Dictionary with serialized type and handler information + """ + serialized_schema = {} + for param, config in schema.items(): + serialized_schema[param] = {"type": serialize_type(config["type"])} + if config.get("handler"): + serialized_schema[param]["handler"] = serialize_callable(config["handler"]) + + return serialized_schema + + +def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert a serialized schema dictionary back to its original format. + + Deserializes the type and optional handler function for each parameter from their + serialized format back into Python types and callables. + + :param schema: Dictionary containing serialized schema information + :returns: Dictionary with deserialized type and handler configurations + """ + deserialized_schema = {} + for param, config in schema.items(): + deserialized_schema[param] = {"type": deserialize_type(config["type"])} + + if config.get("handler"): + deserialized_schema[param]["handler"] = deserialize_callable(config["handler"]) + + return deserialized_schema + + +def _validate_schema(schema: Dict[str, Any]) -> None: + """ + Validate that a schema dictionary meets all required constraints. + + Checks that each parameter definition has a valid type field and that any handler + specified is a callable function. + + :param schema: Dictionary mapping parameter names to their type and handler configs + :raises ValueError: If schema validation fails due to missing or invalid fields + """ + for param, definition in schema.items(): + if "type" not in definition: + raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.") + if not _is_valid_type(definition["type"]): + raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") + if definition.get("handler") is not None and not callable(definition["handler"]): + raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") + if param == "messages" and definition["type"] is not List[ChatMessage]: + raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}") + + +class State: + """ + A class that wraps a StateSchema and maintains an internal _data dictionary. + + Each schema entry has: + "parameter_name": { + "type": SomeType, + "handler": Optional[Callable[[Any, Any], Any]] + } + """ + + def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None): + """ + Initialize a State object with a schema and optional data. + + :param schema: Dictionary mapping parameter names to their type and handler configs. + Type must be a valid Python type, and handler must be a callable function or None. + If handler is None, the default handler for the type will be used. The default handlers are: + - For list types: `haystack.dataclasses.state_utils.merge_lists` + - For all other types: `haystack.dataclasses.state_utils.replace_values` + :param data: Optional dictionary of initial data to populate the state + """ + _validate_schema(schema) + self.schema = deepcopy(schema) + if self.schema.get("messages") is None: + self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists} + self._data = data or {} + + # Set default handlers if not provided in schema + for definition in self.schema.values(): + # Skip if handler is already defined and not None + if definition.get("handler") is not None: + continue + # Set default handler based on type + if _is_list_type(definition["type"]): + definition["handler"] = merge_lists + else: + definition["handler"] = replace_values + + def get(self, key: str, default: Any = None) -> Any: + """ + Retrieve a value from the state by key. + + :param key: Key to look up in the state + :param default: Value to return if key is not found + :returns: Value associated with key or default if not found + """ + return deepcopy(self._data.get(key, default)) + + def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None: + """ + Set or merge a value in the state according to schema rules. + + Value is merged or overwritten according to these rules: + - if handler_override is given, use that + - else use the handler defined in the schema for 'key' + + :param key: Key to store the value under + :param value: Value to store or merge + :param handler_override: Optional function to override the default merge behavior + """ + # If key not in schema, we throw an error + definition = self.schema.get(key, None) + if definition is None: + raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}") + + # Get current value from state and apply handler + current_value = self._data.get(key, None) + handler = handler_override or definition["handler"] + self._data[key] = handler(current_value, value) + + @property + def data(self): + """ + All current data of the state. + """ + return self._data + + def has(self, key: str) -> bool: + """ + Check if a key exists in the state. + + :param key: Key to check for existence + :returns: True if key exists in state, False otherwise + """ + return key in self._data + + def to_dict(self): + """ + Convert the State object to a dictionary. + """ + serialized = {} + serialized["schema"] = _schema_to_dict(self.schema) + + serialized["data"] = serialize_value(self._data) + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + """ + Convert a dictionary back to a State object. + """ + schema = _schema_from_dict(data.get("schema", {})) + deserialized_data = deserialize_value(data.get("data", {})) + return State(schema, deserialized_data) diff --git a/haystack/components/agents/state/state_utils.py b/haystack/components/agents/state/state_utils.py new file mode 100644 index 000000000..2b392d812 --- /dev/null +++ b/haystack/components/agents/state/state_utils.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Any, List, TypeVar, Union, get_origin + +T = TypeVar("T") + + +def _is_valid_type(obj: Any) -> bool: + """ + Check if an object is a valid type annotation. + + Valid types include: + - Normal classes (str, dict, CustomClass) + - Generic types (List[str], Dict[str, int]) + - Union types (Union[str, int], Optional[str]) + + :param obj: The object to check + :return: True if the object is a valid type annotation, False otherwise + + Example usage: + >>> _is_valid_type(str) + True + >>> _is_valid_type(List[int]) + True + >>> _is_valid_type(Union[str, int]) + True + >>> _is_valid_type(42) + False + """ + # Handle Union types (including Optional) + if hasattr(obj, "__origin__") and obj.__origin__ is Union: + return True + + # Handle normal classes and generic types + return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"} + + +def _is_list_type(type_hint: Any) -> bool: + """ + Check if a type hint represents a list type. + + :param type_hint: The type hint to check + :return: True if the type hint represents a list, False otherwise + """ + return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list) + + +def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]: + """ + Merges two values into a single list. + + If either `current` or `new` is not already a list, it is converted into one. + The function ensures that both inputs are treated as lists and concatenates them. + + If `current` is None, it is treated as an empty list. + + :param current: The existing value(s), either a single item or a list. + :param new: The new value(s) to merge, either a single item or a list. + :return: A list containing elements from both `current` and `new`. + """ + current_list = [] if current is None else current if isinstance(current, list) else [current] + new_list = new if isinstance(new, list) else [new] + return current_list + new_list + + +def replace_values(current: Any, new: Any) -> Any: + """ + Replace the `current` value with the `new` value. + + :param current: The existing value + :param new: The new value to replace + :return: The new value + """ + return new diff --git a/pyproject.toml b/pyproject.toml index 00ee08b62..171bf94e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "LLM framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data." readme = "README.md" license = "Apache-2.0" -requires-python = ">=3.9" +requires-python = ">=3.9, <3.13" authors = [{ name = "deepset.ai", email = "malte.pietsch@deepset.ai" }] keywords = [ "BERT",