diff --git a/haystack/nodes/file_classifier/file_type.py b/haystack/nodes/file_classifier/file_type.py index 306c5d12f..ec6fd171d 100644 --- a/haystack/nodes/file_classifier/file_type.py +++ b/haystack/nodes/file_classifier/file_type.py @@ -1,40 +1,72 @@ +from multiprocessing.sharedctypes import Value from typing import List, Union from pathlib import Path from haystack.nodes.base import BaseComponent +DEFAULT_TYPES = ["txt", "pdf", "md", "docx", "html"] + + class FileTypeClassifier(BaseComponent): """ Route files in an Indexing Pipeline to corresponding file converters. """ - outgoing_edges = 5 + outgoing_edges = 10 - def __init__(self): - self.set_config() + def __init__(self, supported_types: List[str] = DEFAULT_TYPES): + """ + Node that sends out files on a different output edge depending on their extension. - def _get_files_extension(self, file_paths: list) -> set: + :param supported_types: the file types that this node can distinguish. + Note that it's limited to a maximum of 10 outgoing edges, which + correspond each to a file extension. Such extension are, by default + `txt`, `pdf`, `md`, `docx`, `html`. Lists containing more than 10 + elements will not be allowed. Lists with duplicate elements will + also be rejected. """ - Return the file extensions - :param file_paths: - :return: set - """ - return {file_path.suffix.lstrip(".") for file_path in file_paths} + if len(supported_types) > 10: + raise ValueError("supported_types can't have more than 10 values.") + if len(set(supported_types)) != len(supported_types): + raise ValueError("supported_types can't contain duplicate values.") - def run(self, file_paths: Union[Path, List[Path]]): # type: ignore + self.set_config(supported_types=supported_types) + self.supported_types = supported_types + + def _get_extension(self, file_paths: List[Path]) -> str: """ - Return the output based on file extension + Return the extension found in the given list of files. + Also makes sure that all files have the same extension. + If this is not true, it throws an exception. + + :param file_paths: the paths to extract the extension from + :return: a set of strings with all the extensions (without duplicates) """ - if isinstance(file_paths, Path): + extension = file_paths[0].suffix + + for path in file_paths: + if path.suffix != extension: + raise ValueError(f"Multiple file types are not allowed at once.") + + return extension.lstrip(".") + + def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Path, str]]]): # type: ignore + """ + Sends out files on a different output edge depending on their extension. + + :param file_paths: paths to route on different edges. + """ + if not isinstance(file_paths, list): file_paths = [file_paths] - extension: set = self._get_files_extension(file_paths) - if len(extension) > 1: - raise ValueError(f"Multiple files types are not allowed at once.") + paths = [Path(path) for path in file_paths] - output = {"file_paths": file_paths} - ext: str = extension.pop() + output = {"file_paths": paths} + extension = self._get_extension(paths) try: - index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1 - return output, f"output_{index}" + index = self.supported_types.index(extension) + 1 except ValueError: - raise Exception(f"Files with an extension '{ext}' are not supported.") + raise ValueError(f"Files of type '{extension}' are not supported. " + f"The supported types are: {self.supported_types}. " + "Consider using the 'supported_types' parameter to " + "change the types accepted by this node.") + return output, f"output_{index}" diff --git a/test/conftest.py b/test/conftest.py index 46548144d..41179c0aa 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -141,6 +141,14 @@ def pytest_collection_modifyitems(config,items): item.add_marker(skip_docstore) +@pytest.fixture +def tmpdir(tmpdir): + """ + Makes pytest's tmpdir fixture fully compatible with pathlib + """ + return Path(tmpdir) + + @pytest.fixture(scope="function", autouse=True) def gc_cleanup(request): """ diff --git a/test/test_filetype_classifier.py b/test/test_filetype_classifier.py new file mode 100644 index 000000000..086ea4214 --- /dev/null +++ b/test/test_filetype_classifier.py @@ -0,0 +1,57 @@ +import pytest +from pathlib import Path +from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES + + +def test_filetype_classifier_single_file(tmpdir): + node = FileTypeClassifier() + test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES] + + for edge_index, test_file in enumerate(test_files): + output, edge = node.run(test_file) + assert edge == f"output_{edge_index+1}" + assert output == {"file_paths": [test_file]} + + +def test_filetype_classifier_many_files(tmpdir): + node = FileTypeClassifier() + + for edge_index, extension in enumerate(DEFAULT_TYPES): + test_files = [tmpdir/f"test_{idx}.{extension}" for idx in range(10)] + + output, edge = node.run(test_files) + assert edge == f"output_{edge_index+1}" + assert output == {"file_paths": test_files} + + +def test_filetype_classifier_many_files_mixed_extensions(tmpdir): + node = FileTypeClassifier() + test_files = [tmpdir/f"test.{extension}" for extension in DEFAULT_TYPES] + + with pytest.raises(ValueError): + node.run(test_files) + + +def test_filetype_classifier_unsupported_extension(tmpdir): + node = FileTypeClassifier() + test_file = tmpdir/f"test.really_weird_extension" + with pytest.raises(ValueError): + node.run(test_file) + + +def test_filetype_classifier_custom_extensions(tmpdir): + node = FileTypeClassifier(supported_types=["my_extension"]) + test_file = tmpdir/f"test.my_extension" + output, edge = node.run(test_file) + assert edge == f"output_1" + assert output == {"file_paths": [test_file]} + + +def test_filetype_classifier_too_many_custom_extensions(): + with pytest.raises(ValueError): + FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)]) + + +def test_filetype_classifier_duplicate_custom_extensions(): + with pytest.raises(ValueError): + FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])