mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
Add to_dict and from_dict methods for Stores (#5541)
* Add to_dict and from_dict methods for Stores * Add release notes * Add tests with custom init parameters
This commit is contained in:
parent
094d8578bd
commit
a7416bcf89
@ -1,5 +1,8 @@
|
||||
from typing import Dict, Any, Type
|
||||
import logging
|
||||
|
||||
from haystack.preview.document_stores.protocols import Store
|
||||
from haystack.preview.document_stores.errors import StoreDeserializationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,6 +30,9 @@ class _Store:
|
||||
self.registry[cls.__name__] = cls
|
||||
logger.debug("Registered Store %s", cls)
|
||||
|
||||
cls.to_dict = _default_store_to_dict
|
||||
cls.from_dict = classmethod(_default_store_from_dict)
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(self, cls=None):
|
||||
@ -37,3 +43,28 @@ class _Store:
|
||||
|
||||
|
||||
store = _Store()
|
||||
|
||||
|
||||
def _default_store_to_dict(store_: Store) -> Dict[str, Any]:
|
||||
"""
|
||||
Default store serializer.
|
||||
Serializes a store to a dictionary.
|
||||
"""
|
||||
return {
|
||||
"hash": id(store_),
|
||||
"type": store_.__class__.__name__,
|
||||
"init_parameters": getattr(store_, "init_parameters", {}),
|
||||
}
|
||||
|
||||
|
||||
def _default_store_from_dict(cls: Type[Store], data: Dict[str, Any]) -> Store:
|
||||
"""
|
||||
Default store deserializer.
|
||||
The "type" field in `data` must match the class that is being deserialized into.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "type" not in data:
|
||||
raise StoreDeserializationError("Missing 'type' in store serialization data")
|
||||
if data["type"] != cls.__name__:
|
||||
raise StoreDeserializationError(f"Store '{data['type']}' can't be deserialized as '{cls.__name__}'")
|
||||
return cls(**init_params)
|
||||
|
||||
@ -12,3 +12,7 @@ class DuplicateDocumentError(StoreError):
|
||||
|
||||
class MissingDocumentError(StoreError):
|
||||
pass
|
||||
|
||||
|
||||
class StoreDeserializationError(StoreError):
|
||||
pass
|
||||
|
||||
@ -47,6 +47,13 @@ class MemoryDocumentStore:
|
||||
self.bm25_algorithm = algorithm_class
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
|
||||
# Used to convert this instance to a dictionary for serialization
|
||||
self.init_parameters = {
|
||||
"bm25_tokenization_regex": bm25_tokenization_regex,
|
||||
"bm25_algorithm": bm25_algorithm,
|
||||
"bm25_parameters": self.bm25_parameters,
|
||||
}
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
Returns the number of how many documents are present in the document store.
|
||||
|
||||
@ -25,6 +25,17 @@ class Store(Protocol):
|
||||
you're using.
|
||||
"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this store to a dictionary.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Store":
|
||||
"""
|
||||
Deserializes the store from a dictionary.
|
||||
"""
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
Returns the number of documents stored.
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- Add `from_dict` and `to_dict` methods to `Store` `Protocol`
|
||||
- Add default `from_dict` and `to_dict` implementations to classes decorated with `@store`
|
||||
59
test/preview/document_stores/test_decorator.py
Normal file
59
test/preview/document_stores/test_decorator.py
Normal file
@ -0,0 +1,59 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview.testing.factory import store_class
|
||||
from haystack.preview.document_stores.decorator import _default_store_to_dict, _default_store_from_dict
|
||||
from haystack.preview.document_stores.errors import StoreDeserializationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_to_dict():
|
||||
MyStore = store_class("MyStore")
|
||||
comp = MyStore()
|
||||
res = _default_store_to_dict(comp)
|
||||
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_to_dict_with_custom_init_parameters():
|
||||
extra_fields = {"init_parameters": {"custom_param": True}}
|
||||
MyStore = store_class("MyStore", extra_fields=extra_fields)
|
||||
comp = MyStore()
|
||||
res = _default_store_to_dict(comp)
|
||||
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict():
|
||||
MyStore = store_class("MyStore")
|
||||
comp = _default_store_from_dict(MyStore, {"type": "MyStore"})
|
||||
assert isinstance(comp, MyStore)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_with_custom_init_parameters():
|
||||
def store_init(self, custom_param: int):
|
||||
self.custom_param = custom_param
|
||||
|
||||
extra_fields = {"__init__": store_init}
|
||||
MyStore = store_class("MyStore", extra_fields=extra_fields)
|
||||
comp = _default_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}})
|
||||
assert isinstance(comp, MyStore)
|
||||
assert comp.custom_param == 100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_without_type():
|
||||
with pytest.raises(StoreDeserializationError, match="Missing 'type' in store serialization data"):
|
||||
_default_store_from_dict(Mock, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_store_from_dict_unregistered_store(request):
|
||||
# We use the test function name as store name to make sure it's not registered.
|
||||
# Since the registry is global we risk to have a store with the same name registered in another test.
|
||||
store_name = request.node.name
|
||||
|
||||
with pytest.raises(StoreDeserializationError, match=f"Store '{store_name}' can't be deserialized as 'Mock'"):
|
||||
_default_store_from_dict(Mock, {"type": store_name})
|
||||
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.document_stores import Store, MemoryDocumentStore
|
||||
|
||||
from haystack.testing.preview.document_store import DocumentStoreBaseTests
|
||||
|
||||
|
||||
@ -18,6 +18,53 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
def docstore(self) -> MemoryDocumentStore:
|
||||
return MemoryDocumentStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
store = MemoryDocumentStore()
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"hash": id(store),
|
||||
"type": "MemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
|
||||
"bm25_algorithm": "BM25Okapi",
|
||||
"bm25_parameters": {},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
store = MemoryDocumentStore(
|
||||
bm25_tokenization_regex="custom_regex", bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"}
|
||||
)
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"hash": id(store),
|
||||
"type": "MemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.document_stores.memory.document_store.re")
|
||||
def test_from_dict(self, mock_regex):
|
||||
data = {
|
||||
"type": "MemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
},
|
||||
}
|
||||
store = MemoryDocumentStore.from_dict(data)
|
||||
mock_regex.compile.assert_called_with("custom_regex")
|
||||
assert store.tokenizer
|
||||
assert store.bm25_algorithm.__name__ == "BM25Plus"
|
||||
assert store.bm25_parameters == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval(self, docstore: Store):
|
||||
docstore = MemoryDocumentStore()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user