mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 10:42:13 +00:00
refactor: Move invocation_context from meta to own pipeline variable (#3888)
This commit is contained in:
parent
34b7db0209
commit
e2fb82b148
@ -851,6 +851,7 @@ class PromptNode(BaseComponent):
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
invocation_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""
|
||||
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
|
||||
@ -864,25 +865,23 @@ class PromptNode(BaseComponent):
|
||||
prompt template.
|
||||
:param documents: The documents to be used for the prompt.
|
||||
:param meta: The meta to be used for the prompt. Usually not used.
|
||||
:param invocation_context: The invocation context to be used for the prompt.
|
||||
"""
|
||||
|
||||
if not meta:
|
||||
meta = {}
|
||||
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
|
||||
# to pass results from a pipeline node to any other downstream pipeline node.
|
||||
if "invocation_context" not in meta:
|
||||
meta["invocation_context"] = {}
|
||||
invocation_context = invocation_context or {}
|
||||
|
||||
results = self(
|
||||
query=query,
|
||||
labels=labels,
|
||||
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
|
||||
**meta["invocation_context"],
|
||||
**invocation_context,
|
||||
)
|
||||
|
||||
if self.output_variable:
|
||||
meta["invocation_context"][self.output_variable] = results
|
||||
return {"results": results, "meta": {**meta}}, "output_1"
|
||||
invocation_context[self.output_variable] = results
|
||||
return {"results": results, "invocation_context": invocation_context}, "output_1"
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
|
@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
|
||||
@ -365,7 +366,8 @@ def test_complex_pipeline_yaml(tmp_path):
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
@ -402,7 +404,8 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
|
||||
@ -448,7 +451,87 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
|
||||
|
||||
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
|
||||
# test that we can stick some random node in between prompt nodes and that everything still works
|
||||
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
|
||||
class InBetweenNode(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
return {}, "output_1"
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
return {}, "output_1"
|
||||
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: in_between
|
||||
type: InBetweenNode
|
||||
- name: pmodel
|
||||
type: PromptModel
|
||||
params:
|
||||
model_name_or_path: google/flan-t5-small
|
||||
model_kwargs:
|
||||
torch_dtype: torch.bfloat16
|
||||
- name: question_generation_template
|
||||
type: PromptTemplate
|
||||
params:
|
||||
name: question-generation-new
|
||||
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
|
||||
- name: p1
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question_generation_template
|
||||
output_variable: questions
|
||||
type: PromptNode
|
||||
- name: p2
|
||||
params:
|
||||
model_name_or_path: pmodel
|
||||
default_prompt_template: question-answering
|
||||
type: PromptNode
|
||||
pipelines:
|
||||
- name: query
|
||||
nodes:
|
||||
- name: p1
|
||||
inputs:
|
||||
- Query
|
||||
- name: in_between
|
||||
inputs:
|
||||
- p1
|
||||
- name: p2
|
||||
inputs:
|
||||
- in_between
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
response = result["results"][0]
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -507,4 +590,5 @@ def test_complex_pipeline_with_all_features(tmp_path):
|
||||
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
|
||||
response = result["results"][0]
|
||||
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
|
||||
assert len(result["meta"]["invocation_context"]) > 0
|
||||
assert len(result["invocation_context"]) > 0
|
||||
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user