mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 15:59:14 +00:00
fix: add option to not override results by Shaper (#4231)
* add option to shaper and support answers * remove publish restrictions on outputs * support list
This commit is contained in:
parent
262c9771f4
commit
32b2abf9d5
@ -246,6 +246,7 @@ class Shaper(BaseComponent):
|
||||
outputs: List[str],
|
||||
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
publish_outputs: Union[bool, List[str]] = True,
|
||||
):
|
||||
"""
|
||||
Initializes the Shaper component.
|
||||
@ -319,14 +320,38 @@ class Shaper(BaseComponent):
|
||||
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
|
||||
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
|
||||
is not passed to this node by the pipeline.
|
||||
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
|
||||
:param outputs: The key to store the outputs in the invocation context. The length of the outputs must match
|
||||
the number of outputs produced by the function invoked.
|
||||
:param publish_outputs: Controls whether to publish the outputs to the pipeline's output.
|
||||
Set `True` (default value) to publishes all outputs or `False` to publish None.
|
||||
E.g. if `outputs = ["documents"]` result for `publish_outputs = True` looks like
|
||||
```python
|
||||
{
|
||||
"invocation_context": {
|
||||
"documents": [...]
|
||||
},
|
||||
"documents": [...]
|
||||
}
|
||||
```
|
||||
For `publish_outputs = False` result looks like
|
||||
```python
|
||||
{
|
||||
"invocation_context": {
|
||||
"documents": [...]
|
||||
},
|
||||
}
|
||||
```
|
||||
If you want to have finer-grained control, pass a list of the outputs you want to publish.
|
||||
"""
|
||||
super().__init__()
|
||||
self.function = REGISTERED_FUNCTIONS[func]
|
||||
self.outputs = outputs
|
||||
self.inputs = inputs or {}
|
||||
self.params = params or {}
|
||||
if isinstance(publish_outputs, bool):
|
||||
self.publish_outputs = self.outputs if publish_outputs else []
|
||||
else:
|
||||
self.publish_outputs = publish_outputs
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
@ -404,7 +429,7 @@ class Shaper(BaseComponent):
|
||||
results = {}
|
||||
for output_key, output_value in zip(self.outputs, output_values):
|
||||
invocation_context[output_key] = output_value
|
||||
if output_key in ["query", "file_paths", "labels", "documents", "meta"]:
|
||||
if output_key in self.publish_outputs:
|
||||
results[output_key] = output_value
|
||||
results["invocation_context"] = invocation_context
|
||||
|
||||
|
||||
@ -3,7 +3,9 @@ import logging
|
||||
|
||||
import haystack
|
||||
from haystack import Pipeline, Document, Answer
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.nodes.other.shaper import Shaper
|
||||
from haystack.nodes.retriever.sparse import BM25Retriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -340,6 +342,37 @@ def test_join_documents():
|
||||
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
|
||||
assert results["documents"] == [Document(content="first | second | third")]
|
||||
|
||||
|
||||
def test_join_documents_without_publish_outputs():
|
||||
shaper = Shaper(
|
||||
func="join_documents",
|
||||
inputs={"documents": "documents"},
|
||||
params={"delimiter": " | "},
|
||||
outputs=["documents"],
|
||||
publish_outputs=False,
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
|
||||
assert "documents" not in results
|
||||
|
||||
|
||||
def test_join_documents_with_publish_outputs_as_list():
|
||||
shaper = Shaper(
|
||||
func="join_documents",
|
||||
inputs={"documents": "documents"},
|
||||
params={"delimiter": " | "},
|
||||
outputs=["documents"],
|
||||
publish_outputs=["documents"],
|
||||
)
|
||||
results, _ = shaper.run(
|
||||
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
|
||||
)
|
||||
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
|
||||
assert results["documents"] == [Document(content="first | second | third")]
|
||||
|
||||
|
||||
def test_join_documents_default_delimiter():
|
||||
@ -457,6 +490,11 @@ def test_strings_to_answers_yaml(tmp_path):
|
||||
Answer(answer="b", type="generative"),
|
||||
Answer(answer="c", type="generative"),
|
||||
]
|
||||
assert result["answers"] == [
|
||||
Answer(answer="a", type="generative"),
|
||||
Answer(answer="b", type="generative"),
|
||||
Answer(answer="c", type="generative"),
|
||||
]
|
||||
|
||||
|
||||
#
|
||||
@ -1116,3 +1154,19 @@ def test_join_query_and_documents_convert_into_documents_yaml(tmp_path):
|
||||
assert result["invocation_context"]["query_and_docs"]
|
||||
assert len(result["invocation_context"]["query_and_docs"]) == 4
|
||||
assert isinstance(result["invocation_context"]["query_and_docs"][0], Document)
|
||||
|
||||
|
||||
def test_shaper_publishes_unknown_arg_does_not_break_pipeline():
|
||||
documents = [Document(content="test query")]
|
||||
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["unknown_by_retriever"], publish_outputs=True)
|
||||
document_store = InMemoryDocumentStore(use_bm25=True)
|
||||
document_store.write_documents(documents)
|
||||
retriever = BM25Retriever(document_store=document_store)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=shaper, name="shaper", inputs=["Query"])
|
||||
pipeline.add_node(component=retriever, name="retriever", inputs=["shaper"])
|
||||
|
||||
result = pipeline.run(query="test query")
|
||||
assert result["invocation_context"]["unknown_by_retriever"] == "test query"
|
||||
assert result["unknown_by_retriever"] == "test query"
|
||||
assert len(result["documents"]) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user