mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
feat: Expand LLM support with PromptModel, PromptNode, and PromptTemplate (#3667)
Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
parent
559f6e0569
commit
9ebf164cfd
@ -23,6 +23,7 @@ from haystack.nodes.file_converter import (
|
||||
from haystack.nodes.label_generator import PseudoLabelGenerator
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger
|
||||
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
|
||||
from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
|
||||
|
||||
1
haystack/nodes/prompt/__init__.py
Normal file
1
haystack/nodes/prompt/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.nodes.prompt.prompt_node import PromptNode, PromptTemplate, PromptModel
|
||||
822
haystack/nodes/prompt/prompt_node.py
Normal file
822
haystack/nodes/prompt/prompt_node.py
Normal file
@ -0,0 +1,822 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from string import Template
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any, Type
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from transformers import pipeline, AutoModelForSeq2SeqLM
|
||||
|
||||
from haystack import MultiLabel
|
||||
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseComponent):
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
raise NotImplementedError("This method should never be implemented in the derived class")
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should never be implemented in the derived class")
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate, ABC):
|
||||
"""
|
||||
PromptTemplate represents a template for a prompt. For example, a prompt template for the sentiment
|
||||
analysis task might look like this:
|
||||
|
||||
```python
|
||||
PromptTemplate(name="sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, negative
|
||||
or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"])
|
||||
```
|
||||
|
||||
PromptTemplate declares prompt_params, which are the input parameters that need to be filled in the prompt_text.
|
||||
For example, in the above example, the prompt_params are ["documents"] and the prompt_text is
|
||||
"Please give a sentiment..."
|
||||
|
||||
The prompt_text contains a placeholder $documents. This variable will be filled in runtime with the non-keyword
|
||||
or keyword argument `documents` passed to this PromptTemplate's fill() method.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, prompt_text: str, prompt_params: Optional[List[str]] = None):
|
||||
super().__init__()
|
||||
if not prompt_params:
|
||||
# Define the regex pattern to match the strings after the $ character
|
||||
pattern = r"\$([a-zA-Z0-9_]+)"
|
||||
prompt_params = re.findall(pattern, prompt_text)
|
||||
|
||||
if prompt_text.count("$") != len(prompt_params):
|
||||
raise ValueError(
|
||||
f"Number of parameters in prompt text {prompt_text} for prompt template {name} "
|
||||
f"does not match number of specified parameters {prompt_params}"
|
||||
)
|
||||
|
||||
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
||||
prompt_text = prompt_text.strip("'").strip('"')
|
||||
|
||||
t = Template(prompt_text)
|
||||
try:
|
||||
t.substitute(**{param: "" for param in prompt_params})
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid parameter {e} in prompt text "
|
||||
f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}"
|
||||
)
|
||||
|
||||
self.name = name
|
||||
self.prompt_text = prompt_text
|
||||
self.prompt_params = prompt_params
|
||||
|
||||
def fill(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepares the prompt text parameters from non-keyword and keyword arguments.
|
||||
|
||||
In the case of non-keyword arguments, the order of the arguments should match the left-to-right
|
||||
order of appearance of the parameters in the prompt text. For example, if the prompt text is:
|
||||
`Please come up with a question for the given context and the answer. Context: $documents;
|
||||
Answer: $answers; Question:` then the first non-keyword argument will fill the $documents placeholder
|
||||
and the second non-keyword argument will fill the $answers placeholder.
|
||||
|
||||
In the case of keyword arguments, the order of the arguments does not matter. Placeholders in the
|
||||
prompt text are filled with the corresponding keyword argument.
|
||||
|
||||
:param args: non-keyword arguments to use for filling the prompt text
|
||||
:param kwargs: keyword arguments to use for filling the prompt text
|
||||
:return: a dictionary with the prompt text and the prompt parameters
|
||||
"""
|
||||
template_dict = {}
|
||||
# attempt to resolve args first
|
||||
if args:
|
||||
if len(args) != len(self.prompt_params):
|
||||
logger.warning(
|
||||
f"For {self.name}, expected {self.prompt_params} arguments, instead "
|
||||
f"got {len(args)} arguments {args}"
|
||||
)
|
||||
for prompt_param, arg in zip(self.prompt_params, args):
|
||||
template_dict[prompt_param] = [arg] if isinstance(arg, str) else arg
|
||||
# then attempt to resolve kwargs
|
||||
if kwargs:
|
||||
for param in self.prompt_params:
|
||||
if param in kwargs:
|
||||
template_dict[param] = kwargs[param]
|
||||
|
||||
if set(template_dict.keys()) != set(self.prompt_params):
|
||||
available_params = set(list(template_dict.keys()) + list(set(kwargs.keys())))
|
||||
raise ValueError(f"Expected prompt params {self.prompt_params} but got {list(available_params)}")
|
||||
|
||||
template_dict["prompt_template"] = self.prompt_text
|
||||
return template_dict
|
||||
|
||||
|
||||
PREDEFINED_PROMPT_TEMPLATES = [
|
||||
PromptTemplate(
|
||||
name="question-answering",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: $questions; Answer:",
|
||||
prompt_params=["documents", "questions"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-generation",
|
||||
prompt_text="Given the context please generate a question. Context: $documents; Question:",
|
||||
prompt_params=["documents"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conditioned-question-generation",
|
||||
prompt_text="Please come up with a question for the given context and the answer. "
|
||||
"Context: $documents; Answer: $answers; Question:",
|
||||
prompt_params=["documents", "answers"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="summarization", prompt_text="Summarize this document: $documents Summary:", prompt_params=["documents"]
|
||||
),
|
||||
PromptTemplate(
|
||||
name="question-answering-check",
|
||||
prompt_text="Does the following context contain the answer to the question. "
|
||||
"Context: $documents; Question: $questions; Please answer yes or no! Answer:",
|
||||
prompt_params=["documents", "questions"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="multiple-choice-question-answering",
|
||||
prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. "
|
||||
"Options: $options; Answer:",
|
||||
prompt_params=["questions", "options"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="topic-classification",
|
||||
prompt_text="Categories: $options; What category best describes: $documents; Answer:",
|
||||
prompt_params=["documents", "options"],
|
||||
),
|
||||
PromptTemplate(
|
||||
name="language-detection",
|
||||
prompt_text="Detect the language in the following context and answer with the "
|
||||
"name of the language. Context: $documents; Answer:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="translation",
|
||||
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PromptModelInvocationLayer:
|
||||
"""
|
||||
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
|
||||
|
||||
The implementation can be a simple invocation on the underlying model running in a local runtime, or
|
||||
could be even remote, for example, a call to a remote API endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
|
||||
if model_name_or_path is None or len(model_name_or_path) == 0:
|
||||
raise ValueError("model_name_or_path cannot be None or empty string")
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.max_length: Optional[int] = max_length
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
A subclass of the PromptModelInvocationLayer class. It loads a pre-trained model from Hugging Face and
|
||||
passes a prepared prompt into that model.
|
||||
|
||||
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class
|
||||
as many variants of PromptModelInvocationLayer are possible and they may have different parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "google/flan-t5-base",
|
||||
max_length: Optional[int] = 100,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
use_gpu: Optional[bool] = True,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model_name_or_path, max_length)
|
||||
self.use_auth_token = use_auth_token
|
||||
|
||||
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||
if len(self.devices) > 1:
|
||||
logger.warning(
|
||||
f"Multiple devices are not supported in {self.__class__.__name__} inference, "
|
||||
f"using the first device {self.devices[0]}."
|
||||
)
|
||||
|
||||
# Due to reflective construction of all invocation layers we might receive some
|
||||
# unknown kwargs, so we need to take only the relevant.
|
||||
# For more details refer to Hugging Face pipeline documentation
|
||||
# Do not use `device_map` AND `device` at the same time as they will conflict
|
||||
model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"model_kwargs",
|
||||
"trust_remote_code",
|
||||
"revision",
|
||||
"feature_extractor",
|
||||
"tokenizer",
|
||||
"config",
|
||||
"use_fast",
|
||||
"torch_dtype",
|
||||
"device_map",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
# flatten model_kwargs one level
|
||||
if "model_kwargs" in model_input_kwargs:
|
||||
mkwargs = model_input_kwargs.pop("model_kwargs")
|
||||
model_input_kwargs.update(mkwargs)
|
||||
|
||||
torch_dtype = model_input_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if "torch." not in torch_dtype:
|
||||
raise ValueError(
|
||||
f"torch_dtype should be a torch.dtype or a string with 'torch.' prefix, got {torch_dtype}"
|
||||
)
|
||||
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype_resolved = torch_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
|
||||
model_input_kwargs["torch_dtype"] = torch_dtype_resolved
|
||||
|
||||
if len(model_input_kwargs) > 0:
|
||||
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
|
||||
|
||||
self.pipe = pipeline(
|
||||
"text2text-generation",
|
||||
model=model_name_or_path,
|
||||
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
|
||||
use_auth_token=self.use_auth_token,
|
||||
model_kwargs=model_input_kwargs,
|
||||
)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
It takes a prompt and returns a list of generated text using the local Hugging Face transformers model
|
||||
:return: A list of generated text.
|
||||
"""
|
||||
output = []
|
||||
if kwargs and "prompt" in kwargs:
|
||||
prompt = kwargs.pop("prompt")
|
||||
|
||||
# We might have some uncleaned kwargs, so we need to take only the relevant.
|
||||
# For more details refer to Hugging Face Text2TextGenerationPipeline documentation
|
||||
model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in ["return_tensors", "return_text", "clean_up_tokenization_spaces", "truncation"]
|
||||
if key in kwargs
|
||||
}
|
||||
output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs)
|
||||
return [o["generated_text"] for o in output]
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str) -> bool:
|
||||
if not all(m in model_name_or_path for m in ["google", "flan", "t5"]):
|
||||
return False
|
||||
|
||||
try:
|
||||
# if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly
|
||||
AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
||||
except EnvironmentError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
PromptModelInvocationLayer implementation for OpenAI's GPT-3 InstructGPT models. Invocations are made via REST API.
|
||||
See [OpenAI GPT-3](https://beta.openai.com/docs/models/gpt-3) for more details.
|
||||
|
||||
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class
|
||||
as many variants of PromptModelInvocationLayer are possible and they may have different parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs
|
||||
):
|
||||
super().__init__(model_name_or_path, max_length)
|
||||
if not isinstance(api_key, str) or len(api_key) == 0:
|
||||
raise OpenAIError(
|
||||
f"api_key {api_key} has to be a valid OpenAI key. Please visit https://beta.openai.com/ to get one."
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.url = "https://api.openai.com/v1/completions"
|
||||
|
||||
# Due to reflective construction of all invocation layers we might receive some
|
||||
# unknown kwargs, so we need to take only the relevant.
|
||||
# For more details refer to OpenAI documentation
|
||||
self.model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"suffix",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"logprobs",
|
||||
"echo",
|
||||
"stop",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. It takes in a prompt, and returns a list of responses using a REST invocation.
|
||||
|
||||
:return: The responses are being returned.
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt"
|
||||
f"Make sure to provide prompt in kwargs"
|
||||
)
|
||||
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
if kwargs:
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
payload = {
|
||||
"model": self.model_name_or_path,
|
||||
"prompt": prompt,
|
||||
"suffix": kwargs_with_defaults.get("suffix", None),
|
||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||
"n": kwargs_with_defaults.get("n", 1),
|
||||
"stream": False, # no support for streaming
|
||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||
"echo": kwargs_with_defaults.get("echo", False),
|
||||
"stop": kwargs_with_defaults.get("stop", None),
|
||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||
"best_of": kwargs.get("best_of", 1),
|
||||
"logit_bias": kwargs.get("logit_bias", {}),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
|
||||
res = json.loads(response.text)
|
||||
|
||||
if response.status_code != 200:
|
||||
openai_error: OpenAIError
|
||||
if response.status_code == 429:
|
||||
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
|
||||
else:
|
||||
openai_error = OpenAIError(
|
||||
f"OpenAI returned an error.\n"
|
||||
f"Status code: {response.status_code}\n"
|
||||
f"Response body: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
raise openai_error
|
||||
|
||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||
return responses
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str) -> bool:
|
||||
return any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
|
||||
|
||||
|
||||
class PromptModel(BaseComponent):
|
||||
"""
|
||||
The PromptModel class is a component that uses a pre-trained model to generate text based on a prompt. Out of
|
||||
the box, it supports two model invocation layers: Hugging Face transformers and OpenAI, with the ability to
|
||||
register additional custom invocation layers.
|
||||
|
||||
Although it is possible to use PromptModel to make prompt invocations on the underlying model, please use
|
||||
PromptNode for interactions with the model. PromptModel instances are the practical approach for multiple
|
||||
PromptNode instances to use a single PromptNode and thus save computational resources.
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "google/flan-t5-base",
|
||||
max_length: Optional[int] = 100,
|
||||
api_key: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
model_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.max_length = max_length
|
||||
self.api_key = api_key
|
||||
self.use_auth_token = use_auth_token
|
||||
self.use_gpu = use_gpu
|
||||
self.devices = devices
|
||||
|
||||
self.model_kwargs = model_kwargs if model_kwargs else {}
|
||||
|
||||
self.invocation_layers: List[Type[PromptModelInvocationLayer]] = []
|
||||
|
||||
self.register(HFLocalInvocationLayer) # pylint: disable=W0108
|
||||
self.register(OpenAIInvocationLayer) # pylint: disable=W0108
|
||||
|
||||
self.model_invocation_layer = self.create_invocation_layer()
|
||||
|
||||
def create_invocation_layer(self) -> PromptModelInvocationLayer:
|
||||
kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"use_auth_token": self.use_auth_token,
|
||||
"use_gpu": self.use_gpu,
|
||||
"devices": self.devices,
|
||||
}
|
||||
all_kwargs = {**self.model_kwargs, **kwargs}
|
||||
|
||||
for invocation_layer in self.invocation_layers:
|
||||
if invocation_layer.supports(self.model_name_or_path):
|
||||
return invocation_layer(
|
||||
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Model {self.model_name_or_path} is not supported - no invocation layer found."
|
||||
f"Currently supported models are: {self.invocation_layers}"
|
||||
f"Register new invocation layer for {self.model_name_or_path} using the register method."
|
||||
)
|
||||
|
||||
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
||||
"""
|
||||
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
|
||||
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
|
||||
"""
|
||||
self.invocation_layers.append(invocation_layer)
|
||||
|
||||
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
|
||||
"""
|
||||
It takes in a prompt, and returns a list of responses using the underlying invocation layer.
|
||||
|
||||
:param prompt: The prompt to use for the invocation, it could be a single prompt or a list of prompts
|
||||
:param kwargs: Additional keyword arguments to pass to the invocation layer
|
||||
:return: A list of model generated responses for the prompt or prompts
|
||||
"""
|
||||
output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
|
||||
return output
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
raise NotImplementedError("This method should never be implemented in the derived class")
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should never be implemented in the derived class")
|
||||
|
||||
|
||||
class PromptNode(BaseComponent):
|
||||
"""
|
||||
The PromptNode class is the central abstraction in Haystack's large language model (LLM) support. PromptNode
|
||||
supports multiple NLP tasks out of the box. PromptNode allows users to perform multiple tasks, such as
|
||||
summarization, question answering, question generation etc., using a single, unified model within the Haystack
|
||||
framework.
|
||||
|
||||
One of the benefits of PromptNode is that it allows users to define and add additional prompt templates
|
||||
that the model supports. Defining additional prompt templates enables users to extend the model's capabilities
|
||||
and use it for a broader range of NLP tasks within the Haystack ecosystem. Prompt engineers define templates
|
||||
for each NLP task and register them with PromptNode. The burden of defining templates for each task rests on
|
||||
the prompt engineers, not the users.
|
||||
|
||||
Using an instance of PromptModel class, we can create multiple PromptNodes that share the same model, saving
|
||||
the memory and time required to load the model multiple times.
|
||||
|
||||
PromptNode also supports multiple model invocation layers: Hugging Face transformers and OpenAI with an
|
||||
ability to register additional custom invocation layers.
|
||||
|
||||
"""
|
||||
|
||||
outgoing_edges: int = 1
|
||||
prompt_templates: Dict[str, PromptTemplate] = {
|
||||
prompt_template.name: prompt_template for prompt_template in PREDEFINED_PROMPT_TEMPLATES # type: ignore
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[str, PromptModel] = "google/flan-t5-base",
|
||||
default_prompt_template: Optional[Union[str, PromptTemplate]] = None,
|
||||
output_variable: Optional[str] = None,
|
||||
max_length: Optional[int] = 100,
|
||||
api_key: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
||||
self.output_variable: Optional[str] = output_variable
|
||||
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
||||
self.prompt_model: PromptModel
|
||||
if isinstance(self.default_prompt_template, str) and not self.is_supported_template(
|
||||
self.default_prompt_template
|
||||
):
|
||||
raise ValueError(
|
||||
f"Prompt template {self.default_prompt_template} is not supported. "
|
||||
f"Select one of: {self.get_prompt_template_names()} "
|
||||
f"or first register a new prompt template using the add_prompt_template method."
|
||||
)
|
||||
|
||||
if isinstance(model_name_or_path, str):
|
||||
self.prompt_model = PromptModel(
|
||||
model_name_or_path=model_name_or_path,
|
||||
max_length=max_length,
|
||||
api_key=api_key,
|
||||
use_auth_token=use_auth_token,
|
||||
use_gpu=use_gpu,
|
||||
devices=devices,
|
||||
)
|
||||
elif isinstance(model_name_or_path, PromptModel):
|
||||
self.prompt_model = model_name_or_path
|
||||
else:
|
||||
raise ValueError(f"model_name_or_path must be either a string or a PromptModel object")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> List[str]:
|
||||
"""
|
||||
This method is invoked when the component is called directly, for example:
|
||||
```python
|
||||
PromptNode pn = ...
|
||||
sa = pn.set_default_prompt_template("sentiment-analysis")
|
||||
sa(documents=[Document("I am in love and I feel great!")])
|
||||
```
|
||||
"""
|
||||
if "prompt_template_name" in kwargs:
|
||||
prompt_template_name = kwargs["prompt_template_name"]
|
||||
kwargs.pop("prompt_template_name")
|
||||
return self.prompt(prompt_template_name, *args, **kwargs)
|
||||
else:
|
||||
return self.prompt(self.default_prompt_template, *args, **kwargs)
|
||||
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
||||
"""
|
||||
Prompts the model and represents the central API for the PromptNode. It takes a prompt template,
|
||||
a list of non-keyword and keyword arguments, and returns a list of strings - the responses from
|
||||
the underlying model.
|
||||
|
||||
The optional prompt_template parameter, if specified, takes precedence over the default prompt
|
||||
template for this PromptNode.
|
||||
|
||||
:param prompt_template: The name of the optional prompt template to use
|
||||
:return: A list of strings as model responses
|
||||
"""
|
||||
results = []
|
||||
prompt_prepared: Dict[str, Any] = {}
|
||||
if isinstance(prompt_template, str) and not self.is_supported_template(prompt_template):
|
||||
raise ValueError(
|
||||
f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()} "
|
||||
f"or pass a PromptTemplate instance for prompting."
|
||||
)
|
||||
|
||||
invoke_template = self.default_prompt_template if prompt_template is None else prompt_template
|
||||
if args and invoke_template is None:
|
||||
# create straightforward prompt on the input, no templates used
|
||||
prompt_prepared["prompt"] = list(args)
|
||||
else:
|
||||
template_to_fill: PromptTemplate
|
||||
if isinstance(prompt_template, PromptTemplate):
|
||||
template_to_fill = prompt_template
|
||||
elif isinstance(prompt_template, str):
|
||||
template_to_fill = self.get_prompt_template(prompt_template)
|
||||
else:
|
||||
raise ValueError(f"{prompt_template} with args {args} , and kwargs {kwargs} not supported")
|
||||
# we have potentially args and kwargs; task selected, so templating is needed
|
||||
prompt_prepared = template_to_fill.fill(*args, **kwargs)
|
||||
|
||||
# straightforward prompt, no templates used
|
||||
if "prompt" in prompt_prepared:
|
||||
for prompt in prompt_prepared["prompt"]:
|
||||
output = self.prompt_model.invoke(prompt)
|
||||
for item in output:
|
||||
results.append(item)
|
||||
# templated prompt
|
||||
# we have a prompt dictionary with prompt_template text and key/value pairs for template variables
|
||||
# where key is the variable name and value is a list of variable values
|
||||
# we invoke the model iterating through a list of prompt variable values replacing the variables
|
||||
# in the prompt template
|
||||
elif "prompt_template" in prompt_prepared:
|
||||
template = Template(prompt_prepared["prompt_template"])
|
||||
prompt_context_copy = prompt_prepared.copy()
|
||||
prompt_context_copy.pop("prompt_template")
|
||||
for prompt_context_values in zip(*prompt_context_copy.values()):
|
||||
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
|
||||
template_prepared: str = template.substitute(template_input)
|
||||
# remove template keys from kwargs so we don't pass them to the model
|
||||
removed_keys = [kwargs.pop(key) for key in template_input.keys() if key in kwargs]
|
||||
output = self.prompt_model.invoke(template_prepared, **kwargs)
|
||||
for item in output:
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def add_prompt_template(cls, prompt_template: PromptTemplate) -> None:
|
||||
"""
|
||||
Adds a prompt template to the list of supported prompt templates.
|
||||
:param prompt_template: PromptTemplate object to be added.
|
||||
:return: None
|
||||
"""
|
||||
if prompt_template.name in cls.prompt_templates:
|
||||
raise ValueError(
|
||||
f"Prompt template {prompt_template.name} already exists "
|
||||
f"Please select a different name to add this prompt template."
|
||||
)
|
||||
|
||||
cls.prompt_templates[prompt_template.name] = prompt_template # type: ignore
|
||||
|
||||
@classmethod
|
||||
def remove_prompt_template(cls, prompt_template: str) -> PromptTemplate:
|
||||
"""
|
||||
Removes a prompt template from the list of supported prompt templates.
|
||||
:param prompt_template: Name of the prompt template to be removed.
|
||||
:return: PromptTemplate object that was removed.
|
||||
"""
|
||||
if prompt_template in [template.name for template in PREDEFINED_PROMPT_TEMPLATES]:
|
||||
raise ValueError(f"Cannot remove predefined prompt template {prompt_template}")
|
||||
if prompt_template not in cls.prompt_templates:
|
||||
raise ValueError(f"Prompt template {prompt_template} does not exist")
|
||||
|
||||
return cls.prompt_templates.pop(prompt_template)
|
||||
|
||||
def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode":
|
||||
"""
|
||||
Sets the default prompt template for the node.
|
||||
:param prompt_template: the prompt template to be set as default.
|
||||
:return: the current PromptNode object
|
||||
"""
|
||||
if not self.is_supported_template(prompt_template):
|
||||
raise ValueError(
|
||||
f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()}"
|
||||
)
|
||||
|
||||
self.default_prompt_template = prompt_template
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_prompt_templates(cls) -> List[PromptTemplate]:
|
||||
"""
|
||||
Returns the list of supported prompt templates.
|
||||
:return: List of supported prompt templates.
|
||||
"""
|
||||
return list(cls.prompt_templates.values())
|
||||
|
||||
@classmethod
|
||||
def get_prompt_template_names(cls) -> List[str]:
|
||||
"""
|
||||
Returns the list of supported prompt template names.
|
||||
:return: List of supported prompt template names.
|
||||
"""
|
||||
return list(cls.prompt_templates.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_template(cls, prompt_template: Union[str, PromptTemplate]) -> bool:
|
||||
"""
|
||||
Checks if a prompt template is supported.
|
||||
:param prompt_template: the prompt template to be checked.
|
||||
:return: True if the prompt template is supported, False otherwise.
|
||||
"""
|
||||
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
|
||||
return template_name in cls.prompt_templates
|
||||
|
||||
@classmethod
|
||||
def get_prompt_template(cls, prompt_template_name: str) -> PromptTemplate:
|
||||
"""
|
||||
Returns a prompt template by name.
|
||||
:param prompt_template_name: the name of the prompt template to be returned.
|
||||
:return: the prompt template object.
|
||||
"""
|
||||
if prompt_template_name not in cls.prompt_templates:
|
||||
raise ValueError(f"Prompt template {prompt_template_name} not supported")
|
||||
return cls.prompt_templates[prompt_template_name]
|
||||
|
||||
@classmethod
|
||||
def prompt_template_params(cls, prompt_template: str) -> List[str]:
|
||||
"""
|
||||
Returns the list of parameters for a prompt template.
|
||||
:param prompt_template: the name of the prompt template.
|
||||
:return: the list of parameters for the prompt template.
|
||||
"""
|
||||
if not cls.is_supported_template(prompt_template):
|
||||
raise ValueError(
|
||||
f"{prompt_template} not supported, please select one of: {cls.get_prompt_template_names()}"
|
||||
)
|
||||
|
||||
return list(cls.prompt_templates[prompt_template].prompt_params)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, PromptNode):
|
||||
if self.default_prompt_template != other.default_prompt_template:
|
||||
return False
|
||||
return self.model_name_or_path == other.model_name_or_path
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.default_prompt_template, self.model_name_or_path))
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""
|
||||
Runs the prompt node on these inputs parameters. Returns the output of the prompt model
|
||||
Parameters file_paths, labels, and meta are usually ignored.
|
||||
|
||||
:param query: the query is usually ignored by the prompt node unless it is used as a parameter in the
|
||||
prompt template.
|
||||
:param file_paths: the file paths are usually ignored by the prompt node unless it is used as a parameter
|
||||
in the prompt template.
|
||||
:param labels: the labels are usually ignored by the prompt node unless it is used as a parameter in the
|
||||
prompt template.
|
||||
:param documents: the documents to be used for the prompt.
|
||||
:param meta: the meta to be used for the prompt. Usually not used.
|
||||
"""
|
||||
|
||||
if not meta:
|
||||
meta = {}
|
||||
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
|
||||
# to pass results from a pipeline node to any other downstream pipeline node.
|
||||
if "invocation_context" not in meta:
|
||||
meta["invocation_context"] = {}
|
||||
|
||||
results = self(
|
||||
query=query,
|
||||
labels=labels,
|
||||
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
|
||||
**meta["invocation_context"],
|
||||
)
|
||||
|
||||
if self.output_variable:
|
||||
meta["invocation_context"][self.output_variable] = results
|
||||
return {"results": results, "meta": {**meta}}, "output_1"
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
pass
|
||||
@ -98,7 +98,8 @@ def read_pipeline_config_from_yaml(path: Path) -> Dict[str, Any]:
|
||||
return yaml.safe_load(stream)
|
||||
|
||||
|
||||
JSON_FIELDS = ["custom_query"] # ElasticsearchDocumentStore.custom_query
|
||||
JSON_FIELDS = ["custom_query"]
|
||||
SKIP_VALIDATION_KEYS = ["prompt_text"] # PromptTemplate, PromptNode
|
||||
|
||||
|
||||
def validate_config_strings(pipeline_config: Any, is_value: bool = False):
|
||||
@ -123,6 +124,8 @@ def validate_config_strings(pipeline_config: Any, is_value: bool = False):
|
||||
json.loads(value)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise PipelineConfigError(f"'{pipeline_config}' does not contain valid JSON.")
|
||||
elif key in SKIP_VALIDATION_KEYS:
|
||||
continue
|
||||
else:
|
||||
validate_config_strings(key)
|
||||
validate_config_strings(value, is_value=True)
|
||||
|
||||
@ -64,6 +64,7 @@ from haystack.nodes import (
|
||||
QuestionGenerator,
|
||||
)
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
from haystack.nodes.prompt import PromptNode, PromptModel
|
||||
from haystack.schema import Document
|
||||
from haystack.utils.import_utils import _optional_component_not_installed
|
||||
|
||||
@ -1048,3 +1049,19 @@ def bert_base_squad2(request):
|
||||
use_fast=True, # TODO parametrize this to test slow as well
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_node():
|
||||
return PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_model(request):
|
||||
if request.param == "openai":
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "KEY_NOT_FOUND")
|
||||
if api_key is None or api_key == "":
|
||||
api_key = "KEY_NOT_FOUND"
|
||||
return PromptModel("text-davinci-003", api_key=api_key)
|
||||
else:
|
||||
return PromptModel("google/flan-t5-base", devices=["cpu"])
|
||||
|
||||
477
test/nodes/test_prompt_node.py
Normal file
477
test/nodes/test_prompt_node.py
Normal file
@ -0,0 +1,477 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
|
||||
|
||||
def is_openai_api_key_set(api_key: str):
|
||||
return len(api_key) > 0 and api_key != "KEY_NOT_FOUND"
|
||||
|
||||
|
||||
def test_prompt_templates():
|
||||
p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"])
|
||||
|
||||
with pytest.raises(ValueError, match="Number of parameters in"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo"])
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid parameter"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $footur", ["foo"])
|
||||
|
||||
with pytest.raises(ValueError, match="Number of parameters in"):
|
||||
PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo", "bar", "baz"])
|
||||
|
||||
p = PromptTemplate("t3", "Here is some fake template with variable $for and $bar", ["for", "bar"])
|
||||
|
||||
# last parameter: "prompt_params" can be omitted
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo and $bar")
|
||||
assert p.prompt_params == ["foo", "bar"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo1 and $bar2")
|
||||
assert p.prompt_params == ["foo1", "bar2"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $foo_1 and $bar_2")
|
||||
assert p.prompt_params == ["foo_1", "bar_2"]
|
||||
|
||||
p = PromptTemplate("t4", "Here is some fake template with variable $Foo_1 and $Bar_2")
|
||||
assert p.prompt_params == ["Foo_1", "Bar_2"]
|
||||
|
||||
p = PromptTemplate("t4", "'Here is some fake template with variable $baz'")
|
||||
assert p.prompt_params == ["baz"]
|
||||
# strip single quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable $baz"
|
||||
|
||||
p = PromptTemplate("t4", '"Here is some fake template with variable $baz"')
|
||||
assert p.prompt_params == ["baz"]
|
||||
# strip double quotes, happens in YAML as we need to use single quotes for the template string
|
||||
assert p.prompt_text == "Here is some fake template with variable $baz"
|
||||
|
||||
|
||||
def test_create_prompt_model():
|
||||
model = PromptModel("google/flan-t5-small")
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
model = PromptModel()
|
||||
assert model.model_name_or_path == "google/flan-t5-base"
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
# davinci selected but no API key provided
|
||||
model = PromptModel("text-davinci-003")
|
||||
|
||||
model = PromptModel("text-davinci-003", api_key="no need to provide a real key")
|
||||
assert model.model_name_or_path == "text-davinci-003"
|
||||
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptModel("some-random-model")
|
||||
|
||||
# we can also pass model kwargs to the PromptModel
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
# we can also pass kwargs directly, see HF Pipeline constructor
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
# we can't use device_map auto without accelerate library installed
|
||||
with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"):
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
|
||||
def test_create_prompt_node():
|
||||
prompt_node = PromptNode()
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
prompt_node = PromptNode("google/flan-t5-small")
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.model_name_or_path == "google/flan-t5-small"
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
# davinci selected but no API key provided
|
||||
prompt_node = PromptNode("text-davinci-003")
|
||||
|
||||
prompt_node = PromptNode("text-davinci-003", api_key="no need to provide a real key")
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.model_name_or_path == "text-davinci-003"
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"):
|
||||
# yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
||||
# currently support only T5-Flan models
|
||||
prompt_node = PromptNode("vblagoje/bart_lfqa")
|
||||
|
||||
with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"):
|
||||
# yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
||||
# currently support only T5-Flan models
|
||||
prompt_node = PromptNode("valhalla/t5-base-e2e-qg")
|
||||
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptNode("some-random-model")
|
||||
|
||||
|
||||
def test_add_and_remove_template(prompt_node):
|
||||
num_default_tasks = len(prompt_node.get_prompt_template_names())
|
||||
custom_task = PromptTemplate(
|
||||
name="custom-task", prompt_text="Custom task: $param1, $param2", prompt_params=["param1", "param2"]
|
||||
)
|
||||
prompt_node.add_prompt_template(custom_task)
|
||||
assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1
|
||||
assert "custom-task" in prompt_node.get_prompt_template_names()
|
||||
|
||||
assert prompt_node.remove_prompt_template("custom-task") is not None
|
||||
assert "custom-task" not in prompt_node.get_prompt_template_names()
|
||||
|
||||
|
||||
def test_invalid_template(prompt_node):
|
||||
with pytest.raises(ValueError, match="Invalid parameter"):
|
||||
PromptTemplate(
|
||||
name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Number of parameters"):
|
||||
PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"])
|
||||
|
||||
|
||||
def test_add_template_and_invoke(prompt_node):
|
||||
tt = PromptTemplate(
|
||||
name="sentiment-analysis-new",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
)
|
||||
prompt_node.add_prompt_template(tt)
|
||||
|
||||
r = prompt_node.prompt("sentiment-analysis-new", documents=["Berlin is an amazing city."])
|
||||
assert r[0].casefold() == "positive"
|
||||
|
||||
|
||||
def test_on_the_fly_prompt(prompt_node):
|
||||
tt = PromptTemplate(
|
||||
name="sentiment-analysis-temp",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: $documents; Answer:",
|
||||
prompt_params=["documents"],
|
||||
)
|
||||
r = prompt_node.prompt(tt, documents=["Berlin is an amazing city."])
|
||||
assert r[0].casefold() == "positive"
|
||||
|
||||
|
||||
def test_direct_prompting(prompt_node):
|
||||
r = prompt_node("What is the capital of Germany?")
|
||||
assert r[0].casefold() == "berlin"
|
||||
|
||||
r = prompt_node("What is the capital of Germany?", "What is the secret of universe?")
|
||||
assert r[0].casefold() == "berlin"
|
||||
assert len(r[1]) > 0
|
||||
|
||||
r = prompt_node("Capital of Germany is Berlin", task="question-generation")
|
||||
assert len(r[0]) > 10 and "Germany" in r[0]
|
||||
|
||||
r = prompt_node(["Capital of Germany is Berlin", "Capital of France is Paris"], task="question-generation")
|
||||
assert len(r) == 2
|
||||
|
||||
|
||||
def test_question_generation(prompt_node):
|
||||
r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
|
||||
assert len(r) == 1 and len(r[0]) > 0
|
||||
|
||||
|
||||
def test_template_selection(prompt_node):
|
||||
qa = prompt_node.set_default_prompt_template("question-answering")
|
||||
r = qa(
|
||||
["Berlin is the capital of Germany.", "Paris is the capital of France."],
|
||||
["What is the capital of Germany?", "What is the capital of France"],
|
||||
)
|
||||
assert r[0].casefold() == "berlin" and r[1].casefold() == "paris"
|
||||
|
||||
|
||||
def test_has_supported_template_names(prompt_node):
|
||||
assert len(prompt_node.get_prompt_template_names()) > 0
|
||||
|
||||
|
||||
def test_invalid_template_params(prompt_node):
|
||||
with pytest.raises(ValueError, match="Expected prompt params"):
|
||||
prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
|
||||
|
||||
def test_wrong_template_params(prompt_node):
|
||||
with pytest.raises(ValueError, match="Expected prompt params"):
|
||||
# with don't have options param, multiple choice QA has
|
||||
prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."])
|
||||
|
||||
|
||||
def test_run_invalid_template(prompt_node):
|
||||
with pytest.raises(ValueError, match="invalid-task not supported"):
|
||||
prompt_node.prompt("invalid-task", {})
|
||||
|
||||
|
||||
def test_invalid_prompting(prompt_node):
|
||||
with pytest.raises(ValueError, match="Hey there, what is the best city in the worl"):
|
||||
prompt_node.prompt(
|
||||
"Hey there, what is the best city in the world?" "Hey there, what is the best city in the world?"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Hey there, what is the best city in the"):
|
||||
prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"])
|
||||
|
||||
|
||||
def test_invalid_state_ops(prompt_node):
|
||||
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
|
||||
prompt_node.remove_prompt_template("no_such_task_exists")
|
||||
# remove default task
|
||||
prompt_node.remove_prompt_template("question-answering")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_open_ai_prompt_with_params():
|
||||
pm = PromptModel("text-davinci-003", api_key=os.environ["OPENAI_API_KEY"])
|
||||
pn = PromptNode(pm)
|
||||
optional_davinci_params = {"temperature": 0.5, "max_tokens": 10, "top_p": 1, "frequency_penalty": 0.5}
|
||||
r = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params)
|
||||
assert len(r) == 1 and len(r[0]) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_simple_pipeline(prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert result["results"][0].casefold() == "positive"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_complex_pipeline(prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="questions")
|
||||
node2 = PromptNode(prompt_model, default_prompt_template="question-answering")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
|
||||
|
||||
assert "berlin" in result["results"][0].casefold()
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_model():
|
||||
model = PromptModel()
|
||||
node = PromptNode(
|
||||
model_name_or_path=model, default_prompt_template="question-generation", output_variable="questions"
|
||||
)
|
||||
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
|
||||
|
||||
assert result["results"][0] == "Berlin"
|
||||
|
||||
|
||||
def test_simple_pipeline_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
params:
|
||||
default_prompt_template: sentiment-analysis
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert result["results"][0] == "positive"
|
||||
|
||||
|
||||
def test_complex_pipeline_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
params:
|
||||
default_prompt_template: question-generation
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
default_prompt_template: question-answering
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
- name: p2
|
||||
inputs:
|
||||
- p1
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert result["results"][0] == "Berlin"
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: pmodel
|
||||
type: PromptModel
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-generation
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
- name: p2
|
||||
inputs:
|
||||
- p1
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert "Berlin" in result["results"][0]
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: pmodel
|
||||
type: PromptModel
|
||||
params:
|
||||
model_name_or_path: google/flan-t5-small
|
||||
model_kwargs:
|
||||
torch_dtype: torch.bfloat16
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
- name: p2
|
||||
inputs:
|
||||
- p1
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert "Berlin" in result["results"][0]
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_complex_pipeline_with_all_features(tmp_path):
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: pmodel
|
||||
type: PromptModel
|
||||
params:
|
||||
model_name_or_path: google/flan-t5-small
|
||||
model_kwargs:
|
||||
torch_dtype: torch.bfloat16
|
||||
- name: pmodel_openai
|
||||
type: PromptModel
|
||||
params:
|
||||
model_name_or_path: text-davinci-003
|
||||
model_kwargs:
|
||||
temperature: 0.9
|
||||
max_tokens: 64
|
||||
api_key: {api_key}
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel_openai
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
- name: p2
|
||||
inputs:
|
||||
- p1
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
|
||||
assert "Berlin" in result["results"][0] or "Germany" in result["results"][0]
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
Loading…
x
Reference in New Issue
Block a user