fix: make from dict conditional router more resilient (#8343)

* fix: make from dict conditional router more resilient

* refactor: remove

* dos: add release notes

* fix: format
This commit is contained in:
ArzelaAscoIi 2024-09-09 15:11:52 +02:00 committed by GitHub
parent 75955922b9
commit 720e54970f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 2 deletions

View File

@ -187,8 +187,13 @@ class ConditionalRouter:
for route in routes:
# output_type needs to be deserialized from a string to a type
route["output_type"] = deserialize_type(route["output_type"])
for name, filter_func in init_params.get("custom_filters", {}).items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
# Since the custom_filters are typed as optional in the init signature, we catch the
# case where they are not present in the serialized data and set them to an empty dict.
custom_filters = init_params.get("custom_filters", {})
if custom_filters is not None:
for name, filter_func in custom_filters.items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)
def run(self, **kwargs):

View File

@ -0,0 +1,6 @@
---
fixes:
- |
The `from_dict` method of `ConditionalRouter` now correctly handles
the case where the `dict` passed to it contains the key `custom_filters` explicitly
set to `None`. Previously this was causing an `AttributeError`

View File

@ -243,6 +243,36 @@ class TestRouter:
# check that the result is the same and correct
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}
def test_router_de_serialization_with_none_argument(self):
new_router = ConditionalRouter.from_dict(
{
"type": "haystack.components.routers.conditional_router.ConditionalRouter",
"init_parameters": {
"routes": [
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": "str",
"output_name": "query",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": "typing.List[int]",
"output_name": "streams",
},
],
"custom_filters": None,
"unsafe": False,
},
}
)
# now use both routers with the same input
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result2 = new_router.run(**kwargs)
assert result2 == {"streams": [1, 2, 3]}
def test_router_serialization_idempotence(self):
routes = [
{