mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat: allow DocumentJoiner to accept top_k parameter in run method (#7709)
* feat: allow DocumentJoiner to accept top_k parameter in run method * Added release note for DocumentJoiner top_k fix
This commit is contained in:
parent
482f60ec99
commit
badb05b3ab
@ -74,12 +74,14 @@ class DocumentJoiner:
|
||||
self.sort_by_score = sort_by_score
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: Variadic[List[Document]]):
|
||||
def run(self, documents: Variadic[List[Document]], top_k: Optional[int] = None):
|
||||
"""
|
||||
Joins multiple lists of Documents into a single list depending on the `join_mode` parameter.
|
||||
|
||||
:param documents:
|
||||
List of list of Documents to be merged.
|
||||
:param top_k:
|
||||
The maximum number of Documents to return. Overrides the instance's `top_k` if provided.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
@ -103,8 +105,11 @@ class DocumentJoiner:
|
||||
"score, so those with score=None were sorted as if they had a score of -infinity."
|
||||
)
|
||||
|
||||
if self.top_k:
|
||||
if top_k:
|
||||
output_documents = output_documents[:top_k]
|
||||
elif self.top_k:
|
||||
output_documents = output_documents[: self.top_k]
|
||||
|
||||
return {"documents": output_documents}
|
||||
|
||||
def _concatenate(self, document_lists):
|
||||
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
|
||||
enhancements:
|
||||
- |
|
||||
The `DocumentJoiner` component's `run` method now accepts a `top_k` parameter, allowing users to specify the maximum number of documents to return at query time. This fixes issue #7702.
|
@ -115,6 +115,14 @@ class TestDocumentJoiner:
|
||||
]
|
||||
assert all(doc.id in expected_document_ids for doc in output["documents"])
|
||||
|
||||
def test_run_with_top_k_in_run_method(self):
|
||||
joiner = DocumentJoiner()
|
||||
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
|
||||
documents_2 = [Document(content="d"), Document(content="e"), Document(content="f")]
|
||||
top_k = 4
|
||||
output = joiner.run([documents_1, documents_2], top_k=top_k)
|
||||
assert len(output["documents"]) == top_k
|
||||
|
||||
def test_sort_by_score_without_scores(self, caplog):
|
||||
joiner = DocumentJoiner()
|
||||
with caplog.at_level(logging.INFO):
|
||||
|
Loading…
x
Reference in New Issue
Block a user