feat: Add current date in UTC to PromptBuilder (#8233)

* initial commit

* add unit tests

* add release notes

* update function name
This commit is contained in:
Mo Sriha 2024-09-09 02:47:03 -05:00 committed by GitHub
parent e31b3edda1
commit 75955922b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 264 additions and 3 deletions

View File

@ -8,6 +8,7 @@ from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from haystack import component, default_to_dict
from haystack.utils import Jinja2TimeExtension
@component
@ -160,10 +161,10 @@ class PromptBuilder:
self._required_variables = required_variables
self.required_variables = required_variables or []
self._env = SandboxedEnvironment()
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
self.template = self._env.from_string(template)
if not variables:
# infere variables from template
# infer variables from template
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)

View File

@ -8,6 +8,7 @@ from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .docstore_deserialization import deserialize_document_store_in_init_params_inplace
from .expit import expit
from .filters import document_matches_filter
from .jinja2_extensions import Jinja2TimeExtension
from .jupyter import is_in_jupyter
from .requests_utils import request_with_retry
from .type_serialization import deserialize_type, serialize_type
@ -28,4 +29,5 @@ __all__ = [
"serialize_type",
"deserialize_type",
"deserialize_document_store_in_init_params_inplace",
"Jinja2TimeExtension",
]

View File

@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional, Union
import arrow
from jinja2 import Environment, nodes
from jinja2.ext import Extension
class Jinja2TimeExtension(Extension):
# Syntax for current date
tags = {"now"}
def __init__(self, environment: Environment): # pylint: disable=useless-parent-delegation
"""
Initializes the JinjaTimeExtension object.
:param environment: The Jinja2 environment to initialize the extension with.
It provides the context where the extension will operate.
"""
super().__init__(environment)
@staticmethod
def _get_datetime(
timezone: str,
operator: Optional[str] = None,
offset: Optional[str] = None,
datetime_format: Optional[str] = None,
) -> str:
"""
Get the current datetime based on timezone, apply any offset if provided, and format the result.
:param timezone: The timezone string (e.g., 'UTC' or 'America/New_York') for which the current
time should be fetched.
:param operator: The operator ('+' or '-') to apply to the offset (used for adding/subtracting intervals).
Defaults to None if no offset is applied, otherwise default is '+'.
:param offset: The offset string in the format 'interval=value' (e.g., 'hours=2,days=1') specifying how much
to adjust the datetime. The intervals can be any valid interval accepted
by Arrow (e.g., hours, days, weeks, months). Defaults to None if no adjustment is needed.
:param datetime_format: The format string to use for formatting the output datetime.
Defaults to '%Y-%m-%d %H:%M:%S' if not provided.
"""
try:
dt = arrow.now(timezone)
except Exception as e:
raise ValueError(f"Invalid timezone {timezone}: {e}")
if offset and operator:
try:
# Parse the offset and apply it to the datetime object
replace_params = {
interval.strip(): float(operator + value.strip())
for param in offset.split(",")
for interval, value in [param.split("=")]
}
# Shift the datetime fields based on the parsed offset
dt = dt.shift(**replace_params)
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid offset or operator {offset}, {operator}: {e}")
# Use the provided format or fallback to the default one
datetime_format = datetime_format or "%Y-%m-%d %H:%M:%S"
return dt.strftime(datetime_format)
def parse(self, parser: Any) -> Union[nodes.Node, List[nodes.Node]]:
"""
Parse the template expression to determine how to handle the datetime formatting.
:param parser: The parser object that processes the template expressions and manages the syntax tree.
It's used to interpret the template's structure.
"""
lineno = next(parser.stream).lineno
node = parser.parse_expression()
# Check if a custom datetime format is provided after a comma
datetime_format = parser.parse_expression() if parser.stream.skip_if("comma") else nodes.Const(None)
# Default Add when no operator is provided
operator = "+" if isinstance(node, nodes.Add) else "-"
# Call the _get_datetime method with the appropriate operator and offset, if exist
call_method = self.call_method(
"_get_datetime",
[node.left, nodes.Const(operator), node.right, datetime_format]
if isinstance(node, (nodes.Add, nodes.Sub))
else [node, nodes.Const(None), nodes.Const(None), datetime_format],
lineno=lineno,
)
return nodes.Output([call_method], lineno=lineno)

View File

@ -57,6 +57,7 @@ dependencies = [
"numpy<2",
"python-dateutil",
"haystack-experimental",
"arrow>=1.3.0" # Jinja2TimeExtension
]
[tool.hatch.envs.default]

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Add the current date inside a template in `PromptBuilder` using the `utc_now()`.
Users can also specify the date format, such as `utc_now("%Y-%m-%d")`.

View File

