feat: Add route output type validation in ConditionalRouter (#8500)

This commit is contained in:
Silvano Cerza 2024-10-29 18:06:54 +01:00 committed by GitHub
parent 33675b4caf
commit 8a35e792b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 38 deletions

View File

@ -4,7 +4,7 @@
import ast
import contextlib
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
from warnings import warn
from jinja2 import Environment, TemplateSyntaxError, meta
@ -107,7 +107,13 @@ class ConditionalRouter:
```
"""
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False):
def __init__(
self,
routes: List[Dict],
custom_filters: Optional[Dict[str, Callable]] = None,
unsafe: bool = False,
validate_output_type: bool = False,
):
"""
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
@ -127,10 +133,14 @@ class ConditionalRouter:
:param unsafe:
Enable execution of arbitrary code in the Jinja template.
This should only be used if you trust the source of the template as it can be lead to remote code execution.
:param validate_output_type:
Enable validation of routes' output.
If a route output doesn't match the declared type a ValueError is raised running.
"""
self.routes: List[dict] = routes
self.custom_filters = custom_filters or {}
self._unsafe = unsafe
self._validate_output_type = validate_output_type
# Create a Jinja environment to inspect variables in the condition templates
if self._unsafe:
@ -170,7 +180,13 @@ class ConditionalRouter:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe)
return default_to_dict(
self,
routes=self.routes,
custom_filters=se_filters,
unsafe=self._unsafe,
validate_output_type=self._validate_output_type,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
@ -210,9 +226,12 @@ class ConditionalRouter:
:returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
of the selected route.
:raises NoRouteSelectedException: If no `condition' in the routes is `True`.
:raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the
routes.
:raises NoRouteSelectedException:
If no `condition' in the routes is `True`.
:raises RouteConditionException:
If there is an error parsing or evaluating the `condition` expression in the routes.
:raises ValueError:
If type validation is enabled and route type doesn't match actual value type.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
for route in self.routes:
@ -221,21 +240,28 @@ class ConditionalRouter:
rendered = t.render(**kwargs)
if not self._unsafe:
rendered = ast.literal_eval(rendered)
if rendered:
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
if not rendered:
continue
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)
except Exception as e:
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
msg = f"Error evaluating condition for route '{route}': {e}"
raise RouteConditionException(msg) from e
if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
raise ValueError(msg)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
@ -288,3 +314,53 @@ class ConditionalRouter:
return True
except TemplateSyntaxError:
return False
def _output_matches_type(self, value: Any, expected_type: type): # noqa: PLR0911 # pylint: disable=too-many-return-statements
"""
Checks whether `value` type matches the `expected_type`.
"""
# Handle Any type
if expected_type is Any:
return True
# Get the origin type (List, Dict, etc) and type arguments
origin = get_origin(expected_type)
args = get_args(expected_type)
# Handle basic types (int, str, etc)
if origin is None:
return isinstance(value, expected_type)
# Handle Sequence types (List, Tuple, etc)
if isinstance(origin, type) and issubclass(origin, Sequence):
if not isinstance(value, Sequence):
return False
# Empty sequence is valid
if not value:
return True
# Check each element against the sequence's type parameter
return all(self._output_matches_type(item, args[0]) for item in value)
# Handle basic types (int, str, etc)
if origin is None:
return isinstance(value, expected_type)
# Handle Mapping types (Dict, etc)
if isinstance(origin, type) and issubclass(origin, Mapping):
if not isinstance(value, Mapping):
return False
# Empty mapping is valid
if not value:
return True
key_type, value_type = args
# Check all keys and values match their respective types
return all(
self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type)
for k, v in value.items()
)
# Handle Union types (including Optional)
if origin is Union:
return any(self._output_matches_type(value, arg) for arg in args)
return False

View File

@ -0,0 +1,7 @@
---
enhancements:
- |
Add output type validation in `ConditionalRouter`.
Setting `validate_output_type` to `True` will enable a check to verify if
the actual output of a route returns the declared type.
If it doesn't match a `ValueError` is raised.

View File

@ -18,22 +18,6 @@ def custom_filter_to_sede(value):
class TestRouter:
@pytest.fixture
def routes(self):
return [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
@pytest.fixture
def router(self, routes):
return ConditionalRouter(routes)
def test_missing_mandatory_fields(self):
"""
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
@ -90,7 +74,16 @@ class TestRouter:
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_router_initialized(self, routes):
def test_router_initialized(self):
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
assert router.routes == routes
@ -166,7 +159,7 @@ class TestRouter:
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
assert result == {"streams": [1, 2, 3]}
def test_router_no_route(self, router):
def test_router_no_route(self):
# should raise an exception
router = ConditionalRouter(
[
@ -358,3 +351,49 @@ class TestRouter:
message = ChatMessage.from_user(content="This is a message")
res = router.run(streams=streams, message=message)
assert res == {"message": message}
def test_validate_output_type_without_unsafe(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"):
router.run(streams=streams, message=message)
def test_validate_output_type_with_unsafe(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes, unsafe=True, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
res = router.run(streams=streams, message=message)
assert isinstance(res["message"], ChatMessage)
streams = ["1", "2", "3", "4"]
with pytest.raises(ValueError, match="Route 'streams' type doesn't match expected type"):
router.run(streams=streams, message=message)