From 8652d00b54922871d190003f5ae59e34e83c7f78 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 15 Aug 2023 15:58:55 +0200 Subject: [PATCH] feat: Add FileExtensionClassifier to previews (#5514) * Add FileExtensionClassifier preview component * Add release note * PR feedback --- haystack/preview/components/__init__.py | 1 + .../components/classifiers/__init__.py | 1 + .../components/classifiers/file_classifier.py | 82 +++++++++++++++++ ...n-classifier-preview-40f31c27bbd7cff9.yaml | 4 + .../components/classifiers/__init__.py | 0 .../classifiers/test_file_classifier.py | 89 +++++++++++++++++++ 6 files changed, 177 insertions(+) create mode 100644 haystack/preview/components/classifiers/__init__.py create mode 100644 haystack/preview/components/classifiers/file_classifier.py create mode 100644 releasenotes/notes/add-file-extension-classifier-preview-40f31c27bbd7cff9.yaml create mode 100644 test/preview/components/classifiers/__init__.py create mode 100644 test/preview/components/classifiers/test_file_classifier.py diff --git a/haystack/preview/components/__init__.py b/haystack/preview/components/__init__.py index aef8809c9..5831a72d6 100644 --- a/haystack/preview/components/__init__.py +++ b/haystack/preview/components/__init__.py @@ -1,3 +1,4 @@ from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber from haystack.preview.components.file_converters import TextFileToDocument +from haystack.preview.components.classifiers import FileExtensionClassifier diff --git a/haystack/preview/components/classifiers/__init__.py b/haystack/preview/components/classifiers/__init__.py new file mode 100644 index 000000000..c0f52300b --- /dev/null +++ b/haystack/preview/components/classifiers/__init__.py @@ -0,0 +1 @@ +from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier diff --git a/haystack/preview/components/classifiers/file_classifier.py b/haystack/preview/components/classifiers/file_classifier.py new file mode 100644 index 000000000..5ec06c79b --- /dev/null +++ b/haystack/preview/components/classifiers/file_classifier.py @@ -0,0 +1,82 @@ +import logging +import mimetypes +from collections import defaultdict +from pathlib import Path +from typing import List, Union, Optional + +from haystack.preview import component + +logger = logging.getLogger(__name__) + + +@component +class FileExtensionClassifier: + """ + A component that classifies files based on their MIME types read from their file extensions. This component + does not read the file contents, but rather uses the file extension to determine the MIME type of the file. + + The FileExtensionClassifier takes a list of file paths and groups them by their MIME types. + The list of MIME types to consider is provided during the initialization of the component. + + This component is particularly useful when working with a large number of files, and you + want to categorize them based on their MIME types. + """ + + def __init__(self, mime_types: List[str]): + """ + Initialize the FileExtensionClassifier. + + :param mime_types: A list of file mime types to consider when classifying + files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]). + """ + if not mime_types: + raise ValueError("The list of mime types cannot be empty.") + + all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types) + if not all_known_mime_types: + raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}") + + # save the init parameters for serialization + self.init_parameters = {"mime_types": mime_types} + + component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types}) + self.mime_types = mime_types + + def run(self, paths: List[Union[str, Path]]): + """ + Run the FileExtensionClassifier. + + This method takes the input data, iterates through the provided file paths, checks the file + mime type of each file, and groups the file paths by their mime types. + + :param paths: The input data containing the file paths to classify. + :return: The output data containing the classified file paths. + """ + mime_types = defaultdict(list) + for path in paths: + if isinstance(path, str): + path = Path(path) + mime_type = self.get_mime_type(path) + if mime_type in self.mime_types: + mime_types[mime_type].append(path) + else: + mime_types["unclassified"].append(path) + + return mime_types + + def get_mime_type(self, path: Path) -> Optional[str]: + """ + Get the MIME type of the provided file path. + + :param path: The file path to get the MIME type for. + :return: The MIME type of the provided file path, or None if the MIME type cannot be determined. + """ + return mimetypes.guess_type(path.as_posix())[0] + + def is_valid_mime_type_format(self, mime_type: str) -> bool: + """ + Check if the provided MIME type is in valid format + :param mime_type: The MIME type to check. + :return: True if the provided MIME type is a valid MIME type format, False otherwise. + """ + return mime_type in mimetypes.types_map.values() diff --git a/releasenotes/notes/add-file-extension-classifier-preview-40f31c27bbd7cff9.yaml b/releasenotes/notes/add-file-extension-classifier-preview-40f31c27bbd7cff9.yaml new file mode 100644 index 000000000..a6eed88eb --- /dev/null +++ b/releasenotes/notes/add-file-extension-classifier-preview-40f31c27bbd7cff9.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds FileExtensionClassifier to preview components. diff --git a/test/preview/components/classifiers/__init__.py b/test/preview/components/classifiers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/preview/components/classifiers/test_file_classifier.py b/test/preview/components/classifiers/test_file_classifier.py new file mode 100644 index 000000000..094d4edc3 --- /dev/null +++ b/test/preview/components/classifiers/test_file_classifier.py @@ -0,0 +1,89 @@ +import sys + +import pytest + +from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier +from test.preview.components.base import BaseTestComponent +from test.conftest import preview_samples_path + + +@pytest.mark.skipif( + sys.platform in ["win32", "cygwin"], + reason="Can't run on Windows Github CI, need access to registry to get mime types", +) +class TestFileExtensionClassifier(BaseTestComponent): + @pytest.mark.unit + def test_save_load(self, tmp_path): + self.assert_can_be_saved_and_loaded_in_pipeline( + FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path + ) + + @pytest.mark.unit + def test_run(self, preview_samples_path): + """ + Test if the component runs correctly in the simplest happy path. + """ + file_paths = [ + preview_samples_path / "txt" / "doc_1.txt", + preview_samples_path / "txt" / "doc_2.txt", + preview_samples_path / "audio" / "the context for this answer is here.wav", + preview_samples_path / "images" / "apple.jpg", + ] + + classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) + output = classifier.run(paths=file_paths) + assert output + assert len(output["text/plain"]) == 2 + assert len(output["audio/x-wav"]) == 1 + assert len(output["image/jpeg"]) == 1 + assert not output["unclassified"] + + @pytest.mark.unit + def test_no_files(self): + """ + Test that the component runs correctly when no files are provided. + """ + classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) + output = classifier.run(paths=[]) + assert not output + + @pytest.mark.unit + def test_unlisted_extensions(self, preview_samples_path): + """ + Test that the component correctly handles files with non specified mime types. + """ + file_paths = [ + preview_samples_path / "txt" / "doc_1.txt", + preview_samples_path / "audio" / "ignored.mp3", + preview_samples_path / "audio" / "this is the content of the document.wav", + ] + classifier = FileExtensionClassifier(mime_types=["text/plain"]) + output = classifier.run(paths=file_paths) + assert len(output["text/plain"]) == 1 + assert "mp3" not in output + assert len(output["unclassified"]) == 2 + assert str(output["unclassified"][0]).endswith("ignored.mp3") + assert str(output["unclassified"][1]).endswith("this is the content of the document.wav") + + @pytest.mark.unit + def test_no_extension(self, preview_samples_path): + """ + Test that the component ignores files with no extension. + """ + file_paths = [ + preview_samples_path / "txt" / "doc_1.txt", + preview_samples_path / "txt" / "doc_2", + preview_samples_path / "txt" / "doc_2.txt", + ] + classifier = FileExtensionClassifier(mime_types=["text/plain"]) + output = classifier.run(paths=file_paths) + assert len(output["text/plain"]) == 2 + assert len(output["unclassified"]) == 1 + + @pytest.mark.unit + def test_unknown_mime_type(self): + """ + Test that the component handles files with unknown mime types. + """ + with pytest.raises(ValueError, match="The list of mime types"): + FileExtensionClassifier(mime_types=["type_invalid"])