refactor: Move invocation_context from meta to own pipeline variable (#3888)

This commit is contained in:
Vladimir Blagojevic 2023-01-19 11:17:06 +01:00 committed by GitHub
parent 34b7db0209
commit e2fb82b148
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 12 deletions

View File

@ -851,6 +851,7 @@ class PromptNode(BaseComponent):
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
"""
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
@ -864,25 +865,23 @@ class PromptNode(BaseComponent):
prompt template.
:param documents: The documents to be used for the prompt.
:param meta: The meta to be used for the prompt. Usually not used.
:param invocation_context: The invocation context to be used for the prompt.
"""
if not meta:
meta = {}
# invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used
# to pass results from a pipeline node to any other downstream pipeline node.
if "invocation_context" not in meta:
meta["invocation_context"] = {}
invocation_context = invocation_context or {}
results = self(
query=query,
labels=labels,
documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [],
**meta["invocation_context"],
**invocation_context,
)
if self.output_variable:
meta["invocation_context"][self.output_variable] = results
return {"results": results, "meta": {**meta}}, "output_1"
invocation_context[self.output_variable] = results
return {"results": results, "invocation_context": invocation_context}, "output_1"
def run_batch(
self,

View File

@ -1,9 +1,10 @@
import os
from typing import Optional, Union, List, Dict, Any, Tuple
import pytest
import torch
from haystack import Document, Pipeline
from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
@ -365,7 +366,8 @@ def test_complex_pipeline_yaml(tmp_path):
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["meta"]["invocation_context"]) > 0
assert len(result["invocation_context"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
@ -402,7 +404,8 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["meta"]["invocation_context"]) > 0
assert len(result["invocation_context"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
@ -448,7 +451,87 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["meta"]["invocation_context"]) > 0
assert len(result["invocation_context"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
# test that we can stick some random node in between prompt nodes and that everything still works
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
class InBetweenNode(BaseComponent):
outgoing_edges = 1
def run(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
) -> Tuple[Dict, str]:
return {}, "output_1"
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
return {}, "output_1"
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: in_between
type: InBetweenNode
- name: pmodel
type: PromptModel
params:
model_name_or_path: google/flan-t5-small
model_kwargs:
torch_dtype: torch.bfloat16
- name: question_generation_template
type: PromptTemplate
params:
name: question-generation-new
prompt_text: "Given the context please generate a question. Context: $documents; Question:"
- name: p1
params:
model_name_or_path: pmodel
default_prompt_template: question_generation_template
output_variable: questions
type: PromptNode
- name: p2
params:
model_name_or_path: pmodel
default_prompt_template: question-answering
type: PromptNode
pipelines:
- name: query
nodes:
- name: p1
inputs:
- Query
- name: in_between
inputs:
- p1
- name: p2
inputs:
- in_between
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml")
result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
response = result["results"][0]
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["invocation_context"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.skipif(
@ -507,4 +590,5 @@ def test_complex_pipeline_with_all_features(tmp_path):
result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")])
response = result["results"][0]
assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold())
assert len(result["meta"]["invocation_context"]) > 0
assert len(result["invocation_context"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0