mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: More flexible routing for RouteDocuments node (#4690)
* Added warning messages for documents that are skipped by RouteDocuments. Begun adding support for new option return_remaining and List of List support for metadata value splitting. * Simplify _split_by_content_type * Added new unit test and updated _calculate_outgoing_edges * Added some TODOs and turned assert into raising an error. * Update logging messages and make new fixture in tests * Update _split_by_metadata_values to work with return_remaining * Remove unneeded code * Documentation * Add proper support for list of lists * Fix mypy errors * Added assert to make mypy happy * Update haystack/nodes/other/route_documents.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * PR comments * Remove check for logging level * make mypy happy * Update docstring of metadata_values * Removed duplicate check. Make explicit check for metadata_values --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
b06821b311
commit
8c4176bdb2
@ -1,9 +1,12 @@
|
||||
from typing import Any, List, Tuple, Dict, Optional, Union
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouteDocuments(BaseComponent):
|
||||
"""
|
||||
@ -14,7 +17,12 @@ class RouteDocuments(BaseComponent):
|
||||
# By default (split_by == "content_type"), the node has two outgoing edges.
|
||||
outgoing_edges = 2
|
||||
|
||||
def __init__(self, split_by: str = "content_type", metadata_values: Optional[List[str]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
split_by: str = "content_type",
|
||||
metadata_values: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
return_remaining: bool = False,
|
||||
):
|
||||
"""
|
||||
:param split_by: Field to split the documents by, either `"content_type"` or a metadata field name.
|
||||
If this parameter is set to `"content_type"`, the list of `Document`s will be split into a list containing
|
||||
@ -22,59 +30,117 @@ class RouteDocuments(BaseComponent):
|
||||
type `"table"` (will be routed to `"output_2"`).
|
||||
If this parameter is set to a metadata field name, you need to specify the parameter `metadata_values` as
|
||||
well.
|
||||
:param metadata_values: If the parameter `split_by` is set to a metadata field name, you need to provide a list
|
||||
of values to group the `Document`s to. `Document`s whose metadata field is equal to the first value of the
|
||||
provided list will be routed to `"output_1"`, `Document`s whose metadata field is equal to the second
|
||||
value of the provided list will be routed to `"output_2"`, etc.
|
||||
:param metadata_values: A list of values to group `Document`s by metadata field. If the parameter `split_by`
|
||||
is set to a metadata field name, you must provide a list of values (or a list of lists of values) to
|
||||
group the `Document`s by.
|
||||
If `metadata_values` is a list of strings, then the `Document`s whose metadata field is equal to the
|
||||
corresponding value will be routed to the output with the same index.
|
||||
If `metadata_values` is a list of lists, then the `Document`s whose metadata field is equal to the first
|
||||
value of the provided sublist will be routed to `"output_1"`, the `Document`s whose metadata field is equal
|
||||
to the second value of the provided sublist will be routed to `"output_2"`, and so on.
|
||||
:param return_remaining: Whether to return all remaining documents that don't match the `split_by` or
|
||||
`metadata_values` into an additional output route. This additional output route will be indexed to plus one
|
||||
of the previous last output route. For example, if there would normally be `"output_1"` and `"output_2"`
|
||||
when return_remaining is False, then when return_remaining is True the additional output route would be
|
||||
`"output_3"`.
|
||||
"""
|
||||
|
||||
if split_by != "content_type" and metadata_values is None:
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
||||
"to group the documents to."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.split_by = split_by
|
||||
self.metadata_values = metadata_values
|
||||
self.return_remaining = return_remaining
|
||||
|
||||
if self.split_by != "content_type":
|
||||
if self.metadata_values is None or len(self.metadata_values) == 0:
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
|
||||
"a list of Documents by a metadata field."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
split_by = component_params.get("split_by", "content_type")
|
||||
metadata_values = component_params.get("metadata_values", None)
|
||||
return_remaining = component_params.get("return_remaining", False)
|
||||
|
||||
# If we split list of Documents by a metadata field, number of outgoing edges might change
|
||||
if split_by != "content_type" and metadata_values is not None:
|
||||
return len(metadata_values)
|
||||
return 2
|
||||
num_edges = len(metadata_values)
|
||||
else:
|
||||
num_edges = 2
|
||||
|
||||
if return_remaining:
|
||||
num_edges += 1
|
||||
return num_edges
|
||||
|
||||
def _split_by_content_type(self, documents: List[Document]) -> Dict[str, List[Document]]:
|
||||
mapping = {"text": "output_1", "table": "output_2"}
|
||||
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": [], "output_3": []}
|
||||
for doc in documents:
|
||||
output_route = mapping.get(doc.content_type, "output_3")
|
||||
split_documents[output_route].append(doc)
|
||||
|
||||
if not self.return_remaining:
|
||||
# Used to avoid unnecessarily calculating other_content_types depending on logging level
|
||||
if logger.isEnabledFor(logging.WARNING) and len(split_documents["output_3"]) > 0:
|
||||
other_content_types = {x.content_type for x in split_documents["output_3"]}
|
||||
logger.warning(
|
||||
"%s document(s) were skipped because they have content type(s) %s. Only the content "
|
||||
"types 'text' and 'table' are routed.",
|
||||
len(split_documents["output_3"]),
|
||||
other_content_types,
|
||||
)
|
||||
del split_documents["output_3"]
|
||||
|
||||
return split_documents
|
||||
|
||||
def _get_metadata_values_index(self, metadata_values: Union[List[str], List[List[str]]], value: str) -> int:
|
||||
for idx, item in enumerate(metadata_values):
|
||||
if isinstance(item, list):
|
||||
if value in item:
|
||||
return idx
|
||||
else:
|
||||
if value == item:
|
||||
return idx
|
||||
return len(metadata_values)
|
||||
|
||||
def _split_by_metadata_values(
|
||||
self, metadata_values: Union[List, List[List]], documents: List[Document]
|
||||
) -> Dict[str, List[Document]]:
|
||||
# We need also to keep track of the excluded documents so we add 2 to the number of metadata_values
|
||||
output_keys = [f"output_{i}" for i in range(1, len(metadata_values) + 2)]
|
||||
split_documents: Dict[str, List[Document]] = {k: [] for k in output_keys}
|
||||
# This is the key used for excluded documents
|
||||
remaining_key = output_keys[-1]
|
||||
|
||||
for doc in documents:
|
||||
current_metadata_value = doc.meta.get(self.split_by, remaining_key)
|
||||
index = self._get_metadata_values_index(metadata_values, current_metadata_value)
|
||||
output = output_keys[index]
|
||||
split_documents[output].append(doc)
|
||||
|
||||
if not self.return_remaining:
|
||||
if len(split_documents[remaining_key]) > 0:
|
||||
logger.warning(
|
||||
"%s documents were skipped because they were either missing the metadata field '%s' or the"
|
||||
" corresponding metadata value is not included in `metadata_values`.",
|
||||
len(split_documents[remaining_key]),
|
||||
self.split_by,
|
||||
)
|
||||
del split_documents[remaining_key]
|
||||
|
||||
return split_documents
|
||||
|
||||
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
|
||||
if self.split_by == "content_type":
|
||||
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": []}
|
||||
|
||||
for doc in documents:
|
||||
if doc.content_type == "text":
|
||||
split_documents["output_1"].append(doc)
|
||||
elif doc.content_type == "table":
|
||||
split_documents["output_2"].append(doc)
|
||||
|
||||
split_documents = self._split_by_content_type(documents)
|
||||
elif self.metadata_values:
|
||||
split_documents = self._split_by_metadata_values(self.metadata_values, documents)
|
||||
else:
|
||||
assert isinstance(self.metadata_values, list), (
|
||||
"You need to provide metadata_values if you want to split" " a list of Documents by a metadata field."
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
|
||||
"a list of Documents by a metadata field."
|
||||
)
|
||||
split_documents = {f"output_{i+1}": [] for i in range(len(self.metadata_values))}
|
||||
for doc in documents:
|
||||
current_metadata_value = doc.meta.get(self.split_by, None)
|
||||
# Disregard current document if it does not contain the provided metadata field
|
||||
if current_metadata_value is not None:
|
||||
try:
|
||||
index = self.metadata_values.index(current_metadata_value)
|
||||
except ValueError:
|
||||
# Disregard current document if current_metadata_value is not in the provided metadata_values
|
||||
continue
|
||||
|
||||
split_documents[f"output_{index+1}"].append(doc)
|
||||
|
||||
return split_documents, "split"
|
||||
|
||||
def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[Dict, str]: # type: ignore
|
||||
@ -86,5 +152,4 @@ class RouteDocuments(BaseComponent):
|
||||
results, _ = self.run(documents=doc_list) # type: ignore
|
||||
for key in results:
|
||||
split_documents[key].append(results[key])
|
||||
|
||||
return split_documents, "split"
|
||||
|
||||
@ -5,23 +5,52 @@ from haystack.schema import Document
|
||||
from haystack.nodes import RouteDocuments
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_content_type():
|
||||
docs = [
|
||||
@pytest.fixture
|
||||
def docs_diff_types():
|
||||
return [
|
||||
Document(content="text document", content_type="text"),
|
||||
Document(
|
||||
content=pd.DataFrame(columns=["col 1", "col 2"], data=[["row 1", "row 1"], ["row 2", "row 2"]]),
|
||||
content_type="table",
|
||||
),
|
||||
Document(content="image/path", content_type="image"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docs_with_meta():
|
||||
return [
|
||||
Document(content="text document 1", content_type="text", meta={"meta_field": "test1"}),
|
||||
Document(content="text document 2", content_type="text", meta={"meta_field": "test2"}),
|
||||
Document(content="text document 3", content_type="text", meta={"meta_field": "test3"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_content_type(docs_diff_types):
|
||||
route_documents = RouteDocuments()
|
||||
result, _ = route_documents.run(documents=docs)
|
||||
result, _ = route_documents.run(documents=docs_diff_types)
|
||||
assert route_documents.outgoing_edges == 2
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert "output_3" not in result
|
||||
assert result["output_1"][0].content_type == "text"
|
||||
assert result["output_2"][0].content_type == "table"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_content_type_return_remaining(docs_diff_types):
|
||||
route_documents = RouteDocuments(return_remaining=True)
|
||||
result, _ = route_documents.run(documents=docs_diff_types)
|
||||
assert route_documents.outgoing_edges == 3
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert len(result["output_3"]) == 1
|
||||
assert result["output_1"][0].content_type == "text"
|
||||
assert result["output_2"][0].content_type == "table"
|
||||
assert result["output_3"][0].content_type == "image"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_metafield(docs):
|
||||
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
|
||||
@ -30,6 +59,40 @@ def test_routedocuments_by_metafield(docs):
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert len(result["output_3"]) == 1
|
||||
assert "output_4" not in result
|
||||
assert result["output_1"][0].meta["meta_field"] == "test1"
|
||||
assert result["output_2"][0].meta["meta_field"] == "test3"
|
||||
assert result["output_3"][0].meta["meta_field"] == "test5"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_metafield_return_remaning(docs):
|
||||
route_documents = RouteDocuments(
|
||||
split_by="meta_field", metadata_values=["test1", "test3", "test5"], return_remaining=True
|
||||
)
|
||||
assert route_documents.outgoing_edges == 4
|
||||
result, _ = route_documents.run(docs)
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert len(result["output_3"]) == 1
|
||||
assert len(result["output_4"]) == 2
|
||||
assert result["output_1"][0].meta["meta_field"] == "test1"
|
||||
assert result["output_2"][0].meta["meta_field"] == "test3"
|
||||
assert result["output_3"][0].meta["meta_field"] == "test5"
|
||||
assert result["output_4"][0].meta["meta_field"] == "test2"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_routedocuments_by_metafield_list_of_lists(docs):
|
||||
route_documents = RouteDocuments(
|
||||
split_by="meta_field", metadata_values=[["test1", "test3"], "test5"], return_remaining=True
|
||||
)
|
||||
assert route_documents.outgoing_edges == 3
|
||||
result, _ = route_documents.run(docs)
|
||||
assert len(result["output_1"]) == 2
|
||||
assert len(result["output_2"]) == 1
|
||||
assert len(result["output_3"]) == 2
|
||||
assert result["output_1"][0].meta["meta_field"] == "test1"
|
||||
assert result["output_1"][1].meta["meta_field"] == "test3"
|
||||
assert result["output_2"][0].meta["meta_field"] == "test5"
|
||||
assert result["output_3"][0].meta["meta_field"] == "test2"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user