diff --git a/haystack/preview/document_stores/decorator.py b/haystack/preview/document_stores/decorator.py index 01e85eda9..9831dbef8 100644 --- a/haystack/preview/document_stores/decorator.py +++ b/haystack/preview/document_stores/decorator.py @@ -1,5 +1,8 @@ +from typing import Dict, Any, Type import logging +from haystack.preview.document_stores.protocols import Store +from haystack.preview.document_stores.errors import StoreDeserializationError logger = logging.getLogger(__name__) @@ -27,6 +30,9 @@ class _Store: self.registry[cls.__name__] = cls logger.debug("Registered Store %s", cls) + cls.to_dict = _default_store_to_dict + cls.from_dict = classmethod(_default_store_from_dict) + return cls def __call__(self, cls=None): @@ -37,3 +43,28 @@ class _Store: store = _Store() + + +def _default_store_to_dict(store_: Store) -> Dict[str, Any]: + """ + Default store serializer. + Serializes a store to a dictionary. + """ + return { + "hash": id(store_), + "type": store_.__class__.__name__, + "init_parameters": getattr(store_, "init_parameters", {}), + } + + +def _default_store_from_dict(cls: Type[Store], data: Dict[str, Any]) -> Store: + """ + Default store deserializer. + The "type" field in `data` must match the class that is being deserialized into. + """ + init_params = data.get("init_parameters", {}) + if "type" not in data: + raise StoreDeserializationError("Missing 'type' in store serialization data") + if data["type"] != cls.__name__: + raise StoreDeserializationError(f"Store '{data['type']}' can't be deserialized as '{cls.__name__}'") + return cls(**init_params) diff --git a/haystack/preview/document_stores/errors.py b/haystack/preview/document_stores/errors.py index 8d231c212..d9a663d29 100644 --- a/haystack/preview/document_stores/errors.py +++ b/haystack/preview/document_stores/errors.py @@ -12,3 +12,7 @@ class DuplicateDocumentError(StoreError): class MissingDocumentError(StoreError): pass + + +class StoreDeserializationError(StoreError): + pass diff --git a/haystack/preview/document_stores/memory/document_store.py b/haystack/preview/document_stores/memory/document_store.py index 27becd9ae..ee67a3f21 100644 --- a/haystack/preview/document_stores/memory/document_store.py +++ b/haystack/preview/document_stores/memory/document_store.py @@ -47,6 +47,13 @@ class MemoryDocumentStore: self.bm25_algorithm = algorithm_class self.bm25_parameters = bm25_parameters or {} + # Used to convert this instance to a dictionary for serialization + self.init_parameters = { + "bm25_tokenization_regex": bm25_tokenization_regex, + "bm25_algorithm": bm25_algorithm, + "bm25_parameters": self.bm25_parameters, + } + def count_documents(self) -> int: """ Returns the number of how many documents are present in the document store. diff --git a/haystack/preview/document_stores/protocols.py b/haystack/preview/document_stores/protocols.py index 4e8753198..54f160bff 100644 --- a/haystack/preview/document_stores/protocols.py +++ b/haystack/preview/document_stores/protocols.py @@ -25,6 +25,17 @@ class Store(Protocol): you're using. """ + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this store to a dictionary. + """ + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Store": + """ + Deserializes the store from a dictionary. + """ + def count_documents(self) -> int: """ Returns the number of documents stored. diff --git a/releasenotes/notes/stores-serialisation-a09398d158b01ae6.yaml b/releasenotes/notes/stores-serialisation-a09398d158b01ae6.yaml new file mode 100644 index 000000000..da00f4873 --- /dev/null +++ b/releasenotes/notes/stores-serialisation-a09398d158b01ae6.yaml @@ -0,0 +1,4 @@ +--- +features: + - Add `from_dict` and `to_dict` methods to `Store` `Protocol` + - Add default `from_dict` and `to_dict` implementations to classes decorated with `@store` diff --git a/test/preview/document_stores/test_decorator.py b/test/preview/document_stores/test_decorator.py new file mode 100644 index 000000000..3896f0d39 --- /dev/null +++ b/test/preview/document_stores/test_decorator.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock + +import pytest + +from haystack.preview.testing.factory import store_class +from haystack.preview.document_stores.decorator import _default_store_to_dict, _default_store_from_dict +from haystack.preview.document_stores.errors import StoreDeserializationError + + +@pytest.mark.unit +def test_default_store_to_dict(): + MyStore = store_class("MyStore") + comp = MyStore() + res = _default_store_to_dict(comp) + assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}} + + +@pytest.mark.unit +def test_default_store_to_dict_with_custom_init_parameters(): + extra_fields = {"init_parameters": {"custom_param": True}} + MyStore = store_class("MyStore", extra_fields=extra_fields) + comp = MyStore() + res = _default_store_to_dict(comp) + assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}} + + +@pytest.mark.unit +def test_default_store_from_dict(): + MyStore = store_class("MyStore") + comp = _default_store_from_dict(MyStore, {"type": "MyStore"}) + assert isinstance(comp, MyStore) + + +@pytest.mark.unit +def test_default_store_from_dict_with_custom_init_parameters(): + def store_init(self, custom_param: int): + self.custom_param = custom_param + + extra_fields = {"__init__": store_init} + MyStore = store_class("MyStore", extra_fields=extra_fields) + comp = _default_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}}) + assert isinstance(comp, MyStore) + assert comp.custom_param == 100 + + +@pytest.mark.unit +def test_default_store_from_dict_without_type(): + with pytest.raises(StoreDeserializationError, match="Missing 'type' in store serialization data"): + _default_store_from_dict(Mock, {}) + + +@pytest.mark.unit +def test_default_store_from_dict_unregistered_store(request): + # We use the test function name as store name to make sure it's not registered. + # Since the registry is global we risk to have a store with the same name registered in another test. + store_name = request.node.name + + with pytest.raises(StoreDeserializationError, match=f"Store '{store_name}' can't be deserialized as 'Mock'"): + _default_store_from_dict(Mock, {"type": store_name}) diff --git a/test/preview/document_stores/test_memory.py b/test/preview/document_stores/test_memory.py index c2db35ba4..cae31041a 100644 --- a/test/preview/document_stores/test_memory.py +++ b/test/preview/document_stores/test_memory.py @@ -1,11 +1,11 @@ import logging +from unittest.mock import patch import pandas as pd import pytest from haystack.preview import Document from haystack.preview.document_stores import Store, MemoryDocumentStore - from haystack.testing.preview.document_store import DocumentStoreBaseTests @@ -18,6 +18,53 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): def docstore(self) -> MemoryDocumentStore: return MemoryDocumentStore() + @pytest.mark.unit + def test_to_dict(self): + store = MemoryDocumentStore() + data = store.to_dict() + assert data == { + "hash": id(store), + "type": "MemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": r"(?u)\b\w\w+\b", + "bm25_algorithm": "BM25Okapi", + "bm25_parameters": {}, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + store = MemoryDocumentStore( + bm25_tokenization_regex="custom_regex", bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"} + ) + data = store.to_dict() + assert data == { + "hash": id(store), + "type": "MemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": "custom_regex", + "bm25_algorithm": "BM25Plus", + "bm25_parameters": {"key": "value"}, + }, + } + + @pytest.mark.unit + @patch("haystack.preview.document_stores.memory.document_store.re") + def test_from_dict(self, mock_regex): + data = { + "type": "MemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": "custom_regex", + "bm25_algorithm": "BM25Plus", + "bm25_parameters": {"key": "value"}, + }, + } + store = MemoryDocumentStore.from_dict(data) + mock_regex.compile.assert_called_with("custom_regex") + assert store.tokenizer + assert store.bm25_algorithm.__name__ == "BM25Plus" + assert store.bm25_parameters == {"key": "value"} + @pytest.mark.unit def test_bm25_retrieval(self, docstore: Store): docstore = MemoryDocumentStore()