haystack/test/dataclasses/test_state.py
Julian Risch 657d09d7f1
feat: integrate updates of Tool, ToolInvoker, State, create_tool_from_function, ComponentTool from haystack-experimental (#9113)
* update Tool,ToolInvoker,ComponentTool,create_tool_from_function

* add State and its utils

* add tests for State and its utils

* update tests for Tool etc.

* reno

* fix circular imports

* update experimental imports in tests

* fix unit tests

* fix ChatGenerator unit tests

* mypy

* add State to init and pydoc

* explain State in more detail in release note

* add test from #8913

* re-add _check_duplicate_tool_names and refactor imports

* rename inputs and outputs
2025-03-28 10:49:23 +01:00

147 lines
4.2 KiB
Python

import pytest
from typing import List, Dict
from haystack.dataclasses.state import State, _validate_schema
@pytest.fixture
def basic_schema():
return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
@pytest.fixture
def complex_schema():
return {
"numbers": {
"type": list,
"handler": lambda current, new: sorted(set(current + new)) if current else sorted(set(new)),
},
"metadata": {"type": dict},
"name": {"type": str},
}
def test_validate_schema_valid(basic_schema):
# Should not raise any exceptions
_validate_schema(basic_schema)
def test_validate_schema_invalid_type():
invalid_schema = {"test": {"type": "not_a_type"}}
with pytest.raises(ValueError, match="must be a Python type"):
_validate_schema(invalid_schema)
def test_validate_schema_missing_type():
invalid_schema = {"test": {"handler": lambda x, y: x + y}}
with pytest.raises(ValueError, match="missing a 'type' entry"):
_validate_schema(invalid_schema)
def test_validate_schema_invalid_handler():
invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
with pytest.raises(ValueError, match="must be callable or None"):
_validate_schema(invalid_schema)
def test_state_initialization(basic_schema):
# Test empty initialization
state = State(basic_schema)
assert state.data == {}
# Test initialization with data
initial_data = {"numbers": [1, 2, 3], "name": "test"}
state = State(basic_schema, initial_data)
assert state.data["numbers"] == [1, 2, 3]
assert state.data["name"] == "test"
def test_state_get(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.get("name") == "test"
assert state.get("non_existent") is None
assert state.get("non_existent", "default") == "default"
def test_state_set_basic(basic_schema):
state = State(basic_schema)
# Test setting new values
state.set("numbers", [1, 2])
assert state.get("numbers") == [1, 2]
# Test updating existing values
state.set("numbers", [3, 4])
assert state.get("numbers") == [1, 2, 3, 4]
def test_state_set_with_handler(complex_schema):
state = State(complex_schema)
# Test custom handler for numbers
state.set("numbers", [3, 2, 1])
assert state.get("numbers") == [1, 2, 3]
state.set("numbers", [6, 5, 4])
assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
def test_state_set_with_handler_override(basic_schema):
state = State(basic_schema)
# Custom handler that concatenates strings
custom_handler = lambda current, new: f"{current}-{new}" if current else new
state.set("name", "first")
state.set("name", "second", handler_override=custom_handler)
assert state.get("name") == "first-second"
def test_state_has(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.has("name") is True
assert state.has("non_existent") is False
def test_state_empty_schema():
state = State({})
assert state.data == {}
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
state.set("any_key", "value")
def test_state_none_values(basic_schema):
state = State(basic_schema)
state.set("name", None)
assert state.get("name") is None
state.set("name", "value")
assert state.get("name") == "value"
def test_state_merge_lists(basic_schema):
state = State(basic_schema)
state.set("numbers", "not_a_list")
assert state.get("numbers") == ["not_a_list"]
state.set("numbers", [1, 2])
assert state.get("numbers") == ["not_a_list", 1, 2]
def test_state_nested_structures():
schema = {
"complex": {
"type": Dict[str, List[int]],
"handler": lambda current, new: {
k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())
}
if current
else new,
}
}
state = State(schema)
state.set("complex", {"a": [1, 2], "b": [3, 4]})
state.set("complex", {"b": [5, 6], "c": [7, 8]})
expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
assert state.get("complex") == expected