mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 06:43:58 +00:00
feat: Rename FileExtensionRouter to FileTypeRouter, handle ByteStream(s) (#5998)
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
07048791aa
commit
98215aec0d
4
haystack/preview/components/routers/__init__.py
Normal file
4
haystack/preview/components/routers/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from haystack.preview.components.routers.file_type_router import FileTypeRouter
|
||||||
|
from haystack.preview.components.routers.metadata_router import MetadataRouter
|
||||||
|
|
||||||
|
__all__ = ["FileTypeRouter", "MetadataRouter"]
|
||||||
@ -5,26 +5,27 @@ from pathlib import Path
|
|||||||
from typing import List, Union, Optional, Dict, Any
|
from typing import List, Union, Optional, Dict, Any
|
||||||
|
|
||||||
from haystack.preview import component, default_from_dict, default_to_dict
|
from haystack.preview import component, default_from_dict, default_to_dict
|
||||||
|
from haystack.preview.dataclasses import ByteStream
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class FileExtensionRouter:
|
class FileTypeRouter:
|
||||||
"""
|
"""
|
||||||
A component that routes files based on their MIME types read from their file extensions. This component
|
FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
|
||||||
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
|
MIME types. For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
|
||||||
|
are determined from the provided metadata.
|
||||||
|
|
||||||
The FileExtensionRouter takes a list of file paths and groups them by their MIME types.
|
The set of MIME types to consider is specified during the initialization of the component.
|
||||||
The list of MIME types to consider is provided during the initialization of the component.
|
|
||||||
|
|
||||||
This component is particularly useful when working with a large number of files, and you
|
This component is invaluable when categorizing a large collection of files or data streams by their MIME
|
||||||
want to categorize them based on their MIME types.
|
types and routing them to different components for further processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mime_types: List[str]):
|
def __init__(self, mime_types: List[str]):
|
||||||
"""
|
"""
|
||||||
Initialize the FileExtensionRouter.
|
Initialize the FileTypeRouter.
|
||||||
|
|
||||||
:param mime_types: A list of file mime types to consider when routing
|
:param mime_types: A list of file mime types to consider when routing
|
||||||
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
|
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
|
||||||
@ -48,31 +49,36 @@ class FileExtensionRouter:
|
|||||||
return default_to_dict(self, mime_types=self.mime_types)
|
return default_to_dict(self, mime_types=self.mime_types)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionRouter":
|
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
|
||||||
"""
|
"""
|
||||||
Deserialize this component from a dictionary.
|
Deserialize this component from a dictionary.
|
||||||
"""
|
"""
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
def run(self, paths: List[Union[str, Path]]):
|
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
|
||||||
"""
|
"""
|
||||||
Run the FileExtensionRouter.
|
Categorizes the provided data sources by their MIME types.
|
||||||
|
|
||||||
This method takes the input data, iterates through the provided file paths, checks the file
|
:param sources: A list of file paths or byte streams to categorize.
|
||||||
mime type of each file, and groups the file paths by their mime types.
|
:return: A dictionary where keys are MIME types and values are lists of data sources.
|
||||||
|
|
||||||
:param paths: The input data containing the file paths to route.
|
|
||||||
:return: The output data containing the routed file paths.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mime_types = defaultdict(list)
|
mime_types = defaultdict(list)
|
||||||
for path in paths:
|
for source in sources:
|
||||||
if isinstance(path, str):
|
if isinstance(source, str):
|
||||||
path = Path(path)
|
source = Path(source)
|
||||||
mime_type = self.get_mime_type(path)
|
|
||||||
if mime_type in self.mime_types:
|
if isinstance(source, Path):
|
||||||
mime_types[mime_type].append(path)
|
mime_type = self.get_mime_type(source)
|
||||||
|
elif isinstance(source, ByteStream):
|
||||||
|
mime_type = source.metadata.get("content_type")
|
||||||
else:
|
else:
|
||||||
mime_types["unclassified"].append(path)
|
raise ValueError(f"Unsupported data source type: {type(source)}")
|
||||||
|
|
||||||
|
if mime_type in self.mime_types:
|
||||||
|
mime_types[mime_type].append(source)
|
||||||
|
else:
|
||||||
|
mime_types["unclassified"].append(source)
|
||||||
|
|
||||||
return mime_types
|
return mime_types
|
||||||
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Enhanced file routing capabilities with the introduction of `ByteStream` handling, and
|
||||||
|
improved clarity by renaming the router to `FileTypeRouter`.
|
||||||
@ -2,30 +2,31 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.preview.components.routers.file_router import FileExtensionRouter
|
from haystack.preview.components.routers.file_type_router import FileTypeRouter
|
||||||
|
from haystack.preview.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform in ["win32", "cygwin"],
|
sys.platform in ["win32", "cygwin"],
|
||||||
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
||||||
)
|
)
|
||||||
class TestFileExtensionRouter:
|
class TestFileTypeRouter:
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
component = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "FileExtensionRouter",
|
"type": "FileTypeRouter",
|
||||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict(self):
|
def test_from_dict(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "FileExtensionRouter",
|
"type": "FileTypeRouter",
|
||||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||||
}
|
}
|
||||||
component = FileExtensionRouter.from_dict(data)
|
component = FileTypeRouter.from_dict(data)
|
||||||
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
|
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -40,21 +41,78 @@ class TestFileExtensionRouter:
|
|||||||
preview_samples_path / "images" / "apple.jpg",
|
preview_samples_path / "images" / "apple.jpg",
|
||||||
]
|
]
|
||||||
|
|
||||||
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
output = router.run(paths=file_paths)
|
output = router.run(sources=file_paths)
|
||||||
assert output
|
assert output
|
||||||
assert len(output["text/plain"]) == 2
|
assert len(output["text/plain"]) == 2
|
||||||
assert len(output["audio/x-wav"]) == 1
|
assert len(output["audio/x-wav"]) == 1
|
||||||
assert len(output["image/jpeg"]) == 1
|
assert len(output["image/jpeg"]) == 1
|
||||||
assert not output["unclassified"]
|
assert not output["unclassified"]
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_run_with_bytestreams(self, preview_samples_path):
|
||||||
|
"""
|
||||||
|
Test if the component runs correctly with ByteStream inputs.
|
||||||
|
"""
|
||||||
|
file_paths = [
|
||||||
|
preview_samples_path / "txt" / "doc_1.txt",
|
||||||
|
preview_samples_path / "txt" / "doc_2.txt",
|
||||||
|
preview_samples_path / "audio" / "the context for this answer is here.wav",
|
||||||
|
preview_samples_path / "images" / "apple.jpg",
|
||||||
|
]
|
||||||
|
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
|
||||||
|
# Convert file paths to ByteStream objects and set metadata
|
||||||
|
byte_streams = []
|
||||||
|
for path, mime_type in zip(file_paths, mime_types):
|
||||||
|
stream = ByteStream(path.read_bytes())
|
||||||
|
|
||||||
|
stream.metadata["content_type"] = mime_type
|
||||||
|
|
||||||
|
byte_streams.append(stream)
|
||||||
|
|
||||||
|
# add unclassified ByteStream
|
||||||
|
bs = ByteStream(b"unclassified content")
|
||||||
|
bs.metadata["content_type"] = "unknown_type"
|
||||||
|
byte_streams.append(bs)
|
||||||
|
|
||||||
|
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
|
output = router.run(sources=byte_streams)
|
||||||
|
assert output
|
||||||
|
assert len(output["text/plain"]) == 2
|
||||||
|
assert len(output["audio/x-wav"]) == 1
|
||||||
|
assert len(output["image/jpeg"]) == 1
|
||||||
|
assert len(output.get("unclassified")) == 1
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_run_with_bytestreams_and_file_paths(self, preview_samples_path):
|
||||||
|
file_paths = [
|
||||||
|
preview_samples_path / "txt" / "doc_1.txt",
|
||||||
|
preview_samples_path / "audio" / "the context for this answer is here.wav",
|
||||||
|
preview_samples_path / "txt" / "doc_2.txt",
|
||||||
|
preview_samples_path / "images" / "apple.jpg",
|
||||||
|
]
|
||||||
|
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg"]
|
||||||
|
byte_stream_sources = []
|
||||||
|
for path, mime_type in zip(file_paths, mime_types):
|
||||||
|
stream = ByteStream(path.read_bytes())
|
||||||
|
stream.metadata["content_type"] = mime_type
|
||||||
|
byte_stream_sources.append(stream)
|
||||||
|
|
||||||
|
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
|
||||||
|
|
||||||
|
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
|
output = router.run(sources=mixed_sources)
|
||||||
|
assert len(output["text/plain"]) == 2
|
||||||
|
assert len(output["audio/x-wav"]) == 1
|
||||||
|
assert len(output["image/jpeg"]) == 1
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_no_files(self):
|
def test_no_files(self):
|
||||||
"""
|
"""
|
||||||
Test that the component runs correctly when no files are provided.
|
Test that the component runs correctly when no files are provided.
|
||||||
"""
|
"""
|
||||||
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
output = router.run(paths=[])
|
output = router.run(sources=[])
|
||||||
assert not output
|
assert not output
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -67,8 +125,8 @@ class TestFileExtensionRouter:
|
|||||||
preview_samples_path / "audio" / "ignored.mp3",
|
preview_samples_path / "audio" / "ignored.mp3",
|
||||||
preview_samples_path / "audio" / "this is the content of the document.wav",
|
preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||||
]
|
]
|
||||||
router = FileExtensionRouter(mime_types=["text/plain"])
|
router = FileTypeRouter(mime_types=["text/plain"])
|
||||||
output = router.run(paths=file_paths)
|
output = router.run(sources=file_paths)
|
||||||
assert len(output["text/plain"]) == 1
|
assert len(output["text/plain"]) == 1
|
||||||
assert "mp3" not in output
|
assert "mp3" not in output
|
||||||
assert len(output["unclassified"]) == 2
|
assert len(output["unclassified"]) == 2
|
||||||
@ -85,8 +143,8 @@ class TestFileExtensionRouter:
|
|||||||
preview_samples_path / "txt" / "doc_2",
|
preview_samples_path / "txt" / "doc_2",
|
||||||
preview_samples_path / "txt" / "doc_2.txt",
|
preview_samples_path / "txt" / "doc_2.txt",
|
||||||
]
|
]
|
||||||
router = FileExtensionRouter(mime_types=["text/plain"])
|
router = FileTypeRouter(mime_types=["text/plain"])
|
||||||
output = router.run(paths=file_paths)
|
output = router.run(sources=file_paths)
|
||||||
assert len(output["text/plain"]) == 2
|
assert len(output["text/plain"]) == 2
|
||||||
assert len(output["unclassified"]) == 1
|
assert len(output["unclassified"]) == 1
|
||||||
|
|
||||||
@ -96,4 +154,4 @@ class TestFileExtensionRouter:
|
|||||||
Test that the component handles files with unknown mime types.
|
Test that the component handles files with unknown mime types.
|
||||||
"""
|
"""
|
||||||
with pytest.raises(ValueError, match="Unknown mime type:"):
|
with pytest.raises(ValueError, match="Unknown mime type:"):
|
||||||
FileExtensionRouter(mime_types=["type_invalid"])
|
FileTypeRouter(mime_types=["type_invalid"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user