@ -4,7 +4,8 @@
from typing import Any, Dict, List, Optional
from jinja2 import TemplateSyntaxError
import pytest
from unittest.mock import patch, MagicMock
import arrow
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack import component
from haystack.core.pipeline.pipeline import Pipeline
@ -254,3 +255,59 @@ class TestPromptBuilder:
"prompt_builder": {"prompt": "This is the dynamic prompt:\n Query: Where does the speaker live?"}
}
assert result == expected_dynamic
def test_with_custom_dateformat(self) -> None:
template = "Formatted date: {% now 'UTC', '%Y-%m-%d' %}"
builder = PromptBuilder(template=template)
result = builder.run()["prompt"]
now_formatted = f"Formatted date: {arrow.now('UTC').strftime('%Y-%m-%d')}"
assert now_formatted == result
def test_with_different_timezone(self) -> None:
template = "Current time in New York is: {% now 'America/New_York' %}"
builder = PromptBuilder(template=template)
result = builder.run()["prompt"]
now_ny = f"Current time in New York is: {arrow.now('America/New_York').strftime('%Y-%m-%d %H:%M:%S')}"
assert now_ny == result
def test_date_with_addition_offset(self) -> None:
template = "Time after 2 hours is: {% now 'UTC' + 'hours=2' %}"
builder = PromptBuilder(template=template)
result = builder.run()["prompt"]
now_plus_2 = f"Time after 2 hours is: {(arrow.now('UTC').shift(hours=+2)).strftime('%Y-%m-%d %H:%M:%S')}"
assert now_plus_2 == result
def test_date_with_substraction_offset(self) -> None:
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
builder = PromptBuilder(template=template)
result = builder.run()["prompt"]
now_plus_2 = f"Time after 12 days is: {(arrow.now('UTC').shift(days=-12)).strftime('%Y-%m-%d %H:%M:%S')}"
assert now_plus_2 == result
def test_invalid_timezone(self) -> None:
template = "Current time is: {% now 'Invalid/Timezone' %}"
builder = PromptBuilder(template=template)
# Expect ValueError for invalid timezone
with pytest.raises(ValueError, match="Invalid timezone"):
builder.run()
def test_invalid_offset(self) -> None:
template = "Time after invalid offset is: {% now 'UTC' + 'invalid_offset' %}"
builder = PromptBuilder(template=template)
# Expect ValueError for invalid offset
with pytest.raises(ValueError, match="Invalid offset or operator"):
builder.run()

View File

@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from jinja2 import Environment
import arrow
from haystack.utils import Jinja2TimeExtension
class TestJinja2TimeExtension:
@pytest.fixture
def jinja_env(self) -> Environment:
return Environment(extensions=[Jinja2TimeExtension])
@pytest.fixture
def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension:
return Jinja2TimeExtension(jinja_env)
def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime(
"UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"
)
assert isinstance(result, str)
assert len(result) == 19
def test_parse_valid_expression(self, jinja_env: Environment) -> None:
template = "{% now 'UTC' + 'hours=2', '%Y-%m-%d %H:%M:%S' %}"
result = jinja_env.from_string(template).render()
assert isinstance(result, str)
assert len(result) == 19
def test_get_datetime_no_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC")
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_get_datetime_with_offset_add(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="+", offset="hours=1")
expected = arrow.now("UTC").shift(hours=1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_get_datetime_with_offset_subtract(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1")
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_get_datetime_with_offset_subtract_days_hours(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1, hours=2")
expected = arrow.now("UTC").shift(days=-1, hours=-2).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_get_datetime_with_custom_format(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", datetime_format="%d-%m-%Y")
expected = arrow.now("UTC").strftime("%d-%m-%Y")
assert result == expected
def test_get_datetime_new_york_timezone(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'America/New_York' %}")
result = template.render()
expected = arrow.now("America/New_York").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_parse_no_operator(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' %}")
result = template.render()
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_parse_with_add(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' + 'hours=2' %}")
result = template.render()
expected = arrow.now("UTC").shift(hours=2).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_parse_with_subtract(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' - 'days=1' %}")
result = template.render()
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected
def test_parse_with_custom_format(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC', '%d-%m-%Y' %}")
result = template.render()
expected = arrow.now("UTC").strftime("%d-%m-%Y")
assert result == expected
def test_default_format(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC'%}")
result = template.render()
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S") # default format
assert result == expected
def test_invalid_timezone(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid timezone"):
jinja_extension._get_datetime("Invalid/Timezone")
def test_invalid_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid offset or operator"):
jinja_extension._get_datetime("UTC", operator="+", offset="invalid_format")
def test_invalid_operator(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid offset or operator"):
jinja_extension._get_datetime("UTC", operator="*", offset="hours=2")