diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index 70325f4d5..be72c7cc5 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -2,13 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Set +from typing import Any, Callable, Dict, List, Optional, Set from jinja2 import Environment, TemplateSyntaxError, meta from jinja2.nativetypes import NativeEnvironment from haystack import component, default_from_dict, default_to_dict, logging -from haystack.utils import deserialize_type, serialize_type +from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type logger = logging.getLogger(__name__) @@ -102,7 +102,7 @@ class ConditionalRouter: ``` """ - def __init__(self, routes: List[Dict]): + def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None): """ Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing. @@ -113,12 +113,20 @@ class ConditionalRouter: - `output_type`: The type of the output data (e.g., str, List[int]). - `output_name`: The name under which the `output` value of the route is published. This name is used to connect the router to other components in the pipeline. + :param custom_filters: A dictionary of custom Jinja2 filters to be used in the condition expressions. + For example, passing `{"my_filter": my_filter_fcn}` where: + - `my_filter` is the name of the custom filter. + - `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`. + `{{ my_var|my_filter }}` can then be used inside a route condition expression like so: + `"condition": "{{ my_var|my_filter == 'foo' }}"`. """ self._validate_routes(routes) self.routes: List[dict] = routes + self.custom_filters = custom_filters or {} # Create a Jinja native environment to inspect variables in the condition templates env = NativeEnvironment() + env.filters.update(self.custom_filters) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any @@ -145,8 +153,8 @@ class ConditionalRouter: for route in self.routes: # output_type needs to be serialized to a string route["output_type"] = serialize_type(route["output_type"]) - - return default_to_dict(self, routes=self.routes) + se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} + return default_to_dict(self, routes=self.routes, custom_filters=se_filters) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter": @@ -163,6 +171,8 @@ class ConditionalRouter: for route in routes: # output_type needs to be deserialized from a string to a type route["output_type"] = deserialize_type(route["output_type"]) + for name, filter_func in init_params.get("custom_filters", {}).items(): + init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None return default_from_dict(cls, data) def run(self, **kwargs): @@ -185,6 +195,7 @@ class ConditionalRouter: """ # Create a Jinja native environment to evaluate the condition templates as Python expressions env = NativeEnvironment() + env.filters.update(self.custom_filters) for route in self.routes: try: diff --git a/releasenotes/notes/add-custom-filters-to-conditional-router-631eba8bab3c2ae7.yaml b/releasenotes/notes/add-custom-filters-to-conditional-router-631eba8bab3c2ae7.yaml new file mode 100644 index 000000000..9f4cdde9f --- /dev/null +++ b/releasenotes/notes/add-custom-filters-to-conditional-router-631eba8bab3c2ae7.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added custom filters support to ConditionalRouter. Users can now pass in + one or more custom Jinja2 filter callables and be able to access those + filters when defining condition expressions in routes. diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 93d454784..0b9059571 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -12,6 +12,11 @@ from haystack.components.routers.conditional_router import NoRouteSelectedExcept from haystack.dataclasses import ChatMessage +def custom_filter_to_sede(value): + """splits by hyphen and returns the first part""" + return int(value.split("-")[0]) + + class TestRouter: @pytest.fixture def routes(self): @@ -288,3 +293,47 @@ class TestRouter: router_dict_first_invocation = copy.deepcopy(router.to_dict()) router_dict_second_invocation = router.to_dict() assert router_dict_first_invocation == router_dict_second_invocation + + def test_custom_filter(self): + routes = [ + { + "condition": "{{phone_num|get_area_code == 123}}", + "output": "Phone number has a 123 area code", + "output_name": "good_phone_num", + "output_type": str, + }, + { + "condition": "{{phone_num|get_area_code != 123}}", + "output": "Phone number does not have 123 area code", + "output_name": "bad_phone_num", + "output_type": str, + }, + ] + + router = ConditionalRouter(routes, custom_filters={"get_area_code": custom_filter_to_sede}) + kwargs = {"phone_num": "123-456-7890"} + result = router.run(**kwargs) + assert result == {"good_phone_num": "Phone number has a 123 area code"} + kwargs = {"phone_num": "321-456-7890"} + result = router.run(**kwargs) + assert result == {"bad_phone_num": "Phone number does not have 123 area code"} + + def test_sede_with_custom_filter(self): + routes = [ + { + "condition": "{{ test|custom_filter_to_sede == 123 }}", + "output": "123", + "output_name": "test", + "output_type": int, + } + ] + custom_filters = {"custom_filter_to_sede": custom_filter_to_sede} + router = ConditionalRouter(routes, custom_filters=custom_filters) + kwargs = {"test": "123-456-789"} + result = router.run(**kwargs) + assert result == {"test": 123} + serialized_router = router.to_dict() + deserialized_router = ConditionalRouter.from_dict(serialized_router) + assert deserialized_router.custom_filters == router.custom_filters + assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123 + assert result == deserialized_router.run(**kwargs)