diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index a1c47a508..a82f7eef7 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Set -from jinja2 import Template, meta +from jinja2 import meta +from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses.chat_message import ChatMessage, ChatRole @@ -123,12 +124,12 @@ class ChatPromptBuilder: self.required_variables = required_variables or [] self.template = template variables = variables or [] + self._env = SandboxedEnvironment() if template and not variables: for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): # infere variables from template - msg_template = Template(message.content) - ast = msg_template.environment.parse(message.content) + ast = self._env.parse(message.content) template_variables = meta.find_undeclared_variables(ast) variables += list(template_variables) @@ -194,7 +195,8 @@ class ChatPromptBuilder: for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): self._validate_variables(set(template_variables_combined.keys())) - compiled_template = Template(message.content) + + compiled_template = self._env.from_string(message.content) rendered_content = compiled_template.render(template_variables_combined) rendered_message = ( ChatMessage.from_user(rendered_content) diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index b071eb121..ca8fd1728 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Set -from jinja2 import Template, meta +from jinja2 import meta +from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_to_dict @@ -158,10 +159,12 @@ class PromptBuilder: self._variables = variables self._required_variables = required_variables self.required_variables = required_variables or [] - self.template = Template(template) + + self._env = SandboxedEnvironment() + self.template = self._env.from_string(template) if not variables: # infere variables from template - ast = self.template.environment.parse(template) + ast = self._env.parse(template) template_variables = meta.find_undeclared_variables(ast) variables = list(template_variables) @@ -216,8 +219,8 @@ class PromptBuilder: self._validate_variables(set(template_variables_combined.keys())) compiled_template = self.template - if isinstance(template, str): - compiled_template = Template(template) + if template is not None: + compiled_template = self._env.from_string(template) result = compiled_template.render(template_variables_combined) return {"prompt": result} diff --git a/haystack/components/converters/output_adapter.py b/haystack/components/converters/output_adapter.py index aabab1f2b..ce0e1a4e9 100644 --- a/haystack/components/converters/output_adapter.py +++ b/haystack/components/converters/output_adapter.py @@ -2,11 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +import ast +import contextlib from typing import Any, Callable, Dict, Optional, Set import jinja2.runtime from jinja2 import TemplateSyntaxError, meta -from jinja2.nativetypes import NativeEnvironment +from jinja2.sandbox import SandboxedEnvironment from typing_extensions import TypeAlias from haystack import component, default_from_dict, default_to_dict @@ -58,18 +60,18 @@ class OutputAdapter: # Create a Jinja native environment, we need it to: # a) add custom filters to the environment for filter compilation stage - env = NativeEnvironment() + self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) try: - env.parse(template) # Validate template syntax + self._env.parse(template) # Validate template syntax self.template = template except TemplateSyntaxError as e: raise ValueError(f"Invalid Jinja template '{template}': {e}") from e for name, filter_func in self.custom_filters.items(): - env.filters[name] = filter_func + self._env.filters[name] = filter_func # b) extract variables in the template - route_input_names = self._extract_variables(env) + route_input_names = self._extract_variables(self._env) input_types.update(route_input_names) # the env is not needed, discarded automatically @@ -92,16 +94,22 @@ class OutputAdapter: # check if kwargs are empty if not kwargs: raise ValueError("No input data provided for output adaptation") - env = NativeEnvironment() for name, filter_func in self.custom_filters.items(): - env.filters[name] = filter_func + self._env.filters[name] = filter_func adapted_outputs = {} try: - adapted_output_template = env.from_string(self.template) + adapted_output_template = self._env.from_string(self.template) output_result = adapted_output_template.render(**kwargs) if isinstance(output_result, jinja2.runtime.Undefined): raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}") + # We suppress the exception in case the output is already a string, otherwise + # we try to evaluate it and would fail. + # This must be done cause the output could be different literal structures. + # This doesn't support any user types. + with contextlib.suppress(Exception): + output_result = ast.literal_eval(output_result) + adapted_outputs["output"] = output_result except Exception as e: raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e @@ -135,14 +143,12 @@ class OutputAdapter: init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None return default_from_dict(cls, data) - def _extract_variables(self, env: NativeEnvironment) -> Set[str]: + def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]: """ Extracts all variables from a list of Jinja template strings. :param env: A Jinja native environment. :return: A set of variable names extracted from the template strings. """ - variables = set() ast = env.parse(self.template) - variables.update(meta.find_undeclared_variables(ast)) - return variables + return meta.find_undeclared_variables(ast) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index be72c7cc5..74dcf3293 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +import ast +import contextlib from typing import Any, Callable, Dict, List, Optional, Set from jinja2 import Environment, TemplateSyntaxError, meta from jinja2.nativetypes import NativeEnvironment +from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type @@ -125,8 +128,8 @@ class ConditionalRouter: self.custom_filters = custom_filters or {} # Create a Jinja native environment to inspect variables in the condition templates - env = NativeEnvironment() - env.filters.update(self.custom_filters) + self._env = SandboxedEnvironment() + self._env.filters.update(self.custom_filters) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any @@ -134,7 +137,7 @@ class ConditionalRouter: for route in routes: # extract inputs - route_input_names = self._extract_variables(env, [route["output"], route["condition"]]) + route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]]) input_types.update(route_input_names) # extract outputs @@ -194,16 +197,20 @@ class ConditionalRouter: routes. """ # Create a Jinja native environment to evaluate the condition templates as Python expressions - env = NativeEnvironment() - env.filters.update(self.custom_filters) - for route in self.routes: try: - t = env.from_string(route["condition"]) - if t.render(**kwargs): + t = self._env.from_string(route["condition"]) + rendered = t.render(**kwargs) + if ast.literal_eval(rendered): # We now evaluate the `output` expression to determine the route output - t_output = env.from_string(route["output"]) + t_output = self._env.from_string(route["output"]) output = t_output.render(**kwargs) + # We suppress the exception in case the output is already a string, otherwise + # we try to evaluate it and would fail. + # This must be done cause the output could be different literal structures. + # This doesn't support any user types. + with contextlib.suppress(Exception): + output = ast.literal_eval(output) # and return the output as a dictionary under the output_name key return {route["output_name"]: output} except Exception as e: @@ -234,7 +241,7 @@ class ConditionalRouter: if not self._validate_template(env, route[field]): raise ValueError(f"Invalid template for field '{field}': {route[field]}") - def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]: + def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]: """ Extracts all variables from a list of Jinja template strings. diff --git a/haystack/core/pipeline/template.py b/haystack/core/pipeline/template.py index 514c17280..338e6b3fb 100644 --- a/haystack/core/pipeline/template.py +++ b/haystack/core/pipeline/template.py @@ -6,7 +6,8 @@ from enum import Enum from pathlib import Path from typing import Any, Dict, Optional, Union -from jinja2 import Environment, PackageLoader, TemplateSyntaxError, meta +from jinja2 import PackageLoader, TemplateSyntaxError, meta +from jinja2.sandbox import SandboxedEnvironment TEMPLATE_FILE_EXTENSION = ".yaml.jinja2" TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined" @@ -74,7 +75,7 @@ class PipelineTemplate: :param template_content: The raw template source to use in the template. """ - env = Environment( + env = SandboxedEnvironment( loader=PackageLoader("haystack.core.pipeline", "predefined"), trim_blocks=True, lstrip_blocks=True ) try: diff --git a/releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml b/releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml new file mode 100644 index 000000000..328fae814 --- /dev/null +++ b/releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml @@ -0,0 +1,14 @@ +--- +upgrade: + - | + `OutputAdapter` and `ConditionalRouter` can't return users inputs anymore. +security: + - | + Fix issue that could lead to remote code execution when using insecure Jinja template in the following Components: + + - `PromptBuilder` + - `ChatPromptBuilder` + - `OutputAdapter` + - `ConditionalRouter` + + The same issue has been fixed in the `PipelineTemplate` class too. diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 0b9059571..b324cf806 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -97,7 +97,23 @@ class TestRouter: assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"} assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"} - def test_router_evaluate_condition_expressions(self, router): + def test_router_evaluate_condition_expressions(self): + router = ConditionalRouter( + [ + { + "condition": "{{streams|length < 2}}", + "output": "{{query}}", + "output_type": str, + "output_name": "query", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + ) # first route should be selected kwargs = {"streams": [1, 2, 3], "query": "test"} result = router.run(**kwargs) @@ -227,52 +243,6 @@ class TestRouter: # check that the result is the same and correct assert result1 == result2 and result1 == {"streams": [1, 2, 3]} - def test_router_de_serialization_user_type(self): - routes = [ - { - "condition": "{{streams|length < 2}}", - "output": "{{message}}", - "output_type": ChatMessage, - "output_name": "message", - }, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - router = ConditionalRouter(routes) - router_dict = router.to_dict() - - # assert that the router dict is correct, with all keys and values being strings - for route in router_dict["init_parameters"]["routes"]: - for key in route.keys(): - assert isinstance(key, str) - assert isinstance(route[key], str) - - # check that the output_type is a string and a proper class name - assert ( - router_dict["init_parameters"]["routes"][0]["output_type"] - == "haystack.dataclasses.chat_message.ChatMessage" - ) - - # deserialize the router - new_router = ConditionalRouter.from_dict(router_dict) - - # check that the output_type is the right class - assert new_router.routes[0]["output_type"] == ChatMessage - assert router.routes == new_router.routes - - # now use both routers to run the same message - message = ChatMessage.from_user("ciao") - kwargs = {"streams": [1], "message": message} - result1 = router.run(**kwargs) - result2 = new_router.run(**kwargs) - - # check that the result is the same and correct - assert result1 == result2 and result1["message"].content == message.content - def test_router_serialization_idempotence(self): routes = [ {