From c44d67856ed1ab47732a48703d2e74462d18b296 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Jan 2023 18:31:15 +0100 Subject: [PATCH] Simplify PromptTemplate substitution in PromptNode (#3876) --- haystack/nodes/prompt/prompt_node.py | 96 +++++++++++++--------------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index cbee816a7..bd685f791 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -3,7 +3,7 @@ import logging import re from abc import ABC, abstractmethod from string import Template -from typing import Dict, List, Optional, Tuple, Union, Any, Type +from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator import requests import torch @@ -103,18 +103,9 @@ class PromptTemplate(BasePromptTemplate, ABC): self.prompt_text = prompt_text self.prompt_params = prompt_params - def fill(self, *args, **kwargs) -> Dict[str, Any]: + def prepare(self, *args, **kwargs) -> Dict[str, Any]: """ - Fills the prompt text parameters from non-keyword and keyword arguments. - - In the case of non-keyword arguments, the order of the arguments should match the left-to-right - order of appearance of the parameters in the prompt text. For example, if the prompt text is: - `Please come up with a question for the given context and the answer. Context: $documents; - Answer: $answers; Question:` then the first non-keyword argument fills the $documents placeholder - and the second non-keyword argument fills the $answers placeholder. - - In the case of keyword arguments, the order of the arguments doesn't matter. Placeholders in the - prompt text are filled with the corresponding keyword argument. + Prepares and verifies the prompt template with input parameters. :param args: Non-keyword arguments to use for filling the prompt text. :param kwargs: Keyword arguments to use for filling the prompt text. @@ -143,9 +134,34 @@ class PromptTemplate(BasePromptTemplate, ABC): available_params = set(list(template_dict.keys()) + list(set(kwargs.keys()))) raise ValueError(f"Expected prompt params {self.prompt_params} but got {list(available_params)}") - template_dict["prompt_template"] = self.prompt_text return template_dict + def fill(self, *args, **kwargs) -> Iterator[str]: + """ + Fills the prompt text parameters from non-keyword and keyword arguments and returns the iterator prompt text. + + In the case of non-keyword arguments, the order of the arguments should match the left-to-right + order of appearance of the parameters in the prompt text. For example, if the prompt text is: + `Please come up with a question for the given context and the answer. Context: $documents; + Answer: $answers; Question:` then the first non-keyword argument fills the $documents placeholder + and the second non-keyword argument fills the $answers placeholder. + + In the case of keyword arguments, the order of the arguments doesn't matter. Placeholders in the + prompt text are filled with the corresponding keyword argument. + + :param args: Non-keyword arguments to use for filling the prompt text. + :param kwargs: Keyword arguments to use for filling the prompt text. + :return: An iterator of prompt texts. + """ + template_dict = self.prepare(*args, **kwargs) + template = Template(self.prompt_text) + # the prompt context values should all be lists, as they will be split as one + prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in template_dict.items()} + for prompt_context_values in zip(*prompt_context_copy.values()): + template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())} + prompt_prepared: str = template.substitute(template_input) + yield prompt_prepared + class PromptModelInvocationLayer: """ @@ -710,53 +726,31 @@ class PromptNode(BaseComponent): :return: A list of strings as model responses. """ results = [] - prompt_prepared: Dict[str, Any] = {} if isinstance(prompt_template, str) and not self.is_supported_template(prompt_template): raise ValueError( f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()} " f"or pass a PromptTemplate instance for prompting." ) - invoke_template = self.default_prompt_template if prompt_template is None else prompt_template - if args and invoke_template is None: - # create straightforward prompt on the input, no templates used - prompt_prepared["prompt"] = list(args) - else: - template_to_fill: PromptTemplate - if isinstance(prompt_template, PromptTemplate): - template_to_fill = prompt_template - elif isinstance(prompt_template, str): - template_to_fill = self.get_prompt_template(prompt_template) + prompt_template_used = prompt_template or self.default_prompt_template + if prompt_template_used: + if isinstance(prompt_template_used, PromptTemplate): + template_to_fill = prompt_template_used + elif isinstance(prompt_template_used, str): + template_to_fill = self.get_prompt_template(prompt_template_used) else: - raise ValueError(f"{prompt_template} with args {args} , and kwargs {kwargs} not supported") - # we have potentially args and kwargs; task selected, so templating is needed - prompt_prepared = template_to_fill.fill(*args, **kwargs) + raise ValueError(f"{prompt_template_used} with args {args} , and kwargs {kwargs} not supported") - # straightforward prompt, no templates used - if "prompt" in prompt_prepared: - for prompt in prompt_prepared["prompt"]: + # prompt template used, yield prompts from inputs args + for prompt in template_to_fill.fill(*args, **kwargs): + # and pass the prepared prompt to the model + output = self.prompt_model.invoke(prompt, **kwargs) + results.extend(output) + else: + # straightforward prompt, no templates used + for prompt in list(args): output = self.prompt_model.invoke(prompt) - for item in output: - results.append(item) - # templated prompt - # we have a prompt dictionary with prompt_template text and key/value pairs for template variables - # where key is the variable name and value is a list of variable values - # we invoke the model iterating through a list of prompt variable values replacing the variables - # in the prompt template - elif "prompt_template" in prompt_prepared: - template = Template(prompt_prepared["prompt_template"]) - prompt_context_copy = prompt_prepared.copy() - prompt_context_copy.pop("prompt_template") - # the prompt context values should all be lists, as they will be split as one - prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in prompt_context_copy.items()} - for prompt_context_values in zip(*prompt_context_copy.values()): - template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())} - template_prepared: str = template.substitute(template_input) - # remove template keys from kwargs so we don't pass them to the model - kwargs = {key: value for key, value in kwargs.items() if key not in template_input.keys()} - output = self.prompt_model.invoke(template_prepared, **kwargs) - for item in output: - results.append(item) + results.extend(output) return results def add_prompt_template(self, prompt_template: PromptTemplate) -> None: