mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 07:59:27 +00:00
chore: Make arrow an optional dependency (#8345)
* Make arrow an optional dependency * Fix imports
This commit is contained in:
parent
720e54970f
commit
da49e782e2
@ -160,8 +160,13 @@ class PromptBuilder:
|
||||
self._variables = variables
|
||||
self._required_variables = required_variables
|
||||
self.required_variables = required_variables or []
|
||||
try:
|
||||
# The Jinja2TimeExtension needs an optional dependency to be installed.
|
||||
# If it's not available we can do without it and use the PromptBuilder as is.
|
||||
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
|
||||
except ImportError:
|
||||
self._env = SandboxedEnvironment()
|
||||
|
||||
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
|
||||
self.template = self._env.from_string(template)
|
||||
if not variables:
|
||||
# infer variables from template
|
||||
|
||||
@ -4,10 +4,14 @@
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import arrow
|
||||
from jinja2 import Environment, nodes
|
||||
from jinja2.ext import Extension
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message='Run "pip install arrow>=1.3.0"') as arrow_import:
|
||||
import arrow
|
||||
|
||||
|
||||
class Jinja2TimeExtension(Extension):
|
||||
# Syntax for current date
|
||||
@ -20,6 +24,7 @@ class Jinja2TimeExtension(Extension):
|
||||
:param environment: The Jinja2 environment to initialize the extension with.
|
||||
It provides the context where the extension will operate.
|
||||
"""
|
||||
arrow_import.check()
|
||||
super().__init__(environment)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -57,7 +57,6 @@ dependencies = [
|
||||
"numpy<2",
|
||||
"python-dateutil",
|
||||
"haystack-experimental",
|
||||
"arrow>=1.3.0" # Jinja2TimeExtension
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
@ -86,6 +85,7 @@ extra-dependencies = [
|
||||
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
|
||||
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
|
||||
"openai-whisper>=20231106", # LocalWhisperTranscriber
|
||||
"arrow>=1.3.0", # Jinja2TimeExtension
|
||||
|
||||
# NamedEntityExtractor
|
||||
"spacy>=3.7,<3.8",
|
||||
|
||||
@ -2,12 +2,14 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
from jinja2 import TemplateSyntaxError
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import arrow
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
import pytest
|
||||
from jinja2 import TemplateSyntaxError
|
||||
|
||||
from haystack import component
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.core.pipeline.pipeline import Pipeline
|
||||
from haystack.dataclasses.document import Document
|
||||
|
||||
@ -77,6 +79,14 @@ class TestPromptBuilder:
|
||||
assert set(outputs.keys()) == {"prompt"}
|
||||
assert outputs["prompt"].type == str
|
||||
|
||||
@patch("haystack.components.builders.prompt_builder.Jinja2TimeExtension")
|
||||
def test_init_with_missing_extension_dependency(self, extension_mock):
|
||||
extension_mock.side_effect = ImportError
|
||||
builder = PromptBuilder(template="This is a {{ variable }}")
|
||||
assert builder._env.extensions == {}
|
||||
res = builder.run(variable="test")
|
||||
assert res == {"prompt": "This is a test"}
|
||||
|
||||
def test_to_dict(self):
|
||||
builder = PromptBuilder(
|
||||
template="This is a {{ variable }}", variables=["var1", "var2"], required_variables=["var1", "var3"]
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from jinja2 import Environment
|
||||
import arrow
|
||||
@ -17,6 +19,12 @@ class TestJinja2TimeExtension:
|
||||
def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension:
|
||||
return Jinja2TimeExtension(jinja_env)
|
||||
|
||||
@patch("haystack.utils.jinja2_extensions.arrow_import")
|
||||
def test_init_fails_without_arrow(self, arrow_import_mock) -> None:
|
||||
arrow_import_mock.check.side_effect = ImportError
|
||||
with pytest.raises(ImportError):
|
||||
Jinja2TimeExtension(Environment())
|
||||
|
||||
def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime(
|
||||
"UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user