From e2fb82b14815bdb210bb445089b200408b485536 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Jan 2023 11:17:06 +0100 Subject: [PATCH] refactor: Move invocation_context from meta to own pipeline variable (#3888) --- haystack/nodes/prompt/prompt_node.py | 13 ++-- test/nodes/test_prompt_node.py | 94 ++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 12 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index bd685f791..9b1f003d3 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -851,6 +851,7 @@ class PromptNode(BaseComponent): labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, meta: Optional[dict] = None, + invocation_context: Optional[Dict[str, Any]] = None, ) -> Tuple[Dict, str]: """ Runs the PromptNode on these inputs parameters. Returns the output of the prompt model. @@ -864,25 +865,23 @@ class PromptNode(BaseComponent): prompt template. :param documents: The documents to be used for the prompt. :param meta: The meta to be used for the prompt. Usually not used. + :param invocation_context: The invocation context to be used for the prompt. """ - if not meta: - meta = {} # invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used # to pass results from a pipeline node to any other downstream pipeline node. - if "invocation_context" not in meta: - meta["invocation_context"] = {} + invocation_context = invocation_context or {} results = self( query=query, labels=labels, documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [], - **meta["invocation_context"], + **invocation_context, ) if self.output_variable: - meta["invocation_context"][self.output_variable] = results - return {"results": results, "meta": {**meta}}, "output_1" + invocation_context[self.output_variable] = results + return {"results": results, "invocation_context": invocation_context}, "output_1" def run_batch( self, diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py index 03cadd5f7..3184c1ed9 100644 --- a/test/nodes/test_prompt_node.py +++ b/test/nodes/test_prompt_node.py @@ -1,9 +1,10 @@ import os +from typing import Optional, Union, List, Dict, Any, Tuple import pytest import torch -from haystack import Document, Pipeline +from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack.errors import OpenAIError from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel @@ -365,7 +366,8 @@ def test_complex_pipeline_yaml(tmp_path): result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) response = result["results"][0] assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold()) - assert len(result["meta"]["invocation_context"]) > 0 + assert len(result["invocation_context"]) > 0 + assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path): @@ -402,7 +404,8 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path): result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) response = result["results"][0] assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold()) - assert len(result["meta"]["invocation_context"]) > 0 + assert len(result["invocation_context"]) > 0 + assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path): @@ -448,7 +451,87 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_ result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) response = result["results"][0] assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold()) - assert len(result["meta"]["invocation_context"]) > 0 + assert len(result["invocation_context"]) > 0 + assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 + + +def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path): + # test that we can stick some random node in between prompt nodes and that everything still works + # most specifically, we want to ensure that invocation_context is still populated correctly and propagated + class InBetweenNode(BaseComponent): + outgoing_edges = 1 + + def run( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + ) -> Tuple[Dict, str]: + return {}, "output_1" + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + return {}, "output_1" + + with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: in_between + type: InBetweenNode + - name: pmodel + type: PromptModel + params: + model_name_or_path: google/flan-t5-small + model_kwargs: + torch_dtype: torch.bfloat16 + - name: question_generation_template + type: PromptTemplate + params: + name: question-generation-new + prompt_text: "Given the context please generate a question. Context: $documents; Question:" + - name: p1 + params: + model_name_or_path: pmodel + default_prompt_template: question_generation_template + output_variable: questions + type: PromptNode + - name: p2 + params: + model_name_or_path: pmodel + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: in_between + inputs: + - p1 + - name: p2 + inputs: + - in_between + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + response = result["results"][0] + assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold()) + assert len(result["invocation_context"]) > 0 + assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 @pytest.mark.skipif( @@ -507,4 +590,5 @@ def test_complex_pipeline_with_all_features(tmp_path): result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")]) response = result["results"][0] assert any(word for word in ["berlin", "germany", "population", "city", "amazing"] if word in response.casefold()) - assert len(result["meta"]["invocation_context"]) > 0 + assert len(result["invocation_context"]) > 0 + assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0