mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 02:16:32 +00:00
Make FileTypeClassifier more flexible (#2101)
* Make FileTypeClassifier more flexible * Make supported_types a init parameter * Add tests and fix a couple of bugs * Formatting * Fix mypy * Implement feedback
This commit is contained in:
parent
767f0025c6
commit
3a6e64b2a3
@ -1,40 +1,72 @@
|
||||
from multiprocessing.sharedctypes import Value
|
||||
from typing import List, Union
|
||||
from pathlib import Path
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"]
|
||||
|
||||
|
||||
class FileTypeClassifier(BaseComponent):
|
||||
"""
|
||||
Route files in an Indexing Pipeline to corresponding file converters.
|
||||
"""
|
||||
outgoing_edges = 5
|
||||
outgoing_edges = 10
|
||||
|
||||
def __init__(self):
|
||||
self.set_config()
|
||||
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
|
||||
"""
|
||||
Node that sends out files on a different output edge depending on their extension.
|
||||
|
||||
def _get_files_extension(self, file_paths: list) -> set:
|
||||
:param supported_types: the file types that this node can distinguish.
|
||||
Note that it's limited to a maximum of 10 outgoing edges, which
|
||||
correspond each to a file extension. Such extension are, by default
|
||||
`txt`, `pdf`, `md`, `docx`, `html`. Lists containing more than 10
|
||||
elements will not be allowed. Lists with duplicate elements will
|
||||
also be rejected.
|
||||
"""
|
||||
Return the file extensions
|
||||
:param file_paths:
|
||||
:return: set
|
||||
"""
|
||||
return {file_path.suffix.lstrip(".") for file_path in file_paths}
|
||||
if len(supported_types) > 10:
|
||||
raise ValueError("supported_types can't have more than 10 values.")
|
||||
if len(set(supported_types)) != len(supported_types):
|
||||
raise ValueError("supported_types can't contain duplicate values.")
|
||||
|
||||
def run(self, file_paths: Union[Path, List[Path]]): # type: ignore
|
||||
self.set_config(supported_types=supported_types)
|
||||
self.supported_types = supported_types
|
||||
|
||||
def _get_extension(self, file_paths: List[Path]) -> str:
|
||||
"""
|
||||
Return the output based on file extension
|
||||
Return the extension found in the given list of files.
|
||||
Also makes sure that all files have the same extension.
|
||||
If this is not true, it throws an exception.
|
||||
|
||||
:param file_paths: the paths to extract the extension from
|
||||
:return: a set of strings with all the extensions (without duplicates)
|
||||
"""
|
||||
if isinstance(file_paths, Path):
|
||||
extension = file_paths[0].suffix
|
||||
|
||||
for path in file_paths:
|
||||
if path.suffix != extension:
|
||||
raise ValueError(f"Multiple file types are not allowed at once.")
|
||||
|
||||
return extension.lstrip(".")
|
||||
|
||||
def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore
|
||||
"""
|
||||
Sends out files on a different output edge depending on their extension.
|
||||
|
||||
:param file_paths: paths to route on different edges.
|
||||
"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
extension: set = self._get_files_extension(file_paths)
|
||||
if len(extension) > 1:
|
||||
raise ValueError(f"Multiple files types are not allowed at once.")
|
||||
paths = [Path(path) for path in file_paths]
|
||||
|
||||
output = {"file_paths": file_paths}
|
||||
ext: str = extension.pop()
|
||||
output = {"file_paths": paths}
|
||||
extension = self._get_extension(paths)
|
||||
try:
|
||||
index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1
|
||||
return output, f"output_{index}"
|
||||
index = self.supported_types.index(extension) + 1
|
||||
except ValueError:
|
||||
raise Exception(f"Files with an extension '{ext}' are not supported.")
|
||||
raise ValueError(f"Files of type '{extension}' are not supported. "
|
||||
f"The supported types are: {self.supported_types}. "
|
||||
"Consider using the 'supported_types' parameter to "
|
||||
"change the types accepted by this node.")
|
||||
return output, f"output_{index}"
|
||||
|
@ -141,6 +141,14 @@ def pytest_collection_modifyitems(config,items):
|
||||
item.add_marker(skip_docstore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmpdir(tmpdir):
|
||||
"""
|
||||
Makes pytest's tmpdir fixture fully compatible with pathlib
|
||||
"""
|
||||
return Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def gc_cleanup(request):
|
||||
"""
|
||||
|
57
test/test_filetype_classifier.py
Normal file
57
test/test_filetype_classifier.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
|
||||
|
||||
|
||||
def test_filetype_classifier_single_file(tmpdir):
|
||||
node = FileTypeClassifier()
|
||||
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
|
||||
for edge_index, test_file in enumerate(test_files):
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_{edge_index+1}"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
def test_filetype_classifier_many_files(tmpdir):
|
||||
node = FileTypeClassifier()
|
||||
|
||||
for edge_index, extension in enumerate(DEFAULT_TYPES):
|
||||
test_files = [tmpdir/f"test_{idx}.{extension}" for idx in range(10)]
|
||||
|
||||
output, edge = node.run(test_files)
|
||||
assert edge == f"output_{edge_index+1}"
|
||||
assert output == {"file_paths": test_files}
|
||||
|
||||
|
||||
def test_filetype_classifier_many_files_mixed_extensions(tmpdir):
|
||||
node = FileTypeClassifier()
|
||||
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
node.run(test_files)
|
||||
|
||||
|
||||
def test_filetype_classifier_unsupported_extension(tmpdir):
|
||||
node = FileTypeClassifier()
|
||||
test_file = tmpdir/f"test.really_weird_extension"
|
||||
with pytest.raises(ValueError):
|
||||
node.run(test_file)
|
||||
|
||||
|
||||
def test_filetype_classifier_custom_extensions(tmpdir):
|
||||
node = FileTypeClassifier(supported_types=["my_extension"])
|
||||
test_file = tmpdir/f"test.my_extension"
|
||||
output, edge = node.run(test_file)
|
||||
assert edge == f"output_1"
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
def test_filetype_classifier_too_many_custom_extensions():
|
||||
with pytest.raises(ValueError):
|
||||
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
|
||||
|
||||
|
||||
def test_filetype_classifier_duplicate_custom_extensions():
|
||||
with pytest.raises(ValueError):
|
||||
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
Loading…
x
Reference in New Issue
Block a user