From da49e782e2c53cee15e98bdcf03fbfabdfd71152 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:09:51 +0200 Subject: [PATCH] chore: Make `arrow` an optional dependency (#8345) * Make arrow an optional dependency * Fix imports --- haystack/components/builders/prompt_builder.py | 7 ++++++- haystack/utils/jinja2_extensions.py | 7 ++++++- pyproject.toml | 2 +- .../components/builders/test_prompt_builder.py | 18 ++++++++++++++---- test/utils/test_jinja2_extensions.py | 8 ++++++++ 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index 68b51b20a..b9899c922 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -160,8 +160,13 @@ class PromptBuilder: self._variables = variables self._required_variables = required_variables self.required_variables = required_variables or [] + try: + # The Jinja2TimeExtension needs an optional dependency to be installed. + # If it's not available we can do without it and use the PromptBuilder as is. + self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension]) + except ImportError: + self._env = SandboxedEnvironment() - self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension]) self.template = self._env.from_string(template) if not variables: # infer variables from template diff --git a/haystack/utils/jinja2_extensions.py b/haystack/utils/jinja2_extensions.py index 4dd8d88da..94d0dc8fd 100644 --- a/haystack/utils/jinja2_extensions.py +++ b/haystack/utils/jinja2_extensions.py @@ -4,10 +4,14 @@ from typing import Any, List, Optional, Union -import arrow from jinja2 import Environment, nodes from jinja2.ext import Extension +from haystack.lazy_imports import LazyImport + +with LazyImport(message='Run "pip install arrow>=1.3.0"') as arrow_import: + import arrow + class Jinja2TimeExtension(Extension): # Syntax for current date @@ -20,6 +24,7 @@ class Jinja2TimeExtension(Extension): :param environment: The Jinja2 environment to initialize the extension with. It provides the context where the extension will operate. """ + arrow_import.check() super().__init__(environment) @staticmethod diff --git a/pyproject.toml b/pyproject.toml index 2f17e8898..e5c525d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ dependencies = [ "numpy<2", "python-dateutil", "haystack-experimental", - "arrow>=1.3.0" # Jinja2TimeExtension ] [tool.hatch.envs.default] @@ -86,6 +85,7 @@ extra-dependencies = [ "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber + "arrow>=1.3.0", # Jinja2TimeExtension # NamedEntityExtractor "spacy>=3.7,<3.8", diff --git a/test/components/builders/test_prompt_builder.py b/test/components/builders/test_prompt_builder.py index 1c57bd660..7461327f4 100644 --- a/test/components/builders/test_prompt_builder.py +++ b/test/components/builders/test_prompt_builder.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict, List, Optional -from jinja2 import TemplateSyntaxError -import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch + import arrow -from haystack.components.builders.prompt_builder import PromptBuilder +import pytest +from jinja2 import TemplateSyntaxError + from haystack import component +from haystack.components.builders.prompt_builder import PromptBuilder from haystack.core.pipeline.pipeline import Pipeline from haystack.dataclasses.document import Document @@ -77,6 +79,14 @@ class TestPromptBuilder: assert set(outputs.keys()) == {"prompt"} assert outputs["prompt"].type == str + @patch("haystack.components.builders.prompt_builder.Jinja2TimeExtension") + def test_init_with_missing_extension_dependency(self, extension_mock): + extension_mock.side_effect = ImportError + builder = PromptBuilder(template="This is a {{ variable }}") + assert builder._env.extensions == {} + res = builder.run(variable="test") + assert res == {"prompt": "This is a test"} + def test_to_dict(self): builder = PromptBuilder( template="This is a {{ variable }}", variables=["var1", "var2"], required_variables=["var1", "var3"] diff --git a/test/utils/test_jinja2_extensions.py b/test/utils/test_jinja2_extensions.py index c11140232..56d5ab573 100644 --- a/test/utils/test_jinja2_extensions.py +++ b/test/utils/test_jinja2_extensions.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + import pytest from jinja2 import Environment import arrow @@ -17,6 +19,12 @@ class TestJinja2TimeExtension: def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension: return Jinja2TimeExtension(jinja_env) + @patch("haystack.utils.jinja2_extensions.arrow_import") + def test_init_fails_without_arrow(self, arrow_import_mock) -> None: + arrow_import_mock.check.side_effect = ImportError + with pytest.raises(ImportError): + Jinja2TimeExtension(Environment()) + def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None: result = jinja_extension._get_datetime( "UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"