feat: Add support for matching mime types using regex (#7303)

* feat: Add support for matching mime types using regex
---------

Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Yudhajit Sinha 2024-03-11 19:28:08 +05:30 committed by GitHub
parent 38b3472bb2
commit 41dbbdb3fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 151 additions and 58 deletions

View File

@ -1,4 +1,5 @@
import mimetypes
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
@ -12,28 +13,39 @@ logger = logging.getLogger(__name__)
@component
class FileTypeRouter:
"""
FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
MIME types.
FileTypeRouter groups a list of data sources (file paths or byte streams) by their MIME types, allowing
for flexible routing of files to different components based on their content type. It supports both exact MIME type
matching and pattern matching using regular expressions.
For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
are determined from the provided metadata.
For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types are determined from
the provided metadata. This enables the router to classify a diverse collection of files and data streams for
specialized processing.
The set of MIME types to consider is specified during the initialization of the component.
This component is useful when you need to classify a large collection of files or data streams according to their
MIME types and route them to different components for further processing.
The router's flexibility is enhanced by the support for regex patterns in the `mime_types` parameter, allowing users
to specify broad categories (e.g., 'audio/*' or 'text/*') or more specific types with regex patterns. This feature
is designed to be backward compatible, treating MIME types without regex patterns as exact matches.
Usage example:
```python
from haystack.components.routers import FileTypeRouter
from pathlib import Path
router = FileTypeRouter(mime_types=["text/plain"])
# For exact MIME type matching
router = FileTypeRouter(mime_types=["text/plain", "application/pdf"])
print(router.run(sources=["text_file.txt", "pdf_file.pdf"]))
# For flexible matching using regex, to handle all audio types
router_with_regex = FileTypeRouter(mime_types=[r"audio/.*", r"text/plain"])
# defaultdict(<class 'list'>, {'text/plain': [PosixPath('text_file.txt')],
# 'unclassified': [PosixPath('pdf_file.pdf')]})
sources = [Path("file.txt"), Path("document.pdf"), Path("song.mp3")]
print(router.run(sources=sources))
print(router_with_regex.run(sources=sources))
# Expected output:
# {'text/plain': [PosixPath('file.txt')], 'application/pdf': [PosixPath('document.pdf')], 'unclassified': [PosixPath('song.mp3')]}
# {'audio/.*': [PosixPath('song.mp3')], 'text/plain': [PosixPath('file.txt')], 'unclassified': [PosixPath('document.pdf')]}
```
:param mime_types: A list of MIME types or regex patterns to classify the incoming files or data streams.
"""
def __init__(self, mime_types: List[str]):
@ -44,11 +56,12 @@ class FileTypeRouter:
if not mime_types:
raise ValueError("The list of mime types cannot be empty.")
self.mime_type_patterns = []
for mime_type in mime_types:
if not self._is_valid_mime_type_format(mime_type):
raise ValueError(
f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter"
)
raise ValueError(f"Invalid mime type or regex pattern: '{mime_type}'.")
pattern = re.compile(mime_type)
self.mime_type_patterns.append(pattern)
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types
@ -66,20 +79,24 @@ class FileTypeRouter:
for source in sources:
if isinstance(source, str):
source = Path(source)
if isinstance(source, Path):
mime_type = self._get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.meta.get("content_type")
mime_type = source.meta.get("content_type", None)
else:
raise ValueError(f"Unsupported data source type: {type(source)}")
raise ValueError(f"Unsupported data source type: {type(source).__name__}")
if mime_type in self.mime_types:
mime_types[mime_type].append(source)
else:
matched = False
if mime_type:
for pattern in self.mime_type_patterns:
if pattern.fullmatch(mime_type):
mime_types[pattern.pattern].append(source)
matched = True
break
if not matched:
mime_types["unclassified"].append(source)
return mime_types
return dict(mime_types)
def _get_mime_type(self, path: Path) -> Optional[str]:
"""
@ -96,13 +113,17 @@ class FileTypeRouter:
def _is_valid_mime_type_format(self, mime_type: str) -> bool:
"""
Check if the provided MIME type is in valid format
Checks if the provided MIME type string is a valid regex pattern.
:param mime_type: The MIME type to check.
:returns: `True` if the provided MIME type is a valid MIME type format, `False` otherwise.
:param mime_type: The MIME type or regex pattern to validate.
:raises ValueError: If the mime_type is not a valid regex pattern.
:returns: Always True because a ValueError is raised for invalid patterns.
"""
return mime_type in mimetypes.types_map.values() or mime_type in self._get_custom_mime_mappings().values()
try:
re.compile(mime_type)
return True
except re.error:
raise ValueError(f"Invalid regex pattern '{mime_type}'.")
@staticmethod
def _get_custom_mime_mappings() -> Dict[str, str]:

View File

@ -0,0 +1,24 @@
---
enhancements:
- |
Enhanced FileTypeRouter with Regex Pattern Support for MIME Types: This introduces a significant enhancement to the `FileTypeRouter`, now featuring support for regex pattern matching for MIME types. This powerful addition allows for more granular control and flexibility in routing files based on their MIME types, enabling the handling of broad categories or specific MIME type patterns with ease. This feature is particularly beneficial for applications requiring sophisticated file classification and routing logic.
Usage example:
```python
from haystack.components.routers import FileTypeRouter
router = FileTypeRouter(mime_types=[r"text/.*", r"application/(pdf|json)"])
# Example files to classify
file_paths = [
Path("document.pdf"),
Path("report.json"),
Path("notes.txt"),
Path("image.png"),
]
result = router.run(sources=file_paths)
for mime_type, files in result.items():
print(f"MIME Type: {mime_type}, Files: {[str(file) for file in files]}")
```

View File

@ -1,4 +1,6 @@
import io
import sys
from unittest.mock import mock_open, patch
import pytest
@ -22,13 +24,13 @@ class TestFileTypeRouter:
test_files_path / "images" / "apple.jpg",
]
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
output = router.run(sources=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]
assert len(output[r"text/plain"]) == 2
assert len(output[r"audio/x-wav"]) == 1
assert len(output[r"image/jpeg"]) == 1
assert not output.get("unclassified")
def test_run_with_bytestreams(self, test_files_path):
"""
@ -40,14 +42,12 @@ class TestFileTypeRouter:
test_files_path / "audio" / "the context for this answer is here.wav",
test_files_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
mime_types = [r"text/plain", r"text/plain", r"audio/x-wav", r"image/jpeg"]
# Convert file paths to ByteStream objects and set metadata
byte_streams = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.meta["content_type"] = mime_type
byte_streams.append(stream)
# add unclassified ByteStream
@ -55,15 +55,18 @@ class TestFileTypeRouter:
bs.meta["content_type"] = "unknown_type"
byte_streams.append(bs)
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
output = router.run(sources=byte_streams)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert len(output[r"text/plain"]) == 2
assert len(output[r"audio/x-wav"]) == 1
assert len(output[r"image/jpeg"]) == 1
assert len(output.get("unclassified")) == 1
def test_run_with_bytestreams_and_file_paths(self, test_files_path):
"""
Test if the component raises an error for unsupported data source types.
"""
file_paths = [
test_files_path / "txt" / "doc_1.txt",
test_files_path / "audio" / "the context for this answer is here.wav",
@ -71,7 +74,7 @@ class TestFileTypeRouter:
test_files_path / "images" / "apple.jpg",
test_files_path / "markdown" / "sample.md",
]
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg", "text/markdown"]
mime_types = [r"text/plain", r"audio/x-wav", r"text/plain", r"image/jpeg", r"text/markdown"]
byte_stream_sources = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
@ -80,18 +83,18 @@ class TestFileTypeRouter:
mixed_sources = file_paths[:2] + byte_stream_sources[2:]
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg", "text/markdown"])
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg", r"text/markdown"])
output = router.run(sources=mixed_sources)
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert len(output["text/markdown"]) == 1
assert len(output[r"text/plain"]) == 2
assert len(output[r"audio/x-wav"]) == 1
assert len(output[r"image/jpeg"]) == 1
assert len(output[r"text/markdown"]) == 1
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
output = router.run(sources=[])
assert not output
@ -104,13 +107,11 @@ class TestFileTypeRouter:
test_files_path / "audio" / "ignored.mp3",
test_files_path / "audio" / "this is the content of the document.wav",
]
router = FileTypeRouter(mime_types=["text/plain"])
router = FileTypeRouter(mime_types=[r"text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 1
assert len(output[r"text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
assert str(output["unclassified"][0]).endswith("ignored.mp3")
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
assert len(output.get("unclassified")) == 2
def test_no_extension(self, test_files_path):
"""
@ -121,14 +122,61 @@ class TestFileTypeRouter:
test_files_path / "txt" / "doc_2",
test_files_path / "txt" / "doc_2.txt",
]
router = FileTypeRouter(mime_types=["text/plain"])
router = FileTypeRouter(mime_types=[r"text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1
assert len(output[r"text/plain"]) == 2
assert len(output.get("unclassified")) == 1
def test_unknown_mime_type(self):
def test_unsupported_source_type(self):
"""
Test that the component handles files with unknown mime types.
Test if the component raises an error for unsupported data source types.
"""
with pytest.raises(ValueError, match="Unknown mime type:"):
FileTypeRouter(mime_types=["type_invalid"])
router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"])
with pytest.raises(ValueError, match="Unsupported data source type:"):
router.run(sources=[{"unsupported": "type"}])
def test_invalid_regex_pattern(self):
"""
Test that the component raises a ValueError for invalid regex patterns.
"""
with pytest.raises(ValueError, match="Invalid regex pattern"):
FileTypeRouter(mime_types=["[Invalid-Regex"])
def test_regex_mime_type_matching(self, test_files_path):
"""
Test if the component correctly matches mime types using regex.
"""
router = FileTypeRouter(mime_types=[r"text\/.*", r"audio\/.*", r"image\/.*"])
file_paths = [
test_files_path / "txt" / "doc_1.txt",
test_files_path / "audio" / "the context for this answer is here.wav",
test_files_path / "images" / "apple.jpg",
]
output = router.run(sources=file_paths)
assert len(output[r"text\/.*"]) == 1, "Failed to match text file with regex"
assert len(output[r"audio\/.*"]) == 1, "Failed to match audio file with regex"
assert len(output[r"image\/.*"]) == 1, "Failed to match image file with regex"
@patch("pathlib.Path.open", new_callable=mock_open, read_data=b"Mock file content.")
def test_exact_mime_type_matching(self, mock_file):
"""
Test if the component correctly matches mime types exactly, without regex patterns.
"""
txt_stream = ByteStream(io.BytesIO(b"Text file content"), meta={"content_type": "text/plain"})
jpg_stream = ByteStream(io.BytesIO(b"JPEG file content"), meta={"content_type": "image/jpeg"})
mp3_stream = ByteStream(io.BytesIO(b"MP3 file content"), meta={"content_type": "audio/mpeg"})
byte_streams = [txt_stream, jpg_stream, mp3_stream]
router = FileTypeRouter(mime_types=["text/plain", "image/jpeg"])
output = router.run(sources=byte_streams)
assert len(output["text/plain"]) == 1, "Failed to match 'text/plain' MIME type exactly"
assert txt_stream in output["text/plain"], "'doc_1.txt' ByteStream not correctly classified as 'text/plain'"
assert len(output["image/jpeg"]) == 1, "Failed to match 'image/jpeg' MIME type exactly"
assert jpg_stream in output["image/jpeg"], "'apple.jpg' ByteStream not correctly classified as 'image/jpeg'"
assert len(output.get("unclassified")) == 1, "Failed to handle unclassified file types"
assert mp3_stream in output["unclassified"], "'sound.mp3' ByteStream should be unclassified but is not"