feat: allow DocumentJoiner to accept top_k parameter in run method (#7709)

* feat: allow DocumentJoiner to accept top_k parameter in run method

* Added release note for DocumentJoiner top_k fix
This commit is contained in:
Varun Krishnan 2024-05-23 10:03:26 -04:00 committed by GitHub
parent 482f60ec99
commit badb05b3ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 2 deletions

View File

@ -74,12 +74,14 @@ class DocumentJoiner:
self.sort_by_score = sort_by_score self.sort_by_score = sort_by_score
@component.output_types(documents=List[Document]) @component.output_types(documents=List[Document])
def run(self, documents: Variadic[List[Document]]): def run(self, documents: Variadic[List[Document]], top_k: Optional[int] = None):
""" """
Joins multiple lists of Documents into a single list depending on the `join_mode` parameter. Joins multiple lists of Documents into a single list depending on the `join_mode` parameter.
:param documents: :param documents:
List of list of Documents to be merged. List of list of Documents to be merged.
:param top_k:
The maximum number of Documents to return. Overrides the instance's `top_k` if provided.
:returns: :returns:
A dictionary with the following keys: A dictionary with the following keys:
@ -103,8 +105,11 @@ class DocumentJoiner:
"score, so those with score=None were sorted as if they had a score of -infinity." "score, so those with score=None were sorted as if they had a score of -infinity."
) )
if self.top_k: if top_k:
output_documents = output_documents[:top_k]
elif self.top_k:
output_documents = output_documents[: self.top_k] output_documents = output_documents[: self.top_k]
return {"documents": output_documents} return {"documents": output_documents}
def _concatenate(self, document_lists): def _concatenate(self, document_lists):

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
The `DocumentJoiner` component's `run` method now accepts a `top_k` parameter, allowing users to specify the maximum number of documents to return at query time. This fixes issue #7702.

View File

@ -115,6 +115,14 @@ class TestDocumentJoiner:
] ]
assert all(doc.id in expected_document_ids for doc in output["documents"]) assert all(doc.id in expected_document_ids for doc in output["documents"])
def test_run_with_top_k_in_run_method(self):
joiner = DocumentJoiner()
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
documents_2 = [Document(content="d"), Document(content="e"), Document(content="f")]
top_k = 4
output = joiner.run([documents_1, documents_2], top_k=top_k)
assert len(output["documents"]) == top_k
def test_sort_by_score_without_scores(self, caplog): def test_sort_by_score_without_scores(self, caplog):
joiner = DocumentJoiner() joiner = DocumentJoiner()
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):