Add top_k_join parameter to JoinDocuments.run (#2065)

* add top_k_join parameter to JoinDocuments.run

* test JoinDocuments concatenate with top_k_join parameter

* test two different top_k_join parameters
This commit is contained in:
Adrien Wald 2022-01-26 16:30:16 +00:00 committed by GitHub
parent 5b7e906e85
commit 2edc421a09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 3 deletions

View File

@ -39,7 +39,7 @@ class JoinDocuments(BaseComponent):
self.weights = [float(i)/sum(weights) for i in weights] if weights else None
self.top_k_join = top_k_join
def run(self, inputs: List[dict]): # type: ignore
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
if self.join_mode == "concatenate":
document_map = {}
for input_from_node in inputs:
@ -63,7 +63,10 @@ class JoinDocuments(BaseComponent):
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
if self.top_k_join:
documents = documents[: self.top_k_join]
if top_k_join is None:
top_k_join = self.top_k_join
if top_k_join:
documents = documents[: top_k_join]
output = {"documents": documents, "labels": inputs[0].get("labels", None)}
return output, "output_1"

View File

@ -167,6 +167,17 @@ def test_join_document_pipeline(document_store_dot_product_with_docs, reader):
results = p.run(query=query)
assert len(results["documents"]) == 3
# test concatenate with top_k_join parameter
join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
one_result = p.run(query=query, params={ 'Join': { 'top_k_join': 1 } })
two_results = p.run(query=query, params={ 'Join': { 'top_k_join': 2 } })
assert len(one_result["documents"]) == 1
assert len(two_results["documents"]) == 2
# test join_node with reader
join_node = JoinDocuments()
p = Pipeline()