PromptHub integration in PromptNode (#4879)

* initial integration

* upgrade of prompthub

* fix get_prompt_template

* feedback

* add prompthub-py to dependencies

* tests

* mypy

* stray changes

* review feedback

* missing init

* fix test

* move logic in prompttemplate

* linting

* bugfixes

* fix unit tests

* fix cache

* simplify prompttemplate init

* remove unused function

* removing wrong params

* try remove all instances of prompt names

* more tests

* fix agent tests

* more tests

* fix tests

* pylint

* comma

* black

* fix test

* docstring

* review feedback

* review feedback

* fix mocks

* mypy

* fix mocks

* fix reference to missing templates

* feedback

* remove direct references to default template var

* tests

* Update haystack/nodes/prompt/prompt_node.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-05-23 15:22:58 +02:00 committed by GitHub
parent 9e4feb6bed
commit 949b1b63b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 340 additions and 400 deletions

View File

@ -89,13 +89,12 @@ class OpenAIAnswerGenerator(BaseGenerator):
If not supplied, the default prompt template is:
```python
PromptTemplate(
name="question-answering-with-examples",
prompt_text="Please answer the question according to the above context."
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
"Please answer the question according to the above context."
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
)
```
To learn how variables, such as'{context}', are substituted in the `prompt_text`, see
To learn how variables, such as'{context}', are substituted in the prompt text, see
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
:param context_join_str: The separation string used to join the input documents to create the context
used by the PromptTemplate.
@ -114,10 +113,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
stop_words = ["\n", "<|endoftext|>"]
if prompt_template is None:
prompt_template = PromptTemplate(
name="question-answering-with-examples",
prompt_text="Please answer the question according to the above context."
"Please answer the question according to the above context."
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
"===\nContext: {context}\n===\n{query}"
)
else:
# Check for required prompts

View File

@ -1,19 +1,15 @@
from collections import defaultdict
import copy
import logging
import re
from typing import Dict, List, Optional, Tuple, Union, Any
import torch
import yaml
from haystack.nodes.base import BaseComponent
from haystack.schema import Document, MultiLabel
from haystack.telemetry import send_event
from haystack.nodes.prompt.shapers import BaseOutputParser
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.prompt_template import PromptTemplate, get_predefined_prompt_templates
from haystack.nodes.prompt.prompt_template import PromptTemplate
logger = logging.getLogger(__name__)
@ -97,8 +93,12 @@ class PromptNode(BaseComponent):
},
)
super().__init__()
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
self._prompt_templates_cache: Dict[str, PromptTemplate] = {}
# If we don't set _default_template here Pylint fails with error W0201 because it can't see that
# it's set in default_prompt_template, so we set it explicitly to None to avoid it failing
self._default_template = None
self.default_prompt_template = default_prompt_template
self.output_variable: Optional[str] = output_variable
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
self.prompt_model: PromptModel
@ -106,15 +106,6 @@ class PromptNode(BaseComponent):
self.top_k: int = top_k
self.debug = debug
if isinstance(self.default_prompt_template, str) and not self.is_supported_template(
self.default_prompt_template
):
raise ValueError(
f"Prompt template {self.default_prompt_template} is not supported. "
f"Select one of: {self.get_prompt_template_names()} "
f"or register a new prompt template first using the add_prompt_template() method."
)
if isinstance(model_name_or_path, str):
self.prompt_model = PromptModel(
model_name_or_path=model_name_or_path,
@ -187,65 +178,18 @@ class PromptNode(BaseComponent):
results.extend(output)
return results
def add_prompt_template(self, prompt_template: PromptTemplate) -> None:
"""
Adds a prompt template to the list of supported prompt templates.
:param prompt_template: The PromptTemplate object to be added.
:return: None
"""
if prompt_template.name in self.prompt_templates:
raise ValueError(
f"Prompt template {prompt_template.name} already exists. "
f"Select a different name for this prompt template."
)
@property
def default_prompt_template(self):
return self._default_template
self.prompt_templates[prompt_template.name] = prompt_template # type: ignore
def remove_prompt_template(self, prompt_template: str) -> PromptTemplate:
"""
Removes a prompt template from the list of supported prompt templates.
:param prompt_template: Name of the prompt template to be removed.
:return: PromptTemplate object that was removed.
"""
if prompt_template not in self.prompt_templates:
raise ValueError(f"Prompt template {prompt_template} does not exist")
return self.prompt_templates.pop(prompt_template)
def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode":
@default_prompt_template.setter
def default_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]):
"""
Sets the default prompt template for the node.
:param prompt_template: The prompt template to be set as default.
:return: The current PromptNode object.
"""
if not self.is_supported_template(prompt_template):
raise ValueError(f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()}")
self.default_prompt_template = prompt_template
return self
def get_prompt_templates(self) -> List[PromptTemplate]:
"""
Returns the list of supported prompt templates.
:return: List of supported prompt templates.
"""
return list(self.prompt_templates.values())
def get_prompt_template_names(self) -> List[str]:
"""
Returns the list of supported prompt template names.
:return: List of supported prompt template names.
"""
return list(self.prompt_templates.keys())
def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) -> bool:
"""
Checks if a prompt template is supported.
:param prompt_template: The prompt template to be checked.
:return: True if the prompt template is supported, False otherwise.
"""
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
return template_name in self.prompt_templates
self._default_template = self.get_prompt_template(prompt_template)
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None] = None) -> Optional[PromptTemplate]:
"""
@ -261,34 +205,25 @@ class PromptNode(BaseComponent):
:return: The prompt template object.
"""
prompt_template = prompt_template or self.default_prompt_template
# None means we're asking for the default prompt template
prompt_template = prompt_template or self._default_template
if prompt_template is None:
return None
# PromptTemplate instances simply go through
if isinstance(prompt_template, PromptTemplate):
return prompt_template
if isinstance(prompt_template, str) and prompt_template in self.prompt_templates:
return self.prompt_templates[prompt_template]
# If it's the name of a template that was used already
if prompt_template in self._prompt_templates_cache:
return self._prompt_templates_cache[prompt_template]
# if it's not a string or looks like a prompt template name
if not isinstance(prompt_template, str) or re.fullmatch(r"[-a-zA-Z0-9_]+", prompt_template):
raise ValueError(
f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()} or pass a PromptTemplate instance for prompting."
)
if "prompt_text:" in prompt_template:
prompt_template_parsed = yaml.safe_load(prompt_template)
if isinstance(prompt_template_parsed, dict):
return PromptTemplate(**prompt_template_parsed)
# it's a prompt_text
prompt_text = prompt_template
output_parser: Optional[BaseOutputParser] = None
default_prompt_template = self.get_prompt_template()
if default_prompt_template:
output_parser = default_prompt_template.output_parser
return PromptTemplate(name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser)
output_parser = None
if self.default_prompt_template:
output_parser = self.default_prompt_template.output_parser
template = PromptTemplate(prompt_template, output_parser=output_parser)
self._prompt_templates_cache[prompt_template] = template
return template
def prompt_template_params(self, prompt_template: str) -> List[str]:
"""
@ -296,10 +231,10 @@ class PromptNode(BaseComponent):
:param prompt_template: The name of the prompt template.
:return: The list of parameters for the prompt template.
"""
if not self.is_supported_template(prompt_template):
raise ValueError(f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()}")
return list(self.prompt_templates[prompt_template].prompt_params)
template = self.get_prompt_template(prompt_template)
if template:
return list(template.prompt_params)
return []
def run(
self,

View File

@ -1,11 +1,18 @@
from typing import Optional, List, Union, Tuple, Dict, Iterator, Any
import logging
import re
import os
import ast
import json
from pathlib import Path
from abc import ABC
from uuid import uuid4
import yaml
import tenacity
import prompthub
from requests import HTTPError, RequestException, JSONDecodeError
from haystack.errors import NodeError
from haystack.environment import HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS
from haystack.nodes.base import BaseComponent
@ -19,6 +26,11 @@ from haystack.nodes.prompt.shapers import ( # pylint: disable=unused-import
format_string,
)
from haystack.schema import Document, MultiLabel
from haystack.environment import (
HAYSTACK_REMOTE_API_TIMEOUT_SEC,
HAYSTACK_REMOTE_API_BACKOFF_SEC,
HAYSTACK_REMOTE_API_MAX_RETRIES,
)
logger = logging.getLogger(__name__)
@ -29,6 +41,14 @@ PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS = {"new_line": "\n", "tab": "\t", "double_quo
PROMPT_TEMPLATE_STRIPS = ["'", '"']
PROMPT_TEMPLATE_STR_REPLACE = {'"': "'"}
PROMPTHUB_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30.0))
PROMPTHUB_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10.0))
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
class PromptNotFoundError(Exception):
...
class BasePromptTemplate(BaseComponent):
outgoing_edges = 1
@ -163,9 +183,7 @@ class PromptTemplate(BasePromptTemplate, ABC):
PromptTemplate is a template for the prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what a prompt template may look like:
```python
PromptTemplate(name="sentiment-analysis",
prompt_text="Give a sentiment for this context. Answer with positive, negative"
"or neutral. Context: {documents}; Answer:")
PromptTemplate("Give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:")
```
Optionally, you can declare prompt parameters using f-string syntax in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
@ -183,14 +201,12 @@ class PromptTemplate(BasePromptTemplate, ABC):
[PromptTemplates](https://docs.haystack.deepset.ai/docs/prompt_node#prompttemplates).
"""
def __init__(
self, name: str, prompt_text: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None
):
def __init__(self, prompt: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None):
"""
Creates a PromptTemplate instance.
:param name: The name of the prompt template (for example, "sentiment-analysis", "question-generation"). You can specify your own name but it must be unique.
:param prompt_text: The prompt text, including prompt parameters.
:param prompt: The name of the prompt template on the PromptHub (for example, "sentiment-analysis",
"question-generation"), a Path to a local file, or the text of a new prompt, including its parameters.
:param output_parser: A parser that applied to the model output.
For example, to convert the model output to an Answer object, you can use `AnswerParser`.
Instead of BaseOutputParser instances, you can also pass dictionaries defining the output parsers. For example:
@ -199,6 +215,36 @@ class PromptTemplate(BasePromptTemplate, ABC):
```
"""
super().__init__()
name, prompt_text = "", ""
try:
# if it looks like a prompt template name
if re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
prompt_text = self._fetch_from_prompthub(prompt)
# if it's a path to a YAML file
elif Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]
# Otherwise it's a on-the-fly prompt text
else:
prompt_text = prompt
name = "custom-at-query-time"
except OSError as exc:
logger.info(
"There was an error checking whether this prompt is a file (%s). Haystack will assume it's not.",
str(exc),
)
# In case of errors, let's directly assume this is a text prompt
prompt_text = prompt
name = "custom-at-query-time"
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
for strip in PROMPT_TEMPLATE_STRIPS:
@ -241,6 +287,26 @@ class PromptTemplate(BasePromptTemplate, ABC):
def output_variable(self) -> Optional[str]:
return self.output_parser.output_variable if self.output_parser else None
@tenacity.retry(
reraise=True,
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
)
def _fetch_from_prompthub(self, name) -> str:
"""
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.
Raises PromptNotFoundError if the prompt is not present in the hub.
"""
try:
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
return prompt_data.text
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
"""
Prepares and verifies the PromtpTemplate with input parameters.
@ -336,110 +402,3 @@ class PromptTemplate(BasePromptTemplate, ABC):
def __repr__(self):
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"
def get_predefined_prompt_templates() -> List[PromptTemplate]:
return [
PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="question-answering-per-document",
prompt_text="Given the context please answer the question. Context: {documents}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="question-answering-with-references",
prompt_text="Create a concise and informative answer (no more than 50 words) for a given question "
"based solely on the given documents. You must only use information from the given documents. "
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
"If multiple documents contain the answer, cite those documents like as stated in Document[number], Document[number], etc.. "
"If the documents do not contain the answer to the question, say that answering is not possible given the available information.\n"
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
),
PromptTemplate(
name="question-answering-with-document-scores",
prompt_text="Answer the following question using the paragraphs below as sources. "
"An answer should be short, a few words at most.\n"
"Paragraphs:\n{documents}\n"
"Question: {query}\n\n"
"Instructions: Consider all the paragraphs above and their corresponding scores to generate "
"the answer. While a single paragraph may have a high score, it's important to consider all "
"paragraphs for the same answer candidate to answer accurately.\n\n"
"After having considered all possibilities, the final answer is:\n",
),
PromptTemplate(
name="question-generation",
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
),
PromptTemplate(
name="conditioned-question-generation",
prompt_text="Please come up with a question for the given context and the answer. "
"Context: {documents}; Answer: {answers}; Question:",
),
PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
PromptTemplate(
name="question-answering-check",
prompt_text="Does the following context contain the answer to the question? "
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
),
PromptTemplate(
name="multiple-choice-question-answering",
prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. "
"Options: {options}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
name="topic-classification",
prompt_text="Categories: {options}; What category best describes: {documents}; Answer:",
),
PromptTemplate(
name="language-detection",
prompt_text="Detect the language in the following context and answer with the "
"name of the language. Context: {documents}; Answer:",
),
PromptTemplate(
name="translation",
prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:",
),
PromptTemplate(
name="zero-shot-react",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: pick one of {tool_names} \n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to {transcript}",
),
PromptTemplate(
name="conversational-agent",
prompt_text="The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:",
),
PromptTemplate(
name="conversational-summary",
prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:",
),
]

View File

@ -58,6 +58,7 @@ dependencies = [
"rank_bm25",
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
"generalimport", # Optional imports
"prompthub-py",
# Utils
"dill", # pickle extension for (de-)serialization

View File

@ -177,7 +177,7 @@ def test_tool_result_extraction(reader, retriever_with_docs):
assert result == "Paris" or result == "Madrid"
# PromptNode as a Tool
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
pt = PromptTemplate("Here is a question: {query}, Answer:")
pn = PromptNode(default_prompt_template=pt)
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
@ -212,12 +212,11 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
country_finder = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
name="country_finder",
prompt_text="When I give you a name of the city, respond with the country where the city is located.\n"
"When I give you a name of the city, respond with the country where the city is located.\n"
"City: Rome\nCountry: Italy\n"
"City: Berlin\nCountry: Germany\n"
"City: Belgrade\nCountry: Serbia\n"
"City: {query}?\nCountry: ",
"City: {query}?\nCountry: "
),
)

View File

@ -1,5 +1,5 @@
import pytest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
from haystack.agents.conversational import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
@ -8,8 +8,11 @@ from haystack.nodes import PromptNode
@pytest.mark.unit
def test_init():
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
# Test normal case
assert isinstance(agent.memory, ConversationMemory)
assert callable(agent.prompt_parameters_resolver)
@ -19,21 +22,33 @@ def test_init():
# ConversationalAgent doesn't have tools
assert not agent.tm.tools
@pytest.mark.unit
def test_init_with_summary_memory():
prompt_node = PromptNode(default_prompt_template="this is a test")
# Test with summary memory
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit
def test_init_with_no_memory():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit
def test_run():
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
# Mock the Agent run method
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")
# Mock the Agent run method
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")

View File

@ -9,9 +9,7 @@ from haystack.agents.memory import ConversationSummaryMemory
@pytest.fixture
def mocked_prompt_node():
mock_prompt_node = MagicMock(spec=PromptNode)
mock_prompt_node.default_prompt_template = PromptTemplate(
"conversational-summary", "Summarize the conversation: {chat_transcript}"
)
mock_prompt_node.default_prompt_template = PromptTemplate("Summarize the conversation: {chat_transcript}")
mock_prompt_node.prompt.return_value = ["This is a summary."]
return mock_prompt_node
@ -127,7 +125,7 @@ def test_conversation_summary_is_accumulating(mocked_prompt_node):
@pytest.mark.unit
def test_conversation_summary_memory_with_template(mocked_prompt_node):
pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}")
pt = PromptTemplate("Summarize the conversation: {chat_transcript}")
summary_mem = ConversationSummaryMemory(mocked_prompt_node, prompt_template=pt)
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}

View File

@ -398,9 +398,8 @@ class MockPromptNode(PromptNode):
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> Optional[PromptTemplate]:
if prompt_template == "think-step-by-step":
return PromptTemplate(
name="think-step-by-step",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
p = PromptTemplate(
"You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"{tool_names_with_descriptions}\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
@ -417,10 +416,11 @@ class MockPromptNode(PromptNode):
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to {generated_text}",
"Thought: Let's think step-by-step, I first need to {generated_text}"
)
p.name = "think-step-by-step"
else:
return PromptTemplate(name="", prompt_text="")
return PromptTemplate("test prompt")
@pytest.fixture

View File

@ -158,10 +158,8 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs):
pytest.skip("No API key found, skipping test")
lfqa_prompt = PromptTemplate(
name="lfqa",
prompt_text="""
Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question.
\n===\Paragraphs: {context}\n===\n{query}""",
"""Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and
the given question.\n===\Paragraphs: {context}\n===\n{query}"""
)
node = OpenAIAnswerGenerator(
api_key=haystack_openai_config["api_key"],

View File

@ -845,8 +845,7 @@ def test_strings_to_answers_after_prompt_node_yaml(tmp_path):
- name: prompt_template_raw_qa_per_document
type: PromptTemplate
params:
name: raw-question-answering-per-document
prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
prompt: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
- name: prompt_node_raw_qa
type: PromptNode

View File

@ -11,6 +11,15 @@ from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler
@pytest.fixture
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = (
lambda name: "This is a test prompt. Use your knowledge to answer this question: {question}"
)
yield mock_prompthub
def skip_test_for_invalid_key(prompt_model):
if prompt_model.api_key is not None and prompt_model.api_key == "KEY_NOT_FOUND":
pytest.skip("No API key found, skipping test")
@ -24,59 +33,6 @@ def get_api_key(request):
return os.environ.get("AZURE_OPENAI_API_KEY", None)
@pytest.mark.unit
def test_add_and_remove_template():
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
node = PromptNode()
total_count = 16
# Verifies default
assert len(node.get_prompt_template_names()) == total_count
# Add a fake template
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
node.add_prompt_template(fake_template)
assert len(node.get_prompt_template_names()) == total_count + 1
assert "fake-template" in node.get_prompt_template_names()
# Verify that adding the same template throws an expection
with pytest.raises(ValueError) as e:
node.add_prompt_template(fake_template)
assert e.match(
"Prompt template fake-template already exists. Select a different name for this prompt template."
)
# Verify template is correctly removed
assert node.remove_prompt_template("fake-template")
assert len(node.get_prompt_template_names()) == total_count
assert "fake-template" not in node.get_prompt_template_names()
# Verify that removing the same template throws an expection
with pytest.raises(ValueError) as e:
node.remove_prompt_template("fake-template")
assert e.match("Prompt template fake-template does not exist")
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_after_adding_template(mock_model):
# Make model always return something positive on invoke
mock_model.return_value.invoke.return_value = ["positive"]
# Create a template
template = PromptTemplate(
name="fake-sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
)
# Execute prompt
node = PromptNode()
node.add_prompt_template(template)
result = node.prompt("fake-sentiment-analysis", documents=["Berlin is an amazing city."])
assert result == ["positive"]
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_passing_template(mock_model):
@ -85,9 +41,8 @@ def test_prompt_passing_template(mock_model):
# Create a template
template = PromptTemplate(
name="fake-sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
"Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:"
)
# Execute prompt
@ -137,69 +92,92 @@ def test_prompt_call_with_custom_kwargs_and_template(mock_model, mocked_prompt):
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_without_default_template(mock_model):
def test_get_prompt_template_no_default_template(mock_model):
node = PromptNode()
assert node.get_prompt_template() is None
template = node.get_prompt_template("question-answering")
assert template.name == "question-answering"
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
assert template.name == "fake-template"
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_with_default_template(mock_model, mock_prompthub):
node = PromptNode()
node.default_prompt_template = "deepset/test-prompt"
with pytest.raises(ValueError) as e:
node.get_prompt_template("some-unsupported-template")
assert e.match("some-unsupported-template not supported, select one of:")
template = node.get_prompt_template()
assert template.name == "deepset/test-prompt"
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "fake-yaml-template"
fake_yaml_prompt = "- prompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "custom-at-query-time"
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_name_from_hub(mock_model, mock_prompthub):
node = PromptNode()
template = node.get_prompt_template("deepset/test-prompt")
assert template.name == "deepset/test-prompt"
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_local_file(mock_model, tmp_path, mock_prompthub):
with open(tmp_path / "local_prompt_template.yml", "w") as ptf:
ptf.write(
"""
name: my_prompts/question-answering
prompt_text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
description: A simple prompt to answer a question given a set of documents
tags:
- question-answering
meta:
authors:
- vblagoje
version: v0.1.1
"""
)
node = PromptNode()
template = node.get_prompt_template(str(tmp_path / "local_prompt_template.yml"))
assert template.name == "my_prompts/question-answering"
assert "Given the context" in template.prompt_text
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_object(mock_model, mock_prompthub):
node = PromptNode()
original_template = PromptTemplate("fake-template")
template = node.get_prompt_template(original_template)
assert template == original_template
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_wrong_template_name(mock_model):
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mock_prompthub:
def not_found(*a, **k):
raise ValueError("'some-unsupported-template' not supported!")
mock_prompthub.fetch.side_effect = not_found
node = PromptNode()
with pytest.raises(ValueError, match="not supported") as e:
node.get_prompt_template("some-unsupported-template")
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_only_template_text(mock_model, mock_prompthub):
node = PromptNode()
template = node.get_prompt_template("some prompt")
assert template.name == "custom-at-query-time"
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_with_default_template(mock_model):
def test_invalid_template_params(mock_model, mock_prompthub):
node = PromptNode()
node.set_default_prompt_template("question-answering")
template = node.get_prompt_template()
assert template.name == "question-answering"
template = node.get_prompt_template("sentiment-analysis")
assert template.name == "sentiment-analysis"
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
assert template.name == "fake-template"
with pytest.raises(ValueError) as e:
node.get_prompt_template("some-unsupported-template")
assert e.match("some-unsupported-template not supported, select one of:")
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "fake-yaml-template"
fake_yaml_prompt = "- prompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "custom-at-query-time"
template = node.get_prompt_template("some prompt")
assert template.name == "custom-at-query-time"
@pytest.mark.integration
def test_invalid_template_params():
# TODO: This can be a PromptTemplate unit test
node = PromptNode("google/flan-t5-small", devices=["cpu"])
with pytest.raises(ValueError, match="Expected prompt parameters"):
node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.")
@pytest.mark.integration
@ -282,7 +260,10 @@ def test_stop_words(prompt_model):
node = PromptNode(prompt_model, stop_words=["capital"])
# with default prompt template and stop words set in PN
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
)
assert r[0] == "What is the" or r[0] == "What city is the"
# test stop words for both HF and OpenAI
@ -290,11 +271,18 @@ def test_stop_words(prompt_model):
node = PromptNode(prompt_model, stop_words=["capital", "Germany"])
# with default prompt template and stop words set in PN
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
)
assert r[0] == "What is the" or r[0] == "What city is the"
# with default prompt template and stop words set in kwargs (overrides PN stop words)
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."], stop_words=None)
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
stop_words=None,
)
assert "capital" in r[0] or "Germany" in r[0]
# simple prompting
@ -310,10 +298,7 @@ def test_stop_words(prompt_model):
)
assert "capital" in r[0] or "Germany" in r[0]
tt = PromptTemplate(
name="question-generation-copy",
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
)
tt = PromptTemplate("Given the context please generate a question. Context: {documents}; Question:")
# with custom prompt template
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
assert r[0] == "What is the" or r[0] == "What city is the"
@ -520,7 +505,7 @@ def test_pipeline_with_qa_with_references(prompt_model):
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_pipeline_with_prompt_text_at_query_time(prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
node = PromptNode(prompt_model, default_prompt_template="test prompt template text", top_k=1)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
@ -609,8 +594,7 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
- name: template_with_nested_shaper
type: PromptTemplate
params:
name: custom-template-with-nested-shaper
prompt_text: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: "
prompt: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: "
output_parser:
type: AnswerParser
- name: p1
@ -674,8 +658,7 @@ def test_complex_pipeline_with_qa(prompt_model):
skip_test_for_invalid_key(prompt_model)
prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
"Given the context please answer the question. Context: {documents}; Question: {query}; Answer:"
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
@ -874,8 +857,7 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
- name: question_generation_template
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel
@ -954,8 +936,7 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
- name: question_generation_template
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel
@ -1028,8 +1009,7 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
- name: question_generation_template
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
- name: p1
params:
model_name_or_path: pmodel_openai
@ -1104,9 +1084,7 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
class TestTokenLimit:
@pytest.mark.integration
def test_hf_token_limit_warning(self, caplog):
prompt_template = PromptTemplate(
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
)
prompt_template = PromptTemplate("Repeating text" * 200 + "Docs: {documents}; Answer:")
with caplog.at_level(logging.WARNING):
node = PromptNode("google/flan-t5-small", devices=["cpu"])
node.prompt(prompt_template, documents=["Berlin is an amazing city."])
@ -1119,7 +1097,7 @@ class TestTokenLimit:
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_openai_token_limit_warning(self, caplog):
tt = PromptTemplate(name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:")
tt = PromptTemplate("Repeating text" * 200 + "Docs: {documents}; Answer:")
prompt_node = PromptNode("text-ada-001", max_length=2000, api_key=os.environ.get("OPENAI_API_KEY", ""))
with caplog.at_level(logging.WARNING):
_ = prompt_node.prompt(tt, documents=["Berlin is an amazing city."])
@ -1133,7 +1111,10 @@ class TestRunBatch:
def test_simple_pipeline_batch_no_query_single_doc_list(self, prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
node = PromptNode(
prompt_model,
default_prompt_template="Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:",
)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
@ -1151,7 +1132,11 @@ class TestRunBatch:
def test_simple_pipeline_batch_no_query_multiple_doc_list(self, prompt_model):
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
node = PromptNode(
prompt_model,
default_prompt_template="Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:",
output_variable="out",
)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
@ -1174,8 +1159,7 @@ class TestRunBatch:
skip_test_for_invalid_key(prompt_model)
prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
"Given the context please answer the question. Context: {documents}; Question: {query}; Answer:"
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)

View File

@ -1,7 +1,9 @@
from typing import Set, Type, List
from unittest.mock import patch
import textwrap
from unittest.mock import patch, MagicMock
import pytest
import prompthub
from haystack.nodes.prompt import PromptTemplate
from haystack.nodes.prompt.prompt_node import PromptNode
@ -11,29 +13,81 @@ from haystack.pipelines.base import Pipeline
from haystack.schema import Answer, Document
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [
("deepset/test-prompt", "This is a test prompt. Use your knowledge to answer this question: {question}")
]
yield mock_prompthub
@pytest.mark.unit
def test_prompt_templates():
p = PromptTemplate("t1", "Here is some fake template with variable {foo}")
def test_prompt_templates_from_hub():
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mock_prompthub:
PromptTemplate("deepset/question-answering")
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)
@pytest.mark.unit
def test_prompt_templates_from_file(tmp_path):
path = tmp_path / "test-prompt.yml"
with open(path, "a") as yamlfile:
yamlfile.write(
textwrap.dedent(
"""
name: deepset/question-answering
prompt_text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
description: A simple prompt to answer a question given a set of documents
tags:
- question-answering
meta:
authors:
- vblagoje
version: v0.1.1
"""
)
)
p = PromptTemplate(str(path.absolute()))
assert p.name == "deepset/question-answering"
assert "Given the context please answer the question" in p.prompt_text
@pytest.mark.unit
def test_prompt_templates_on_the_fly():
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml:
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph:
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
assert p.name == "custom-at-query-time"
mocked_ph.fetch.assert_not_called()
mocked_yaml.safe_load.assert_not_called()
@pytest.mark.unit
def test_custom_prompt_templates():
p = PromptTemplate("Here is some fake template with variable {foo}")
assert set(p.prompt_params) == {"foo"}
p = PromptTemplate("t3", "Here is some fake template with variable {foo} and {bar}")
p = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
assert set(p.prompt_params) == {"foo", "bar"}
p = PromptTemplate("t4", "Here is some fake template with variable {foo1} and {bar2}")
p = PromptTemplate("Here is some fake template with variable {foo1} and {bar2}")
assert set(p.prompt_params) == {"foo1", "bar2"}
p = PromptTemplate("t4", "Here is some fake template with variable {foo_1} and {bar_2}")
p = PromptTemplate("Here is some fake template with variable {foo_1} and {bar_2}")
assert set(p.prompt_params) == {"foo_1", "bar_2"}
p = PromptTemplate("t4", "Here is some fake template with variable {Foo_1} and {Bar_2}")
p = PromptTemplate("Here is some fake template with variable {Foo_1} and {Bar_2}")
assert set(p.prompt_params) == {"Foo_1", "Bar_2"}
p = PromptTemplate("t4", "'Here is some fake template with variable {baz}'")
p = PromptTemplate("'Here is some fake template with variable {baz}'")
assert set(p.prompt_params) == {"baz"}
# strip single quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable {baz}"
p = PromptTemplate("t4", '"Here is some fake template with variable {baz}"')
p = PromptTemplate('"Here is some fake template with variable {baz}"')
assert set(p.prompt_params) == {"baz"}
# strip double quotes, happens in YAML as we need to use single quotes for the template string
assert p.prompt_text == "Here is some fake template with variable {baz}"
@ -41,7 +95,7 @@ def test_prompt_templates():
@pytest.mark.unit
def test_missing_prompt_template_params():
template = PromptTemplate("missing_params", "Here is some fake template with variable {foo} and {bar}")
template = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
# both params provided - ok
template.prepare(foo="foo", bar="bar")
@ -62,8 +116,10 @@ def test_missing_prompt_template_params():
@pytest.mark.unit
def test_prompt_template_repr():
p = PromptTemplate("t", "Here is variable {baz}")
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
p = PromptTemplate("Here is variable {baz}")
desired_repr = (
"PromptTemplate(name=custom-at-query-time, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
)
assert repr(p) == desired_repr
assert str(p) == desired_repr
@ -72,8 +128,7 @@ def test_prompt_template_repr():
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_template_deserialization(mock_prompt_model):
custom_prompt_template = PromptTemplate(
name="custom-question-answering",
prompt_text="Given the context please answer the question. Context: {context}; Question: {query}; Answer:",
"Given the context please answer the question. Context: {context}; Question: {query}; Answer:",
output_parser=AnswerParser(),
)
@ -88,7 +143,6 @@ def test_prompt_template_deserialization(mock_prompt_model):
loaded_generator = loaded_pipe.get_node("Generator")
assert isinstance(loaded_generator, PromptNode)
assert isinstance(loaded_generator.default_prompt_template, PromptTemplate)
assert loaded_generator.default_prompt_template.name == "custom-question-answering"
assert (
loaded_generator.default_prompt_template.prompt_text
== "Given the context please answer the question. Context: {context}; Question: {query}; Answer:"
@ -135,7 +189,7 @@ class TestPromptTemplateSyntax:
def test_prompt_template_syntax_parser(
self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
assert set(prompt_template.prompt_params) == expected_prompt_params
assert set(prompt_template._used_functions) == expected_used_functions
@ -210,13 +264,13 @@ class TestPromptTemplateSyntax:
"how?",
["context: doc1 d question: what!"],
),
("context", None, None, ["context"]),
("context: ", None, None, ["context: "]),
],
)
def test_prompt_template_syntax_fill(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts
@ -243,7 +297,7 @@ class TestPromptTemplateSyntax:
],
)
def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@ -276,7 +330,7 @@ class TestPromptTemplateSyntax:
],
)
def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
assert prompts == expected_prompts
@ -300,7 +354,7 @@ class TestPromptTemplateSyntax:
self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str
):
with pytest.raises(exc_type, match=expected_exc_match):
PromptTemplate(name="test", prompt_text=prompt_text)
PromptTemplate(prompt_text)
@pytest.mark.unit
@pytest.mark.parametrize(
@ -316,7 +370,7 @@ class TestPromptTemplateSyntax:
expected_exc_match: str,
):
with pytest.raises(exc_type, match=expected_exc_match):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
next(prompt_template.fill(documents=documents, query=query))
@pytest.mark.unit
@ -337,6 +391,6 @@ class TestPromptTemplateSyntax:
def test_prompt_template_syntax_fill_ignores_dangerous_input(
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
):
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts