mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
Add top_k_join parameter to JoinDocuments.run (#2065)
* add top_k_join parameter to JoinDocuments.run * test JoinDocuments concatenate with top_k_join parameter * test two different top_k_join parameters
This commit is contained in:
parent
5b7e906e85
commit
2edc421a09
@ -39,7 +39,7 @@ class JoinDocuments(BaseComponent):
|
||||
self.weights = [float(i)/sum(weights) for i in weights] if weights else None
|
||||
self.top_k_join = top_k_join
|
||||
|
||||
def run(self, inputs: List[dict]): # type: ignore
|
||||
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||
if self.join_mode == "concatenate":
|
||||
document_map = {}
|
||||
for input_from_node in inputs:
|
||||
@ -63,7 +63,10 @@ class JoinDocuments(BaseComponent):
|
||||
|
||||
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
|
||||
|
||||
if self.top_k_join:
|
||||
documents = documents[: self.top_k_join]
|
||||
if top_k_join is None:
|
||||
top_k_join = self.top_k_join
|
||||
|
||||
if top_k_join:
|
||||
documents = documents[: top_k_join]
|
||||
output = {"documents": documents, "labels": inputs[0].get("labels", None)}
|
||||
return output, "output_1"
|
||||
|
||||
@ -167,6 +167,17 @@ def test_join_document_pipeline(document_store_dot_product_with_docs, reader):
|
||||
results = p.run(query=query)
|
||||
assert len(results["documents"]) == 3
|
||||
|
||||
# test concatenate with top_k_join parameter
|
||||
join_node = JoinDocuments(join_mode="concatenate")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
one_result = p.run(query=query, params={ 'Join': { 'top_k_join': 1 } })
|
||||
two_results = p.run(query=query, params={ 'Join': { 'top_k_join': 2 } })
|
||||
assert len(one_result["documents"]) == 1
|
||||
assert len(two_results["documents"]) == 2
|
||||
|
||||
# test join_node with reader
|
||||
join_node = JoinDocuments()
|
||||
p = Pipeline()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user