Fix JoinAnswer/JoinNode (#2612)

* fix join nodes

* Update Documentation & Code Style

* fix unused import

* change arg order

* Update Documentation & Code Style

* fix kwargs check

* add warning when there is only one input node

* Update Documentation & Code Style

* fix type hint

* fix wrong import order

* Update Documentation & Code Style

* undo kwargs

* add accidentally deleted newline#

* fix type hint

* fix type hint

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-06-17 16:29:15 +02:00 committed by GitHub
parent a26c042994
commit 964e6cdafb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 10 deletions

View File

@ -23,7 +23,7 @@ This ensures that your output is in a compatible format.
## JoinDocuments ## JoinDocuments
```python ```python
class JoinDocuments(BaseComponent) class JoinDocuments(JoinNode)
``` ```
A node to join documents outputted by multiple retriever nodes. A node to join documents outputted by multiple retriever nodes.
@ -61,7 +61,7 @@ to each retriever score. This param is not compatible with the `concatenate` joi
## JoinAnswers ## JoinAnswers
```python ```python
class JoinAnswers(BaseComponent) class JoinAnswers(JoinNode)
``` ```
A node to join `Answer`s produced by multiple `Reader` nodes. A node to join `Answer`s produced by multiple `Reader` nodes.

View File

@ -2,3 +2,4 @@ from haystack.nodes.other.docs2answers import Docs2Answers
from haystack.nodes.other.join_docs import JoinDocuments from haystack.nodes.other.join_docs import JoinDocuments
from haystack.nodes.other.route_documents import RouteDocuments from haystack.nodes.other.route_documents import RouteDocuments
from haystack.nodes.other.join_answers import JoinAnswers from haystack.nodes.other.join_answers import JoinAnswers
from haystack.nodes.other.join import JoinNode

View File

@ -0,0 +1,76 @@
from abc import abstractmethod
from typing import Optional, List, Tuple, Dict, Union, Any
import warnings
from haystack import MultiLabel, Document, Answer
from haystack.nodes.base import BaseComponent
class JoinNode(BaseComponent):
def run( # type: ignore
self,
inputs: Optional[List[dict]] = None,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_accumulated(inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_accumulated(
inputs=[
{
"query": query,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"answers": answers,
}
],
top_k_join=top_k_join,
)
@abstractmethod
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass
def run_batch( # type: ignore
self,
inputs: Optional[List[dict]] = None,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_batch_accumulated(
inputs=[
{
"queries": queries,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"params": params,
"debug": debug,
"answers": answers,
}
],
top_k_join=top_k_join
)
@abstractmethod
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass

View File

@ -1,10 +1,10 @@
from typing import Optional, List, Dict, Tuple from typing import Optional, List, Dict, Tuple
from haystack.schema import Answer from haystack.schema import Answer
from haystack.nodes.base import BaseComponent from haystack.nodes.other.join import JoinNode
class JoinAnswers(BaseComponent): class JoinAnswers(JoinNode):
""" """
A node to join `Answer`s produced by multiple `Reader` nodes. A node to join `Answer`s produced by multiple `Reader` nodes.
""" """
@ -40,7 +40,7 @@ class JoinAnswers(BaseComponent):
self.top_k_join = top_k_join self.top_k_join = top_k_join
self.sort_by_score = sort_by_score self.sort_by_score = sort_by_score
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
reader_results = [inp["answers"] for inp in inputs] reader_results = [inp["answers"] for inp in inputs]
if not top_k_join: if not top_k_join:
@ -61,7 +61,7 @@ class JoinAnswers(BaseComponent):
else: else:
raise ValueError(f"Invalid join_mode: {self.join_mode}") raise ValueError(f"Invalid join_mode: {self.join_mode}")
def run_batch(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore def run_batch_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
output_ans = [] output_ans = []
incoming_edges = [inp["answers"] for inp in inputs] incoming_edges = [inp["answers"] for inp in inputs]
# At each idx, we find predicted answers for the same query from different Readers # At each idx, we find predicted answers for the same query from different Readers

View File

@ -3,10 +3,10 @@ from collections import defaultdict
from typing import Optional, List from typing import Optional, List
from haystack.schema import Document from haystack.schema import Document
from haystack.nodes.base import BaseComponent from haystack.nodes.other.join import JoinNode
class JoinDocuments(BaseComponent): class JoinDocuments(JoinNode):
""" """
A node to join documents outputted by multiple retriever nodes. A node to join documents outputted by multiple retriever nodes.
@ -47,7 +47,7 @@ class JoinDocuments(BaseComponent):
self.weights = [float(i) / sum(weights) for i in weights] if weights else None self.weights = [float(i) / sum(weights) for i in weights] if weights else None
self.top_k_join = top_k_join self.top_k_join = top_k_join
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
results = [inp["documents"] for inp in inputs] results = [inp["documents"] for inp in inputs]
document_map = {doc.id: doc for result in results for doc in result} document_map = {doc.id: doc for result in results for doc in result}
@ -77,7 +77,7 @@ class JoinDocuments(BaseComponent):
return output, "output_1" return output, "output_1"
def run_batch(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
# Join single document lists # Join single document lists
if isinstance(inputs[0]["documents"][0], Document): if isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join) return self.run(inputs=inputs, top_k_join=top_k_join)