mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: Add current date in UTC to PromptBuilder (#8233)
* initial commit * add unit tests * add release notes * update function name
This commit is contained in:
parent
e31b3edda1
commit
75955922b9
@ -8,6 +8,7 @@ from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
from haystack.utils import Jinja2TimeExtension
|
||||
|
||||
|
||||
@component
|
||||
@ -160,10 +161,10 @@ class PromptBuilder:
|
||||
self._required_variables = required_variables
|
||||
self.required_variables = required_variables or []
|
||||
|
||||
self._env = SandboxedEnvironment()
|
||||
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
|
||||
self.template = self._env.from_string(template)
|
||||
if not variables:
|
||||
# infere variables from template
|
||||
# infer variables from template
|
||||
ast = self._env.parse(template)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
variables = list(template_variables)
|
||||
|
||||
@ -8,6 +8,7 @@ from .device import ComponentDevice, Device, DeviceMap, DeviceType
|
||||
from .docstore_deserialization import deserialize_document_store_in_init_params_inplace
|
||||
from .expit import expit
|
||||
from .filters import document_matches_filter
|
||||
from .jinja2_extensions import Jinja2TimeExtension
|
||||
from .jupyter import is_in_jupyter
|
||||
from .requests_utils import request_with_retry
|
||||
from .type_serialization import deserialize_type, serialize_type
|
||||
@ -28,4 +29,5 @@ __all__ = [
|
||||
"serialize_type",
|
||||
"deserialize_type",
|
||||
"deserialize_document_store_in_init_params_inplace",
|
||||
"Jinja2TimeExtension",
|
||||
]
|
||||
|
||||
91
haystack/utils/jinja2_extensions.py
Normal file
91
haystack/utils/jinja2_extensions.py
Normal file
@ -0,0 +1,91 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import arrow
|
||||
from jinja2 import Environment, nodes
|
||||
from jinja2.ext import Extension
|
||||
|
||||
|
||||
class Jinja2TimeExtension(Extension):
|
||||
# Syntax for current date
|
||||
tags = {"now"}
|
||||
|
||||
def __init__(self, environment: Environment): # pylint: disable=useless-parent-delegation
|
||||
"""
|
||||
Initializes the JinjaTimeExtension object.
|
||||
|
||||
:param environment: The Jinja2 environment to initialize the extension with.
|
||||
It provides the context where the extension will operate.
|
||||
"""
|
||||
super().__init__(environment)
|
||||
|
||||
@staticmethod
|
||||
def _get_datetime(
|
||||
timezone: str,
|
||||
operator: Optional[str] = None,
|
||||
offset: Optional[str] = None,
|
||||
datetime_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the current datetime based on timezone, apply any offset if provided, and format the result.
|
||||
|
||||
:param timezone: The timezone string (e.g., 'UTC' or 'America/New_York') for which the current
|
||||
time should be fetched.
|
||||
:param operator: The operator ('+' or '-') to apply to the offset (used for adding/subtracting intervals).
|
||||
Defaults to None if no offset is applied, otherwise default is '+'.
|
||||
:param offset: The offset string in the format 'interval=value' (e.g., 'hours=2,days=1') specifying how much
|
||||
to adjust the datetime. The intervals can be any valid interval accepted
|
||||
by Arrow (e.g., hours, days, weeks, months). Defaults to None if no adjustment is needed.
|
||||
:param datetime_format: The format string to use for formatting the output datetime.
|
||||
Defaults to '%Y-%m-%d %H:%M:%S' if not provided.
|
||||
"""
|
||||
try:
|
||||
dt = arrow.now(timezone)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid timezone {timezone}: {e}")
|
||||
|
||||
if offset and operator:
|
||||
try:
|
||||
# Parse the offset and apply it to the datetime object
|
||||
replace_params = {
|
||||
interval.strip(): float(operator + value.strip())
|
||||
for param in offset.split(",")
|
||||
for interval, value in [param.split("=")]
|
||||
}
|
||||
# Shift the datetime fields based on the parsed offset
|
||||
dt = dt.shift(**replace_params)
|
||||
except (ValueError, AttributeError) as e:
|
||||
raise ValueError(f"Invalid offset or operator {offset}, {operator}: {e}")
|
||||
|
||||
# Use the provided format or fallback to the default one
|
||||
datetime_format = datetime_format or "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
return dt.strftime(datetime_format)
|
||||
|
||||
def parse(self, parser: Any) -> Union[nodes.Node, List[nodes.Node]]:
|
||||
"""
|
||||
Parse the template expression to determine how to handle the datetime formatting.
|
||||
|
||||
:param parser: The parser object that processes the template expressions and manages the syntax tree.
|
||||
It's used to interpret the template's structure.
|
||||
"""
|
||||
lineno = next(parser.stream).lineno
|
||||
node = parser.parse_expression()
|
||||
# Check if a custom datetime format is provided after a comma
|
||||
datetime_format = parser.parse_expression() if parser.stream.skip_if("comma") else nodes.Const(None)
|
||||
|
||||
# Default Add when no operator is provided
|
||||
operator = "+" if isinstance(node, nodes.Add) else "-"
|
||||
# Call the _get_datetime method with the appropriate operator and offset, if exist
|
||||
call_method = self.call_method(
|
||||
"_get_datetime",
|
||||
[node.left, nodes.Const(operator), node.right, datetime_format]
|
||||
if isinstance(node, (nodes.Add, nodes.Sub))
|
||||
else [node, nodes.Const(None), nodes.Const(None), datetime_format],
|
||||
lineno=lineno,
|
||||
)
|
||||
|
||||
return nodes.Output([call_method], lineno=lineno)
|
||||
@ -57,6 +57,7 @@ dependencies = [
|
||||
"numpy<2",
|
||||
"python-dateutil",
|
||||
"haystack-experimental",
|
||||
"arrow>=1.3.0" # Jinja2TimeExtension
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add the current date inside a template in `PromptBuilder` using the `utc_now()`.
|
||||
Users can also specify the date format, such as `utc_now("%Y-%m-%d")`.
|
||||
@ -4,7 +4,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from jinja2 import TemplateSyntaxError
|
||||
import pytest
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
import arrow
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack import component
|
||||
from haystack.core.pipeline.pipeline import Pipeline
|
||||
@ -254,3 +255,59 @@ class TestPromptBuilder:
|
||||
"prompt_builder": {"prompt": "This is the dynamic prompt:\n Query: Where does the speaker live?"}
|
||||
}
|
||||
assert result == expected_dynamic
|
||||
|
||||
def test_with_custom_dateformat(self) -> None:
|
||||
template = "Formatted date: {% now 'UTC', '%Y-%m-%d' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
result = builder.run()["prompt"]
|
||||
|
||||
now_formatted = f"Formatted date: {arrow.now('UTC').strftime('%Y-%m-%d')}"
|
||||
|
||||
assert now_formatted == result
|
||||
|
||||
def test_with_different_timezone(self) -> None:
|
||||
template = "Current time in New York is: {% now 'America/New_York' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
result = builder.run()["prompt"]
|
||||
|
||||
now_ny = f"Current time in New York is: {arrow.now('America/New_York').strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
assert now_ny == result
|
||||
|
||||
def test_date_with_addition_offset(self) -> None:
|
||||
template = "Time after 2 hours is: {% now 'UTC' + 'hours=2' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
result = builder.run()["prompt"]
|
||||
|
||||
now_plus_2 = f"Time after 2 hours is: {(arrow.now('UTC').shift(hours=+2)).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
assert now_plus_2 == result
|
||||
|
||||
def test_date_with_substraction_offset(self) -> None:
|
||||
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
result = builder.run()["prompt"]
|
||||
|
||||
now_plus_2 = f"Time after 12 days is: {(arrow.now('UTC').shift(days=-12)).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
assert now_plus_2 == result
|
||||
|
||||
def test_invalid_timezone(self) -> None:
|
||||
template = "Current time is: {% now 'Invalid/Timezone' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
# Expect ValueError for invalid timezone
|
||||
with pytest.raises(ValueError, match="Invalid timezone"):
|
||||
builder.run()
|
||||
|
||||
def test_invalid_offset(self) -> None:
|
||||
template = "Time after invalid offset is: {% now 'UTC' + 'invalid_offset' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
# Expect ValueError for invalid offset
|
||||
with pytest.raises(ValueError, match="Invalid offset or operator"):
|
||||
builder.run()
|
||||
|
||||
104
test/utils/test_jinja2_extensions.py
Normal file
104
test/utils/test_jinja2_extensions.py
Normal file
@ -0,0 +1,104 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from jinja2 import Environment
|
||||
import arrow
|
||||
from haystack.utils import Jinja2TimeExtension
|
||||
|
||||
|
||||
class TestJinja2TimeExtension:
|
||||
@pytest.fixture
|
||||
def jinja_env(self) -> Environment:
|
||||
return Environment(extensions=[Jinja2TimeExtension])
|
||||
|
||||
@pytest.fixture
|
||||
def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension:
|
||||
return Jinja2TimeExtension(jinja_env)
|
||||
|
||||
def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime(
|
||||
"UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 19
|
||||
|
||||
def test_parse_valid_expression(self, jinja_env: Environment) -> None:
|
||||
template = "{% now 'UTC' + 'hours=2', '%Y-%m-%d %H:%M:%S' %}"
|
||||
result = jinja_env.from_string(template).render()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 19
|
||||
|
||||
def test_get_datetime_no_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime("UTC")
|
||||
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_get_datetime_with_offset_add(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime("UTC", operator="+", offset="hours=1")
|
||||
expected = arrow.now("UTC").shift(hours=1).strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_get_datetime_with_offset_subtract(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1")
|
||||
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_get_datetime_with_offset_subtract_days_hours(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1, hours=2")
|
||||
expected = arrow.now("UTC").shift(days=-1, hours=-2).strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_get_datetime_with_custom_format(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
result = jinja_extension._get_datetime("UTC", datetime_format="%d-%m-%Y")
|
||||
expected = arrow.now("UTC").strftime("%d-%m-%Y")
|
||||
assert result == expected
|
||||
|
||||
def test_get_datetime_new_york_timezone(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'America/New_York' %}")
|
||||
result = template.render()
|
||||
expected = arrow.now("America/New_York").strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_parse_no_operator(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'UTC' %}")
|
||||
result = template.render()
|
||||
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_parse_with_add(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'UTC' + 'hours=2' %}")
|
||||
result = template.render()
|
||||
expected = arrow.now("UTC").shift(hours=2).strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_parse_with_subtract(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'UTC' - 'days=1' %}")
|
||||
result = template.render()
|
||||
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert result == expected
|
||||
|
||||
def test_parse_with_custom_format(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'UTC', '%d-%m-%Y' %}")
|
||||
result = template.render()
|
||||
expected = arrow.now("UTC").strftime("%d-%m-%Y")
|
||||
assert result == expected
|
||||
|
||||
def test_default_format(self, jinja_env: Environment) -> None:
|
||||
template = jinja_env.from_string("{% now 'UTC'%}")
|
||||
result = template.render()
|
||||
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S") # default format
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_timezone(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid timezone"):
|
||||
jinja_extension._get_datetime("Invalid/Timezone")
|
||||
|
||||
def test_invalid_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid offset or operator"):
|
||||
jinja_extension._get_datetime("UTC", operator="+", offset="invalid_format")
|
||||
|
||||
def test_invalid_operator(self, jinja_extension: Jinja2TimeExtension) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid offset or operator"):
|
||||
jinja_extension._get_datetime("UTC", operator="*", offset="hours=2")
|
||||
Loading…
x
Reference in New Issue
Block a user