haystack/test/prompt/test_prompt_template.py
ZanSara 949b1b63b3
PromptHub integration in PromptNode (#4879)
* initial integration

* upgrade of prompthub

* fix get_prompt_template

* feedback

* add prompthub-py to dependencies

* tests

* mypy

* stray changes

* review feedback

* missing init

* fix test

* move logic in prompttemplate

* linting

* bugfixes

* fix unit tests

* fix cache

* simplify prompttemplate init

* remove unused function

* removing wrong params

* try remove all instances of prompt names

* more tests

* fix agent tests

* more tests

* fix tests

* pylint

* comma

* black

* fix test

* docstring

* review feedback

* review feedback

* fix mocks

* mypy

* fix mocks

* fix reference to missing templates

* feedback

* remove direct references to default template var

* tests

* Update haystack/nodes/prompt/prompt_node.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-05-23 15:22:58 +02:00

397 lines
16 KiB
Python

from typing import Set, Type, List
import textwrap
from unittest.mock import patch, MagicMock
import pytest
import prompthub
from haystack.nodes.prompt import PromptTemplate
from haystack.nodes.prompt.prompt_node import PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError
from haystack.nodes.prompt.shapers import AnswerParser
from haystack.pipelines.base import Pipeline
from haystack.schema import Answer, Document
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [
("deepset/test-prompt", "This is a test prompt. Use your knowledge to answer this question: {question}")
]
yield mock_prompthub
@pytest.mark.unit
def test_prompt_templates_from_hub():
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mock_prompthub:
PromptTemplate("deepset/question-answering")
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)
@pytest.mark.unit
def test_prompt_templates_from_file(tmp_path):
path = tmp_path / "test-prompt.yml"
with open(path, "a") as yamlfile:
yamlfile.write(
textwrap.dedent(
"""
name: deepset/question-answering
prompt_text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
description: A simple prompt to answer a question given a set of documents
tags:
- question-answering
meta:
authors:
- vblagoje
version: v0.1.1
"""
)
)
p = PromptTemplate(str(path.absolute()))
assert p.name == "deepset/question-answering"
assert "Given the context please answer the question" in p.prompt_text
@pytest.mark.unit
def test_prompt_templates_on_the_fly():
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml:
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph:
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
assert p.name == "custom-at-query-time"
mocked_ph.fetch.assert_not_called()
mocked_yaml.safe_load.assert_not_called()
@pytest.mark.unit
def test_custom_prompt_templates():
p = PromptTemplate("Here is some fake template with variable {foo}")
assert set(p.prompt_params) == {"foo"}
p = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
assert set(p.prompt_params) == {"foo", "bar"}
p = PromptTemplate("Here is some fake template with variable {foo1} and {bar2}")
assert set(p.prompt_params) == {"foo1", "bar2"}
p = PromptTemplate("Here is some fake template with variable {foo_1} and {bar_2}")
assert set(p.prompt_params) == {"foo_1", "bar_2"}
p = PromptTemplate("Here is some fake template with variable {Foo_1} and {Bar_2}")
assert set(p.prompt_params) == {"Foo_1", "Bar_2"}
p = PromptTemplate("'Here is some fake template with variable {baz}'")
assert set(p.prompt_params) == {"baz"}
# strip single quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable {baz}"
p = PromptTemplate('"Here is some fake template with variable {baz}"')
assert set(p.prompt_params) == {"baz"}
# strip double quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable {baz}"
@pytest.mark.unit
def test_missing_prompt_template_params():
template = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
# both params provided - ok
template.prepare(foo="foo", bar="bar")
# missing one param
with pytest.raises(ValueError, match=r".*parameters \['bar', 'foo'\] to be provided but got only \['foo'\].*"):
template.prepare(foo="foo")
# missing both params
with pytest.raises(
ValueError, match=r".*parameters \['bar', 'foo'\] to be provided but got none of these parameters.*"
):
template.prepare(lets="go")
# more than both params provided - also ok
template.prepare(foo="foo", bar="bar", lets="go")
@pytest.mark.unit
def test_prompt_template_repr():
p = PromptTemplate("Here is variable {baz}")
desired_repr = (
"PromptTemplate(name=custom-at-query-time, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
)
assert repr(p) == desired_repr
assert str(p) == desired_repr
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_template_deserialization(mock_prompt_model):
custom_prompt_template = PromptTemplate(
"Given the context please answer the question. Context: {context}; Question: {query}; Answer:",
output_parser=AnswerParser(),
)
prompt_node = PromptNode(default_prompt_template=custom_prompt_template)
pipe = Pipeline()
pipe.add_node(component=prompt_node, name="Generator", inputs=["Query"])
config = pipe.get_config()
loaded_pipe = Pipeline.load_from_config(config)
loaded_generator = loaded_pipe.get_node("Generator")
assert isinstance(loaded_generator, PromptNode)
assert isinstance(loaded_generator.default_prompt_template, PromptTemplate)
assert (
loaded_generator.default_prompt_template.prompt_text
== "Given the context please answer the question. Context: {context}; Question: {query}; Answer:"
)
assert isinstance(loaded_generator.default_prompt_template.output_parser, AnswerParser)
class TestPromptTemplateSyntax:
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, expected_prompt_params, expected_used_functions",
[
("{documents}", {"documents"}, set()),
("Please answer the question: {documents} Question: how?", {"documents"}, set()),
("Please answer the question: {documents} Question: {query}", {"documents", "query"}, set()),
("Please answer the question: {documents} {{Question}}: {query}", {"documents", "query"}, set()),
(
"Please answer the question: {join(documents)} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(documents, 'delim', {'{': '('})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
'Please answer the question: {join(documents, "delim", {"{": "("})} Question: {query.replace("A", "a")}',
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(documents, 'delim', {'a': {'b': 'c'}})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
(
"Please answer the question: {join(document=documents, delimiter='delim', str_replace={'{': '('})} Question: {query.replace('A', 'a')}",
{"documents", "query"},
{"join", "replace"},
),
],
)
def test_prompt_template_syntax_parser(
self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str]
):
prompt_template = PromptTemplate(prompt_text)
assert set(prompt_template.prompt_params) == expected_prompt_params
assert set(prompt_template._used_functions) == expected_used_functions
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, expected_prompts",
[
("{documents}", [Document("doc1"), Document("doc2")], None, ["doc1", "doc2"]),
(
"context: {documents} question: how?",
[Document("doc1"), Document("doc2")],
None,
["context: doc1 question: how?", "context: doc2 question: how?"],
),
(
"context: {' '.join([d.content for d in documents])} question: how?",
[Document("doc1"), Document("doc2")],
None,
["context: doc1 doc2 question: how?"],
),
(
"context: {documents} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 question: how?", "context: doc2 question: how?"],
),
(
"context: {documents} {{question}}: {query}",
[Document("doc1")],
"how?",
["context: doc1 {question}: how?"],
),
(
"context: {join(documents)} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 doc2 question: how?"],
),
(
"Please answer the question: {join(documents, ' delim ', '[$idx] $content', {'{': '('})} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"Please answer the question: {join(documents=documents, delimiter=' delim ', pattern='[$idx] $content', str_replace={'{': '('})} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"Please answer the question: {' delim '.join(['['+str(idx+1)+'] '+d.content.replace('{', '(') for idx, d in enumerate(documents)])} question: {query}",
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
'Please answer the question: {join(documents, " delim ", "[$idx] $content", {"{": "("})} question: {query}',
[Document("doc1"), Document("doc2")],
"how?",
["Please answer the question: [1] doc1 delim [2] doc2 question: how?"],
),
(
"context: {join(documents)} question: {query.replace('how', 'what')}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 doc2 question: what?"],
),
(
"context: {join(documents)[:6]} question: {query.replace('how', 'what').replace('?', '!')}",
[Document("doc1"), Document("doc2")],
"how?",
["context: doc1 d question: what!"],
),
("context: ", None, None, ["context: "]),
],
)
def test_prompt_template_syntax_fill(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, expected_prompts",
[
("{join(documents)}", [Document("doc1"), Document("doc2")], ["doc1 doc2"]),
(
"{join(documents, ' delim ', '[$idx] $content', {'c': 'C'})}",
[Document("doc1"), Document("doc2")],
["[1] doC1 delim [2] doC2"],
),
(
"{join(documents, ' delim ', '[$id] $content', {'c': 'C'})}",
[Document("doc1", id="123"), Document("doc2", id="456")],
["[123] doC1 delim [456] doC2"],
),
(
"{join(documents, ' delim ', '[$file_id] $content', {'c': 'C'})}",
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
["[123.txt] doC1 delim [456.txt] doC2"],
),
],
)
def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, expected_prompts",
[
("{to_strings(documents)}", [Document("doc1"), Document("doc2")], ["doc1", "doc2"]),
(
"{to_strings(documents, '[$idx] $content', {'c': 'C'})}",
[Document("doc1"), Document("doc2")],
["[1] doC1", "[2] doC2"],
),
(
"{to_strings(documents, '[$id] $content', {'c': 'C'})}",
[Document("doc1", id="123"), Document("doc2", id="456")],
["[123] doC1", "[456] doC2"],
),
(
"{to_strings(documents, '[$file_id] $content', {'c': 'C'})}",
[Document("doc1", meta={"file_id": "123.txt"}), Document("doc2", meta={"file_id": "456.txt"})],
["[123.txt] doC1", "[456.txt] doC2"],
),
("{to_strings(documents, '[$file_id] $content', {'c': 'C'})}", ["doc1", "doc2"], ["doC1", "doC2"]),
(
"{to_strings(documents, '[$idx] $answer', {'c': 'C'})}",
[Answer("doc1"), Answer("doc2")],
["[1] doC1", "[2] doC2"],
),
],
)
def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, exc_type, expected_exc_match",
[
("{__import__('os').listdir('.')}", PromptTemplateValidationError, "Invalid function in prompt text"),
("{__import__('os')}", PromptTemplateValidationError, "Invalid function in prompt text"),
(
"{requests.get('https://haystack.deepset.ai/')}",
PromptTemplateValidationError,
"Invalid function in prompt text",
),
("{join(__import__('os').listdir('.'))}", PromptTemplateValidationError, "Invalid function in prompt text"),
("{for}", SyntaxError, "invalid syntax"),
("This is an invalid {variable .", SyntaxError, "f-string: expecting '}'"),
],
)
def test_prompt_template_syntax_init_raises(
self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str
):
with pytest.raises(exc_type, match=expected_exc_match):
PromptTemplate(prompt_text)
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, exc_type, expected_exc_match",
[("{join}", None, None, ValueError, "Expected prompt parameters")],
)
def test_prompt_template_syntax_fill_raises(
self,
prompt_text: str,
documents: List[Document],
query: str,
exc_type: Type[BaseException],
expected_exc_match: str,
):
with pytest.raises(exc_type, match=expected_exc_match):
prompt_template = PromptTemplate(prompt_text)
next(prompt_template.fill(documents=documents, query=query))
@pytest.mark.unit
@pytest.mark.parametrize(
"prompt_text, documents, query, expected_prompts",
[
("__import__('os').listdir('.')", None, None, ["__import__('os').listdir('.')"]),
(
"requests.get('https://haystack.deepset.ai/')",
None,
None,
["requests.get('https://haystack.deepset.ai/')"],
),
("{query}", None, print, ["<built-in function print>"]),
("\b\b__import__('os').listdir('.')", None, None, ["\x08\x08__import__('os').listdir('.')"]),
],
)
def test_prompt_template_syntax_fill_ignores_dangerous_input(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts