chore: migrate to canals==0.7.0 (#5647)

* add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0

* canals==0.7.0

* whisper components

* add to_dict/from_dict stubs

* import serialization methods in init to hide canals imports

* reno

* export deserializationerror too

* Update haystack/preview/__init__.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* serialization methods for LocalWhisperTranscriber (#5648)

* chore: serialization methods for `FileExtensionClassifier` (#5651)

* serialization methods for FileExtensionClassifier

* Update test_file_classifier.py

* chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652)

* serialization methods for SentenceTransformersDocumentEmbedder

* fix device management

* serialization methods for SentenceTransformersTextEmbedder (#5653)

* serialization methods for TextFileToDocument (#5654)

* chore: serialization methods for `RemoteWhisperTranscriber` (#5650)

* serialization methods for RemoteWhisperTranscriber

* remove patches

* Add default to_dict and from_dict in document stores built with factory (#5674)

* fix tests (#5671)

* chore: simplify serialization methods for `MemoryDocumentStore` (#5667)

* simplify serialization for MemoryDocumentStore

* remove redundant tests

* pylint

* chore: serialization methods for `MemoryRetriever` (#5663)

* serialization method for MemoryRetriever

* more tests

* remove hash from default_document_store_to_dict

* remove diff in factory.py

* chore: serialization methods for `DocumentWriter` (#5661)

* serialization methods for DocumentWriter

* more tests

* use factory

* black

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-08-29 18:15:07 +02:00 committed by GitHub
parent a613b1b7f5
commit b1daa7c647
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 699 additions and 171 deletions

View File

@ -1,2 +1,4 @@
from canals import component, Pipeline
from canals.serialization import default_from_dict, default_to_dict
from canals.errors import DeserializationError
from haystack.preview.dataclasses import *

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
import whisper
from haystack.preview import component, Document
from haystack.preview import component, Document, default_to_dict, default_from_dict
logger = logging.getLogger(__name__)
@ -55,6 +55,21 @@ class LocalWhisperTranscriber:
if not self._model:
self._model = whisper.load_model(self.model_name, device=self.device)
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""

View File

@ -6,7 +6,7 @@ import logging
from pathlib import Path
from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document
from haystack.preview import component, Document, default_to_dict, default_from_dict
logger = logging.getLogger(__name__)
@ -49,17 +49,29 @@ class RemoteWhisperTranscriber:
if not api_key:
raise ValueError("API key is None.")
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
self.whisper_params = whisper_params or {}
self.model_name = model_name
self.init_parameters = {
"api_key": self.api_key,
"model_name": self.model_name,
"api_base": self.api_base,
"whisper_params": self.whisper_params,
}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name=self.model_name,
api_key=self.api_key,
api_base=self.api_base,
whisper_params=self.whisper_params,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):

View File

@ -2,9 +2,9 @@ import logging
import mimetypes
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Optional
from typing import List, Union, Optional, Dict, Any
from haystack.preview import component
from haystack.preview import component, default_from_dict, default_to_dict
logger = logging.getLogger(__name__)
@ -38,12 +38,22 @@ class FileExtensionClassifier:
f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter"
)
# save the init parameters for serialization
self.init_parameters = {"mime_types": mime_types}
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, mime_types=self.mime_types)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionClassifier":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def run(self, paths: List[Union[str, Path]]):
"""
Run the FileExtensionClassifier.

View File

@ -1,7 +1,6 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component
from haystack.preview import Document
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -42,7 +41,7 @@ class SentenceTransformersDocumentEmbedder:
self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.device = device or "cpu"
self.use_auth_token = use_auth_token
self.batch_size = batch_size
self.progress_bar = progress_bar
@ -50,6 +49,29 @@ class SentenceTransformersDocumentEmbedder:
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
device=self.device,
use_auth_token=self.use_auth_token,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
metadata_fields_to_embed=self.metadata_fields_to_embed,
embedding_separator=self.embedding_separator,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -40,7 +40,7 @@ class SentenceTransformersTextEmbedder:
self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.device = device or "cpu"
self.use_auth_token = use_auth_token
self.prefix = prefix
self.suffix = suffix
@ -48,6 +48,29 @@ class SentenceTransformersTextEmbedder:
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
device=self.device,
use_auth_token=self.use_auth_token,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.

View File

@ -1,12 +1,12 @@
import logging
from pathlib import Path
from typing import Optional, List, Union, Dict
from typing import Optional, List, Union, Dict, Any
from canals.errors import PipelineRuntimeError
from tqdm import tqdm
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import Document, component
from haystack.preview import Document, component, default_to_dict, default_from_dict
with LazyImport("Run 'pip install farm-haystack[preprocessing]'") as langdetect_import:
import langdetect
@ -61,6 +61,27 @@ class TextFileToDocument:
self.id_hash_keys = id_hash_keys or []
self.progress_bar = progress_bar
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
encoding=self.encoding,
remove_numeric_tables=self.remove_numeric_tables,
numeric_row_threshold=self.numeric_row_threshold,
valid_languages=self.valid_languages,
id_hash_keys=self.id_hash_keys,
progress_bar=self.progress_bar,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(
self,

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.preview.document_stores import MemoryDocumentStore, document_store
@component
@ -41,6 +41,33 @@ class MemoryRetriever:
self.top_k = top_k
self.scale_score = scale_score
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")
docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)
@component.output_types(documents=List[List[Document]])
def run(
self,

View File

@ -1,7 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Dict, Any
from haystack.preview import component, Document
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy
from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy, document_store
@component
@ -19,6 +19,31 @@ class DocumentWriter:
self.document_store = document_store
self.policy = policy
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.")
docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
return default_from_dict(cls, data)
def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None):
"""
Run DocumentWriter on the given input data.

View File

@ -1,9 +1,5 @@
from typing import Dict, Any, Type
import logging
from haystack.preview.document_stores.protocols import DocumentStore
from haystack.preview.document_stores.errors import DocumentStoreDeserializationError
logger = logging.getLogger(__name__)
@ -40,30 +36,3 @@ class _DocumentStore:
document_store = _DocumentStore()
def default_document_store_to_dict(store_: DocumentStore) -> Dict[str, Any]:
"""
Default DocumentStore serializer.
Serializes a DocumentStore to a dictionary.
"""
return {
"hash": id(store_),
"type": store_.__class__.__name__,
"init_parameters": getattr(store_, "init_parameters", {}),
}
def default_document_store_from_dict(cls: Type[DocumentStore], data: Dict[str, Any]) -> DocumentStore:
"""
Default DocumentStore deserializer.
The "type" field in `data` must match the class that is being deserialized into.
"""
init_params = data.get("init_parameters", {})
if "type" not in data:
raise DocumentStoreDeserializationError("Missing 'type' in DocumentStore serialization data")
if data["type"] != cls.__name__:
raise DocumentStoreDeserializationError(
f"DocumentStore '{data['type']}' can't be deserialized as '{cls.__name__}'"
)
return cls(**init_params)

View File

@ -12,7 +12,3 @@ class DuplicateDocumentError(DocumentStoreError):
class MissingDocumentError(DocumentStoreError):
pass
class DocumentStoreDeserializationError(DocumentStoreError):
pass

View File

@ -7,11 +7,8 @@ import numpy as np
import rank_bm25
from tqdm.auto import tqdm
from haystack.preview.document_stores.decorator import (
document_store,
default_document_store_to_dict,
default_document_store_from_dict,
)
from haystack.preview import default_from_dict, default_to_dict
from haystack.preview.document_stores.decorator import document_store
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.protocols import DuplicatePolicy, DocumentStore
from haystack.preview.document_stores.memory._filters import match
@ -44,6 +41,7 @@ class MemoryDocumentStore:
Initializes the DocumentStore.
"""
self.storage: Dict[str, Document] = {}
self._bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall
algorithm_class = getattr(rank_bm25, bm25_algorithm)
if algorithm_class is None:
@ -51,25 +49,23 @@ class MemoryDocumentStore:
self.bm25_algorithm = algorithm_class
self.bm25_parameters = bm25_parameters or {}
# Used to convert this instance to a dictionary for serialization
self.init_parameters = {
"bm25_tokenization_regex": bm25_tokenization_regex,
"bm25_algorithm": bm25_algorithm,
"bm25_parameters": self.bm25_parameters,
}
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this store to a dictionary.
"""
return default_document_store_to_dict(self)
return default_to_dict(
self,
bm25_tokenization_regex=self._bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm.__name__,
bm25_parameters=self.bm25_parameters,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore":
"""
Deserializes the store from a dictionary.
"""
return default_document_store_from_dict(cls, data)
return default_from_dict(cls, data)
def count_documents(self) -> int:
"""

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, Optional, Tuple, Type, List, Union
from haystack.preview import default_to_dict, default_from_dict
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import document_store, DocumentStore, DuplicatePolicy
@ -96,11 +97,16 @@ def document_store_class(
def delete_documents(self, document_ids: List[str]) -> None:
return
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self)
fields = {
"count_documents": count_documents,
"filter_documents": filter_documents,
"write_documents": write_documents,
"delete_documents": delete_documents,
"to_dict": to_dict,
"from_dict": classmethod(default_from_dict),
}
if extra_fields is not None:

View File

@ -79,7 +79,7 @@ dependencies = [
"jsonschema",
# Preview
"canals==0.5.0",
"canals==0.7.0",
# Agent events
"events",

View File

@ -0,0 +1,4 @@
---
preview:
- Migrate all components to Canals==0.7.0
- Add serialization and deserialization methods for all Haystack components

View File

@ -26,6 +26,47 @@ class TestLocalWhisperTranscriber:
with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"):
LocalWhisperTranscriber(model_name_or_path="whisper-1")
@pytest.mark.unit
def test_to_dict(self):
transcriber = LocalWhisperTranscriber()
data = transcriber.to_dict()
assert data == {
"type": "LocalWhisperTranscriber",
"init_parameters": {"model_name_or_path": "large", "device": "cpu", "whisper_params": {}},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
transcriber = LocalWhisperTranscriber(
model_name_or_path="tiny",
device="cuda",
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
)
data = transcriber.to_dict()
assert data == {
"type": "LocalWhisperTranscriber",
"init_parameters": {
"model_name_or_path": "tiny",
"device": "cuda",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "LocalWhisperTranscriber",
"init_parameters": {
"model_name_or_path": "tiny",
"device": "cuda",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
transcriber = LocalWhisperTranscriber.from_dict(data)
assert transcriber.model_name == "tiny"
assert transcriber.device == torch.device("cuda")
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
@pytest.mark.unit
def test_warmup(self):
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:

View File

@ -1,3 +1,4 @@
from typing import Literal
from unittest.mock import MagicMock, patch
import pytest
@ -24,6 +25,56 @@ class TestRemoteWhisperTranscriber:
with pytest.raises(ValueError, match="API key is None"):
RemoteWhisperTranscriber(api_key=None)
@pytest.mark.unit
def test_to_dict(self):
transcriber = RemoteWhisperTranscriber(api_key="test")
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://api.openai.com/v1",
"whisper_params": {},
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test",
model_name="whisper-1",
api_base="https://my.api.base/something_else/v3",
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
)
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "test"
assert transcriber.api_base == "https://my.api.base/something_else/v3"
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
@pytest.mark.unit
def test_run_with_path(self, preview_samples_path):
mock_response = MagicMock()

View File

@ -10,6 +10,24 @@ from haystack.preview.components.classifiers.file_classifier import FileExtensio
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionClassifier:
@pytest.mark.unit
def test_to_dict(self):
component = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict()
assert data == {
"type": "FileExtensionClassifier",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "FileExtensionClassifier",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
component = FileExtensionClassifier.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -13,28 +13,104 @@ class TestSentenceTransformersDocumentEmbedder:
def test_init_default(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
assert embedder.model_name_or_path == "model"
assert embedder.device is None
assert embedder.device == "cpu"
assert embedder.use_auth_token is None
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.metadata_fields_to_embed == []
assert embedder.embedding_separator == "\n"
@pytest.mark.unit
def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model",
device="cpu",
device="cuda",
use_auth_token=True,
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
metadata_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.model_name_or_path == "model"
assert embedder.device == "cpu"
assert embedder.device == "cuda"
assert embedder.use_auth_token is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.metadata_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
@pytest.mark.unit
def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
data = component.to_dict()
assert data == {
"type": "SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cpu",
"use_auth_token": None,
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"embedding_separator": "\n",
"metadata_fields_to_embed": [],
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersDocumentEmbedder(
model_name_or_path="model",
device="cuda",
use_auth_token="the-token",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
metadata_fields_to_embed=["meta_field"],
embedding_separator=" - ",
)
data = component.to_dict()
assert data == {
"type": "SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": "the-token",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"metadata_fields_to_embed": ["meta_field"],
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": "the-token",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": False,
"embedding_separator": " - ",
"metadata_fields_to_embed": ["meta_field"],
},
}
component = SentenceTransformersDocumentEmbedder.from_dict(data)
assert component.model_name_or_path == "model"
assert component.device == "cuda"
assert component.use_auth_token == "the-token"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is False
assert component.metadata_fields_to_embed == ["meta_field"]
assert component.embedding_separator == " - "
@pytest.mark.unit
@patch(
@ -45,7 +121,7 @@ class TestSentenceTransformersDocumentEmbedder:
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="model", device=None, use_auth_token=None
model_name_or_path="model", device="cpu", use_auth_token=None
)
@pytest.mark.unit

View File

@ -11,7 +11,7 @@ class TestSentenceTransformersTextEmbedder:
def test_init_default(self):
embedder = SentenceTransformersTextEmbedder(model_name_or_path="model")
assert embedder.model_name_or_path == "model"
assert embedder.device is None
assert embedder.device == "cpu"
assert embedder.use_auth_token is None
assert embedder.prefix == ""
assert embedder.suffix == ""
@ -23,7 +23,7 @@ class TestSentenceTransformersTextEmbedder:
def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
model_name_or_path="model",
device="cpu",
device="cuda",
use_auth_token=True,
prefix="prefix",
suffix="suffix",
@ -32,7 +32,7 @@ class TestSentenceTransformersTextEmbedder:
normalize_embeddings=True,
)
assert embedder.model_name_or_path == "model"
assert embedder.device == "cpu"
assert embedder.device == "cuda"
assert embedder.use_auth_token is True
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
@ -40,6 +40,76 @@ class TestSentenceTransformersTextEmbedder:
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
@pytest.mark.unit
def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model_name_or_path="model")
data = component.to_dict()
assert data == {
"type": "SentenceTransformersTextEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cpu",
"use_auth_token": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersTextEmbedder(
model_name_or_path="model",
device="cuda",
use_auth_token=True,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
)
data = component.to_dict()
assert data == {
"type": "SentenceTransformersTextEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": True,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "SentenceTransformersTextEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": True,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model_name_or_path == "model"
assert component.device == "cuda"
assert component.use_auth_token is True
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is True
@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
@ -49,7 +119,7 @@ class TestSentenceTransformersTextEmbedder:
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="model", device=None, use_auth_token=None
model_name_or_path="model", device="cpu", use_auth_token=None
)
@pytest.mark.unit

View File

@ -11,6 +11,66 @@ from haystack.preview.components.file_converters.txt import TextFileToDocument
class TestTextfileToDocument:
@pytest.mark.unit
def test_to_dict(self):
component = TextFileToDocument()
data = component.to_dict()
assert data == {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "utf-8",
"remove_numeric_tables": False,
"numeric_row_threshold": 0.4,
"valid_languages": [],
"id_hash_keys": [],
"progress_bar": True,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = TextFileToDocument(
encoding="latin-1",
remove_numeric_tables=True,
numeric_row_threshold=0.7,
valid_languages=["en", "de"],
id_hash_keys=["name"],
progress_bar=False,
)
data = component.to_dict()
assert data == {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "latin-1",
"remove_numeric_tables": True,
"numeric_row_threshold": 0.7,
"valid_languages": ["en", "de"],
"id_hash_keys": ["name"],
"progress_bar": False,
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "latin-1",
"remove_numeric_tables": True,
"numeric_row_threshold": 0.7,
"valid_languages": ["en", "de"],
"id_hash_keys": ["name"],
"progress_bar": False,
},
}
component = TextFileToDocument.from_dict(data)
assert component.encoding == "latin-1"
assert component.remove_numeric_tables
assert component.numeric_row_threshold == 0.7
assert component.valid_languages == ["en", "de"]
assert component.id_hash_keys == ["name"]
assert not component.progress_bar
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -1,8 +1,9 @@
from typing import Dict, Any
from unittest.mock import MagicMock, patch
import pytest
from haystack.preview import Pipeline
from haystack.preview import Pipeline, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.dataclasses import Document
@ -40,6 +41,81 @@ class TestMemoryRetriever:
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
@pytest.mark.unit
def test_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
"top_k": 10,
"scale_score": True,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever(
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
)
data = component.to_dict()
assert data == {
"type": "MemoryRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": False,
},
}
@pytest.mark.unit
def test_from_dict(self):
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
data = {
"type": "MemoryRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
},
}
component = MemoryRetriever.from_dict(data)
assert isinstance(component.document_store, MemoryDocumentStore)
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score
@pytest.mark.unit
def test_from_dict_without_docstore(self):
data = {"type": "MemoryRetriever", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
MemoryRetriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_without_docstore_type(self):
data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
MemoryRetriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "MemoryRetriever",
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
MemoryRetriever.from_dict(data)
@pytest.mark.unit
def test_valid_run(self, mock_docs):
top_k = 5

View File

@ -1,21 +0,0 @@
from unittest.mock import MagicMock
import pytest
from haystack.preview import Document
from haystack.preview.components.writers.document_writer import DocumentWriter
from haystack.preview.document_stores import DuplicatePolicy
class TestDocumentWriter:
@pytest.mark.unit
def test_run(self):
mocked_document_store = MagicMock()
writer = DocumentWriter(mocked_document_store)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]
writer.run(documents=documents)
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)

View File

@ -0,0 +1,83 @@
from unittest.mock import MagicMock
import pytest
from haystack.preview import Document, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.writers.document_writer import DocumentWriter
from haystack.preview.document_stores import DuplicatePolicy
class TestDocumentWriter:
@pytest.mark.unit
def test_to_dict(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
component = DocumentWriter(document_store=mocked_docstore_class())
data = component.to_dict()
assert data == {
"type": "DocumentWriter",
"init_parameters": {
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
"policy": "FAIL",
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
component = DocumentWriter(document_store=mocked_docstore_class(), policy=DuplicatePolicy.SKIP)
data = component.to_dict()
assert data == {
"type": "DocumentWriter",
"init_parameters": {
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
"policy": "SKIP",
},
}
@pytest.mark.unit
def test_from_dict(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
data = {
"type": "DocumentWriter",
"init_parameters": {
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
"policy": "SKIP",
},
}
component = DocumentWriter.from_dict(data)
assert isinstance(component.document_store, mocked_docstore_class)
assert component.policy == DuplicatePolicy.SKIP
@pytest.mark.unit
def test_from_dict_without_docstore(self):
data = {"type": "DocumentWriter", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
DocumentWriter.from_dict(data)
@pytest.mark.unit
def test_from_dict_without_docstore_type(self):
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
DocumentWriter.from_dict(data)
@pytest.mark.unit
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "DocumentWriter",
"init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."):
DocumentWriter.from_dict(data)
@pytest.mark.unit
def test_run(self):
mocked_document_store = MagicMock()
writer = DocumentWriter(mocked_document_store)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]
writer.run(documents=documents)
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)

View File

@ -1,61 +0,0 @@
from unittest.mock import Mock
import pytest
from haystack.preview.testing.factory import document_store_class
from haystack.preview.document_stores.decorator import default_document_store_to_dict, default_document_store_from_dict
from haystack.preview.document_stores.errors import DocumentStoreDeserializationError
@pytest.mark.unit
def test_default_store_to_dict():
MyStore = document_store_class("MyStore")
comp = MyStore()
res = default_document_store_to_dict(comp)
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}}
@pytest.mark.unit
def test_default_store_to_dict_with_custom_init_parameters():
extra_fields = {"init_parameters": {"custom_param": True}}
MyStore = document_store_class("MyStore", extra_fields=extra_fields)
comp = MyStore()
res = default_document_store_to_dict(comp)
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}}
@pytest.mark.unit
def test_default_store_from_dict():
MyStore = document_store_class("MyStore")
comp = default_document_store_from_dict(MyStore, {"type": "MyStore"})
assert isinstance(comp, MyStore)
@pytest.mark.unit
def test_default_store_from_dict_with_custom_init_parameters():
def store_init(self, custom_param: int):
self.custom_param = custom_param
extra_fields = {"__init__": store_init}
MyStore = document_store_class("MyStore", extra_fields=extra_fields)
comp = default_document_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}})
assert isinstance(comp, MyStore)
assert comp.custom_param == 100
@pytest.mark.unit
def test_default_store_from_dict_without_type():
with pytest.raises(DocumentStoreDeserializationError, match="Missing 'type' in DocumentStore serialization data"):
default_document_store_from_dict(Mock, {})
@pytest.mark.unit
def test_default_store_from_dict_unregistered_store(request):
# We use the test function name as store name to make sure it's not registered.
# Since the registry is global we risk to have a store with the same name registered in another test.
store_name = request.node.name
with pytest.raises(
DocumentStoreDeserializationError, match=f"DocumentStore '{store_name}' can't be deserialized as 'Mock'"
):
default_document_store_from_dict(Mock, {"type": store_name})

View File

@ -24,7 +24,6 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
store = MemoryDocumentStore()
data = store.to_dict()
assert data == {
"hash": id(store),
"type": "MemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
@ -40,7 +39,6 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
)
data = store.to_dict()
assert data == {
"hash": id(store),
"type": "MemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "custom_regex",

View File

@ -13,6 +13,15 @@ def test_document_store_class_default():
assert store.filter_documents() == []
assert store.write_documents([]) is None
assert store.delete_documents([]) is None
assert store.to_dict() == {"type": "MyStore", "init_parameters": {}}
@pytest.mark.unit
def test_document_store_from_dict():
MyStore = document_store_class("MyStore")
store = MyStore.from_dict({"type": "MyStore", "init_parameters": {}})
assert isinstance(store, MyStore)
@pytest.mark.unit