fix: Fix issue that could lead to RCE if using unsecure Jinja templates (#8095)

* Fix issue that could lead to RCE if using unsecure Jinja templates

* Add comment explaining exception suppression

* Update release note

* Update release note
This commit is contained in:
Silvano Cerza 2024-07-26 16:02:09 +02:00 committed by GitHub
parent 47f4db8698
commit 3fed1366c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 83 additions and 80 deletions

View File

@ -4,7 +4,8 @@
from typing import Any, Dict, List, Optional, Set
from jinja2 import Template, meta
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
@ -123,12 +124,12 @@ class ChatPromptBuilder:
self.required_variables = required_variables or []
self.template = template
variables = variables or []
self._env = SandboxedEnvironment()
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
msg_template = Template(message.content)
ast = msg_template.environment.parse(message.content)
ast = self._env.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
@ -194,7 +195,8 @@ class ChatPromptBuilder:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))
compiled_template = Template(message.content)
compiled_template = self._env.from_string(message.content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = (
ChatMessage.from_user(rendered_content)

View File

@ -4,7 +4,8 @@
from typing import Any, Dict, List, Optional, Set
from jinja2 import Template, meta
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from haystack import component, default_to_dict
@ -158,10 +159,12 @@ class PromptBuilder:
self._variables = variables
self._required_variables = required_variables
self.required_variables = required_variables or []
self.template = Template(template)
self._env = SandboxedEnvironment()
self.template = self._env.from_string(template)
if not variables:
# infere variables from template
ast = self.template.environment.parse(template)
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
@ -216,8 +219,8 @@ class PromptBuilder:
self._validate_variables(set(template_variables_combined.keys()))
compiled_template = self.template
if isinstance(template, str):
compiled_template = Template(template)
if template is not None:
compiled_template = self._env.from_string(template)
result = compiled_template.render(template_variables_combined)
return {"prompt": result}

View File

@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0
import ast
import contextlib
from typing import Any, Callable, Dict, Optional, Set
import jinja2.runtime
from jinja2 import TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias
from haystack import component, default_from_dict, default_to_dict
@ -58,18 +60,18 @@ class OutputAdapter:
# Create a Jinja native environment, we need it to:
# a) add custom filters to the environment for filter compilation stage
env = NativeEnvironment()
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
try:
env.parse(template) # Validate template syntax
self._env.parse(template) # Validate template syntax
self.template = template
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja template '{template}': {e}") from e
for name, filter_func in self.custom_filters.items():
env.filters[name] = filter_func
self._env.filters[name] = filter_func
# b) extract variables in the template
route_input_names = self._extract_variables(env)
route_input_names = self._extract_variables(self._env)
input_types.update(route_input_names)
# the env is not needed, discarded automatically
@ -92,16 +94,22 @@ class OutputAdapter:
# check if kwargs are empty
if not kwargs:
raise ValueError("No input data provided for output adaptation")
env = NativeEnvironment()
for name, filter_func in self.custom_filters.items():
env.filters[name] = filter_func
self._env.filters[name] = filter_func
adapted_outputs = {}
try:
adapted_output_template = env.from_string(self.template)
adapted_output_template = self._env.from_string(self.template)
output_result = adapted_output_template.render(**kwargs)
if isinstance(output_result, jinja2.runtime.Undefined):
raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}")
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
output_result = ast.literal_eval(output_result)
adapted_outputs["output"] = output_result
except Exception as e:
raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e
@ -135,14 +143,12 @@ class OutputAdapter:
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)
def _extract_variables(self, env: NativeEnvironment) -> Set[str]:
def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.
:param env: A Jinja native environment.
:return: A set of variable names extracted from the template strings.
"""
variables = set()
ast = env.parse(self.template)
variables.update(meta.find_undeclared_variables(ast))
return variables
return meta.find_undeclared_variables(ast)

View File

@ -2,10 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0
import ast
import contextlib
from typing import Any, Callable, Dict, List, Optional, Set
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
@ -125,8 +128,8 @@ class ConditionalRouter:
self.custom_filters = custom_filters or {}
# Create a Jinja native environment to inspect variables in the condition templates
env = NativeEnvironment()
env.filters.update(self.custom_filters)
self._env = SandboxedEnvironment()
self._env.filters.update(self.custom_filters)
# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
@ -134,7 +137,7 @@ class ConditionalRouter:
for route in routes:
# extract inputs
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
input_types.update(route_input_names)
# extract outputs
@ -194,16 +197,20 @@ class ConditionalRouter:
routes.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
env = NativeEnvironment()
env.filters.update(self.custom_filters)
for route in self.routes:
try:
t = env.from_string(route["condition"])
if t.render(**kwargs):
t = self._env.from_string(route["condition"])
rendered = t.render(**kwargs)
if ast.literal_eval(rendered):
# We now evaluate the `output` expression to determine the route output
t_output = env.from_string(route["output"])
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
output = ast.literal_eval(output)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
except Exception as e:
@ -234,7 +241,7 @@ class ConditionalRouter:
if not self._validate_template(env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.

View File

@ -6,7 +6,8 @@ from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Union
from jinja2 import Environment, PackageLoader, TemplateSyntaxError, meta
from jinja2 import PackageLoader, TemplateSyntaxError, meta
from jinja2.sandbox import SandboxedEnvironment
TEMPLATE_FILE_EXTENSION = ".yaml.jinja2"
TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined"
@ -74,7 +75,7 @@ class PipelineTemplate:
:param template_content: The raw template source to use in the template.
"""
env = Environment(
env = SandboxedEnvironment(
loader=PackageLoader("haystack.core.pipeline", "predefined"), trim_blocks=True, lstrip_blocks=True
)
try:

View File

@ -0,0 +1,14 @@
---
upgrade:
- |
`OutputAdapter` and `ConditionalRouter` can't return users inputs anymore.
security:
- |
Fix issue that could lead to remote code execution when using insecure Jinja template in the following Components:
- `PromptBuilder`
- `ChatPromptBuilder`
- `OutputAdapter`
- `ConditionalRouter`
The same issue has been fixed in the `PipelineTemplate` class too.

View File

@ -97,7 +97,23 @@ class TestRouter:
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
def test_router_evaluate_condition_expressions(self, router):
def test_router_evaluate_condition_expressions(self):
router = ConditionalRouter(
[
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": str,
"output_name": "query",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
)
# first route should be selected
kwargs = {"streams": [1, 2, 3], "query": "test"}
result = router.run(**kwargs)
@ -227,52 +243,6 @@ class TestRouter:
# check that the result is the same and correct
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}
def test_router_de_serialization_user_type(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
router_dict = router.to_dict()
# assert that the router dict is correct, with all keys and values being strings
for route in router_dict["init_parameters"]["routes"]:
for key in route.keys():
assert isinstance(key, str)
assert isinstance(route[key], str)
# check that the output_type is a string and a proper class name
assert (
router_dict["init_parameters"]["routes"][0]["output_type"]
== "haystack.dataclasses.chat_message.ChatMessage"
)
# deserialize the router
new_router = ConditionalRouter.from_dict(router_dict)
# check that the output_type is the right class
assert new_router.routes[0]["output_type"] == ChatMessage
assert router.routes == new_router.routes
# now use both routers to run the same message
message = ChatMessage.from_user("ciao")
kwargs = {"streams": [1], "message": message}
result1 = router.run(**kwargs)
result2 = new_router.run(**kwargs)
# check that the result is the same and correct
assert result1 == result2 and result1["message"].content == message.content
def test_router_serialization_idempotence(self):
routes = [
{