feat: add custom jinja filter handling to ConditionalRouter (#7957)

* add custom jinja filter handling to ConditionalRouter

* add release notes for custom filters

* align sede to existing patterns and update docstring example

* update sede unit test route condition to be more explicit

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Chris Pappalardo 2024-07-04 01:08:12 -07:00 committed by GitHub
parent cafcf51cb0
commit 7178aa0253
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 5 deletions

View File

@ -2,13 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Set
from typing import Any, Callable, Dict, List, Optional, Set
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_type, serialize_type
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
logger = logging.getLogger(__name__)
@ -102,7 +102,7 @@ class ConditionalRouter:
```
"""
def __init__(self, routes: List[Dict]):
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
"""
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
@ -113,12 +113,20 @@ class ConditionalRouter:
- `output_type`: The type of the output data (e.g., str, List[int]).
- `output_name`: The name under which the `output` value of the route is published. This name is used to
connect the router to other components in the pipeline.
:param custom_filters: A dictionary of custom Jinja2 filters to be used in the condition expressions.
For example, passing `{"my_filter": my_filter_fcn}` where:
- `my_filter` is the name of the custom filter.
- `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
`{{ my_var|my_filter }}` can then be used inside a route condition expression like so:
`"condition": "{{ my_var|my_filter == 'foo' }}"`.
"""
self._validate_routes(routes)
self.routes: List[dict] = routes
self.custom_filters = custom_filters or {}
# Create a Jinja native environment to inspect variables in the condition templates
env = NativeEnvironment()
env.filters.update(self.custom_filters)
# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
@ -145,8 +153,8 @@ class ConditionalRouter:
for route in self.routes:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])
return default_to_dict(self, routes=self.routes)
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
@ -163,6 +171,8 @@ class ConditionalRouter:
for route in routes:
# output_type needs to be deserialized from a string to a type
route["output_type"] = deserialize_type(route["output_type"])
for name, filter_func in init_params.get("custom_filters", {}).items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)
def run(self, **kwargs):
@ -185,6 +195,7 @@ class ConditionalRouter:
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
env = NativeEnvironment()
env.filters.update(self.custom_filters)
for route in self.routes:
try:

View File

@ -0,0 +1,6 @@
---
features:
- |
Added custom filters support to ConditionalRouter. Users can now pass in
one or more custom Jinja2 filter callables and be able to access those
filters when defining condition expressions in routes.

View File

@ -12,6 +12,11 @@ from haystack.components.routers.conditional_router import NoRouteSelectedExcept
from haystack.dataclasses import ChatMessage
def custom_filter_to_sede(value):
"""splits by hyphen and returns the first part"""
return int(value.split("-")[0])
class TestRouter:
@pytest.fixture
def routes(self):
@ -288,3 +293,47 @@ class TestRouter:
router_dict_first_invocation = copy.deepcopy(router.to_dict())
router_dict_second_invocation = router.to_dict()
assert router_dict_first_invocation == router_dict_second_invocation
def test_custom_filter(self):
routes = [
{
"condition": "{{phone_num|get_area_code == 123}}",
"output": "Phone number has a 123 area code",
"output_name": "good_phone_num",
"output_type": str,
},
{
"condition": "{{phone_num|get_area_code != 123}}",
"output": "Phone number does not have 123 area code",
"output_name": "bad_phone_num",
"output_type": str,
},
]
router = ConditionalRouter(routes, custom_filters={"get_area_code": custom_filter_to_sede})
kwargs = {"phone_num": "123-456-7890"}
result = router.run(**kwargs)
assert result == {"good_phone_num": "Phone number has a 123 area code"}
kwargs = {"phone_num": "321-456-7890"}
result = router.run(**kwargs)
assert result == {"bad_phone_num": "Phone number does not have 123 area code"}
def test_sede_with_custom_filter(self):
routes = [
{
"condition": "{{ test|custom_filter_to_sede == 123 }}",
"output": "123",
"output_name": "test",
"output_type": int,
}
]
custom_filters = {"custom_filter_to_sede": custom_filter_to_sede}
router = ConditionalRouter(routes, custom_filters=custom_filters)
kwargs = {"test": "123-456-789"}
result = router.run(**kwargs)
assert result == {"test": 123}
serialized_router = router.to_dict()
deserialized_router = ConditionalRouter.from_dict(serialized_router)
assert deserialized_router.custom_filters == router.custom_filters
assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123
assert result == deserialized_router.run(**kwargs)