mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-18 19:13:50 +00:00
Prompt node/run batch (#4072)
* Starting to implement first pass at run_batch * Started to add _flatten_input function * First pass at run_batch method. * Fixed bug * Adding tests for run_batch * Update doc strings * Pylint and mypy * Pylint * Fixing mypy * Restructurig of run_batch tests * Add minor lg updates * Adding more tests * Update dev comments and call static method differently * Fixed the setting of output variable * Set output_variable in __init__ of PromptNode * Make a one-liner --------- Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
83d615a32b
commit
d129598203
@ -741,7 +741,7 @@ class PromptNode(BaseComponent):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
|
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
|
||||||
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
|
||||||
self.output_variable: Optional[str] = output_variable
|
self.output_variable: str = output_variable or "results"
|
||||||
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
|
||||||
self.prompt_model: PromptModel
|
self.prompt_model: PromptModel
|
||||||
self.stop_words: Optional[List[str]] = stop_words
|
self.stop_words: Optional[List[str]] = stop_words
|
||||||
@ -924,8 +924,10 @@ class PromptNode(BaseComponent):
|
|||||||
invocation_context: Optional[Dict[str, Any]] = None,
|
invocation_context: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[Dict, str]:
|
) -> Tuple[Dict, str]:
|
||||||
"""
|
"""
|
||||||
Runs the PromptNode on these input parameters. Returns the output of the prompt model.
|
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
|
||||||
Parameters `file_paths`, `labels`, and `meta` are usually ignored.
|
The parameters `query`, `file_paths`, `labels`, `documents` and `meta` are added to the invocation context
|
||||||
|
before invoking the prompt model. PromptNode uses these variables only if they are present as
|
||||||
|
parameters in the PromptTemplate.
|
||||||
|
|
||||||
:param query: The PromptNode usually ignores the query, unless it's used as a parameter in the
|
:param query: The PromptNode usually ignores the query, unless it's used as a parameter in the
|
||||||
prompt template.
|
prompt template.
|
||||||
@ -934,7 +936,8 @@ class PromptNode(BaseComponent):
|
|||||||
:param labels: The PromptNode usually ignores the labels, unless they're used as a parameter in the
|
:param labels: The PromptNode usually ignores the labels, unless they're used as a parameter in the
|
||||||
prompt template.
|
prompt template.
|
||||||
:param documents: The documents to be used for the prompt.
|
:param documents: The documents to be used for the prompt.
|
||||||
:param meta: The meta to be used for the prompt. Usually not used.
|
:param meta: PromptNode usually ignores meta information, unless it's used as a parameter in the
|
||||||
|
PromptTemplate.
|
||||||
:param invocation_context: The invocation context to be used for the prompt.
|
:param invocation_context: The invocation context to be used for the prompt.
|
||||||
"""
|
"""
|
||||||
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
|
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
|
||||||
@ -967,29 +970,128 @@ class PromptNode(BaseComponent):
|
|||||||
|
|
||||||
results = self(prompt_collector=prompt_collector, **invocation_context)
|
results = self(prompt_collector=prompt_collector, **invocation_context)
|
||||||
|
|
||||||
final_result: Dict[str, Any] = {}
|
invocation_context[self.output_variable] = results
|
||||||
output_variable = self.output_variable or "results"
|
final_result: Dict[str, Any] = {
|
||||||
if output_variable:
|
self.output_variable: results,
|
||||||
invocation_context[output_variable] = results
|
"invocation_context": invocation_context,
|
||||||
final_result[output_variable] = results
|
"_debug": {"prompts_used": prompt_collector},
|
||||||
|
}
|
||||||
final_result["invocation_context"] = invocation_context
|
|
||||||
final_result["_debug"] = {"prompts_used": prompt_collector}
|
|
||||||
return final_result, "output_1"
|
return final_result, "output_1"
|
||||||
|
|
||||||
def run_batch(
|
def run_batch( # type: ignore
|
||||||
self,
|
self,
|
||||||
queries: Optional[Union[str, List[str]]] = None,
|
queries: Optional[List[str]] = None,
|
||||||
file_paths: Optional[List[str]] = None,
|
|
||||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
|
||||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
|
||||||
params: Optional[dict] = None,
|
|
||||||
debug: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
raise NotImplementedError("run_batch is not implemented for PromptNode.")
|
"""
|
||||||
|
Runs PromptNode in batch mode.
|
||||||
|
|
||||||
|
- If you provide a list containing a single query (and/or invocation context)...
|
||||||
|
- ... and a single list of Documents, the query is applied to each Document individually.
|
||||||
|
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
|
||||||
|
are aggregated per Document list.
|
||||||
|
|
||||||
|
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
|
||||||
|
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
|
||||||
|
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
|
||||||
|
and the results are aggregated per query-Document pair.
|
||||||
|
|
||||||
|
- If you provide no Documents, then each query (and/or invocation context) is applied directly to the PromptTemplate.
|
||||||
|
|
||||||
|
:param queries: List of queries.
|
||||||
|
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||||
|
:param invocation_contexts: List of invocation contexts.
|
||||||
|
"""
|
||||||
|
inputs = PromptNode._flatten_inputs(queries, documents, invocation_contexts)
|
||||||
|
all_results: Dict[str, List] = {self.output_variable: [], "invocation_contexts": [], "_debug": []}
|
||||||
|
for query, docs, invocation_context in zip(
|
||||||
|
inputs["queries"], inputs["documents"], inputs["invocation_contexts"]
|
||||||
|
):
|
||||||
|
results = self.run(query=query, documents=docs, invocation_context=invocation_context)[0]
|
||||||
|
all_results[self.output_variable].append(results[self.output_variable])
|
||||||
|
all_results["invocation_contexts"].append(all_results["invocation_contexts"])
|
||||||
|
all_results["_debug"].append(all_results["_debug"])
|
||||||
|
return all_results, "output_1"
|
||||||
|
|
||||||
def _prepare_model_kwargs(self):
|
def _prepare_model_kwargs(self):
|
||||||
# these are the parameters from PromptNode level
|
# these are the parameters from PromptNode level
|
||||||
# that are passed to the prompt model invocation layer
|
# that are passed to the prompt model invocation layer
|
||||||
return {"stop_words": self.stop_words}
|
return {"stop_words": self.stop_words}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_inputs(
|
||||||
|
queries: Optional[List[str]] = None,
|
||||||
|
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||||
|
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
"""Flatten and copy the queries, documents, and invocation contexts into lists of equal length.
|
||||||
|
|
||||||
|
- If you provide a list containing a single query (and/or invocation context)...
|
||||||
|
- ... and a single list of Documents, the query is applied to each Document individually.
|
||||||
|
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
|
||||||
|
are aggregated per Document list.
|
||||||
|
|
||||||
|
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
|
||||||
|
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
|
||||||
|
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
|
||||||
|
and the results are aggregated per query-Document pair.
|
||||||
|
|
||||||
|
- If you provide no Documents, then each query (and/or invocation context) is applied to the PromptTemplate.
|
||||||
|
|
||||||
|
:param queries: List of queries.
|
||||||
|
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||||
|
:param invocation_contexts: List of invocation contexts.
|
||||||
|
"""
|
||||||
|
# Check that queries, and invocation_contexts are of the same length if provided
|
||||||
|
input_queries: List[Any]
|
||||||
|
input_invocation_contexts: List[Any]
|
||||||
|
if queries is not None and invocation_contexts is not None:
|
||||||
|
if len(queries) != len(invocation_contexts):
|
||||||
|
raise ValueError("The input variables queries and invocation_contexts should have the same length.")
|
||||||
|
input_queries = queries
|
||||||
|
input_invocation_contexts = invocation_contexts
|
||||||
|
elif queries is not None and invocation_contexts is None:
|
||||||
|
input_queries = queries
|
||||||
|
input_invocation_contexts = [None] * len(queries)
|
||||||
|
elif queries is None and invocation_contexts is not None:
|
||||||
|
input_queries = [None] * len(invocation_contexts)
|
||||||
|
input_invocation_contexts = invocation_contexts
|
||||||
|
else:
|
||||||
|
input_queries = [None]
|
||||||
|
input_invocation_contexts = [None]
|
||||||
|
|
||||||
|
multi_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], list)
|
||||||
|
single_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], Document)
|
||||||
|
|
||||||
|
# Docs case 1: single list of Documents
|
||||||
|
# -> apply each query (and invocation_contexts) to all Documents
|
||||||
|
inputs: Dict[str, List] = {"queries": [], "invocation_contexts": [], "documents": []}
|
||||||
|
if documents is not None:
|
||||||
|
if single_docs_list:
|
||||||
|
for query, invocation_context in zip(input_queries, input_invocation_contexts):
|
||||||
|
for doc in documents:
|
||||||
|
inputs["queries"].append(query)
|
||||||
|
inputs["invocation_contexts"].append(invocation_context)
|
||||||
|
inputs["documents"].append([doc])
|
||||||
|
# Docs case 2: list of lists of Documents
|
||||||
|
# -> apply each query (and invocation_context) to corresponding list of Documents,
|
||||||
|
# if queries contains only one query, apply it to each list of Documents
|
||||||
|
elif multi_docs_list:
|
||||||
|
total_queries = input_queries.copy()
|
||||||
|
total_invocation_contexts = input_invocation_contexts.copy()
|
||||||
|
if len(total_queries) == 1 and len(total_invocation_contexts) == 1:
|
||||||
|
total_queries = input_queries * len(documents)
|
||||||
|
total_invocation_contexts = input_invocation_contexts * len(documents)
|
||||||
|
if len(total_queries) != len(documents) or len(total_invocation_contexts) != len(documents):
|
||||||
|
raise ValueError("Number of queries must be equal to number of provided Document lists.")
|
||||||
|
for query, invocation_context, cur_docs in zip(total_queries, total_invocation_contexts, documents):
|
||||||
|
inputs["queries"].append(query)
|
||||||
|
inputs["invocation_contexts"].append(invocation_context)
|
||||||
|
inputs["documents"].append(cur_docs)
|
||||||
|
elif queries is not None or invocation_contexts is not None:
|
||||||
|
for query, invocation_context in zip(input_queries, input_invocation_contexts):
|
||||||
|
inputs["queries"].append(query)
|
||||||
|
inputs["invocation_contexts"].append(invocation_context)
|
||||||
|
inputs["documents"].append([None])
|
||||||
|
return inputs
|
||||||
|
|||||||
@ -319,12 +319,12 @@ def test_simple_pipeline(prompt_model):
|
|||||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||||
pytest.skip("No API key found for OpenAI, skipping test")
|
pytest.skip("No API key found for OpenAI, skipping test")
|
||||||
|
|
||||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||||
|
|
||||||
pipe = Pipeline()
|
pipe = Pipeline()
|
||||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||||
assert result["results"][0].casefold() == "positive"
|
assert result["out"][0].casefold() == "positive"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@ -748,6 +748,78 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
|||||||
assert pipeline is not None
|
assert pipeline is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBatch:
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||||
|
def test_simple_pipeline_batch_no_query_single_doc_list(self, prompt_model):
|
||||||
|
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||||
|
pytest.skip("No API key found for OpenAI, skipping test")
|
||||||
|
|
||||||
|
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||||
|
result = pipe.run_batch(
|
||||||
|
queries=None, documents=[Document("Berlin is an amazing city."), Document("I am not feeling well.")]
|
||||||
|
)
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert isinstance(result["results"][0], list)
|
||||||
|
assert isinstance(result["results"][0][0], str)
|
||||||
|
assert "positive" in result["results"][0][0].casefold()
|
||||||
|
assert "negative" in result["results"][1][0].casefold()
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||||
|
def test_simple_pipeline_batch_no_query_multiple_doc_list(self, prompt_model):
|
||||||
|
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||||
|
pytest.skip("No API key found for OpenAI, skipping test")
|
||||||
|
|
||||||
|
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||||
|
result = pipe.run_batch(
|
||||||
|
queries=None,
|
||||||
|
documents=[
|
||||||
|
[Document("Berlin is an amazing city."), Document("Paris is an amazing city.")],
|
||||||
|
[Document("I am not feeling well.")],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert isinstance(result["out"], list)
|
||||||
|
assert isinstance(result["out"][0], list)
|
||||||
|
assert isinstance(result["out"][0][0], str)
|
||||||
|
assert all("positive" in x.casefold() for x in result["out"][0])
|
||||||
|
assert "negative" in result["out"][1][0].casefold()
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||||
|
def test_simple_pipeline_batch_query_multiple_doc_list(self, prompt_model):
|
||||||
|
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||||
|
pytest.skip("No API key found for OpenAI, skipping test")
|
||||||
|
|
||||||
|
prompt_template = PromptTemplate(
|
||||||
|
name="question-answering-new",
|
||||||
|
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
|
||||||
|
prompt_params=["documents", "query"],
|
||||||
|
)
|
||||||
|
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||||
|
result = pipe.run_batch(
|
||||||
|
queries=["Who lives in Berlin?"],
|
||||||
|
documents=[
|
||||||
|
[Document("My name is Carla and I live in Berlin"), Document("My name is James and I live in London")],
|
||||||
|
[Document("My name is Christelle and I live in Paris")],
|
||||||
|
],
|
||||||
|
debug=True,
|
||||||
|
)
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert isinstance(result["results"][0], list)
|
||||||
|
assert isinstance(result["results"][0][0], str)
|
||||||
|
# TODO Finish
|
||||||
|
|
||||||
|
|
||||||
def test_HFLocalInvocationLayer_supports():
|
def test_HFLocalInvocationLayer_supports():
|
||||||
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user