haystack/test/nodes/test_route_documents.py
Sebastian 8c4176bdb2
feat: More flexible routing for RouteDocuments node (#4690)
* Added warning messages for documents that are skipped by RouteDocuments. Begun adding support for new option return_remaining and List of List support for metadata value splitting.

* Simplify _split_by_content_type

* Added new unit test and updated _calculate_outgoing_edges

* Added some TODOs and turned assert into raising an error.

* Update logging messages and make new fixture in tests

* Update _split_by_metadata_values to work with return_remaining

* Remove unneeded code

* Documentation

* Add proper support for list of lists

* Fix mypy errors

* Added assert to make mypy happy

* Update haystack/nodes/other/route_documents.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* PR comments

* Remove check for logging level

* make mypy happy

* Update docstring of metadata_values

* Removed duplicate check. Make explicit check for metadata_values

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-04-18 15:18:13 +02:00

99 lines
3.7 KiB
Python

import pytest
import pandas as pd
from haystack.schema import Document
from haystack.nodes import RouteDocuments
@pytest.fixture
def docs_diff_types():
return [
Document(content="text document", content_type="text"),
Document(
content=pd.DataFrame(columns=["col 1", "col 2"], data=[["row 1", "row 1"], ["row 2", "row 2"]]),
content_type="table",
),
Document(content="image/path", content_type="image"),
]
@pytest.fixture
def docs_with_meta():
return [
Document(content="text document 1", content_type="text", meta={"meta_field": "test1"}),
Document(content="text document 2", content_type="text", meta={"meta_field": "test2"}),
Document(content="text document 3", content_type="text", meta={"meta_field": "test3"}),
]
@pytest.mark.unit
def test_routedocuments_by_content_type(docs_diff_types):
route_documents = RouteDocuments()
result, _ = route_documents.run(documents=docs_diff_types)
assert route_documents.outgoing_edges == 2
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert "output_3" not in result
assert result["output_1"][0].content_type == "text"
assert result["output_2"][0].content_type == "table"
@pytest.mark.unit
def test_routedocuments_by_content_type_return_remaining(docs_diff_types):
route_documents = RouteDocuments(return_remaining=True)
result, _ = route_documents.run(documents=docs_diff_types)
assert route_documents.outgoing_edges == 3
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert result["output_1"][0].content_type == "text"
assert result["output_2"][0].content_type == "table"
assert result["output_3"][0].content_type == "image"
@pytest.mark.unit
def test_routedocuments_by_metafield(docs):
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
assert route_documents.outgoing_edges == 3
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert "output_4" not in result
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_2"][0].meta["meta_field"] == "test3"
assert result["output_3"][0].meta["meta_field"] == "test5"
@pytest.mark.unit
def test_routedocuments_by_metafield_return_remaning(docs):
route_documents = RouteDocuments(
split_by="meta_field", metadata_values=["test1", "test3", "test5"], return_remaining=True
)
assert route_documents.outgoing_edges == 4
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
assert len(result["output_4"]) == 2
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_2"][0].meta["meta_field"] == "test3"
assert result["output_3"][0].meta["meta_field"] == "test5"
assert result["output_4"][0].meta["meta_field"] == "test2"
@pytest.mark.unit
def test_routedocuments_by_metafield_list_of_lists(docs):
route_documents = RouteDocuments(
split_by="meta_field", metadata_values=[["test1", "test3"], "test5"], return_remaining=True
)
assert route_documents.outgoing_edges == 3
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 2
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 2
assert result["output_1"][0].meta["meta_field"] == "test1"
assert result["output_1"][1].meta["meta_field"] == "test3"
assert result["output_2"][0].meta["meta_field"] == "test5"
assert result["output_3"][0].meta["meta_field"] == "test2"