feat: Improve performance and add default media support in FileTypeClassifier (#5083)

* feat: add media outgoing edge to FileTypeClassifier

* Add release note

* Update language

---------

Co-authored-by: Daniel Bichuetti <daniel.bichuetti@gmail.com>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-08-08 15:51:07 +02:00 committed by GitHub
parent d46c84bb61
commit 84ed954c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 15 deletions

View File

@ -14,7 +14,9 @@ with LazyImport() as magic_import:
import magic
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"]
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html", "media"]
DEFAULT_MEDIA_TYPES = ["mp3", "mp4", "mpeg", "m4a", "wav", "webm"]
class FileTypeClassifier(BaseComponent):
@ -24,15 +26,20 @@ class FileTypeClassifier(BaseComponent):
outgoing_edges = len(DEFAULT_TYPES)
def __init__(self, supported_types: Optional[List[str]] = None):
def __init__(self, supported_types: Optional[List[str]] = None, full_analysis: bool = False):
"""
Node that sends out files on a different output edge depending on their extension.
:param supported_types: The file types that this node can distinguish between.
If no value is provided, the value created by default comprises: `txt`, `pdf`, `md`, `docx`, and `html`.
Lists with duplicate elements are not allowed.
:param supported_types: The file types this node distinguishes. Optional.
If you don't provide any value, the default is: `txt`, `pdf`, `md`, `docx`, and `html`.
You can't use lists with duplicate elements.
:param full_analysis: If True, the whole file is analyzed to determine the file type.
If False, only the first 2049 bytes are analyzed.
"""
self.full_analysis = full_analysis
self._default_types = False
if supported_types is None:
self._default_types = True
supported_types = DEFAULT_TYPES
if len(set(supported_types)) != len(supported_types):
duplicates = supported_types
@ -56,9 +63,17 @@ class FileTypeClassifier(BaseComponent):
:param file_path: the path to extract the extension from
"""
try:
magic_import.check()
extension = magic.from_file(str(file_path), mime=True)
return mimetypes.guess_extension(extension) or ""
with open(file_path, "rb") as f:
if self.full_analysis:
buffer = f.read()
else:
buffer = f.read(2049)
extension = magic.from_buffer(buffer, mime=True)
real_extension = mimetypes.guess_extension(extension) or ""
real_extension = real_extension.lstrip(".")
if self._default_types and real_extension in DEFAULT_MEDIA_TYPES:
return "media"
return real_extension or ""
except (NameError, ImportError):
logger.error(
"The type of '%s' could not be guessed, probably because 'python-magic' is not installed. Ignoring this error."
@ -76,18 +91,19 @@ class FileTypeClassifier(BaseComponent):
:param file_paths: the paths to extract the extension from
:return: a set of strings with all the extensions (without duplicates), the extension will be guessed if the file has none
"""
extension = file_paths[0].suffix.lower()
if extension == "":
extension = file_paths[0].suffix.lower().lstrip(".")
if extension == "" or (self._default_types and extension in DEFAULT_MEDIA_TYPES):
extension = self._estimate_extension(file_paths[0])
for path in file_paths:
path_suffix = path.suffix.lower()
if path_suffix == "":
path_suffix = path.suffix.lower().lstrip(".")
if path_suffix == "" or (self._default_types and path_suffix in DEFAULT_MEDIA_TYPES):
path_suffix = self._estimate_extension(path)
if path_suffix != extension:
raise ValueError("Multiple file types are not allowed at once.")
raise ValueError("Multiple non-default file types are not allowed at once.")
return extension.lstrip(".")
return extension
def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore
"""

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enhance FileTypeClassifier to detect media file types like mp3, mp4, mpeg, m4a, and similar.

View File

@ -1,10 +1,13 @@
import logging
import os
import platform
import shutil
from pathlib import Path
import pytest
import haystack
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES
@pytest.mark.unit
@ -101,3 +104,67 @@ def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch,
with caplog.at_level(logging.ERROR):
node.run(samples_path / "extensionless_files" / f"pdf_file")
assert "'python-magic' is not installed" in caplog.text
@pytest.mark.unit
def test_filetype_classifier_media_extensions_positive(tmp_path):
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
for idx in range(len(DEFAULT_MEDIA_TYPES)):
test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}"
output, edge = node.run(test_file)
assert edge == f"output_{idx+1}"
assert output == {"file_paths": [test_file]}
@pytest.mark.unit
def test_filetype_classifier_media_extensions_negative(tmp_path):
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
test_file = tmp_path / f"test.txt"
with pytest.raises(ValueError, match="Files of type 'txt'"):
node.run(test_file)
@pytest.mark.unit
@pytest.mark.skipif(platform.system() in ["Windows", "Darwin"], reason="python-magic not available")
def test_filetype_classifier_estimate_media_extensions(tmp_path):
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
test_file = "test/samples/audio/answer.wav"
new_file_name = "test_wav_no_extension"
new_file_path = os.path.join(tmp_path, new_file_name)
shutil.copy(test_file, new_file_path)
output, edge = node.run(new_file_path)
assert edge == f"output_5"
assert output == {"file_paths": [Path(new_file_path)]}
@pytest.mark.unit
def test_filetype_classifier_batched_various_media_extensions(tmp_path):
test_files = []
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
for idx in range(len(DEFAULT_MEDIA_TYPES)):
test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}"
test_files.append(test_file)
# we can't classify a list of files with different media extensions
with pytest.raises(ValueError, match="Multiple non-default file types are not allowed at once."):
node.run_batch(test_files)
@pytest.mark.unit
def test_filetype_classifier_batched_same_media_extensions(tmp_path):
test_files = []
batch_size = 5
file_index = 0
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
for idx in range(batch_size):
test_file = tmp_path / f"test-{idx}.{DEFAULT_MEDIA_TYPES[file_index]}"
test_files.append(test_file)
# we should be able to pass a list of files with the same extension
output, edge = node.run_batch(test_files)
assert edge == f"output_1"
assert output == {"file_paths": test_files}