Make FileTypeClassifier more flexible (#2101)

* Make FileTypeClassifier more flexible

* Make supported_types a init parameter

* Add tests and fix a couple of bugs

* Formatting

* Fix mypy

* Implement feedback
This commit is contained in:
Sara Zan 2022-02-02 17:51:04 +01:00 committed by GitHub
parent 767f0025c6
commit 3a6e64b2a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 117 additions and 20 deletions

View File

@ -1,40 +1,72 @@
from multiprocessing.sharedctypes import Value
from typing import List, Union
from pathlib import Path
from haystack.nodes.base import BaseComponent
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"]
class FileTypeClassifier(BaseComponent):
"""
Route files in an Indexing Pipeline to corresponding file converters.
"""
outgoing_edges = 5
outgoing_edges = 10
def __init__(self):
self.set_config()
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
"""
Node that sends out files on a different output edge depending on their extension.
def _get_files_extension(self, file_paths: list) -> set:
:param supported_types: the file types that this node can distinguish.
Note that it's limited to a maximum of 10 outgoing edges, which
correspond each to a file extension. Such extension are, by default
`txt`, `pdf`, `md`, `docx`, `html`. Lists containing more than 10
elements will not be allowed. Lists with duplicate elements will
also be rejected.
"""
Return the file extensions
:param file_paths:
:return: set
"""
return {file_path.suffix.lstrip(".") for file_path in file_paths}
if len(supported_types) > 10:
raise ValueError("supported_types can't have more than 10 values.")
if len(set(supported_types)) != len(supported_types):
raise ValueError("supported_types can't contain duplicate values.")
def run(self, file_paths: Union[Path, List[Path]]): # type: ignore
self.set_config(supported_types=supported_types)
self.supported_types = supported_types
def _get_extension(self, file_paths: List[Path]) -> str:
"""
Return the output based on file extension
Return the extension found in the given list of files.
Also makes sure that all files have the same extension.
If this is not true, it throws an exception.
:param file_paths: the paths to extract the extension from
:return: a set of strings with all the extensions (without duplicates)
"""
if isinstance(file_paths, Path):
extension = file_paths[0].suffix
for path in file_paths:
if path.suffix != extension:
raise ValueError(f"Multiple file types are not allowed at once.")
return extension.lstrip(".")
def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore
"""
Sends out files on a different output edge depending on their extension.
:param file_paths: paths to route on different edges.
"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
extension: set = self._get_files_extension(file_paths)
if len(extension) > 1:
raise ValueError(f"Multiple files types are not allowed at once.")
paths = [Path(path) for path in file_paths]
output = {"file_paths": file_paths}
ext: str = extension.pop()
output = {"file_paths": paths}
extension = self._get_extension(paths)
try:
index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1
return output, f"output_{index}"
index = self.supported_types.index(extension) + 1
except ValueError:
raise Exception(f"Files with an extension '{ext}' are not supported.")
raise ValueError(f"Files of type '{extension}' are not supported. "
f"The supported types are: {self.supported_types}. "
"Consider using the 'supported_types' parameter to "
"change the types accepted by this node.")
return output, f"output_{index}"

View File

@ -141,6 +141,14 @@ def pytest_collection_modifyitems(config,items):
item.add_marker(skip_docstore)
@pytest.fixture
def tmpdir(tmpdir):
"""
Makes pytest's tmpdir fixture fully compatible with pathlib
"""
return Path(tmpdir)
@pytest.fixture(scope="function", autouse=True)
def gc_cleanup(request):
"""

View File

@ -0,0 +1,57 @@
import pytest
from pathlib import Path
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
def test_filetype_classifier_single_file(tmpdir):
node = FileTypeClassifier()
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
for edge_index, test_file in enumerate(test_files):
output, edge = node.run(test_file)
assert edge == f"output_{edge_index+1}"
assert output == {"file_paths": [test_file]}
def test_filetype_classifier_many_files(tmpdir):
node = FileTypeClassifier()
for edge_index, extension in enumerate(DEFAULT_TYPES):
test_files = [tmpdir/f"test_{idx}.{extension}" for idx in range(10)]
output, edge = node.run(test_files)
assert edge == f"output_{edge_index+1}"
assert output == {"file_paths": test_files}
def test_filetype_classifier_many_files_mixed_extensions(tmpdir):
node = FileTypeClassifier()
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
with pytest.raises(ValueError):
node.run(test_files)
def test_filetype_classifier_unsupported_extension(tmpdir):
node = FileTypeClassifier()
test_file = tmpdir/f"test.really_weird_extension"
with pytest.raises(ValueError):
node.run(test_file)
def test_filetype_classifier_custom_extensions(tmpdir):
node = FileTypeClassifier(supported_types=["my_extension"])
test_file = tmpdir/f"test.my_extension"
output, edge = node.run(test_file)
assert edge == f"output_1"
assert output == {"file_paths": [test_file]}
def test_filetype_classifier_too_many_custom_extensions():
with pytest.raises(ValueError):
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
def test_filetype_classifier_duplicate_custom_extensions():
with pytest.raises(ValueError):
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])