diff --git a/haystack/nodes/other/shaper.py b/haystack/nodes/other/shaper.py index 3ab147475..c6b2ee9f9 100644 --- a/haystack/nodes/other/shaper.py +++ b/haystack/nodes/other/shaper.py @@ -246,6 +246,7 @@ class Shaper(BaseComponent): outputs: List[str], inputs: Optional[Dict[str, Union[List[str], str]]] = None, params: Optional[Dict[str, Any]] = None, + publish_outputs: Union[bool, List[str]] = True, ): """ Initializes the Shaper component. @@ -319,14 +320,38 @@ class Shaper(BaseComponent): You can use params to provide fallback values for arguments of `run` that you're not sure exist. So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query` is not passed to this node by the pipeline. - :param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match + :param outputs: The key to store the outputs in the invocation context. The length of the outputs must match the number of outputs produced by the function invoked. + :param publish_outputs: Controls whether to publish the outputs to the pipeline's output. + Set `True` (default value) to publishes all outputs or `False` to publish None. + E.g. if `outputs = ["documents"]` result for `publish_outputs = True` looks like + ```python + { + "invocation_context": { + "documents": [...] + }, + "documents": [...] + } + ``` + For `publish_outputs = False` result looks like + ```python + { + "invocation_context": { + "documents": [...] + }, + } + ``` + If you want to have finer-grained control, pass a list of the outputs you want to publish. """ super().__init__() self.function = REGISTERED_FUNCTIONS[func] self.outputs = outputs self.inputs = inputs or {} self.params = params or {} + if isinstance(publish_outputs, bool): + self.publish_outputs = self.outputs if publish_outputs else [] + else: + self.publish_outputs = publish_outputs def run( # type: ignore self, @@ -404,7 +429,7 @@ class Shaper(BaseComponent): results = {} for output_key, output_value in zip(self.outputs, output_values): invocation_context[output_key] = output_value - if output_key in ["query", "file_paths", "labels", "documents", "meta"]: + if output_key in self.publish_outputs: results[output_key] = output_value results["invocation_context"] = invocation_context diff --git a/test/nodes/test_shaper.py b/test/nodes/test_shaper.py index e72e2a658..76e7caac3 100644 --- a/test/nodes/test_shaper.py +++ b/test/nodes/test_shaper.py @@ -3,7 +3,9 @@ import logging import haystack from haystack import Pipeline, Document, Answer +from haystack.document_stores.memory import InMemoryDocumentStore from haystack.nodes.other.shaper import Shaper +from haystack.nodes.retriever.sparse import BM25Retriever @pytest.fixture @@ -340,6 +342,37 @@ def test_join_documents(): documents=[Document(content="first"), Document(content="second"), Document(content="third")] ) assert results["invocation_context"]["documents"] == [Document(content="first | second | third")] + assert results["documents"] == [Document(content="first | second | third")] + + +def test_join_documents_without_publish_outputs(): + shaper = Shaper( + func="join_documents", + inputs={"documents": "documents"}, + params={"delimiter": " | "}, + outputs=["documents"], + publish_outputs=False, + ) + results, _ = shaper.run( + documents=[Document(content="first"), Document(content="second"), Document(content="third")] + ) + assert results["invocation_context"]["documents"] == [Document(content="first | second | third")] + assert "documents" not in results + + +def test_join_documents_with_publish_outputs_as_list(): + shaper = Shaper( + func="join_documents", + inputs={"documents": "documents"}, + params={"delimiter": " | "}, + outputs=["documents"], + publish_outputs=["documents"], + ) + results, _ = shaper.run( + documents=[Document(content="first"), Document(content="second"), Document(content="third")] + ) + assert results["invocation_context"]["documents"] == [Document(content="first | second | third")] + assert results["documents"] == [Document(content="first | second | third")] def test_join_documents_default_delimiter(): @@ -457,6 +490,11 @@ def test_strings_to_answers_yaml(tmp_path): Answer(answer="b", type="generative"), Answer(answer="c", type="generative"), ] + assert result["answers"] == [ + Answer(answer="a", type="generative"), + Answer(answer="b", type="generative"), + Answer(answer="c", type="generative"), + ] # @@ -1116,3 +1154,19 @@ def test_join_query_and_documents_convert_into_documents_yaml(tmp_path): assert result["invocation_context"]["query_and_docs"] assert len(result["invocation_context"]["query_and_docs"]) == 4 assert isinstance(result["invocation_context"]["query_and_docs"][0], Document) + + +def test_shaper_publishes_unknown_arg_does_not_break_pipeline(): + documents = [Document(content="test query")] + shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["unknown_by_retriever"], publish_outputs=True) + document_store = InMemoryDocumentStore(use_bm25=True) + document_store.write_documents(documents) + retriever = BM25Retriever(document_store=document_store) + pipeline = Pipeline() + pipeline.add_node(component=shaper, name="shaper", inputs=["Query"]) + pipeline.add_node(component=retriever, name="retriever", inputs=["shaper"]) + + result = pipeline.run(query="test query") + assert result["invocation_context"]["unknown_by_retriever"] == "test query" + assert result["unknown_by_retriever"] == "test query" + assert len(result["documents"]) == 1