mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 18:52:11 +00:00
refactor: Move invocation_context from meta to own pipeline variable (#3888)
This commit is contained in:
parent
34b7db0209
commit
e2fb82b148
@ -851,6 +851,7 @@ class PromptNode(BaseComponent):
|
|||||||
labels: Optional[MultiLabel] = None,
|
labels: Optional[MultiLabel] = None,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
meta: Optional[dict] = None,
|
meta: Optional[dict] = None,
|
||||||
|
invocation_context: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[Dict, str]:
|
) -> Tuple[Dict, str]:
|
||||||
"""
|
"""
|
||||||
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
|
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
|
||||||
@ -864,25 +865,23 @@ class PromptNode(BaseComponent):
|
|||||||
prompt template.
|
prompt template.
|
||||||
:param documents: The documents to be used for the prompt.
|
:param documents: The documents to be used for the prompt.
|
||||||
:param meta: The meta to be used for the prompt. Usually not used.
|
:param meta: The meta to be used for the prompt. Usually not used.
|
||||||
|
:param invocation_context: The invocation context to be used for the prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not meta:
|
|
||||||
meta = {}
|
|
||||||
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
|
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
|
||||||
# to pass results from a pipeline node to any other downstream pipeline node.
|
# to pass results from a pipeline node to any other downstream pipeline node.
|
||||||
if "invocation_context" not in meta:
|
invocation_context = invocation_context or {}
|
||||||
meta["invocation_context"] = {}
|
|
||||||
|
|
||||||
results = self(
|
results = self(
|
||||||
query=query,
|
query=query,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
|
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
|
||||||
**meta["invocation_context"],
|
**invocation_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.output_variable:
|
if self.output_variable:
|
||||||
meta["invocation_context"][self.output_variable] = results
|
invocation_context[self.output_variable] = results
|
||||||
return {"results": results, "meta": {**meta}}, "output_1"
|
return {"results": results, "invocation_context": invocation_context}, "output_1"
|
||||||
|
|
||||||
def run_batch(
|
def run_batch(
|
||||||
self,
|
self,
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||||
from haystack.errors import OpenAIError
|
from haystack.errors import OpenAIError
|
||||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||||
|
|
||||||
@ -365,7 +366,8 @@ def test_complex_pipeline_yaml(tmp_path):
|
|||||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||||
response = result["results"][0]
|
response = result["results"][0]
|
||||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||||
assert len(result["meta"]["invocation_context"]) > 0
|
assert len(result["invocation_context"]) > 0
|
||||||
|
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||||
@ -402,7 +404,8 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
|||||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||||
response = result["results"][0]
|
response = result["results"][0]
|
||||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||||
assert len(result["meta"]["invocation_context"]) > 0
|
assert len(result["invocation_context"]) > 0
|
||||||
|
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
|
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
|
||||||
@ -448,7 +451,87 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
|||||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||||
response = result["results"][0]
|
response = result["results"][0]
|
||||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||||
assert len(result["meta"]["invocation_context"]) > 0
|
assert len(result["invocation_context"]) > 0
|
||||||
|
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
|
||||||
|
# test that we can stick some random node in between prompt nodes and that everything still works
|
||||||
|
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
|
||||||
|
class InBetweenNode(BaseComponent):
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
file_paths: Optional[List[str]] = None,
|
||||||
|
labels: Optional[MultiLabel] = None,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
) -> Tuple[Dict, str]:
|
||||||
|
return {}, "output_1"
|
||||||
|
|
||||||
|
def run_batch(
|
||||||
|
self,
|
||||||
|
queries: Optional[Union[str, List[str]]] = None,
|
||||||
|
file_paths: Optional[List[str]] = None,
|
||||||
|
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||||
|
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||||
|
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
debug: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
return {}, "output_1"
|
||||||
|
|
||||||
|
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: ignore
|
||||||
|
components:
|
||||||
|
- name: in_between
|
||||||
|
type: InBetweenNode
|
||||||
|
- name: pmodel
|
||||||
|
type: PromptModel
|
||||||
|
params:
|
||||||
|
model_name_or_path: google/flan-t5-small
|
||||||
|
model_kwargs:
|
||||||
|
torch_dtype: torch.bfloat16
|
||||||
|
- name: question_generation_template
|
||||||
|
type: PromptTemplate
|
||||||
|
params:
|
||||||
|
name: question-generation-new
|
||||||
|
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||||
|
- name: p1
|
||||||
|
params:
|
||||||
|
model_name_or_path: pmodel
|
||||||
|
default_prompt_template: question_generation_template
|
||||||
|
output_variable: questions
|
||||||
|
type: PromptNode
|
||||||
|
- name: p2
|
||||||
|
params:
|
||||||
|
model_name_or_path: pmodel
|
||||||
|
default_prompt_template: question-answering
|
||||||
|
type: PromptNode
|
||||||
|
pipelines:
|
||||||
|
- name: query
|
||||||
|
nodes:
|
||||||
|
- name: p1
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: in_between
|
||||||
|
inputs:
|
||||||
|
- p1
|
||||||
|
- name: p2
|
||||||
|
inputs:
|
||||||
|
- in_between
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||||
|
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||||
|
response = result["results"][0]
|
||||||
|
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||||
|
assert len(result["invocation_context"]) > 0
|
||||||
|
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -507,4 +590,5 @@ def test_complex_pipeline_with_all_features(tmp_path):
|
|||||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
|
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
|
||||||
response = result["results"][0]
|
response = result["results"][0]
|
||||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||||
assert len(result["meta"]["invocation_context"]) > 0
|
assert len(result["invocation_context"]) > 0
|
||||||
|
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user