Fix JoinAnswer/JoinNode (#2612)

* fix join nodes

* Update Documentation & Code Style

* fix unused import

* change arg order

* Update Documentation & Code Style

* fix kwargs check

* add warning when there is only one input node

* Update Documentation & Code Style

* fix type hint

* fix wrong import order

* Update Documentation & Code Style

* undo kwargs

* add accidentally deleted newline#

* fix type hint

* fix type hint

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-06-17 16:29:15 +02:00 committed by GitHub
parent a26c042994
commit 964e6cdafb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 10 deletions

View File

@ -23,7 +23,7 @@ This ensures that your output is in a compatible format.
## JoinDocuments
```python
class JoinDocuments(BaseComponent)
class JoinDocuments(JoinNode)
```
A node to join documents outputted by multiple retriever nodes.
@ -61,7 +61,7 @@ to each retriever score. This param is not compatible with the `concatenate` joi
## JoinAnswers
```python
class JoinAnswers(BaseComponent)
class JoinAnswers(JoinNode)
```
A node to join `Answer`s produced by multiple `Reader` nodes.

View File

@ -2,3 +2,4 @@ from haystack.nodes.other.docs2answers import Docs2Answers
from haystack.nodes.other.join_docs import JoinDocuments
from haystack.nodes.other.route_documents import RouteDocuments
from haystack.nodes.other.join_answers import JoinAnswers
from haystack.nodes.other.join import JoinNode

View File

@ -0,0 +1,76 @@
from abc import abstractmethod
from typing import Optional, List, Tuple, Dict, Union, Any
import warnings
from haystack import MultiLabel, Document, Answer
from haystack.nodes.base import BaseComponent
class JoinNode(BaseComponent):
def run( # type: ignore
self,
inputs: Optional[List[dict]] = None,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_accumulated(inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_accumulated(
inputs=[
{
"query": query,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"answers": answers,
}
],
top_k_join=top_k_join,
)
@abstractmethod
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass
def run_batch( # type: ignore
self,
inputs: Optional[List[dict]] = None,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_batch_accumulated(
inputs=[
{
"queries": queries,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"params": params,
"debug": debug,
"answers": answers,
}
],
top_k_join=top_k_join
)
@abstractmethod
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass

View File

@ -1,10 +1,10 @@
from typing import Optional, List, Dict, Tuple
from haystack.schema import Answer
from haystack.nodes.base import BaseComponent
from haystack.nodes.other.join import JoinNode
class JoinAnswers(BaseComponent):
class JoinAnswers(JoinNode):
"""
A node to join `Answer`s produced by multiple `Reader` nodes.
"""
@ -40,7 +40,7 @@ class JoinAnswers(BaseComponent):
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
reader_results = [inp["answers"] for inp in inputs]
if not top_k_join:
@ -61,7 +61,7 @@ class JoinAnswers(BaseComponent):
else:
raise ValueError(f"Invalid join_mode: {self.join_mode}")
def run_batch(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
def run_batch_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
output_ans = []
incoming_edges = [inp["answers"] for inp in inputs]
# At each idx, we find predicted answers for the same query from different Readers

View File

@ -3,10 +3,10 @@ from collections import defaultdict
from typing import Optional, List
from haystack.schema import Document
from haystack.nodes.base import BaseComponent
from haystack.nodes.other.join import JoinNode
class JoinDocuments(BaseComponent):
class JoinDocuments(JoinNode):
"""
A node to join documents outputted by multiple retriever nodes.
@ -47,7 +47,7 @@ class JoinDocuments(BaseComponent):
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
self.top_k_join = top_k_join
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
results = [inp["documents"] for inp in inputs]
document_map = {doc.id: doc for result in results for doc in result}
@ -77,7 +77,7 @@ class JoinDocuments(BaseComponent):
return output, "output_1"
def run_batch(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)