mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
feat: PromptTemplate extensions (#4378)
* use outputshapers in prompttemplate * fix pylint * first iteration on regex * implement new promptnode syntax based on f-strings * finish fstring implementation * add additional tests * add security tests * fix mypy * fix pylint * fix test_prompt_templates * fix test_prompt_template_repr * fix test_prompt_node_with_custom_invocation_layer * fix test_invalid_template * more security tests * fix test_complex_pipeline_with_all_features * fix agent tests * refactor get_prompt_template * fix test_prompt_template_syntax_parser * fix test_complex_pipeline_with_all_features * allow functions in comprehensions * break out of fstring test * fix additional tests * mark new tests as unit tests * fix agents tests * convert missing templates * proper use of get_prompt_template * refactor and add docstrings * fix tests * fix pylint * fix agents test * fix tests * refactor globals * make allowed functions configurable via env variable * better dummy variable * fix special alias * don't replace special char variables * more special chars, better docstrings * cherrypick fix audio tests * fix test * rework shapers * fix pylint * fix tests * add new templates * add reference parsing * add more shaper tests * add tests for join and to_string * fix pylint * fix pylint * fix pylint for real * auto fill shaper function params * fix reference parsing for multiple references * fix output variable inference * consolidate qa prompt template output and make shaper work per-document * fix types after merge * introduce output_parser * fix tests * better docstring * rename RegexAnswerParser to AnswerParser * better docstrings
This commit is contained in:
parent
9518bcb7a8
commit
382ca8094e
@ -163,9 +163,7 @@ class Agent:
|
||||
)
|
||||
)
|
||||
self.prompt_node = prompt_node
|
||||
self.prompt_template = (
|
||||
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
|
||||
)
|
||||
self.prompt_template = prompt_node.get_prompt_template(prompt_template)
|
||||
self.tools = {tool.name: tool for tool in tools} if tools else {}
|
||||
self.tool_names = ", ".join(self.tools.keys())
|
||||
self.tool_names_with_descriptions = "\n".join(
|
||||
|
||||
@ -15,6 +15,9 @@ HAYSTACK_REMOTE_API_BACKOFF_SEC = "HAYSTACK_REMOTE_API_BACKOFF_SEC"
|
||||
HAYSTACK_REMOTE_API_MAX_RETRIES = "HAYSTACK_REMOTE_API_MAX_RETRIES"
|
||||
HAYSTACK_REMOTE_API_TIMEOUT_SEC = "HAYSTACK_REMOTE_API_TIMEOUT_SEC"
|
||||
|
||||
HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS = "HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS"
|
||||
|
||||
|
||||
env_meta_data: Dict[str, Any] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -92,12 +92,11 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
PromptTemplate(
|
||||
name="question-answering-with-examples",
|
||||
prompt_text="Please answer the question according to the above context."
|
||||
"\n===\nContext: $examples_context\n===\n$examples\n\n"
|
||||
"===\nContext: $context\n===\n$query",
|
||||
prompt_params=["examples_context", "examples", "context", "query"],
|
||||
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
|
||||
"===\nContext: {context}\n===\n{query}",
|
||||
)
|
||||
```
|
||||
To learn how variables, such as'$context', are substituted in the `prompt_text`, see
|
||||
To learn how variables, such as'{context}', are substituted in the `prompt_text`, see
|
||||
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
||||
:param context_join_str: The separation string used to join the input documents to create the context
|
||||
used by the PromptTemplate.
|
||||
@ -118,9 +117,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-with-examples",
|
||||
prompt_text="Please answer the question according to the above context."
|
||||
"\n===\nContext: $examples_context\n===\n$examples\n\n"
|
||||
"===\nContext: $context\n===\n$query",
|
||||
prompt_params=["examples_context", "examples", "context", "query"],
|
||||
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
|
||||
"===\nContext: {context}\n===\n{query}",
|
||||
)
|
||||
else:
|
||||
# Check for required prompts
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union, Callable
|
||||
from functools import reduce
|
||||
import inspect
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Literal, Optional, List, Dict, Any, Tuple, Union, Callable
|
||||
|
||||
import logging
|
||||
|
||||
@ -9,49 +13,49 @@ from haystack.schema import Document, Answer, MultiLabel
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rename(value: Any) -> Tuple[Any]:
|
||||
def rename(value: Any) -> Any:
|
||||
"""
|
||||
Identity function. Can be used to rename values in the invocation context without changing them.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert rename(1) == (1, )
|
||||
assert rename(1) == 1
|
||||
```
|
||||
"""
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
def value_to_list(value: Any, target_list: List[Any]) -> Tuple[List[Any]]:
|
||||
def value_to_list(value: Any, target_list: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Transforms a value into a list containing this value as many times as the length of the target list.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert value_to_list(value=1, target_list=list(range(5))) == ([1, 1, 1, 1, 1], )
|
||||
assert value_to_list(value=1, target_list=list(range(5))) == [1, 1, 1, 1, 1]
|
||||
```
|
||||
"""
|
||||
return ([value] * len(target_list),)
|
||||
return [value] * len(target_list)
|
||||
|
||||
|
||||
def join_lists(lists: List[List[Any]]) -> Tuple[List[Any]]:
|
||||
def join_lists(lists: List[List[Any]]) -> List[Any]:
|
||||
"""
|
||||
Joins the passed lists into a single one.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == ([1, 2, 3, 4, 5], )
|
||||
assert join_lists(lists=[[1, 2, 3], [4, 5]]) == [1, 2, 3, 4, 5]
|
||||
```
|
||||
"""
|
||||
merged_list = []
|
||||
for inner_list in lists:
|
||||
merged_list += inner_list
|
||||
return (merged_list,)
|
||||
return merged_list
|
||||
|
||||
|
||||
def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
|
||||
def join_strings(strings: List[str], delimiter: str = " ", str_replace: Optional[Dict[str, str]] = None) -> str:
|
||||
"""
|
||||
Transforms a list of strings into a single string. The content of this string
|
||||
is the content of all original strings separated by the delimiter you specify.
|
||||
@ -59,16 +63,42 @@ def join_strings(strings: List[str], delimiter: str = " ") -> Tuple[str]:
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_strings(strings=["first", "second", "third"], delimiter=" - ") == ("first - second - third", )
|
||||
assert join_strings(strings=["first", "second", "third"], delimiter=" - ", str_replace={"r": "R"}) == "fiRst - second - thiRd"
|
||||
```
|
||||
"""
|
||||
return (delimiter.join(strings),)
|
||||
str_replace = str_replace or {}
|
||||
return delimiter.join([format_string(string, str_replace) for string in strings])
|
||||
|
||||
|
||||
def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[List[Document]]:
|
||||
def format_string(string: str, str_replace: Optional[Dict[str, str]] = None) -> str:
|
||||
"""
|
||||
Transforms a string using a substitution dict.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert format_string(string="first", str_replace={"r": "R"}) == "fiRst"
|
||||
```
|
||||
"""
|
||||
str_replace = str_replace or {}
|
||||
return reduce(lambda s, kv: s.replace(*kv), str_replace.items(), string)
|
||||
|
||||
|
||||
def join_documents(
|
||||
documents: List[Document],
|
||||
delimiter: str = " ",
|
||||
pattern: Optional[str] = None,
|
||||
str_replace: Optional[Dict[str, str]] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Transforms a list of documents into a list containing a single Document. The content of this list
|
||||
is the content of all original documents separated by the delimiter you specify.
|
||||
is the joined result of all original documents separated by the delimiter you specify.
|
||||
How each document is represented is controlled by the pattern parameter.
|
||||
You can use the following placeholders:
|
||||
- $content: the content of the document
|
||||
- $idx: the index of the document in the list
|
||||
- $id: the id of the document
|
||||
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
|
||||
|
||||
All metadata is dropped. (TODO: fix)
|
||||
|
||||
@ -81,31 +111,333 @@ def join_documents(documents: List[Document], delimiter: str = " ") -> Tuple[Lis
|
||||
Document(content="second"),
|
||||
Document(content="third")
|
||||
],
|
||||
delimiter=" - "
|
||||
) == ([Document(content="first - second - third")], )
|
||||
delimiter=" - ",
|
||||
pattern="[$idx] $content",
|
||||
str_replace={"r": "R"}
|
||||
) == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
|
||||
```
|
||||
"""
|
||||
return ([Document(content=delimiter.join([d.content for d in documents]))],)
|
||||
return [Document(content=join_documents_to_string(documents, delimiter, pattern, str_replace))]
|
||||
|
||||
|
||||
def strings_to_answers(strings: List[str]) -> Tuple[List[Answer]]:
|
||||
def format_document(
|
||||
document: Document,
|
||||
pattern: Optional[str] = None,
|
||||
str_replace: Optional[Dict[str, str]] = None,
|
||||
idx: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transforms a list of strings into a list of Answers.
|
||||
Transforms a document into a single string.
|
||||
How the document is represented is controlled by the pattern parameter.
|
||||
You can use the following placeholders:
|
||||
- $content: the content of the document
|
||||
- $idx: the index of the document in the list
|
||||
- $id: the id of the document
|
||||
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert strings_to_answers(strings=["first", "second", "third"]) == ([
|
||||
Answer(answer="first"),
|
||||
Answer(answer="second"),
|
||||
Answer(answer="third"),
|
||||
], )
|
||||
assert format_document(
|
||||
document=Document(content="first"),
|
||||
pattern="prefix [$idx] $content",
|
||||
str_replace={"r": "R"},
|
||||
idx=1,
|
||||
) == "prefix [1] fiRst"
|
||||
```
|
||||
"""
|
||||
return ([Answer(answer=string, type="generative") for string in strings],)
|
||||
str_replace = str_replace or {}
|
||||
pattern = pattern or "$content"
|
||||
|
||||
template = Template(pattern)
|
||||
pattern_params = [
|
||||
match.groupdict().get("named", match.groupdict().get("braced"))
|
||||
for match in template.pattern.finditer(template.template)
|
||||
]
|
||||
meta_params = [param for param in pattern_params if param and param not in ["content", "idx", "id"]]
|
||||
content = template.substitute(
|
||||
{
|
||||
"idx": idx,
|
||||
"content": reduce(lambda content, kv: content.replace(*kv), str_replace.items(), document.content),
|
||||
"id": reduce(lambda id, kv: id.replace(*kv), str_replace.items(), document.id),
|
||||
**{
|
||||
k: reduce(lambda val, kv: val.replace(*kv), str_replace.items(), document.meta.get(k, ""))
|
||||
for k in meta_params
|
||||
},
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
|
||||
def format_answer(
|
||||
answer: Answer,
|
||||
pattern: Optional[str] = None,
|
||||
str_replace: Optional[Dict[str, str]] = None,
|
||||
idx: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transforms an answer into a single string.
|
||||
How the answer is represented is controlled by the pattern parameter.
|
||||
You can use the following placeholders:
|
||||
- $answer: the answer text of the answer
|
||||
- $idx: the index of the answer in the list
|
||||
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert format_answer(
|
||||
answer=Answer(answer="first"),
|
||||
pattern="prefix [$idx] $answer",
|
||||
str_replace={"r": "R"},
|
||||
idx=1,
|
||||
) == "prefix [1] fiRst"
|
||||
```
|
||||
"""
|
||||
str_replace = str_replace or {}
|
||||
pattern = pattern or "$answer"
|
||||
|
||||
template = Template(pattern)
|
||||
pattern_params = [
|
||||
match.groupdict().get("named", match.groupdict().get("braced"))
|
||||
for match in template.pattern.finditer(template.template)
|
||||
]
|
||||
meta_params = [param for param in pattern_params if param and param not in ["answer", "idx"]]
|
||||
meta = answer.meta or {}
|
||||
content = template.substitute(
|
||||
{
|
||||
"idx": idx,
|
||||
"answer": reduce(lambda content, kv: content.replace(*kv), str_replace.items(), answer.answer),
|
||||
**{k: reduce(lambda val, kv: val.replace(*kv), str_replace.items(), meta.get(k, "")) for k in meta_params},
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def join_documents_to_string(
|
||||
documents: List[Document],
|
||||
delimiter: str = " ",
|
||||
pattern: Optional[str] = None,
|
||||
str_replace: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transforms a list of documents into a single string. The content of this string
|
||||
is the joined result of all original documents separated by the delimiter you specify.
|
||||
How each document is represented is controlled by the pattern parameter.
|
||||
You can use the following placeholders:
|
||||
- $content: the content of the document
|
||||
- $idx: the index of the document in the list
|
||||
- $id: the id of the document
|
||||
- $META_FIELD: the value of the metadata field of name 'META_FIELD'
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
assert join_documents_to_string(
|
||||
documents=[
|
||||
Document(content="first"),
|
||||
Document(content="second"),
|
||||
Document(content="third")
|
||||
],
|
||||
delimiter=" - ",
|
||||
pattern="[$idx] $content",
|
||||
str_replace={"r": "R"}
|
||||
) == "[1] fiRst - [2] second - [3] thiRd"
|
||||
```
|
||||
"""
|
||||
content = delimiter.join(
|
||||
format_document(doc, pattern, str_replace, idx=idx) for idx, doc in enumerate(documents, start=1)
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def strings_to_answers(
|
||||
strings: List[str],
|
||||
prompts: Optional[List[Union[str, List[Dict[str, str]]]]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
pattern: Optional[str] = None,
|
||||
reference_pattern: Optional[str] = None,
|
||||
reference_mode: Literal["index", "id", "meta"] = "index",
|
||||
reference_meta_field: Optional[str] = None,
|
||||
) -> List[Answer]:
|
||||
"""
|
||||
Transforms a list of strings into a list of Answers.
|
||||
Specify `reference_pattern` to populate the answer's `document_ids` by extracting document references from the strings.
|
||||
|
||||
:param strings: The list of strings to transform.
|
||||
:param prompts: The prompts used to generate the answers.
|
||||
:param documents: The documents used to generate the answers.
|
||||
:param pattern: The regex pattern to use for parsing the answer.
|
||||
Examples:
|
||||
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
|
||||
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
|
||||
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
|
||||
:param reference_pattern: The regex pattern to use for parsing the document references.
|
||||
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
|
||||
If None, no parsing is done and all documents are referenced.
|
||||
:param reference_mode: The mode used to reference documents. Supported modes are:
|
||||
- index: the document references are the one-based index of the document in the list of documents.
|
||||
Example: "this is an answer[1]" will reference the first document in the list of documents.
|
||||
- id: the document references are the document ids.
|
||||
Example: "this is an answer[123]" will reference the document with id "123".
|
||||
- meta: the document references are the value of a metadata field of the document.
|
||||
Example: "this is an answer[123]" will reference the document with the value "123" in the metadata field specified by reference_meta_field.
|
||||
:param reference_meta_field: The name of the metadata field to use for document references in reference_mode "meta".
|
||||
:return: The list of answers.
|
||||
|
||||
Examples:
|
||||
|
||||
Without reference parsing:
|
||||
```python
|
||||
assert strings_to_answers(strings=["first", "second", "third"], prompt="prompt", documents=[Document(id="123", content="content")]) == [
|
||||
Answer(answer="first", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
|
||||
Answer(answer="second", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
|
||||
Answer(answer="third", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
|
||||
]
|
||||
```
|
||||
|
||||
With reference parsing:
|
||||
```python
|
||||
assert strings_to_answers(strings=["first[1]", "second[2]", "third[1][3]"], prompt="prompt",
|
||||
documents=[Document(id="123", content="content"), Document(id="456", content="content"), Document(id="789", content="content")],
|
||||
reference_pattern=r"\\[(\\d+)\\]",
|
||||
reference_mode="index"
|
||||
) == [
|
||||
Answer(answer="first", type="generative", document_ids=["123"], meta={"prompt": "prompt"}),
|
||||
Answer(answer="second", type="generative", document_ids=["456"], meta={"prompt": "prompt"}),
|
||||
Answer(answer="third", type="generative", document_ids=["123", "789"], meta={"prompt": "prompt"}),
|
||||
]
|
||||
```
|
||||
"""
|
||||
if prompts:
|
||||
if len(prompts) == 1:
|
||||
# one prompt for all strings/documents
|
||||
documents_per_string: List[Optional[List[Document]]] = [documents] * len(strings)
|
||||
prompt_per_string: List[Optional[Union[str, List[Dict[str, str]]]]] = [prompts[0]] * len(strings)
|
||||
elif len(prompts) > 1 and len(strings) % len(prompts) == 0:
|
||||
# one prompt per string/document
|
||||
if documents is not None and len(documents) != len(prompts):
|
||||
raise ValueError("The number of documents must match the number of prompts")
|
||||
string_multiplier = len(strings) // len(prompts)
|
||||
documents_per_string = (
|
||||
[[doc] for doc in documents for _ in range(string_multiplier)] if documents else [None] * len(strings)
|
||||
)
|
||||
prompt_per_string = [prompt for prompt in prompts for _ in range(string_multiplier)]
|
||||
else:
|
||||
raise ValueError("The number of prompts must be one or a multiple of the number of strings")
|
||||
else:
|
||||
documents_per_string = [documents] * len(strings)
|
||||
prompt_per_string = [None] * len(strings)
|
||||
|
||||
answers = []
|
||||
for string, prompt, _documents in zip(strings, prompt_per_string, documents_per_string):
|
||||
answer = string_to_answer(
|
||||
string=string,
|
||||
prompt=prompt,
|
||||
documents=_documents,
|
||||
pattern=pattern,
|
||||
reference_pattern=reference_pattern,
|
||||
reference_mode=reference_mode,
|
||||
reference_meta_field=reference_meta_field,
|
||||
)
|
||||
answers.append(answer)
|
||||
return answers
|
||||
|
||||
|
||||
def string_to_answer(
|
||||
string: str,
|
||||
prompt: Optional[Union[str, List[Dict[str, str]]]],
|
||||
documents: Optional[List[Document]],
|
||||
pattern: Optional[str] = None,
|
||||
reference_pattern: Optional[str] = None,
|
||||
reference_mode: Literal["index", "id", "meta"] = "index",
|
||||
reference_meta_field: Optional[str] = None,
|
||||
) -> Answer:
|
||||
"""
|
||||
Transforms a string into an Answer.
|
||||
Specify `reference_pattern` to populate the answer's `document_ids` by extracting document references from the string.
|
||||
|
||||
:param string: The string to transform.
|
||||
:param prompt: The prompt used to generate the answer.
|
||||
:param documents: The documents used to generate the answer.
|
||||
:param pattern: The regex pattern to use for parsing the answer.
|
||||
Examples:
|
||||
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
|
||||
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
|
||||
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
|
||||
:param reference_pattern: The regex pattern to use for parsing the document references.
|
||||
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
|
||||
If None, no parsing is done and all documents are referenced.
|
||||
:param reference_mode: The mode used to reference documents. Supported modes are:
|
||||
- index: the document references are the one-based index of the document in the list of documents.
|
||||
Example: "this is an answer[1]" will reference the first document in the list of documents.
|
||||
- id: the document references are the document ids.
|
||||
Example: "this is an answer[123]" will reference the document with id "123".
|
||||
- meta: the document references are the value of a metadata field of the document.
|
||||
Example: "this is an answer[123]" will reference the document with the value "123" in the metadata field specified by reference_meta_field.
|
||||
:param reference_meta_field: The name of the metadata field to use for document references in reference_mode "meta".
|
||||
:return: The answer
|
||||
"""
|
||||
if reference_mode == "index":
|
||||
candidates = {str(idx): doc.id for idx, doc in enumerate(documents, start=1)} if documents else {}
|
||||
elif reference_mode == "id":
|
||||
candidates = {doc.id: doc.id for doc in documents} if documents else {}
|
||||
elif reference_mode == "meta":
|
||||
if not reference_meta_field:
|
||||
raise ValueError("reference_meta_field must be specified when reference_mode is 'meta'")
|
||||
candidates = (
|
||||
{doc.meta[reference_meta_field]: doc.id for doc in documents if doc.meta.get(reference_meta_field)}
|
||||
if documents
|
||||
else {}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid document_id_mode: {reference_mode}")
|
||||
|
||||
if pattern:
|
||||
match = re.search(pattern, string)
|
||||
if match:
|
||||
if not match.lastindex:
|
||||
# no group in pattern -> take the whole match
|
||||
string = match.group(0)
|
||||
elif match.lastindex == 1:
|
||||
# one group in pattern -> take the group
|
||||
string = match.group(1)
|
||||
else:
|
||||
# more than one group in pattern -> raise error
|
||||
raise ValueError(f"Pattern must have at most one group: {pattern}")
|
||||
else:
|
||||
string = ""
|
||||
document_ids = parse_references(string=string, reference_pattern=reference_pattern, candidates=candidates)
|
||||
answer = Answer(answer=string, type="generative", document_ids=document_ids, meta={"prompt": prompt})
|
||||
return answer
|
||||
|
||||
|
||||
def parse_references(
|
||||
string: str, reference_pattern: Optional[str] = None, candidates: Optional[Dict[str, str]] = None
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Parses an answer string for document references and returns the document ids of the referenced documents.
|
||||
|
||||
:param string: The string to parse.
|
||||
:param reference_pattern: The regex pattern to use for parsing the document references.
|
||||
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
|
||||
If None, no parsing is done and all candidate document ids are returned.
|
||||
:param candidates: A dictionary of candidates to choose from. The keys are the reference strings and the values are the document ids.
|
||||
If None, no parsing is done and None is returned.
|
||||
:return: A list of document ids.
|
||||
"""
|
||||
if not candidates:
|
||||
return None
|
||||
if not reference_pattern:
|
||||
return list(candidates.values())
|
||||
|
||||
document_idxs = re.findall(reference_pattern, string)
|
||||
return [candidates[idx] for idx in document_idxs if idx in candidates]
|
||||
|
||||
|
||||
def answers_to_strings(
|
||||
answers: List[Answer], pattern: Optional[str] = None, str_replace: Optional[Dict[str, str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extracts the content field of Documents and returns a list of strings.
|
||||
|
||||
@ -117,18 +449,20 @@ def answers_to_strings(answers: List[Answer]) -> Tuple[List[str]]:
|
||||
Answer(answer="first"),
|
||||
Answer(answer="second"),
|
||||
Answer(answer="third")
|
||||
]
|
||||
) == (["first", "second", "third"],)
|
||||
],
|
||||
pattern="[$idx] $answer",
|
||||
str_replace={"r": "R"}
|
||||
) == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
```
|
||||
"""
|
||||
return ([answer.answer for answer in answers],)
|
||||
return [format_answer(answer, pattern, str_replace, idx) for idx, answer in enumerate(answers, start=1)]
|
||||
|
||||
|
||||
def strings_to_documents(
|
||||
strings: List[str],
|
||||
meta: Union[List[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[List[Document]]:
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Transforms a list of strings into a list of Documents. If you pass the metadata in a single
|
||||
dictionary, all Documents get the same metadata. If you pass the metadata as a list, the length of this list
|
||||
@ -142,11 +476,11 @@ def strings_to_documents(
|
||||
strings=["first", "second", "third"],
|
||||
meta=[{"position": i} for i in range(3)],
|
||||
id_hash_keys=['content', 'meta]
|
||||
) == ([
|
||||
) == [
|
||||
Document(content="first", metadata={"position": 1}, id_hash_keys=['content', 'meta])]),
|
||||
Document(content="second", metadata={"position": 2}, id_hash_keys=['content', 'meta]),
|
||||
Document(content="third", metadata={"position": 3}, id_hash_keys=['content', 'meta])
|
||||
], )
|
||||
]
|
||||
```
|
||||
"""
|
||||
all_metadata: List[Optional[Dict[str, Any]]]
|
||||
@ -161,10 +495,12 @@ def strings_to_documents(
|
||||
else:
|
||||
all_metadata = [None] * len(strings)
|
||||
|
||||
return ([Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)],)
|
||||
return [Document(content=string, meta=m, id_hash_keys=id_hash_keys) for string, m in zip(strings, all_metadata)]
|
||||
|
||||
|
||||
def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
|
||||
def documents_to_strings(
|
||||
documents: List[Document], pattern: Optional[str] = None, str_replace: Optional[Dict[str, str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extracts the content field of Documents and returns a list of strings.
|
||||
|
||||
@ -176,14 +512,16 @@ def documents_to_strings(documents: List[Document]) -> Tuple[List[str]]:
|
||||
Document(content="first"),
|
||||
Document(content="second"),
|
||||
Document(content="third")
|
||||
]
|
||||
) == (["first", "second", "third"],)
|
||||
],
|
||||
pattern="[$idx] $content",
|
||||
str_replace={"r": "R"}
|
||||
) == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
```
|
||||
"""
|
||||
return ([doc.content for doc in documents],)
|
||||
return [format_document(doc, pattern, str_replace, idx) for idx, doc in enumerate(documents, start=1)]
|
||||
|
||||
|
||||
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Tuple[Any]]] = {
|
||||
REGISTERED_FUNCTIONS: Dict[str, Callable[..., Any]] = {
|
||||
"rename": rename,
|
||||
"value_to_list": value_to_list,
|
||||
"join_lists": join_lists,
|
||||
@ -389,6 +727,16 @@ class Shaper(BaseComponent):
|
||||
if value in invocation_context.keys() and value is not None:
|
||||
input_values[key] = invocation_context[value]
|
||||
|
||||
# auto fill in input values if there's an invocation context value with the same name
|
||||
function_params = inspect.signature(self.function).parameters
|
||||
for parameter in function_params.values():
|
||||
if (
|
||||
parameter.name not in input_values.keys()
|
||||
and parameter.name not in self.params.keys()
|
||||
and parameter.name in invocation_context.keys()
|
||||
):
|
||||
input_values[parameter.name] = invocation_context[parameter.name]
|
||||
|
||||
input_values = {**self.params, **input_values}
|
||||
try:
|
||||
logger.debug(
|
||||
@ -397,6 +745,8 @@ class Shaper(BaseComponent):
|
||||
", ".join([f"{key}={value}" for key, value in input_values.items()]),
|
||||
)
|
||||
output_values = self.function(**input_values)
|
||||
if not isinstance(output_values, tuple):
|
||||
output_values = (output_values,)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
"Shaper couldn't apply the function to your inputs and parameters. "
|
||||
|
||||
@ -1,17 +1,27 @@
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC
|
||||
from string import Template
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type, overload
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
|
||||
from haystack import MultiLabel
|
||||
from haystack.environment import HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS
|
||||
from haystack.errors import NodeError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.other.shaper import ( # pylint: disable=unused-import
|
||||
Shaper,
|
||||
join_documents_to_string as join, # used as shaping function
|
||||
format_document,
|
||||
format_answer,
|
||||
format_string,
|
||||
)
|
||||
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, instruction_following_models
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.schema import Answer, Document
|
||||
from haystack.telemetry_2 import send_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -43,6 +53,169 @@ class BasePromptTemplate(BaseComponent):
|
||||
raise NotImplementedError("This method should never be implemented in the derived class")
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_ALLOWED_FUNCTIONS = ast.literal_eval(
|
||||
os.environ.get(HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS, '["join", "to_strings", "replace", "enumerate", "str"]')
|
||||
)
|
||||
PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS = {"new_line": "\n", "tab": "\t", "double_quote": '"', "carriage_return": "\r"}
|
||||
PROMPT_TEMPLATE_STRIPS = ["'", '"']
|
||||
PROMPT_TEMPLATE_STR_REPLACE = {'"': "'"}
|
||||
|
||||
|
||||
def to_strings(items: List[Union[str, Document, Answer]], pattern=None, str_replace=None) -> List[str]:
|
||||
results = []
|
||||
for idx, item in enumerate(items, start=1):
|
||||
if isinstance(item, str):
|
||||
results.append(format_string(item, str_replace=str_replace))
|
||||
elif isinstance(item, Document):
|
||||
results.append(format_document(document=item, pattern=pattern, str_replace=str_replace, idx=idx))
|
||||
elif isinstance(item, Answer):
|
||||
results.append(format_answer(answer=item, pattern=pattern, str_replace=str_replace, idx=idx))
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
return results
|
||||
|
||||
|
||||
class PromptTemplateValidationError(NodeError):
|
||||
"""
|
||||
Error raised when a prompt template is invalid.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _ValidationVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
This class is used to validate the prompt text for a prompt template.
|
||||
It checks that the prompt text is a valid f-string and that it only uses allowed functions.
|
||||
Useful information extracted from the AST is stored in the class attributes (e.g. `prompt_params` and `used_functions`)
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_template_name: str):
|
||||
self.used_names: List[str] = []
|
||||
self.comprehension_targets: List[str] = []
|
||||
self.used_functions: List[str] = []
|
||||
self.prompt_template_name = prompt_template_name
|
||||
|
||||
@property
|
||||
def prompt_params(self) -> List[str]:
|
||||
"""
|
||||
The names of the variables used in the prompt text.
|
||||
E.g. for the prompt text `f"Hello {name}"`, the prompt_params would be `["name"]`
|
||||
"""
|
||||
return list(set(self.used_names) - set(self.used_functions) - set(self.comprehension_targets))
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> None:
|
||||
"""
|
||||
Stores the name of the variable used in the prompt text. This also includes function and method names.
|
||||
E.g. for the prompt text `f"Hello {func(name)}"`, the used_names would be `["func", "name"]`
|
||||
"""
|
||||
self.used_names.append(node.id)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> None:
|
||||
"""
|
||||
Stores the name of the variable used in comprehensions.
|
||||
E.g. for the prompt text `f"Hello {[name for name in names]}"`, the comprehension_targets would be `["name"]`
|
||||
"""
|
||||
super().generic_visit(node)
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.comprehension_targets.append(node.target.id)
|
||||
elif isinstance(node.target, ast.Tuple):
|
||||
self.comprehension_targets.extend([elt.id for elt in node.target.elts if isinstance(elt, ast.Name)])
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
"""
|
||||
Stores the name of functions and methods used in the prompt text and validates that only allowed functions are used.
|
||||
E.g. for the prompt text `f"Hello {func(name)}"`, the used_functions would be `["func"]`
|
||||
|
||||
raises: PromptTemplateValidationError if an invalid function is used in the prompt text
|
||||
"""
|
||||
super().generic_visit(node)
|
||||
if isinstance(node.func, ast.Name) and node.func.id in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS:
|
||||
# functions: func(args, kwargs)
|
||||
self.used_functions.append(node.func.id)
|
||||
elif isinstance(node.func, ast.Attribute) and node.func.attr in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS:
|
||||
# methods: instance.method(args, kwargs)
|
||||
self.used_functions.append(node.func.attr)
|
||||
else:
|
||||
raise PromptTemplateValidationError(
|
||||
f"Invalid function in prompt text for prompt template {self.prompt_template_name}. "
|
||||
f"Allowed functions are {PROMPT_TEMPLATE_ALLOWED_FUNCTIONS}."
|
||||
)
|
||||
|
||||
|
||||
class _FstringParamsTransformer(ast.NodeTransformer):
|
||||
"""
|
||||
This class is used to transform an AST for f-strings into a format that can be used by the PromptTemplate.
|
||||
It replaces all f-string expressions with a unique id and stores the corresponding expression in a dictionary.
|
||||
|
||||
The stored expressions can be evaluated using the `eval` function given the `prompt_params` (see _ValidatorVisitor) .
|
||||
PromptTemplate determines the number of prompts to generate and renders them using the evaluated expressions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.prompt_params_functions: Dict[str, ast.Expression] = {}
|
||||
|
||||
def visit_FormattedValue(self, node: ast.FormattedValue) -> Optional[ast.AST]:
|
||||
"""
|
||||
Replaces the f-string expression with a unique id and stores the corresponding expression in a dictionary.
|
||||
If the expression is the raw `documents` variable, it is encapsulated into a call to `documents_to_strings` to ensure that the documents get rendered correctly.
|
||||
"""
|
||||
super().generic_visit(node)
|
||||
|
||||
# Keep special char variables as is. They are available via globals.
|
||||
if isinstance(node.value, ast.Name) and node.value.id in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS:
|
||||
return node
|
||||
|
||||
id = uuid4().hex
|
||||
if isinstance(node.value, ast.Name) and node.value.id in ["documents", "answers"]:
|
||||
call = ast.Call(func=ast.Name(id="to_strings", ctx=ast.Load()), args=[node.value], keywords=[])
|
||||
self.prompt_params_functions[id] = ast.fix_missing_locations(ast.Expression(body=call))
|
||||
else:
|
||||
self.prompt_params_functions[id] = ast.fix_missing_locations(ast.Expression(body=node.value))
|
||||
return ast.FormattedValue(
|
||||
value=ast.Name(id=id, ctx=ast.Load()), conversion=node.conversion, format_spec=node.format_spec
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(Shaper):
|
||||
"""
|
||||
A output parser defines in `PromptTemplate` how to parse the model output and convert it into Haystack primitives.
|
||||
BaseOutputParser is the base class for output parser implementations.
|
||||
"""
|
||||
|
||||
@property
|
||||
def output_variable(self) -> Optional[str]:
|
||||
return self.outputs[0]
|
||||
|
||||
|
||||
class AnswerParser(BaseOutputParser):
|
||||
"""
|
||||
AnswerParser is used to parse the model output to extract the answer into a proper `Answer` object using regex patterns.
|
||||
AnswerParser enriches the `Answer` object with the used prompts and the document_ids of the documents that were used to generate the answer.
|
||||
You can pass a reference_pattern to extract the document_ids of the answer from the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None):
|
||||
"""
|
||||
:param pattern: The regex pattern to use for parsing the answer.
|
||||
Examples:
|
||||
`[^\\n]+$` will find "this is an answer" in string "this is an argument.\nthis is an answer".
|
||||
`Answer: (.*)` will find "this is an answer" in string "this is an argument. Answer: this is an answer".
|
||||
If None, the whole string is used as the answer. If not None, the first group of the regex is used as the answer. If there is no group, the whole match is used as the answer.
|
||||
:param reference_pattern: The regex pattern to use for parsing the document references.
|
||||
Example: `\\[(\\d+)\\]` will find "1" in string "this is an answer[1]".
|
||||
If None, no parsing is done and all documents are referenced.
|
||||
"""
|
||||
self.pattern = pattern
|
||||
self.reference_pattern = reference_pattern
|
||||
super().__init__(
|
||||
func="strings_to_answers",
|
||||
inputs={"strings": "results"},
|
||||
outputs=["answers"],
|
||||
params={"pattern": pattern, "reference_pattern": reference_pattern},
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate, ABC):
|
||||
"""
|
||||
PromptTemplate is a template for a prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what such prompt template may look like:
|
||||
@ -50,54 +223,67 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
```python
|
||||
PromptTemplate(name="sentiment-analysis",
|
||||
prompt_text="Give a sentiment for this context. Answer with positive, negative
|
||||
or neutral. Context: $documents; Answer:")
|
||||
or neutral. Context: {documents}; Answer:")
|
||||
```
|
||||
|
||||
Optionally, you can declare prompt parameters in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
|
||||
the prompt_text for the model to perform the task. For example, in the template above, there's one prompt parameter, `documents`. You declare prompt parameters by adding variables to the prompt text. These variables should be in the format: `$variable`. In the template above, the variable is `$documents`.
|
||||
Optionally, you can declare prompt parameters using f-string syntax in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
|
||||
the prompt_text for the model to perform the task. For example, in the template above, there's one prompt parameter, `documents`. You declare prompt parameters by adding variables to the prompt text. These variables should be in the format: `{variable}`. In the template above, the variable is `{documents}`.
|
||||
|
||||
At runtime, these variables are filled in with arguments passed to the `fill()` method of the PromptTemplate. So in the example above, the `$documents` variable will be filled with the Documents whose sentiment you want the model to analyze.
|
||||
At runtime, these variables are filled in with arguments passed to the `fill()` method of the PromptTemplate. So in the example above, the `{documents}` variable will be filled with the Documents whose sentiment you want the model to analyze.
|
||||
|
||||
Note that other than strict f-string syntax, you can safely use the following backslash characters in text parts of the prompt text: `\n`, `\t`, `\r`.
|
||||
If you want to use them in f-string expressions, use `new_line`, `tab`, `carriage_return` instead.
|
||||
Double quotes (e.g. `"`) will be automatically replaced with single quotes (e.g. `'`) in the prompt text. If you want to use double quotes in the prompt text, use `{double_quote}` instead.
|
||||
|
||||
For more details on how to use PromptTemplate, see
|
||||
[PromptNode](https://docs.haystack.deepset.ai/docs/prompt_node).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, prompt_text: str, prompt_params: Optional[List[str]] = None):
|
||||
def __init__(self, name: str, prompt_text: str, output_parser: Optional[BaseOutputParser] = None):
|
||||
"""
|
||||
Creates a PromptTemplate instance.
|
||||
|
||||
:param name: The name of the prompt template (for example, sentiment-analysis, question-generation). You can specify your own name but it must be unique.
|
||||
:param prompt_text: The prompt text, including prompt parameters.
|
||||
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
|
||||
:param output_parser: A parser that will be applied to the model output.
|
||||
For example, if you want to convert the model output to an Answer object, you can use `AnswerParser`.
|
||||
"""
|
||||
super().__init__()
|
||||
if prompt_params:
|
||||
self.prompt_params = prompt_params
|
||||
else:
|
||||
# Define the regex pattern to match the strings after the $ character
|
||||
pattern = r"\$([a-zA-Z0-9_]+)"
|
||||
self.prompt_params = re.findall(pattern, prompt_text)
|
||||
|
||||
if prompt_text.count("$") != len(self.prompt_params):
|
||||
raise ValueError(
|
||||
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
|
||||
f"does not match the number of specified parameters {self.prompt_params}."
|
||||
)
|
||||
|
||||
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
||||
prompt_text = prompt_text.strip("'").strip('"')
|
||||
for strip in PROMPT_TEMPLATE_STRIPS:
|
||||
prompt_text = prompt_text.strip(strip)
|
||||
replacements = {
|
||||
**{v: "{" + k + "}" for k, v in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS.items()},
|
||||
**PROMPT_TEMPLATE_STR_REPLACE,
|
||||
}
|
||||
for old, new in replacements.items():
|
||||
prompt_text = prompt_text.replace(old, new)
|
||||
|
||||
t = Template(prompt_text)
|
||||
try:
|
||||
t.substitute(**{param: "" for param in self.prompt_params})
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid parameter {e} in prompt text "
|
||||
f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
|
||||
)
|
||||
self._ast_expression = ast.parse(f'f"{prompt_text}"', mode="eval")
|
||||
|
||||
ast_validator = _ValidationVisitor(prompt_template_name=name)
|
||||
ast_validator.visit(self._ast_expression)
|
||||
|
||||
ast_transformer = _FstringParamsTransformer()
|
||||
self._ast_expression = ast.fix_missing_locations(ast_transformer.visit(self._ast_expression))
|
||||
self._prompt_params_functions = ast_transformer.prompt_params_functions
|
||||
self._used_functions = ast_validator.used_functions
|
||||
|
||||
self.name = name
|
||||
self.prompt_text = prompt_text
|
||||
self.prompt_params = sorted(
|
||||
param for param in ast_validator.prompt_params if param not in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS
|
||||
)
|
||||
self.globals = {
|
||||
**{k: v for k, v in globals().items() if k in PROMPT_TEMPLATE_ALLOWED_FUNCTIONS},
|
||||
**PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS,
|
||||
}
|
||||
self.output_parser = output_parser
|
||||
|
||||
@property
|
||||
def output_variable(self) -> Optional[str]:
|
||||
return self.output_parser.output_variable if self.output_parser else None
|
||||
|
||||
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -107,7 +293,7 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
:param kwargs: Keyword arguments to fill the parameters in the prompt text of a PromptTemplate.
|
||||
:return: A dictionary with the prompt text and the prompt parameters.
|
||||
"""
|
||||
template_dict = {}
|
||||
params_dict = {}
|
||||
# attempt to resolve args first
|
||||
if args:
|
||||
if len(args) != len(self.prompt_params):
|
||||
@ -119,28 +305,49 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
args,
|
||||
)
|
||||
for prompt_param, arg in zip(self.prompt_params, args):
|
||||
template_dict[prompt_param] = [arg] if isinstance(arg, str) else arg
|
||||
params_dict[prompt_param] = [arg] if isinstance(arg, str) else arg
|
||||
# then attempt to resolve kwargs
|
||||
if kwargs:
|
||||
for param in self.prompt_params:
|
||||
if param in kwargs:
|
||||
template_dict[param] = kwargs[param]
|
||||
params_dict[param] = kwargs[param]
|
||||
|
||||
if set(template_dict.keys()) != set(self.prompt_params):
|
||||
available_params = set(list(template_dict.keys()) + list(set(kwargs.keys())))
|
||||
if set(params_dict.keys()) != set(self.prompt_params):
|
||||
available_params = set(list(params_dict.keys()) + list(set(kwargs.keys())))
|
||||
raise ValueError(f"Expected prompt parameters {self.prompt_params} but got {list(available_params)}.")
|
||||
|
||||
template_dict = {"_at_least_one_prompt": True}
|
||||
for id, call in self._prompt_params_functions.items():
|
||||
template_dict[id] = eval( # pylint: disable=eval-used
|
||||
compile(call, filename="<string>", mode="eval"), self.globals, params_dict
|
||||
)
|
||||
|
||||
return template_dict
|
||||
|
||||
def post_process(self, prompt_output: List[str], **kwargs) -> List[Any]:
|
||||
"""
|
||||
Post-processes the output of the prompt template.
|
||||
:param args: Non-keyword arguments to use for post-processing the prompt output.
|
||||
:param kwargs: Keyword arguments to use for post-processing the prompt output.
|
||||
:return: A dictionary with the post-processed output.
|
||||
"""
|
||||
if self.output_parser:
|
||||
invocation_context = kwargs
|
||||
invocation_context["results"] = prompt_output
|
||||
self.output_parser.run(invocation_context=invocation_context)
|
||||
return invocation_context[self.output_parser.outputs[0]]
|
||||
else:
|
||||
return prompt_output
|
||||
|
||||
def fill(self, *args, **kwargs) -> Iterator[str]:
|
||||
"""
|
||||
Fills the parameters defined in the prompt text with the arguments passed to it and returns the iterator prompt text.
|
||||
|
||||
You can pass non-keyword (args) or keyword (kwargs) arguments to this method. If you pass non-keyword arguments, their order must match the left-to-right
|
||||
order of appearance of the parameters in the prompt text. For example, if the prompt text is:
|
||||
`Come up with a question for the given context and the answer. Context: $documents;
|
||||
Answer: $answers; Question:`, then the first non-keyword argument fills the `$documents` variable
|
||||
and the second non-keyword argument fills the `$answers` variable.
|
||||
`Come up with a question for the given context and the answer. Context: {documents};
|
||||
Answer: {answers}; Question:`, then the first non-keyword argument fills the `{documents}` variable
|
||||
and the second non-keyword argument fills the `{answers}` variable.
|
||||
|
||||
If you pass keyword arguments, the order of the arguments doesn't matter. Variables in the
|
||||
prompt text are filled with the corresponding keyword argument.
|
||||
@ -150,12 +357,20 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
:return: An iterator of prompt texts.
|
||||
"""
|
||||
template_dict = self.prepare(*args, **kwargs)
|
||||
template = Template(self.prompt_text)
|
||||
|
||||
# the prompt context values should all be lists, as they will be split as one
|
||||
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in template_dict.items()}
|
||||
max_len = max(len(v) for v in prompt_context_copy.values())
|
||||
if max_len > 1:
|
||||
for key, value in prompt_context_copy.items():
|
||||
if len(value) == 1:
|
||||
prompt_context_copy[key] = value * max_len
|
||||
|
||||
for prompt_context_values in zip(*prompt_context_copy.values()):
|
||||
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
|
||||
prompt_prepared: str = template.substitute(template_input)
|
||||
prompt_prepared: str = eval( # pylint: disable=eval-used
|
||||
compile(self._ast_expression, filename="<string>", mode="eval"), self.globals, template_input
|
||||
)
|
||||
yield prompt_prepared
|
||||
|
||||
def __repr__(self):
|
||||
@ -311,66 +526,85 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
|
||||
return [
|
||||
PromptTemplate(
|
||||
name="question-answering",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: "
|
||||
"$questions; Answer:",
|
||||
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
|
||||
"{query}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-per-document",
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: "
|
||||
"{query}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-with-references",
|
||||
prompt_text="Create a concise and informative answer (no more than 50 words) for a given question "
|
||||
"based solely on the given documents. You must only use information from the given documents. "
|
||||
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
|
||||
"If multiple documents contain the answer, cite those documents like ‘as stated in Document[number], Document[number], etc.’. "
|
||||
"If the documents do not contain the answer to the question, say that ‘answering is not possible given the available information.’\n"
|
||||
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
|
||||
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-generation",
|
||||
prompt_text="Given the context please generate a question. Context: $documents; Question:",
|
||||
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conditioned-question-generation",
|
||||
prompt_text="Please come up with a question for the given context and the answer. "
|
||||
"Context: $documents; Answer: $answers; Question:",
|
||||
"Context: {documents}; Answer: {answers}; Question:",
|
||||
),
|
||||
PromptTemplate(name="summarization", prompt_text="Summarize this document: $documents Summary:"),
|
||||
PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
|
||||
PromptTemplate(
|
||||
name="question-answering-check",
|
||||
prompt_text="Does the following context contain the answer to the question? "
|
||||
"Context: $documents; Question: $questions; Please answer yes or no! Answer:",
|
||||
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="multiple-choice-question-answering",
|
||||
prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. "
|
||||
"Options: $options; Answer:",
|
||||
prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. "
|
||||
"Options: {options}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="topic-classification",
|
||||
prompt_text="Categories: $options; What category best describes: $documents; Answer:",
|
||||
prompt_text="Categories: {options}; What category best describes: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="language-detection",
|
||||
prompt_text="Detect the language in the following context and answer with the "
|
||||
"name of the language. Context: $documents; Answer:",
|
||||
"name of the language. Context: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="translation",
|
||||
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
|
||||
prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="zero-shot-react",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"$tool_names_with_descriptions\n\n"
|
||||
"{tool_names_with_descriptions}\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
||||
"for a final answer, respond with the `Final Answer:`\n\n"
|
||||
"Use the following format:\n\n"
|
||||
"Question: the question to be answered\n"
|
||||
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
||||
"Tool: pick one of $tool_names \n"
|
||||
"Tool: pick one of {tool_names} \n"
|
||||
"Tool Input: the input for the tool\n"
|
||||
"Observation: the tool will respond with the result\n"
|
||||
"...\n"
|
||||
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: $query\n"
|
||||
"Question: {query}\n"
|
||||
"Thought: Let's think step-by-step, I first need to ",
|
||||
),
|
||||
]
|
||||
@ -430,6 +664,7 @@ class PromptNode(BaseComponent):
|
||||
:param model_name_or_path: The name of the model to use or an instance of the PromptModel.
|
||||
:param default_prompt_template: The default prompt template to use for the model.
|
||||
:param output_variable: The name of the output variable in which you want to store the inference results.
|
||||
If not set, PromptNode uses PromptTemplate's output_variable. If PromptTemplate's output_variable is not set, default name is `results`.
|
||||
:param max_length: The maximum length of the generated text output.
|
||||
:param api_key: The API key to use for the model.
|
||||
:param use_auth_token: The authentication token to use for the model.
|
||||
@ -449,7 +684,7 @@ class PromptNode(BaseComponent):
|
||||
super().__init__()
|
||||
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
|
||||
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
||||
self.output_variable: str = output_variable or "results"
|
||||
self.output_variable: Optional[str] = output_variable
|
||||
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
||||
self.prompt_model: PromptModel
|
||||
self.stop_words: Optional[List[str]] = stop_words
|
||||
@ -480,7 +715,7 @@ class PromptNode(BaseComponent):
|
||||
else:
|
||||
raise ValueError("model_name_or_path must be either a string or a PromptModel object")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> List[str]:
|
||||
def __call__(self, *args, **kwargs) -> List[Any]:
|
||||
"""
|
||||
This method is invoked when the component is called directly, for example:
|
||||
```python
|
||||
@ -496,7 +731,7 @@ class PromptNode(BaseComponent):
|
||||
else:
|
||||
return self.prompt(self.default_prompt_template, *args, **kwargs)
|
||||
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[Any]:
|
||||
"""
|
||||
Prompts the model and represents the central API for the PromptNode. It takes a prompt template,
|
||||
a list of non-keyword and keyword arguments, and returns a list of strings - the responses from the underlying model.
|
||||
@ -520,12 +755,7 @@ class PromptNode(BaseComponent):
|
||||
kwargs = {**self._prepare_model_kwargs(), **kwargs}
|
||||
prompt_template_used = prompt_template or self.default_prompt_template
|
||||
if prompt_template_used:
|
||||
if isinstance(prompt_template_used, PromptTemplate):
|
||||
template_to_fill = prompt_template_used
|
||||
elif isinstance(prompt_template_used, str):
|
||||
template_to_fill = self.get_prompt_template(prompt_template_used)
|
||||
else:
|
||||
raise ValueError(f"{prompt_template_used} with args {args} , and kwargs {kwargs} not supported")
|
||||
template_to_fill = self.get_prompt_template(prompt_template_used)
|
||||
|
||||
# prompt template used, yield prompts from inputs args
|
||||
for prompt in template_to_fill.fill(*args, **kwargs):
|
||||
@ -536,6 +766,9 @@ class PromptNode(BaseComponent):
|
||||
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy)
|
||||
output = self.prompt_model.invoke(prompt, **kwargs_copy)
|
||||
results.extend(output)
|
||||
|
||||
kwargs["prompts"] = prompt_collector
|
||||
results = template_to_fill.post_process(results, **kwargs)
|
||||
else:
|
||||
# straightforward prompt, no templates used
|
||||
for prompt in list(args):
|
||||
@ -607,15 +840,19 @@ class PromptNode(BaseComponent):
|
||||
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
|
||||
return template_name in self.prompt_templates
|
||||
|
||||
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
|
||||
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> PromptTemplate:
|
||||
"""
|
||||
Returns a prompt template by name.
|
||||
:param prompt_template_name: The name of the prompt template to be returned.
|
||||
:return: The prompt template object.
|
||||
"""
|
||||
if prompt_template_name not in self.prompt_templates:
|
||||
raise ValueError(f"Prompt template {prompt_template_name} not supported")
|
||||
return self.prompt_templates[prompt_template_name]
|
||||
if isinstance(prompt_template, PromptTemplate):
|
||||
return prompt_template
|
||||
|
||||
if not isinstance(prompt_template, str) or prompt_template not in self.prompt_templates:
|
||||
raise ValueError(f"Prompt template {prompt_template} not supported")
|
||||
|
||||
return self.prompt_templates[prompt_template]
|
||||
|
||||
def prompt_template_params(self, prompt_template: str) -> List[str]:
|
||||
"""
|
||||
@ -674,18 +911,14 @@ class PromptNode(BaseComponent):
|
||||
if meta and "meta" not in invocation_context.keys():
|
||||
invocation_context["meta"] = meta
|
||||
|
||||
if "documents" in invocation_context.keys():
|
||||
for doc in invocation_context.get("documents", []):
|
||||
if not isinstance(doc, str) and not isinstance(doc.content, str):
|
||||
raise ValueError("PromptNode only accepts text documents.")
|
||||
invocation_context["documents"] = [
|
||||
doc.content if isinstance(doc, Document) else doc for doc in invocation_context.get("documents", [])
|
||||
]
|
||||
|
||||
results = self(prompt_collector=prompt_collector, **invocation_context)
|
||||
|
||||
invocation_context[self.output_variable] = results
|
||||
final_result: Dict[str, Any] = {self.output_variable: results, "invocation_context": invocation_context}
|
||||
prompt_template = self.get_prompt_template(self.default_prompt_template)
|
||||
output_variable = self.output_variable or prompt_template.output_variable or "results"
|
||||
|
||||
invocation_context[output_variable] = results
|
||||
invocation_context["prompts"] = prompt_collector
|
||||
final_result: Dict[str, Any] = {output_variable: results, "invocation_context": invocation_context}
|
||||
|
||||
if self.debug:
|
||||
final_result["_debug"] = {"prompts_used": prompt_collector}
|
||||
@ -717,15 +950,19 @@ class PromptNode(BaseComponent):
|
||||
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||
:param invocation_contexts: List of invocation contexts.
|
||||
"""
|
||||
prompt_template = self.get_prompt_template(self.default_prompt_template)
|
||||
output_variable = self.output_variable or prompt_template.output_variable or "results"
|
||||
|
||||
inputs = PromptNode._flatten_inputs(queries, documents, invocation_contexts)
|
||||
all_results: Dict[str, List] = {self.output_variable: [], "invocation_contexts": [], "_debug": []}
|
||||
all_results: Dict[str, List] = defaultdict(list)
|
||||
for query, docs, invocation_context in zip(
|
||||
inputs["queries"], inputs["documents"], inputs["invocation_contexts"]
|
||||
):
|
||||
results = self.run(query=query, documents=docs, invocation_context=invocation_context)[0]
|
||||
all_results[self.output_variable].append(results[self.output_variable])
|
||||
all_results["invocation_contexts"].append(all_results["invocation_contexts"])
|
||||
all_results["_debug"].append(all_results["_debug"])
|
||||
all_results[output_variable].append(results[output_variable])
|
||||
all_results["invocation_contexts"].append(results["invocation_context"])
|
||||
if self.debug:
|
||||
all_results["_debug"].append(results["_debug"])
|
||||
return all_results, "output_1"
|
||||
|
||||
def _prepare_model_kwargs(self):
|
||||
|
||||
@ -214,7 +214,7 @@ def test_tool_result_extraction(reader, retriever_with_docs):
|
||||
assert result == "Paris" or result == "Madrid"
|
||||
|
||||
# PromptNode as a Tool
|
||||
pt = PromptTemplate("test", "Here is a question: $query, Answer:")
|
||||
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
|
||||
pn = PromptNode(default_prompt_template=pt)
|
||||
|
||||
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
|
||||
@ -253,8 +253,7 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
|
||||
"Word: Rome\nLength: 4\n"
|
||||
"Word: Arles\nLength: 5\n"
|
||||
"Word: Berlin\nLength: 6\n"
|
||||
"Word: $query?\nLength: ",
|
||||
prompt_params=["query"],
|
||||
"Word: {query}?\nLength: ",
|
||||
),
|
||||
)
|
||||
|
||||
@ -304,8 +303,7 @@ def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
|
||||
"Word: Rome\nLength: 4\n"
|
||||
"Word: Arles\nLength: 5\n"
|
||||
"Word: Berlin\nLength: 6\n"
|
||||
"Word: $query?\nLength: ",
|
||||
prompt_params=["query"],
|
||||
"Word: {query}\nLength: ",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -378,28 +378,28 @@ class MockPromptNode(PromptNode):
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
||||
return [""]
|
||||
|
||||
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
|
||||
if prompt_template_name == "think-step-by-step":
|
||||
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> PromptTemplate:
|
||||
if prompt_template == "think-step-by-step":
|
||||
return PromptTemplate(
|
||||
name="think-step-by-step",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"$tool_names_with_descriptions\n\n"
|
||||
"{tool_names_with_descriptions}\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
||||
"for a final answer, respond with the `Final Answer:`\n\n"
|
||||
"Use the following format:\n\n"
|
||||
"Question: the question to be answered\n"
|
||||
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
||||
"Tool: [$tool_names]\n"
|
||||
"Tool: [{tool_names}]\n"
|
||||
"Tool Input: the input for the tool\n"
|
||||
"Observation: the tool will respond with the result\n"
|
||||
"...\n"
|
||||
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: $query\n"
|
||||
"Thought: Let's think step-by-step, I first need to $generated_text",
|
||||
"Question: {query}\n"
|
||||
"Thought: Let's think step-by-step, I first need to {generated_text}",
|
||||
)
|
||||
else:
|
||||
return PromptTemplate(name="", prompt_text="")
|
||||
|
||||
@ -140,8 +140,7 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs):
|
||||
name="lfqa",
|
||||
prompt_text="""
|
||||
Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question.
|
||||
\n===\Paragraphs: $context\n===\n$query""",
|
||||
prompt_params=["context", "query"],
|
||||
\n===\Paragraphs: {context}\n===\n{query}""",
|
||||
)
|
||||
node = OpenAIAnswerGenerator(
|
||||
api_key=haystack_openai_config["api_key"],
|
||||
|
||||
@ -11,7 +11,7 @@ from haystack.nodes.retriever.sparse import BM25Retriever
|
||||
@pytest.fixture
|
||||
def mock_function(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
haystack.nodes.other.shaper, "REGISTERED_FUNCTIONS", {"test_function": lambda a, b: ([a] * len(b),)}
|
||||
haystack.nodes.other.shaper, "REGISTERED_FUNCTIONS", {"test_function": lambda a, b: [a] * len(b)}
|
||||
)
|
||||
|
||||
|
||||
@ -293,6 +293,17 @@ def test_join_strings_default_delimiter():
|
||||
assert results["invocation_context"]["single_string"] == "first second"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_strings_with_str_replace():
|
||||
shaper = Shaper(
|
||||
func="join_strings",
|
||||
params={"strings": ["first", "second", "third"], "delimiter": " - ", "str_replace": {"r": "R"}},
|
||||
outputs=["single_string"],
|
||||
)
|
||||
results, _ = shaper.run()
|
||||
assert results["invocation_context"]["single_string"] == "fiRst - second - thiRd"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_strings_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -351,6 +362,38 @@ def test_join_strings_default_delimiter_yaml(tmp_path):
|
||||
assert result["invocation_context"]["single_string"] == "first second third"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_strings_with_str_replace_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: join_strings
|
||||
inputs:
|
||||
strings: documents
|
||||
outputs:
|
||||
- single_string
|
||||
params:
|
||||
delimiter: ' - '
|
||||
str_replace:
|
||||
r: R
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(documents=["first", "second", "third"])
|
||||
assert result["invocation_context"]["single_string"] == "fiRst - second - thiRd"
|
||||
|
||||
|
||||
#
|
||||
# join_documents
|
||||
#
|
||||
@ -407,6 +450,20 @@ def test_join_documents_default_delimiter():
|
||||
assert results["invocation_context"]["documents"] == [Document(content="first second third")]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_documents_with_pattern_and_str_replace():
|
||||
shaper = Shaper(
|
||||
func="join_documents",
|
||||
inputs={"documents": "documents"},
|
||||
outputs=["documents"],
|
||||
params={"delimiter": " - ", "pattern": "[$idx] $content", "str_replace": {"r": "R"}},
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert results["invocation_context"]["documents"] == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_documents_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -472,19 +529,213 @@ def test_join_documents_default_delimiter_yaml(tmp_path):
|
||||
assert result["invocation_context"]["documents"] == [Document(content="first second third")]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_join_documents_with_pattern_and_str_replace_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: join_documents
|
||||
inputs:
|
||||
documents: documents
|
||||
outputs:
|
||||
- documents
|
||||
params:
|
||||
delimiter: ' - '
|
||||
pattern: '[$idx] $content'
|
||||
str_replace:
|
||||
r: R
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(
|
||||
query="test query", documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert result["invocation_context"]["documents"] == [Document(content="[1] fiRst - [2] second - [3] thiRd")]
|
||||
|
||||
|
||||
#
|
||||
# strings_to_answers
|
||||
#
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_no_meta_no_hashkeys():
|
||||
def test_strings_to_answers_simple():
|
||||
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
|
||||
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"]})
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative"),
|
||||
Answer(answer="second", type="generative"),
|
||||
Answer(answer="third", type="generative"),
|
||||
Answer(answer="first", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="second", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="third", type="generative", meta={"prompt": None}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_prompt():
|
||||
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
|
||||
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"], "prompts": ["test prompt"]})
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": "test prompt"}),
|
||||
Answer(answer="second", type="generative", meta={"prompt": "test prompt"}),
|
||||
Answer(answer="third", type="generative", meta={"prompt": "test prompt"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_documents():
|
||||
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first", "second", "third"],
|
||||
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="second", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="third", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_prompt_per_document():
|
||||
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first", "second"],
|
||||
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
|
||||
"prompts": ["prompt1", "prompt2"],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
|
||||
Answer(answer="second", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_prompt_per_document_multiple_results():
|
||||
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first", "second", "third", "fourth"],
|
||||
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
|
||||
"prompts": ["prompt1", "prompt2"],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
|
||||
Answer(answer="second", type="generative", meta={"prompt": "prompt1"}, document_ids=["123"]),
|
||||
Answer(answer="third", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": "prompt2"}, document_ids=["456"]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_pattern_group():
|
||||
shaper = Shaper(
|
||||
func="strings_to_answers",
|
||||
inputs={"strings": "responses"},
|
||||
outputs=["answers"],
|
||||
params={"pattern": r"Answer: (.*)"},
|
||||
)
|
||||
results, _ = shaper.run(invocation_context={"responses": ["Answer: first", "Answer: second", "Answer: third"]})
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="second", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="third", type="generative", meta={"prompt": None}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_pattern_no_group():
|
||||
shaper = Shaper(
|
||||
func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"], params={"pattern": r"[^\n]+$"}
|
||||
)
|
||||
results, _ = shaper.run(invocation_context={"responses": ["Answer\nfirst", "Answer\nsecond", "Answer\n\nthird"]})
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="second", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="third", type="generative", meta={"prompt": None}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_references_index():
|
||||
shaper = Shaper(
|
||||
func="strings_to_answers",
|
||||
inputs={"strings": "responses", "documents": "documents"},
|
||||
outputs=["answers"],
|
||||
params={"reference_pattern": r"\[(\d+)\]"},
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first[1]", "second[2]", "third[1][2]", "fourth"],
|
||||
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first[1]", type="generative", meta={"prompt": None}, document_ids=["123"]),
|
||||
Answer(answer="second[2]", type="generative", meta={"prompt": None}, document_ids=["456"]),
|
||||
Answer(answer="third[1][2]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_references_id():
|
||||
shaper = Shaper(
|
||||
func="strings_to_answers",
|
||||
inputs={"strings": "responses", "documents": "documents"},
|
||||
outputs=["answers"],
|
||||
params={"reference_pattern": r"\[(\d+)\]", "reference_mode": "id"},
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first[123]", "second[456]", "third[123][456]", "fourth"],
|
||||
"documents": [Document(id="123", content="test"), Document(id="456", content="test")],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first[123]", type="generative", meta={"prompt": None}, document_ids=["123"]),
|
||||
Answer(answer="second[456]", type="generative", meta={"prompt": None}, document_ids=["456"]),
|
||||
Answer(answer="third[123][456]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_references_meta():
|
||||
shaper = Shaper(
|
||||
func="strings_to_answers",
|
||||
inputs={"strings": "responses", "documents": "documents"},
|
||||
outputs=["answers"],
|
||||
params={"reference_pattern": r"\[([^\]]+)\]", "reference_mode": "meta", "reference_meta_field": "file_id"},
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
invocation_context={
|
||||
"responses": ["first[123.txt]", "second[456.txt]", "third[123.txt][456.txt]", "fourth"],
|
||||
"documents": [
|
||||
Document(id="123", content="test", meta={"file_id": "123.txt"}),
|
||||
Document(id="456", content="test", meta={"file_id": "456.txt"}),
|
||||
],
|
||||
}
|
||||
)
|
||||
assert results["invocation_context"]["answers"] == [
|
||||
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
|
||||
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
|
||||
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
|
||||
]
|
||||
|
||||
|
||||
@ -514,17 +765,145 @@ def test_strings_to_answers_yaml(tmp_path):
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run()
|
||||
assert result["invocation_context"]["answers"] == [
|
||||
Answer(answer="a", type="generative"),
|
||||
Answer(answer="b", type="generative"),
|
||||
Answer(answer="c", type="generative"),
|
||||
Answer(answer="a", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="b", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="c", type="generative", meta={"prompt": None}),
|
||||
]
|
||||
assert result["answers"] == [
|
||||
Answer(answer="a", type="generative"),
|
||||
Answer(answer="b", type="generative"),
|
||||
Answer(answer="c", type="generative"),
|
||||
Answer(answer="a", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="b", type="generative", meta={"prompt": None}),
|
||||
Answer(answer="c", type="generative", meta={"prompt": None}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strings_to_answers_with_reference_meta_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: strings_to_answers
|
||||
inputs:
|
||||
documents: documents
|
||||
params:
|
||||
reference_meta_field: file_id
|
||||
reference_mode: meta
|
||||
reference_pattern: \[([^\]]+)\]
|
||||
strings: ['first[123.txt]', 'second[456.txt]', 'third[123.txt][456.txt]', 'fourth']
|
||||
outputs:
|
||||
- answers
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(
|
||||
documents=[
|
||||
Document(id="123", content="test", meta={"file_id": "123.txt"}),
|
||||
Document(id="456", content="test", meta={"file_id": "456.txt"}),
|
||||
]
|
||||
)
|
||||
assert result["invocation_context"]["answers"] == [
|
||||
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
|
||||
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
|
||||
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
|
||||
]
|
||||
assert result["answers"] == [
|
||||
Answer(answer="first[123.txt]", type="generative", meta={"prompt": None}, document_ids=["123"]),
|
||||
Answer(answer="second[456.txt]", type="generative", meta={"prompt": None}, document_ids=["456"]),
|
||||
Answer(answer="third[123.txt][456.txt]", type="generative", meta={"prompt": None}, document_ids=["123", "456"]),
|
||||
Answer(answer="fourth", type="generative", meta={"prompt": None}, document_ids=[]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_strings_to_answers_after_prompt_node_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: prompt_model
|
||||
type: PromptModel
|
||||
|
||||
- name: prompt_template_raw_qa_per_document
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: raw-question-answering-per-document
|
||||
prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
|
||||
|
||||
- name: prompt_node_raw_qa
|
||||
type: PromptNode
|
||||
params:
|
||||
model_name_or_path: prompt_model
|
||||
default_prompt_template: prompt_template_raw_qa_per_document
|
||||
top_k: 2
|
||||
|
||||
- name: prompt_node_question_generation
|
||||
type: PromptNode
|
||||
params:
|
||||
model_name_or_path: prompt_model
|
||||
default_prompt_template: question-generation
|
||||
output_variable: query
|
||||
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: strings_to_answers
|
||||
inputs:
|
||||
strings: results
|
||||
outputs:
|
||||
- answers
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: prompt_node_question_generation
|
||||
inputs:
|
||||
- Query
|
||||
- name: prompt_node_raw_qa
|
||||
inputs:
|
||||
- prompt_node_question_generation
|
||||
- name: shaper
|
||||
inputs:
|
||||
- prompt_node_raw_qa
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(
|
||||
query="What's Berlin like?",
|
||||
documents=[
|
||||
Document("Berlin is an amazing city.", id="123"),
|
||||
Document("Berlin is a cool city in Germany.", id="456"),
|
||||
],
|
||||
)
|
||||
results = result["answers"]
|
||||
assert len(results) == 4
|
||||
assert any([True for r in results if "Berlin" in r.answer])
|
||||
for answer in results[:2]:
|
||||
assert answer.document_ids == ["123"]
|
||||
assert (
|
||||
answer.meta["prompt"]
|
||||
== f"Given the context please answer the question. Context: Berlin is an amazing city.; Question: {result['query'][0]}; Answer:"
|
||||
)
|
||||
for answer in results[2:]:
|
||||
assert answer.document_ids == ["456"]
|
||||
assert (
|
||||
answer.meta["prompt"]
|
||||
== f"Given the context please answer the question. Context: Berlin is a cool city in Germany.; Question: {result['query'][1]}; Answer:"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# answers_to_strings
|
||||
#
|
||||
@ -537,6 +916,18 @@ def test_answers_to_strings():
|
||||
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_answers_to_strings_with_pattern_and_str_replace():
|
||||
shaper = Shaper(
|
||||
func="answers_to_strings",
|
||||
inputs={"answers": "documents"},
|
||||
outputs=["strings"],
|
||||
params={"pattern": "[$idx] $answer", "str_replace": {"r": "R"}},
|
||||
)
|
||||
results, _ = shaper.run(documents=[Answer(answer="first"), Answer(answer="second"), Answer(answer="third")])
|
||||
assert results["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_answers_to_strings_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -565,6 +956,38 @@ def test_answers_to_strings_yaml(tmp_path):
|
||||
assert result["invocation_context"]["strings"] == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_answers_to_strings_with_pattern_and_str_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: answers_to_strings
|
||||
inputs:
|
||||
answers: documents
|
||||
outputs:
|
||||
- strings
|
||||
params:
|
||||
pattern: '[$idx] $answer'
|
||||
str_replace:
|
||||
r: R
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(documents=[Answer(answer="first"), Answer(answer="second"), Answer(answer="third")])
|
||||
assert result["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
|
||||
|
||||
#
|
||||
# strings_to_documents
|
||||
#
|
||||
@ -722,6 +1145,20 @@ def test_documents_to_strings():
|
||||
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_documents_to_strings_with_pattern_and_str_replace():
|
||||
shaper = Shaper(
|
||||
func="documents_to_strings",
|
||||
inputs={"documents": "documents"},
|
||||
outputs=["strings"],
|
||||
params={"pattern": "[$idx] $content", "str_replace": {"r": "R"}},
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert results["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_documents_to_strings_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -750,6 +1187,38 @@ def test_documents_to_strings_yaml(tmp_path):
|
||||
assert result["invocation_context"]["strings"] == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_documents_to_strings_with_pattern_and_str_replace_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: documents_to_strings
|
||||
inputs:
|
||||
documents: documents
|
||||
outputs:
|
||||
- strings
|
||||
params:
|
||||
pattern: '[$idx] $content'
|
||||
str_replace:
|
||||
r: R
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(documents=[Document(content="first"), Document(content="second"), Document(content="third")])
|
||||
assert result["invocation_context"]["strings"] == ["[1] fiRst", "[2] second", "[3] thiRd"]
|
||||
|
||||
|
||||
#
|
||||
# Chaining and real-world usage
|
||||
#
|
||||
@ -923,32 +1392,19 @@ def test_with_prompt_node(tmp_path):
|
||||
- name: prompt_model
|
||||
type: PromptModel
|
||||
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: value_to_list
|
||||
inputs:
|
||||
value: query
|
||||
target_list: documents
|
||||
outputs:
|
||||
- questions
|
||||
|
||||
- name: prompt_node
|
||||
type: PromptNode
|
||||
params:
|
||||
output_variable: answers
|
||||
model_name_or_path: prompt_model
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
- name: prompt_node
|
||||
inputs:
|
||||
- shaper
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
@ -957,10 +1413,8 @@ def test_with_prompt_node(tmp_path):
|
||||
documents=[Document("Berlin is an amazing city."), Document("Berlin is a cool city in Germany.")],
|
||||
)
|
||||
assert len(result["answers"]) == 2
|
||||
assert any(word for word in ["berlin", "germany", "cool", "city", "amazing"] if word in result["answers"])
|
||||
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["invocation_context"]["questions"]) == 2
|
||||
raw_answers = [answer.answer for answer in result["answers"]]
|
||||
assert any(word for word in ["berlin", "germany", "cool", "city", "amazing"] if word in raw_answers)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -973,15 +1427,6 @@ def test_with_multiple_prompt_nodes(tmp_path):
|
||||
- name: prompt_model
|
||||
type: PromptModel
|
||||
|
||||
- name: shaper
|
||||
type: Shaper
|
||||
params:
|
||||
func: value_to_list
|
||||
inputs:
|
||||
value: query
|
||||
target_list: documents
|
||||
outputs: [questions]
|
||||
|
||||
- name: renamer
|
||||
type: Shaper
|
||||
params:
|
||||
@ -989,13 +1434,13 @@ def test_with_multiple_prompt_nodes(tmp_path):
|
||||
inputs:
|
||||
value: new-questions
|
||||
outputs:
|
||||
- questions
|
||||
- query
|
||||
|
||||
- name: prompt_node
|
||||
type: PromptNode
|
||||
params:
|
||||
model_name_or_path: prompt_model
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
|
||||
- name: prompt_node_second
|
||||
type: PromptNode
|
||||
@ -1007,19 +1452,15 @@ def test_with_multiple_prompt_nodes(tmp_path):
|
||||
- name: prompt_node_third
|
||||
type: PromptNode
|
||||
params:
|
||||
output_variable: answers
|
||||
model_name_or_path: google/flan-t5-small
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: shaper
|
||||
inputs:
|
||||
- Query
|
||||
- name: prompt_node
|
||||
inputs:
|
||||
- shaper
|
||||
- Query
|
||||
- name: prompt_node_second
|
||||
inputs:
|
||||
- prompt_node
|
||||
@ -1038,7 +1479,7 @@ def test_with_multiple_prompt_nodes(tmp_path):
|
||||
)
|
||||
results = result["answers"]
|
||||
assert len(results) == 2
|
||||
assert any([True for r in results if "Berlin" in r])
|
||||
assert any([True for r in results if "Berlin" in r.answer])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
from typing import Optional, Set, Type, Union, List, Dict, Any, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -9,7 +9,9 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.prompt_node import PromptTemplateValidationError
|
||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
|
||||
from haystack.schema import Answer
|
||||
|
||||
|
||||
def skip_test_for_invalid_key(prompt_model):
|
||||
@ -57,47 +59,36 @@ def get_api_key(request):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates():
|
||||
p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"])
|
||||
p = PromptTemplate("t1", "Here is some fake template with variable {foo}")
|
||||
assert set(p.prompt_params) == {"foo"}
|
||||
|
||||
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo"])
|
||||
p = PromptTemplate("t3", "Here is some fake template with variable {foo} and {bar}")
|
||||
assert set(p.prompt_params) == {"foo", "bar"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid parameter"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $footur", ["foo"])
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {foo1} and {bar2}")
|
||||
assert set(p.prompt_params) == {"foo1", "bar2"}
|
||||
|
||||
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo", "bar", "baz"])
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {foo_1} and {bar_2}")
|
||||
assert set(p.prompt_params) == {"foo_1", "bar_2"}
|
||||
|
||||
p = PromptTemplate("t3", "Here is some fake template with variable $for and $bar", ["for", "bar"])
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {Foo_1} and {Bar_2}")
|
||||
assert set(p.prompt_params) == {"Foo_1", "Bar_2"}
|
||||
|
||||
# last parameter: "prompt_params" can be omitted
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo and $bar")
|
||||
assert p.prompt_params == ["foo", "bar"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo1 and $bar2")
|
||||
assert p.prompt_params == ["foo1", "bar2"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo_1 and $bar_2")
|
||||
assert p.prompt_params == ["foo_1", "bar_2"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $Foo_1 and $Bar_2")
|
||||
assert p.prompt_params == ["Foo_1", "Bar_2"]
|
||||
|
||||
p = PromptTemplate("t4", "'Here is some fake template with variable $baz'")
|
||||
assert p.prompt_params == ["baz"]
|
||||
p = PromptTemplate("t4", "'Here is some fake template with variable {baz}'")
|
||||
assert set(p.prompt_params) == {"baz"}
|
||||
# strip single quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable $baz"
|
||||
assert p.prompt_text == "Here is some fake template with variable {baz}"
|
||||
|
||||
p = PromptTemplate("t4", '"Here is some fake template with variable $baz"')
|
||||
assert p.prompt_params == ["baz"]
|
||||
p = PromptTemplate("t4", '"Here is some fake template with variable {baz}"')
|
||||
assert set(p.prompt_params) == {"baz"}
|
||||
# strip double quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable $baz"
|
||||
assert p.prompt_text == "Here is some fake template with variable {baz}"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_template_repr():
|
||||
p = PromptTemplate("t", "Here is variable $baz")
|
||||
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable $baz, prompt_params=['baz'])"
|
||||
p = PromptTemplate("t", "Here is variable {baz}")
|
||||
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
|
||||
assert repr(p) == desired_repr
|
||||
assert str(p) == desired_repr
|
||||
|
||||
@ -178,9 +169,7 @@ def test_create_prompt_node():
|
||||
@pytest.mark.integration
|
||||
def test_add_and_remove_template(prompt_node):
|
||||
num_default_tasks = len(prompt_node.get_prompt_template_names())
|
||||
custom_task = PromptTemplate(
|
||||
name="custom-task", prompt_text="Custom task: $param1, $param2", prompt_params=["param1", "param2"]
|
||||
)
|
||||
custom_task = PromptTemplate(name="custom-task", prompt_text="Custom task: {param1}, {param2}")
|
||||
prompt_node.add_prompt_template(custom_task)
|
||||
assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1
|
||||
assert "custom-task" in prompt_node.get_prompt_template_names()
|
||||
@ -189,24 +178,12 @@ def test_add_and_remove_template(prompt_node):
|
||||
assert "custom-task" not in prompt_node.get_prompt_template_names()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_template():
|
||||
with pytest.raises(ValueError, match="Invalid parameter"):
|
||||
PromptTemplate(
|
||||
name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="The number of parameters in prompt text"):
|
||||
PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_template_and_invoke(prompt_node):
|
||||
tt = PromptTemplate(
|
||||
name="sentiment-analysis-new",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
prompt_node.add_prompt_template(tt)
|
||||
|
||||
@ -219,8 +196,7 @@ def test_on_the_fly_prompt(prompt_node):
|
||||
prompt_template = PromptTemplate(
|
||||
name="sentiment-analysis-temp",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
r = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
assert r[0].casefold() == "positive"
|
||||
@ -250,12 +226,12 @@ def test_question_generation(prompt_node):
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_template_selection(prompt_node):
|
||||
qa = prompt_node.set_default_prompt_template("question-answering")
|
||||
qa = prompt_node.set_default_prompt_template("question-answering-per-document")
|
||||
r = qa(
|
||||
["Berlin is the capital of Germany.", "Paris is the capital of France."],
|
||||
["What is the capital of Germany?", "What is the capital of France"],
|
||||
)
|
||||
assert r[0].casefold() == "berlin" and r[1].casefold() == "paris"
|
||||
assert r[0].answer.casefold() == "berlin" and r[1].answer.casefold() == "paris"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -266,14 +242,14 @@ def test_has_supported_template_names(prompt_node):
|
||||
@pytest.mark.integration
|
||||
def test_invalid_template_params(prompt_node):
|
||||
with pytest.raises(ValueError, match="Expected prompt parameters"):
|
||||
prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
prompt_node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_wrong_template_params(prompt_node):
|
||||
with pytest.raises(ValueError, match="Expected prompt parameters"):
|
||||
# with don't have options param, multiple choice QA has
|
||||
prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."])
|
||||
prompt_node.prompt("question-answering-per-document", options=["Berlin is the capital of Germany."])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -298,7 +274,7 @@ def test_invalid_state_ops(prompt_node):
|
||||
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
|
||||
prompt_node.remove_prompt_template("no_such_task_exists")
|
||||
# remove default task
|
||||
prompt_node.remove_prompt_template("question-answering")
|
||||
prompt_node.remove_prompt_template("question-answering-per-document")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -365,7 +341,7 @@ def test_stop_words(prompt_model):
|
||||
|
||||
tt = PromptTemplate(
|
||||
name="question-generation-copy",
|
||||
prompt_text="Given the context please generate a question. Context: $documents; Question:",
|
||||
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
|
||||
)
|
||||
# with custom prompt template
|
||||
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
|
||||
@ -441,15 +417,15 @@ def test_simple_pipeline(prompt_model):
|
||||
def test_complex_pipeline(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="questions")
|
||||
node2 = PromptNode(prompt_model, default_prompt_template="question-answering")
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query")
|
||||
node2 = PromptNode(prompt_model, default_prompt_template="question-answering-per-document")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
|
||||
|
||||
assert "berlin" in result["results"][0].casefold()
|
||||
assert "berlin" in result["answers"][0].answer.casefold()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -457,13 +433,70 @@ def test_complex_pipeline(prompt_model):
|
||||
def test_simple_pipeline_with_topk(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", top_k=2)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query", top_k=2)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
|
||||
|
||||
assert len(result["results"]) == 2
|
||||
assert len(result["query"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_standard_qa(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-answering", top_k=1)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(
|
||||
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
|
||||
documents=[
|
||||
Document("My name is Carla and I live in Berlin", id="1"),
|
||||
Document("My name is Christelle and I live in Paris", id="2"),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result["answers"]) == 1
|
||||
assert "carla" in result["answers"][0].answer.casefold()
|
||||
|
||||
assert result["answers"][0].document_ids == ["1", "2"]
|
||||
assert (
|
||||
result["answers"][0].meta["prompt"]
|
||||
== "Given the context please answer the question. Context: My name is Carla and I live in Berlin My name is Christelle and I live in Paris; "
|
||||
"Question: Who lives in Berlin?; Answer:"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_qa_with_references(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(
|
||||
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
|
||||
documents=[
|
||||
Document("My name is Carla and I live in Berlin", id="1"),
|
||||
Document("My name is Christelle and I live in Paris", id="2"),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result["answers"]) == 1
|
||||
assert "carla, as stated in document[1]" in result["answers"][0].answer.casefold()
|
||||
|
||||
assert result["answers"][0].document_ids == ["1"]
|
||||
assert (
|
||||
result["answers"][0].meta["prompt"]
|
||||
== "Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. "
|
||||
"You must only use information from the given documents. Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
|
||||
"If multiple documents contain the answer, cite those documents like ‘as stated in Document[number], Document[number], etc.’. If the documents do not contain the answer to the question, "
|
||||
"say that ‘answering is not possible given the available information.’\n\nDocument[1]: My name is Carla and I live in Berlin\n\nDocument[2]: My name is Christelle and I live in Paris \n "
|
||||
"Question: Who lives in Berlin?; Answer: "
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -501,8 +534,7 @@ def test_complex_pipeline_with_qa(prompt_model):
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
|
||||
prompt_params=["documents", "query"],
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
@ -517,7 +549,7 @@ def test_complex_pipeline_with_qa(prompt_model):
|
||||
debug=True, # so we can verify that the constructed prompt is returned in debug
|
||||
)
|
||||
|
||||
assert len(result["results"]) == 1
|
||||
assert len(result["results"]) == 2
|
||||
assert "carla" in result["results"][0].casefold()
|
||||
|
||||
# also verify that the PromptNode has included its constructed prompt LLM model input in the returned debug
|
||||
@ -531,17 +563,15 @@ def test_complex_pipeline_with_qa(prompt_model):
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_shared_model():
|
||||
model = PromptModel()
|
||||
node = PromptNode(
|
||||
model_name_or_path=model, default_prompt_template="question-generation", output_variable="questions"
|
||||
)
|
||||
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering")
|
||||
node = PromptNode(model_name_or_path=model, default_prompt_template="question-generation", output_variable="query")
|
||||
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering-per-document")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
|
||||
|
||||
assert result["results"][0] == "Berlin"
|
||||
assert result["answers"][0].answer == "Berlin"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -606,11 +636,11 @@ def test_complex_pipeline_yaml(tmp_path):
|
||||
- name: p1
|
||||
params:
|
||||
default_prompt_template: question-generation
|
||||
output_variable: questions
|
||||
output_variable: query
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -625,11 +655,11 @@ def test_complex_pipeline_yaml(tmp_path):
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
response = result["answers"][0].answer
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["questions"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
assert len(result["query"]) > 0
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -645,12 +675,12 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-generation
|
||||
output_variable: questions
|
||||
output_variable: query
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -665,11 +695,11 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
response = result["answers"][0].answer
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["questions"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
assert len(result["query"]) > 0
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -689,17 +719,17 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
output_variable: query
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -714,11 +744,11 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
response = result["answers"][0].answer
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["questions"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
assert len(result["query"]) > 0
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -767,17 +797,17 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
output_variable: query
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -795,11 +825,11 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
response = result["answers"][0].answer
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["questions"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
assert len(result["query"]) > 0
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True)
|
||||
@ -839,17 +869,17 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel_openai
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
output_variable: query
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -864,11 +894,11 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
|
||||
response = result["results"][0]
|
||||
response = result["answers"][0].answer
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert len(result["questions"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
assert len(result["query"]) > 0
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -882,15 +912,14 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
||||
- name: p1
|
||||
params:
|
||||
default_prompt_template: question-generation
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
- name: p3
|
||||
params:
|
||||
default_prompt_template: question-answering
|
||||
default_prompt_template: question-answering-per-document
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -914,9 +943,7 @@ class TestTokenLimit:
|
||||
@pytest.mark.integration
|
||||
def test_hf_token_limit_warning(self, prompt_node, caplog):
|
||||
prompt_template = PromptTemplate(
|
||||
name="too-long-temp",
|
||||
prompt_text="Repeating text" * 200 + "Docs: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_ = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
@ -929,11 +956,7 @@ class TestTokenLimit:
|
||||
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_openai_token_limit_warning(self, caplog):
|
||||
tt = PromptTemplate(
|
||||
name="too-long-temp",
|
||||
prompt_text="Repeating text" * 200 + "Docs: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
)
|
||||
tt = PromptTemplate(name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:")
|
||||
prompt_node = PromptNode("text-ada-001", max_length=2000, api_key=os.environ.get("OPENAI_API_KEY", ""))
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_ = prompt_node.prompt(tt, documents=["Berlin is an amazing city."])
|
||||
@ -989,8 +1012,7 @@ class TestRunBatch:
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
|
||||
prompt_params=["documents", "query"],
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
@ -1015,6 +1037,252 @@ def test_HFLocalInvocationLayer_supports():
|
||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||
|
||||
|
||||
class TestPromptTemplateSyntax:
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, expected_prompt_params, expected_used_functions",
|
||||
[
|
||||
("{documents}", {"documents"}, set()),
|
||||
("Please answer the question: {documents} Question: how?", {"documents"}, set()),
|
||||
("Please answer the question: {documents} Question: {query}", {"documents", "query"}, set()),
|
||||
("Please answer the question: {documents} {{Question}}: {query}", {"documents", "query"}, set()),
|
||||
(
|
||||
"Please answer the question: {join(documents)} Question: {query.replace('A', 'a')}",
|
||||
{"documents", "query"},
|
||||
{"join", "replace"},
|
||||
),
|
||||
(
|
||||
"Please answer the question: {join(documents, 'delim', {'{': '('})} Question: {query.replace('A', 'a')}",
|
||||
{"documents", "query"},
|
||||
{"join", "replace"},
|
||||
),
|
||||
(
|
||||
'Please answer the question: {join(documents, "delim", {"{": "("})} Question: {query.replace("A", "a")}',
|
||||
{"documents", "query"},
|
||||
{"join", "replace"},
|
||||
),
|
||||
(
|
||||
"Please answer the question: {join(documents, 'delim', {'a': {'b': 'c'}})} Question: {query.replace('A', 'a')}",
|
||||
{"documents", "query"},
|
||||
{"join", "replace"},
|
||||
),
|
||||
(
|
||||
"Please answer the question: {join(document=documents, delimiter='delim', str_replace={'{': '('})} Question: {query.replace('A', 'a')}",
|
||||
{"documents", "query"},
|
||||
{"join", "replace"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_syntax_parser(
|
||||
self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
assert set(prompt_template.prompt_params) == expected_prompt_params
|
||||
assert set(prompt_template._used_functions) == expected_used_functions
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, documents, query, expected_prompts",
|
||||
[
|
||||
("{documents}", [Document("doc1"), Document("doc2")], None, ["doc1", "doc2"]),
|
||||
(
|
||||
"context: {documents} question: how?",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
None,
|
||||
["context: doc1 question: how?", "context: doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"context: {' '.join([d.content for d in documents])} question: how?",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
None,
|
||||
["context: doc1 doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"context: {documents} question: {query}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["context: doc1 question: how?", "context: doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"context: {documents} {{question}}: {query}",
|
||||
[Document("doc1")],
|
||||
"how?",
|
||||
["context: doc1 {question}: how?"],
|
||||
),
|
||||
(
|
||||
"context: {join(documents)} question: {query}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["context: doc1 doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"Please answer the question: {join(documents, ' delim ', '[$idx] $content', {'{': '('})} question: {query}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"Please answer the question: {join(documents=documents, delimiter=' delim ', pattern='[$idx] $content', str_replace={'{': '('})} question: {query}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"Please answer the question: {' delim '.join(['['+str(idx+1)+'] '+d.content.replace('{', '(') for idx, d in enumerate(documents)])} question: {query}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
'Please answer the question: {join(documents, " delim ", "[$idx] $content", {"{": "("})} question: {query}',
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
|
||||
),
|
||||
(
|
||||
"context: {join(documents)} question: {query.replace('how', 'what')}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["context: doc1 doc2 question: what?"],
|
||||
),
|
||||
(
|
||||
"context: {join(documents)[:6]} question: {query.replace('how', 'what').replace('?', '!')}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
"how?",
|
||||
["context: doc1 d question: what!"],
|
||||
),
|
||||
("context", None, None, ["context"]),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_syntax_fill(
|
||||
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, documents, expected_prompts",
|
||||
[
|
||||
("{join(documents)}", [Document("doc1"), Document("doc2")], ["doc1 doc2"]),
|
||||
(
|
||||
"{join(documents, ' delim ', '[$idx] $content', {'c': 'C'})}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
["[1] doC1 delim [2] doC2"],
|
||||
),
|
||||
(
|
||||
"{join(documents, ' delim ', '[$id] $content', {'c': 'C'})}",
|
||||
[Document("doc1", id="123"), Document("doc2", id="456")],
|
||||
["[123] doC1 delim [456] doC2"],
|
||||
),
|
||||
(
|
||||
"{join(documents, ' delim ', '[$file_id] $content', {'c': 'C'})}",
|
||||
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
|
||||
["[123.txt] doC1 delim [456.txt] doC2"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, documents, expected_prompts",
|
||||
[
|
||||
("{to_strings(documents)}", [Document("doc1"), Document("doc2")], ["doc1", "doc2"]),
|
||||
(
|
||||
"{to_strings(documents, '[$idx] $content', {'c': 'C'})}",
|
||||
[Document("doc1"), Document("doc2")],
|
||||
["[1] doC1", "[2] doC2"],
|
||||
),
|
||||
(
|
||||
"{to_strings(documents, '[$id] $content', {'c': 'C'})}",
|
||||
[Document("doc1", id="123"), Document("doc2", id="456")],
|
||||
["[123] doC1", "[456] doC2"],
|
||||
),
|
||||
(
|
||||
"{to_strings(documents, '[$file_id] $content', {'c': 'C'})}",
|
||||
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
|
||||
["[123.txt] doC1", "[456.txt] doC2"],
|
||||
),
|
||||
("{to_strings(documents, '[$file_id] $content', {'c': 'C'})}", ["doc1", "doc2"], ["doC1", "doC2"]),
|
||||
(
|
||||
"{to_strings(documents, '[$idx] $answer', {'c': 'C'})}",
|
||||
[Answer("doc1"), Answer("doc2")],
|
||||
["[1] doC1", "[2] doC2"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, exc_type, expected_exc_match",
|
||||
[
|
||||
("{__import__('os').listdir('.')}", PromptTemplateValidationError, "Invalid function in prompt text"),
|
||||
("{__import__('os')}", PromptTemplateValidationError, "Invalid function in prompt text"),
|
||||
(
|
||||
"{requests.get('https://haystack.deepset.ai/')}",
|
||||
PromptTemplateValidationError,
|
||||
"Invalid function in prompt text",
|
||||
),
|
||||
("{join(__import__('os').listdir('.'))}", PromptTemplateValidationError, "Invalid function in prompt text"),
|
||||
("{for}", SyntaxError, "invalid syntax"),
|
||||
("This is an invalid {variable .", SyntaxError, "f-string: expecting '}'"),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_syntax_init_raises(
|
||||
self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str
|
||||
):
|
||||
with pytest.raises(exc_type, match=expected_exc_match):
|
||||
PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, documents, query, exc_type, expected_exc_match",
|
||||
[("{join}", None, None, ValueError, "Expected prompt parameters")],
|
||||
)
|
||||
def test_prompt_template_syntax_fill_raises(
|
||||
self,
|
||||
prompt_text: str,
|
||||
documents: List[Document],
|
||||
query: str,
|
||||
exc_type: Type[BaseException],
|
||||
expected_exc_match: str,
|
||||
):
|
||||
with pytest.raises(exc_type, match=expected_exc_match):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
next(prompt_template.fill(documents=documents, query=query))
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_text, documents, query, expected_prompts",
|
||||
[
|
||||
("__import__('os').listdir('.')", None, None, ["__import__('os').listdir('.')"]),
|
||||
(
|
||||
"requests.get('https://haystack.deepset.ai/')",
|
||||
None,
|
||||
None,
|
||||
["requests.get('https://haystack.deepset.ai/')"],
|
||||
),
|
||||
("{query}", None, print, ["<built-in function print>"]),
|
||||
("\b\b__import__('os').listdir('.')", None, None, ["\x08\x08__import__('os').listdir('.')"]),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_syntax_fill_ignores_dangerous_input(
|
||||
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_chatgpt_direct_prompting(chatgpt_prompt_model):
|
||||
skip_test_for_invalid_key(chatgpt_prompt_model)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user