feat: Add FileExtensionClassifier to previews (#5514)

* Add FileExtensionClassifier preview component

* Add release note

* PR feedback
This commit is contained in:
Vladimir Blagojevic 2023-08-15 15:58:55 +02:00 committed by GitHub
parent bb7af3827d
commit 8652d00b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 0 deletions

View File

@ -1,3 +1,4 @@
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.preview.components.file_converters import TextFileToDocument from haystack.preview.components.file_converters import TextFileToDocument
from haystack.preview.components.classifiers import FileExtensionClassifier

View File

@ -0,0 +1 @@
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier

View File

@ -0,0 +1,82 @@
import logging
import mimetypes
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Optional
from haystack.preview import component
logger = logging.getLogger(__name__)
@component
class FileExtensionClassifier:
"""
A component that classifies files based on their MIME types read from their file extensions. This component
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
The FileExtensionClassifier takes a list of file paths and groups them by their MIME types.
The list of MIME types to consider is provided during the initialization of the component.
This component is particularly useful when working with a large number of files, and you
want to categorize them based on their MIME types.
"""
def __init__(self, mime_types: List[str]):
"""
Initialize the FileExtensionClassifier.
:param mime_types: A list of file mime types to consider when classifying
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
"""
if not mime_types:
raise ValueError("The list of mime types cannot be empty.")
all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types)
if not all_known_mime_types:
raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}")
# save the init parameters for serialization
self.init_parameters = {"mime_types": mime_types}
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types
def run(self, paths: List[Union[str, Path]]):
"""
Run the FileExtensionClassifier.
This method takes the input data, iterates through the provided file paths, checks the file
mime type of each file, and groups the file paths by their mime types.
:param paths: The input data containing the file paths to classify.
:return: The output data containing the classified file paths.
"""
mime_types = defaultdict(list)
for path in paths:
if isinstance(path, str):
path = Path(path)
mime_type = self.get_mime_type(path)
if mime_type in self.mime_types:
mime_types[mime_type].append(path)
else:
mime_types["unclassified"].append(path)
return mime_types
def get_mime_type(self, path: Path) -> Optional[str]:
"""
Get the MIME type of the provided file path.
:param path: The file path to get the MIME type for.
:return: The MIME type of the provided file path, or None if the MIME type cannot be determined.
"""
return mimetypes.guess_type(path.as_posix())[0]
def is_valid_mime_type_format(self, mime_type: str) -> bool:
"""
Check if the provided MIME type is in valid format
:param mime_type: The MIME type to check.
:return: True if the provided MIME type is a valid MIME type format, False otherwise.
"""
return mime_type in mimetypes.types_map.values()

View File

@ -0,0 +1,4 @@
---
features:
- |
Adds FileExtensionClassifier to preview components.

View File

@ -0,0 +1,89 @@
import sys
import pytest
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
from test.preview.components.base import BaseTestComponent
from test.conftest import preview_samples_path
@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionClassifier(BaseTestComponent):
@pytest.mark.unit
def test_save_load(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(
FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path
)
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""
Test if the component runs correctly in the simplest happy path.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]
@pytest.mark.unit
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=[])
assert not output
@pytest.mark.unit
def test_unlisted_extensions(self, preview_samples_path):
"""
Test that the component correctly handles files with non specified mime types.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
assert str(output["unclassified"][0]).endswith("ignored.mp3")
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
@pytest.mark.unit
def test_no_extension(self, preview_samples_path):
"""
Test that the component ignores files with no extension.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1
@pytest.mark.unit
def test_unknown_mime_type(self):
"""
Test that the component handles files with unknown mime types.
"""
with pytest.raises(ValueError, match="The list of mime types"):
FileExtensionClassifier(mime_types=["type_invalid"])