mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat: Add FileExtensionClassifier to previews (#5514)
* Add FileExtensionClassifier preview component * Add release note * PR feedback
This commit is contained in:
parent
bb7af3827d
commit
8652d00b54
@ -1,3 +1,4 @@
|
|||||||
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
|
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
|
||||||
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
|
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
|
||||||
from haystack.preview.components.file_converters import TextFileToDocument
|
from haystack.preview.components.file_converters import TextFileToDocument
|
||||||
|
from haystack.preview.components.classifiers import FileExtensionClassifier
|
||||||
|
1
haystack/preview/components/classifiers/__init__.py
Normal file
1
haystack/preview/components/classifiers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
|
82
haystack/preview/components/classifiers/file_classifier.py
Normal file
82
haystack/preview/components/classifiers/file_classifier.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
|
from haystack.preview import component
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class FileExtensionClassifier:
|
||||||
|
"""
|
||||||
|
A component that classifies files based on their MIME types read from their file extensions. This component
|
||||||
|
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
|
||||||
|
|
||||||
|
The FileExtensionClassifier takes a list of file paths and groups them by their MIME types.
|
||||||
|
The list of MIME types to consider is provided during the initialization of the component.
|
||||||
|
|
||||||
|
This component is particularly useful when working with a large number of files, and you
|
||||||
|
want to categorize them based on their MIME types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mime_types: List[str]):
|
||||||
|
"""
|
||||||
|
Initialize the FileExtensionClassifier.
|
||||||
|
|
||||||
|
:param mime_types: A list of file mime types to consider when classifying
|
||||||
|
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
|
||||||
|
"""
|
||||||
|
if not mime_types:
|
||||||
|
raise ValueError("The list of mime types cannot be empty.")
|
||||||
|
|
||||||
|
all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types)
|
||||||
|
if not all_known_mime_types:
|
||||||
|
raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}")
|
||||||
|
|
||||||
|
# save the init parameters for serialization
|
||||||
|
self.init_parameters = {"mime_types": mime_types}
|
||||||
|
|
||||||
|
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
|
||||||
|
self.mime_types = mime_types
|
||||||
|
|
||||||
|
def run(self, paths: List[Union[str, Path]]):
|
||||||
|
"""
|
||||||
|
Run the FileExtensionClassifier.
|
||||||
|
|
||||||
|
This method takes the input data, iterates through the provided file paths, checks the file
|
||||||
|
mime type of each file, and groups the file paths by their mime types.
|
||||||
|
|
||||||
|
:param paths: The input data containing the file paths to classify.
|
||||||
|
:return: The output data containing the classified file paths.
|
||||||
|
"""
|
||||||
|
mime_types = defaultdict(list)
|
||||||
|
for path in paths:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
mime_type = self.get_mime_type(path)
|
||||||
|
if mime_type in self.mime_types:
|
||||||
|
mime_types[mime_type].append(path)
|
||||||
|
else:
|
||||||
|
mime_types["unclassified"].append(path)
|
||||||
|
|
||||||
|
return mime_types
|
||||||
|
|
||||||
|
def get_mime_type(self, path: Path) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the MIME type of the provided file path.
|
||||||
|
|
||||||
|
:param path: The file path to get the MIME type for.
|
||||||
|
:return: The MIME type of the provided file path, or None if the MIME type cannot be determined.
|
||||||
|
"""
|
||||||
|
return mimetypes.guess_type(path.as_posix())[0]
|
||||||
|
|
||||||
|
def is_valid_mime_type_format(self, mime_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the provided MIME type is in valid format
|
||||||
|
:param mime_type: The MIME type to check.
|
||||||
|
:return: True if the provided MIME type is a valid MIME type format, False otherwise.
|
||||||
|
"""
|
||||||
|
return mime_type in mimetypes.types_map.values()
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Adds FileExtensionClassifier to preview components.
|
0
test/preview/components/classifiers/__init__.py
Normal file
0
test/preview/components/classifiers/__init__.py
Normal file
89
test/preview/components/classifiers/test_file_classifier.py
Normal file
89
test/preview/components/classifiers/test_file_classifier.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
|
||||||
|
from test.preview.components.base import BaseTestComponent
|
||||||
|
from test.conftest import preview_samples_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform in ["win32", "cygwin"],
|
||||||
|
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
||||||
|
)
|
||||||
|
class TestFileExtensionClassifier(BaseTestComponent):
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_save_load(self, tmp_path):
|
||||||
|
self.assert_can_be_saved_and_loaded_in_pipeline(
|
||||||
|
FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_run(self, preview_samples_path):
|
||||||
|
"""
|
||||||
|
Test if the component runs correctly in the simplest happy path.
|
||||||
|
"""
|
||||||
|
file_paths = [
|
||||||
|
preview_samples_path / "txt" / "doc_1.txt",
|
||||||
|
preview_samples_path / "txt" / "doc_2.txt",
|
||||||
|
preview_samples_path / "audio" / "the context for this answer is here.wav",
|
||||||
|
preview_samples_path / "images" / "apple.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
|
output = classifier.run(paths=file_paths)
|
||||||
|
assert output
|
||||||
|
assert len(output["text/plain"]) == 2
|
||||||
|
assert len(output["audio/x-wav"]) == 1
|
||||||
|
assert len(output["image/jpeg"]) == 1
|
||||||
|
assert not output["unclassified"]
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_no_files(self):
|
||||||
|
"""
|
||||||
|
Test that the component runs correctly when no files are provided.
|
||||||
|
"""
|
||||||
|
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||||
|
output = classifier.run(paths=[])
|
||||||
|
assert not output
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_unlisted_extensions(self, preview_samples_path):
|
||||||
|
"""
|
||||||
|
Test that the component correctly handles files with non specified mime types.
|
||||||
|
"""
|
||||||
|
file_paths = [
|
||||||
|
preview_samples_path / "txt" / "doc_1.txt",
|
||||||
|
preview_samples_path / "audio" / "ignored.mp3",
|
||||||
|
preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||||
|
]
|
||||||
|
classifier = FileExtensionClassifier(mime_types=["text/plain"])
|
||||||
|
output = classifier.run(paths=file_paths)
|
||||||
|
assert len(output["text/plain"]) == 1
|
||||||
|
assert "mp3" not in output
|
||||||
|
assert len(output["unclassified"]) == 2
|
||||||
|
assert str(output["unclassified"][0]).endswith("ignored.mp3")
|
||||||
|
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_no_extension(self, preview_samples_path):
|
||||||
|
"""
|
||||||
|
Test that the component ignores files with no extension.
|
||||||
|
"""
|
||||||
|
file_paths = [
|
||||||
|
preview_samples_path / "txt" / "doc_1.txt",
|
||||||
|
preview_samples_path / "txt" / "doc_2",
|
||||||
|
preview_samples_path / "txt" / "doc_2.txt",
|
||||||
|
]
|
||||||
|
classifier = FileExtensionClassifier(mime_types=["text/plain"])
|
||||||
|
output = classifier.run(paths=file_paths)
|
||||||
|
assert len(output["text/plain"]) == 2
|
||||||
|
assert len(output["unclassified"]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_unknown_mime_type(self):
|
||||||
|
"""
|
||||||
|
Test that the component handles files with unknown mime types.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError, match="The list of mime types"):
|
||||||
|
FileExtensionClassifier(mime_types=["type_invalid"])
|
Loading…
x
Reference in New Issue
Block a user