feat: PromptTemplate extensions (#4378)

* use outputshapers in prompttemplate

* fix pylint

* first iteration on regex

* implement new promptnode syntax based on f-strings

* finish fstring implementation

* add additional tests

* add security tests

* fix mypy

* fix pylint

* fix test_prompt_templates

* fix test_prompt_template_repr

* fix test_prompt_node_with_custom_invocation_layer

* fix test_invalid_template

* more security tests

* fix test_complex_pipeline_with_all_features

* fix agent tests

* refactor get_prompt_template

* fix test_prompt_template_syntax_parser

* fix test_complex_pipeline_with_all_features

* allow functions in comprehensions

* break out of fstring test

* fix additional tests

* mark new tests as unit tests

* fix agents tests

* convert missing templates

* proper use of get_prompt_template

* refactor and add docstrings

* fix tests

* fix pylint

* fix agents test

* fix tests

* refactor globals

* make allowed functions configurable via env variable

* better dummy variable

* fix special alias

* don't replace special char variables

* more special chars, better docstrings

* cherrypick fix audio tests

* fix test

* rework shapers

* fix pylint

* fix tests

* add new templates

* add reference parsing

* add more shaper tests

* add tests for join and to_string

* fix pylint

* fix pylint

* fix pylint for real

* auto fill shaper function params

* fix reference parsing for multiple references

* fix output variable inference

* consolidate qa prompt template output and make shaper work per-document

* fix types after merge

* introduce output_parser

* fix tests

* better docstring

* rename RegexAnswerParser to AnswerParser

* better docstrings
This commit is contained in:
tstadel 2023-03-27 12:14:11 +02:00 committed by GitHub
parent 9518bcb7a8
commit 382ca8094e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1593 additions and 301 deletions

View File

@ -163,9 +163,7 @@ class Agent:
)
)
self.prompt_node = prompt_node
self.prompt_template = (
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
)
self.prompt_template = prompt_node.get_prompt_template(prompt_template)
self.tools = {tool.name: tool for tool in tools} if tools else {}
self.tool_names = ", ".join(self.tools.keys())
self.tool_names_with_descriptions = "\n".join(

View File

@ -15,6 +15,9 @@ HAYSTACK_REMOTE_API_BACKOFF_SEC = "HAYSTACK_REMOTE_API_BACKOFF_SEC"
HAYSTACK_REMOTE_API_MAX_RETRIES = "HAYSTACK_REMOTE_API_MAX_RETRIES"
HAYSTACK_REMOTE_API_TIMEOUT_SEC = "HAYSTACK_REMOTE_API_TIMEOUT_SEC"
HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS = "HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS"
env_meta_data: Dict[str, Any] = {}
logger = logging.getLogger(__name__)

View File

@ -92,12 +92,11 @@ class OpenAIAnswerGenerator(BaseGenerator):
PromptTemplate(
name="question-answering-with-examples",
prompt_text="Please answer the question according to the above context."
"\n===\nContext: $examples_context\n===\n$examples\n\n"
"===\nContext: $context\n===\n$query",
prompt_params=["examples_context", "examples", "context", "query"],
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
)
```
To learn how variables, such as'$context', are substituted in the `prompt_text`, see
To learn how variables, such as'{context}', are substituted in the `prompt_text`, see
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
:param context_join_str: The separation string used to join the input documents to create the context
used by the PromptTemplate.
@ -118,9 +117,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
prompt_template = PromptTemplate(
name="question-answering-with-examples",
prompt_text="Please answer the question according to the above context."
"\n===\nContext: $examples_context\n===\n$examples\n\n"
"===\nContext: $context\n===\n$query",
prompt_params=["examples_context", "examples", "context", "query"],
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
)
else:
# Check for required prompts

View File

@ -1,4 +1,8 @@
from typing import Optional, List, Dict, Any, Tuple, Union, Callable
from functools import reduce
import inspect
import re
from string import Template
from typing import Literal, Optional, List, Dict, Any, Tuple, Union, Callable
import logging
@ -9,49 +13,49 @@ from haystack.schema import Document, Answer, MultiLabel
logger = logging.getLogger(__name__)
def rename(value: Any) -> Tuple[Any]:
def rename(value: Any) -> Any:
"""
Identity function. Can be used to rename values in the invocation context without changing them.
Example:
```python
assert rename(1) == (1, )
assert rename(1) == 1
```
"""
return (value,)
return value
def value_to_list(value: Any, target_list: List[Any]) -> Tuple[List[Any]]:
def value_to_list(value: Any, target_list: List[Any]) -> List[Any]:
"""
Transforms a value into a list containing this value as many times as the length of the target list.
Example:
```python
assert value_to_list(value=1, target_list=list(range(5))) == ([1, 1, 1, 1, 1], )
assert value_to_list(value=1, target_list=list(range(5))) == [1, 1, 1, 1, 1]
```
"""
return ([value] * len(target_list),)
return [value] * len(target_list)
def join_lists(lists: List[List[Any]]) -> Tuple[List[Any]]:
def join_lists(lists: List[List[Any]]) -> List[Any]:
"""
Joins the passed lists into a single one.
Example:
```python
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == ([1, 2, 3, 4, 5], )
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == [1, 2, 3, 4, 5]
```
"""
merged_list = []
for inner_list in lists:
merged_list += inner_list
return (merged_list,)
return merged_list
def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
def join_strings(strings: List[str], delimiter: str = " ", str_replace: Optional[Dict[str, str]] = None) -> str:
"""
Transforms a list of strings into a single string. The content of this string
is the content of all original strings separated by the delimiter you specify.
@ -59,16 +63,42 @@ def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
Example:
```python
assert join_strings(strings=["first", "second", "third"], delimiter=" - ") == ("first - second - third", )
assert join_strings(strings=["first", "second", "third"], delimiter=" - ", str_replace={"r": "R"}) == "fiRst - second - thiRd"
```
"""
return (delimiter.join(strings),)
str_replace = str_replace or {}
return delimiter.join([format_string(string, str_replace) for string in strings])
def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[List[Document]]:
def format_string(string: str, str_replace: Optional[Dict[str, str]] = None) -> str:
"""
Transforms a string using a substitution dict.
Example:
```python
assert format_string(string="first", str_replace={"r": "R"}) == "fiRst"
```
"""
str_replace = str_replace or {}
return reduce(lambda s, kv: s.replace(*kv), str_replace.items(), string)
def join_documents(
documents: List[Document],
delimiter: str = " ",
pattern: Optional[str] = None,
str_replace: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
Transforms a list of documents into a list containing a single Document. The content of this list
is the content of all original documents separated by the delimiter you specify.
is the joined result of all original documents separated by the delimiter you specify.
How each document is represented is controlled by the pattern parameter.
You can use the following placeholders:
- $content: the content of the document
- $idx: the index of the document in the list
- $id: the id of the document
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
All metadata is dropped. (TODO: fix)
@ -81,31 +111,333 @@ def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[Lis
Document(content="second"),
Document(content="third")
],
delimiter=" - "
) == ([Document(content="first - second - third")], )
delimiter=" - ",
pattern="[$idx] $content",
str_replace={"r": "R"}
) == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
```
"""
return ([Document(content=delimiter.join([d.content for d in documents]))],)
return [Document(content=join_documents_to_string(documents, delimiter, pattern, str_replace))]
def strings_to_answers(strings: List[str]) -> Tuple[List[Answer]]:
def format_document(
document: Document,
pattern: Optional[str] = None,
str_replace: Optional[Dict[str, str]] = None,
idx: Optional[int] = None,
) -> str:
"""
Transforms a list of strings into a list of Answers.
Transforms a document into a single string.
How the document is represented is controlled by the pattern parameter.
You can use the following placeholders:
- $content: the content of the document
- $idx: the index of the document in the list
- $id: the id of the document
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
Example:
```python
assert strings_to_answers(strings=["first", "second", "third"]) == ([
Answer(answer="first"),
Answer(answer="second"),
Answer(answer="third"),
], )
assert format_document(
document=Document(content="first"),
pattern="prefix [$idx] $content",
str_replace={"r": "R"},
idx=1,
) == "prefix [1] fiRst"
```
"""
return ([Answer(answer=string, type="generative") for string in strings],)
str_replace = str_replace or {}
pattern = pattern or "$content"
template = Template(pattern)
pattern_params = [
match.groupdict().get("named", match.groupdict().get("braced"))
for match in template.pattern.finditer(template.template)
]
meta_params = [param for param in pattern_params if param and param not in ["content", "idx", "id"]]
content = template.substitute(
{
"idx": idx,
"content": reduce(lambda content, kv: content.replace(*kv), str_replace.items(), document.content),
"id": reduce(lambda id, kv: id.replace(*kv), str_replace.items(), document.id),
**{
k: reduce(lambda val, kv: val.replace(*kv), str_replace.items(), document.meta.get(k, ""))
for k in meta_params
},
}
)
return content
def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
def format_answer(
answer: Answer,
pattern: Optional[str] = None,
str_replace: Optional[Dict[str, str]] = None,
idx: Optional[int] = None,
) -> str:
"""
Transforms an answer into a single string.
How the answer is represented is controlled by the pattern parameter.
You can use the following placeholders:
- $answer: the answer text of the answer
- $idx: the index of the answer in the list
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
Example:
```python
assert format_answer(
answer=Answer(answer="first"),
pattern="prefix [$idx] $answer",
str_replace={"r": "R"},
idx=1,
) == "prefix [1] fiRst"
```
"""
str_replace = str_replace or {}
pattern = pattern or "$answer"
template = Template(pattern)
pattern_params = [
match.groupdict().get("named", match.groupdict().get("braced"))
for match in template.pattern.finditer(template.template)
]
meta_params = [param for param in pattern_params if param and param not in ["answer", "idx"]]
meta = answer.meta or {}
content = template.substitute(
{
"idx": idx,
"answer": reduce(lambda content, kv: content.replace(*kv), str_replace.items(), answer.answer),
**{k: reduce(lambda val, kv: val.replace(*kv), str_replace.items(), meta.get(k, "")) for k in meta_params},
}
)
return content
def join_documents_to_string(
documents: List[Document],
delimiter: str = " ",
pattern: Optional[str] = None,
str_replace: Optional[Dict[str, str]] = None,
) -> str:
"""
Transforms a list of documents into a single string. The content of this string
is the joined result of all original documents separated by the delimiter you specify.
How each document is represented is controlled by the pattern parameter.
You can use the following placeholders:
- $content: the content of the document
- $idx: the index of the document in the list
- $id: the id of the document
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
Example:
```python
assert join_documents_to_string(
documents=[
Document(content="first"),
Document(content="second"),
Document(content="third")
],
delimiter=" - ",
pattern="[$idx] $content",
str_replace={"r": "R"}
) == "[1] fiRst - [2] second - [3] thiRd"
```
"""
content = delimiter.join(
format_document(doc, pattern, str_replace, idx=idx) for idx, doc in enumerate(documents, start=1)
)
return content
def strings_to_answers(
strings: List[str],
prompts: Optional[List[Union[str, List[Dict[str, str]]]]] = None,
documents: Optional[List[Document]] = None,
pattern: Optional[str] = None,
reference_pattern: Optional[str] = None,
reference_mode: Literal["index", "id", "meta"] = "index",
reference_meta_field: Optional[str] = None,
) -> List[Answer]:
"""
Transforms a list of strings into a list of Answers.
Specify `reference_pattern` to populate the answer's `document_ids` by extracting document references from the strings.
:param strings: The list of strings to transform.
:param prompts: The prompts used to generate the answers.
:param documents: The documents used to generate the answers.
:param pattern: The regex pattern to use for parsing the answer.
Examples:
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
:param reference_pattern: The regex pattern to use for parsing the document references.
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
If None, no parsing is done and all documents are referenced.
:param reference_mode: The mode used to reference documents. Supported modes are:
- index: the document references are the one-based index of the document in the list of documents.
Example: "this is an answer[1]" will reference the first document in the list of documents.
- id: the document references are the document ids.
Example: "this is an answer[123]" will reference the document with id "123".
- meta: the document references are the value of a metadata field of the document.
Example: "this is an answer[123]" will reference the document with the value "123" in the metadata field specified by reference_meta_field.
:param reference_meta_field: The name of the metadata field to use for document references in reference_mode "meta".
:return: The list of answers.
Examples:
Without reference parsing:
```python
assert strings_to_answers(strings=["first", "second", "third"], prompt="prompt", documents=[Document(id="123", content="content")]) == [
Answer(answer="first", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
Answer(answer="second", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
Answer(answer="third", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
]
```
With reference parsing:
```python
assert strings_to_answers(strings=["first[1]", "second[2]", "third[1][3]"], prompt="prompt",
documents=[Document(id="123", content="content"), Document(id="456", content="content"), Document(id="789", content="content")],
reference_pattern=r"\\[(\\d+)\\]",
reference_mode="index"
) == [
Answer(answer="first", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
Answer(answer="second", type="generative", document_ids=["456"], meta={"prompt": "prompt"}),
Answer(answer="third", type="generative", document_ids=["123", "789"], meta={"prompt": "prompt"}),
]
```
"""
if prompts:
if len(prompts) == 1:
# one prompt for all strings/documents
documents_per_string: List[Optional[List[Document]]] = [documents] * len(strings)
prompt_per_string: List[Optional[Union[str, List[Dict[str, str]]]]] = [prompts[0]] * len(strings)
elif len(prompts) > 1 and len(strings) % len(prompts) == 0:
# one prompt per string/document
if documents is not None and len(documents) != len(prompts):
raise ValueError("The number of documents must match the number of prompts")
string_multiplier = len(strings) // len(prompts)
documents_per_string = (
[[doc] for doc in documents for _ in range(string_multiplier)] if documents else [None] * len(strings)
)
prompt_per_string = [prompt for prompt in prompts for _ in range(string_multiplier)]
else:
raise ValueError("The number of prompts must be one or a multiple of the number of strings")
else:
documents_per_string = [documents] * len(strings)
prompt_per_string = [None] * len(strings)
answers = []
for string, prompt, _documents in zip(strings, prompt_per_string, documents_per_string):
answer = string_to_answer(
string=string,
prompt=prompt,
documents=_documents,
pattern=pattern,
reference_pattern=reference_pattern,
reference_mode=reference_mode,
reference_meta_field=reference_meta_field,
)
answers.append(answer)
return answers
def string_to_answer(
string: str,
prompt: Optional[Union[str, List[Dict[str, str]]]],
documents: Optional[List[Document]],
pattern: Optional[str] = None,
reference_pattern: Optional[str] = None,
reference_mode: Literal["index", "id", "meta"] = "index",
reference_meta_field: Optional[str] = None,
) -> Answer:
"""
Transforms a string into an Answer.
Specify `reference_pattern` to populate the answer's `document_ids` by extracting document references from the string.
:param string: The string to transform.
:param prompt: The prompt used to generate the answer.
:param documents: The documents used to generate the answer.
:param pattern: The regex pattern to use for parsing the answer.
Examples:
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
:param reference_pattern: The regex pattern to use for parsing the document references.
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
If None, no parsing is done and all documents are referenced.
:param reference_mode: The mode used to reference documents. Supported modes are:
- index: the document references are the one-based index of the document in the list of documents.
Example: "this is an answer[1]" will reference the first document in the list of documents.
- id: the document references are the document ids.
Example: "this is an answer[123]" will reference the document with id "123".
- meta: the document references are the value of a metadata field of the document.
Example: "this is an answer[123]" will reference the document with the value "123" in the metadata field specified by reference_meta_field.
:param reference_meta_field: The name of the metadata field to use for document references in reference_mode "meta".
:return: The answer
"""
if reference_mode == "index":
candidates = {str(idx): doc.id for idx, doc in enumerate(documents, start=1)} if documents else {}
elif reference_mode == "id":
candidates = {doc.id: doc.id for doc in documents} if documents else {}
elif reference_mode == "meta":
if not reference_meta_field:
raise ValueError("reference_meta_field must be specified when reference_mode is 'meta'")
candidates = (
{doc.meta[reference_meta_field]: doc.id for doc in documents if doc.meta.get(reference_meta_field)}
if documents
else {}
)
else:
raise ValueError(f"Invalid document_id_mode: {reference_mode}")
if pattern:
match = re.search(pattern, string)
if match:
if not match.lastindex:
# no group in pattern -> take the whole match
string = match.group(0)
elif match.lastindex == 1:
# one group in pattern -> take the group
string = match.group(1)
else:
# more than one group in pattern -> raise error
raise ValueError(f"Pattern must have at most one group: {pattern}")
else:
string = ""
document_ids = parse_references(string=string, reference_pattern=reference_pattern, candidates=candidates)
answer = Answer(answer=string, type="generative", document_ids=document_ids, meta={"prompt": prompt})
return answer
def parse_references(
string: str, reference_pattern: Optional[str] = None, candidates: Optional[Dict[str, str]] = None
) -> Optional[List[str]]:
"""
Parses an answer string for document references and returns the document ids of the referenced documents.
:param string: The string to parse.
:param reference_pattern: The regex pattern to use for parsing the document references.
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
If None, no parsing is done and all candidate document ids are returned.
:param candidates: A dictionary of candidates to choose from. The keys are the reference strings and the values are the document ids.
If None, no parsing is done and None is returned.
:return: A list of document ids.
"""
if not candidates:
return None
if not reference_pattern:
return list(candidates.values())
document_idxs = re.findall(reference_pattern, string)
return [candidates[idx] for idx in document_idxs if idx in candidates]
def answers_to_strings(
answers: List[Answer], pattern: Optional[str] = None, str_replace: Optional[Dict[str, str]] = None
) -> List[str]:
"""
Extracts the content field of Documents and returns a list of strings.
@ -117,18 +449,20 @@ def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
Answer(answer="first"),
Answer(answer="second"),
Answer(answer="third")
]
) == (["first", "second", "third"],)
],
pattern="[$idx] $answer",
str_replace={"r": "R"}
) == ["[1] fiRst", "[2] second", "[3] thiRd"]
```
"""
return ([answer.answer for answer in answers],)
return [format_answer(answer, pattern, str_replace, idx) for idx, answer in enumerate(answers, start=1)]
def strings_to_documents(
strings: List[str],
meta: Union[List[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> Tuple[List[Document]]:
) -> List[Document]:
"""
Transforms a list of strings into a list of Documents. If you pass the metadata in a single
dictionary, all Documents get the same metadata. If you pass the metadata as a list, the length of this list
@ -142,11 +476,11 @@ def strings_to_documents(
strings=["first", "second", "third"],
meta=[{"position": i} for i in range(3)],
id_hash_keys=['content', 'meta]
) == ([
) == [
Document(content="first", metadata={"position": 1}, id_hash_keys=['content', 'meta])]),
Document(content="second", metadata={"position": 2}, id_hash_keys=['content', 'meta]),
Document(content="third", metadata={"position": 3}, id_hash_keys=['content', 'meta])
], )
]
```
"""
all_metadata: List[Optional[Dict[str, Any]]]
@ -161,10 +495,12 @@ def strings_to_documents(
else:
all_metadata = [None] * len(strings)
return ([Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)],)
return [Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)]
def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
def documents_to_strings(
documents: List[Document], pattern: Optional[str] = None, str_replace: Optional[Dict[str, str]] = None
) -> List[str]:
"""
Extracts the content field of Documents and returns a list of strings.
@ -176,14 +512,16 @@ def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
Document(content="first"),
Document(content="second"),
Document(content="third")
]
) == (["first", "second", "third"],)
],
pattern="[$idx] $content",
str_replace={"r": "R"}
) == ["[1] fiRst", "[2] second", "[3] thiRd"]
```
"""
return ([doc.content for doc in documents],)
return [format_document(doc, pattern, str_replace, idx) for idx, doc in enumerate(documents, start=1)]
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Tuple[Any]]] = {
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Any]] = {
"rename": rename,
"value_to_list": value_to_list,
"join_lists": join_lists,
@ -389,6 +727,16 @@ class Shaper(BaseComponent):
if value in invocation_context.keys() and value is not None:
input_values[key] = invocation_context[value]
# auto fill in input values if there's an invocation context value with the same name
function_params = inspect.signature(self.function).parameters
for parameter in function_params.values():
if (
parameter.name not in input_values.keys()
and parameter.name not in self.params.keys()
and parameter.name in invocation_context.keys()
):
input_values[parameter.name] = invocation_context[parameter.name]
input_values = {**self.params, **input_values}
try:
logger.debug(
@ -397,6 +745,8 @@ class Shaper(BaseComponent):
", ".join([f"{key}={value}" for key, value in input_values.items()]),
)
output_values = self.function(**input_values)
if not isinstance(output_values, tuple):
output_values = (output_values,)
except TypeError as e:
raise ValueError(
"Shaper couldn't apply the function to your inputs and parameters. "

View File

@ -1,17 +1,27 @@
import ast
from collections import defaultdict
import copy
import logging
import re
from abc import ABC
from string import Template
import os
from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type, overload
from uuid import uuid4
import torch
from haystack import MultiLabel
from haystack.environment import HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS
from haystack.errors import NodeError
from haystack.nodes.base import BaseComponent
from haystack.nodes.other.shaper import ( # pylint: disable=unused-import
Shaper,
join_documents_to_string as join, # used as shaping function
format_document,
format_answer,
format_string,
)
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, instruction_following_models
from haystack.schema import Document
from haystack.schema import Answer, Document
from haystack.telemetry_2 import send_event
logger = logging.getLogger(__name__)
@ -43,6 +53,169 @@ class BasePromptTemplate(BaseComponent):
raise NotImplementedError("This method should never be implemented in the derived class")
PROMPT_TEMPLATE_ALLOWED_FUNCTIONS = ast.literal_eval(
os.environ.get(HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS, '["join", "to_strings", "replace", "enumerate", "str"]')
)
PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS = {"new_line": "\n", "tab": "\t", "double_quote": '"', "carriage_return": "\r"}
PROMPT_TEMPLATE_STRIPS = ["'", '"']
PROMPT_TEMPLATE_STR_REPLACE = {'"': "'"}
def to_strings(items: List[Union[str, Document, Answer]], pattern=None, str_replace=None) -> List[str]:
results = []
for idx, item in enumerate(items, start=1):
if isinstance(item, str):
results.append(format_string(item, str_replace=str_replace))
elif isinstance(item, Document):
results.append(format_document(document=item, pattern=pattern, str_replace=str_replace, idx=idx))
elif isinstance(item, Answer):
results.append(format_answer(answer=item, pattern=pattern, str_replace=str_replace, idx=idx))
else:
raise ValueError(f"Unsupported item type: {type(item)}")
return results
class PromptTemplateValidationError(NodeError):
"""
Error raised when a prompt template is invalid.
"""
pass
class _ValidationVisitor(ast.NodeVisitor):
"""
This class is used to validate the prompt text for a prompt template.
It checks that the prompt text is a valid f-string and that it only uses allowed functions.
Useful information extracted from the AST is stored in the class attributes (e.g. `prompt_params` and `used_functions`)
"""
def __init__(self, prompt_template_name: str):
self.used_names: List[str] = []
self.comprehension_targets: List[str] = []
self.used_functions: List[str] = []
self.prompt_template_name = prompt_template_name
@property
def prompt_params(self) -> List[str]:
"""
The names of the variables used in the prompt text.
E.g. for the prompt text `f"Hello {name}"`, the prompt_params would be `["name"]`
"""
return list(set(self.used_names) - set(self.used_functions) - set(self.comprehension_targets))
def visit_Name(self, node: ast.Name) -> None:
"""
Stores the name of the variable used in the prompt text. This also includes function and method names.
E.g. for the prompt text `f"Hello {func(name)}"`, the used_names would be `["func", "name"]`
"""
self.used_names.append(node.id)
def visit_comprehension(self, node: ast.comprehension) -> None:
"""
Stores the name of the variable used in comprehensions.
E.g. for the prompt text `f"Hello {[name for name in names]}"`, the comprehension_targets would be `["name"]`
"""
super().generic_visit(node)
if isinstance(node.target, ast.Name):
self.comprehension_targets.append(node.target.id)
elif isinstance(node.target, ast.Tuple):
self.comprehension_targets.extend([elt.id for elt in node.target.elts if isinstance(elt, ast.Name)])
def visit_Call(self, node: ast.Call) -> None:
"""
Stores the name of functions and methods used in the prompt text and validates that only allowed functions are used.
E.g. for the prompt text `f"Hello {func(name)}"`, the used_functions would be `["func"]`
raises: PromptTemplateValidationError if an invalid function is used in the prompt text
"""
super().generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS:
# functions: func(args, kwargs)
self.used_functions.append(node.func.id)
elif isinstance(node.func, ast.Attribute) and node.func.attr in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS:
# methods: instance.method(args, kwargs)
self.used_functions.append(node.func.attr)
else:
raise PromptTemplateValidationError(
f"Invalid function in prompt text for prompt template {self.prompt_template_name}. "
f"Allowed functions are {PROMPT_TEMPLATE_ALLOWED_FUNCTIONS}."
)
class _FstringParamsTransformer(ast.NodeTransformer):
"""
This class is used to transform an AST for f-strings into a format that can be used by the PromptTemplate.
It replaces all f-string expressions with a unique id and stores the corresponding expression in a dictionary.
The stored expressions can be evaluated using the `eval` function given the `prompt_params` (see _ValidatorVisitor) .
PromptTemplate determines the number of prompts to generate and renders them using the evaluated expressions.
"""
def __init__(self):
self.prompt_params_functions: Dict[str, ast.Expression] = {}
def visit_FormattedValue(self, node: ast.FormattedValue) -> Optional[ast.AST]:
"""
Replaces the f-string expression with a unique id and stores the corresponding expression in a dictionary.
If the expression is the raw `documents` variable, it is encapsulated into a call to `documents_to_strings` to ensure that the documents get rendered correctly.
"""
super().generic_visit(node)
# Keep special char variables as is. They are available via globals.
if isinstance(node.value, ast.Name) and node.value.id in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS:
return node
id = uuid4().hex
if isinstance(node.value, ast.Name) and node.value.id in ["documents", "answers"]:
call = ast.Call(func=ast.Name(id="to_strings", ctx=ast.Load()), args=[node.value], keywords=[])
self.prompt_params_functions[id] = ast.fix_missing_locations(ast.Expression(body=call))
else:
self.prompt_params_functions[id] = ast.fix_missing_locations(ast.Expression(body=node.value))
return ast.FormattedValue(
value=ast.Name(id=id, ctx=ast.Load()), conversion=node.conversion, format_spec=node.format_spec
)
class BaseOutputParser(Shaper):
"""
A output parser defines in `PromptTemplate` how to parse the model output and convert it into Haystack primitives.
BaseOutputParser is the base class for output parser implementations.
"""
@property
def output_variable(self) -> Optional[str]:
return self.outputs[0]
class AnswerParser(BaseOutputParser):
"""
AnswerParser is used to parse the model output to extract the answer into a proper `Answer` object using regex patterns.
AnswerParser enriches the `Answer` object with the used prompts and the document_ids of the documents that were used to generate the answer.
You can pass a reference_pattern to extract the document_ids of the answer from the model output.
"""
def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None):
"""
:param pattern: The regex pattern to use for parsing the answer.
Examples:
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
:param reference_pattern: The regex pattern to use for parsing the document references.
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
If None, no parsing is done and all documents are referenced.
"""
self.pattern = pattern
self.reference_pattern = reference_pattern
super().__init__(
func="strings_to_answers",
inputs={"strings": "results"},
outputs=["answers"],
params={"pattern": pattern, "reference_pattern": reference_pattern},
)
class PromptTemplate(BasePromptTemplate, ABC):
"""
PromptTemplate is a template for a prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what such prompt template may look like:
@ -50,54 +223,67 @@ class PromptTemplate(BasePromptTemplate, ABC):
```python
PromptTemplate(name="sentiment-analysis",
prompt_text="Give a sentiment for this context. Answer with positive, negative
or neutral. Context: $documents; Answer:")
or neutral. Context: {documents}; Answer:")
```
Optionally, you can declare prompt parameters in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
the prompt_text for the model to perform the task. For example, in the template above, there's one prompt parameter, `documents`. You declare prompt parameters by adding variables to the prompt text. These variables should be in the format: `$variable`. In the template above, the variable is `$documents`.
Optionally, you can declare prompt parameters using f-string syntax in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
the prompt_text for the model to perform the task. For example, in the template above, there's one prompt parameter, `documents`. You declare prompt parameters by adding variables to the prompt text. These variables should be in the format: `{variable}`. In the template above, the variable is `{documents}`.
At runtime, these variables are filled in with arguments passed to the `fill()` method of the PromptTemplate. So in the example above, the `$documents` variable will be filled with the Documents whose sentiment you want the model to analyze.
At runtime, these variables are filled in with arguments passed to the `fill()` method of the PromptTemplate. So in the example above, the `{documents}` variable will be filled with the Documents whose sentiment you want the model to analyze.
Note that other than strict f-string syntax, you can safely use the following backslash characters in text parts of the prompt text: `\n`, `\t`, `\r`.
If you want to use them in f-string expressions, use `new_line`, `tab`, `carriage_return` instead.
Double quotes (e.g. `"`) will be automatically replaced with single quotes (e.g. `'`) in the prompt text. If you want to use double quotes in the prompt text, use `{double_quote}` instead.
For more details on how to use PromptTemplate, see
[PromptNode](https://docs.haystack.deepset.ai/docs/prompt_node).
"""
def __init__(self, name: str, prompt_text: str, prompt_params: Optional[List[str]] = None):
def __init__(self, name: str, prompt_text: str, output_parser: Optional[BaseOutputParser] = None):
"""
Creates a PromptTemplate instance.
:param name: The name of the prompt template (for example, sentiment-analysis, question-generation). You can specify your own name but it must be unique.
:param prompt_text: The prompt text, including prompt parameters.
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
:param output_parser: A parser that will be applied to the model output.
For example, if you want to convert the model output to an Answer object, you can use `AnswerParser`.
"""
super().__init__()
if prompt_params:
self.prompt_params = prompt_params
else:
# Define the regex pattern to match the strings after the $ character
pattern = r"\$([a-zA-Z0-9_]+)"
self.prompt_params = re.findall(pattern, prompt_text)
if prompt_text.count("$") != len(self.prompt_params):
raise ValueError(
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
f"does not match the number of specified parameters {self.prompt_params}."
)
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
prompt_text = prompt_text.strip("'").strip('"')
for strip in PROMPT_TEMPLATE_STRIPS:
prompt_text = prompt_text.strip(strip)
replacements = {
**{v: "{" + k + "}" for k, v in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS.items()},
**PROMPT_TEMPLATE_STR_REPLACE,
}
for old, new in replacements.items():
prompt_text = prompt_text.replace(old, new)
t = Template(prompt_text)
try:
t.substitute(**{param: "" for param in self.prompt_params})
except KeyError as e:
raise ValueError(
f"Invalid parameter {e} in prompt text "
f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
)
self._ast_expression = ast.parse(f'f"{prompt_text}"', mode="eval")
ast_validator = _ValidationVisitor(prompt_template_name=name)
ast_validator.visit(self._ast_expression)
ast_transformer = _FstringParamsTransformer()
self._ast_expression = ast.fix_missing_locations(ast_transformer.visit(self._ast_expression))
self._prompt_params_functions = ast_transformer.prompt_params_functions
self._used_functions = ast_validator.used_functions
self.name = name
self.prompt_text = prompt_text
self.prompt_params = sorted(
param for param in ast_validator.prompt_params if param not in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS
)
self.globals = {
**{k: v for k, v in globals().items() if k in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS},
**PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS,
}
self.output_parser = output_parser
@property
def output_variable(self) -> Optional[str]:
return self.output_parser.output_variable if self.output_parser else None
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
"""
@ -107,7 +293,7 @@ class PromptTemplate(BasePromptTemplate, ABC):
:param kwargs: Keyword arguments to fill the parameters in the prompt text of a PromptTemplate.
:return: A dictionary with the prompt text and the prompt parameters.
"""
template_dict = {}
params_dict = {}
# attempt to resolve args first
if args:
if len(args) != len(self.prompt_params):
@ -119,28 +305,49 @@ class PromptTemplate(BasePromptTemplate, ABC):
args,
)
for prompt_param, arg in zip(self.prompt_params, args):
template_dict[prompt_param] = [arg] if isinstance(arg, str) else arg
params_dict[prompt_param] = [arg] if isinstance(arg, str) else arg
# then attempt to resolve kwargs
if kwargs:
for param in self.prompt_params:
if param in kwargs:
template_dict[param] = kwargs[param]
params_dict[param] = kwargs[param]
if set(template_dict.keys()) != set(self.prompt_params):
available_params = set(list(template_dict.keys()) + list(set(kwargs.keys())))
if set(params_dict.keys()) != set(self.prompt_params):
available_params = set(list(params_dict.keys()) + list(set(kwargs.keys())))
raise ValueError(f"Expected prompt parameters {self.prompt_params} but got {list(available_params)}.")
template_dict = {"_at_least_one_prompt": True}
for id, call in self._prompt_params_functions.items():
template_dict[id] = eval( # pylint: disable=eval-used
compile(call, filename="<string>", mode="eval"), self.globals, params_dict
)
return template_dict
def post_process(self, prompt_output: List[str], **kwargs) -> List[Any]:
"""
Post-processes the output of the prompt template.
:param args: Non-keyword arguments to use for post-processing the prompt output.
:param kwargs: Keyword arguments to use for post-processing the prompt output.
:return: A dictionary with the post-processed output.
"""
if self.output_parser:
invocation_context = kwargs
invocation_context["results"] = prompt_output
self.output_parser.run(invocation_context=invocation_context)
return invocation_context[self.output_parser.outputs[0]]
else:
return prompt_output
def fill(self, *args, **kwargs) -> Iterator[str]:
"""
Fills the parameters defined in the prompt text with the arguments passed to it and returns the iterator prompt text.
You can pass non-keyword (args) or keyword (kwargs) arguments to this method. If you pass non-keyword arguments, their order must match the left-to-right
order of appearance of the parameters in the prompt text. For example, if the prompt text is:
`Come up with a question for the given context and the answer. Context: $documents;
Answer: $answers; Question:`, then the first non-keyword argument fills the `$documents` variable
and the second non-keyword argument fills the `$answers` variable.
`Come up with a question for the given context and the answer. Context: {documents};
Answer: {answers}; Question:`, then the first non-keyword argument fills the `{documents}` variable
and the second non-keyword argument fills the `{answers}` variable.
If you pass keyword arguments, the order of the arguments doesn't matter. Variables in the
prompt text are filled with the corresponding keyword argument.
@ -150,12 +357,20 @@ class PromptTemplate(BasePromptTemplate, ABC):
:return: An iterator of prompt texts.
"""
template_dict = self.prepare(*args, **kwargs)
template = Template(self.prompt_text)
# the prompt context values should all be lists, as they will be split as one
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in template_dict.items()}
max_len = max(len(v) for v in prompt_context_copy.values())
if max_len > 1:
for key, value in prompt_context_copy.items():
if len(value) == 1:
prompt_context_copy[key] = value * max_len
for prompt_context_values in zip(*prompt_context_copy.values()):
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
prompt_prepared: str = template.substitute(template_input)
prompt_prepared: str = eval( # pylint: disable=eval-used
compile(self._ast_expression, filename="<string>", mode="eval"), self.globals, template_input
)
yield prompt_prepared
def __repr__(self):
@ -311,66 +526,85 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
return [
PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: "
"$questions; Answer:",
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="question-answering-per-document",
prompt_text="Given the context please answer the question. Context: {documents}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="question-answering-with-references",
prompt_text="Create a concise and informative answer (no more than 50 words) for a given question "
"based solely on the given documents. You must only use information from the given documents. "
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
"If multiple documents contain the answer, cite those documents like as stated in Document[number], Document[number], etc.. "
"If the documents do not contain the answer to the question, say that answering is not possible given the available information.\n"
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
),
PromptTemplate(
name="question-generation",
prompt_text="Given the context please generate a question. Context: $documents; Question:",
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
),
PromptTemplate(
name="conditioned-question-generation",
prompt_text="Please come up with a question for the given context and the answer. "
"Context: $documents; Answer: $answers; Question:",
"Context: {documents}; Answer: {answers}; Question:",
),
PromptTemplate(name="summarization", prompt_text="Summarize this document: $documents Summary:"),
PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
PromptTemplate(
name="question-answering-check",
prompt_text="Does the following context contain the answer to the question? "
"Context: $documents; Question: $questions; Please answer yes or no! Answer:",
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: $documents; Answer:",
"negative or neutral. Context: {documents}; Answer:",
),
PromptTemplate(
name="multiple-choice-question-answering",
prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. "
"Options: $options; Answer:",
prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. "
"Options: {options}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="topic-classification",
prompt_text="Categories: $options; What category best describes: $documents; Answer:",
prompt_text="Categories: {options}; What category best describes: {documents}; Answer:",
),
PromptTemplate(
name="language-detection",
prompt_text="Detect the language in the following context and answer with the "
"name of the language. Context: $documents; Answer:",
"name of the language. Context: {documents}; Answer:",
),
PromptTemplate(
name="translation",
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:",
),
PromptTemplate(
name="zero-shot-react",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"$tool_names_with_descriptions\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: pick one of $tool_names \n"
"Tool: pick one of {tool_names} \n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: $query\n"
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to ",
),
]
@ -430,6 +664,7 @@ class PromptNode(BaseComponent):
:param model_name_or_path: The name of the model to use or an instance of the PromptModel.
:param default_prompt_template: The default prompt template to use for the model.
:param output_variable: The name of the output variable in which you want to store the inference results.
If not set, PromptNode uses PromptTemplate's output_variable. If PromptTemplate's output_variable is not set, default name is `results`.
:param max_length: The maximum length of the generated text output.
:param api_key: The API key to use for the model.
:param use_auth_token: The authentication token to use for the model.
@ -449,7 +684,7 @@ class PromptNode(BaseComponent):
super().__init__()
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
self.output_variable: str = output_variable or "results"
self.output_variable: Optional[str] = output_variable
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
self.prompt_model: PromptModel
self.stop_words: Optional[List[str]] = stop_words
@ -480,7 +715,7 @@ class PromptNode(BaseComponent):
else:
raise ValueError("model_name_or_path must be either a string or a PromptModel object")
def __call__(self, *args, **kwargs) -> List[str]:
def __call__(self, *args, **kwargs) -> List[Any]:
"""
This method is invoked when the component is called directly, for example:
```python
@ -496,7 +731,7 @@ class PromptNode(BaseComponent):
else:
return self.prompt(self.default_prompt_template, *args, **kwargs)
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[Any]:
"""
Prompts the model and represents the central API for the PromptNode. It takes a prompt template,
a list of non-keyword and keyword arguments, and returns a list of strings - the responses from the underlying model.
@ -520,12 +755,7 @@ class PromptNode(BaseComponent):
kwargs = {**self._prepare_model_kwargs(), **kwargs}
prompt_template_used = prompt_template or self.default_prompt_template
if prompt_template_used:
if isinstance(prompt_template_used, PromptTemplate):
template_to_fill = prompt_template_used
elif isinstance(prompt_template_used, str):
template_to_fill = self.get_prompt_template(prompt_template_used)
else:
raise ValueError(f"{prompt_template_used} with args {args} , and kwargs {kwargs} not supported")
template_to_fill = self.get_prompt_template(prompt_template_used)
# prompt template used, yield prompts from inputs args
for prompt in template_to_fill.fill(*args, **kwargs):
@ -536,6 +766,9 @@ class PromptNode(BaseComponent):
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy)
output = self.prompt_model.invoke(prompt, **kwargs_copy)
results.extend(output)
kwargs["prompts"] = prompt_collector
results = template_to_fill.post_process(results, **kwargs)
else:
# straightforward prompt, no templates used
for prompt in list(args):
@ -607,15 +840,19 @@ class PromptNode(BaseComponent):
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
return template_name in self.prompt_templates
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> PromptTemplate:
"""
Returns a prompt template by name.
:param prompt_template_name: The name of the prompt template to be returned.
:return: The prompt template object.
"""
if prompt_template_name not in self.prompt_templates:
raise ValueError(f"Prompt template {prompt_template_name} not supported")
return self.prompt_templates[prompt_template_name]
if isinstance(prompt_template, PromptTemplate):
return prompt_template
if not isinstance(prompt_template, str) or prompt_template not in self.prompt_templates:
raise ValueError(f"Prompt template {prompt_template} not supported")
return self.prompt_templates[prompt_template]
def prompt_template_params(self, prompt_template: str) -> List[str]:
"""
@ -674,18 +911,14 @@ class PromptNode(BaseComponent):
if meta and "meta" not in invocation_context.keys():
invocation_context["meta"] = meta
if "documents" in invocation_context.keys():
for doc in invocation_context.get("documents", []):
if not isinstance(doc, str) and not isinstance(doc.content, str):
raise ValueError("PromptNode only accepts text documents.")
invocation_context["documents"] = [
doc.content if isinstance(doc, Document) else doc for doc in invocation_context.get("documents", [])
]
results = self(prompt_collector=prompt_collector, **invocation_context)
invocation_context[self.output_variable] = results
final_result: Dict[str, Any] = {self.output_variable: results, "invocation_context": invocation_context}
prompt_template = self.get_prompt_template(self.default_prompt_template)
output_variable = self.output_variable or prompt_template.output_variable or "results"
invocation_context[output_variable] = results
invocation_context["prompts"] = prompt_collector
final_result: Dict[str, Any] = {output_variable: results, "invocation_context": invocation_context}
if self.debug:
final_result["_debug"] = {"prompts_used": prompt_collector}
@ -717,15 +950,19 @@ class PromptNode(BaseComponent):
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
:param invocation_contexts: List of invocation contexts.
"""
prompt_template = self.get_prompt_template(self.default_prompt_template)
output_variable = self.output_variable or prompt_template.output_variable or "results"
inputs = PromptNode._flatten_inputs(queries, documents, invocation_contexts)
all_results: Dict[str, List] = {self.output_variable: [], "invocation_contexts": [], "_debug": []}
all_results: Dict[str, List] = defaultdict(list)
for query, docs, invocation_context in zip(
inputs["queries"], inputs["documents"], inputs["invocation_contexts"]
):
results = self.run(query=query, documents=docs, invocation_context=invocation_context)[0]
all_results[self.output_variable].append(results[self.output_variable])
all_results["invocation_contexts"].append(all_results["invocation_contexts"])
all_results["_debug"].append(all_results["_debug"])
all_results[output_variable].append(results[output_variable])
all_results["invocation_contexts"].append(results["invocation_context"])
if self.debug:
all_results["_debug"].append(results["_debug"])
return all_results, "output_1"
def _prepare_model_kwargs(self):

View File

@ -214,7 +214,7 @@ def test_tool_result_extraction(reader, retriever_with_docs):
assert result == "Paris" or result == "Madrid"
# PromptNode as a Tool
pt = PromptTemplate("test", "Here is a question: $query, Answer:")
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
pn = PromptNode(default_prompt_template=pt)
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
@ -253,8 +253,7 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
"Word: $query?\nLength: ",
prompt_params=["query"],
"Word: {query}?\nLength: ",
),
)
@ -304,8 +303,7 @@ def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
"Word: $query?\nLength: ",
prompt_params=["query"],
"Word: {query}\nLength: ",
),
)

View File

@ -378,28 +378,28 @@ class MockPromptNode(PromptNode):
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
return [""]
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
if prompt_template_name == "think-step-by-step":
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> PromptTemplate:
if prompt_template == "think-step-by-step":
return PromptTemplate(
name="think-step-by-step",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"$tool_names_with_descriptions\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: [$tool_names]\n"
"Tool: [{tool_names}]\n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: $query\n"
"Thought: Let's think step-by-step, I first need to $generated_text",
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to {generated_text}",
)
else:
return PromptTemplate(name="", prompt_text="")

View File

@ -140,8 +140,7 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs):
name="lfqa",
prompt_text="""
Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question.
\n===\Paragraphs: $context\n===\n$query""",
prompt_params=["context", "query"],
\n===\Paragraphs: {context}\n===\n{query}""",
)
node = OpenAIAnswerGenerator(
api_key=haystack_openai_config["api_key"],

View File

@ -11,7 +11,7 @@ from haystack.nodes.retriever.sparse import BM25Retriever
@pytest.fixture
def mock_function(monkeypatch):
monkeypatch.setattr(
haystack.nodes.other.shaper, "REGISTERED_FUNCTIONS", {"test_function": lambda a, b: ([a] * len(b),)}
haystack.nodes.other.shaper, "REGISTERED_FUNCTIONS", {"test_function": lambda a, b: [a] * len(b)}
)
@ -293,6 +293,17 @@ def test_join_strings_default_delimiter():
assert results["invocation_context"]["single_string"] == "first second"
@pytest.mark.unit
def test_join_strings_with_str_replace():
shaper = Shaper(
func="join_strings",
params={"strings": ["first", "second", "third"], "delimiter": " - ", "str_replace": {"r": "R"}},
outputs=["single_string"],
)
results, _ = shaper.run()
assert results["invocation_context"]["single_string"] == "fiRst - second - thiRd"
@pytest.mark.unit
def test_join_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -351,6 +362,38 @@ def test_join_strings_default_delimiter_yaml(tmp_path):
assert result["invocation_context"]["single_string"] == "first second third"
@pytest.mark.unit
def test_join_strings_with_str_replace_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: join_strings
inputs:
strings: documents
outputs:
- single_string
params:
delimiter: ' - '
str_replace:
r: R
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(documents=["first", "second", "third"])
assert result["invocation_context"]["single_string"] == "fiRst - second - thiRd"
#
# join_documents
#
@ -407,6 +450,20 @@ def test_join_documents_default_delimiter():
assert results["invocation_context"]["documents"] == [Document(content="first second third")]
@pytest.mark.unit
def test_join_documents_with_pattern_and_str_replace():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
outputs=["documents"],
params={"delimiter": " - ", "pattern": "[$idx] $content", "str_replace": {"r": "R"}},
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
@pytest.mark.unit
def test_join_documents_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -472,19 +529,213 @@ def test_join_documents_default_delimiter_yaml(tmp_path):
assert result["invocation_context"]["documents"] == [Document(content="first second third")]
@pytest.mark.unit
def test_join_documents_with_pattern_and_str_replace_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: join_documents
inputs:
documents: documents
outputs:
- documents
params:
delimiter: ' - '
pattern: '[$idx] $content'
str_replace:
r: R
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(
query="test query", documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert result["invocation_context"]["documents"] == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
#
# strings_to_answers
#
@pytest.mark.unit
def test_strings_to_answers_no_meta_no_hashkeys():
def test_strings_to_answers_simple():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"]})
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative"),
Answer(answer="second", type="generative"),
Answer(answer="third", type="generative"),
Answer(answer="first", type="generative", meta={"prompt": None}),
Answer(answer="second", type="generative", meta={"prompt": None}),
Answer(answer="third", type="generative", meta={"prompt": None}),
]
@pytest.mark.unit
def test_strings_to_answers_with_prompt():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"], "prompts": ["test prompt"]})
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": "test prompt"}),
Answer(answer="second", type="generative", meta={"prompt": "test prompt"}),
Answer(answer="third", type="generative", meta={"prompt": "test prompt"}),
]
@pytest.mark.unit
def test_strings_to_answers_with_documents():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(
invocation_context={
"responses": ["first", "second", "third"],
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="second", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="third", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
]
@pytest.mark.unit
def test_strings_to_answers_with_prompt_per_document():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(
invocation_context={
"responses": ["first", "second"],
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
"prompts": ["prompt1", "prompt2"],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
Answer(answer="second", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
]
@pytest.mark.unit
def test_strings_to_answers_with_prompt_per_document_multiple_results():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(
invocation_context={
"responses": ["first", "second", "third", "fourth"],
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
"prompts": ["prompt1", "prompt2"],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
Answer(answer="second", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
Answer(answer="third", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
Answer(answer="fourth", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
]
@pytest.mark.unit
def test_strings_to_answers_with_pattern_group():
shaper = Shaper(
func="strings_to_answers",
inputs={"strings": "responses"},
outputs=["answers"],
params={"pattern": r"Answer: (.*)"},
)
results, _ = shaper.run(invocation_context={"responses": ["Answer: first", "Answer: second", "Answer: third"]})
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": None}),
Answer(answer="second", type="generative", meta={"prompt": None}),
Answer(answer="third", type="generative", meta={"prompt": None}),
]
@pytest.mark.unit
def test_strings_to_answers_with_pattern_no_group():
shaper = Shaper(
func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"], params={"pattern": r"[^\n]+$"}
)
results, _ = shaper.run(invocation_context={"responses": ["Answer\nfirst", "Answer\nsecond", "Answer\n\nthird"]})
assert results["invocation_context"]["answers"] == [
Answer(answer="first", type="generative", meta={"prompt": None}),
Answer(answer="second", type="generative", meta={"prompt": None}),
Answer(answer="third", type="generative", meta={"prompt": None}),
]
@pytest.mark.unit
def test_strings_to_answers_with_references_index():
shaper = Shaper(
func="strings_to_answers",
inputs={"strings": "responses", "documents": "documents"},
outputs=["answers"],
params={"reference_pattern": r"\[(\d+)\]"},
)
results, _ = shaper.run(
invocation_context={
"responses": ["first[1]", "second[2]", "third[1][2]", "fourth"],
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first[1]", type="generative", meta={"prompt": None}, document_ids=["123"]),
Answer(answer="second[2]", type="generative", meta={"prompt": None}, document_ids=["456"]),
Answer(answer="third[1][2]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
]
@pytest.mark.unit
def test_strings_to_answers_with_references_id():
shaper = Shaper(
func="strings_to_answers",
inputs={"strings": "responses", "documents": "documents"},
outputs=["answers"],
params={"reference_pattern": r"\[(\d+)\]", "reference_mode": "id"},
)
results, _ = shaper.run(
invocation_context={
"responses": ["first[123]", "second[456]", "third[123][456]", "fourth"],
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first[123]", type="generative", meta={"prompt": None}, document_ids=["123"]),
Answer(answer="second[456]", type="generative", meta={"prompt": None}, document_ids=["456"]),
Answer(answer="third[123][456]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
]
@pytest.mark.unit
def test_strings_to_answers_with_references_meta():
shaper = Shaper(
func="strings_to_answers",
inputs={"strings": "responses", "documents": "documents"},
outputs=["answers"],
params={"reference_pattern": r"\[([^\]]+)\]", "reference_mode": "meta", "reference_meta_field": "file_id"},
)
results, _ = shaper.run(
invocation_context={
"responses": ["first[123.txt]", "second[456.txt]", "third[123.txt][456.txt]", "fourth"],
"documents": [
Document(id="123", content="test", meta={"file_id": "123.txt"}),
Document(id="456", content="test", meta={"file_id": "456.txt"}),
],
}
)
assert results["invocation_context"]["answers"] == [
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
]
@ -514,17 +765,145 @@ def test_strings_to_answers_yaml(tmp_path):
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run()
assert result["invocation_context"]["answers"] == [
Answer(answer="a", type="generative"),
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
Answer(answer="a", type="generative", meta={"prompt": None}),
Answer(answer="b", type="generative", meta={"prompt": None}),
Answer(answer="c", type="generative", meta={"prompt": None}),
]
assert result["answers"] == [
Answer(answer="a", type="generative"),
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
Answer(answer="a", type="generative", meta={"prompt": None}),
Answer(answer="b", type="generative", meta={"prompt": None}),
Answer(answer="c", type="generative", meta={"prompt": None}),
]
@pytest.mark.unit
def test_strings_to_answers_with_reference_meta_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: strings_to_answers
inputs:
documents: documents
params:
reference_meta_field: file_id
reference_mode: meta
reference_pattern: \[([^\]]+)\]
strings: ['first[123.txt]', 'second[456.txt]', 'third[123.txt][456.txt]', 'fourth']
outputs:
- answers
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(
documents=[
Document(id="123", content="test", meta={"file_id": "123.txt"}),
Document(id="456", content="test", meta={"file_id": "456.txt"}),
]
)
assert result["invocation_context"]["answers"] == [
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
]
assert result["answers"] == [
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
]
@pytest.mark.integration
def test_strings_to_answers_after_prompt_node_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: prompt_model
type: PromptModel
- name: prompt_template_raw_qa_per_document
type: PromptTemplate
params:
name: raw-question-answering-per-document
prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
- name: prompt_node_raw_qa
type: PromptNode
params:
model_name_or_path: prompt_model
default_prompt_template: prompt_template_raw_qa_per_document
top_k: 2
- name: prompt_node_question_generation
type: PromptNode
params:
model_name_or_path: prompt_model
default_prompt_template: question-generation
output_variable: query
- name: shaper
type: Shaper
params:
func: strings_to_answers
inputs:
strings: results
outputs:
- answers
pipelines:
- name: query
nodes:
- name: prompt_node_question_generation
inputs:
- Query
- name: prompt_node_raw_qa
inputs:
- prompt_node_question_generation
- name: shaper
inputs:
- prompt_node_raw_qa
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(
query="What's Berlin like?",
documents=[
Document("Berlin is an amazing city.", id="123"),
Document("Berlin is a cool city in Germany.", id="456"),
],
)
results = result["answers"]
assert len(results) == 4
assert any([True for r in results if "Berlin" in r.answer])
for answer in results[:2]:
assert answer.document_ids == ["123"]
assert (
answer.meta["prompt"]
== f"Given the context please answer the question. Context: Berlin is an amazing city.; Question: {result['query'][0]}; Answer:"
)
for answer in results[2:]:
assert answer.document_ids == ["456"]
assert (
answer.meta["prompt"]
== f"Given the context please answer the question. Context: Berlin is a cool city in Germany.; Question: {result['query'][1]}; Answer:"
)
#
# answers_to_strings
#
@ -537,6 +916,18 @@ def test_answers_to_strings():
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
@pytest.mark.unit
def test_answers_to_strings_with_pattern_and_str_replace():
shaper = Shaper(
func="answers_to_strings",
inputs={"answers": "documents"},
outputs=["strings"],
params={"pattern": "[$idx] $answer", "str_replace": {"r": "R"}},
)
results, _ = shaper.run(documents=[Answer(answer="first"), Answer(answer="second"), Answer(answer="third")])
assert results["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
@pytest.mark.unit
def test_answers_to_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -565,6 +956,38 @@ def test_answers_to_strings_yaml(tmp_path):
assert result["invocation_context"]["strings"] == ["a", "b", "c"]
@pytest.mark.unit
def test_answers_to_strings_with_pattern_and_str_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: answers_to_strings
inputs:
answers: documents
outputs:
- strings
params:
pattern: '[$idx] $answer'
str_replace:
r: R
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(documents=[Answer(answer="first"), Answer(answer="second"), Answer(answer="third")])
assert result["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
#
# strings_to_documents
#
@ -722,6 +1145,20 @@ def test_documents_to_strings():
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
@pytest.mark.unit
def test_documents_to_strings_with_pattern_and_str_replace():
shaper = Shaper(
func="documents_to_strings",
inputs={"documents": "documents"},
outputs=["strings"],
params={"pattern": "[$idx] $content", "str_replace": {"r": "R"}},
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
@pytest.mark.unit
def test_documents_to_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -750,6 +1187,38 @@ def test_documents_to_strings_yaml(tmp_path):
assert result["invocation_context"]["strings"] == ["a", "b", "c"]
@pytest.mark.unit
def test_documents_to_strings_with_pattern_and_str_replace_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: shaper
type: Shaper
params:
func: documents_to_strings
inputs:
documents: documents
outputs:
- strings
params:
pattern: '[$idx] $content'
str_replace:
r: R
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(documents=[Document(content="first"), Document(content="second"), Document(content="third")])
assert result["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
#
# Chaining and real-world usage
#
@ -923,32 +1392,19 @@ def test_with_prompt_node(tmp_path):
- name: prompt_model
type: PromptModel
- name: shaper
type: Shaper
params:
func: value_to_list
inputs:
value: query
target_list: documents
outputs:
- questions
- name: prompt_node
type: PromptNode
params:
output_variable: answers
model_name_or_path: prompt_model
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
- name: prompt_node
inputs:
- shaper
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
@ -957,10 +1413,8 @@ def test_with_prompt_node(tmp_path):
documents=[Document("Berlin is an amazing city."), Document("Berlin is a cool city in Germany.")],
)
assert len(result["answers"]) == 2
assert any(word for word in ["berlin", "germany", "cool", "city", "amazing"] if word in result["answers"])
assert len(result["invocation_context"]) > 0
assert len(result["invocation_context"]["questions"]) == 2
raw_answers = [answer.answer for answer in result["answers"]]
assert any(word for word in ["berlin", "germany", "cool", "city", "amazing"] if word in raw_answers)
@pytest.mark.integration
@ -973,15 +1427,6 @@ def test_with_multiple_prompt_nodes(tmp_path):
- name: prompt_model
type: PromptModel
- name: shaper
type: Shaper
params:
func: value_to_list
inputs:
value: query
target_list: documents
outputs: [questions]
- name: renamer
type: Shaper
params:
@ -989,13 +1434,13 @@ def test_with_multiple_prompt_nodes(tmp_path):
inputs:
value: new-questions
outputs:
- questions
- query
- name: prompt_node
type: PromptNode
params:
model_name_or_path: prompt_model
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
- name: prompt_node_second
type: PromptNode
@ -1007,19 +1452,15 @@ def test_with_multiple_prompt_nodes(tmp_path):
- name: prompt_node_third
type: PromptNode
params:
output_variable: answers
model_name_or_path: google/flan-t5-small
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
pipelines:
- name: query
nodes:
- name: shaper
inputs:
- Query
- name: prompt_node
inputs:
- shaper
- Query
- name: prompt_node_second
inputs:
- prompt_node
@ -1038,7 +1479,7 @@ def test_with_multiple_prompt_nodes(tmp_path):
)
results = result["answers"]
assert len(results) == 2
assert any([True for r in results if "Berlin" in r])
assert any([True for r in results if "Berlin" in r.answer])
@pytest.mark.unit

View File

@ -1,6 +1,6 @@
import os
import logging
from typing import Optional, Union, List, Dict, Any, Tuple
from typing import Optional, Set, Type, Union, List, Dict, Any, Tuple
import pytest
import torch
@ -9,7 +9,9 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt import PromptModelInvocationLayer
from haystack.nodes.prompt.prompt_node import PromptTemplateValidationError
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
from haystack.schema import Answer
def skip_test_for_invalid_key(prompt_model):
@ -57,47 +59,36 @@ def get_api_key(request):
@pytest.mark.unit
def test_prompt_templates():
p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"])
p = PromptTemplate("t1", "Here is some fake template with variable {foo}")
assert set(p.prompt_params) == {"foo"}
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo"])
p = PromptTemplate("t3", "Here is some fake template with variable {foo} and {bar}")
assert set(p.prompt_params) == {"foo", "bar"}
with pytest.raises(ValueError, match="Invalid parameter"):
PromptTemplate("t2", "Here is some fake template with variable $footur", ["foo"])
p = PromptTemplate("t4", "Here is some fake template with variable {foo1} and {bar2}")
assert set(p.prompt_params) == {"foo1", "bar2"}
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo", "bar", "baz"])
p = PromptTemplate("t4", "Here is some fake template with variable {foo_1} and {bar_2}")
assert set(p.prompt_params) == {"foo_1", "bar_2"}
p = PromptTemplate("t3", "Here is some fake template with variable $for and $bar", ["for", "bar"])
p = PromptTemplate("t4", "Here is some fake template with variable {Foo_1} and {Bar_2}")
assert set(p.prompt_params) == {"Foo_1", "Bar_2"}
# last parameter: "prompt_params" can be omitted
p = PromptTemplate("t4", "Here is some fake template with variable $foo and $bar")
assert p.prompt_params == ["foo", "bar"]
p = PromptTemplate("t4", "Here is some fake template with variable $foo1 and $bar2")
assert p.prompt_params == ["foo1", "bar2"]
p = PromptTemplate("t4", "Here is some fake template with variable $foo_1 and $bar_2")
assert p.prompt_params == ["foo_1", "bar_2"]
p = PromptTemplate("t4", "Here is some fake template with variable $Foo_1 and $Bar_2")
assert p.prompt_params == ["Foo_1", "Bar_2"]
p = PromptTemplate("t4", "'Here is some fake template with variable $baz'")
assert p.prompt_params == ["baz"]
p = PromptTemplate("t4", "'Here is some fake template with variable {baz}'")
assert set(p.prompt_params) == {"baz"}
# strip single quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable $baz"
assert p.prompt_text == "Here is some fake template with variable {baz}"
p = PromptTemplate("t4", '"Here is some fake template with variable $baz"')
assert p.prompt_params == ["baz"]
p = PromptTemplate("t4", '"Here is some fake template with variable {baz}"')
assert set(p.prompt_params) == {"baz"}
# strip double quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable $baz"
assert p.prompt_text == "Here is some fake template with variable {baz}"
@pytest.mark.unit
def test_prompt_template_repr():
p = PromptTemplate("t", "Here is variable $baz")
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable $baz, prompt_params=['baz'])"
p = PromptTemplate("t", "Here is variable {baz}")
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
assert repr(p) == desired_repr
assert str(p) == desired_repr
@ -178,9 +169,7 @@ def test_create_prompt_node():
@pytest.mark.integration
def test_add_and_remove_template(prompt_node):
num_default_tasks = len(prompt_node.get_prompt_template_names())
custom_task = PromptTemplate(
name="custom-task", prompt_text="Custom task: $param1, $param2", prompt_params=["param1", "param2"]
)
custom_task = PromptTemplate(name="custom-task", prompt_text="Custom task: {param1}, {param2}")
prompt_node.add_prompt_template(custom_task)
assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1
assert "custom-task" in prompt_node.get_prompt_template_names()
@ -189,24 +178,12 @@ def test_add_and_remove_template(prompt_node):
assert "custom-task" not in prompt_node.get_prompt_template_names()
@pytest.mark.unit
def test_invalid_template():
with pytest.raises(ValueError, match="Invalid parameter"):
PromptTemplate(
name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"]
)
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"])
@pytest.mark.integration
def test_add_template_and_invoke(prompt_node):
tt = PromptTemplate(
name="sentiment-analysis-new",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: $documents; Answer:",
prompt_params=["documents"],
"negative or neutral. Context: {documents}; Answer:",
)
prompt_node.add_prompt_template(tt)
@ -219,8 +196,7 @@ def test_on_the_fly_prompt(prompt_node):
prompt_template = PromptTemplate(
name="sentiment-analysis-temp",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: $documents; Answer:",
prompt_params=["documents"],
"negative or neutral. Context: {documents}; Answer:",
)
r = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
assert r[0].casefold() == "positive"
@ -250,12 +226,12 @@ def test_question_generation(prompt_node):
@pytest.mark.integration
def test_template_selection(prompt_node):
qa = prompt_node.set_default_prompt_template("question-answering")
qa = prompt_node.set_default_prompt_template("question-answering-per-document")
r = qa(
["Berlin is the capital of Germany.", "Paris is the capital of France."],
["What is the capital of Germany?", "What is the capital of France"],
)
assert r[0].casefold() == "berlin" and r[1].casefold() == "paris"
assert r[0].answer.casefold() == "berlin" and r[1].answer.casefold() == "paris"
@pytest.mark.integration
@ -266,14 +242,14 @@ def test_has_supported_template_names(prompt_node):
@pytest.mark.integration
def test_invalid_template_params(prompt_node):
with pytest.raises(ValueError, match="Expected prompt parameters"):
prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."})
prompt_node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
@pytest.mark.integration
def test_wrong_template_params(prompt_node):
with pytest.raises(ValueError, match="Expected prompt parameters"):
# with don't have options param, multiple choice QA has
prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."])
prompt_node.prompt("question-answering-per-document", options=["Berlin is the capital of Germany."])
@pytest.mark.integration
@ -298,7 +274,7 @@ def test_invalid_state_ops(prompt_node):
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
prompt_node.remove_prompt_template("no_such_task_exists")
# remove default task
prompt_node.remove_prompt_template("question-answering")
prompt_node.remove_prompt_template("question-answering-per-document")
@pytest.mark.integration
@ -365,7 +341,7 @@ def test_stop_words(prompt_model):
tt = PromptTemplate(
name="question-generation-copy",
prompt_text="Given the context please generate a question. Context: $documents; Question:",
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
)
# with custom prompt template
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
@ -441,15 +417,15 @@ def test_simple_pipeline(prompt_model):
def test_complex_pipeline(prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="questions")
node2 = PromptNode(prompt_model, default_prompt_template="question-answering")
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query")
node2 = PromptNode(prompt_model, default_prompt_template="question-answering-per-document")
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
assert "berlin" in result["results"][0].casefold()
assert "berlin" in result["answers"][0].answer.casefold()
@pytest.mark.integration
@ -457,13 +433,70 @@ def test_complex_pipeline(prompt_model):
def test_simple_pipeline_with_topk(prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-generation", top_k=2)
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query", top_k=2)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
assert len(result["results"]) == 2
assert len(result["query"]) == 2
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_pipeline_with_standard_qa(prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-answering", top_k=1)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
documents=[
Document("My name is Carla and I live in Berlin", id="1"),
Document("My name is Christelle and I live in Paris", id="2"),
],
)
assert len(result["answers"]) == 1
assert "carla" in result["answers"][0].answer.casefold()
assert result["answers"][0].document_ids == ["1", "2"]
assert (
result["answers"][0].meta["prompt"]
== "Given the context please answer the question. Context: My name is Carla and I live in Berlin My name is Christelle and I live in Paris; "
"Question: Who lives in Berlin?; Answer:"
)
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_pipeline_with_qa_with_references(prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
documents=[
Document("My name is Carla and I live in Berlin", id="1"),
Document("My name is Christelle and I live in Paris", id="2"),
],
)
assert len(result["answers"]) == 1
assert "carla, as stated in document[1]" in result["answers"][0].answer.casefold()
assert result["answers"][0].document_ids == ["1"]
assert (
result["answers"][0].meta["prompt"]
== "Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. "
"You must only use information from the given documents. Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
"If multiple documents contain the answer, cite those documents like as stated in Document[number], Document[number], etc.. If the documents do not contain the answer to the question, "
"say that answering is not possible given the available information.\n\nDocument[1]: My name is Carla and I live in Berlin\n\nDocument[2]: My name is Christelle and I live in Paris \n "
"Question: Who lives in Berlin?; Answer: "
)
@pytest.mark.integration
@ -501,8 +534,7 @@ def test_complex_pipeline_with_qa(prompt_model):
prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
prompt_params=["documents", "query"],
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
@ -517,7 +549,7 @@ def test_complex_pipeline_with_qa(prompt_model):
debug=True, # so we can verify that the constructed prompt is returned in debug
)
assert len(result["results"]) == 1
assert len(result["results"]) == 2
assert "carla" in result["results"][0].casefold()
# also verify that the PromptNode has included its constructed prompt LLM model input in the returned debug
@ -531,17 +563,15 @@ def test_complex_pipeline_with_qa(prompt_model):
@pytest.mark.integration
def test_complex_pipeline_with_shared_model():
model = PromptModel()
node = PromptNode(
model_name_or_path=model, default_prompt_template="question-generation", output_variable="questions"
)
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering")
node = PromptNode(model_name_or_path=model, default_prompt_template="question-generation", output_variable="query")
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering-per-document")
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
assert result["results"][0] == "Berlin"
assert result["answers"][0].answer == "Berlin"
@pytest.mark.integration
@ -606,11 +636,11 @@ def test_complex_pipeline_yaml(tmp_path):
- name: p1
params:
default_prompt_template: question-generation
output_variable: questions
output_variable: query
type: PromptNode
- name: p2
params:
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -625,11 +655,11 @@ def test_complex_pipeline_yaml(tmp_path):
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
response = result["answers"][0].answer
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
assert len(result["query"]) > 0
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.integration
@ -645,12 +675,12 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
params:
model_name_or_path: pmodel
default_prompt_template: question-generation
output_variable: questions
output_variable: query
type: PromptNode
- name: p2
params:
model_name_or_path: pmodel
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -665,11 +695,11 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
response = result["answers"][0].answer
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
assert len(result["query"]) > 0
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.integration
@ -689,17 +719,17 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel
default_prompt_template: question_generation_template
output_variable: questions
output_variable: query
type: PromptNode
- name: p2
params:
model_name_or_path: pmodel
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -714,11 +744,11 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
response = result["answers"][0].answer
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
assert len(result["query"]) > 0
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.integration
@ -767,17 +797,17 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel
default_prompt_template: question_generation_template
output_variable: questions
output_variable: query
type: PromptNode
- name: p2
params:
model_name_or_path: pmodel
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -795,11 +825,11 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
response = result["answers"][0].answer
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
assert len(result["query"]) > 0
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True)
@ -839,17 +869,17 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel_openai
default_prompt_template: question_generation_template
output_variable: questions
output_variable: query
type: PromptNode
- name: p2
params:
model_name_or_path: pmodel
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -864,11 +894,11 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
response = result["results"][0]
response = result["answers"][0].answer
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
assert len(result["query"]) > 0
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.integration
@ -882,15 +912,14 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
- name: p1
params:
default_prompt_template: question-generation
output_variable: questions
type: PromptNode
- name: p2
params:
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
- name: p3
params:
default_prompt_template: question-answering
default_prompt_template: question-answering-per-document
type: PromptNode
pipelines:
- name: query
@ -914,9 +943,7 @@ class TestTokenLimit:
@pytest.mark.integration
def test_hf_token_limit_warning(self, prompt_node, caplog):
prompt_template = PromptTemplate(
name="too-long-temp",
prompt_text="Repeating text" * 200 + "Docs: $documents; Answer:",
prompt_params=["documents"],
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
)
with caplog.at_level(logging.WARNING):
_ = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
@ -929,11 +956,7 @@ class TestTokenLimit:
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_openai_token_limit_warning(self, caplog):
tt = PromptTemplate(
name="too-long-temp",
prompt_text="Repeating text" * 200 + "Docs: $documents; Answer:",
prompt_params=["documents"],
)
tt = PromptTemplate(name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:")
prompt_node = PromptNode("text-ada-001", max_length=2000, api_key=os.environ.get("OPENAI_API_KEY", ""))
with caplog.at_level(logging.WARNING):
_ = prompt_node.prompt(tt, documents=["Berlin is an amazing city."])
@ -989,8 +1012,7 @@ class TestRunBatch:
prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
prompt_params=["documents", "query"],
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
@ -1015,6 +1037,252 @@ def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
class TestPromptTemplateSyntax:
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, expected_prompt_params, expected_used_functions",
[
("{documents}", {"documents"}, set()),
("Please answer the question: {documents} Question: how?", {"documents"}, set()),
("Please answer the question: {documents} Question: {query}", {"documents", "query"}, set()),
("Please answer the question: {documents} {{Question}}: {query}", {"documents", "query"}, set()),
(
"Please answer the question: {join(documents)} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(documents, 'delim', {'{': '('})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
'Please answer the question: {join(documents, "delim", {"{": "("})} Question: {query.replace("A", "a")}',
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(documents, 'delim', {'a': {'b': 'c'}})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(document=documents, delimiter='delim', str_replace={'{': '('})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
],
)
def test_prompt_template_syntax_parser(
self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
assert set(prompt_template.prompt_params) == expected_prompt_params
assert set(prompt_template._used_functions) == expected_used_functions
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, expected_prompts",
[
("{documents}", [Document("doc1"), Document("doc2")], None, ["doc1", "doc2"]),
(
"context: {documents} question: how?",
[Document("doc1"), Document("doc2")],
None,
["context: doc1 question: how?", "context: doc2 question: how?"],
),
(
"context: {' '.join([d.content for d in documents])} question: how?",
[Document("doc1"), Document("doc2")],
None,
["context: doc1 doc2 question: how?"],
),
(
"context: {documents} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 question: how?", "context: doc2 question: how?"],
),
(
"context: {documents} {{question}}: {query}",
[Document("doc1")],
"how?",
["context: doc1 {question}: how?"],
),
(
"context: {join(documents)} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 doc2 question: how?"],
),
(
"Please answer the question: {join(documents, ' delim ', '[$idx] $content', {'{': '('})} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"Please answer the question: {join(documents=documents, delimiter=' delim ', pattern='[$idx] $content', str_replace={'{': '('})} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"Please answer the question: {' delim '.join(['['+str(idx+1)+'] '+d.content.replace('{', '(') for idx, d in enumerate(documents)])} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
'Please answer the question: {join(documents, " delim ", "[$idx] $content", {"{": "("})} question: {query}',
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"context: {join(documents)} question: {query.replace('how', 'what')}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 doc2 question: what?"],
),
(
"context: {join(documents)[:6]} question: {query.replace('how', 'what').replace('?', '!')}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 d question: what!"],
),
("context", None, None, ["context"]),
],
)
def test_prompt_template_syntax_fill(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, expected_prompts",
[
("{join(documents)}", [Document("doc1"), Document("doc2")], ["doc1 doc2"]),
(
"{join(documents, ' delim ', '[$idx] $content', {'c': 'C'})}",
[Document("doc1"), Document("doc2")],
["[1] doC1 delim [2] doC2"],
),
(
"{join(documents, ' delim ', '[$id] $content', {'c': 'C'})}",
[Document("doc1", id="123"), Document("doc2", id="456")],
["[123] doC1 delim [456] doC2"],
),
(
"{join(documents, ' delim ', '[$file_id] $content', {'c': 'C'})}",
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
["[123.txt] doC1 delim [456.txt] doC2"],
),
],
)
def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, expected_prompts",
[
("{to_strings(documents)}", [Document("doc1"), Document("doc2")], ["doc1", "doc2"]),
(
"{to_strings(documents, '[$idx] $content', {'c': 'C'})}",
[Document("doc1"), Document("doc2")],
["[1] doC1", "[2] doC2"],
),
(
"{to_strings(documents, '[$id] $content', {'c': 'C'})}",
[Document("doc1", id="123"), Document("doc2", id="456")],
["[123] doC1", "[456] doC2"],
),
(
"{to_strings(documents, '[$file_id] $content', {'c': 'C'})}",
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
["[123.txt] doC1", "[456.txt] doC2"],
),
("{to_strings(documents, '[$file_id] $content', {'c': 'C'})}", ["doc1", "doc2"], ["doC1", "doC2"]),
(
"{to_strings(documents, '[$idx] $answer', {'c': 'C'})}",
[Answer("doc1"), Answer("doc2")],
["[1] doC1", "[2] doC2"],
),
],
)
def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, exc_type, expected_exc_match",
[
("{__import__('os').listdir('.')}", PromptTemplateValidationError, "Invalid function in prompt text"),
("{__import__('os')}", PromptTemplateValidationError, "Invalid function in prompt text"),
(
"{requests.get('https://haystack.deepset.ai/')}",
PromptTemplateValidationError,
"Invalid function in prompt text",
),
("{join(__import__('os').listdir('.'))}", PromptTemplateValidationError, "Invalid function in prompt text"),
("{for}", SyntaxError, "invalid syntax"),
("This is an invalid {variable .", SyntaxError, "f-string: expecting '}'"),
],
)
def test_prompt_template_syntax_init_raises(
self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str
):
with pytest.raises(exc_type, match=expected_exc_match):
PromptTemplate(name="test", prompt_text=prompt_text)
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, exc_type, expected_exc_match",
[("{join}", None, None, ValueError, "Expected prompt parameters")],
)
def test_prompt_template_syntax_fill_raises(
self,
prompt_text: str,
documents: List[Document],
query: str,
exc_type: Type[BaseException],
expected_exc_match: str,
):
with pytest.raises(exc_type, match=expected_exc_match):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
next(prompt_template.fill(documents=documents, query=query))
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, expected_prompts",
[
("__import__('os').listdir('.')", None, None, ["__import__('os').listdir('.')"]),
(
"requests.get('https://haystack.deepset.ai/')",
None,
None,
["requests.get('https://haystack.deepset.ai/')"],
),
("{query}", None, print, ["<built-in function print>"]),
("\b\b__import__('os').listdir('.')", None, None, ["\x08\x08__import__('os').listdir('.')"]),
],
)
def test_prompt_template_syntax_fill_ignores_dangerous_input(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts
@pytest.mark.integration
def test_chatgpt_direct_prompting(chatgpt_prompt_model):
skip_test_for_invalid_key(chatgpt_prompt_model)