feat: current_datetime shaper function (#5195)

* current_datetime shaper

* explicitly add current_datetime to the functions allowed in a prompt template
This commit is contained in:
ZanSara 2023-06-23 10:33:34 +02:00 committed by GitHub
parent 612c5cd005
commit 36192eca72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 3 deletions

View File

@ -1,6 +1,7 @@
from functools import reduce
import inspect
import re
from datetime import datetime
from string import Template
from typing import Literal, Optional, List, Dict, Any, Tuple, Union, Callable
@ -26,6 +27,19 @@ def rename(value: Any) -> Any:
return value
def current_datetime(format: str = "%H:%M:%S %d/%m/%y") -> str:
"""
Function that outputs the current time and/or date formatted according to the parameters.
Example:
```python
assert current_datetime("%d.%m.%y %H:%M:%S") == 01.01.2023 12:30:10
```
"""
return datetime.now().strftime(format)
def value_to_list(value: Any, target_list: List[Any]) -> List[Any]:
"""
Transforms a value into a list containing this value as many times as the length of the target list.
@ -544,6 +558,7 @@ def documents_to_strings(
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Any]] = {
"rename": rename,
"current_datetime": current_datetime,
"value_to_list": value_to_list,
"join_lists": join_lists,
"join_strings": join_strings,

View File

@ -22,6 +22,7 @@ from haystack.nodes.prompt.shapers import ( # pylint: disable=unused-import
BaseOutputParser,
AnswerParser,
to_strings,
current_datetime,
join, # used as shaping function
format_document,
format_answer,
@ -37,7 +38,10 @@ from haystack.environment import (
logger = logging.getLogger(__name__)
PROMPT_TEMPLATE_ALLOWED_FUNCTIONS = json.loads(
os.environ.get(HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS, '["join", "to_strings", "replace", "enumerate", "str"]')
os.environ.get(
HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS,
'["join", "to_strings", "replace", "enumerate", "str", "current_datetime"]',
)
)
PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS = {"new_line": "\n", "tab": "\t", "double_quote": '"', "carriage_return": "\r"}
PROMPT_TEMPLATE_STRIPS = ["'", '"']
@ -289,7 +293,8 @@ class _FstringParamsTransformer(ast.NodeTransformer):
def visit_FormattedValue(self, node: ast.FormattedValue) -> Optional[ast.AST]:
"""
Replaces the f-string expression with a unique ID and stores the corresponding expression in a dictionary.
If the expression is the raw `documents` variable, it is encapsulated into a call to `documents_to_strings` to ensure that the documents get rendered correctly.
If the expression is the raw `documents` variable, it is encapsulated into a call to `documents_to_strings`
to ensure that the documents get rendered correctly.
"""
super().generic_visit(node)

View File

@ -4,6 +4,7 @@ from haystack.schema import Answer, Document
from haystack.nodes.other.shaper import ( # pylint: disable=unused-import
Shaper,
current_datetime, # used as shaping function
join_documents_to_string as join, # used as shaping function
format_document,
format_answer,

View File

@ -1,6 +1,8 @@
import pytest
from datetime import datetime
import logging
import pytest
import haystack
from haystack import Pipeline, Document, Answer
from haystack.document_stores.memory import InMemoryDocumentStore
@ -191,6 +193,46 @@ def test_rename_yaml(tmp_path):
assert result["invocation_context"]["questions"] == "test query"
#
# current_datetime
#
@pytest.mark.unit
def test_current_datetime():
shaper = Shaper(func="current_datetime", inputs={}, outputs=["date_time"], params={"format": "%y-%m-%d"})
results, _ = shaper.run()
assert results["invocation_context"]["date_time"] == datetime.now().strftime("%y-%m-%d")
@pytest.mark.unit
def test_current_datetime_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: current_datetime
params:
format: "%y-%m-%d"
outputs:
- date_time
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run()
assert result["invocation_context"]["date_time"] == datetime.now().strftime("%y-%m-%d")
#
# value_to_list
#