mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Fix JoinAnswer/JoinNode (#2612)
* fix join nodes * Update Documentation & Code Style * fix unused import * change arg order * Update Documentation & Code Style * fix kwargs check * add warning when there is only one input node * Update Documentation & Code Style * fix type hint * fix wrong import order * Update Documentation & Code Style * undo kwargs * add accidentally deleted newline# * fix type hint * fix type hint Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
a26c042994
commit
964e6cdafb
@ -23,7 +23,7 @@ This ensures that your output is in a compatible format.
|
||||
## JoinDocuments
|
||||
|
||||
```python
|
||||
class JoinDocuments(BaseComponent)
|
||||
class JoinDocuments(JoinNode)
|
||||
```
|
||||
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
@ -61,7 +61,7 @@ to each retriever score. This param is not compatible with the `concatenate` joi
|
||||
## JoinAnswers
|
||||
|
||||
```python
|
||||
class JoinAnswers(BaseComponent)
|
||||
class JoinAnswers(JoinNode)
|
||||
```
|
||||
|
||||
A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||
|
||||
@ -2,3 +2,4 @@ from haystack.nodes.other.docs2answers import Docs2Answers
|
||||
from haystack.nodes.other.join_docs import JoinDocuments
|
||||
from haystack.nodes.other.route_documents import RouteDocuments
|
||||
from haystack.nodes.other.join_answers import JoinAnswers
|
||||
from haystack.nodes.other.join import JoinNode
|
||||
|
||||
76
haystack/nodes/other/join.py
Normal file
76
haystack/nodes/other/join.py
Normal file
@ -0,0 +1,76 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, List, Tuple, Dict, Union, Any
|
||||
import warnings
|
||||
|
||||
from haystack import MultiLabel, Document, Answer
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class JoinNode(BaseComponent):
|
||||
def run( # type: ignore
|
||||
self,
|
||||
inputs: Optional[List[dict]] = None,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
answers: Optional[List[Answer]] = None,
|
||||
top_k_join: Optional[int] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
if inputs:
|
||||
return self.run_accumulated(inputs, top_k_join=top_k_join)
|
||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||
return self.run_accumulated(
|
||||
inputs=[
|
||||
{
|
||||
"query": query,
|
||||
"file_paths": file_paths,
|
||||
"labels": labels,
|
||||
"documents": documents,
|
||||
"meta": meta,
|
||||
"answers": answers,
|
||||
}
|
||||
],
|
||||
top_k_join=top_k_join,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
|
||||
pass
|
||||
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
inputs: Optional[List[dict]] = None,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
answers: Optional[List[Answer]] = None,
|
||||
top_k_join: Optional[int] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
if inputs:
|
||||
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||
return self.run_batch_accumulated(
|
||||
inputs=[
|
||||
{
|
||||
"queries": queries,
|
||||
"file_paths": file_paths,
|
||||
"labels": labels,
|
||||
"documents": documents,
|
||||
"meta": meta,
|
||||
"params": params,
|
||||
"debug": debug,
|
||||
"answers": answers,
|
||||
}
|
||||
],
|
||||
top_k_join=top_k_join
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
|
||||
pass
|
||||
@ -1,10 +1,10 @@
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
from haystack.schema import Answer
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.other.join import JoinNode
|
||||
|
||||
|
||||
class JoinAnswers(BaseComponent):
|
||||
class JoinAnswers(JoinNode):
|
||||
"""
|
||||
A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||
"""
|
||||
@ -40,7 +40,7 @@ class JoinAnswers(BaseComponent):
|
||||
self.top_k_join = top_k_join
|
||||
self.sort_by_score = sort_by_score
|
||||
|
||||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
reader_results = [inp["answers"] for inp in inputs]
|
||||
|
||||
if not top_k_join:
|
||||
@ -61,7 +61,7 @@ class JoinAnswers(BaseComponent):
|
||||
else:
|
||||
raise ValueError(f"Invalid join_mode: {self.join_mode}")
|
||||
|
||||
def run_batch(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
def run_batch_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
output_ans = []
|
||||
incoming_edges = [inp["answers"] for inp in inputs]
|
||||
# At each idx, we find predicted answers for the same query from different Readers
|
||||
|
||||
@ -3,10 +3,10 @@ from collections import defaultdict
|
||||
from typing import Optional, List
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.other.join import JoinNode
|
||||
|
||||
|
||||
class JoinDocuments(BaseComponent):
|
||||
class JoinDocuments(JoinNode):
|
||||
"""
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
|
||||
@ -47,7 +47,7 @@ class JoinDocuments(BaseComponent):
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
self.top_k_join = top_k_join
|
||||
|
||||
def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||
results = [inp["documents"] for inp in inputs]
|
||||
document_map = {doc.id: doc for result in results for doc in result}
|
||||
|
||||
@ -77,7 +77,7 @@ class JoinDocuments(BaseComponent):
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def run_batch(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
|
||||
# Join single document lists
|
||||
if isinstance(inputs[0]["documents"][0], Document):
|
||||
return self.run(inputs=inputs, top_k_join=top_k_join)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user