diff --git a/haystack/components/joiners/document_joiner.py b/haystack/components/joiners/document_joiner.py index 9810a529f..2f6c044d1 100644 --- a/haystack/components/joiners/document_joiner.py +++ b/haystack/components/joiners/document_joiner.py @@ -74,12 +74,14 @@ class DocumentJoiner: self.sort_by_score = sort_by_score @component.output_types(documents=List[Document]) - def run(self, documents: Variadic[List[Document]]): + def run(self, documents: Variadic[List[Document]], top_k: Optional[int] = None): """ Joins multiple lists of Documents into a single list depending on the `join_mode` parameter. :param documents: List of list of Documents to be merged. + :param top_k: + The maximum number of Documents to return. Overrides the instance's `top_k` if provided. :returns: A dictionary with the following keys: @@ -103,8 +105,11 @@ class DocumentJoiner: "score, so those with score=None were sorted as if they had a score of -infinity." ) - if self.top_k: + if top_k: + output_documents = output_documents[:top_k] + elif self.top_k: output_documents = output_documents[: self.top_k] + return {"documents": output_documents} def _concatenate(self, document_lists): diff --git a/releasenotes/notes/fix-documentjoiner-topk-173141a894e5c093.yaml b/releasenotes/notes/fix-documentjoiner-topk-173141a894e5c093.yaml new file mode 100644 index 000000000..665186022 --- /dev/null +++ b/releasenotes/notes/fix-documentjoiner-topk-173141a894e5c093.yaml @@ -0,0 +1,5 @@ +--- + +enhancements: + - | + The `DocumentJoiner` component's `run` method now accepts a `top_k` parameter, allowing users to specify the maximum number of documents to return at query time. This fixes issue #7702. diff --git a/test/components/joiners/test_document_joiner.py b/test/components/joiners/test_document_joiner.py index 310ade4b6..48fa5b230 100644 --- a/test/components/joiners/test_document_joiner.py +++ b/test/components/joiners/test_document_joiner.py @@ -115,6 +115,14 @@ class TestDocumentJoiner: ] assert all(doc.id in expected_document_ids for doc in output["documents"]) + def test_run_with_top_k_in_run_method(self): + joiner = DocumentJoiner() + documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")] + documents_2 = [Document(content="d"), Document(content="e"), Document(content="f")] + top_k = 4 + output = joiner.run([documents_1, documents_2], top_k=top_k) + assert len(output["documents"]) == top_k + def test_sort_by_score_without_scores(self, caplog): joiner = DocumentJoiner() with caplog.at_level(logging.INFO):