From 84ed954c8c7d45ba8683a8aeede7e95bdf8ee57e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 8 Aug 2023 15:51:07 +0200 Subject: [PATCH] feat: Improve performance and add default media support in FileTypeClassifier (#5083) * feat: add media outgoing edge to FileTypeClassifier * Add release note * Update language --------- Co-authored-by: Daniel Bichuetti Co-authored-by: Massimiliano Pippi Co-authored-by: agnieszka-m --- haystack/nodes/file_classifier/file_type.py | 44 ++++++++---- ...-media-files-support-e970524f726dd844.yaml | 4 ++ test/nodes/test_filetype_classifier.py | 69 ++++++++++++++++++- 3 files changed, 102 insertions(+), 15 deletions(-) create mode 100644 releasenotes/notes/file-classifier-add-media-files-support-e970524f726dd844.yaml diff --git a/haystack/nodes/file_classifier/file_type.py b/haystack/nodes/file_classifier/file_type.py index ed05f2fd1..3a91a89de 100644 --- a/haystack/nodes/file_classifier/file_type.py +++ b/haystack/nodes/file_classifier/file_type.py @@ -14,7 +14,9 @@ with LazyImport() as magic_import: import magic -DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"] +DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html", "media"] + +DEFAULT_MEDIA_TYPES = ["mp3", "mp4", "mpeg", "m4a", "wav", "webm"] class FileTypeClassifier(BaseComponent): @@ -24,15 +26,20 @@ class FileTypeClassifier(BaseComponent): outgoing_edges = len(DEFAULT_TYPES) - def __init__(self, supported_types: Optional[List[str]] = None): + def __init__(self, supported_types: Optional[List[str]] = None, full_analysis: bool = False): """ Node that sends out files on a different output edge depending on their extension. - :param supported_types: The file types that this node can distinguish between. - If no value is provided, the value created by default comprises: `txt`, `pdf`, `md`, `docx`, and `html`. - Lists with duplicate elements are not allowed. + :param supported_types: The file types this node distinguishes. Optional. + If you don't provide any value, the default is: `txt`, `pdf`, `md`, `docx`, and `html`. + You can't use lists with duplicate elements. + :param full_analysis: If True, the whole file is analyzed to determine the file type. + If False, only the first 2049 bytes are analyzed. """ + self.full_analysis = full_analysis + self._default_types = False if supported_types is None: + self._default_types = True supported_types = DEFAULT_TYPES if len(set(supported_types)) != len(supported_types): duplicates = supported_types @@ -56,9 +63,17 @@ class FileTypeClassifier(BaseComponent): :param file_path: the path to extract the extension from """ try: - magic_import.check() - extension = magic.from_file(str(file_path), mime=True) - return mimetypes.guess_extension(extension) or "" + with open(file_path, "rb") as f: + if self.full_analysis: + buffer = f.read() + else: + buffer = f.read(2049) + extension = magic.from_buffer(buffer, mime=True) + real_extension = mimetypes.guess_extension(extension) or "" + real_extension = real_extension.lstrip(".") + if self._default_types and real_extension in DEFAULT_MEDIA_TYPES: + return "media" + return real_extension or "" except (NameError, ImportError): logger.error( "The type of '%s' could not be guessed, probably because 'python-magic' is not installed. Ignoring this error." @@ -76,18 +91,19 @@ class FileTypeClassifier(BaseComponent): :param file_paths: the paths to extract the extension from :return: a set of strings with all the extensions (without duplicates), the extension will be guessed if the file has none """ - extension = file_paths[0].suffix.lower() - if extension == "": + extension = file_paths[0].suffix.lower().lstrip(".") + + if extension == "" or (self._default_types and extension in DEFAULT_MEDIA_TYPES): extension = self._estimate_extension(file_paths[0]) for path in file_paths: - path_suffix = path.suffix.lower() - if path_suffix == "": + path_suffix = path.suffix.lower().lstrip(".") + if path_suffix == "" or (self._default_types and path_suffix in DEFAULT_MEDIA_TYPES): path_suffix = self._estimate_extension(path) if path_suffix != extension: - raise ValueError("Multiple file types are not allowed at once.") + raise ValueError("Multiple non-default file types are not allowed at once.") - return extension.lstrip(".") + return extension def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore """ diff --git a/releasenotes/notes/file-classifier-add-media-files-support-e970524f726dd844.yaml b/releasenotes/notes/file-classifier-add-media-files-support-e970524f726dd844.yaml new file mode 100644 index 000000000..97a42ee98 --- /dev/null +++ b/releasenotes/notes/file-classifier-add-media-files-support-e970524f726dd844.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Enhance FileTypeClassifier to detect media file types like mp3, mp4, mpeg, m4a, and similar. diff --git a/test/nodes/test_filetype_classifier.py b/test/nodes/test_filetype_classifier.py index b5e9bffe6..15c9a70d3 100644 --- a/test/nodes/test_filetype_classifier.py +++ b/test/nodes/test_filetype_classifier.py @@ -1,10 +1,13 @@ import logging +import os import platform +import shutil +from pathlib import Path import pytest import haystack -from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES +from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES @pytest.mark.unit @@ -101,3 +104,67 @@ def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch, with caplog.at_level(logging.ERROR): node.run(samples_path / "extensionless_files" / f"pdf_file") assert "'python-magic' is not installed" in caplog.text + + +@pytest.mark.unit +def test_filetype_classifier_media_extensions_positive(tmp_path): + node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES) + for idx in range(len(DEFAULT_MEDIA_TYPES)): + test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}" + output, edge = node.run(test_file) + assert edge == f"output_{idx+1}" + assert output == {"file_paths": [test_file]} + + +@pytest.mark.unit +def test_filetype_classifier_media_extensions_negative(tmp_path): + node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES) + + test_file = tmp_path / f"test.txt" + with pytest.raises(ValueError, match="Files of type 'txt'"): + node.run(test_file) + + +@pytest.mark.unit +@pytest.mark.skipif(platform.system() in ["Windows", "Darwin"], reason="python-magic not available") +def test_filetype_classifier_estimate_media_extensions(tmp_path): + node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES) + + test_file = "test/samples/audio/answer.wav" + new_file_name = "test_wav_no_extension" + new_file_path = os.path.join(tmp_path, new_file_name) + + shutil.copy(test_file, new_file_path) + + output, edge = node.run(new_file_path) + assert edge == f"output_5" + assert output == {"file_paths": [Path(new_file_path)]} + + +@pytest.mark.unit +def test_filetype_classifier_batched_various_media_extensions(tmp_path): + test_files = [] + node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES) + for idx in range(len(DEFAULT_MEDIA_TYPES)): + test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}" + test_files.append(test_file) + + # we can't classify a list of files with different media extensions + with pytest.raises(ValueError, match="Multiple non-default file types are not allowed at once."): + node.run_batch(test_files) + + +@pytest.mark.unit +def test_filetype_classifier_batched_same_media_extensions(tmp_path): + test_files = [] + batch_size = 5 + file_index = 0 + node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES) + for idx in range(batch_size): + test_file = tmp_path / f"test-{idx}.{DEFAULT_MEDIA_TYPES[file_index]}" + test_files.append(test_file) + + # we should be able to pass a list of files with the same extension + output, edge = node.run_batch(test_files) + assert edge == f"output_1" + assert output == {"file_paths": test_files}