mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 14:08:27 +00:00
feat: Add route output type validation in ConditionalRouter (#8500)
This commit is contained in:
parent
33675b4caf
commit
8a35e792b9
@ -4,7 +4,7 @@
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
|
||||
from warnings import warn
|
||||
|
||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||
@ -107,7 +107,13 @@ class ConditionalRouter:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
routes: List[Dict],
|
||||
custom_filters: Optional[Dict[str, Callable]] = None,
|
||||
unsafe: bool = False,
|
||||
validate_output_type: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
|
||||
|
||||
@ -127,10 +133,14 @@ class ConditionalRouter:
|
||||
:param unsafe:
|
||||
Enable execution of arbitrary code in the Jinja template.
|
||||
This should only be used if you trust the source of the template as it can be lead to remote code execution.
|
||||
:param validate_output_type:
|
||||
Enable validation of routes' output.
|
||||
If a route output doesn't match the declared type a ValueError is raised running.
|
||||
"""
|
||||
self.routes: List[dict] = routes
|
||||
self.custom_filters = custom_filters or {}
|
||||
self._unsafe = unsafe
|
||||
self._validate_output_type = validate_output_type
|
||||
|
||||
# Create a Jinja environment to inspect variables in the condition templates
|
||||
if self._unsafe:
|
||||
@ -170,7 +180,13 @@ class ConditionalRouter:
|
||||
# output_type needs to be serialized to a string
|
||||
route["output_type"] = serialize_type(route["output_type"])
|
||||
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
|
||||
return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe)
|
||||
return default_to_dict(
|
||||
self,
|
||||
routes=self.routes,
|
||||
custom_filters=se_filters,
|
||||
unsafe=self._unsafe,
|
||||
validate_output_type=self._validate_output_type,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
||||
@ -210,9 +226,12 @@ class ConditionalRouter:
|
||||
:returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
|
||||
of the selected route.
|
||||
|
||||
:raises NoRouteSelectedException: If no `condition' in the routes is `True`.
|
||||
:raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the
|
||||
routes.
|
||||
:raises NoRouteSelectedException:
|
||||
If no `condition' in the routes is `True`.
|
||||
:raises RouteConditionException:
|
||||
If there is an error parsing or evaluating the `condition` expression in the routes.
|
||||
:raises ValueError:
|
||||
If type validation is enabled and route type doesn't match actual value type.
|
||||
"""
|
||||
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
||||
for route in self.routes:
|
||||
@ -221,21 +240,28 @@ class ConditionalRouter:
|
||||
rendered = t.render(**kwargs)
|
||||
if not self._unsafe:
|
||||
rendered = ast.literal_eval(rendered)
|
||||
if rendered:
|
||||
# We now evaluate the `output` expression to determine the route output
|
||||
t_output = self._env.from_string(route["output"])
|
||||
output = t_output.render(**kwargs)
|
||||
# We suppress the exception in case the output is already a string, otherwise
|
||||
# we try to evaluate it and would fail.
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
if not self._unsafe:
|
||||
output = ast.literal_eval(output)
|
||||
# and return the output as a dictionary under the output_name key
|
||||
return {route["output_name"]: output}
|
||||
if not rendered:
|
||||
continue
|
||||
# We now evaluate the `output` expression to determine the route output
|
||||
t_output = self._env.from_string(route["output"])
|
||||
output = t_output.render(**kwargs)
|
||||
# We suppress the exception in case the output is already a string, otherwise
|
||||
# we try to evaluate it and would fail.
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
if not self._unsafe:
|
||||
output = ast.literal_eval(output)
|
||||
except Exception as e:
|
||||
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
|
||||
msg = f"Error evaluating condition for route '{route}': {e}"
|
||||
raise RouteConditionException(msg) from e
|
||||
|
||||
if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
|
||||
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
|
||||
raise ValueError(msg)
|
||||
|
||||
# and return the output as a dictionary under the output_name key
|
||||
return {route["output_name"]: output}
|
||||
|
||||
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
|
||||
|
||||
@ -288,3 +314,53 @@ class ConditionalRouter:
|
||||
return True
|
||||
except TemplateSyntaxError:
|
||||
return False
|
||||
|
||||
def _output_matches_type(self, value: Any, expected_type: type): # noqa: PLR0911 # pylint: disable=too-many-return-statements
|
||||
"""
|
||||
Checks whether `value` type matches the `expected_type`.
|
||||
"""
|
||||
# Handle Any type
|
||||
if expected_type is Any:
|
||||
return True
|
||||
|
||||
# Get the origin type (List, Dict, etc) and type arguments
|
||||
origin = get_origin(expected_type)
|
||||
args = get_args(expected_type)
|
||||
|
||||
# Handle basic types (int, str, etc)
|
||||
if origin is None:
|
||||
return isinstance(value, expected_type)
|
||||
|
||||
# Handle Sequence types (List, Tuple, etc)
|
||||
if isinstance(origin, type) and issubclass(origin, Sequence):
|
||||
if not isinstance(value, Sequence):
|
||||
return False
|
||||
# Empty sequence is valid
|
||||
if not value:
|
||||
return True
|
||||
# Check each element against the sequence's type parameter
|
||||
return all(self._output_matches_type(item, args[0]) for item in value)
|
||||
|
||||
# Handle basic types (int, str, etc)
|
||||
if origin is None:
|
||||
return isinstance(value, expected_type)
|
||||
|
||||
# Handle Mapping types (Dict, etc)
|
||||
if isinstance(origin, type) and issubclass(origin, Mapping):
|
||||
if not isinstance(value, Mapping):
|
||||
return False
|
||||
# Empty mapping is valid
|
||||
if not value:
|
||||
return True
|
||||
key_type, value_type = args
|
||||
# Check all keys and values match their respective types
|
||||
return all(
|
||||
self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type)
|
||||
for k, v in value.items()
|
||||
)
|
||||
|
||||
# Handle Union types (including Optional)
|
||||
if origin is Union:
|
||||
return any(self._output_matches_type(value, arg) for arg in args)
|
||||
|
||||
return False
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add output type validation in `ConditionalRouter`.
|
||||
Setting `validate_output_type` to `True` will enable a check to verify if
|
||||
the actual output of a route returns the declared type.
|
||||
If it doesn't match a `ValueError` is raised.
|
||||
@ -18,22 +18,6 @@ def custom_filter_to_sede(value):
|
||||
|
||||
|
||||
class TestRouter:
|
||||
@pytest.fixture
|
||||
def routes(self):
|
||||
return [
|
||||
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def router(self, routes):
|
||||
return ConditionalRouter(routes)
|
||||
|
||||
def test_missing_mandatory_fields(self):
|
||||
"""
|
||||
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
|
||||
@ -90,7 +74,16 @@ class TestRouter:
|
||||
with pytest.raises(ValueError):
|
||||
ConditionalRouter(routes)
|
||||
|
||||
def test_router_initialized(self, routes):
|
||||
def test_router_initialized(self):
|
||||
routes = [
|
||||
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes)
|
||||
|
||||
assert router.routes == routes
|
||||
@ -166,7 +159,7 @@ class TestRouter:
|
||||
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
|
||||
assert result == {"streams": [1, 2, 3]}
|
||||
|
||||
def test_router_no_route(self, router):
|
||||
def test_router_no_route(self):
|
||||
# should raise an exception
|
||||
router = ConditionalRouter(
|
||||
[
|
||||
@ -358,3 +351,49 @@ class TestRouter:
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
res = router.run(streams=streams, message=message)
|
||||
assert res == {"message": message}
|
||||
|
||||
def test_validate_output_type_without_unsafe(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": "{{message}}",
|
||||
"output_type": ChatMessage,
|
||||
"output_name": "message",
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes, validate_output_type=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"):
|
||||
router.run(streams=streams, message=message)
|
||||
|
||||
def test_validate_output_type_with_unsafe(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": "{{message}}",
|
||||
"output_type": ChatMessage,
|
||||
"output_name": "message",
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes, unsafe=True, validate_output_type=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
res = router.run(streams=streams, message=message)
|
||||
assert isinstance(res["message"], ChatMessage)
|
||||
|
||||
streams = ["1", "2", "3", "4"]
|
||||
with pytest.raises(ValueError, match="Route 'streams' type doesn't match expected type"):
|
||||
router.run(streams=streams, message=message)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user