Simplify PromptTemplate substitution in PromptNode (#3876)

This commit is contained in:
Vladimir Blagojevic 2023-01-18 18:31:15 +01:00 committed by GitHub
parent eb57e1fc09
commit c44d67856e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ import logging
import re
from abc import ABC, abstractmethod
from string import Template
from typing import Dict, List, Optional, Tuple, Union, Any, Type
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator
import requests
import torch
@ -103,18 +103,9 @@ class PromptTemplate(BasePromptTemplate, ABC):
self.prompt_text = prompt_text
self.prompt_params = prompt_params
def fill(self, *args, **kwargs) -> Dict[str, Any]:
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
"""
Fills the prompt text parameters from non-keyword and keyword arguments.
In the case of non-keyword arguments, the order of the arguments should match the left-to-right
order of appearance of the parameters in the prompt text. For example, if the prompt text is:
`Please come up with a question for the given context and the answer. Context: $documents;
Answer: $answers; Question:` then the first non-keyword argument fills the $documents placeholder
and the second non-keyword argument fills the $answers placeholder.
In the case of keyword arguments, the order of the arguments doesn't matter. Placeholders in the
prompt text are filled with the corresponding keyword argument.
Prepares and verifies the prompt template with input parameters.
:param args: Non-keyword arguments to use for filling the prompt text.
:param kwargs: Keyword arguments to use for filling the prompt text.
@ -143,9 +134,34 @@ class PromptTemplate(BasePromptTemplate, ABC):
available_params = set(list(template_dict.keys()) + list(set(kwargs.keys())))
raise ValueError(f"Expected prompt params {self.prompt_params} but got {list(available_params)}")
template_dict["prompt_template"] = self.prompt_text
return template_dict
def fill(self, *args, **kwargs) -> Iterator[str]:
"""
Fills the prompt text parameters from non-keyword and keyword arguments and returns the iterator prompt text.
In the case of non-keyword arguments, the order of the arguments should match the left-to-right
order of appearance of the parameters in the prompt text. For example, if the prompt text is:
`Please come up with a question for the given context and the answer. Context: $documents;
Answer: $answers; Question:` then the first non-keyword argument fills the $documents placeholder
and the second non-keyword argument fills the $answers placeholder.
In the case of keyword arguments, the order of the arguments doesn't matter. Placeholders in the
prompt text are filled with the corresponding keyword argument.
:param args: Non-keyword arguments to use for filling the prompt text.
:param kwargs: Keyword arguments to use for filling the prompt text.
:return: An iterator of prompt texts.
"""
template_dict = self.prepare(*args, **kwargs)
template = Template(self.prompt_text)
# the prompt context values should all be lists, as they will be split as one
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in template_dict.items()}
for prompt_context_values in zip(*prompt_context_copy.values()):
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
prompt_prepared: str = template.substitute(template_input)
yield prompt_prepared
class PromptModelInvocationLayer:
"""
@ -710,53 +726,31 @@ class PromptNode(BaseComponent):
:return: A list of strings as model responses.
"""
results = []
prompt_prepared: Dict[str, Any] = {}
if isinstance(prompt_template, str) and not self.is_supported_template(prompt_template):
raise ValueError(
f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()} "
f"or pass a PromptTemplate instance for prompting."
)
invoke_template = self.default_prompt_template if prompt_template is None else prompt_template
if args and invoke_template is None:
# create straightforward prompt on the input, no templates used
prompt_prepared["prompt"] = list(args)
else:
template_to_fill: PromptTemplate
if isinstance(prompt_template, PromptTemplate):
template_to_fill = prompt_template
elif isinstance(prompt_template, str):
template_to_fill = self.get_prompt_template(prompt_template)
prompt_template_used = prompt_template or self.default_prompt_template
if prompt_template_used:
if isinstance(prompt_template_used, PromptTemplate):
template_to_fill = prompt_template_used
elif isinstance(prompt_template_used, str):
template_to_fill = self.get_prompt_template(prompt_template_used)
else:
raise ValueError(f"{prompt_template} with args {args} , and kwargs {kwargs} not supported")
# we have potentially args and kwargs; task selected, so templating is needed
prompt_prepared = template_to_fill.fill(*args, **kwargs)
raise ValueError(f"{prompt_template_used} with args {args} , and kwargs {kwargs} not supported")
# straightforward prompt, no templates used
if "prompt" in prompt_prepared:
for prompt in prompt_prepared["prompt"]:
# prompt template used, yield prompts from inputs args
for prompt in template_to_fill.fill(*args, **kwargs):
# and pass the prepared prompt to the model
output = self.prompt_model.invoke(prompt, **kwargs)
results.extend(output)
else:
# straightforward prompt, no templates used
for prompt in list(args):
output = self.prompt_model.invoke(prompt)
for item in output:
results.append(item)
# templated prompt
# we have a prompt dictionary with prompt_template text and key/value pairs for template variables
# where key is the variable name and value is a list of variable values
# we invoke the model iterating through a list of prompt variable values replacing the variables
# in the prompt template
elif "prompt_template" in prompt_prepared:
template = Template(prompt_prepared["prompt_template"])
prompt_context_copy = prompt_prepared.copy()
prompt_context_copy.pop("prompt_template")
# the prompt context values should all be lists, as they will be split as one
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in prompt_context_copy.items()}
for prompt_context_values in zip(*prompt_context_copy.values()):
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
template_prepared: str = template.substitute(template_input)
# remove template keys from kwargs so we don't pass them to the model
kwargs = {key: value for key, value in kwargs.items() if key not in template_input.keys()}
output = self.prompt_model.invoke(template_prepared, **kwargs)
for item in output:
results.append(item)
results.extend(output)
return results
def add_prompt_template(self, prompt_template: PromptTemplate) -> None: