fix: add option to not override results by Shaper (#4231)

* add  option to shaper and support answers

* remove publish restrictions on outputs

* support list
This commit is contained in:
tstadel 2023-02-22 14:36:58 +01:00 committed by GitHub
parent 262c9771f4
commit 32b2abf9d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 2 deletions

View File

@ -246,6 +246,7 @@ class Shaper(BaseComponent):
outputs: List[str],
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
params: Optional[Dict[str, Any]] = None,
publish_outputs: Union[bool, List[str]] = True,
):
"""
Initializes the Shaper component.
@ -319,14 +320,38 @@ class Shaper(BaseComponent):
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
is not passed to this node by the pipeline.
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
:param outputs: The key to store the outputs in the invocation context. The length of the outputs must match
the number of outputs produced by the function invoked.
:param publish_outputs: Controls whether to publish the outputs to the pipeline's output.
Set `True` (default value) to publishes all outputs or `False` to publish None.
E.g. if `outputs = ["documents"]` result for `publish_outputs = True` looks like
```python
{
"invocation_context": {
"documents": [...]
},
"documents": [...]
}
```
For `publish_outputs = False` result looks like
```python
{
"invocation_context": {
"documents": [...]
},
}
```
If you want to have finer-grained control, pass a list of the outputs you want to publish.
"""
super().__init__()
self.function = REGISTERED_FUNCTIONS[func]
self.outputs = outputs
self.inputs = inputs or {}
self.params = params or {}
if isinstance(publish_outputs, bool):
self.publish_outputs = self.outputs if publish_outputs else []
else:
self.publish_outputs = publish_outputs
def run( # type: ignore
self,
@ -404,7 +429,7 @@ class Shaper(BaseComponent):
results = {}
for output_key, output_value in zip(self.outputs, output_values):
invocation_context[output_key] = output_value
if output_key in ["query", "file_paths", "labels", "documents", "meta"]:
if output_key in self.publish_outputs:
results[output_key] = output_value
results["invocation_context"] = invocation_context

View File

@ -3,7 +3,9 @@ import logging
import haystack
from haystack import Pipeline, Document, Answer
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes.other.shaper import Shaper
from haystack.nodes.retriever.sparse import BM25Retriever
@pytest.fixture
@ -340,6 +342,37 @@ def test_join_documents():
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert results["documents"] == [Document(content="first | second | third")]
def test_join_documents_without_publish_outputs():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
params={"delimiter": " | "},
outputs=["documents"],
publish_outputs=False,
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert "documents" not in results
def test_join_documents_with_publish_outputs_as_list():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
params={"delimiter": " | "},
outputs=["documents"],
publish_outputs=["documents"],
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert results["documents"] == [Document(content="first | second | third")]
def test_join_documents_default_delimiter():
@ -457,6 +490,11 @@ def test_strings_to_answers_yaml(tmp_path):
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
]
assert result["answers"] == [
Answer(answer="a", type="generative"),
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
]
#
@ -1116,3 +1154,19 @@ def test_join_query_and_documents_convert_into_documents_yaml(tmp_path):
assert result["invocation_context"]["query_and_docs"]
assert len(result["invocation_context"]["query_and_docs"]) == 4
assert isinstance(result["invocation_context"]["query_and_docs"][0], Document)
def test_shaper_publishes_unknown_arg_does_not_break_pipeline():
documents = [Document(content="test query")]
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["unknown_by_retriever"], publish_outputs=True)
document_store = InMemoryDocumentStore(use_bm25=True)
document_store.write_documents(documents)
retriever = BM25Retriever(document_store=document_store)
pipeline = Pipeline()
pipeline.add_node(component=shaper, name="shaper", inputs=["Query"])
pipeline.add_node(component=retriever, name="retriever", inputs=["shaper"])
result = pipeline.run(query="test query")
assert result["invocation_context"]["unknown_by_retriever"] == "test query"
assert result["unknown_by_retriever"] == "test query"
assert len(result["documents"]) == 1