diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index b16ae3d34..27312f43a 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -4,7 +4,7 @@ import ast import contextlib -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin from warnings import warn from jinja2 import Environment, TemplateSyntaxError, meta @@ -107,7 +107,13 @@ class ConditionalRouter: ``` """ - def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False): + def __init__( + self, + routes: List[Dict], + custom_filters: Optional[Dict[str, Callable]] = None, + unsafe: bool = False, + validate_output_type: bool = False, + ): """ Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing. @@ -127,10 +133,14 @@ class ConditionalRouter: :param unsafe: Enable execution of arbitrary code in the Jinja template. This should only be used if you trust the source of the template as it can be lead to remote code execution. + :param validate_output_type: + Enable validation of routes' output. + If a route output doesn't match the declared type a ValueError is raised running. """ self.routes: List[dict] = routes self.custom_filters = custom_filters or {} self._unsafe = unsafe + self._validate_output_type = validate_output_type # Create a Jinja environment to inspect variables in the condition templates if self._unsafe: @@ -170,7 +180,13 @@ class ConditionalRouter: # output_type needs to be serialized to a string route["output_type"] = serialize_type(route["output_type"]) se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} - return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe) + return default_to_dict( + self, + routes=self.routes, + custom_filters=se_filters, + unsafe=self._unsafe, + validate_output_type=self._validate_output_type, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter": @@ -210,9 +226,12 @@ class ConditionalRouter: :returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output` of the selected route. - :raises NoRouteSelectedException: If no `condition' in the routes is `True`. - :raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the - routes. + :raises NoRouteSelectedException: + If no `condition' in the routes is `True`. + :raises RouteConditionException: + If there is an error parsing or evaluating the `condition` expression in the routes. + :raises ValueError: + If type validation is enabled and route type doesn't match actual value type. """ # Create a Jinja native environment to evaluate the condition templates as Python expressions for route in self.routes: @@ -221,21 +240,28 @@ class ConditionalRouter: rendered = t.render(**kwargs) if not self._unsafe: rendered = ast.literal_eval(rendered) - if rendered: - # We now evaluate the `output` expression to determine the route output - t_output = self._env.from_string(route["output"]) - output = t_output.render(**kwargs) - # We suppress the exception in case the output is already a string, otherwise - # we try to evaluate it and would fail. - # This must be done cause the output could be different literal structures. - # This doesn't support any user types. - with contextlib.suppress(Exception): - if not self._unsafe: - output = ast.literal_eval(output) - # and return the output as a dictionary under the output_name key - return {route["output_name"]: output} + if not rendered: + continue + # We now evaluate the `output` expression to determine the route output + t_output = self._env.from_string(route["output"]) + output = t_output.render(**kwargs) + # We suppress the exception in case the output is already a string, otherwise + # we try to evaluate it and would fail. + # This must be done cause the output could be different literal structures. + # This doesn't support any user types. + with contextlib.suppress(Exception): + if not self._unsafe: + output = ast.literal_eval(output) except Exception as e: - raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e + msg = f"Error evaluating condition for route '{route}': {e}" + raise RouteConditionException(msg) from e + + if self._validate_output_type and not self._output_matches_type(output, route["output_type"]): + msg = f"""Route '{route["output_name"]}' type doesn't match expected type""" + raise ValueError(msg) + + # and return the output as a dictionary under the output_name key + return {route["output_name"]: output} raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}") @@ -288,3 +314,53 @@ class ConditionalRouter: return True except TemplateSyntaxError: return False + + def _output_matches_type(self, value: Any, expected_type: type): # noqa: PLR0911 # pylint: disable=too-many-return-statements + """ + Checks whether `value` type matches the `expected_type`. + """ + # Handle Any type + if expected_type is Any: + return True + + # Get the origin type (List, Dict, etc) and type arguments + origin = get_origin(expected_type) + args = get_args(expected_type) + + # Handle basic types (int, str, etc) + if origin is None: + return isinstance(value, expected_type) + + # Handle Sequence types (List, Tuple, etc) + if isinstance(origin, type) and issubclass(origin, Sequence): + if not isinstance(value, Sequence): + return False + # Empty sequence is valid + if not value: + return True + # Check each element against the sequence's type parameter + return all(self._output_matches_type(item, args[0]) for item in value) + + # Handle basic types (int, str, etc) + if origin is None: + return isinstance(value, expected_type) + + # Handle Mapping types (Dict, etc) + if isinstance(origin, type) and issubclass(origin, Mapping): + if not isinstance(value, Mapping): + return False + # Empty mapping is valid + if not value: + return True + key_type, value_type = args + # Check all keys and values match their respective types + return all( + self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type) + for k, v in value.items() + ) + + # Handle Union types (including Optional) + if origin is Union: + return any(self._output_matches_type(value, arg) for arg in args) + + return False diff --git a/releasenotes/notes/conditional-routes-validation-b46fc506d35894d4.yaml b/releasenotes/notes/conditional-routes-validation-b46fc506d35894d4.yaml new file mode 100644 index 000000000..315510b28 --- /dev/null +++ b/releasenotes/notes/conditional-routes-validation-b46fc506d35894d4.yaml @@ -0,0 +1,7 @@ +--- +enhancements: + - | + Add output type validation in `ConditionalRouter`. + Setting `validate_output_type` to `True` will enable a check to verify if + the actual output of a route returns the declared type. + If it doesn't match a `ValueError` is raised. diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 461c58f96..8ea3f86d9 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -18,22 +18,6 @@ def custom_filter_to_sede(value): class TestRouter: - @pytest.fixture - def routes(self): - return [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - - @pytest.fixture - def router(self, routes): - return ConditionalRouter(routes) - def test_missing_mandatory_fields(self): """ Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys @@ -90,7 +74,16 @@ class TestRouter: with pytest.raises(ValueError): ConditionalRouter(routes) - def test_router_initialized(self, routes): + def test_router_initialized(self): + routes = [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] router = ConditionalRouter(routes) assert router.routes == routes @@ -166,7 +159,7 @@ class TestRouter: result = router.run(messages=[message], streams=[1, 2, 3], query="my query") assert result == {"streams": [1, 2, 3]} - def test_router_no_route(self, router): + def test_router_no_route(self): # should raise an exception router = ConditionalRouter( [ @@ -358,3 +351,49 @@ class TestRouter: message = ChatMessage.from_user(content="This is a message") res = router.run(streams=streams, message=message) assert res == {"message": message} + + def test_validate_output_type_without_unsafe(self): + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{message}}", + "output_type": ChatMessage, + "output_name": "message", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes, validate_output_type=True) + streams = [1] + message = ChatMessage.from_user(content="This is a message") + with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"): + router.run(streams=streams, message=message) + + def test_validate_output_type_with_unsafe(self): + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{message}}", + "output_type": ChatMessage, + "output_name": "message", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes, unsafe=True, validate_output_type=True) + streams = [1] + message = ChatMessage.from_user(content="This is a message") + res = router.run(streams=streams, message=message) + assert isinstance(res["message"], ChatMessage) + + streams = ["1", "2", "3", "4"] + with pytest.raises(ValueError, match="Route 'streams' type doesn't match expected type"): + router.run(streams=streams, message=message)