From 3e3f79b9285c5b56432aac3e4ef2309e5f31ea74 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:14:54 +0200 Subject: [PATCH] feat: Add `unsafe` init arg in `ConditionalRouter` and `OutputAdapter` to enable previous behaviour (#8176) * Add unsafe behaviour to OutputAdapter * Add unsafe behaviour to ConditionalRouter * Add release notes * Fix mypy * Add documentation links --------- Co-authored-by: Madeesh Kannan --- .../components/converters/output_adapter.py | 43 +++++++++++++++---- .../components/routers/conditional_router.py | 34 ++++++++++----- .../unsafe-behaviour-e8b41d957113e0c3.yaml | 8 ++++ .../converters/test_output_adapter.py | 13 ++++++ .../routers/test_conditional_router.py | 21 +++++++++ 5 files changed, 100 insertions(+), 19 deletions(-) create mode 100644 releasenotes/notes/unsafe-behaviour-e8b41d957113e0c3.yaml diff --git a/haystack/components/converters/output_adapter.py b/haystack/components/converters/output_adapter.py index 64cac7996..50ffd5391 100644 --- a/haystack/components/converters/output_adapter.py +++ b/haystack/components/converters/output_adapter.py @@ -5,9 +5,11 @@ import ast import contextlib from typing import Any, Callable, Dict, Optional, Set +from warnings import warn import jinja2.runtime -from jinja2 import TemplateSyntaxError, meta +from jinja2 import Environment, TemplateSyntaxError, meta +from jinja2.nativetypes import NativeEnvironment from jinja2.sandbox import SandboxedEnvironment from typing_extensions import TypeAlias @@ -37,7 +39,13 @@ class OutputAdapter: ``` """ - def __init__(self, template: str, output_type: TypeAlias, custom_filters: Optional[Dict[str, Callable]] = None): + def __init__( + self, + template: str, + output_type: TypeAlias, + custom_filters: Optional[Dict[str, Callable]] = None, + unsafe: bool = False, + ): """ Create an OutputAdapter component. @@ -54,13 +62,25 @@ class OutputAdapter: The type of output this instance will return. :param custom_filters: A dictionary of custom Jinja filters used in the template. + :param unsafe: + Enable execution of arbitrary code in the Jinja template. + This should only be used if you trust the source of the template as it can be lead to remote code execution. """ self.custom_filters = {**(custom_filters or {})} input_types: Set[str] = set() - # Create a Jinja native environment, we need it to: - # a) add custom filters to the environment for filter compilation stage - self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) + self._unsafe = unsafe + + if self._unsafe: + msg = ( + "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. " + "Use this only if you trust the source of the template." + ) + warn(msg) + self._env = ( + NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) + ) + try: self._env.parse(template) # Validate template syntax self.template = template @@ -108,7 +128,8 @@ class OutputAdapter: # This must be done cause the output could be different literal structures. # This doesn't support any user types. with contextlib.suppress(Exception): - output_result = ast.literal_eval(output_result) + if not self._unsafe: + output_result = ast.literal_eval(output_result) adapted_outputs["output"] = output_result except Exception as e: @@ -124,7 +145,11 @@ class OutputAdapter: """ se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} return default_to_dict( - self, template=self.template, output_type=serialize_type(self.output_type), custom_filters=se_filters + self, + template=self.template, + output_type=serialize_type(self.output_type), + custom_filters=se_filters, + unsafe=self._unsafe, ) @classmethod @@ -148,11 +173,11 @@ class OutputAdapter: } return default_from_dict(cls, data) - def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]: + def _extract_variables(self, env: Environment) -> Set[str]: """ Extracts all variables from a list of Jinja template strings. - :param env: A Jinja native environment. + :param env: A Jinja environment. :return: A set of variable names extracted from the template strings. """ ast = env.parse(self.template) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index a72bb4fc0..ccac555d3 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -5,6 +5,7 @@ import ast import contextlib from typing import Any, Callable, Dict, List, Optional, Set +from warnings import warn from jinja2 import Environment, TemplateSyntaxError, meta from jinja2.nativetypes import NativeEnvironment @@ -106,7 +107,7 @@ class ConditionalRouter: ``` """ - def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None): + def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False): """ Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing. @@ -123,15 +124,26 @@ class ConditionalRouter: - `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`. `{{ my_var|my_filter }}` can then be used inside a route condition expression: `"condition": "{{ my_var|my_filter == 'foo' }}"`. + :param unsafe: + Enable execution of arbitrary code in the Jinja template. + This should only be used if you trust the source of the template as it can be lead to remote code execution. """ - self._validate_routes(routes) self.routes: List[dict] = routes self.custom_filters = custom_filters or {} + self._unsafe = unsafe - # Create a Jinja native environment to inspect variables in the condition templates - self._env = SandboxedEnvironment() + # Create a Jinja environment to inspect variables in the condition templates + if self._unsafe: + msg = ( + "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. " + "Use this only if you trust the source of the template." + ) + warn(msg) + + self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment() self._env.filters.update(self.custom_filters) + self._validate_routes(routes) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any output_types: Dict[str, str] = {} @@ -158,7 +170,7 @@ class ConditionalRouter: # output_type needs to be serialized to a string route["output_type"] = serialize_type(route["output_type"]) se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} - return default_to_dict(self, routes=self.routes, custom_filters=se_filters) + return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter": @@ -202,7 +214,9 @@ class ConditionalRouter: try: t = self._env.from_string(route["condition"]) rendered = t.render(**kwargs) - if ast.literal_eval(rendered): + if not self._unsafe: + rendered = ast.literal_eval(rendered) + if rendered: # We now evaluate the `output` expression to determine the route output t_output = self._env.from_string(route["output"]) output = t_output.render(**kwargs) @@ -211,7 +225,8 @@ class ConditionalRouter: # This must be done cause the output could be different literal structures. # This doesn't support any user types. with contextlib.suppress(Exception): - output = ast.literal_eval(output) + if not self._unsafe: + output = ast.literal_eval(output) # and return the output as a dictionary under the output_name key return {route["output_name"]: output} except Exception as e: @@ -225,7 +240,6 @@ class ConditionalRouter: :param routes: A list of routes. """ - env = NativeEnvironment() for route in routes: try: keys = set(route.keys()) @@ -239,10 +253,10 @@ class ConditionalRouter: f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}" ) for field in ["condition", "output"]: - if not self._validate_template(env, route[field]): + if not self._validate_template(self._env, route[field]): raise ValueError(f"Invalid template for field '{field}': {route[field]}") - def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]: + def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]: """ Extracts all variables from a list of Jinja template strings. diff --git a/releasenotes/notes/unsafe-behaviour-e8b41d957113e0c3.yaml b/releasenotes/notes/unsafe-behaviour-e8b41d957113e0c3.yaml new file mode 100644 index 000000000..d6242432e --- /dev/null +++ b/releasenotes/notes/unsafe-behaviour-e8b41d957113e0c3.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Add `unsafe` argument to enable behaviour that could lead to remote code execution in `ConditionalRouter` and `OutputAdapter`. + By default unsafe behaviour is not enabled, the user must set it explicitly to `True`. + This means that user types like `ChatMessage`, `Document`, and `Answer` can be used as output types when `unsafe` is `True`. + We recommend using `unsafe` behaviour only when the Jinja templates source is trusted. + For more info see the documentation for [`ConditionalRouter`](https://docs.haystack.deepset.ai/docs/conditionalrouter#unsafe-behaviour) and [`OutputAdapter`](https://docs.haystack.deepset.ai/docs/outputadapter#unsafe-behaviour) diff --git a/test/components/converters/test_output_adapter.py b/test/components/converters/test_output_adapter.py index 25b5eb7e0..547ce433e 100644 --- a/test/components/converters/test_output_adapter.py +++ b/test/components/converters/test_output_adapter.py @@ -8,6 +8,7 @@ import json import pytest from haystack import Pipeline, component +from haystack.dataclasses import Document from haystack.components.converters import OutputAdapter from haystack.components.converters.output_adapter import OutputAdaptationException @@ -150,6 +151,7 @@ class TestOutputAdapter: "template": "{{ documents[0].content}}", "output_type": "str", "custom_filters": None, + "unsafe": False, }, } ) @@ -157,6 +159,7 @@ class TestOutputAdapter: assert component.template == "{{ documents[0].content}}" assert component.output_type == str assert component.custom_filters == {} + assert not component._unsafe def test_output_adapter_in_pipeline(self): @component @@ -179,3 +182,13 @@ class TestOutputAdapter: result = pipe.run(data={}) assert result assert result["output_adapter"]["output"] == {"framework": "Haystack"} + + def test_unsafe(self): + adapter = OutputAdapter(template="{{ documents[0] }}", output_type=Document, unsafe=True) + documents = [ + Document(content="Test document"), + Document(content="Another test document"), + Document(content="Yet another test document"), + ] + res = adapter.run(documents=documents) + assert res["output"] == documents[0] diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index b324cf806..47873e3d3 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -307,3 +307,24 @@ class TestRouter: assert deserialized_router.custom_filters == router.custom_filters assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123 assert result == deserialized_router.run(**kwargs) + + def test_unsafe(self): + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{message}}", + "output_type": ChatMessage, + "output_name": "message", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes, unsafe=True) + streams = [1] + message = ChatMessage.from_user(content="This is a message") + res = router.run(streams=streams, message=message) + assert res == {"message": message}