feat: Rename FileExtensionRouter to FileTypeRouter, handle ByteStream(s) (#5998)

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Vladimir Blagojevic 2023-10-10 09:14:04 +02:00 committed by GitHub
parent 07048791aa
commit 98215aec0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 38 deletions

View File

@ -0,0 +1,4 @@
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter
__all__ = ["FileTypeRouter", "MetadataRouter"]

View File

@ -5,26 +5,27 @@ from pathlib import Path
from typing import List, Union, Optional, Dict, Any from typing import List, Union, Optional, Dict, Any
from haystack.preview import component, default_from_dict, default_to_dict from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@component @component
class FileExtensionRouter: class FileTypeRouter:
""" """
A component that routes files based on their MIME types read from their file extensions. This component FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
does not read the file contents, but rather uses the file extension to determine the MIME type of the file. MIME types. For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
are determined from the provided metadata.
The FileExtensionRouter takes a list of file paths and groups them by their MIME types. The set of MIME types to consider is specified during the initialization of the component.
The list of MIME types to consider is provided during the initialization of the component.
This component is particularly useful when working with a large number of files, and you This component is invaluable when categorizing a large collection of files or data streams by their MIME
want to categorize them based on their MIME types. types and routing them to different components for further processing.
""" """
def __init__(self, mime_types: List[str]): def __init__(self, mime_types: List[str]):
""" """
Initialize the FileExtensionRouter. Initialize the FileTypeRouter.
:param mime_types: A list of file mime types to consider when routing :param mime_types: A list of file mime types to consider when routing
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]). files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
@ -48,31 +49,36 @@ class FileExtensionRouter:
return default_to_dict(self, mime_types=self.mime_types) return default_to_dict(self, mime_types=self.mime_types)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionRouter": def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
return default_from_dict(cls, data) return default_from_dict(cls, data)
def run(self, paths: List[Union[str, Path]]): def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
""" """
Run the FileExtensionRouter. Categorizes the provided data sources by their MIME types.
This method takes the input data, iterates through the provided file paths, checks the file :param sources: A list of file paths or byte streams to categorize.
mime type of each file, and groups the file paths by their mime types. :return: A dictionary where keys are MIME types and values are lists of data sources.
:param paths: The input data containing the file paths to route.
:return: The output data containing the routed file paths.
""" """
mime_types = defaultdict(list) mime_types = defaultdict(list)
for path in paths: for source in sources:
if isinstance(path, str): if isinstance(source, str):
path = Path(path) source = Path(source)
mime_type = self.get_mime_type(path)
if mime_type in self.mime_types: if isinstance(source, Path):
mime_types[mime_type].append(path) mime_type = self.get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.metadata.get("content_type")
else: else:
mime_types["unclassified"].append(path) raise ValueError(f"Unsupported data source type: {type(source)}")
if mime_type in self.mime_types:
mime_types[mime_type].append(source)
else:
mime_types["unclassified"].append(source)
return mime_types return mime_types

View File

@ -0,0 +1,5 @@
---
preview:
- |
Enhanced file routing capabilities with the introduction of `ByteStream` handling, and
improved clarity by renaming the router to `FileTypeRouter`.

View File

@ -2,30 +2,31 @@ import sys
import pytest import pytest
from haystack.preview.components.routers.file_router import FileExtensionRouter from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.dataclasses import ByteStream
@pytest.mark.skipif( @pytest.mark.skipif(
sys.platform in ["win32", "cygwin"], sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types", reason="Can't run on Windows Github CI, need access to registry to get mime types",
) )
class TestFileExtensionRouter: class TestFileTypeRouter:
@pytest.mark.unit @pytest.mark.unit
def test_to_dict(self): def test_to_dict(self):
component = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "FileExtensionRouter", "type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]}, "init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
} }
@pytest.mark.unit @pytest.mark.unit
def test_from_dict(self): def test_from_dict(self):
data = { data = {
"type": "FileExtensionRouter", "type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]}, "init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
} }
component = FileExtensionRouter.from_dict(data) component = FileTypeRouter.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"] assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
@pytest.mark.unit @pytest.mark.unit
@ -40,21 +41,78 @@ class TestFileExtensionRouter:
preview_samples_path / "images" / "apple.jpg", preview_samples_path / "images" / "apple.jpg",
] ]
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=file_paths) output = router.run(sources=file_paths)
assert output assert output
assert len(output["text/plain"]) == 2 assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1 assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1 assert len(output["image/jpeg"]) == 1
assert not output["unclassified"] assert not output["unclassified"]
@pytest.mark.unit
def test_run_with_bytestreams(self, preview_samples_path):
"""
Test if the component runs correctly with ByteStream inputs.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
# Convert file paths to ByteStream objects and set metadata
byte_streams = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.metadata["content_type"] = mime_type
byte_streams.append(stream)
# add unclassified ByteStream
bs = ByteStream(b"unclassified content")
bs.metadata["content_type"] = "unknown_type"
byte_streams.append(bs)
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=byte_streams)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert len(output.get("unclassified")) == 1
@pytest.mark.unit
def test_run_with_bytestreams_and_file_paths(self, preview_samples_path):
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg"]
byte_stream_sources = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.metadata["content_type"] = mime_type
byte_stream_sources.append(stream)
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=mixed_sources)
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
@pytest.mark.unit @pytest.mark.unit
def test_no_files(self): def test_no_files(self):
""" """
Test that the component runs correctly when no files are provided. Test that the component runs correctly when no files are provided.
""" """
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=[]) output = router.run(sources=[])
assert not output assert not output
@pytest.mark.unit @pytest.mark.unit
@ -67,8 +125,8 @@ class TestFileExtensionRouter:
preview_samples_path / "audio" / "ignored.mp3", preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav", preview_samples_path / "audio" / "this is the content of the document.wav",
] ]
router = FileExtensionRouter(mime_types=["text/plain"]) router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths) output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 1 assert len(output["text/plain"]) == 1
assert "mp3" not in output assert "mp3" not in output
assert len(output["unclassified"]) == 2 assert len(output["unclassified"]) == 2
@ -85,8 +143,8 @@ class TestFileExtensionRouter:
preview_samples_path / "txt" / "doc_2", preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt", preview_samples_path / "txt" / "doc_2.txt",
] ]
router = FileExtensionRouter(mime_types=["text/plain"]) router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths) output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 2 assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1 assert len(output["unclassified"]) == 1
@ -96,4 +154,4 @@ class TestFileExtensionRouter:
Test that the component handles files with unknown mime types. Test that the component handles files with unknown mime types.
""" """
with pytest.raises(ValueError, match="Unknown mime type:"): with pytest.raises(ValueError, match="Unknown mime type:"):
FileExtensionRouter(mime_types=["type_invalid"]) FileTypeRouter(mime_types=["type_invalid"])