Add to_dict and from_dict methods for Stores (#5541)

* Add to_dict and from_dict methods for Stores

* Add release notes

* Add tests with custom init parameters
This commit is contained in:
Silvano Cerza 2023-08-11 14:45:56 +02:00 committed by GitHub
parent 094d8578bd
commit a7416bcf89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 1 deletions

View File

@ -1,5 +1,8 @@
from typing import Dict, Any, Type
import logging
from haystack.preview.document_stores.protocols import Store
from haystack.preview.document_stores.errors import StoreDeserializationError
logger = logging.getLogger(__name__)
@ -27,6 +30,9 @@ class _Store:
self.registry[cls.__name__] = cls
logger.debug("Registered Store %s", cls)
cls.to_dict = _default_store_to_dict
cls.from_dict = classmethod(_default_store_from_dict)
return cls
def __call__(self, cls=None):
@ -37,3 +43,28 @@ class _Store:
store = _Store()
def _default_store_to_dict(store_: Store) -> Dict[str, Any]:
"""
Default store serializer.
Serializes a store to a dictionary.
"""
return {
"hash": id(store_),
"type": store_.__class__.__name__,
"init_parameters": getattr(store_, "init_parameters", {}),
}
def _default_store_from_dict(cls: Type[Store], data: Dict[str, Any]) -> Store:
"""
Default store deserializer.
The "type" field in `data` must match the class that is being deserialized into.
"""
init_params = data.get("init_parameters", {})
if "type" not in data:
raise StoreDeserializationError("Missing 'type' in store serialization data")
if data["type"] != cls.__name__:
raise StoreDeserializationError(f"Store '{data['type']}' can't be deserialized as '{cls.__name__}'")
return cls(**init_params)

View File

@ -12,3 +12,7 @@ class DuplicateDocumentError(StoreError):
class MissingDocumentError(StoreError):
pass
class StoreDeserializationError(StoreError):
pass

View File

@ -47,6 +47,13 @@ class MemoryDocumentStore:
self.bm25_algorithm = algorithm_class
self.bm25_parameters = bm25_parameters or {}
# Used to convert this instance to a dictionary for serialization
self.init_parameters = {
"bm25_tokenization_regex": bm25_tokenization_regex,
"bm25_algorithm": bm25_algorithm,
"bm25_parameters": self.bm25_parameters,
}
def count_documents(self) -> int:
"""
Returns the number of how many documents are present in the document store.

View File

@ -25,6 +25,17 @@ class Store(Protocol):
you're using.
"""
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this store to a dictionary.
"""
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Store":
"""
Deserializes the store from a dictionary.
"""
def count_documents(self) -> int:
"""
Returns the number of documents stored.

View File

@ -0,0 +1,4 @@
---
features:
- Add `from_dict` and `to_dict` methods to `Store` `Protocol`
- Add default `from_dict` and `to_dict` implementations to classes decorated with `@store`

View File

@ -0,0 +1,59 @@
from unittest.mock import Mock
import pytest
from haystack.preview.testing.factory import store_class
from haystack.preview.document_stores.decorator import _default_store_to_dict, _default_store_from_dict
from haystack.preview.document_stores.errors import StoreDeserializationError
@pytest.mark.unit
def test_default_store_to_dict():
MyStore = store_class("MyStore")
comp = MyStore()
res = _default_store_to_dict(comp)
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}}
@pytest.mark.unit
def test_default_store_to_dict_with_custom_init_parameters():
extra_fields = {"init_parameters": {"custom_param": True}}
MyStore = store_class("MyStore", extra_fields=extra_fields)
comp = MyStore()
res = _default_store_to_dict(comp)
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}}
@pytest.mark.unit
def test_default_store_from_dict():
MyStore = store_class("MyStore")
comp = _default_store_from_dict(MyStore, {"type": "MyStore"})
assert isinstance(comp, MyStore)
@pytest.mark.unit
def test_default_store_from_dict_with_custom_init_parameters():
def store_init(self, custom_param: int):
self.custom_param = custom_param
extra_fields = {"__init__": store_init}
MyStore = store_class("MyStore", extra_fields=extra_fields)
comp = _default_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}})
assert isinstance(comp, MyStore)
assert comp.custom_param == 100
@pytest.mark.unit
def test_default_store_from_dict_without_type():
with pytest.raises(StoreDeserializationError, match="Missing 'type' in store serialization data"):
_default_store_from_dict(Mock, {})
@pytest.mark.unit
def test_default_store_from_dict_unregistered_store(request):
# We use the test function name as store name to make sure it's not registered.
# Since the registry is global we risk to have a store with the same name registered in another test.
store_name = request.node.name
with pytest.raises(StoreDeserializationError, match=f"Store '{store_name}' can't be deserialized as 'Mock'"):
_default_store_from_dict(Mock, {"type": store_name})

View File

@ -1,11 +1,11 @@
import logging
from unittest.mock import patch
import pandas as pd
import pytest
from haystack.preview import Document
from haystack.preview.document_stores import Store, MemoryDocumentStore
from haystack.testing.preview.document_store import DocumentStoreBaseTests
@ -18,6 +18,53 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
def docstore(self) -> MemoryDocumentStore:
return MemoryDocumentStore()
@pytest.mark.unit
def test_to_dict(self):
store = MemoryDocumentStore()
data = store.to_dict()
assert data == {
"hash": id(store),
"type": "MemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
"bm25_algorithm": "BM25Okapi",
"bm25_parameters": {},
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
store = MemoryDocumentStore(
bm25_tokenization_regex="custom_regex", bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"}
)
data = store.to_dict()
assert data == {
"hash": id(store),
"type": "MemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"},
},
}
@pytest.mark.unit
@patch("haystack.preview.document_stores.memory.document_store.re")
def test_from_dict(self, mock_regex):
data = {
"type": "MemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"},
},
}
store = MemoryDocumentStore.from_dict(data)
mock_regex.compile.assert_called_with("custom_regex")
assert store.tokenizer
assert store.bm25_algorithm.__name__ == "BM25Plus"
assert store.bm25_parameters == {"key": "value"}
@pytest.mark.unit
def test_bm25_retrieval(self, docstore: Store):
docstore = MemoryDocumentStore()