mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-08 13:54:31 +00:00
Fix JoinAnswer/JoinNode (#2612)
* fix join nodes * Update Documentation & Code Style * fix unused import * change arg order * Update Documentation & Code Style * fix kwargs check * add warning when there is only one input node * Update Documentation & Code Style * fix type hint * fix wrong import order * Update Documentation & Code Style * undo kwargs * add accidentally deleted newline# * fix type hint * fix type hint Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
a26c042994
commit
964e6cdafb
@ -23,7 +23,7 @@ This ensures that your output is in a compatible format.
|
|||||||
## JoinDocuments
|
## JoinDocuments
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class JoinDocuments(BaseComponent)
|
class JoinDocuments(JoinNode)
|
||||||
```
|
```
|
||||||
|
|
||||||
A node to join documents outputted by multiple retriever nodes.
|
A node to join documents outputted by multiple retriever nodes.
|
||||||
@ -61,7 +61,7 @@ to each retriever score. This param is not compatible with the `concatenate` joi
|
|||||||
## JoinAnswers
|
## JoinAnswers
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class JoinAnswers(BaseComponent)
|
class JoinAnswers(JoinNode)
|
||||||
```
|
```
|
||||||
|
|
||||||
A node to join `Answer`s produced by multiple `Reader` nodes.
|
A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||||
|
|||||||
@ -2,3 +2,4 @@ from haystack.nodes.other.docs2answers import Docs2Answers
|
|||||||
from haystack.nodes.other.join_docs import JoinDocuments
|
from haystack.nodes.other.join_docs import JoinDocuments
|
||||||
from haystack.nodes.other.route_documents import RouteDocuments
|
from haystack.nodes.other.route_documents import RouteDocuments
|
||||||
from haystack.nodes.other.join_answers import JoinAnswers
|
from haystack.nodes.other.join_answers import JoinAnswers
|
||||||
|
from haystack.nodes.other.join import JoinNode
|
||||||
|
|||||||
76
haystack/nodes/other/join.py
Normal file
76
haystack/nodes/other/join.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional, List, Tuple, Dict, Union, Any
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from haystack import MultiLabel, Document, Answer
|
||||||
|
from haystack.nodes.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
|
class JoinNode(BaseComponent):
|
||||||
|
def run( # type: ignore
|
||||||
|
self,
|
||||||
|
inputs: Optional[List[dict]] = None,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
file_paths: Optional[List[str]] = None,
|
||||||
|
labels: Optional[MultiLabel] = None,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
answers: Optional[List[Answer]] = None,
|
||||||
|
top_k_join: Optional[int] = None,
|
||||||
|
) -> Tuple[Dict, str]:
|
||||||
|
if inputs:
|
||||||
|
return self.run_accumulated(inputs, top_k_join=top_k_join)
|
||||||
|
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||||
|
return self.run_accumulated(
|
||||||
|
inputs=[
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"file_paths": file_paths,
|
||||||
|
"labels": labels,
|
||||||
|
"documents": documents,
|
||||||
|
"meta": meta,
|
||||||
|
"answers": answers,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
top_k_join=top_k_join,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_batch( # type: ignore
|
||||||
|
self,
|
||||||
|
inputs: Optional[List[dict]] = None,
|
||||||
|
queries: Optional[Union[str, List[str]]] = None,
|
||||||
|
file_paths: Optional[List[str]] = None,
|
||||||
|
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||||
|
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||||
|
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
debug: Optional[bool] = None,
|
||||||
|
answers: Optional[List[Answer]] = None,
|
||||||
|
top_k_join: Optional[int] = None,
|
||||||
|
) -> Tuple[Dict, str]:
|
||||||
|
if inputs:
|
||||||
|
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
||||||
|
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||||
|
return self.run_batch_accumulated(
|
||||||
|
inputs=[
|
||||||
|
{
|
||||||
|
"queries": queries,
|
||||||
|
"file_paths": file_paths,
|
||||||
|
"labels": labels,
|
||||||
|
"documents": documents,
|
||||||
|
"meta": meta,
|
||||||
|
"params": params,
|
||||||
|
"debug": debug,
|
||||||
|
"answers": answers,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
top_k_join=top_k_join
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
|
||||||
|
pass
|
||||||
@ -1,10 +1,10 @@
|
|||||||
from typing import Optional, List, Dict, Tuple
|
from typing import Optional, List, Dict, Tuple
|
||||||
|
|
||||||
from haystack.schema import Answer
|
from haystack.schema import Answer
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.other.join import JoinNode
|
||||||
|
|
||||||
|
|
||||||
class JoinAnswers(BaseComponent):
|
class JoinAnswers(JoinNode):
|
||||||
"""
|
"""
|
||||||
A node to join `Answer`s produced by multiple `Reader` nodes.
|
A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||||
"""
|
"""
|
||||||
@ -40,7 +40,7 @@ class JoinAnswers(BaseComponent):
|
|||||||
self.top_k_join = top_k_join
|
self.top_k_join = top_k_join
|
||||||
self.sort_by_score = sort_by_score
|
self.sort_by_score = sort_by_score
|
||||||
|
|
||||||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||||
reader_results = [inp["answers"] for inp in inputs]
|
reader_results = [inp["answers"] for inp in inputs]
|
||||||
|
|
||||||
if not top_k_join:
|
if not top_k_join:
|
||||||
@ -61,7 +61,7 @@ class JoinAnswers(BaseComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid join_mode: {self.join_mode}")
|
raise ValueError(f"Invalid join_mode: {self.join_mode}")
|
||||||
|
|
||||||
def run_batch(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
def run_batch_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||||
output_ans = []
|
output_ans = []
|
||||||
incoming_edges = [inp["answers"] for inp in inputs]
|
incoming_edges = [inp["answers"] for inp in inputs]
|
||||||
# At each idx, we find predicted answers for the same query from different Readers
|
# At each idx, we find predicted answers for the same query from different Readers
|
||||||
|
|||||||
@ -3,10 +3,10 @@ from collections import defaultdict
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.other.join import JoinNode
|
||||||
|
|
||||||
|
|
||||||
class JoinDocuments(BaseComponent):
|
class JoinDocuments(JoinNode):
|
||||||
"""
|
"""
|
||||||
A node to join documents outputted by multiple retriever nodes.
|
A node to join documents outputted by multiple retriever nodes.
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class JoinDocuments(BaseComponent):
|
|||||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||||
self.top_k_join = top_k_join
|
self.top_k_join = top_k_join
|
||||||
|
|
||||||
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||||
results = [inp["documents"] for inp in inputs]
|
results = [inp["documents"] for inp in inputs]
|
||||||
document_map = {doc.id: doc for result in results for doc in result}
|
document_map = {doc.id: doc for result in results for doc in result}
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class JoinDocuments(BaseComponent):
|
|||||||
|
|
||||||
return output, "output_1"
|
return output, "output_1"
|
||||||
|
|
||||||
def run_batch(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||||
# Join single document lists
|
# Join single document lists
|
||||||
if isinstance(inputs[0]["documents"][0], Document):
|
if isinstance(inputs[0]["documents"][0], Document):
|
||||||
return self.run(inputs=inputs, top_k_join=top_k_join)
|
return self.run(inputs=inputs, top_k_join=top_k_join)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user