mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 06:58:35 +00:00
feat: Add support for matching mime types using regex (#7303)
* feat: Add support for matching mime types using regex --------- Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
38b3472bb2
commit
41dbbdb3fc
@ -1,4 +1,5 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
@ -12,28 +13,39 @@ logger = logging.getLogger(__name__)
|
||||
@component
|
||||
class FileTypeRouter:
|
||||
"""
|
||||
FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
|
||||
MIME types.
|
||||
FileTypeRouter groups a list of data sources (file paths or byte streams) by their MIME types, allowing
|
||||
for flexible routing of files to different components based on their content type. It supports both exact MIME type
|
||||
matching and pattern matching using regular expressions.
|
||||
|
||||
For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
|
||||
are determined from the provided metadata.
|
||||
For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types are determined from
|
||||
the provided metadata. This enables the router to classify a diverse collection of files and data streams for
|
||||
specialized processing.
|
||||
|
||||
The set of MIME types to consider is specified during the initialization of the component.
|
||||
|
||||
This component is useful when you need to classify a large collection of files or data streams according to their
|
||||
MIME types and route them to different components for further processing.
|
||||
The router's flexibility is enhanced by the support for regex patterns in the `mime_types` parameter, allowing users
|
||||
to specify broad categories (e.g., 'audio/*' or 'text/*') or more specific types with regex patterns. This feature
|
||||
is designed to be backward compatible, treating MIME types without regex patterns as exact matches.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.routers import FileTypeRouter
|
||||
from pathlib import Path
|
||||
|
||||
router = FileTypeRouter(mime_types=["text/plain"])
|
||||
# For exact MIME type matching
|
||||
router = FileTypeRouter(mime_types=["text/plain", "application/pdf"])
|
||||
|
||||
print(router.run(sources=["text_file.txt", "pdf_file.pdf"]))
|
||||
# For flexible matching using regex, to handle all audio types
|
||||
router_with_regex = FileTypeRouter(mime_types=[r"audio/.*", r"text/plain"])
|
||||
|
||||
# defaultdict(<class 'list'>, {'text/plain': [PosixPath('text_file.txt')],
|
||||
# 'unclassified': [PosixPath('pdf_file.pdf')]})
|
||||
sources = [Path("file.txt"), Path("document.pdf"), Path("song.mp3")]
|
||||
print(router.run(sources=sources))
|
||||
print(router_with_regex.run(sources=sources))
|
||||
|
||||
# Expected output:
|
||||
# {'text/plain': [PosixPath('file.txt')], 'application/pdf': [PosixPath('document.pdf')], 'unclassified': [PosixPath('song.mp3')]}
|
||||
# {'audio/.*': [PosixPath('song.mp3')], 'text/plain': [PosixPath('file.txt')], 'unclassified': [PosixPath('document.pdf')]}
|
||||
```
|
||||
|
||||
:param mime_types: A list of MIME types or regex patterns to classify the incoming files or data streams.
|
||||
"""
|
||||
|
||||
def __init__(self, mime_types: List[str]):
|
||||
@ -44,11 +56,12 @@ class FileTypeRouter:
|
||||
if not mime_types:
|
||||
raise ValueError("The list of mime types cannot be empty.")
|
||||
|
||||
self.mime_type_patterns = []
|
||||
for mime_type in mime_types:
|
||||
if not self._is_valid_mime_type_format(mime_type):
|
||||
raise ValueError(
|
||||
f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter"
|
||||
)
|
||||
raise ValueError(f"Invalid mime type or regex pattern: '{mime_type}'.")
|
||||
pattern = re.compile(mime_type)
|
||||
self.mime_type_patterns.append(pattern)
|
||||
|
||||
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
|
||||
self.mime_types = mime_types
|
||||
@ -66,20 +79,24 @@ class FileTypeRouter:
|
||||
for source in sources:
|
||||
if isinstance(source, str):
|
||||
source = Path(source)
|
||||
|
||||
if isinstance(source, Path):
|
||||
mime_type = self._get_mime_type(source)
|
||||
elif isinstance(source, ByteStream):
|
||||
mime_type = source.meta.get("content_type")
|
||||
mime_type = source.meta.get("content_type", None)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source type: {type(source)}")
|
||||
raise ValueError(f"Unsupported data source type: {type(source).__name__}")
|
||||
|
||||
if mime_type in self.mime_types:
|
||||
mime_types[mime_type].append(source)
|
||||
else:
|
||||
matched = False
|
||||
if mime_type:
|
||||
for pattern in self.mime_type_patterns:
|
||||
if pattern.fullmatch(mime_type):
|
||||
mime_types[pattern.pattern].append(source)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
mime_types["unclassified"].append(source)
|
||||
|
||||
return mime_types
|
||||
return dict(mime_types)
|
||||
|
||||
def _get_mime_type(self, path: Path) -> Optional[str]:
|
||||
"""
|
||||
@ -96,13 +113,17 @@ class FileTypeRouter:
|
||||
|
||||
def _is_valid_mime_type_format(self, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if the provided MIME type is in valid format
|
||||
Checks if the provided MIME type string is a valid regex pattern.
|
||||
|
||||
:param mime_type: The MIME type to check.
|
||||
|
||||
:returns: `True` if the provided MIME type is a valid MIME type format, `False` otherwise.
|
||||
:param mime_type: The MIME type or regex pattern to validate.
|
||||
:raises ValueError: If the mime_type is not a valid regex pattern.
|
||||
:returns: Always True because a ValueError is raised for invalid patterns.
|
||||
"""
|
||||
return mime_type in mimetypes.types_map.values() or mime_type in self._get_custom_mime_mappings().values()
|
||||
try:
|
||||
re.compile(mime_type)
|
||||
return True
|
||||
except re.error:
|
||||
raise ValueError(f"Invalid regex pattern '{mime_type}'.")
|
||||
|
||||
@staticmethod
|
||||
def _get_custom_mime_mappings() -> Dict[str, str]:
|
||||
|
||||
@ -0,0 +1,24 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enhanced FileTypeRouter with Regex Pattern Support for MIME Types: This introduces a significant enhancement to the `FileTypeRouter`, now featuring support for regex pattern matching for MIME types. This powerful addition allows for more granular control and flexibility in routing files based on their MIME types, enabling the handling of broad categories or specific MIME type patterns with ease. This feature is particularly beneficial for applications requiring sophisticated file classification and routing logic.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.routers import FileTypeRouter
|
||||
|
||||
router = FileTypeRouter(mime_types=[r"text/.*", r"application/(pdf|json)"])
|
||||
|
||||
# Example files to classify
|
||||
file_paths = [
|
||||
Path("document.pdf"),
|
||||
Path("report.json"),
|
||||
Path("notes.txt"),
|
||||
Path("image.png"),
|
||||
]
|
||||
|
||||
result = router.run(sources=file_paths)
|
||||
|
||||
for mime_type, files in result.items():
|
||||
print(f"MIME Type: {mime_type}, Files: {[str(file) for file in files]}")
|
||||
```
|
||||
@ -1,4 +1,6 @@
|
||||
import io
|
||||
import sys
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -22,13 +24,13 @@ class TestFileTypeRouter:
|
||||
test_files_path / "images" / "apple.jpg",
|
||||
]
|
||||
|
||||
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
||||
output = router.run(sources=file_paths)
|
||||
assert output
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["audio/x-wav"]) == 1
|
||||
assert len(output["image/jpeg"]) == 1
|
||||
assert not output["unclassified"]
|
||||
assert len(output[r"text/plain"]) == 2
|
||||
assert len(output[r"audio/x-wav"]) == 1
|
||||
assert len(output[r"image/jpeg"]) == 1
|
||||
assert not output.get("unclassified")
|
||||
|
||||
def test_run_with_bytestreams(self, test_files_path):
|
||||
"""
|
||||
@ -40,14 +42,12 @@ class TestFileTypeRouter:
|
||||
test_files_path / "audio" / "the context for this answer is here.wav",
|
||||
test_files_path / "images" / "apple.jpg",
|
||||
]
|
||||
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
|
||||
mime_types = [r"text/plain", r"text/plain", r"audio/x-wav", r"image/jpeg"]
|
||||
# Convert file paths to ByteStream objects and set metadata
|
||||
byte_streams = []
|
||||
for path, mime_type in zip(file_paths, mime_types):
|
||||
stream = ByteStream(path.read_bytes())
|
||||
|
||||
stream.meta["content_type"] = mime_type
|
||||
|
||||
byte_streams.append(stream)
|
||||
|
||||
# add unclassified ByteStream
|
||||
@ -55,15 +55,18 @@ class TestFileTypeRouter:
|
||||
bs.meta["content_type"] = "unknown_type"
|
||||
byte_streams.append(bs)
|
||||
|
||||
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
||||
output = router.run(sources=byte_streams)
|
||||
assert output
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["audio/x-wav"]) == 1
|
||||
assert len(output["image/jpeg"]) == 1
|
||||
assert len(output[r"text/plain"]) == 2
|
||||
assert len(output[r"audio/x-wav"]) == 1
|
||||
assert len(output[r"image/jpeg"]) == 1
|
||||
assert len(output.get("unclassified")) == 1
|
||||
|
||||
def test_run_with_bytestreams_and_file_paths(self, test_files_path):
|
||||
"""
|
||||
Test if the component raises an error for unsupported data source types.
|
||||
"""
|
||||
file_paths = [
|
||||
test_files_path / "txt" / "doc_1.txt",
|
||||
test_files_path / "audio" / "the context for this answer is here.wav",
|
||||
@ -71,7 +74,7 @@ class TestFileTypeRouter:
|
||||
test_files_path / "images" / "apple.jpg",
|
||||
test_files_path / "markdown" / "sample.md",
|
||||
]
|
||||
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg", "text/markdown"]
|
||||
mime_types = [r"text/plain", r"audio/x-wav", r"text/plain", r"image/jpeg", r"text/markdown"]
|
||||
byte_stream_sources = []
|
||||
for path, mime_type in zip(file_paths, mime_types):
|
||||
stream = ByteStream(path.read_bytes())
|
||||
@ -80,18 +83,18 @@ class TestFileTypeRouter:
|
||||
|
||||
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
|
||||
|
||||
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg", "text/markdown"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg", r"text/markdown"])
|
||||
output = router.run(sources=mixed_sources)
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["audio/x-wav"]) == 1
|
||||
assert len(output["image/jpeg"]) == 1
|
||||
assert len(output["text/markdown"]) == 1
|
||||
assert len(output[r"text/plain"]) == 2
|
||||
assert len(output[r"audio/x-wav"]) == 1
|
||||
assert len(output[r"image/jpeg"]) == 1
|
||||
assert len(output[r"text/markdown"]) == 1
|
||||
|
||||
def test_no_files(self):
|
||||
"""
|
||||
Test that the component runs correctly when no files are provided.
|
||||
"""
|
||||
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
||||
output = router.run(sources=[])
|
||||
assert not output
|
||||
|
||||
@ -104,13 +107,11 @@ class TestFileTypeRouter:
|
||||
test_files_path / "audio" / "ignored.mp3",
|
||||
test_files_path / "audio" / "this is the content of the document.wav",
|
||||
]
|
||||
router = FileTypeRouter(mime_types=["text/plain"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain"])
|
||||
output = router.run(sources=file_paths)
|
||||
assert len(output["text/plain"]) == 1
|
||||
assert len(output[r"text/plain"]) == 1
|
||||
assert "mp3" not in output
|
||||
assert len(output["unclassified"]) == 2
|
||||
assert str(output["unclassified"][0]).endswith("ignored.mp3")
|
||||
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
|
||||
assert len(output.get("unclassified")) == 2
|
||||
|
||||
def test_no_extension(self, test_files_path):
|
||||
"""
|
||||
@ -121,14 +122,61 @@ class TestFileTypeRouter:
|
||||
test_files_path / "txt" / "doc_2",
|
||||
test_files_path / "txt" / "doc_2.txt",
|
||||
]
|
||||
router = FileTypeRouter(mime_types=["text/plain"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain"])
|
||||
output = router.run(sources=file_paths)
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["unclassified"]) == 1
|
||||
assert len(output[r"text/plain"]) == 2
|
||||
assert len(output.get("unclassified")) == 1
|
||||
|
||||
def test_unknown_mime_type(self):
|
||||
def test_unsupported_source_type(self):
|
||||
"""
|
||||
Test that the component handles files with unknown mime types.
|
||||
Test if the component raises an error for unsupported data source types.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Unknown mime type:"):
|
||||
FileTypeRouter(mime_types=["type_invalid"])
|
||||
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
|
||||
with pytest.raises(ValueError, match="Unsupported data source type:"):
|
||||
router.run(sources=[{"unsupported": "type"}])
|
||||
|
||||
def test_invalid_regex_pattern(self):
|
||||
"""
|
||||
Test that the component raises a ValueError for invalid regex patterns.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Invalid regex pattern"):
|
||||
FileTypeRouter(mime_types=["[Invalid-Regex"])
|
||||
|
||||
def test_regex_mime_type_matching(self, test_files_path):
|
||||
"""
|
||||
Test if the component correctly matches mime types using regex.
|
||||
"""
|
||||
router = FileTypeRouter(mime_types=[r"text\/.*", r"audio\/.*", r"image\/.*"])
|
||||
file_paths = [
|
||||
test_files_path / "txt" / "doc_1.txt",
|
||||
test_files_path / "audio" / "the context for this answer is here.wav",
|
||||
test_files_path / "images" / "apple.jpg",
|
||||
]
|
||||
output = router.run(sources=file_paths)
|
||||
assert len(output[r"text\/.*"]) == 1, "Failed to match text file with regex"
|
||||
assert len(output[r"audio\/.*"]) == 1, "Failed to match audio file with regex"
|
||||
assert len(output[r"image\/.*"]) == 1, "Failed to match image file with regex"
|
||||
|
||||
@patch("pathlib.Path.open", new_callable=mock_open, read_data=b"Mock file content.")
|
||||
def test_exact_mime_type_matching(self, mock_file):
|
||||
"""
|
||||
Test if the component correctly matches mime types exactly, without regex patterns.
|
||||
"""
|
||||
txt_stream = ByteStream(io.BytesIO(b"Text file content"), meta={"content_type": "text/plain"})
|
||||
jpg_stream = ByteStream(io.BytesIO(b"JPEG file content"), meta={"content_type": "image/jpeg"})
|
||||
mp3_stream = ByteStream(io.BytesIO(b"MP3 file content"), meta={"content_type": "audio/mpeg"})
|
||||
|
||||
byte_streams = [txt_stream, jpg_stream, mp3_stream]
|
||||
|
||||
router = FileTypeRouter(mime_types=["text/plain", "image/jpeg"])
|
||||
|
||||
output = router.run(sources=byte_streams)
|
||||
|
||||
assert len(output["text/plain"]) == 1, "Failed to match 'text/plain' MIME type exactly"
|
||||
assert txt_stream in output["text/plain"], "'doc_1.txt' ByteStream not correctly classified as 'text/plain'"
|
||||
|
||||
assert len(output["image/jpeg"]) == 1, "Failed to match 'image/jpeg' MIME type exactly"
|
||||
assert jpg_stream in output["image/jpeg"], "'apple.jpg' ByteStream not correctly classified as 'image/jpeg'"
|
||||
|
||||
assert len(output.get("unclassified")) == 1, "Failed to handle unclassified file types"
|
||||
assert mp3_stream in output["unclassified"], "'sound.mp3' ByteStream should be unclassified but is not"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user