haystack/test/preview/components/classifiers/test_file_classifier.py
ZanSara b1daa7c647
chore: migrate to canals==0.7.0 (#5647)
* add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0

* canals==0.7.0

* whisper components

* add to_dict/from_dict stubs

* import serialization methods in init to hide canals imports

* reno

* export deserializationerror too

* Update haystack/preview/__init__.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* serialization methods for LocalWhisperTranscriber (#5648)

* chore: serialization methods for `FileExtensionClassifier` (#5651)

* serialization methods for FileExtensionClassifier

* Update test_file_classifier.py

* chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652)

* serialization methods for SentenceTransformersDocumentEmbedder

* fix device management

* serialization methods for SentenceTransformersTextEmbedder (#5653)

* serialization methods for TextFileToDocument (#5654)

* chore: serialization methods for `RemoteWhisperTranscriber` (#5650)

* serialization methods for RemoteWhisperTranscriber

* remove patches

* Add default to_dict and from_dict in document stores built with factory (#5674)

* fix tests (#5671)

* chore: simplify serialization methods for `MemoryDocumentStore` (#5667)

* simplify serialization for MemoryDocumentStore

* remove redundant tests

* pylint

* chore: serialization methods for `MemoryRetriever` (#5663)

* serialization method for MemoryRetriever

* more tests

* remove hash from default_document_store_to_dict

* remove diff in factory.py

* chore: serialization methods for `DocumentWriter` (#5661)

* serialization methods for DocumentWriter

* more tests

* use factory

* black

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00

100 lines
3.8 KiB
Python

import sys
import pytest
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionClassifier:
@pytest.mark.unit
def test_to_dict(self):
component = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict()
assert data == {
"type": "FileExtensionClassifier",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "FileExtensionClassifier",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
component = FileExtensionClassifier.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""
Test if the component runs correctly in the simplest happy path.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]
@pytest.mark.unit
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=[])
assert not output
@pytest.mark.unit
def test_unlisted_extensions(self, preview_samples_path):
"""
Test that the component correctly handles files with non specified mime types.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
assert str(output["unclassified"][0]).endswith("ignored.mp3")
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")
@pytest.mark.unit
def test_no_extension(self, preview_samples_path):
"""
Test that the component ignores files with no extension.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1
@pytest.mark.unit
def test_unknown_mime_type(self):
"""
Test that the component handles files with unknown mime types.
"""
with pytest.raises(ValueError, match="Unknown mime type:"):
FileExtensionClassifier(mime_types=["type_invalid"])