feat: Rename FileExtensionRouter to FileTypeRouter, handle ByteStream(s) (#5998)

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Vladimir Blagojevic 2023-10-10 09:14:04 +02:00 committed by GitHub
parent 07048791aa
commit 98215aec0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 38 deletions

View File

@ -0,0 +1,4 @@
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter
__all__ = ["FileTypeRouter", "MetadataRouter"]

View File

@ -5,26 +5,27 @@ from pathlib import Path
from typing import List, Union, Optional, Dict, Any
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream
logger = logging.getLogger(__name__)
@component
class FileExtensionRouter:
class FileTypeRouter:
"""
A component that routes files based on their MIME types read from their file extensions. This component
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
MIME types. For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
are determined from the provided metadata.
The FileExtensionRouter takes a list of file paths and groups them by their MIME types.
The list of MIME types to consider is provided during the initialization of the component.
The set of MIME types to consider is specified during the initialization of the component.
This component is particularly useful when working with a large number of files, and you
want to categorize them based on their MIME types.
This component is invaluable when categorizing a large collection of files or data streams by their MIME
types and routing them to different components for further processing.
"""
def __init__(self, mime_types: List[str]):
"""
Initialize the FileExtensionRouter.
Initialize the FileTypeRouter.
:param mime_types: A list of file mime types to consider when routing
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
@ -48,31 +49,36 @@ class FileExtensionRouter:
return default_to_dict(self, mime_types=self.mime_types)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionRouter":
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def run(self, paths: List[Union[str, Path]]):
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
"""
Run the FileExtensionRouter.
Categorizes the provided data sources by their MIME types.
This method takes the input data, iterates through the provided file paths, checks the file
mime type of each file, and groups the file paths by their mime types.
:param paths: The input data containing the file paths to route.
:return: The output data containing the routed file paths.
:param sources: A list of file paths or byte streams to categorize.
:return: A dictionary where keys are MIME types and values are lists of data sources.
"""
mime_types = defaultdict(list)
for path in paths:
if isinstance(path, str):
path = Path(path)
mime_type = self.get_mime_type(path)
if mime_type in self.mime_types:
mime_types[mime_type].append(path)
for source in sources:
if isinstance(source, str):
source = Path(source)
if isinstance(source, Path):
mime_type = self.get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.metadata.get("content_type")
else:
mime_types["unclassified"].append(path)
raise ValueError(f"Unsupported data source type: {type(source)}")
if mime_type in self.mime_types:
mime_types[mime_type].append(source)
else:
mime_types["unclassified"].append(source)
return mime_types

View File

@ -0,0 +1,5 @@
---
preview:
- |
Enhanced file routing capabilities with the introduction of `ByteStream` handling, and
improved clarity by renaming the router to `FileTypeRouter`.

View File

@ -2,30 +2,31 @@ import sys
import pytest
from haystack.preview.components.routers.file_router import FileExtensionRouter
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.dataclasses import ByteStream
@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionRouter:
class TestFileTypeRouter:
@pytest.mark.unit
def test_to_dict(self):
component = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict()
assert data == {
"type": "FileExtensionRouter",
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "FileExtensionRouter",
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
component = FileExtensionRouter.from_dict(data)
component = FileTypeRouter.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
@pytest.mark.unit
@ -40,21 +41,78 @@ class TestFileExtensionRouter:
preview_samples_path / "images" / "apple.jpg",
]
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]
@pytest.mark.unit
def test_run_with_bytestreams(self, preview_samples_path):
"""
Test if the component runs correctly with ByteStream inputs.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
# Convert file paths to ByteStream objects and set metadata
byte_streams = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.metadata["content_type"] = mime_type
byte_streams.append(stream)
# add unclassified ByteStream
bs = ByteStream(b"unclassified content")
bs.metadata["content_type"] = "unknown_type"
byte_streams.append(bs)
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=byte_streams)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert len(output.get("unclassified")) == 1
@pytest.mark.unit
def test_run_with_bytestreams_and_file_paths(self, preview_samples_path):
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg"]
byte_stream_sources = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.metadata["content_type"] = mime_type
byte_stream_sources.append(stream)
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=mixed_sources)
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
@pytest.mark.unit
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=[])
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=[])
assert not output
@pytest.mark.unit
@ -67,8 +125,8 @@ class TestFileExtensionRouter:
preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav",
]
router = FileExtensionRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
@ -85,8 +143,8 @@ class TestFileExtensionRouter:
preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt",
]
router = FileExtensionRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1
@ -96,4 +154,4 @@ class TestFileExtensionRouter:
Test that the component handles files with unknown mime types.
"""
with pytest.raises(ValueError, match="Unknown mime type:"):
FileExtensionRouter(mime_types=["type_invalid"])
FileTypeRouter(mime_types=["type_invalid"])