chore: Make arrow an optional dependency (#8345)

* Make arrow an optional dependency

* Fix imports
This commit is contained in:
Silvano Cerza 2024-09-09 16:09:51 +02:00 committed by GitHub
parent 720e54970f
commit da49e782e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 35 additions and 7 deletions

View File

@ -160,8 +160,13 @@ class PromptBuilder:
self._variables = variables
self._required_variables = required_variables
self.required_variables = required_variables or []
try:
# The Jinja2TimeExtension needs an optional dependency to be installed.
# If it's not available we can do without it and use the PromptBuilder as is.
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
except ImportError:
self._env = SandboxedEnvironment()
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
self.template = self._env.from_string(template)
if not variables:
# infer variables from template

View File

@ -4,10 +4,14 @@
from typing import Any, List, Optional, Union
import arrow
from jinja2 import Environment, nodes
from jinja2.ext import Extension
from haystack.lazy_imports import LazyImport
with LazyImport(message='Run "pip install arrow>=1.3.0"') as arrow_import:
import arrow
class Jinja2TimeExtension(Extension):
# Syntax for current date
@ -20,6 +24,7 @@ class Jinja2TimeExtension(Extension):
:param environment: The Jinja2 environment to initialize the extension with.
It provides the context where the extension will operate.
"""
arrow_import.check()
super().__init__(environment)
@staticmethod

View File

@ -57,7 +57,6 @@ dependencies = [
"numpy<2",
"python-dateutil",
"haystack-experimental",
"arrow>=1.3.0" # Jinja2TimeExtension
]
[tool.hatch.envs.default]
@ -86,6 +85,7 @@ extra-dependencies = [
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
"arrow>=1.3.0", # Jinja2TimeExtension
# NamedEntityExtractor
"spacy>=3.7,<3.8",

View File

@ -2,12 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from jinja2 import TemplateSyntaxError
import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import patch
import arrow
from haystack.components.builders.prompt_builder import PromptBuilder
import pytest
from jinja2 import TemplateSyntaxError
from haystack import component
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.core.pipeline.pipeline import Pipeline
from haystack.dataclasses.document import Document
@ -77,6 +79,14 @@ class TestPromptBuilder:
assert set(outputs.keys()) == {"prompt"}
assert outputs["prompt"].type == str
@patch("haystack.components.builders.prompt_builder.Jinja2TimeExtension")
def test_init_with_missing_extension_dependency(self, extension_mock):
extension_mock.side_effect = ImportError
builder = PromptBuilder(template="This is a {{ variable }}")
assert builder._env.extensions == {}
res = builder.run(variable="test")
assert res == {"prompt": "This is a test"}
def test_to_dict(self):
builder = PromptBuilder(
template="This is a {{ variable }}", variables=["var1", "var2"], required_variables=["var1", "var3"]

View File

@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch
import pytest
from jinja2 import Environment
import arrow
@ -17,6 +19,12 @@ class TestJinja2TimeExtension:
def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension:
return Jinja2TimeExtension(jinja_env)
@patch("haystack.utils.jinja2_extensions.arrow_import")
def test_init_fails_without_arrow(self, arrow_import_mock) -> None:
arrow_import_mock.check.side_effect = ImportError
with pytest.raises(ImportError):
Jinja2TimeExtension(Environment())
def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime(
"UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"