mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 15:38:36 +00:00
feat: Add unsafe init arg in ConditionalRouter and OutputAdapter to enable previous behaviour (#8176)
* Add unsafe behaviour to OutputAdapter * Add unsafe behaviour to ConditionalRouter * Add release notes * Fix mypy * Add documentation links --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
e614fa0c62
commit
3e3f79b928
@ -5,9 +5,11 @@
|
||||
import ast
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Optional, Set
|
||||
from warnings import warn
|
||||
|
||||
import jinja2.runtime
|
||||
from jinja2 import TemplateSyntaxError, meta
|
||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@ -37,7 +39,13 @@ class OutputAdapter:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, template: str, output_type: TypeAlias, custom_filters: Optional[Dict[str, Callable]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
output_type: TypeAlias,
|
||||
custom_filters: Optional[Dict[str, Callable]] = None,
|
||||
unsafe: bool = False,
|
||||
):
|
||||
"""
|
||||
Create an OutputAdapter component.
|
||||
|
||||
@ -54,13 +62,25 @@ class OutputAdapter:
|
||||
The type of output this instance will return.
|
||||
:param custom_filters:
|
||||
A dictionary of custom Jinja filters used in the template.
|
||||
:param unsafe:
|
||||
Enable execution of arbitrary code in the Jinja template.
|
||||
This should only be used if you trust the source of the template as it can be lead to remote code execution.
|
||||
"""
|
||||
self.custom_filters = {**(custom_filters or {})}
|
||||
input_types: Set[str] = set()
|
||||
|
||||
# Create a Jinja native environment, we need it to:
|
||||
# a) add custom filters to the environment for filter compilation stage
|
||||
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
|
||||
self._unsafe = unsafe
|
||||
|
||||
if self._unsafe:
|
||||
msg = (
|
||||
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
|
||||
"Use this only if you trust the source of the template."
|
||||
)
|
||||
warn(msg)
|
||||
self._env = (
|
||||
NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
|
||||
)
|
||||
|
||||
try:
|
||||
self._env.parse(template) # Validate template syntax
|
||||
self.template = template
|
||||
@ -108,7 +128,8 @@ class OutputAdapter:
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
output_result = ast.literal_eval(output_result)
|
||||
if not self._unsafe:
|
||||
output_result = ast.literal_eval(output_result)
|
||||
|
||||
adapted_outputs["output"] = output_result
|
||||
except Exception as e:
|
||||
@ -124,7 +145,11 @@ class OutputAdapter:
|
||||
"""
|
||||
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
|
||||
return default_to_dict(
|
||||
self, template=self.template, output_type=serialize_type(self.output_type), custom_filters=se_filters
|
||||
self,
|
||||
template=self.template,
|
||||
output_type=serialize_type(self.output_type),
|
||||
custom_filters=se_filters,
|
||||
unsafe=self._unsafe,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -148,11 +173,11 @@ class OutputAdapter:
|
||||
}
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]:
|
||||
def _extract_variables(self, env: Environment) -> Set[str]:
|
||||
"""
|
||||
Extracts all variables from a list of Jinja template strings.
|
||||
|
||||
:param env: A Jinja native environment.
|
||||
:param env: A Jinja environment.
|
||||
:return: A set of variable names extracted from the template strings.
|
||||
"""
|
||||
ast = env.parse(self.template)
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import ast
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from warnings import warn
|
||||
|
||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
@ -106,7 +107,7 @@ class ConditionalRouter:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
|
||||
def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False):
|
||||
"""
|
||||
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
|
||||
|
||||
@ -123,15 +124,26 @@ class ConditionalRouter:
|
||||
- `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
|
||||
`{{ my_var|my_filter }}` can then be used inside a route condition expression:
|
||||
`"condition": "{{ my_var|my_filter == 'foo' }}"`.
|
||||
:param unsafe:
|
||||
Enable execution of arbitrary code in the Jinja template.
|
||||
This should only be used if you trust the source of the template as it can be lead to remote code execution.
|
||||
"""
|
||||
self._validate_routes(routes)
|
||||
self.routes: List[dict] = routes
|
||||
self.custom_filters = custom_filters or {}
|
||||
self._unsafe = unsafe
|
||||
|
||||
# Create a Jinja native environment to inspect variables in the condition templates
|
||||
self._env = SandboxedEnvironment()
|
||||
# Create a Jinja environment to inspect variables in the condition templates
|
||||
if self._unsafe:
|
||||
msg = (
|
||||
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
|
||||
"Use this only if you trust the source of the template."
|
||||
)
|
||||
warn(msg)
|
||||
|
||||
self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
|
||||
self._env.filters.update(self.custom_filters)
|
||||
|
||||
self._validate_routes(routes)
|
||||
# Inspect the routes to determine input and output types.
|
||||
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
||||
output_types: Dict[str, str] = {}
|
||||
@ -158,7 +170,7 @@ class ConditionalRouter:
|
||||
# output_type needs to be serialized to a string
|
||||
route["output_type"] = serialize_type(route["output_type"])
|
||||
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
|
||||
return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
|
||||
return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
||||
@ -202,7 +214,9 @@ class ConditionalRouter:
|
||||
try:
|
||||
t = self._env.from_string(route["condition"])
|
||||
rendered = t.render(**kwargs)
|
||||
if ast.literal_eval(rendered):
|
||||
if not self._unsafe:
|
||||
rendered = ast.literal_eval(rendered)
|
||||
if rendered:
|
||||
# We now evaluate the `output` expression to determine the route output
|
||||
t_output = self._env.from_string(route["output"])
|
||||
output = t_output.render(**kwargs)
|
||||
@ -211,7 +225,8 @@ class ConditionalRouter:
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
output = ast.literal_eval(output)
|
||||
if not self._unsafe:
|
||||
output = ast.literal_eval(output)
|
||||
# and return the output as a dictionary under the output_name key
|
||||
return {route["output_name"]: output}
|
||||
except Exception as e:
|
||||
@ -225,7 +240,6 @@ class ConditionalRouter:
|
||||
|
||||
:param routes: A list of routes.
|
||||
"""
|
||||
env = NativeEnvironment()
|
||||
for route in routes:
|
||||
try:
|
||||
keys = set(route.keys())
|
||||
@ -239,10 +253,10 @@ class ConditionalRouter:
|
||||
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
|
||||
)
|
||||
for field in ["condition", "output"]:
|
||||
if not self._validate_template(env, route[field]):
|
||||
if not self._validate_template(self._env, route[field]):
|
||||
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
|
||||
|
||||
def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
|
||||
def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
|
||||
"""
|
||||
Extracts all variables from a list of Jinja template strings.
|
||||
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `unsafe` argument to enable behaviour that could lead to remote code execution in `ConditionalRouter` and `OutputAdapter`.
|
||||
By default unsafe behaviour is not enabled, the user must set it explicitly to `True`.
|
||||
This means that user types like `ChatMessage`, `Document`, and `Answer` can be used as output types when `unsafe` is `True`.
|
||||
We recommend using `unsafe` behaviour only when the Jinja templates source is trusted.
|
||||
For more info see the documentation for [`ConditionalRouter`](https://docs.haystack.deepset.ai/docs/conditionalrouter#unsafe-behaviour) and [`OutputAdapter`](https://docs.haystack.deepset.ai/docs/outputadapter#unsafe-behaviour)
|
||||
@ -8,6 +8,7 @@ import json
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline, component
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.components.converters import OutputAdapter
|
||||
from haystack.components.converters.output_adapter import OutputAdaptationException
|
||||
|
||||
@ -150,6 +151,7 @@ class TestOutputAdapter:
|
||||
"template": "{{ documents[0].content}}",
|
||||
"output_type": "str",
|
||||
"custom_filters": None,
|
||||
"unsafe": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
@ -157,6 +159,7 @@ class TestOutputAdapter:
|
||||
assert component.template == "{{ documents[0].content}}"
|
||||
assert component.output_type == str
|
||||
assert component.custom_filters == {}
|
||||
assert not component._unsafe
|
||||
|
||||
def test_output_adapter_in_pipeline(self):
|
||||
@component
|
||||
@ -179,3 +182,13 @@ class TestOutputAdapter:
|
||||
result = pipe.run(data={})
|
||||
assert result
|
||||
assert result["output_adapter"]["output"] == {"framework": "Haystack"}
|
||||
|
||||
def test_unsafe(self):
|
||||
adapter = OutputAdapter(template="{{ documents[0] }}", output_type=Document, unsafe=True)
|
||||
documents = [
|
||||
Document(content="Test document"),
|
||||
Document(content="Another test document"),
|
||||
Document(content="Yet another test document"),
|
||||
]
|
||||
res = adapter.run(documents=documents)
|
||||
assert res["output"] == documents[0]
|
||||
|
||||
@ -307,3 +307,24 @@ class TestRouter:
|
||||
assert deserialized_router.custom_filters == router.custom_filters
|
||||
assert deserialized_router.custom_filters["custom_filter_to_sede"]("123-456-789") == 123
|
||||
assert result == deserialized_router.run(**kwargs)
|
||||
|
||||
def test_unsafe(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": "{{message}}",
|
||||
"output_type": ChatMessage,
|
||||
"output_name": "message",
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes, unsafe=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
res = router.run(streams=streams, message=message)
|
||||
assert res == {"message": message}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user