mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
fix: Fix issue that could lead to RCE if using unsecure Jinja templates (#8095)
* Fix issue that could lead to RCE if using unsecure Jinja templates * Add comment explaining exception suppression * Update release note * Update release note
This commit is contained in:
parent
47f4db8698
commit
3fed1366c4
@ -4,7 +4,8 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import Template, meta
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
|
||||
@ -123,12 +124,12 @@ class ChatPromptBuilder:
|
||||
self.required_variables = required_variables or []
|
||||
self.template = template
|
||||
variables = variables or []
|
||||
self._env = SandboxedEnvironment()
|
||||
if template and not variables:
|
||||
for message in template:
|
||||
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
|
||||
# infere variables from template
|
||||
msg_template = Template(message.content)
|
||||
ast = msg_template.environment.parse(message.content)
|
||||
ast = self._env.parse(message.content)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
variables += list(template_variables)
|
||||
|
||||
@ -194,7 +195,8 @@ class ChatPromptBuilder:
|
||||
for message in template:
|
||||
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
|
||||
self._validate_variables(set(template_variables_combined.keys()))
|
||||
compiled_template = Template(message.content)
|
||||
|
||||
compiled_template = self._env.from_string(message.content)
|
||||
rendered_content = compiled_template.render(template_variables_combined)
|
||||
rendered_message = (
|
||||
ChatMessage.from_user(rendered_content)
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import Template, meta
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
|
||||
@ -158,10 +159,12 @@ class PromptBuilder:
|
||||
self._variables = variables
|
||||
self._required_variables = required_variables
|
||||
self.required_variables = required_variables or []
|
||||
self.template = Template(template)
|
||||
|
||||
self._env = SandboxedEnvironment()
|
||||
self.template = self._env.from_string(template)
|
||||
if not variables:
|
||||
# infere variables from template
|
||||
ast = self.template.environment.parse(template)
|
||||
ast = self._env.parse(template)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
variables = list(template_variables)
|
||||
|
||||
@ -216,8 +219,8 @@ class PromptBuilder:
|
||||
self._validate_variables(set(template_variables_combined.keys()))
|
||||
|
||||
compiled_template = self.template
|
||||
if isinstance(template, str):
|
||||
compiled_template = Template(template)
|
||||
if template is not None:
|
||||
compiled_template = self._env.from_string(template)
|
||||
|
||||
result = compiled_template.render(template_variables_combined)
|
||||
return {"prompt": result}
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Optional, Set
|
||||
|
||||
import jinja2.runtime
|
||||
from jinja2 import TemplateSyntaxError, meta
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
@ -58,18 +60,18 @@ class OutputAdapter:
|
||||
|
||||
# Create a Jinja native environment, we need it to:
|
||||
# a) add custom filters to the environment for filter compilation stage
|
||||
env = NativeEnvironment()
|
||||
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
|
||||
try:
|
||||
env.parse(template) # Validate template syntax
|
||||
self._env.parse(template) # Validate template syntax
|
||||
self.template = template
|
||||
except TemplateSyntaxError as e:
|
||||
raise ValueError(f"Invalid Jinja template '{template}': {e}") from e
|
||||
|
||||
for name, filter_func in self.custom_filters.items():
|
||||
env.filters[name] = filter_func
|
||||
self._env.filters[name] = filter_func
|
||||
|
||||
# b) extract variables in the template
|
||||
route_input_names = self._extract_variables(env)
|
||||
route_input_names = self._extract_variables(self._env)
|
||||
input_types.update(route_input_names)
|
||||
|
||||
# the env is not needed, discarded automatically
|
||||
@ -92,16 +94,22 @@ class OutputAdapter:
|
||||
# check if kwargs are empty
|
||||
if not kwargs:
|
||||
raise ValueError("No input data provided for output adaptation")
|
||||
env = NativeEnvironment()
|
||||
for name, filter_func in self.custom_filters.items():
|
||||
env.filters[name] = filter_func
|
||||
self._env.filters[name] = filter_func
|
||||
adapted_outputs = {}
|
||||
try:
|
||||
adapted_output_template = env.from_string(self.template)
|
||||
adapted_output_template = self._env.from_string(self.template)
|
||||
output_result = adapted_output_template.render(**kwargs)
|
||||
if isinstance(output_result, jinja2.runtime.Undefined):
|
||||
raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}")
|
||||
|
||||
# We suppress the exception in case the output is already a string, otherwise
|
||||
# we try to evaluate it and would fail.
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
output_result = ast.literal_eval(output_result)
|
||||
|
||||
adapted_outputs["output"] = output_result
|
||||
except Exception as e:
|
||||
raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e
|
||||
@ -135,14 +143,12 @@ class OutputAdapter:
|
||||
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _extract_variables(self, env: NativeEnvironment) -> Set[str]:
|
||||
def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]:
|
||||
"""
|
||||
Extracts all variables from a list of Jinja template strings.
|
||||
|
||||
:param env: A Jinja native environment.
|
||||
:return: A set of variable names extracted from the template strings.
|
||||
"""
|
||||
variables = set()
|
||||
ast = env.parse(self.template)
|
||||
variables.update(meta.find_undeclared_variables(ast))
|
||||
return variables
|
||||
return meta.find_undeclared_variables(ast)
|
||||
|
||||
@ -2,10 +2,13 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import Environment, TemplateSyntaxError, meta
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
|
||||
@ -125,8 +128,8 @@ class ConditionalRouter:
|
||||
self.custom_filters = custom_filters or {}
|
||||
|
||||
# Create a Jinja native environment to inspect variables in the condition templates
|
||||
env = NativeEnvironment()
|
||||
env.filters.update(self.custom_filters)
|
||||
self._env = SandboxedEnvironment()
|
||||
self._env.filters.update(self.custom_filters)
|
||||
|
||||
# Inspect the routes to determine input and output types.
|
||||
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
||||
@ -134,7 +137,7 @@ class ConditionalRouter:
|
||||
|
||||
for route in routes:
|
||||
# extract inputs
|
||||
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
|
||||
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
|
||||
input_types.update(route_input_names)
|
||||
|
||||
# extract outputs
|
||||
@ -194,16 +197,20 @@ class ConditionalRouter:
|
||||
routes.
|
||||
"""
|
||||
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
||||
env = NativeEnvironment()
|
||||
env.filters.update(self.custom_filters)
|
||||
|
||||
for route in self.routes:
|
||||
try:
|
||||
t = env.from_string(route["condition"])
|
||||
if t.render(**kwargs):
|
||||
t = self._env.from_string(route["condition"])
|
||||
rendered = t.render(**kwargs)
|
||||
if ast.literal_eval(rendered):
|
||||
# We now evaluate the `output` expression to determine the route output
|
||||
t_output = env.from_string(route["output"])
|
||||
t_output = self._env.from_string(route["output"])
|
||||
output = t_output.render(**kwargs)
|
||||
# We suppress the exception in case the output is already a string, otherwise
|
||||
# we try to evaluate it and would fail.
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
output = ast.literal_eval(output)
|
||||
# and return the output as a dictionary under the output_name key
|
||||
return {route["output_name"]: output}
|
||||
except Exception as e:
|
||||
@ -234,7 +241,7 @@ class ConditionalRouter:
|
||||
if not self._validate_template(env, route[field]):
|
||||
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
|
||||
|
||||
def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
|
||||
def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
|
||||
"""
|
||||
Extracts all variables from a list of Jinja template strings.
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from jinja2 import Environment, PackageLoader, TemplateSyntaxError, meta
|
||||
from jinja2 import PackageLoader, TemplateSyntaxError, meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
TEMPLATE_FILE_EXTENSION = ".yaml.jinja2"
|
||||
TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined"
|
||||
@ -74,7 +75,7 @@ class PipelineTemplate:
|
||||
|
||||
:param template_content: The raw template source to use in the template.
|
||||
"""
|
||||
env = Environment(
|
||||
env = SandboxedEnvironment(
|
||||
loader=PackageLoader("haystack.core.pipeline", "predefined"), trim_blocks=True, lstrip_blocks=True
|
||||
)
|
||||
try:
|
||||
|
||||
14
releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml
Normal file
14
releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
`OutputAdapter` and `ConditionalRouter` can't return users inputs anymore.
|
||||
security:
|
||||
- |
|
||||
Fix issue that could lead to remote code execution when using insecure Jinja template in the following Components:
|
||||
|
||||
- `PromptBuilder`
|
||||
- `ChatPromptBuilder`
|
||||
- `OutputAdapter`
|
||||
- `ConditionalRouter`
|
||||
|
||||
The same issue has been fixed in the `PipelineTemplate` class too.
|
||||
@ -97,7 +97,23 @@ class TestRouter:
|
||||
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
|
||||
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
|
||||
|
||||
def test_router_evaluate_condition_expressions(self, router):
|
||||
def test_router_evaluate_condition_expressions(self):
|
||||
router = ConditionalRouter(
|
||||
[
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": "{{query}}",
|
||||
"output_type": str,
|
||||
"output_name": "query",
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
)
|
||||
# first route should be selected
|
||||
kwargs = {"streams": [1, 2, 3], "query": "test"}
|
||||
result = router.run(**kwargs)
|
||||
@ -227,52 +243,6 @@ class TestRouter:
|
||||
# check that the result is the same and correct
|
||||
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}
|
||||
|
||||
def test_router_de_serialization_user_type(self):
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": "{{message}}",
|
||||
"output_type": ChatMessage,
|
||||
"output_name": "message",
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": "{{streams}}",
|
||||
"output_type": List[int],
|
||||
"output_name": "streams",
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes)
|
||||
router_dict = router.to_dict()
|
||||
|
||||
# assert that the router dict is correct, with all keys and values being strings
|
||||
for route in router_dict["init_parameters"]["routes"]:
|
||||
for key in route.keys():
|
||||
assert isinstance(key, str)
|
||||
assert isinstance(route[key], str)
|
||||
|
||||
# check that the output_type is a string and a proper class name
|
||||
assert (
|
||||
router_dict["init_parameters"]["routes"][0]["output_type"]
|
||||
== "haystack.dataclasses.chat_message.ChatMessage"
|
||||
)
|
||||
|
||||
# deserialize the router
|
||||
new_router = ConditionalRouter.from_dict(router_dict)
|
||||
|
||||
# check that the output_type is the right class
|
||||
assert new_router.routes[0]["output_type"] == ChatMessage
|
||||
assert router.routes == new_router.routes
|
||||
|
||||
# now use both routers to run the same message
|
||||
message = ChatMessage.from_user("ciao")
|
||||
kwargs = {"streams": [1], "message": message}
|
||||
result1 = router.run(**kwargs)
|
||||
result2 = new_router.run(**kwargs)
|
||||
|
||||
# check that the result is the same and correct
|
||||
assert result1 == result2 and result1["message"].content == message.content
|
||||
|
||||
def test_router_serialization_idempotence(self):
|
||||
routes = [
|
||||
{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user