mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 23:37:36 +00:00
feat: add custom jinja filter handling to ConditionalRouter (#7957)
* add custom jinja filter handling to ConditionalRouter * add release notes for custom filters * align sede to existing patterns and update docstring example * update sede unit test route condition to be more explicit --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
cafcf51cb0
commit
7178aa0253
@ -2,13 +2,13 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||||
from jinja2.nativetypes import NativeEnvironment
|
from jinja2.nativetypes import NativeEnvironment
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.utils import deserialize_type, serialize_type
|
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class ConditionalRouter:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, routes: List[Dict]):
|
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
|
||||||
"""
|
"""
|
||||||
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
|
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
|
||||||
|
|
||||||
@ -113,12 +113,20 @@ class ConditionalRouter:
|
|||||||
- `output_type`: The type of the output data (e.g., str, List[int]).
|
- `output_type`: The type of the output data (e.g., str, List[int]).
|
||||||
- `output_name`: The name under which the `output` value of the route is published. This name is used to
|
- `output_name`: The name under which the `output` value of the route is published. This name is used to
|
||||||
connect the router to other components in the pipeline.
|
connect the router to other components in the pipeline.
|
||||||
|
:param custom_filters: A dictionary of custom Jinja2 filters to be used in the condition expressions.
|
||||||
|
For example, passing `{"my_filter": my_filter_fcn}` where:
|
||||||
|
- `my_filter` is the name of the custom filter.
|
||||||
|
- `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
|
||||||
|
`{{ my_var|my_filter }}` can then be used inside a route condition expression like so:
|
||||||
|
`"condition": "{{ my_var|my_filter == 'foo' }}"`.
|
||||||
"""
|
"""
|
||||||
self._validate_routes(routes)
|
self._validate_routes(routes)
|
||||||
self.routes: List[dict] = routes
|
self.routes: List[dict] = routes
|
||||||
|
self.custom_filters = custom_filters or {}
|
||||||
|
|
||||||
# Create a Jinja native environment to inspect variables in the condition templates
|
# Create a Jinja native environment to inspect variables in the condition templates
|
||||||
env = NativeEnvironment()
|
env = NativeEnvironment()
|
||||||
|
env.filters.update(self.custom_filters)
|
||||||
|
|
||||||
# Inspect the routes to determine input and output types.
|
# Inspect the routes to determine input and output types.
|
||||||
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
||||||
@ -145,8 +153,8 @@ class ConditionalRouter:
|
|||||||
for route in self.routes:
|
for route in self.routes:
|
||||||
# output_type needs to be serialized to a string
|
# output_type needs to be serialized to a string
|
||||||
route["output_type"] = serialize_type(route["output_type"])
|
route["output_type"] = serialize_type(route["output_type"])
|
||||||
|
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
|
||||||
return default_to_dict(self, routes=self.routes)
|
return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
||||||
@ -163,6 +171,8 @@ class ConditionalRouter:
|
|||||||
for route in routes:
|
for route in routes:
|
||||||
# output_type needs to be deserialized from a string to a type
|
# output_type needs to be deserialized from a string to a type
|
||||||
route["output_type"] = deserialize_type(route["output_type"])
|
route["output_type"] = deserialize_type(route["output_type"])
|
||||||
|
for name, filter_func in init_params.get("custom_filters", {}).items():
|
||||||
|
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
@ -185,6 +195,7 @@ class ConditionalRouter:
|
|||||||
"""
|
"""
|
||||||
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
||||||
env = NativeEnvironment()
|
env = NativeEnvironment()
|
||||||
|
env.filters.update(self.custom_filters)
|
||||||
|
|
||||||
for route in self.routes:
|
for route in self.routes:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Added custom filters support to ConditionalRouter. Users can now pass in
|
||||||
|
one or more custom Jinja2 filter callables and be able to access those
|
||||||
|
filters when defining condition expressions in routes.
|
||||||
@ -12,6 +12,11 @@ from haystack.components.routers.conditional_router import NoRouteSelectedExcept
|
|||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
def custom_filter_to_sede(value):
|
||||||
|
"""splits by hyphen and returns the first part"""
|
||||||
|
return int(value.split("-")[0])
|
||||||
|
|
||||||
|
|
||||||
class TestRouter:
|
class TestRouter:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def routes(self):
|
def routes(self):
|
||||||
@ -288,3 +293,47 @@ class TestRouter:
|
|||||||
router_dict_first_invocation = copy.deepcopy(router.to_dict())
|
router_dict_first_invocation = copy.deepcopy(router.to_dict())
|
||||||
router_dict_second_invocation = router.to_dict()
|
router_dict_second_invocation = router.to_dict()
|
||||||
assert router_dict_first_invocation == router_dict_second_invocation
|
assert router_dict_first_invocation == router_dict_second_invocation
|
||||||
|
|
||||||
|
def test_custom_filter(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{phone_num|get_area_code == 123}}",
|
||||||
|
"output": "Phone number has a 123 area code",
|
||||||
|
"output_name": "good_phone_num",
|
||||||
|
"output_type": str,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{phone_num|get_area_code != 123}}",
|
||||||
|
"output": "Phone number does not have 123 area code",
|
||||||
|
"output_name": "bad_phone_num",
|
||||||
|
"output_type": str,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = ConditionalRouter(routes, custom_filters={"get_area_code": custom_filter_to_sede})
|
||||||
|
kwargs = {"phone_num": "123-456-7890"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"good_phone_num": "Phone number has a 123 area code"}
|
||||||
|
kwargs = {"phone_num": "321-456-7890"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"bad_phone_num": "Phone number does not have 123 area code"}
|
||||||
|
|
||||||
|
def test_sede_with_custom_filter(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{ test|custom_filter_to_sede == 123 }}",
|
||||||
|
"output": "123",
|
||||||
|
"output_name": "test",
|
||||||
|
"output_type": int,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
custom_filters = {"custom_filter_to_sede": custom_filter_to_sede}
|
||||||
|
router = ConditionalRouter(routes, custom_filters=custom_filters)
|
||||||
|
kwargs = {"test": "123-456-789"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"test": 123}
|
||||||
|
serialized_router = router.to_dict()
|
||||||
|
deserialized_router = ConditionalRouter.from_dict(serialized_router)
|
||||||
|
assert deserialized_router.custom_filters == router.custom_filters
|
||||||
|
assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123
|
||||||
|
assert result == deserialized_router.run(**kwargs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user