mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
PromptHub integration in PromptNode (#4879)
* initial integration * upgrade of prompthub * fix get_prompt_template * feedback * add prompthub-py to dependencies * tests * mypy * stray changes * review feedback * missing init * fix test * move logic in prompttemplate * linting * bugfixes * fix unit tests * fix cache * simplify prompttemplate init * remove unused function * removing wrong params * try remove all instances of prompt names * more tests * fix agent tests * more tests * fix tests * pylint * comma * black * fix test * docstring * review feedback * review feedback * fix mocks * mypy * fix mocks * fix reference to missing templates * feedback * remove direct references to default template var * tests * Update haystack/nodes/prompt/prompt_node.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
9e4feb6bed
commit
949b1b63b3
@ -89,13 +89,12 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
If not supplied, the default prompt template is:
|
||||
```python
|
||||
PromptTemplate(
|
||||
name="question-answering-with-examples",
|
||||
prompt_text="Please answer the question according to the above context."
|
||||
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
|
||||
"===\nContext: {context}\n===\n{query}",
|
||||
"Please answer the question according to the above context."
|
||||
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
|
||||
"===\nContext: {context}\n===\n{query}",
|
||||
)
|
||||
```
|
||||
To learn how variables, such as'{context}', are substituted in the `prompt_text`, see
|
||||
To learn how variables, such as'{context}', are substituted in the prompt text, see
|
||||
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
||||
:param context_join_str: The separation string used to join the input documents to create the context
|
||||
used by the PromptTemplate.
|
||||
@ -114,10 +113,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
stop_words = ["\n", "<|endoftext|>"]
|
||||
if prompt_template is None:
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-with-examples",
|
||||
prompt_text="Please answer the question according to the above context."
|
||||
"Please answer the question according to the above context."
|
||||
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
|
||||
"===\nContext: {context}\n===\n{query}",
|
||||
"===\nContext: {context}\n===\n{query}"
|
||||
)
|
||||
else:
|
||||
# Check for required prompts
|
||||
|
||||
@ -1,19 +1,15 @@
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.telemetry import send_event
|
||||
from haystack.nodes.prompt.shapers import BaseOutputParser
|
||||
from haystack.nodes.prompt.prompt_model import PromptModel
|
||||
from haystack.nodes.prompt.prompt_template import PromptTemplate, get_predefined_prompt_templates
|
||||
from haystack.nodes.prompt.prompt_template import PromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -97,8 +93,12 @@ class PromptNode(BaseComponent):
|
||||
},
|
||||
)
|
||||
super().__init__()
|
||||
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
|
||||
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
||||
self._prompt_templates_cache: Dict[str, PromptTemplate] = {}
|
||||
|
||||
# If we don't set _default_template here Pylint fails with error W0201 because it can't see that
|
||||
# it's set in default_prompt_template, so we set it explicitly to None to avoid it failing
|
||||
self._default_template = None
|
||||
self.default_prompt_template = default_prompt_template
|
||||
self.output_variable: Optional[str] = output_variable
|
||||
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
||||
self.prompt_model: PromptModel
|
||||
@ -106,15 +106,6 @@ class PromptNode(BaseComponent):
|
||||
self.top_k: int = top_k
|
||||
self.debug = debug
|
||||
|
||||
if isinstance(self.default_prompt_template, str) and not self.is_supported_template(
|
||||
self.default_prompt_template
|
||||
):
|
||||
raise ValueError(
|
||||
f"Prompt template {self.default_prompt_template} is not supported. "
|
||||
f"Select one of: {self.get_prompt_template_names()} "
|
||||
f"or register a new prompt template first using the add_prompt_template() method."
|
||||
)
|
||||
|
||||
if isinstance(model_name_or_path, str):
|
||||
self.prompt_model = PromptModel(
|
||||
model_name_or_path=model_name_or_path,
|
||||
@ -187,65 +178,18 @@ class PromptNode(BaseComponent):
|
||||
results.extend(output)
|
||||
return results
|
||||
|
||||
def add_prompt_template(self, prompt_template: PromptTemplate) -> None:
|
||||
"""
|
||||
Adds a prompt template to the list of supported prompt templates.
|
||||
:param prompt_template: The PromptTemplate object to be added.
|
||||
:return: None
|
||||
"""
|
||||
if prompt_template.name in self.prompt_templates:
|
||||
raise ValueError(
|
||||
f"Prompt template {prompt_template.name} already exists. "
|
||||
f"Select a different name for this prompt template."
|
||||
)
|
||||
@property
|
||||
def default_prompt_template(self):
|
||||
return self._default_template
|
||||
|
||||
self.prompt_templates[prompt_template.name] = prompt_template # type: ignore
|
||||
|
||||
def remove_prompt_template(self, prompt_template: str) -> PromptTemplate:
|
||||
"""
|
||||
Removes a prompt template from the list of supported prompt templates.
|
||||
:param prompt_template: Name of the prompt template to be removed.
|
||||
:return: PromptTemplate object that was removed.
|
||||
"""
|
||||
if prompt_template not in self.prompt_templates:
|
||||
raise ValueError(f"Prompt template {prompt_template} does not exist")
|
||||
|
||||
return self.prompt_templates.pop(prompt_template)
|
||||
|
||||
def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode":
|
||||
@default_prompt_template.setter
|
||||
def default_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]):
|
||||
"""
|
||||
Sets the default prompt template for the node.
|
||||
:param prompt_template: The prompt template to be set as default.
|
||||
:return: The current PromptNode object.
|
||||
"""
|
||||
if not self.is_supported_template(prompt_template):
|
||||
raise ValueError(f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()}")
|
||||
|
||||
self.default_prompt_template = prompt_template
|
||||
return self
|
||||
|
||||
def get_prompt_templates(self) -> List[PromptTemplate]:
|
||||
"""
|
||||
Returns the list of supported prompt templates.
|
||||
:return: List of supported prompt templates.
|
||||
"""
|
||||
return list(self.prompt_templates.values())
|
||||
|
||||
def get_prompt_template_names(self) -> List[str]:
|
||||
"""
|
||||
Returns the list of supported prompt template names.
|
||||
:return: List of supported prompt template names.
|
||||
"""
|
||||
return list(self.prompt_templates.keys())
|
||||
|
||||
def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) -> bool:
|
||||
"""
|
||||
Checks if a prompt template is supported.
|
||||
:param prompt_template: The prompt template to be checked.
|
||||
:return: True if the prompt template is supported, False otherwise.
|
||||
"""
|
||||
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
|
||||
return template_name in self.prompt_templates
|
||||
self._default_template = self.get_prompt_template(prompt_template)
|
||||
|
||||
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None] = None) -> Optional[PromptTemplate]:
|
||||
"""
|
||||
@ -261,34 +205,25 @@ class PromptNode(BaseComponent):
|
||||
|
||||
:return: The prompt template object.
|
||||
"""
|
||||
prompt_template = prompt_template or self.default_prompt_template
|
||||
# None means we're asking for the default prompt template
|
||||
prompt_template = prompt_template or self._default_template
|
||||
if prompt_template is None:
|
||||
return None
|
||||
|
||||
# PromptTemplate instances simply go through
|
||||
if isinstance(prompt_template, PromptTemplate):
|
||||
return prompt_template
|
||||
|
||||
if isinstance(prompt_template, str) and prompt_template in self.prompt_templates:
|
||||
return self.prompt_templates[prompt_template]
|
||||
# If it's the name of a template that was used already
|
||||
if prompt_template in self._prompt_templates_cache:
|
||||
return self._prompt_templates_cache[prompt_template]
|
||||
|
||||
# if it's not a string or looks like a prompt template name
|
||||
if not isinstance(prompt_template, str) or re.fullmatch(r"[-a-zA-Z0-9_]+", prompt_template):
|
||||
raise ValueError(
|
||||
f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()} or pass a PromptTemplate instance for prompting."
|
||||
)
|
||||
|
||||
if "prompt_text:" in prompt_template:
|
||||
prompt_template_parsed = yaml.safe_load(prompt_template)
|
||||
if isinstance(prompt_template_parsed, dict):
|
||||
return PromptTemplate(**prompt_template_parsed)
|
||||
|
||||
# it's a prompt_text
|
||||
prompt_text = prompt_template
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
default_prompt_template = self.get_prompt_template()
|
||||
if default_prompt_template:
|
||||
output_parser = default_prompt_template.output_parser
|
||||
return PromptTemplate(name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser)
|
||||
output_parser = None
|
||||
if self.default_prompt_template:
|
||||
output_parser = self.default_prompt_template.output_parser
|
||||
template = PromptTemplate(prompt_template, output_parser=output_parser)
|
||||
self._prompt_templates_cache[prompt_template] = template
|
||||
return template
|
||||
|
||||
def prompt_template_params(self, prompt_template: str) -> List[str]:
|
||||
"""
|
||||
@ -296,10 +231,10 @@ class PromptNode(BaseComponent):
|
||||
:param prompt_template: The name of the prompt template.
|
||||
:return: The list of parameters for the prompt template.
|
||||
"""
|
||||
if not self.is_supported_template(prompt_template):
|
||||
raise ValueError(f"{prompt_template} not supported, select one of: {self.get_prompt_template_names()}")
|
||||
|
||||
return list(self.prompt_templates[prompt_template].prompt_params)
|
||||
template = self.get_prompt_template(prompt_template)
|
||||
if template:
|
||||
return list(template.prompt_params)
|
||||
return []
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
||||
@ -1,11 +1,18 @@
|
||||
from typing import Optional, List, Union, Tuple, Dict, Iterator, Any
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
import ast
|
||||
import json
|
||||
from pathlib import Path
|
||||
from abc import ABC
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
import tenacity
|
||||
import prompthub
|
||||
from requests import HTTPError, RequestException, JSONDecodeError
|
||||
|
||||
from haystack.errors import NodeError
|
||||
from haystack.environment import HAYSTACK_PROMPT_TEMPLATE_ALLOWED_FUNCTIONS
|
||||
from haystack.nodes.base import BaseComponent
|
||||
@ -19,6 +26,11 @@ from haystack.nodes.prompt.shapers import ( # pylint: disable=unused-import
|
||||
format_string,
|
||||
)
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.environment import (
|
||||
HAYSTACK_REMOTE_API_TIMEOUT_SEC,
|
||||
HAYSTACK_REMOTE_API_BACKOFF_SEC,
|
||||
HAYSTACK_REMOTE_API_MAX_RETRIES,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,6 +41,14 @@ PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS = {"new_line": "\n", "tab": "\t", "double_quo
|
||||
PROMPT_TEMPLATE_STRIPS = ["'", '"']
|
||||
PROMPT_TEMPLATE_STR_REPLACE = {'"': "'"}
|
||||
|
||||
PROMPTHUB_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30.0))
|
||||
PROMPTHUB_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10.0))
|
||||
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
|
||||
|
||||
class PromptNotFoundError(Exception):
|
||||
...
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
@ -163,9 +183,7 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
PromptTemplate is a template for the prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what a prompt template may look like:
|
||||
|
||||
```python
|
||||
PromptTemplate(name="sentiment-analysis",
|
||||
prompt_text="Give a sentiment for this context. Answer with positive, negative"
|
||||
"or neutral. Context: {documents}; Answer:")
|
||||
PromptTemplate("Give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:")
|
||||
```
|
||||
|
||||
Optionally, you can declare prompt parameters using f-string syntax in the PromptTemplate. Prompt parameters are input parameters that need to be filled in
|
||||
@ -183,14 +201,12 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
[PromptTemplates](https://docs.haystack.deepset.ai/docs/prompt_node#prompttemplates).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, prompt_text: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None
|
||||
):
|
||||
def __init__(self, prompt: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None):
|
||||
"""
|
||||
Creates a PromptTemplate instance.
|
||||
|
||||
:param name: The name of the prompt template (for example, "sentiment-analysis", "question-generation"). You can specify your own name but it must be unique.
|
||||
:param prompt_text: The prompt text, including prompt parameters.
|
||||
:param prompt: The name of the prompt template on the PromptHub (for example, "sentiment-analysis",
|
||||
"question-generation"), a Path to a local file, or the text of a new prompt, including its parameters.
|
||||
:param output_parser: A parser that applied to the model output.
|
||||
For example, to convert the model output to an Answer object, you can use `AnswerParser`.
|
||||
Instead of BaseOutputParser instances, you can also pass dictionaries defining the output parsers. For example:
|
||||
@ -199,6 +215,36 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
```
|
||||
"""
|
||||
super().__init__()
|
||||
name, prompt_text = "", ""
|
||||
|
||||
try:
|
||||
# if it looks like a prompt template name
|
||||
if re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
|
||||
name = prompt
|
||||
prompt_text = self._fetch_from_prompthub(prompt)
|
||||
|
||||
# if it's a path to a YAML file
|
||||
elif Path(prompt).exists():
|
||||
with open(prompt, "r", encoding="utf-8") as yaml_file:
|
||||
prompt_template_parsed = yaml.safe_load(yaml_file.read())
|
||||
if not isinstance(prompt_template_parsed, dict):
|
||||
raise ValueError("The prompt loaded is not a prompt YAML file.")
|
||||
name = prompt_template_parsed["name"]
|
||||
prompt_text = prompt_template_parsed["prompt_text"]
|
||||
|
||||
# Otherwise it's a on-the-fly prompt text
|
||||
else:
|
||||
prompt_text = prompt
|
||||
name = "custom-at-query-time"
|
||||
|
||||
except OSError as exc:
|
||||
logger.info(
|
||||
"There was an error checking whether this prompt is a file (%s). Haystack will assume it's not.",
|
||||
str(exc),
|
||||
)
|
||||
# In case of errors, let's directly assume this is a text prompt
|
||||
prompt_text = prompt
|
||||
name = "custom-at-query-time"
|
||||
|
||||
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
||||
for strip in PROMPT_TEMPLATE_STRIPS:
|
||||
@ -241,6 +287,26 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
def output_variable(self) -> Optional[str]:
|
||||
return self.output_parser.output_variable if self.output_parser else None
|
||||
|
||||
@tenacity.retry(
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
|
||||
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
|
||||
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
|
||||
)
|
||||
def _fetch_from_prompthub(self, name) -> str:
|
||||
"""
|
||||
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.
|
||||
|
||||
Raises PromptNotFoundError if the prompt is not present in the hub.
|
||||
"""
|
||||
try:
|
||||
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
|
||||
except HTTPError as http_error:
|
||||
if http_error.response.status_code != 404:
|
||||
raise http_error
|
||||
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
|
||||
return prompt_data.text
|
||||
|
||||
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepares and verifies the PromtpTemplate with input parameters.
|
||||
@ -336,110 +402,3 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
|
||||
def __repr__(self):
|
||||
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"
|
||||
|
||||
|
||||
def get_predefined_prompt_templates() -> List[PromptTemplate]:
|
||||
return [
|
||||
PromptTemplate(
|
||||
name="question-answering",
|
||||
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
|
||||
"{query}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-per-document",
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: "
|
||||
"{query}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-with-references",
|
||||
prompt_text="Create a concise and informative answer (no more than 50 words) for a given question "
|
||||
"based solely on the given documents. You must only use information from the given documents. "
|
||||
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
|
||||
"If multiple documents contain the answer, cite those documents like ‘as stated in Document[number], Document[number], etc.’. "
|
||||
"If the documents do not contain the answer to the question, say that ‘answering is not possible given the available information.’\n"
|
||||
"{join(documents, delimiter=new_line, pattern=new_line+'Document[$idx]: $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n Question: {query}; Answer: ",
|
||||
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-with-document-scores",
|
||||
prompt_text="Answer the following question using the paragraphs below as sources. "
|
||||
"An answer should be short, a few words at most.\n"
|
||||
"Paragraphs:\n{documents}\n"
|
||||
"Question: {query}\n\n"
|
||||
"Instructions: Consider all the paragraphs above and their corresponding scores to generate "
|
||||
"the answer. While a single paragraph may have a high score, it's important to consider all "
|
||||
"paragraphs for the same answer candidate to answer accurately.\n\n"
|
||||
"After having considered all possibilities, the final answer is:\n",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-generation",
|
||||
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conditioned-question-generation",
|
||||
prompt_text="Please come up with a question for the given context and the answer. "
|
||||
"Context: {documents}; Answer: {answers}; Question:",
|
||||
),
|
||||
PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
|
||||
PromptTemplate(
|
||||
name="question-answering-check",
|
||||
prompt_text="Does the following context contain the answer to the question? "
|
||||
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="multiple-choice-question-answering",
|
||||
prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. "
|
||||
"Options: {options}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
),
|
||||
PromptTemplate(
|
||||
name="topic-classification",
|
||||
prompt_text="Categories: {options}; What category best describes: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="language-detection",
|
||||
prompt_text="Detect the language in the following context and answer with the "
|
||||
"name of the language. Context: {documents}; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="translation",
|
||||
prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="zero-shot-react",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"{tool_names_with_descriptions}\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
||||
"for a final answer, respond with the `Final Answer:`\n\n"
|
||||
"Use the following format:\n\n"
|
||||
"Question: the question to be answered\n"
|
||||
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
||||
"Tool: pick one of {tool_names} \n"
|
||||
"Tool Input: the input for the tool\n"
|
||||
"Observation: the tool will respond with the result\n"
|
||||
"...\n"
|
||||
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: {query}\n"
|
||||
"Thought: Let's think step-by-step, I first need to {transcript}",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conversational-agent",
|
||||
prompt_text="The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conversational-summary",
|
||||
prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:",
|
||||
),
|
||||
]
|
||||
|
||||
@ -58,6 +58,7 @@ dependencies = [
|
||||
"rank_bm25",
|
||||
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
|
||||
"generalimport", # Optional imports
|
||||
"prompthub-py",
|
||||
|
||||
# Utils
|
||||
"dill", # pickle extension for (de-)serialization
|
||||
|
||||
@ -177,7 +177,7 @@ def test_tool_result_extraction(reader, retriever_with_docs):
|
||||
assert result == "Paris" or result == "Madrid"
|
||||
|
||||
# PromptNode as a Tool
|
||||
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
|
||||
pt = PromptTemplate("Here is a question: {query}, Answer:")
|
||||
pn = PromptNode(default_prompt_template=pt)
|
||||
|
||||
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
|
||||
@ -212,12 +212,11 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
|
||||
country_finder = PromptNode(
|
||||
model_name_or_path=prompt_model,
|
||||
default_prompt_template=PromptTemplate(
|
||||
name="country_finder",
|
||||
prompt_text="When I give you a name of the city, respond with the country where the city is located.\n"
|
||||
"When I give you a name of the city, respond with the country where the city is located.\n"
|
||||
"City: Rome\nCountry: Italy\n"
|
||||
"City: Berlin\nCountry: Germany\n"
|
||||
"City: Belgrade\nCountry: Serbia\n"
|
||||
"City: {query}?\nCountry: ",
|
||||
"City: {query}?\nCountry: "
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
|
||||
@ -8,8 +8,11 @@ from haystack.nodes import PromptNode
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init():
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Test normal case
|
||||
assert isinstance(agent.memory, ConversationMemory)
|
||||
assert callable(agent.prompt_parameters_resolver)
|
||||
@ -19,21 +22,33 @@ def test_init():
|
||||
# ConversationalAgent doesn't have tools
|
||||
assert not agent.tm.tools
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_summary_memory():
|
||||
prompt_node = PromptNode(default_prompt_template="this is a test")
|
||||
# Test with summary memory
|
||||
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
|
||||
assert isinstance(agent.memory, ConversationSummaryMemory)
|
||||
|
||||
# Test with no memory
|
||||
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
||||
assert isinstance(agent.memory, NoMemory)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_no_memory():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
# Test with no memory
|
||||
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
||||
assert isinstance(agent.memory, NoMemory)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run():
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Mock the Agent run method
|
||||
agent.run = MagicMock(return_value="Hello")
|
||||
assert agent.run("query") == "Hello"
|
||||
agent.run.assert_called_once_with("query")
|
||||
# Mock the Agent run method
|
||||
agent.run = MagicMock(return_value="Hello")
|
||||
assert agent.run("query") == "Hello"
|
||||
agent.run.assert_called_once_with("query")
|
||||
|
||||
@ -9,9 +9,7 @@ from haystack.agents.memory import ConversationSummaryMemory
|
||||
@pytest.fixture
|
||||
def mocked_prompt_node():
|
||||
mock_prompt_node = MagicMock(spec=PromptNode)
|
||||
mock_prompt_node.default_prompt_template = PromptTemplate(
|
||||
"conversational-summary", "Summarize the conversation: {chat_transcript}"
|
||||
)
|
||||
mock_prompt_node.default_prompt_template = PromptTemplate("Summarize the conversation: {chat_transcript}")
|
||||
mock_prompt_node.prompt.return_value = ["This is a summary."]
|
||||
return mock_prompt_node
|
||||
|
||||
@ -127,7 +125,7 @@ def test_conversation_summary_is_accumulating(mocked_prompt_node):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_summary_memory_with_template(mocked_prompt_node):
|
||||
pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}")
|
||||
pt = PromptTemplate("Summarize the conversation: {chat_transcript}")
|
||||
summary_mem = ConversationSummaryMemory(mocked_prompt_node, prompt_template=pt)
|
||||
|
||||
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
|
||||
@ -398,9 +398,8 @@ class MockPromptNode(PromptNode):
|
||||
|
||||
def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]) -> Optional[PromptTemplate]:
|
||||
if prompt_template == "think-step-by-step":
|
||||
return PromptTemplate(
|
||||
name="think-step-by-step",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
p = PromptTemplate(
|
||||
"You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"{tool_names_with_descriptions}\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
@ -417,10 +416,11 @@ class MockPromptNode(PromptNode):
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: {query}\n"
|
||||
"Thought: Let's think step-by-step, I first need to {generated_text}",
|
||||
"Thought: Let's think step-by-step, I first need to {generated_text}"
|
||||
)
|
||||
p.name = "think-step-by-step"
|
||||
else:
|
||||
return PromptTemplate(name="", prompt_text="")
|
||||
return PromptTemplate("test prompt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -158,10 +158,8 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs):
|
||||
pytest.skip("No API key found, skipping test")
|
||||
|
||||
lfqa_prompt = PromptTemplate(
|
||||
name="lfqa",
|
||||
prompt_text="""
|
||||
Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question.
|
||||
\n===\Paragraphs: {context}\n===\n{query}""",
|
||||
"""Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and
|
||||
the given question.\n===\Paragraphs: {context}\n===\n{query}"""
|
||||
)
|
||||
node = OpenAIAnswerGenerator(
|
||||
api_key=haystack_openai_config["api_key"],
|
||||
|
||||
@ -845,8 +845,7 @@ def test_strings_to_answers_after_prompt_node_yaml(tmp_path):
|
||||
- name: prompt_template_raw_qa_per_document
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: raw-question-answering-per-document
|
||||
prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
|
||||
prompt: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
|
||||
|
||||
- name: prompt_node_raw_qa
|
||||
type: PromptNode
|
||||
|
||||
@ -11,6 +11,15 @@ from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, DefaultTokenStreamingHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompthub():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = (
|
||||
lambda name: "This is a test prompt. Use your knowledge to answer this question: {question}"
|
||||
)
|
||||
yield mock_prompthub
|
||||
|
||||
|
||||
def skip_test_for_invalid_key(prompt_model):
|
||||
if prompt_model.api_key is not None and prompt_model.api_key == "KEY_NOT_FOUND":
|
||||
pytest.skip("No API key found, skipping test")
|
||||
@ -24,59 +33,6 @@ def get_api_key(request):
|
||||
return os.environ.get("AZURE_OPENAI_API_KEY", None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_and_remove_template():
|
||||
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
|
||||
node = PromptNode()
|
||||
total_count = 16
|
||||
# Verifies default
|
||||
assert len(node.get_prompt_template_names()) == total_count
|
||||
|
||||
# Add a fake template
|
||||
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
|
||||
node.add_prompt_template(fake_template)
|
||||
assert len(node.get_prompt_template_names()) == total_count + 1
|
||||
assert "fake-template" in node.get_prompt_template_names()
|
||||
|
||||
# Verify that adding the same template throws an expection
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.add_prompt_template(fake_template)
|
||||
assert e.match(
|
||||
"Prompt template fake-template already exists. Select a different name for this prompt template."
|
||||
)
|
||||
|
||||
# Verify template is correctly removed
|
||||
assert node.remove_prompt_template("fake-template")
|
||||
assert len(node.get_prompt_template_names()) == total_count
|
||||
assert "fake-template" not in node.get_prompt_template_names()
|
||||
|
||||
# Verify that removing the same template throws an expection
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.remove_prompt_template("fake-template")
|
||||
assert e.match("Prompt template fake-template does not exist")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_after_adding_template(mock_model):
|
||||
# Make model always return something positive on invoke
|
||||
mock_model.return_value.invoke.return_value = ["positive"]
|
||||
|
||||
# Create a template
|
||||
template = PromptTemplate(
|
||||
name="fake-sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
|
||||
# Execute prompt
|
||||
node = PromptNode()
|
||||
node.add_prompt_template(template)
|
||||
result = node.prompt("fake-sentiment-analysis", documents=["Berlin is an amazing city."])
|
||||
|
||||
assert result == ["positive"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_passing_template(mock_model):
|
||||
@ -85,9 +41,8 @@ def test_prompt_passing_template(mock_model):
|
||||
|
||||
# Create a template
|
||||
template = PromptTemplate(
|
||||
name="fake-sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
"Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:"
|
||||
)
|
||||
|
||||
# Execute prompt
|
||||
@ -137,69 +92,92 @@ def test_prompt_call_with_custom_kwargs_and_template(mock_model, mocked_prompt):
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_without_default_template(mock_model):
|
||||
def test_get_prompt_template_no_default_template(mock_model):
|
||||
node = PromptNode()
|
||||
assert node.get_prompt_template() is None
|
||||
|
||||
template = node.get_prompt_template("question-answering")
|
||||
assert template.name == "question-answering"
|
||||
|
||||
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
|
||||
assert template.name == "fake-template"
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_with_default_template(mock_model, mock_prompthub):
|
||||
node = PromptNode()
|
||||
node.default_prompt_template = "deepset/test-prompt"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.get_prompt_template("some-unsupported-template")
|
||||
assert e.match("some-unsupported-template not supported, select one of:")
|
||||
template = node.get_prompt_template()
|
||||
assert template.name == "deepset/test-prompt"
|
||||
|
||||
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "fake-yaml-template"
|
||||
|
||||
fake_yaml_prompt = "- prompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "custom-at-query-time"
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_name_from_hub(mock_model, mock_prompthub):
|
||||
node = PromptNode()
|
||||
template = node.get_prompt_template("deepset/test-prompt")
|
||||
assert template.name == "deepset/test-prompt"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_local_file(mock_model, tmp_path, mock_prompthub):
|
||||
with open(tmp_path / "local_prompt_template.yml", "w") as ptf:
|
||||
ptf.write(
|
||||
"""
|
||||
name: my_prompts/question-answering
|
||||
prompt_text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
description: A simple prompt to answer a question given a set of documents
|
||||
tags:
|
||||
- question-answering
|
||||
meta:
|
||||
authors:
|
||||
- vblagoje
|
||||
version: v0.1.1
|
||||
"""
|
||||
)
|
||||
node = PromptNode()
|
||||
template = node.get_prompt_template(str(tmp_path / "local_prompt_template.yml"))
|
||||
assert template.name == "my_prompts/question-answering"
|
||||
assert "Given the context" in template.prompt_text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_object(mock_model, mock_prompthub):
|
||||
node = PromptNode()
|
||||
original_template = PromptTemplate("fake-template")
|
||||
template = node.get_prompt_template(original_template)
|
||||
assert template == original_template
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_wrong_template_name(mock_model):
|
||||
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mock_prompthub:
|
||||
|
||||
def not_found(*a, **k):
|
||||
raise ValueError("'some-unsupported-template' not supported!")
|
||||
|
||||
mock_prompthub.fetch.side_effect = not_found
|
||||
node = PromptNode()
|
||||
with pytest.raises(ValueError, match="not supported") as e:
|
||||
node.get_prompt_template("some-unsupported-template")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_only_template_text(mock_model, mock_prompthub):
|
||||
node = PromptNode()
|
||||
template = node.get_prompt_template("some prompt")
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_with_default_template(mock_model):
|
||||
def test_invalid_template_params(mock_model, mock_prompthub):
|
||||
node = PromptNode()
|
||||
node.set_default_prompt_template("question-answering")
|
||||
|
||||
template = node.get_prompt_template()
|
||||
assert template.name == "question-answering"
|
||||
|
||||
template = node.get_prompt_template("sentiment-analysis")
|
||||
assert template.name == "sentiment-analysis"
|
||||
|
||||
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
|
||||
assert template.name == "fake-template"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.get_prompt_template("some-unsupported-template")
|
||||
assert e.match("some-unsupported-template not supported, select one of:")
|
||||
|
||||
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "fake-yaml-template"
|
||||
|
||||
fake_yaml_prompt = "- prompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
template = node.get_prompt_template("some prompt")
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_invalid_template_params():
|
||||
# TODO: This can be a PromptTemplate unit test
|
||||
node = PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
with pytest.raises(ValueError, match="Expected prompt parameters"):
|
||||
node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -282,7 +260,10 @@ def test_stop_words(prompt_model):
|
||||
node = PromptNode(prompt_model, stop_words=["capital"])
|
||||
|
||||
# with default prompt template and stop words set in PN
|
||||
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
)
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
|
||||
# test stop words for both HF and OpenAI
|
||||
@ -290,11 +271,18 @@ def test_stop_words(prompt_model):
|
||||
node = PromptNode(prompt_model, stop_words=["capital", "Germany"])
|
||||
|
||||
# with default prompt template and stop words set in PN
|
||||
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
)
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
|
||||
# with default prompt template and stop words set in kwargs (overrides PN stop words)
|
||||
r = node.prompt("question-generation", documents=["Berlin is the capital of Germany."], stop_words=None)
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
stop_words=None,
|
||||
)
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
# simple prompting
|
||||
@ -310,10 +298,7 @@ def test_stop_words(prompt_model):
|
||||
)
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
tt = PromptTemplate(
|
||||
name="question-generation-copy",
|
||||
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
|
||||
)
|
||||
tt = PromptTemplate("Given the context please generate a question. Context: {documents}; Question:")
|
||||
# with custom prompt template
|
||||
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
@ -520,7 +505,7 @@ def test_pipeline_with_qa_with_references(prompt_model):
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_prompt_text_at_query_time(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
|
||||
node = PromptNode(prompt_model, default_prompt_template="test prompt template text", top_k=1)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
@ -609,8 +594,7 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
|
||||
- name: template_with_nested_shaper
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: custom-template-with-nested-shaper
|
||||
prompt_text: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: "
|
||||
prompt: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: "
|
||||
output_parser:
|
||||
type: AnswerParser
|
||||
- name: p1
|
||||
@ -674,8 +658,7 @@ def test_complex_pipeline_with_qa(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
|
||||
"Given the context please answer the question. Context: {documents}; Question: {query}; Answer:"
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
@ -874,8 +857,7 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
@ -954,8 +936,7 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
@ -1028,8 +1009,7 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
prompt: "Given the context please generate a question. Context: {{documents}}; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel_openai
|
||||
@ -1104,9 +1084,7 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
||||
class TestTokenLimit:
|
||||
@pytest.mark.integration
|
||||
def test_hf_token_limit_warning(self, caplog):
|
||||
prompt_template = PromptTemplate(
|
||||
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
|
||||
)
|
||||
prompt_template = PromptTemplate("Repeating text" * 200 + "Docs: {documents}; Answer:")
|
||||
with caplog.at_level(logging.WARNING):
|
||||
node = PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
@ -1119,7 +1097,7 @@ class TestTokenLimit:
|
||||
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_openai_token_limit_warning(self, caplog):
|
||||
tt = PromptTemplate(name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:")
|
||||
tt = PromptTemplate("Repeating text" * 200 + "Docs: {documents}; Answer:")
|
||||
prompt_node = PromptNode("text-ada-001", max_length=2000, api_key=os.environ.get("OPENAI_API_KEY", ""))
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_ = prompt_node.prompt(tt, documents=["Berlin is an amazing city."])
|
||||
@ -1133,7 +1111,10 @@ class TestRunBatch:
|
||||
def test_simple_pipeline_batch_no_query_single_doc_list(self, prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
||||
node = PromptNode(
|
||||
prompt_model,
|
||||
default_prompt_template="Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
@ -1151,7 +1132,11 @@ class TestRunBatch:
|
||||
def test_simple_pipeline_batch_no_query_multiple_doc_list(self, prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||
node = PromptNode(
|
||||
prompt_model,
|
||||
default_prompt_template="Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:",
|
||||
output_variable="out",
|
||||
)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
@ -1174,8 +1159,7 @@ class TestRunBatch:
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:",
|
||||
"Given the context please answer the question. Context: {documents}; Question: {query}; Answer:"
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from typing import Set, Type, List
|
||||
from unittest.mock import patch
|
||||
import textwrap
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
import prompthub
|
||||
|
||||
from haystack.nodes.prompt import PromptTemplate
|
||||
from haystack.nodes.prompt.prompt_node import PromptNode
|
||||
@ -11,29 +13,81 @@ from haystack.pipelines.base import Pipeline
|
||||
from haystack.schema import Answer, Document
|
||||
|
||||
|
||||
def mock_prompthub():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [
|
||||
("deepset/test-prompt", "This is a test prompt. Use your knowledge to answer this question: {question}")
|
||||
]
|
||||
yield mock_prompthub
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates():
|
||||
p = PromptTemplate("t1", "Here is some fake template with variable {foo}")
|
||||
def test_prompt_templates_from_hub():
|
||||
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mock_prompthub:
|
||||
PromptTemplate("deepset/question-answering")
|
||||
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_from_file(tmp_path):
|
||||
path = tmp_path / "test-prompt.yml"
|
||||
with open(path, "a") as yamlfile:
|
||||
yamlfile.write(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
name: deepset/question-answering
|
||||
prompt_text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
description: A simple prompt to answer a question given a set of documents
|
||||
tags:
|
||||
- question-answering
|
||||
meta:
|
||||
authors:
|
||||
- vblagoje
|
||||
version: v0.1.1
|
||||
"""
|
||||
)
|
||||
)
|
||||
p = PromptTemplate(str(path.absolute()))
|
||||
assert p.name == "deepset/question-answering"
|
||||
assert "Given the context please answer the question" in p.prompt_text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_on_the_fly():
|
||||
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml:
|
||||
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph:
|
||||
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
|
||||
assert p.name == "custom-at-query-time"
|
||||
mocked_ph.fetch.assert_not_called()
|
||||
mocked_yaml.safe_load.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_prompt_templates():
|
||||
p = PromptTemplate("Here is some fake template with variable {foo}")
|
||||
assert set(p.prompt_params) == {"foo"}
|
||||
|
||||
p = PromptTemplate("t3", "Here is some fake template with variable {foo} and {bar}")
|
||||
p = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
|
||||
assert set(p.prompt_params) == {"foo", "bar"}
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {foo1} and {bar2}")
|
||||
p = PromptTemplate("Here is some fake template with variable {foo1} and {bar2}")
|
||||
assert set(p.prompt_params) == {"foo1", "bar2"}
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {foo_1} and {bar_2}")
|
||||
p = PromptTemplate("Here is some fake template with variable {foo_1} and {bar_2}")
|
||||
assert set(p.prompt_params) == {"foo_1", "bar_2"}
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable {Foo_1} and {Bar_2}")
|
||||
p = PromptTemplate("Here is some fake template with variable {Foo_1} and {Bar_2}")
|
||||
assert set(p.prompt_params) == {"Foo_1", "Bar_2"}
|
||||
|
||||
p = PromptTemplate("t4", "'Here is some fake template with variable {baz}'")
|
||||
p = PromptTemplate("'Here is some fake template with variable {baz}'")
|
||||
assert set(p.prompt_params) == {"baz"}
|
||||
# strip single quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable {baz}"
|
||||
|
||||
p = PromptTemplate("t4", '"Here is some fake template with variable {baz}"')
|
||||
p = PromptTemplate('"Here is some fake template with variable {baz}"')
|
||||
assert set(p.prompt_params) == {"baz"}
|
||||
# strip double quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable {baz}"
|
||||
@ -41,7 +95,7 @@ def test_prompt_templates():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_prompt_template_params():
|
||||
template = PromptTemplate("missing_params", "Here is some fake template with variable {foo} and {bar}")
|
||||
template = PromptTemplate("Here is some fake template with variable {foo} and {bar}")
|
||||
|
||||
# both params provided - ok
|
||||
template.prepare(foo="foo", bar="bar")
|
||||
@ -62,8 +116,10 @@ def test_missing_prompt_template_params():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_template_repr():
|
||||
p = PromptTemplate("t", "Here is variable {baz}")
|
||||
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
|
||||
p = PromptTemplate("Here is variable {baz}")
|
||||
desired_repr = (
|
||||
"PromptTemplate(name=custom-at-query-time, prompt_text=Here is variable {baz}, prompt_params=['baz'])"
|
||||
)
|
||||
assert repr(p) == desired_repr
|
||||
assert str(p) == desired_repr
|
||||
|
||||
@ -72,8 +128,7 @@ def test_prompt_template_repr():
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_template_deserialization(mock_prompt_model):
|
||||
custom_prompt_template = PromptTemplate(
|
||||
name="custom-question-answering",
|
||||
prompt_text="Given the context please answer the question. Context: {context}; Question: {query}; Answer:",
|
||||
"Given the context please answer the question. Context: {context}; Question: {query}; Answer:",
|
||||
output_parser=AnswerParser(),
|
||||
)
|
||||
|
||||
@ -88,7 +143,6 @@ def test_prompt_template_deserialization(mock_prompt_model):
|
||||
loaded_generator = loaded_pipe.get_node("Generator")
|
||||
assert isinstance(loaded_generator, PromptNode)
|
||||
assert isinstance(loaded_generator.default_prompt_template, PromptTemplate)
|
||||
assert loaded_generator.default_prompt_template.name == "custom-question-answering"
|
||||
assert (
|
||||
loaded_generator.default_prompt_template.prompt_text
|
||||
== "Given the context please answer the question. Context: {context}; Question: {query}; Answer:"
|
||||
@ -135,7 +189,7 @@ class TestPromptTemplateSyntax:
|
||||
def test_prompt_template_syntax_parser(
|
||||
self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
assert set(prompt_template.prompt_params) == expected_prompt_params
|
||||
assert set(prompt_template._used_functions) == expected_used_functions
|
||||
|
||||
@ -210,13 +264,13 @@ class TestPromptTemplateSyntax:
|
||||
"how?",
|
||||
["context: doc1 d question: what!"],
|
||||
),
|
||||
("context", None, None, ["context"]),
|
||||
("context: ", None, None, ["context: "]),
|
||||
],
|
||||
)
|
||||
def test_prompt_template_syntax_fill(
|
||||
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@ -243,7 +297,7 @@ class TestPromptTemplateSyntax:
|
||||
],
|
||||
)
|
||||
def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@ -276,7 +330,7 @@ class TestPromptTemplateSyntax:
|
||||
],
|
||||
)
|
||||
def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
@ -300,7 +354,7 @@ class TestPromptTemplateSyntax:
|
||||
self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str
|
||||
):
|
||||
with pytest.raises(exc_type, match=expected_exc_match):
|
||||
PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
PromptTemplate(prompt_text)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
@ -316,7 +370,7 @@ class TestPromptTemplateSyntax:
|
||||
expected_exc_match: str,
|
||||
):
|
||||
with pytest.raises(exc_type, match=expected_exc_match):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
next(prompt_template.fill(documents=documents, query=query))
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -337,6 +391,6 @@ class TestPromptTemplateSyntax:
|
||||
def test_prompt_template_syntax_fill_ignores_dangerous_input(
|
||||
self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str]
|
||||
):
|
||||
prompt_template = PromptTemplate(name="test", prompt_text=prompt_text)
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user