mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Prompt node/run batch (#4072)
* Starting to implement first pass at run_batch * Started to add _flatten_input function * First pass at run_batch method. * Fixed bug * Adding tests for run_batch * Update doc strings * Pylint and mypy * Pylint * Fixing mypy * Restructurig of run_batch tests * Add minor lg updates * Adding more tests * Update dev comments and call static method differently * Fixed the setting of output variable * Set output_variable in __init__ of PromptNode * Make a one-liner --------- Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
83d615a32b
commit
d129598203
@ -741,7 +741,7 @@ class PromptNode(BaseComponent):
|
||||
super().__init__()
|
||||
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
|
||||
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
||||
self.output_variable: Optional[str] = output_variable
|
||||
self.output_variable: str = output_variable or "results"
|
||||
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
||||
self.prompt_model: PromptModel
|
||||
self.stop_words: Optional[List[str]] = stop_words
|
||||
@ -924,8 +924,10 @@ class PromptNode(BaseComponent):
|
||||
invocation_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""
|
||||
Runs the PromptNode on these input parameters. Returns the output of the prompt model.
|
||||
Parameters `file_paths`, `labels`, and `meta` are usually ignored.
|
||||
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
|
||||
The parameters `query`, `file_paths`, `labels`, `documents` and `meta` are added to the invocation context
|
||||
before invoking the prompt model. PromptNode uses these variables only if they are present as
|
||||
parameters in the PromptTemplate.
|
||||
|
||||
:param query: The PromptNode usually ignores the query, unless it's used as a parameter in the
|
||||
prompt template.
|
||||
@ -934,7 +936,8 @@ class PromptNode(BaseComponent):
|
||||
:param labels: The PromptNode usually ignores the labels, unless they're used as a parameter in the
|
||||
prompt template.
|
||||
:param documents: The documents to be used for the prompt.
|
||||
:param meta: The meta to be used for the prompt. Usually not used.
|
||||
:param meta: PromptNode usually ignores meta information, unless it's used as a parameter in the
|
||||
PromptTemplate.
|
||||
:param invocation_context: The invocation context to be used for the prompt.
|
||||
"""
|
||||
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
|
||||
@ -967,29 +970,128 @@ class PromptNode(BaseComponent):
|
||||
|
||||
results = self(prompt_collector=prompt_collector, **invocation_context)
|
||||
|
||||
final_result: Dict[str, Any] = {}
|
||||
output_variable = self.output_variable or "results"
|
||||
if output_variable:
|
||||
invocation_context[output_variable] = results
|
||||
final_result[output_variable] = results
|
||||
|
||||
final_result["invocation_context"] = invocation_context
|
||||
final_result["_debug"] = {"prompts_used": prompt_collector}
|
||||
invocation_context[self.output_variable] = results
|
||||
final_result: Dict[str, Any] = {
|
||||
self.output_variable: results,
|
||||
"invocation_context": invocation_context,
|
||||
"_debug": {"prompts_used": prompt_collector},
|
||||
}
|
||||
return final_result, "output_1"
|
||||
|
||||
def run_batch(
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
queries: Optional[List[str]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
raise NotImplementedError("run_batch is not implemented for PromptNode.")
|
||||
"""
|
||||
Runs PromptNode in batch mode.
|
||||
|
||||
- If you provide a list containing a single query (and/or invocation context)...
|
||||
- ... and a single list of Documents, the query is applied to each Document individually.
|
||||
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
|
||||
are aggregated per Document list.
|
||||
|
||||
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
|
||||
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
|
||||
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
|
||||
and the results are aggregated per query-Document pair.
|
||||
|
||||
- If you provide no Documents, then each query (and/or invocation context) is applied directly to the PromptTemplate.
|
||||
|
||||
:param queries: List of queries.
|
||||
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||
:param invocation_contexts: List of invocation contexts.
|
||||
"""
|
||||
inputs = PromptNode._flatten_inputs(queries, documents, invocation_contexts)
|
||||
all_results: Dict[str, List] = {self.output_variable: [], "invocation_contexts": [], "_debug": []}
|
||||
for query, docs, invocation_context in zip(
|
||||
inputs["queries"], inputs["documents"], inputs["invocation_contexts"]
|
||||
):
|
||||
results = self.run(query=query, documents=docs, invocation_context=invocation_context)[0]
|
||||
all_results[self.output_variable].append(results[self.output_variable])
|
||||
all_results["invocation_contexts"].append(all_results["invocation_contexts"])
|
||||
all_results["_debug"].append(all_results["_debug"])
|
||||
return all_results, "output_1"
|
||||
|
||||
def _prepare_model_kwargs(self):
|
||||
# these are the parameters from PromptNode level
|
||||
# that are passed to the prompt model invocation layer
|
||||
return {"stop_words": self.stop_words}
|
||||
|
||||
@staticmethod
|
||||
def _flatten_inputs(
|
||||
queries: Optional[List[str]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Dict[str, List]:
|
||||
"""Flatten and copy the queries, documents, and invocation contexts into lists of equal length.
|
||||
|
||||
- If you provide a list containing a single query (and/or invocation context)...
|
||||
- ... and a single list of Documents, the query is applied to each Document individually.
|
||||
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
|
||||
are aggregated per Document list.
|
||||
|
||||
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
|
||||
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
|
||||
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
|
||||
and the results are aggregated per query-Document pair.
|
||||
|
||||
- If you provide no Documents, then each query (and/or invocation context) is applied to the PromptTemplate.
|
||||
|
||||
:param queries: List of queries.
|
||||
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||
:param invocation_contexts: List of invocation contexts.
|
||||
"""
|
||||
# Check that queries, and invocation_contexts are of the same length if provided
|
||||
input_queries: List[Any]
|
||||
input_invocation_contexts: List[Any]
|
||||
if queries is not None and invocation_contexts is not None:
|
||||
if len(queries) != len(invocation_contexts):
|
||||
raise ValueError("The input variables queries and invocation_contexts should have the same length.")
|
||||
input_queries = queries
|
||||
input_invocation_contexts = invocation_contexts
|
||||
elif queries is not None and invocation_contexts is None:
|
||||
input_queries = queries
|
||||
input_invocation_contexts = [None] * len(queries)
|
||||
elif queries is None and invocation_contexts is not None:
|
||||
input_queries = [None] * len(invocation_contexts)
|
||||
input_invocation_contexts = invocation_contexts
|
||||
else:
|
||||
input_queries = [None]
|
||||
input_invocation_contexts = [None]
|
||||
|
||||
multi_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], list)
|
||||
single_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], Document)
|
||||
|
||||
# Docs case 1: single list of Documents
|
||||
# -> apply each query (and invocation_contexts) to all Documents
|
||||
inputs: Dict[str, List] = {"queries": [], "invocation_contexts": [], "documents": []}
|
||||
if documents is not None:
|
||||
if single_docs_list:
|
||||
for query, invocation_context in zip(input_queries, input_invocation_contexts):
|
||||
for doc in documents:
|
||||
inputs["queries"].append(query)
|
||||
inputs["invocation_contexts"].append(invocation_context)
|
||||
inputs["documents"].append([doc])
|
||||
# Docs case 2: list of lists of Documents
|
||||
# -> apply each query (and invocation_context) to corresponding list of Documents,
|
||||
# if queries contains only one query, apply it to each list of Documents
|
||||
elif multi_docs_list:
|
||||
total_queries = input_queries.copy()
|
||||
total_invocation_contexts = input_invocation_contexts.copy()
|
||||
if len(total_queries) == 1 and len(total_invocation_contexts) == 1:
|
||||
total_queries = input_queries * len(documents)
|
||||
total_invocation_contexts = input_invocation_contexts * len(documents)
|
||||
if len(total_queries) != len(documents) or len(total_invocation_contexts) != len(documents):
|
||||
raise ValueError("Number of queries must be equal to number of provided Document lists.")
|
||||
for query, invocation_context, cur_docs in zip(total_queries, total_invocation_contexts, documents):
|
||||
inputs["queries"].append(query)
|
||||
inputs["invocation_contexts"].append(invocation_context)
|
||||
inputs["documents"].append(cur_docs)
|
||||
elif queries is not None or invocation_contexts is not None:
|
||||
for query, invocation_context in zip(input_queries, input_invocation_contexts):
|
||||
inputs["queries"].append(query)
|
||||
inputs["invocation_contexts"].append(invocation_context)
|
||||
inputs["documents"].append([None])
|
||||
return inputs
|
||||
|
||||
@ -319,12 +319,12 @@ def test_simple_pipeline(prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert result["results"][0].casefold() == "positive"
|
||||
assert result["out"][0].casefold() == "positive"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -748,6 +748,78 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
||||
assert pipeline is not None
|
||||
|
||||
|
||||
class TestRunBatch:
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_simple_pipeline_batch_no_query_single_doc_list(self, prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run_batch(
|
||||
queries=None, documents=[Document("Berlin is an amazing city."), Document("I am not feeling well.")]
|
||||
)
|
||||
assert isinstance(result["results"], list)
|
||||
assert isinstance(result["results"][0], list)
|
||||
assert isinstance(result["results"][0][0], str)
|
||||
assert "positive" in result["results"][0][0].casefold()
|
||||
assert "negative" in result["results"][1][0].casefold()
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_simple_pipeline_batch_no_query_multiple_doc_list(self, prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run_batch(
|
||||
queries=None,
|
||||
documents=[
|
||||
[Document("Berlin is an amazing city."), Document("Paris is an amazing city.")],
|
||||
[Document("I am not feeling well.")],
|
||||
],
|
||||
)
|
||||
assert isinstance(result["out"], list)
|
||||
assert isinstance(result["out"][0], list)
|
||||
assert isinstance(result["out"][0][0], str)
|
||||
assert all("positive" in x.casefold() for x in result["out"][0])
|
||||
assert "negative" in result["out"][1][0].casefold()
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_simple_pipeline_batch_query_multiple_doc_list(self, prompt_model):
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
|
||||
prompt_params=["documents", "query"],
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run_batch(
|
||||
queries=["Who lives in Berlin?"],
|
||||
documents=[
|
||||
[Document("My name is Carla and I live in Berlin"), Document("My name is James and I live in London")],
|
||||
[Document("My name is Christelle and I live in Paris")],
|
||||
],
|
||||
debug=True,
|
||||
)
|
||||
assert isinstance(result["results"], list)
|
||||
assert isinstance(result["results"][0], list)
|
||||
assert isinstance(result["results"][0][0], str)
|
||||
# TODO Finish
|
||||
|
||||
|
||||
def test_HFLocalInvocationLayer_supports():
|
||||
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user