feat: Add unsafe init arg in ConditionalRouter and OutputAdapter to enable previous behaviour (#8176)

* Add unsafe behaviour to OutputAdapter

* Add unsafe behaviour to ConditionalRouter

* Add release notes

* Fix mypy

* Add documentation links

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Silvano Cerza 2024-09-02 16:14:54 +02:00 committed by GitHub
parent e614fa0c62
commit 3e3f79b928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 19 deletions

View File

@ -5,9 +5,11 @@
import ast
import contextlib
from typing import Any, Callable, Dict, Optional, Set
from warnings import warn
import jinja2.runtime
from jinja2 import TemplateSyntaxError, meta
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias
@ -37,7 +39,13 @@ class OutputAdapter:
```
"""
def __init__(self, template: str, output_type: TypeAlias, custom_filters: Optional[Dict[str, Callable]] = None):
def __init__(
self,
template: str,
output_type: TypeAlias,
custom_filters: Optional[Dict[str, Callable]] = None,
unsafe: bool = False,
):
"""
Create an OutputAdapter component.
@ -54,13 +62,25 @@ class OutputAdapter:
The type of output this instance will return.
:param custom_filters:
A dictionary of custom Jinja filters used in the template.
:param unsafe:
Enable execution of arbitrary code in the Jinja template.
This should only be used if you trust the source of the template as it can be lead to remote code execution.
"""
self.custom_filters = {**(custom_filters or {})}
input_types: Set[str] = set()
# Create a Jinja native environment, we need it to:
# a) add custom filters to the environment for filter compilation stage
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
self._unsafe = unsafe
if self._unsafe:
msg = (
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
"Use this only if you trust the source of the template."
)
warn(msg)
self._env = (
NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
)
try:
self._env.parse(template) # Validate template syntax
self.template = template
@ -108,7 +128,8 @@ class OutputAdapter:
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
output_result = ast.literal_eval(output_result)
if not self._unsafe:
output_result = ast.literal_eval(output_result)
adapted_outputs["output"] = output_result
except Exception as e:
@ -124,7 +145,11 @@ class OutputAdapter:
"""
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(
self, template=self.template, output_type=serialize_type(self.output_type), custom_filters=se_filters
self,
template=self.template,
output_type=serialize_type(self.output_type),
custom_filters=se_filters,
unsafe=self._unsafe,
)
@classmethod
@ -148,11 +173,11 @@ class OutputAdapter:
}
return default_from_dict(cls, data)
def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]:
def _extract_variables(self, env: Environment) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.
:param env: A Jinja native environment.
:param env: A Jinja environment.
:return: A set of variable names extracted from the template strings.
"""
ast = env.parse(self.template)

View File

@ -5,6 +5,7 @@
import ast
import contextlib
from typing import Any, Callable, Dict, List, Optional, Set
from warnings import warn
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
@ -106,7 +107,7 @@ class ConditionalRouter:
```
"""
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False):
"""
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
@ -123,15 +124,26 @@ class ConditionalRouter:
- `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
`{{ my_var|my_filter }}` can then be used inside a route condition expression:
`"condition": "{{ my_var|my_filter == 'foo' }}"`.
:param unsafe:
Enable execution of arbitrary code in the Jinja template.
This should only be used if you trust the source of the template as it can be lead to remote code execution.
"""
self._validate_routes(routes)
self.routes: List[dict] = routes
self.custom_filters = custom_filters or {}
self._unsafe = unsafe
# Create a Jinja native environment to inspect variables in the condition templates
self._env = SandboxedEnvironment()
# Create a Jinja environment to inspect variables in the condition templates
if self._unsafe:
msg = (
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
"Use this only if you trust the source of the template."
)
warn(msg)
self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
self._env.filters.update(self.custom_filters)
self._validate_routes(routes)
# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
output_types: Dict[str, str] = {}
@ -158,7 +170,7 @@ class ConditionalRouter:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
@ -202,7 +214,9 @@ class ConditionalRouter:
try:
t = self._env.from_string(route["condition"])
rendered = t.render(**kwargs)
if ast.literal_eval(rendered):
if not self._unsafe:
rendered = ast.literal_eval(rendered)
if rendered:
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
@ -211,7 +225,8 @@ class ConditionalRouter:
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
output = ast.literal_eval(output)
if not self._unsafe:
output = ast.literal_eval(output)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
except Exception as e:
@ -225,7 +240,6 @@ class ConditionalRouter:
:param routes: A list of routes.
"""
env = NativeEnvironment()
for route in routes:
try:
keys = set(route.keys())
@ -239,10 +253,10 @@ class ConditionalRouter:
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
)
for field in ["condition", "output"]:
if not self._validate_template(env, route[field]):
if not self._validate_template(self._env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.

View File

@ -0,0 +1,8 @@
---
features:
- |
Add `unsafe` argument to enable behaviour that could lead to remote code execution in `ConditionalRouter` and `OutputAdapter`.
By default unsafe behaviour is not enabled, the user must set it explicitly to `True`.
This means that user types like `ChatMessage`, `Document`, and `Answer` can be used as output types when `unsafe` is `True`.
We recommend using `unsafe` behaviour only when the Jinja templates source is trusted.
For more info see the documentation for [`ConditionalRouter`](https://docs.haystack.deepset.ai/docs/conditionalrouter#unsafe-behaviour) and [`OutputAdapter`](https://docs.haystack.deepset.ai/docs/outputadapter#unsafe-behaviour)

View File

@ -8,6 +8,7 @@ import json
import pytest
from haystack import Pipeline, component
from haystack.dataclasses import Document
from haystack.components.converters import OutputAdapter
from haystack.components.converters.output_adapter import OutputAdaptationException
@ -150,6 +151,7 @@ class TestOutputAdapter:
"template": "{{ documents[0].content}}",
"output_type": "str",
"custom_filters": None,
"unsafe": False,
},
}
)
@ -157,6 +159,7 @@ class TestOutputAdapter:
assert component.template == "{{ documents[0].content}}"
assert component.output_type == str
assert component.custom_filters == {}
assert not component._unsafe
def test_output_adapter_in_pipeline(self):
@component
@ -179,3 +182,13 @@ class TestOutputAdapter:
result = pipe.run(data={})
assert result
assert result["output_adapter"]["output"] == {"framework": "Haystack"}
def test_unsafe(self):
adapter = OutputAdapter(template="{{ documents[0] }}", output_type=Document, unsafe=True)
documents = [
Document(content="Test document"),
Document(content="Another test document"),
Document(content="Yet another test document"),
]
res = adapter.run(documents=documents)
assert res["output"] == documents[0]

View File

@ -307,3 +307,24 @@ class TestRouter:
assert deserialized_router.custom_filters == router.custom_filters
assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123
assert result == deserialized_router.run(**kwargs)
def test_unsafe(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes, unsafe=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
res = router.run(streams=streams, message=message)
assert res == {"message": message}