mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 15:27:06 +00:00
feat: add custom jinja filter handling to ConditionalRouter (#7957)
* add custom jinja filter handling to ConditionalRouter * add release notes for custom filters * align sede to existing patterns and update docstring example * update sede unit test route condition to be more explicit --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
cafcf51cb0
commit
7178aa0253
@ -2,13 +2,13 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.utils import deserialize_type, serialize_type
|
||||
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -102,7 +102,7 @@ class ConditionalRouter:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, routes: List[Dict]):
|
||||
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
|
||||
"""
|
||||
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
|
||||
|
||||
@ -113,12 +113,20 @@ class ConditionalRouter:
|
||||
- `output_type`: The type of the output data (e.g., str, List[int]).
|
||||
- `output_name`: The name under which the `output` value of the route is published. This name is used to
|
||||
connect the router to other components in the pipeline.
|
||||
:param custom_filters: A dictionary of custom Jinja2 filters to be used in the condition expressions.
|
||||
For example, passing `{"my_filter": my_filter_fcn}` where:
|
||||
- `my_filter` is the name of the custom filter.
|
||||
- `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
|
||||
`{{ my_var|my_filter }}` can then be used inside a route condition expression like so:
|
||||
`"condition": "{{ my_var|my_filter == 'foo' }}"`.
|
||||
"""
|
||||
self._validate_routes(routes)
|
||||
self.routes: List[dict] = routes
|
||||
self.custom_filters = custom_filters or {}
|
||||
|
||||
# Create a Jinja native environment to inspect variables in the condition templates
|
||||
env = NativeEnvironment()
|
||||
env.filters.update(self.custom_filters)
|
||||
|
||||
# Inspect the routes to determine input and output types.
|
||||
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
||||
@ -145,8 +153,8 @@ class ConditionalRouter:
|
||||
for route in self.routes:
|
||||
# output_type needs to be serialized to a string
|
||||
route["output_type"] = serialize_type(route["output_type"])
|
||||
|
||||
return default_to_dict(self, routes=self.routes)
|
||||
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
|
||||
return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
||||
@ -163,6 +171,8 @@ class ConditionalRouter:
|
||||
for route in routes:
|
||||
# output_type needs to be deserialized from a string to a type
|
||||
route["output_type"] = deserialize_type(route["output_type"])
|
||||
for name, filter_func in init_params.get("custom_filters", {}).items():
|
||||
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(self, **kwargs):
|
||||
@ -185,6 +195,7 @@ class ConditionalRouter:
|
||||
"""
|
||||
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
||||
env = NativeEnvironment()
|
||||
env.filters.update(self.custom_filters)
|
||||
|
||||
for route in self.routes:
|
||||
try:
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added custom filters support to ConditionalRouter. Users can now pass in
|
||||
one or more custom Jinja2 filter callables and be able to access those
|
||||
filters when defining condition expressions in routes.
|
||||
@ -12,6 +12,11 @@ from haystack.components.routers.conditional_router import NoRouteSelectedExcept
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
def custom_filter_to_sede(value):
|
||||
"""splits by hyphen and returns the first part"""
|
||||
return int(value.split("-")[0])
|
||||
|
||||
|
||||
class TestRouter:
|
||||
@pytest.fixture
|
||||
def routes(self):
|
||||
@ -288,3 +293,47 @@ class TestRouter:
|
||||
router_dict_first_invocation = copy.deepcopy(router.to_dict())
|
||||
router_dict_second_invocation = router.to_dict()
|
||||
assert router_dict_first_invocation == router_dict_second_invocation
|
||||
|
||||
def test_custom_filter(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{phone_num|get_area_code == 123}}",
|
||||
"output": "Phone number has a 123 area code",
|
||||
"output_name": "good_phone_num",
|
||||
"output_type": str,
|
||||
},
|
||||
{
|
||||
"condition": "{{phone_num|get_area_code != 123}}",
|
||||
"output": "Phone number does not have 123 area code",
|
||||
"output_name": "bad_phone_num",
|
||||
"output_type": str,
|
||||
},
|
||||
]
|
||||
|
||||
router = ConditionalRouter(routes, custom_filters={"get_area_code": custom_filter_to_sede})
|
||||
kwargs = {"phone_num": "123-456-7890"}
|
||||
result = router.run(**kwargs)
|
||||
assert result == {"good_phone_num": "Phone number has a 123 area code"}
|
||||
kwargs = {"phone_num": "321-456-7890"}
|
||||
result = router.run(**kwargs)
|
||||
assert result == {"bad_phone_num": "Phone number does not have 123 area code"}
|
||||
|
||||
def test_sede_with_custom_filter(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{ test|custom_filter_to_sede == 123 }}",
|
||||
"output": "123",
|
||||
"output_name": "test",
|
||||
"output_type": int,
|
||||
}
|
||||
]
|
||||
custom_filters = {"custom_filter_to_sede": custom_filter_to_sede}
|
||||
router = ConditionalRouter(routes, custom_filters=custom_filters)
|
||||
kwargs = {"test": "123-456-789"}
|
||||
result = router.run(**kwargs)
|
||||
assert result == {"test": 123}
|
||||
serialized_router = router.to_dict()
|
||||
deserialized_router = ConditionalRouter.from_dict(serialized_router)
|
||||
assert deserialized_router.custom_filters == router.custom_filters
|
||||
assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123
|
||||
assert result == deserialized_router.run(**kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user