mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat: Add FileExtensionClassifier to previews (#5514)
* Add FileExtensionClassifier preview component * Add release note * PR feedback
This commit is contained in:
parent
bb7af3827d
commit
8652d00b54
@ -1,3 +1,4 @@
|
||||
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
|
||||
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
|
||||
from haystack.preview.components.file_converters import TextFileToDocument
|
||||
from haystack.preview.components.classifiers import FileExtensionClassifier
|
||||
|
1
haystack/preview/components/classifiers/__init__.py
Normal file
1
haystack/preview/components/classifiers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
|
82
haystack/preview/components/classifiers/file_classifier.py
Normal file
82
haystack/preview/components/classifiers/file_classifier.py
Normal file
@ -0,0 +1,82 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from haystack.preview import component
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class FileExtensionClassifier:
|
||||
"""
|
||||
A component that classifies files based on their MIME types read from their file extensions. This component
|
||||
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
|
||||
|
||||
The FileExtensionClassifier takes a list of file paths and groups them by their MIME types.
|
||||
The list of MIME types to consider is provided during the initialization of the component.
|
||||
|
||||
This component is particularly useful when working with a large number of files, and you
|
||||
want to categorize them based on their MIME types.
|
||||
"""
|
||||
|
||||
def __init__(self, mime_types: List[str]):
|
||||
"""
|
||||
Initialize the FileExtensionClassifier.
|
||||
|
||||
:param mime_types: A list of file mime types to consider when classifying
|
||||
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
|
||||
"""
|
||||
if not mime_types:
|
||||
raise ValueError("The list of mime types cannot be empty.")
|
||||
|
||||
all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types)
|
||||
if not all_known_mime_types:
|
||||
raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}")
|
||||
|
||||
# save the init parameters for serialization
|
||||
self.init_parameters = {"mime_types": mime_types}
|
||||
|
||||
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
|
||||
self.mime_types = mime_types
|
||||
|
||||
def run(self, paths: List[Union[str, Path]]):
|
||||
"""
|
||||
Run the FileExtensionClassifier.
|
||||
|
||||
This method takes the input data, iterates through the provided file paths, checks the file
|
||||
mime type of each file, and groups the file paths by their mime types.
|
||||
|
||||
:param paths: The input data containing the file paths to classify.
|
||||
:return: The output data containing the classified file paths.
|
||||
"""
|
||||
mime_types = defaultdict(list)
|
||||
for path in paths:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
mime_type = self.get_mime_type(path)
|
||||
if mime_type in self.mime_types:
|
||||
mime_types[mime_type].append(path)
|
||||
else:
|
||||
mime_types["unclassified"].append(path)
|
||||
|
||||
return mime_types
|
||||
|
||||
def get_mime_type(self, path: Path) -> Optional[str]:
|
||||
"""
|
||||
Get the MIME type of the provided file path.
|
||||
|
||||
:param path: The file path to get the MIME type for.
|
||||
:return: The MIME type of the provided file path, or None if the MIME type cannot be determined.
|
||||
"""
|
||||
return mimetypes.guess_type(path.as_posix())[0]
|
||||
|
||||
def is_valid_mime_type_format(self, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if the provided MIME type is in valid format
|
||||
:param mime_type: The MIME type to check.
|
||||
:return: True if the provided MIME type is a valid MIME type format, False otherwise.
|
||||
"""
|
||||
return mime_type in mimetypes.types_map.values()
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds FileExtensionClassifier to preview components.
|
0
test/preview/components/classifiers/__init__.py
Normal file
0
test/preview/components/classifiers/__init__.py
Normal file
89
test/preview/components/classifiers/test_file_classifier.py
Normal file
89
test/preview/components/classifiers/test_file_classifier.py
Normal file
@ -0,0 +1,89 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
|
||||
from test.preview.components.base import BaseTestComponent
|
||||
from test.conftest import preview_samples_path
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["win32", "cygwin"],
|
||||
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
||||
)
|
||||
class TestFileExtensionClassifier(BaseTestComponent):
|
||||
@pytest.mark.unit
|
||||
def test_save_load(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(
|
||||
FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
Test if the component runs correctly in the simplest happy path.
|
||||
"""
|
||||
file_paths = [
|
||||
preview_samples_path / "txt" / "doc_1.txt",
|
||||
preview_samples_path / "txt" / "doc_2.txt",
|
||||
preview_samples_path / "audio" / "the context for this answer is here.wav",
|
||||
preview_samples_path / "images" / "apple.jpg",
|
||||
]
|
||||
|
||||
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
output = classifier.run(paths=file_paths)
|
||||
assert output
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["audio/x-wav"]) == 1
|
||||
assert len(output["image/jpeg"]) == 1
|
||||
assert not output["unclassified"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_files(self):
|
||||
"""
|
||||
Test that the component runs correctly when no files are provided.
|
||||
"""
|
||||
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
output = classifier.run(paths=[])
|
||||
assert not output
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unlisted_extensions(self, preview_samples_path):
|
||||
"""
|
||||
Test that the component correctly handles files with non specified mime types.
|
||||
"""
|
||||
file_paths = [
|
||||
preview_samples_path / "txt" / "doc_1.txt",
|
||||
preview_samples_path / "audio" / "ignored.mp3",
|
||||
preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||
]
|
||||
classifier = FileExtensionClassifier(mime_types=["text/plain"])
|
||||
output = classifier.run(paths=file_paths)
|
||||
assert len(output["text/plain"]) == 1
|
||||
assert "mp3" not in output
|
||||
assert len(output["unclassified"]) == 2
|
||||
assert str(output["unclassified"][0]).endswith("ignored.mp3")
|
||||
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_extension(self, preview_samples_path):
|
||||
"""
|
||||
Test that the component ignores files with no extension.
|
||||
"""
|
||||
file_paths = [
|
||||
preview_samples_path / "txt" / "doc_1.txt",
|
||||
preview_samples_path / "txt" / "doc_2",
|
||||
preview_samples_path / "txt" / "doc_2.txt",
|
||||
]
|
||||
classifier = FileExtensionClassifier(mime_types=["text/plain"])
|
||||
output = classifier.run(paths=file_paths)
|
||||
assert len(output["text/plain"]) == 2
|
||||
assert len(output["unclassified"]) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_mime_type(self):
|
||||
"""
|
||||
Test that the component handles files with unknown mime types.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="The list of mime types"):
|
||||
FileExtensionClassifier(mime_types=["type_invalid"])
|
Loading…
x
Reference in New Issue
Block a user