mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 16:59:47 +00:00
extract extension based on file's content (#2330)
* extract extension based on file's content * Add python-magic dependency * fix the _estimate_extension function and lowercase the file extensions * check if the FileTypeClassifier can be imported * add test and new file types * fix typing * import Optional * revert Optional and make sure a string is always returned * fix test so that it skips markdown files * Emulate Code & Docs action * Generate schemas * Tidy up test code & extensioness files * Improve error messages * Revert schema changes * Emulate black and docs CI again
This commit is contained in:
parent
ae712fe6bf
commit
b94d9effaf
@ -1 +1,5 @@
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier
|
||||
from haystack.utils.import_utils import safe_import
|
||||
|
||||
FileTypeClassifier = safe_import(
|
||||
"haystack.nodes.file_classifier.file_type", "FileTypeClassifier", "preprocessing"
|
||||
) # Has optional dependencies
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import mimetypes
|
||||
from multiprocessing.sharedctypes import Value
|
||||
from typing import List, Union
|
||||
from pathlib import Path
|
||||
import magic
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
@ -28,12 +30,24 @@ class FileTypeClassifier(BaseComponent):
|
||||
if len(supported_types) > 10:
|
||||
raise ValueError("supported_types can't have more than 10 values.")
|
||||
if len(set(supported_types)) != len(supported_types):
|
||||
raise ValueError("supported_types can't contain duplicate values.")
|
||||
duplicates = supported_types
|
||||
for item in set(supported_types):
|
||||
duplicates.remove(item)
|
||||
raise ValueError(f"supported_types can't contain duplicate values ({duplicates}).")
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.supported_types = supported_types
|
||||
|
||||
def _estimate_extension(self, file_path: Path) -> str:
|
||||
"""
|
||||
Return the extension found based on the contents of the given file
|
||||
|
||||
:param file_path: the path to extract the extension from
|
||||
"""
|
||||
extension = magic.from_file(str(file_path), mime=True)
|
||||
return mimetypes.guess_extension(extension) or ""
|
||||
|
||||
def _get_extension(self, file_paths: List[Path]) -> str:
|
||||
"""
|
||||
Return the extension found in the given list of files.
|
||||
@ -41,12 +55,17 @@ class FileTypeClassifier(BaseComponent):
|
||||
If this is not true, it throws an exception.
|
||||
|
||||
:param file_paths: the paths to extract the extension from
|
||||
:return: a set of strings with all the extensions (without duplicates)
|
||||
:return: a set of strings with all the extensions (without duplicates), the extension will be guessed if the file has none
|
||||
"""
|
||||
extension = file_paths[0].suffix
|
||||
extension = file_paths[0].suffix.lower()
|
||||
if extension == "":
|
||||
extension = self._estimate_extension(file_paths[0])
|
||||
|
||||
for path in file_paths:
|
||||
if path.suffix != extension:
|
||||
path_suffix = path.suffix.lower()
|
||||
if path_suffix == "":
|
||||
path_suffix = self._estimate_extension(path)
|
||||
if path_suffix != extension:
|
||||
raise ValueError(f"Multiple file types are not allowed at once.")
|
||||
|
||||
return extension.lstrip(".")
|
||||
@ -68,7 +87,7 @@ class FileTypeClassifier(BaseComponent):
|
||||
index = self.supported_types.index(extension) + 1
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Files of type '{extension}' are not supported. "
|
||||
f"Files of type '{extension}' ({paths[0]}) are not supported. "
|
||||
f"The supported types are: {self.supported_types}. "
|
||||
"Consider using the 'supported_types' parameter to "
|
||||
"change the types accepted by this node."
|
||||
|
||||
@ -159,6 +159,7 @@ crawler =
|
||||
preprocessing =
|
||||
beautifulsoup4
|
||||
markdown
|
||||
python-magic # Depends on libmagic: https://pypi.org/project/python-magic/
|
||||
ocr =
|
||||
pytesseract==0.3.7
|
||||
pillow
|
||||
|
||||
BIN
test/samples/extensionless_files/docx_file
Normal file
BIN
test/samples/extensionless_files/docx_file
Normal file
Binary file not shown.
BIN
test/samples/extensionless_files/gif_file
Normal file
BIN
test/samples/extensionless_files/gif_file
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 64 B |
2
test/samples/extensionless_files/html_file
Normal file
2
test/samples/extensionless_files/html_file
Normal file
@ -0,0 +1,2 @@
|
||||
<!DOCTYPE html>
|
||||
<a>sample</a>
|
||||
BIN
test/samples/extensionless_files/jpg_file
Normal file
BIN
test/samples/extensionless_files/jpg_file
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.3 KiB |
BIN
test/samples/extensionless_files/mp3_file
Normal file
BIN
test/samples/extensionless_files/mp3_file
Normal file
Binary file not shown.
BIN
test/samples/extensionless_files/odt_file
Normal file
BIN
test/samples/extensionless_files/odt_file
Normal file
Binary file not shown.
BIN
test/samples/extensionless_files/pdf_file
Normal file
BIN
test/samples/extensionless_files/pdf_file
Normal file
Binary file not shown.
BIN
test/samples/extensionless_files/png_file
Normal file
BIN
test/samples/extensionless_files/png_file
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.6 KiB |
BIN
test/samples/extensionless_files/pptx_file
Normal file
BIN
test/samples/extensionless_files/pptx_file
Normal file
Binary file not shown.
1
test/samples/extensionless_files/txt_file
Normal file
1
test/samples/extensionless_files/txt_file
Normal file
@ -0,0 +1 @@
|
||||
Sample
|
||||
BIN
test/samples/extensionless_files/wav_file
Normal file
BIN
test/samples/extensionless_files/wav_file
Normal file
Binary file not shown.
BIN
test/samples/extensionless_files/zip_file
Normal file
BIN
test/samples/extensionless_files/zip_file
Normal file
Binary file not shown.
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
|
||||
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
def test_filetype_classifier_single_file(tmp_path):
|
||||
node = FileTypeClassifier()
|
||||
@ -55,3 +57,25 @@ def test_filetype_classifier_too_many_custom_extensions():
|
||||
def test_filetype_classifier_duplicate_custom_extensions():
|
||||
with pytest.raises(ValueError):
|
||||
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
||||
|
||||
|
||||
def test_filetype_classifier_text_files_without_extension():
|
||||
tested_types = ["docx", "html", "odt", "pdf", "pptx", "txt"]
|
||||
node = FileTypeClassifier(supported_types=tested_types)
|
||||
test_files = [SAMPLES_PATH / "extensionless_files" / f"{type_name}_file" for type_name in tested_types]
|
||||
|
||||
for edge_index, test_file in enumerate(test_files):
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_{edge_index+1}"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
def test_filetype_classifier_other_files_without_extension():
|
||||
tested_types = ["gif", "jpg", "mp3", "png", "wav", "zip"]
|
||||
node = FileTypeClassifier(supported_types=tested_types)
|
||||
test_files = [SAMPLES_PATH / "extensionless_files" / f"{type_name}_file" for type_name in tested_types]
|
||||
|
||||
for edge_index, test_file in enumerate(test_files):
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_{edge_index+1}"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user