mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
398 lines
16 KiB
Python
398 lines
16 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import io
|
|
import sys
|
|
from unittest.mock import mock_open, patch
|
|
|
|
import pytest
|
|
|
|
from haystack.components.routers.file_type_router import FileTypeRouter
|
|
from haystack.components.converters import TextFileToDocument, PyPDFToDocument
|
|
from haystack.dataclasses import ByteStream
|
|
from haystack import Pipeline
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.platform in ["win32", "cygwin"],
|
|
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
|
)
|
|
class TestFileTypeRouter:
|
|
def test_init(self):
|
|
"""
|
|
Test that the component initializes correctly.
|
|
"""
|
|
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
|
assert router.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
|
|
assert router._additional_mimetypes is None
|
|
|
|
router = FileTypeRouter(
|
|
mime_types=["text/plain"],
|
|
additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"},
|
|
)
|
|
assert router.mime_types == ["text/plain"]
|
|
assert router._additional_mimetypes == {
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"
|
|
}
|
|
|
|
def test_init_fail_wo_mime_types(self):
|
|
"""
|
|
Test that the component raises an error if no mime types are provided.
|
|
"""
|
|
with pytest.raises(ValueError):
|
|
FileTypeRouter(mime_types=[])
|
|
|
|
def test_to_dict(self):
|
|
router = FileTypeRouter(
|
|
mime_types=["text/plain", "audio/x-wav", "image/jpeg"],
|
|
additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"},
|
|
)
|
|
expected_dict = {
|
|
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
|
"init_parameters": {
|
|
"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"],
|
|
"additional_mimetypes": {
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"
|
|
},
|
|
},
|
|
}
|
|
assert router.to_dict() == expected_dict
|
|
|
|
def test_from_dict(self):
|
|
router_dict = {
|
|
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
|
"init_parameters": {
|
|
"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"],
|
|
"additional_mimetypes": {
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"
|
|
},
|
|
},
|
|
}
|
|
loaded_router = FileTypeRouter.from_dict(router_dict)
|
|
|
|
expected_router = FileTypeRouter(
|
|
mime_types=["text/plain", "audio/x-wav", "image/jpeg"],
|
|
additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"},
|
|
)
|
|
|
|
assert loaded_router.mime_types == expected_router.mime_types
|
|
assert loaded_router._additional_mimetypes == expected_router._additional_mimetypes
|
|
|
|
def test_run(self, test_files_path):
|
|
"""
|
|
Test if the component runs correctly in the simplest happy path.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "txt" / "doc_2.txt",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
test_files_path / "images" / "apple.jpg",
|
|
test_files_path / "msg" / "sample.msg",
|
|
]
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg", "application/vnd.ms-outlook"])
|
|
output = router.run(sources=file_paths)
|
|
assert output
|
|
assert len(output[r"text/plain"]) == 2
|
|
assert len(output[r"audio/x-wav"]) == 1
|
|
assert len(output[r"image/jpeg"]) == 1
|
|
assert len(output["application/vnd.ms-outlook"]) == 1
|
|
assert not output.get("unclassified")
|
|
|
|
def test_run_with_single_meta(self, test_files_path):
|
|
"""
|
|
Test if the component runs correctly when a single metadata dictionary is provided.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "txt" / "doc_2.txt",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
]
|
|
|
|
meta = {"meta_field": "meta_value"}
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav"])
|
|
output = router.run(sources=file_paths, meta=meta)
|
|
assert output
|
|
|
|
assert len(output[r"text/plain"]) == 2
|
|
assert len(output[r"audio/x-wav"]) == 1
|
|
assert not output.get("unclassified")
|
|
|
|
for elements in output.values():
|
|
for el in elements:
|
|
assert isinstance(el, ByteStream)
|
|
assert el.meta["meta_field"] == "meta_value"
|
|
|
|
def test_run_with_meta_list(self, test_files_path):
|
|
"""
|
|
Test if the component runs correctly when a list of metadata dictionaries is provided.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "images" / "apple.jpg",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
]
|
|
|
|
meta = [{"key1": "value1"}, {"key2": "value2"}, {"key3": "value3"}]
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
|
output = router.run(sources=file_paths, meta=meta)
|
|
assert output
|
|
|
|
assert len(output[r"text/plain"]) == 1
|
|
assert len(output[r"audio/x-wav"]) == 1
|
|
assert len(output[r"image/jpeg"]) == 1
|
|
assert not output.get("unclassified")
|
|
|
|
for i, elements in enumerate(output.values()):
|
|
for el in elements:
|
|
assert isinstance(el, ByteStream)
|
|
|
|
expected_meta_key, expected_meta_value = list(meta[i].items())[0]
|
|
assert el.meta[expected_meta_key] == expected_meta_value
|
|
|
|
def test_run_with_meta_and_bytestreams(self):
|
|
"""
|
|
Test if the component runs correctly with ByteStream inputs and meta.
|
|
The original meta is preserved and the new meta is added.
|
|
"""
|
|
|
|
bs = ByteStream.from_string("Haystack!", mime_type="text/plain", meta={"foo": "bar"})
|
|
|
|
meta = {"another_key": "another_value"}
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain"])
|
|
|
|
output = router.run(sources=[bs], meta=meta)
|
|
|
|
assert output
|
|
assert len(output[r"text/plain"]) == 1
|
|
assert not output.get("unclassified")
|
|
|
|
assert isinstance(output[r"text/plain"][0], ByteStream)
|
|
assert output[r"text/plain"][0].meta["foo"] == "bar"
|
|
assert output[r"text/plain"][0].meta["another_key"] == "another_value"
|
|
|
|
def test_run_fails_if_meta_length_does_not_match_sources(self, test_files_path):
|
|
"""
|
|
Test that the component raises an error if the length of the metadata list does not match the number of sources.
|
|
"""
|
|
file_paths = [test_files_path / "txt" / "doc_1.txt"]
|
|
|
|
meta = [{"key1": "value1"}, {"key2": "value2"}, {"key3": "value3"}]
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain"])
|
|
|
|
with pytest.raises(ValueError):
|
|
router.run(sources=file_paths, meta=meta)
|
|
|
|
def test_run_with_bytestreams(self, test_files_path):
|
|
"""
|
|
Test if the component runs correctly with ByteStream inputs.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "txt" / "doc_2.txt",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
test_files_path / "images" / "apple.jpg",
|
|
]
|
|
mime_types = [r"text/plain", r"text/plain", r"audio/x-wav", r"image/jpeg"]
|
|
# Convert file paths to ByteStream objects and set metadata
|
|
byte_streams = []
|
|
for path, mime_type in zip(file_paths, mime_types):
|
|
stream = ByteStream(path.read_bytes())
|
|
stream.mime_type = mime_type
|
|
byte_streams.append(stream)
|
|
|
|
# add unclassified ByteStream
|
|
bs = ByteStream(b"unclassified content")
|
|
bs.meta["content_type"] = "unknown_type"
|
|
byte_streams.append(bs)
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
|
output = router.run(sources=byte_streams)
|
|
assert output
|
|
assert len(output[r"text/plain"]) == 2
|
|
assert len(output[r"audio/x-wav"]) == 1
|
|
assert len(output[r"image/jpeg"]) == 1
|
|
assert len(output.get("unclassified")) == 1
|
|
|
|
def test_run_with_bytestreams_and_file_paths(self, test_files_path):
|
|
"""
|
|
Test if the component raises an error for unsupported data source types.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
test_files_path / "txt" / "doc_2.txt",
|
|
test_files_path / "images" / "apple.jpg",
|
|
test_files_path / "markdown" / "sample.md",
|
|
]
|
|
mime_types = [r"text/plain", r"audio/x-wav", r"text/plain", r"image/jpeg", r"text/markdown"]
|
|
byte_stream_sources = []
|
|
for path, mime_type in zip(file_paths, mime_types):
|
|
stream = ByteStream(path.read_bytes())
|
|
stream.mime_type = mime_type
|
|
byte_stream_sources.append(stream)
|
|
|
|
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
|
|
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg", r"text/markdown"])
|
|
output = router.run(sources=mixed_sources)
|
|
assert len(output[r"text/plain"]) == 2
|
|
assert len(output[r"audio/x-wav"]) == 1
|
|
assert len(output[r"image/jpeg"]) == 1
|
|
assert len(output[r"text/markdown"]) == 1
|
|
|
|
def test_no_files(self):
|
|
"""
|
|
Test that the component runs correctly when no files are provided.
|
|
"""
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
|
output = router.run(sources=[])
|
|
assert not output
|
|
|
|
def test_unlisted_extensions(self, test_files_path):
|
|
"""
|
|
Test that the component correctly handles files with non specified mime types.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "audio" / "ignored.mp3",
|
|
test_files_path / "audio" / "this is the content of the document.wav",
|
|
]
|
|
router = FileTypeRouter(mime_types=[r"text/plain"])
|
|
output = router.run(sources=file_paths)
|
|
assert len(output[r"text/plain"]) == 1
|
|
assert "mp3" not in output
|
|
assert len(output.get("unclassified")) == 2
|
|
|
|
def test_no_extension(self, test_files_path):
|
|
"""
|
|
Test that the component ignores files with no extension.
|
|
"""
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "txt" / "doc_2",
|
|
test_files_path / "txt" / "doc_2.txt",
|
|
]
|
|
router = FileTypeRouter(mime_types=[r"text/plain"])
|
|
output = router.run(sources=file_paths)
|
|
assert len(output[r"text/plain"]) == 2
|
|
assert len(output.get("unclassified")) == 1
|
|
|
|
def test_unsupported_source_type(self):
|
|
"""
|
|
Test if the component raises an error for unsupported data source types.
|
|
"""
|
|
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
|
with pytest.raises(ValueError, match="Unsupported data source type:"):
|
|
router.run(sources=[{"unsupported": "type"}])
|
|
|
|
def test_invalid_regex_pattern(self):
|
|
"""
|
|
Test that the component raises a ValueError for invalid regex patterns.
|
|
"""
|
|
with pytest.raises(ValueError, match="Invalid regex pattern"):
|
|
FileTypeRouter(mime_types=["[Invalid-Regex"])
|
|
|
|
def test_regex_mime_type_matching(self, test_files_path):
|
|
"""
|
|
Test if the component correctly matches mime types using regex.
|
|
"""
|
|
router = FileTypeRouter(mime_types=[r"text\/.*", r"audio\/.*", r"image\/.*"])
|
|
file_paths = [
|
|
test_files_path / "txt" / "doc_1.txt",
|
|
test_files_path / "audio" / "the context for this answer is here.wav",
|
|
test_files_path / "images" / "apple.jpg",
|
|
]
|
|
output = router.run(sources=file_paths)
|
|
assert len(output[r"text\/.*"]) == 1, "Failed to match text file with regex"
|
|
assert len(output[r"audio\/.*"]) == 1, "Failed to match audio file with regex"
|
|
assert len(output[r"image\/.*"]) == 1, "Failed to match image file with regex"
|
|
|
|
@patch("pathlib.Path.open", new_callable=mock_open, read_data=b"Mock file content.")
|
|
def test_exact_mime_type_matching(self, mock_file):
|
|
"""
|
|
Test if the component correctly matches mime types exactly, without regex patterns.
|
|
"""
|
|
txt_stream = ByteStream(io.BytesIO(b"Text file content").read())
|
|
txt_stream.mime_type = "text/plain"
|
|
jpg_stream = ByteStream(io.BytesIO(b"JPEG file content").read())
|
|
jpg_stream.mime_type = "image/jpeg"
|
|
mp3_stream = ByteStream(io.BytesIO(b"MP3 file content").read())
|
|
mp3_stream.mime_type = "audio/mpeg"
|
|
|
|
byte_streams = [txt_stream, jpg_stream, mp3_stream]
|
|
|
|
router = FileTypeRouter(mime_types=["text/plain", "image/jpeg"])
|
|
|
|
output = router.run(sources=byte_streams)
|
|
|
|
assert len(output["text/plain"]) == 1, "Failed to match 'text/plain' MIME type exactly"
|
|
assert txt_stream in output["text/plain"], "'doc_1.txt' ByteStream not correctly classified as 'text/plain'"
|
|
|
|
assert len(output["image/jpeg"]) == 1, "Failed to match 'image/jpeg' MIME type exactly"
|
|
assert jpg_stream in output["image/jpeg"], "'apple.jpg' ByteStream not correctly classified as 'image/jpeg'"
|
|
|
|
assert len(output.get("unclassified")) == 1, "Failed to handle unclassified file types"
|
|
assert mp3_stream in output["unclassified"], "'sound.mp3' ByteStream should be unclassified but is not"
|
|
|
|
def test_serde_in_pipeline(self):
|
|
"""
|
|
Test if a pipeline containing the component can be serialized and deserialized without errors.
|
|
"""
|
|
|
|
file_type_router = FileTypeRouter(mime_types=["text/plain", "application/pdf"])
|
|
|
|
pipeline = Pipeline()
|
|
pipeline.add_component(instance=file_type_router, name="file_type_router")
|
|
|
|
pipeline_dict = pipeline.to_dict()
|
|
|
|
assert pipeline_dict == {
|
|
"metadata": {},
|
|
"max_runs_per_component": 100,
|
|
"connection_type_validation": True,
|
|
"components": {
|
|
"file_type_router": {
|
|
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
|
"init_parameters": {"mime_types": ["text/plain", "application/pdf"], "additional_mimetypes": None},
|
|
}
|
|
},
|
|
"connections": [],
|
|
}
|
|
|
|
pipeline_yaml = pipeline.dumps()
|
|
|
|
new_pipeline = Pipeline.loads(pipeline_yaml)
|
|
assert new_pipeline == pipeline
|
|
|
|
@pytest.mark.integration
|
|
def test_pipeline_with_converters(self, test_files_path):
|
|
"""
|
|
Test if the component runs correctly in a pipeline with converters and passes metadata correctly.
|
|
"""
|
|
file_type_router = FileTypeRouter(
|
|
mime_types=["text/plain", "application/pdf"],
|
|
additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"},
|
|
)
|
|
|
|
pipe = Pipeline()
|
|
pipe.add_component(instance=file_type_router, name="file_type_router")
|
|
pipe.add_component(instance=TextFileToDocument(), name="text_file_converter")
|
|
pipe.add_component(instance=PyPDFToDocument(), name="pypdf_converter")
|
|
pipe.connect("file_type_router.text/plain", "text_file_converter.sources")
|
|
pipe.connect("file_type_router.application/pdf", "pypdf_converter.sources")
|
|
|
|
file_paths = [test_files_path / "txt" / "doc_1.txt", test_files_path / "pdf" / "sample_pdf_1.pdf"]
|
|
|
|
meta = [{"meta_field_1": "meta_value_1"}, {"meta_field_2": "meta_value_2"}]
|
|
|
|
output = pipe.run(data={"file_type_router": {"sources": file_paths, "meta": meta}})
|
|
|
|
assert output["text_file_converter"]["documents"][0].meta["meta_field_1"] == "meta_value_1"
|
|
assert output["pypdf_converter"]["documents"][0].meta["meta_field_2"] == "meta_value_2"
|