mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 02:16:32 +00:00
Make FileTypeClassifier more flexible (#2101)
* Make FileTypeClassifier more flexible * Make supported_types a init parameter * Add tests and fix a couple of bugs * Formatting * Fix mypy * Implement feedback
This commit is contained in:
parent
767f0025c6
commit
3a6e64b2a3
@ -1,40 +1,72 @@
|
|||||||
|
from multiprocessing.sharedctypes import Value
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"]
|
||||||
|
|
||||||
|
|
||||||
class FileTypeClassifier(BaseComponent):
|
class FileTypeClassifier(BaseComponent):
|
||||||
"""
|
"""
|
||||||
Route files in an Indexing Pipeline to corresponding file converters.
|
Route files in an Indexing Pipeline to corresponding file converters.
|
||||||
"""
|
"""
|
||||||
outgoing_edges = 5
|
outgoing_edges = 10
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
|
||||||
self.set_config()
|
"""
|
||||||
|
Node that sends out files on a different output edge depending on their extension.
|
||||||
|
|
||||||
def _get_files_extension(self, file_paths: list) -> set:
|
:param supported_types: the file types that this node can distinguish.
|
||||||
|
Note that it's limited to a maximum of 10 outgoing edges, which
|
||||||
|
correspond each to a file extension. Such extension are, by default
|
||||||
|
`txt`, `pdf`, `md`, `docx`, `html`. Lists containing more than 10
|
||||||
|
elements will not be allowed. Lists with duplicate elements will
|
||||||
|
also be rejected.
|
||||||
"""
|
"""
|
||||||
Return the file extensions
|
if len(supported_types) > 10:
|
||||||
:param file_paths:
|
raise ValueError("supported_types can't have more than 10 values.")
|
||||||
:return: set
|
if len(set(supported_types)) != len(supported_types):
|
||||||
"""
|
raise ValueError("supported_types can't contain duplicate values.")
|
||||||
return {file_path.suffix.lstrip(".") for file_path in file_paths}
|
|
||||||
|
|
||||||
def run(self, file_paths: Union[Path, List[Path]]): # type: ignore
|
self.set_config(supported_types=supported_types)
|
||||||
|
self.supported_types = supported_types
|
||||||
|
|
||||||
|
def _get_extension(self, file_paths: List[Path]) -> str:
|
||||||
"""
|
"""
|
||||||
Return the output based on file extension
|
Return the extension found in the given list of files.
|
||||||
|
Also makes sure that all files have the same extension.
|
||||||
|
If this is not true, it throws an exception.
|
||||||
|
|
||||||
|
:param file_paths: the paths to extract the extension from
|
||||||
|
:return: a set of strings with all the extensions (without duplicates)
|
||||||
"""
|
"""
|
||||||
if isinstance(file_paths, Path):
|
extension = file_paths[0].suffix
|
||||||
|
|
||||||
|
for path in file_paths:
|
||||||
|
if path.suffix != extension:
|
||||||
|
raise ValueError(f"Multiple file types are not allowed at once.")
|
||||||
|
|
||||||
|
return extension.lstrip(".")
|
||||||
|
|
||||||
|
def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore
|
||||||
|
"""
|
||||||
|
Sends out files on a different output edge depending on their extension.
|
||||||
|
|
||||||
|
:param file_paths: paths to route on different edges.
|
||||||
|
"""
|
||||||
|
if not isinstance(file_paths, list):
|
||||||
file_paths = [file_paths]
|
file_paths = [file_paths]
|
||||||
|
|
||||||
extension: set = self._get_files_extension(file_paths)
|
paths = [Path(path) for path in file_paths]
|
||||||
if len(extension) > 1:
|
|
||||||
raise ValueError(f"Multiple files types are not allowed at once.")
|
|
||||||
|
|
||||||
output = {"file_paths": file_paths}
|
output = {"file_paths": paths}
|
||||||
ext: str = extension.pop()
|
extension = self._get_extension(paths)
|
||||||
try:
|
try:
|
||||||
index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1
|
index = self.supported_types.index(extension) + 1
|
||||||
return output, f"output_{index}"
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise Exception(f"Files with an extension '{ext}' are not supported.")
|
raise ValueError(f"Files of type '{extension}' are not supported. "
|
||||||
|
f"The supported types are: {self.supported_types}. "
|
||||||
|
"Consider using the 'supported_types' parameter to "
|
||||||
|
"change the types accepted by this node.")
|
||||||
|
return output, f"output_{index}"
|
||||||
|
@ -141,6 +141,14 @@ def pytest_collection_modifyitems(config,items):
|
|||||||
item.add_marker(skip_docstore)
|
item.add_marker(skip_docstore)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmpdir(tmpdir):
|
||||||
|
"""
|
||||||
|
Makes pytest's tmpdir fixture fully compatible with pathlib
|
||||||
|
"""
|
||||||
|
return Path(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def gc_cleanup(request):
|
def gc_cleanup(request):
|
||||||
"""
|
"""
|
||||||
|
57
test/test_filetype_classifier.py
Normal file
57
test/test_filetype_classifier.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_single_file(tmpdir):
|
||||||
|
node = FileTypeClassifier()
|
||||||
|
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||||
|
|
||||||
|
for edge_index, test_file in enumerate(test_files):
|
||||||
|
output, edge = node.run(test_file)
|
||||||
|
assert edge == f"output_{edge_index+1}"
|
||||||
|
assert output == {"file_paths": [test_file]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_many_files(tmpdir):
|
||||||
|
node = FileTypeClassifier()
|
||||||
|
|
||||||
|
for edge_index, extension in enumerate(DEFAULT_TYPES):
|
||||||
|
test_files = [tmpdir/f"test_{idx}.{extension}" for idx in range(10)]
|
||||||
|
|
||||||
|
output, edge = node.run(test_files)
|
||||||
|
assert edge == f"output_{edge_index+1}"
|
||||||
|
assert output == {"file_paths": test_files}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_many_files_mixed_extensions(tmpdir):
|
||||||
|
node = FileTypeClassifier()
|
||||||
|
test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
node.run(test_files)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_unsupported_extension(tmpdir):
|
||||||
|
node = FileTypeClassifier()
|
||||||
|
test_file = tmpdir/f"test.really_weird_extension"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
node.run(test_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_custom_extensions(tmpdir):
|
||||||
|
node = FileTypeClassifier(supported_types=["my_extension"])
|
||||||
|
test_file = tmpdir/f"test.my_extension"
|
||||||
|
output, edge = node.run(test_file)
|
||||||
|
assert edge == f"output_1"
|
||||||
|
assert output == {"file_paths": [test_file]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_too_many_custom_extensions():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_filetype_classifier_duplicate_custom_extensions():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
Loading…
x
Reference in New Issue
Block a user