feat: More flexible routing for RouteDocuments node (#4690)

* Added warning messages for documents that are skipped by RouteDocuments. Begun adding support for new option return_remaining and List of List support for metadata value splitting.

* Simplify _split_by_content_type

* Added new unit test and updated _calculate_outgoing_edges

* Added some TODOs and turned assert into raising an error.

* Update logging messages and make new fixture in tests

* Update _split_by_metadata_values to work with return_remaining

* Remove unneeded code

* Documentation

* Add proper support for list of lists

* Fix mypy errors

* Added assert to make mypy happy

* Update haystack/nodes/other/route_documents.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* PR comments

* Remove check for logging level

* make mypy happy

* Update docstring of metadata_values

* Removed duplicate check. Make explicit check for metadata_values

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Sebastian 2023-04-18 15:18:13 +02:00 committed by GitHub
parent b06821b311
commit 8c4176bdb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 170 additions and 42 deletions

View File

@ -1,9 +1,12 @@
from typing import Any, List, Tuple, Dict, Optional, Union
from collections import defaultdict
import logging
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
logger = logging.getLogger(__name__)
class RouteDocuments(BaseComponent):
"""
@ -14,7 +17,12 @@ class RouteDocuments(BaseComponent):
# By default (split_by == "content_type"), the node has two outgoing edges.
outgoing_edges = 2
def __init__(self, split_by: str = "content_type", metadata_values: Optional[List[str]] = None):
def __init__(
self,
split_by: str = "content_type",
metadata_values: Optional[Union[List[str], List[List[str]]]] = None,
return_remaining: bool = False,
):
"""
:param split_by: Field to split the documents by, either `"content_type"` or a metadata field name.
If this parameter is set to `"content_type"`, the list of `Document`s will be split into a list containing
@ -22,59 +30,117 @@ class RouteDocuments(BaseComponent):
type `"table"` (will be routed to `"output_2"`).
If this parameter is set to a metadata field name, you need to specify the parameter `metadata_values` as
well.
:param metadata_values: If the parameter `split_by` is set to a metadata field name, you need to provide a list
of values to group the `Document`s to. `Document`s whose metadata field is equal to the first value of the
provided list will be routed to `"output_1"`, `Document`s whose metadata field is equal to the second
value of the provided list will be routed to `"output_2"`, etc.
:param metadata_values: A list of values to group `Document`s by metadata field. If the parameter `split_by`
is set to a metadata field name, you must provide a list of values (or a list of lists of values) to
group the `Document`s by.
If `metadata_values` is a list of strings, then the `Document`s whose metadata field is equal to the
corresponding value will be routed to the output with the same index.
If `metadata_values` is a list of lists, then the `Document`s whose metadata field is equal to the first
value of the provided sublist will be routed to `"output_1"`, the `Document`s whose metadata field is equal
to the second value of the provided sublist will be routed to `"output_2"`, and so on.
:param return_remaining: Whether to return all remaining documents that don't match the `split_by` or
`metadata_values` into an additional output route. This additional output route will be indexed to plus one
of the previous last output route. For example, if there would normally be `"output_1"` and `"output_2"`
when return_remaining is False, then when return_remaining is True the additional output route would be
`"output_3"`.
"""
if split_by != "content_type" and metadata_values is None:
raise ValueError(
"If split_by is set to the name of a metadata field, you must provide metadata_values "
"to group the documents to."
)
super().__init__()
self.split_by = split_by
self.metadata_values = metadata_values
self.return_remaining = return_remaining
if self.split_by != "content_type":
if self.metadata_values is None or len(self.metadata_values) == 0:
raise ValueError(
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
"a list of Documents by a metadata field."
)
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
split_by = component_params.get("split_by", "content_type")
metadata_values = component_params.get("metadata_values", None)
return_remaining = component_params.get("return_remaining", False)
# If we split list of Documents by a metadata field, number of outgoing edges might change
if split_by != "content_type" and metadata_values is not None:
return len(metadata_values)
return 2
num_edges = len(metadata_values)
else:
num_edges = 2
if return_remaining:
num_edges += 1
return num_edges
def _split_by_content_type(self, documents: List[Document]) -> Dict[str, List[Document]]:
mapping = {"text": "output_1", "table": "output_2"}
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": [], "output_3": []}
for doc in documents:
output_route = mapping.get(doc.content_type, "output_3")
split_documents[output_route].append(doc)
if not self.return_remaining:
# Used to avoid unnecessarily calculating other_content_types depending on logging level
if logger.isEnabledFor(logging.WARNING) and len(split_documents["output_3"]) > 0:
other_content_types = {x.content_type for x in split_documents["output_3"]}
logger.warning(
"%s document(s) were skipped because they have content type(s) %s. Only the content "
"types 'text' and 'table' are routed.",
len(split_documents["output_3"]),
other_content_types,
)
del split_documents["output_3"]
return split_documents
def _get_metadata_values_index(self, metadata_values: Union[List[str], List[List[str]]], value: str) -> int:
for idx, item in enumerate(metadata_values):
if isinstance(item, list):
if value in item:
return idx
else:
if value == item:
return idx
return len(metadata_values)
def _split_by_metadata_values(
self, metadata_values: Union[List, List[List]], documents: List[Document]
) -> Dict[str, List[Document]]:
# We need also to keep track of the excluded documents so we add 2 to the number of metadata_values
output_keys = [f"output_{i}" for i in range(1, len(metadata_values) + 2)]
split_documents: Dict[str, List[Document]] = {k: [] for k in output_keys}
# This is the key used for excluded documents
remaining_key = output_keys[-1]
for doc in documents:
current_metadata_value = doc.meta.get(self.split_by, remaining_key)
index = self._get_metadata_values_index(metadata_values, current_metadata_value)
output = output_keys[index]
split_documents[output].append(doc)
if not self.return_remaining:
if len(split_documents[remaining_key]) > 0:
logger.warning(
"%s documents were skipped because they were either missing the metadata field '%s' or the"
" corresponding metadata value is not included in `metadata_values`.",
len(split_documents[remaining_key]),
self.split_by,
)
del split_documents[remaining_key]
return split_documents
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
if self.split_by == "content_type":
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": []}
for doc in documents:
if doc.content_type == "text":
split_documents["output_1"].append(doc)
elif doc.content_type == "table":
split_documents["output_2"].append(doc)
split_documents = self._split_by_content_type(documents)
elif self.metadata_values:
split_documents = self._split_by_metadata_values(self.metadata_values, documents)
else:
assert isinstance(self.metadata_values, list), (
"You need to provide metadata_values if you want to split" " a list of Documents by a metadata field."
raise ValueError(
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
"a list of Documents by a metadata field."
)
split_documents = {f"output_{i+1}": [] for i in range(len(self.metadata_values))}
for doc in documents:
current_metadata_value = doc.meta.get(self.split_by, None)
# Disregard current document if it does not contain the provided metadata field
if current_metadata_value is not None:
try:
index = self.metadata_values.index(current_metadata_value)
except ValueError:
# Disregard current document if current_metadata_value is not in the provided metadata_values
continue
split_documents[f"output_{index+1}"].append(doc)
return split_documents, "split"
def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[Dict, str]: # type: ignore
@ -86,5 +152,4 @@ class RouteDocuments(BaseComponent):
results, _ = self.run(documents=doc_list) # type: ignore
for key in results:
split_documents[key].append(results[key])
return split_documents, "split"

View File

@ -5,23 +5,52 @@ from haystack.schema import Document
from haystack.nodes import RouteDocuments
@pytest.mark.unit
def test_routedocuments_by_content_type():
docs = [
@pytest.fixture
def docs_diff_types():
return [
Document(content="text document", content_type="text"),
Document(
content=pd.DataFrame(columns=["col 1", "col 2"], data=[["row 1", "row 1"], ["row 2", "row 2"]]),
content_type="table",
),
Document(content="image/path", content_type="image"),
]
@pytest.fixture
def docs_with_meta():
return [
Document(content="text document 1", content_type="text", meta={"meta_field": "test1"}),
Document(content="text document 2", content_type="text", meta={"meta_field": "test2"}),
Document(content="text document 3", content_type="text", meta={"meta_field": "test3"}),
]
@pytest.mark.unit
def test_routedocuments_by_content_type(docs_diff_types):
route_documents = RouteDocuments()
result, _ = route_documents.run(documents=docs)
result, _ = route_documents.run(documents=docs_diff_types)
assert route_documents.outgoing_edges == 2
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert "output_3" not in result
assert result["output_1"][0].content_type == "text"
assert result["output_2"][0].content_type == "table"
@pytest.mark.unit
def test_routedocuments_by_content_type_return_remaining(docs_diff_types):
route_documents = RouteDocuments(return_remaining=True)
result, _ = route_documents.run(documents=docs_diff_types)
assert route_documents.outgoing_edges == 3
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert result["output_1"][0].content_type == "text"
assert result["output_2"][0].content_type == "table"
assert result["output_3"][0].content_type == "image"
@pytest.mark.unit
def test_routedocuments_by_metafield(docs):
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
@ -30,6 +59,40 @@ def test_routedocuments_by_metafield(docs):
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert "output_4" not in result
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_2"][0].meta["meta_field"] == "test3"
assert result["output_3"][0].meta["meta_field"] == "test5"
@pytest.mark.unit
def test_routedocuments_by_metafield_return_remaning(docs):
route_documents = RouteDocuments(
split_by="meta_field", metadata_values=["test1", "test3", "test5"], return_remaining=True
)
assert route_documents.outgoing_edges == 4
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert len(result["output_4"]) == 2
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_2"][0].meta["meta_field"] == "test3"
assert result["output_3"][0].meta["meta_field"] == "test5"
assert result["output_4"][0].meta["meta_field"] == "test2"
@pytest.mark.unit
def test_routedocuments_by_metafield_list_of_lists(docs):
route_documents = RouteDocuments(
split_by="meta_field", metadata_values=[["test1", "test3"], "test5"], return_remaining=True
)
assert route_documents.outgoing_edges == 3
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 2
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 2
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_1"][1].meta["meta_field"] == "test3"
assert result["output_2"][0].meta["meta_field"] == "test5"
assert result["output_3"][0].meta["meta_field"] == "test2"