mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
Pin python version
This commit is contained in:
parent
3e28ec207a
commit
80121a15d9
180
haystack/components/agents/state/state.py
Normal file
180
haystack/components/agents/state/state.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
from haystack.utils import deserialize_value, serialize_value
|
||||||
|
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||||
|
from haystack.utils.type_serialization import deserialize_type, serialize_type
|
||||||
|
|
||||||
|
from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a schema dictionary to a serializable format.
|
||||||
|
|
||||||
|
Converts each parameter's type and optional handler function into a serializable
|
||||||
|
format using type and callable serialization utilities.
|
||||||
|
|
||||||
|
:param schema: Dictionary mapping parameter names to their type and handler configs
|
||||||
|
:returns: Dictionary with serialized type and handler information
|
||||||
|
"""
|
||||||
|
serialized_schema = {}
|
||||||
|
for param, config in schema.items():
|
||||||
|
serialized_schema[param] = {"type": serialize_type(config["type"])}
|
||||||
|
if config.get("handler"):
|
||||||
|
serialized_schema[param]["handler"] = serialize_callable(config["handler"])
|
||||||
|
|
||||||
|
return serialized_schema
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a serialized schema dictionary back to its original format.
|
||||||
|
|
||||||
|
Deserializes the type and optional handler function for each parameter from their
|
||||||
|
serialized format back into Python types and callables.
|
||||||
|
|
||||||
|
:param schema: Dictionary containing serialized schema information
|
||||||
|
:returns: Dictionary with deserialized type and handler configurations
|
||||||
|
"""
|
||||||
|
deserialized_schema = {}
|
||||||
|
for param, config in schema.items():
|
||||||
|
deserialized_schema[param] = {"type": deserialize_type(config["type"])}
|
||||||
|
|
||||||
|
if config.get("handler"):
|
||||||
|
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
|
||||||
|
|
||||||
|
return deserialized_schema
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_schema(schema: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate that a schema dictionary meets all required constraints.
|
||||||
|
|
||||||
|
Checks that each parameter definition has a valid type field and that any handler
|
||||||
|
specified is a callable function.
|
||||||
|
|
||||||
|
:param schema: Dictionary mapping parameter names to their type and handler configs
|
||||||
|
:raises ValueError: If schema validation fails due to missing or invalid fields
|
||||||
|
"""
|
||||||
|
for param, definition in schema.items():
|
||||||
|
if "type" not in definition:
|
||||||
|
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
|
||||||
|
if not _is_valid_type(definition["type"]):
|
||||||
|
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
|
||||||
|
if definition.get("handler") is not None and not callable(definition["handler"]):
|
||||||
|
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
|
||||||
|
if param == "messages" and definition["type"] is not List[ChatMessage]:
|
||||||
|
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
"""
|
||||||
|
A class that wraps a StateSchema and maintains an internal _data dictionary.
|
||||||
|
|
||||||
|
Each schema entry has:
|
||||||
|
"parameter_name": {
|
||||||
|
"type": SomeType,
|
||||||
|
"handler": Optional[Callable[[Any, Any], Any]]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
|
||||||
|
"""
|
||||||
|
Initialize a State object with a schema and optional data.
|
||||||
|
|
||||||
|
:param schema: Dictionary mapping parameter names to their type and handler configs.
|
||||||
|
Type must be a valid Python type, and handler must be a callable function or None.
|
||||||
|
If handler is None, the default handler for the type will be used. The default handlers are:
|
||||||
|
- For list types: `haystack.dataclasses.state_utils.merge_lists`
|
||||||
|
- For all other types: `haystack.dataclasses.state_utils.replace_values`
|
||||||
|
:param data: Optional dictionary of initial data to populate the state
|
||||||
|
"""
|
||||||
|
_validate_schema(schema)
|
||||||
|
self.schema = deepcopy(schema)
|
||||||
|
if self.schema.get("messages") is None:
|
||||||
|
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
|
||||||
|
self._data = data or {}
|
||||||
|
|
||||||
|
# Set default handlers if not provided in schema
|
||||||
|
for definition in self.schema.values():
|
||||||
|
# Skip if handler is already defined and not None
|
||||||
|
if definition.get("handler") is not None:
|
||||||
|
continue
|
||||||
|
# Set default handler based on type
|
||||||
|
if _is_list_type(definition["type"]):
|
||||||
|
definition["handler"] = merge_lists
|
||||||
|
else:
|
||||||
|
definition["handler"] = replace_values
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Retrieve a value from the state by key.
|
||||||
|
|
||||||
|
:param key: Key to look up in the state
|
||||||
|
:param default: Value to return if key is not found
|
||||||
|
:returns: Value associated with key or default if not found
|
||||||
|
"""
|
||||||
|
return deepcopy(self._data.get(key, default))
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Set or merge a value in the state according to schema rules.
|
||||||
|
|
||||||
|
Value is merged or overwritten according to these rules:
|
||||||
|
- if handler_override is given, use that
|
||||||
|
- else use the handler defined in the schema for 'key'
|
||||||
|
|
||||||
|
:param key: Key to store the value under
|
||||||
|
:param value: Value to store or merge
|
||||||
|
:param handler_override: Optional function to override the default merge behavior
|
||||||
|
"""
|
||||||
|
# If key not in schema, we throw an error
|
||||||
|
definition = self.schema.get(key, None)
|
||||||
|
if definition is None:
|
||||||
|
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
|
||||||
|
|
||||||
|
# Get current value from state and apply handler
|
||||||
|
current_value = self._data.get(key, None)
|
||||||
|
handler = handler_override or definition["handler"]
|
||||||
|
self._data[key] = handler(current_value, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
"""
|
||||||
|
All current data of the state.
|
||||||
|
"""
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def has(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a key exists in the state.
|
||||||
|
|
||||||
|
:param key: Key to check for existence
|
||||||
|
:returns: True if key exists in state, False otherwise
|
||||||
|
"""
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Convert the State object to a dictionary.
|
||||||
|
"""
|
||||||
|
serialized = {}
|
||||||
|
serialized["schema"] = _schema_to_dict(self.schema)
|
||||||
|
|
||||||
|
serialized["data"] = serialize_value(self._data)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Convert a dictionary back to a State object.
|
||||||
|
"""
|
||||||
|
schema = _schema_from_dict(data.get("schema", {}))
|
||||||
|
deserialized_data = deserialize_value(data.get("data", {}))
|
||||||
|
return State(schema, deserialized_data)
|
77
haystack/components/agents/state/state_utils.py
Normal file
77
haystack/components/agents/state/state_utils.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, List, TypeVar, Union, get_origin
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_type(obj: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an object is a valid type annotation.
|
||||||
|
|
||||||
|
Valid types include:
|
||||||
|
- Normal classes (str, dict, CustomClass)
|
||||||
|
- Generic types (List[str], Dict[str, int])
|
||||||
|
- Union types (Union[str, int], Optional[str])
|
||||||
|
|
||||||
|
:param obj: The object to check
|
||||||
|
:return: True if the object is a valid type annotation, False otherwise
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> _is_valid_type(str)
|
||||||
|
True
|
||||||
|
>>> _is_valid_type(List[int])
|
||||||
|
True
|
||||||
|
>>> _is_valid_type(Union[str, int])
|
||||||
|
True
|
||||||
|
>>> _is_valid_type(42)
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
# Handle Union types (including Optional)
|
||||||
|
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Handle normal classes and generic types
|
||||||
|
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_list_type(type_hint: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a type hint represents a list type.
|
||||||
|
|
||||||
|
:param type_hint: The type hint to check
|
||||||
|
:return: True if the type hint represents a list, False otherwise
|
||||||
|
"""
|
||||||
|
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
|
||||||
|
"""
|
||||||
|
Merges two values into a single list.
|
||||||
|
|
||||||
|
If either `current` or `new` is not already a list, it is converted into one.
|
||||||
|
The function ensures that both inputs are treated as lists and concatenates them.
|
||||||
|
|
||||||
|
If `current` is None, it is treated as an empty list.
|
||||||
|
|
||||||
|
:param current: The existing value(s), either a single item or a list.
|
||||||
|
:param new: The new value(s) to merge, either a single item or a list.
|
||||||
|
:return: A list containing elements from both `current` and `new`.
|
||||||
|
"""
|
||||||
|
current_list = [] if current is None else current if isinstance(current, list) else [current]
|
||||||
|
new_list = new if isinstance(new, list) else [new]
|
||||||
|
return current_list + new_list
|
||||||
|
|
||||||
|
|
||||||
|
def replace_values(current: Any, new: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Replace the `current` value with the `new` value.
|
||||||
|
|
||||||
|
:param current: The existing value
|
||||||
|
:param new: The new value to replace
|
||||||
|
:return: The new value
|
||||||
|
"""
|
||||||
|
return new
|
@ -8,7 +8,7 @@ dynamic = ["version"]
|
|||||||
description = "LLM framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data."
|
description = "LLM framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9, <3.13"
|
||||||
authors = [{ name = "deepset.ai", email = "malte.pietsch@deepset.ai" }]
|
authors = [{ name = "deepset.ai", email = "malte.pietsch@deepset.ai" }]
|
||||||
keywords = [
|
keywords = [
|
||||||
"BERT",
|
"BERT",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user