mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
chore: migrate to canals==0.7.0 (#5647)
* add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
a613b1b7f5
commit
b1daa7c647
@ -1,2 +1,4 @@
|
||||
from canals import component, Pipeline
|
||||
from canals.serialization import default_from_dict, default_to_dict
|
||||
from canals.errors import DeserializationError
|
||||
from haystack.preview.dataclasses import *
|
||||
|
||||
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -55,6 +55,21 @@ class LocalWhisperTranscriber:
|
||||
if not self._model:
|
||||
self._model = whisper.load_model(self.model_name, device=self.device)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.preview.utils import request_with_retry
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,17 +49,29 @@ class RemoteWhisperTranscriber:
|
||||
if not api_key:
|
||||
raise ValueError("API key is None.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.whisper_params = whisper_params or {}
|
||||
|
||||
self.model_name = model_name
|
||||
self.init_parameters = {
|
||||
"api_key": self.api_key,
|
||||
"model_name": self.model_name,
|
||||
"api_base": self.api_base,
|
||||
"whisper_params": self.whisper_params,
|
||||
}
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
whisper_params=self.whisper_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
|
||||
|
||||
@ -2,9 +2,9 @@ import logging
|
||||
import mimetypes
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Union, Optional, Dict, Any
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -38,12 +38,22 @@ class FileExtensionClassifier:
|
||||
f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter"
|
||||
)
|
||||
|
||||
# save the init parameters for serialization
|
||||
self.init_parameters = {"mime_types": mime_types}
|
||||
|
||||
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
|
||||
self.mime_types = mime_types
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, mime_types=self.mime_types)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionClassifier":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(self, paths: List[Union[str, Path]]):
|
||||
"""
|
||||
Run the FileExtensionClassifier.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview import Document
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
@ -42,7 +41,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
# TODO: remove device parameter and use Haystack's device management once migrated
|
||||
self.device = device
|
||||
self.device = device or "cpu"
|
||||
self.use_auth_token = use_auth_token
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
@ -50,6 +49,29 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self.metadata_fields_to_embed = metadata_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
device=self.device,
|
||||
use_auth_token=self.use_auth_token,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
metadata_fields_to_embed=self.metadata_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
@ -40,7 +40,7 @@ class SentenceTransformersTextEmbedder:
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
# TODO: remove device parameter and use Haystack's device management once migrated
|
||||
self.device = device
|
||||
self.device = device or "cpu"
|
||||
self.use_auth_token = use_auth_token
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
@ -48,6 +48,29 @@ class SentenceTransformersTextEmbedder:
|
||||
self.progress_bar = progress_bar
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
device=self.device,
|
||||
use_auth_token=self.use_auth_token,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Dict
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
|
||||
from canals.errors import PipelineRuntimeError
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
from haystack.preview import Document, component
|
||||
from haystack.preview import Document, component, default_to_dict, default_from_dict
|
||||
|
||||
with LazyImport("Run 'pip install farm-haystack[preprocessing]'") as langdetect_import:
|
||||
import langdetect
|
||||
@ -61,6 +61,27 @@ class TextFileToDocument:
|
||||
self.id_hash_keys = id_hash_keys or []
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
encoding=self.encoding,
|
||||
remove_numeric_tables=self.remove_numeric_tables,
|
||||
numeric_row_threshold=self.numeric_row_threshold,
|
||||
valid_languages=self.valid_languages,
|
||||
id_hash_keys=self.id_hash_keys,
|
||||
progress_bar=self.progress_bar,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
|
||||
from haystack.preview.document_stores import MemoryDocumentStore, document_store
|
||||
|
||||
|
||||
@component
|
||||
@ -41,6 +41,33 @@ class MemoryRetriever:
|
||||
self.top_k = top_k
|
||||
self.scale_score = scale_score
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(
|
||||
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "document_store" not in init_params:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if init_params["document_store"]["type"] not in document_store.registry:
|
||||
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")
|
||||
|
||||
docstore_class = document_store.registry[init_params["document_store"]["type"]]
|
||||
docstore = docstore_class.from_dict(init_params["document_store"])
|
||||
data["init_parameters"]["document_store"] = docstore
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[List[Document]])
|
||||
def run(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy
|
||||
from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError
|
||||
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy, document_store
|
||||
|
||||
|
||||
@component
|
||||
@ -19,6 +19,31 @@ class DocumentWriter:
|
||||
self.document_store = document_store
|
||||
self.policy = policy
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "document_store" not in init_params:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if init_params["document_store"]["type"] not in document_store.registry:
|
||||
raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.")
|
||||
docstore_class = document_store.registry[init_params["document_store"]["type"]]
|
||||
docstore = docstore_class.from_dict(init_params["document_store"])
|
||||
|
||||
data["init_parameters"]["document_store"] = docstore
|
||||
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None):
|
||||
"""
|
||||
Run DocumentWriter on the given input data.
|
||||
|
||||
@ -1,9 +1,5 @@
|
||||
from typing import Dict, Any, Type
|
||||
import logging
|
||||
|
||||
from haystack.preview.document_stores.protocols import DocumentStore
|
||||
from haystack.preview.document_stores.errors import DocumentStoreDeserializationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -40,30 +36,3 @@ class _DocumentStore:
|
||||
|
||||
|
||||
document_store = _DocumentStore()
|
||||
|
||||
|
||||
def default_document_store_to_dict(store_: DocumentStore) -> Dict[str, Any]:
|
||||
"""
|
||||
Default DocumentStore serializer.
|
||||
Serializes a DocumentStore to a dictionary.
|
||||
"""
|
||||
return {
|
||||
"hash": id(store_),
|
||||
"type": store_.__class__.__name__,
|
||||
"init_parameters": getattr(store_, "init_parameters", {}),
|
||||
}
|
||||
|
||||
|
||||
def default_document_store_from_dict(cls: Type[DocumentStore], data: Dict[str, Any]) -> DocumentStore:
|
||||
"""
|
||||
Default DocumentStore deserializer.
|
||||
The "type" field in `data` must match the class that is being deserialized into.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "type" not in data:
|
||||
raise DocumentStoreDeserializationError("Missing 'type' in DocumentStore serialization data")
|
||||
if data["type"] != cls.__name__:
|
||||
raise DocumentStoreDeserializationError(
|
||||
f"DocumentStore '{data['type']}' can't be deserialized as '{cls.__name__}'"
|
||||
)
|
||||
return cls(**init_params)
|
||||
|
||||
@ -12,7 +12,3 @@ class DuplicateDocumentError(DocumentStoreError):
|
||||
|
||||
class MissingDocumentError(DocumentStoreError):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentStoreDeserializationError(DocumentStoreError):
|
||||
pass
|
||||
|
||||
@ -7,11 +7,8 @@ import numpy as np
|
||||
import rank_bm25
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack.preview.document_stores.decorator import (
|
||||
document_store,
|
||||
default_document_store_to_dict,
|
||||
default_document_store_from_dict,
|
||||
)
|
||||
from haystack.preview import default_from_dict, default_to_dict
|
||||
from haystack.preview.document_stores.decorator import document_store
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores.protocols import DuplicatePolicy, DocumentStore
|
||||
from haystack.preview.document_stores.memory._filters import match
|
||||
@ -44,6 +41,7 @@ class MemoryDocumentStore:
|
||||
Initializes the DocumentStore.
|
||||
"""
|
||||
self.storage: Dict[str, Document] = {}
|
||||
self._bm25_tokenization_regex = bm25_tokenization_regex
|
||||
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
||||
algorithm_class = getattr(rank_bm25, bm25_algorithm)
|
||||
if algorithm_class is None:
|
||||
@ -51,25 +49,23 @@ class MemoryDocumentStore:
|
||||
self.bm25_algorithm = algorithm_class
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
|
||||
# Used to convert this instance to a dictionary for serialization
|
||||
self.init_parameters = {
|
||||
"bm25_tokenization_regex": bm25_tokenization_regex,
|
||||
"bm25_algorithm": bm25_algorithm,
|
||||
"bm25_parameters": self.bm25_parameters,
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this store to a dictionary.
|
||||
"""
|
||||
return default_document_store_to_dict(self)
|
||||
return default_to_dict(
|
||||
self,
|
||||
bm25_tokenization_regex=self._bm25_tokenization_regex,
|
||||
bm25_algorithm=self.bm25_algorithm.__name__,
|
||||
bm25_parameters=self.bm25_parameters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore":
|
||||
"""
|
||||
Deserializes the store from a dictionary.
|
||||
"""
|
||||
return default_document_store_from_dict(cls, data)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type, List, Union
|
||||
|
||||
from haystack.preview import default_to_dict, default_from_dict
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import document_store, DocumentStore, DuplicatePolicy
|
||||
|
||||
@ -96,11 +97,16 @@ def document_store_class(
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
return
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return default_to_dict(self)
|
||||
|
||||
fields = {
|
||||
"count_documents": count_documents,
|
||||
"filter_documents": filter_documents,
|
||||
"write_documents": write_documents,
|
||||
"delete_documents": delete_documents,
|
||||
"to_dict": to_dict,
|
||||
"from_dict": classmethod(default_from_dict),
|
||||
}
|
||||
|
||||
if extra_fields is not None:
|
||||
|
||||
@ -79,7 +79,7 @@ dependencies = [
|
||||
"jsonschema",
|
||||
|
||||
# Preview
|
||||
"canals==0.5.0",
|
||||
"canals==0.7.0",
|
||||
|
||||
# Agent events
|
||||
"events",
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- Migrate all components to Canals==0.7.0
|
||||
- Add serialization and deserialization methods for all Haystack components
|
||||
@ -26,6 +26,47 @@ class TestLocalWhisperTranscriber:
|
||||
with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"):
|
||||
LocalWhisperTranscriber(model_name_or_path="whisper-1")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
transcriber = LocalWhisperTranscriber()
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "LocalWhisperTranscriber",
|
||||
"init_parameters": {"model_name_or_path": "large", "device": "cpu", "whisper_params": {}},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
transcriber = LocalWhisperTranscriber(
|
||||
model_name_or_path="tiny",
|
||||
device="cuda",
|
||||
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
)
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "LocalWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "tiny",
|
||||
"device": "cuda",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "LocalWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "tiny",
|
||||
"device": "cuda",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
transcriber = LocalWhisperTranscriber.from_dict(data)
|
||||
assert transcriber.model_name == "tiny"
|
||||
assert transcriber.device == torch.device("cuda")
|
||||
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_warmup(self):
|
||||
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -24,6 +25,56 @@ class TestRemoteWhisperTranscriber:
|
||||
with pytest.raises(ValueError, match="API key is None"):
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test")
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"whisper_params": {},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(
|
||||
api_key="test",
|
||||
model_name="whisper-1",
|
||||
api_base="https://my.api.base/something_else/v3",
|
||||
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
)
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://my.api.base/something_else/v3",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://my.api.base/something_else/v3",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
transcriber = RemoteWhisperTranscriber.from_dict(data)
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.api_key == "test"
|
||||
assert transcriber.api_base == "https://my.api.base/something_else/v3"
|
||||
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_path(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
|
||||
@ -10,6 +10,24 @@ from haystack.preview.components.classifiers.file_classifier import FileExtensio
|
||||
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
||||
)
|
||||
class TestFileExtensionClassifier:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "FileExtensionClassifier",
|
||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "FileExtensionClassifier",
|
||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||
}
|
||||
component = FileExtensionClassifier.from_dict(data)
|
||||
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -13,28 +13,104 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
def test_init_default(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device is None
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.use_auth_token is None
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.normalize_embeddings is False
|
||||
assert embedder.metadata_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model_name_or_path="model",
|
||||
device="cpu",
|
||||
device="cuda",
|
||||
use_auth_token=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
metadata_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.device == "cuda"
|
||||
assert embedder.use_auth_token is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.normalize_embeddings is True
|
||||
assert embedder.metadata_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "SentenceTransformersDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cpu",
|
||||
"use_auth_token": None,
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
"embedding_separator": "\n",
|
||||
"metadata_fields_to_embed": [],
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = SentenceTransformersDocumentEmbedder(
|
||||
model_name_or_path="model",
|
||||
device="cuda",
|
||||
use_auth_token="the-token",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
metadata_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" - ",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "SentenceTransformersDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": "the-token",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
"embedding_separator": " - ",
|
||||
"metadata_fields_to_embed": ["meta_field"],
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SentenceTransformersDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": "the-token",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": False,
|
||||
"embedding_separator": " - ",
|
||||
"metadata_fields_to_embed": ["meta_field"],
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(data)
|
||||
assert component.model_name_or_path == "model"
|
||||
assert component.device == "cuda"
|
||||
assert component.use_auth_token == "the-token"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.metadata_fields_to_embed == ["meta_field"]
|
||||
assert component.embedding_separator == " - "
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
@ -45,7 +121,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model_name_or_path="model", device=None, use_auth_token=None
|
||||
model_name_or_path="model", device="cpu", use_auth_token=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@ -11,7 +11,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
def test_init_default(self):
|
||||
embedder = SentenceTransformersTextEmbedder(model_name_or_path="model")
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device is None
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.use_auth_token is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
@ -23,7 +23,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersTextEmbedder(
|
||||
model_name_or_path="model",
|
||||
device="cpu",
|
||||
device="cuda",
|
||||
use_auth_token=True,
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
@ -32,7 +32,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.device == "cuda"
|
||||
assert embedder.use_auth_token is True
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
@ -40,6 +40,76 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.normalize_embeddings is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersTextEmbedder(model_name_or_path="model")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cpu",
|
||||
"use_auth_token": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = SentenceTransformersTextEmbedder(
|
||||
model_name_or_path="model",
|
||||
device="cuda",
|
||||
use_auth_token=True,
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": True,
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": True,
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
assert component.model_name_or_path == "model"
|
||||
assert component.device == "cuda"
|
||||
assert component.use_auth_token is True
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is True
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
"haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
@ -49,7 +119,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model_name_or_path="model", device=None, use_auth_token=None
|
||||
model_name_or_path="model", device="cpu", use_auth_token=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@ -11,6 +11,66 @@ from haystack.preview.components.file_converters.txt import TextFileToDocument
|
||||
|
||||
|
||||
class TestTextfileToDocument:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TextFileToDocument()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "utf-8",
|
||||
"remove_numeric_tables": False,
|
||||
"numeric_row_threshold": 0.4,
|
||||
"valid_languages": [],
|
||||
"id_hash_keys": [],
|
||||
"progress_bar": True,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = TextFileToDocument(
|
||||
encoding="latin-1",
|
||||
remove_numeric_tables=True,
|
||||
numeric_row_threshold=0.7,
|
||||
valid_languages=["en", "de"],
|
||||
id_hash_keys=["name"],
|
||||
progress_bar=False,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "latin-1",
|
||||
"remove_numeric_tables": True,
|
||||
"numeric_row_threshold": 0.7,
|
||||
"valid_languages": ["en", "de"],
|
||||
"id_hash_keys": ["name"],
|
||||
"progress_bar": False,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "latin-1",
|
||||
"remove_numeric_tables": True,
|
||||
"numeric_row_threshold": 0.7,
|
||||
"valid_languages": ["en", "de"],
|
||||
"id_hash_keys": ["name"],
|
||||
"progress_bar": False,
|
||||
},
|
||||
}
|
||||
component = TextFileToDocument.from_dict(data)
|
||||
assert component.encoding == "latin-1"
|
||||
assert component.remove_numeric_tables
|
||||
assert component.numeric_row_threshold == 0.7
|
||||
assert component.valid_languages == ["en", "de"]
|
||||
assert component.id_hash_keys == ["name"]
|
||||
assert not component.progress_bar
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline
|
||||
from haystack.preview import Pipeline, DeserializationError
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
@ -40,6 +41,81 @@ class TestMemoryRetriever:
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryRetriever(document_store=document_store)
|
||||
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": True,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryRetriever(
|
||||
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
"scale_score": False,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
data = {
|
||||
"type": "MemoryRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
},
|
||||
}
|
||||
component = MemoryRetriever.from_dict(data)
|
||||
assert isinstance(component.document_store, MemoryDocumentStore)
|
||||
assert component.filters == {"name": "test.txt"}
|
||||
assert component.top_k == 5
|
||||
assert component.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "MemoryRetriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
MemoryRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||
MemoryRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "MemoryRetriever",
|
||||
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
|
||||
MemoryRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_run(self, mock_docs):
|
||||
top_k = 5
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.writers.document_writer import DocumentWriter
|
||||
from haystack.preview.document_stores import DuplicatePolicy
|
||||
|
||||
|
||||
class TestDocumentWriter:
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
mocked_document_store = MagicMock()
|
||||
writer = DocumentWriter(mocked_document_store)
|
||||
documents = [
|
||||
Document(content="This is the text of a document."),
|
||||
Document(content="This is the text of another document."),
|
||||
]
|
||||
|
||||
writer.run(documents=documents)
|
||||
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)
|
||||
83
test/preview/components/writers/test_document_writer.py
Normal file
83
test/preview/components/writers/test_document_writer.py
Normal file
@ -0,0 +1,83 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document, DeserializationError
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.components.writers.document_writer import DocumentWriter
|
||||
from haystack.preview.document_stores import DuplicatePolicy
|
||||
|
||||
|
||||
class TestDocumentWriter:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
mocked_docstore_class = document_store_class("MockedDocumentStore")
|
||||
component = DocumentWriter(document_store=mocked_docstore_class())
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentWriter",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
|
||||
"policy": "FAIL",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
mocked_docstore_class = document_store_class("MockedDocumentStore")
|
||||
component = DocumentWriter(document_store=mocked_docstore_class(), policy=DuplicatePolicy.SKIP)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentWriter",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
|
||||
"policy": "SKIP",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
mocked_docstore_class = document_store_class("MockedDocumentStore")
|
||||
data = {
|
||||
"type": "DocumentWriter",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MockedDocumentStore", "init_parameters": {}},
|
||||
"policy": "SKIP",
|
||||
},
|
||||
}
|
||||
component = DocumentWriter.from_dict(data)
|
||||
assert isinstance(component.document_store, mocked_docstore_class)
|
||||
assert component.policy == DuplicatePolicy.SKIP
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "DocumentWriter", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
DocumentWriter.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||
DocumentWriter.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "DocumentWriter",
|
||||
"init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."):
|
||||
DocumentWriter.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
mocked_document_store = MagicMock()
|
||||
writer = DocumentWriter(mocked_document_store)
|
||||
documents = [
|
||||
Document(content="This is the text of a document."),
|
||||
Document(content="This is the text of another document."),
|
||||
]
|
||||
|
||||
writer.run(documents=documents)
|
||||
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)
|
||||
@ -1,61 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.document_stores.decorator import default_document_store_to_dict, default_document_store_from_dict
|
||||
from haystack.preview.document_stores.errors import DocumentStoreDeserializationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_to_dict():
|
||||
MyStore = document_store_class("MyStore")
|
||||
comp = MyStore()
|
||||
res = default_document_store_to_dict(comp)
|
||||
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_to_dict_with_custom_init_parameters():
|
||||
extra_fields = {"init_parameters": {"custom_param": True}}
|
||||
MyStore = document_store_class("MyStore", extra_fields=extra_fields)
|
||||
comp = MyStore()
|
||||
res = default_document_store_to_dict(comp)
|
||||
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict():
|
||||
MyStore = document_store_class("MyStore")
|
||||
comp = default_document_store_from_dict(MyStore, {"type": "MyStore"})
|
||||
assert isinstance(comp, MyStore)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_with_custom_init_parameters():
|
||||
def store_init(self, custom_param: int):
|
||||
self.custom_param = custom_param
|
||||
|
||||
extra_fields = {"__init__": store_init}
|
||||
MyStore = document_store_class("MyStore", extra_fields=extra_fields)
|
||||
comp = default_document_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}})
|
||||
assert isinstance(comp, MyStore)
|
||||
assert comp.custom_param == 100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_without_type():
|
||||
with pytest.raises(DocumentStoreDeserializationError, match="Missing 'type' in DocumentStore serialization data"):
|
||||
default_document_store_from_dict(Mock, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_unregistered_store(request):
|
||||
# We use the test function name as store name to make sure it's not registered.
|
||||
# Since the registry is global we risk to have a store with the same name registered in another test.
|
||||
store_name = request.node.name
|
||||
|
||||
with pytest.raises(
|
||||
DocumentStoreDeserializationError, match=f"DocumentStore '{store_name}' can't be deserialized as 'Mock'"
|
||||
):
|
||||
default_document_store_from_dict(Mock, {"type": store_name})
|
||||
@ -24,7 +24,6 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
store = MemoryDocumentStore()
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"hash": id(store),
|
||||
"type": "MemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
|
||||
@ -40,7 +39,6 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
)
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"hash": id(store),
|
||||
"type": "MemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
|
||||
@ -13,6 +13,15 @@ def test_document_store_class_default():
|
||||
assert store.filter_documents() == []
|
||||
assert store.write_documents([]) is None
|
||||
assert store.delete_documents([]) is None
|
||||
assert store.to_dict() == {"type": "MyStore", "init_parameters": {}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_document_store_from_dict():
|
||||
MyStore = document_store_class("MyStore")
|
||||
|
||||
store = MyStore.from_dict({"type": "MyStore", "init_parameters": {}})
|
||||
assert isinstance(store, MyStore)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user