mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Improve performance and add default media support in FileTypeClassifier (#5083)
* feat: add media outgoing edge to FileTypeClassifier * Add release note * Update language --------- Co-authored-by: Daniel Bichuetti <daniel.bichuetti@gmail.com> Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
d46c84bb61
commit
84ed954c8c
@ -14,7 +14,9 @@ with LazyImport() as magic_import:
|
||||
import magic
|
||||
|
||||
|
||||
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"]
|
||||
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html", "media"]
|
||||
|
||||
DEFAULT_MEDIA_TYPES = ["mp3", "mp4", "mpeg", "m4a", "wav", "webm"]
|
||||
|
||||
|
||||
class FileTypeClassifier(BaseComponent):
|
||||
@ -24,15 +26,20 @@ class FileTypeClassifier(BaseComponent):
|
||||
|
||||
outgoing_edges = len(DEFAULT_TYPES)
|
||||
|
||||
def __init__(self, supported_types: Optional[List[str]] = None):
|
||||
def __init__(self, supported_types: Optional[List[str]] = None, full_analysis: bool = False):
|
||||
"""
|
||||
Node that sends out files on a different output edge depending on their extension.
|
||||
|
||||
:param supported_types: The file types that this node can distinguish between.
|
||||
If no value is provided, the value created by default comprises: `txt`, `pdf`, `md`, `docx`, and `html`.
|
||||
Lists with duplicate elements are not allowed.
|
||||
:param supported_types: The file types this node distinguishes. Optional.
|
||||
If you don't provide any value, the default is: `txt`, `pdf`, `md`, `docx`, and `html`.
|
||||
You can't use lists with duplicate elements.
|
||||
:param full_analysis: If True, the whole file is analyzed to determine the file type.
|
||||
If False, only the first 2049 bytes are analyzed.
|
||||
"""
|
||||
self.full_analysis = full_analysis
|
||||
self._default_types = False
|
||||
if supported_types is None:
|
||||
self._default_types = True
|
||||
supported_types = DEFAULT_TYPES
|
||||
if len(set(supported_types)) != len(supported_types):
|
||||
duplicates = supported_types
|
||||
@ -56,9 +63,17 @@ class FileTypeClassifier(BaseComponent):
|
||||
:param file_path: the path to extract the extension from
|
||||
"""
|
||||
try:
|
||||
magic_import.check()
|
||||
extension = magic.from_file(str(file_path), mime=True)
|
||||
return mimetypes.guess_extension(extension) or ""
|
||||
with open(file_path, "rb") as f:
|
||||
if self.full_analysis:
|
||||
buffer = f.read()
|
||||
else:
|
||||
buffer = f.read(2049)
|
||||
extension = magic.from_buffer(buffer, mime=True)
|
||||
real_extension = mimetypes.guess_extension(extension) or ""
|
||||
real_extension = real_extension.lstrip(".")
|
||||
if self._default_types and real_extension in DEFAULT_MEDIA_TYPES:
|
||||
return "media"
|
||||
return real_extension or ""
|
||||
except (NameError, ImportError):
|
||||
logger.error(
|
||||
"The type of '%s' could not be guessed, probably because 'python-magic' is not installed. Ignoring this error."
|
||||
@ -76,18 +91,19 @@ class FileTypeClassifier(BaseComponent):
|
||||
:param file_paths: the paths to extract the extension from
|
||||
:return: a set of strings with all the extensions (without duplicates), the extension will be guessed if the file has none
|
||||
"""
|
||||
extension = file_paths[0].suffix.lower()
|
||||
if extension == "":
|
||||
extension = file_paths[0].suffix.lower().lstrip(".")
|
||||
|
||||
if extension == "" or (self._default_types and extension in DEFAULT_MEDIA_TYPES):
|
||||
extension = self._estimate_extension(file_paths[0])
|
||||
|
||||
for path in file_paths:
|
||||
path_suffix = path.suffix.lower()
|
||||
if path_suffix == "":
|
||||
path_suffix = path.suffix.lower().lstrip(".")
|
||||
if path_suffix == "" or (self._default_types and path_suffix in DEFAULT_MEDIA_TYPES):
|
||||
path_suffix = self._estimate_extension(path)
|
||||
if path_suffix != extension:
|
||||
raise ValueError("Multiple file types are not allowed at once.")
|
||||
raise ValueError("Multiple non-default file types are not allowed at once.")
|
||||
|
||||
return extension.lstrip(".")
|
||||
return extension
|
||||
|
||||
def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore
|
||||
"""
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enhance FileTypeClassifier to detect media file types like mp3, mp4, mpeg, m4a, and similar.
|
||||
@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import haystack
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -101,3 +104,67 @@ def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch,
|
||||
with caplog.at_level(logging.ERROR):
|
||||
node.run(samples_path / "extensionless_files" / f"pdf_file")
|
||||
assert "'python-magic' is not installed" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filetype_classifier_media_extensions_positive(tmp_path):
|
||||
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
|
||||
for idx in range(len(DEFAULT_MEDIA_TYPES)):
|
||||
test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}"
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_{idx+1}"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filetype_classifier_media_extensions_negative(tmp_path):
|
||||
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
|
||||
|
||||
test_file = tmp_path / f"test.txt"
|
||||
with pytest.raises(ValueError, match="Files of type 'txt'"):
|
||||
node.run(test_file)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.skipif(platform.system() in ["Windows", "Darwin"], reason="python-magic not available")
|
||||
def test_filetype_classifier_estimate_media_extensions(tmp_path):
|
||||
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
|
||||
|
||||
test_file = "test/samples/audio/answer.wav"
|
||||
new_file_name = "test_wav_no_extension"
|
||||
new_file_path = os.path.join(tmp_path, new_file_name)
|
||||
|
||||
shutil.copy(test_file, new_file_path)
|
||||
|
||||
output, edge = node.run(new_file_path)
|
||||
assert edge == f"output_5"
|
||||
assert output == {"file_paths": [Path(new_file_path)]}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filetype_classifier_batched_various_media_extensions(tmp_path):
|
||||
test_files = []
|
||||
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
|
||||
for idx in range(len(DEFAULT_MEDIA_TYPES)):
|
||||
test_file = tmp_path / f"test.{DEFAULT_MEDIA_TYPES[idx]}"
|
||||
test_files.append(test_file)
|
||||
|
||||
# we can't classify a list of files with different media extensions
|
||||
with pytest.raises(ValueError, match="Multiple non-default file types are not allowed at once."):
|
||||
node.run_batch(test_files)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filetype_classifier_batched_same_media_extensions(tmp_path):
|
||||
test_files = []
|
||||
batch_size = 5
|
||||
file_index = 0
|
||||
node = FileTypeClassifier(supported_types=DEFAULT_MEDIA_TYPES)
|
||||
for idx in range(batch_size):
|
||||
test_file = tmp_path / f"test-{idx}.{DEFAULT_MEDIA_TYPES[file_index]}"
|
||||
test_files.append(test_file)
|
||||
|
||||
# we should be able to pass a list of files with the same extension
|
||||
output, edge = node.run_batch(test_files)
|
||||
assert edge == f"output_1"
|
||||
assert output == {"file_paths": test_files}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user