From bcc4104729bd2201ca4fa998bcc2dd0d7932b98c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 14 Aug 2024 13:29:27 +0200 Subject: [PATCH] refactor: utility function for docstore deserialization (#8226) * refactor docstore deserialization * more tests * reno; headers * expose key --- haystack/components/caching/cache_checker.py | 21 +---- .../components/retrievers/filter_retriever.py | 21 +---- .../retrievers/sentence_window_retriever.py | 22 +---- .../components/writers/document_writer.py | 20 +---- haystack/utils/__init__.py | 2 + haystack/utils/docstore_deserialization.py | 41 +++++++++ ...ation-in-init-params-a123a39d5fbc957f.yaml | 5 ++ .../caching/test_url_cache_checker.py | 2 +- .../test_sentence_window_retriever.py | 2 +- .../writers/test_document_writer.py | 2 +- test/utils/test_docstore_deserialization.py | 89 +++++++++++++++++++ 11 files changed, 155 insertions(+), 72 deletions(-) create mode 100644 haystack/utils/docstore_deserialization.py create mode 100644 releasenotes/notes/docstore-deserialization-in-init-params-a123a39d5fbc957f.yaml create mode 100644 test/utils/test_docstore_deserialization.py diff --git a/haystack/components/caching/cache_checker.py b/haystack/components/caching/cache_checker.py index 48751ee6d..9aaf428bd 100644 --- a/haystack/components/caching/cache_checker.py +++ b/haystack/components/caching/cache_checker.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List -from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging -from haystack.core.serialization import import_class_by_name +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.document_stores.types import DocumentStore +from haystack.utils import deserialize_document_store_in_init_parameters logger = logging.getLogger(__name__) @@ -71,21 +71,8 @@ class CacheChecker: :returns: Deserialized component. """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - - doc_store_data = data["init_parameters"]["document_store"] - try: - doc_store_class = import_class_by_name(doc_store_data["type"]) - except ImportError as e: - raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e - if hasattr(doc_store_class, "from_dict"): - data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data) - else: - data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) + # deserialize the document store + data = deserialize_document_store_in_init_parameters(data) return default_from_dict(cls, data) diff --git a/haystack/components/retrievers/filter_retriever.py b/haystack/components/retrievers/filter_retriever.py index a2a612fee..0f69770a3 100644 --- a/haystack/components/retrievers/filter_retriever.py +++ b/haystack/components/retrievers/filter_retriever.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional -from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging -from haystack.core.serialization import import_class_by_name +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.document_stores.types import DocumentStore +from haystack.utils import deserialize_document_store_in_init_parameters logger = logging.getLogger(__name__) @@ -77,21 +77,8 @@ class FilterRetriever: :returns: The deserialized component. """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - - doc_store_data = data["init_parameters"]["document_store"] - try: - doc_store_class = import_class_by_name(doc_store_data["type"]) - except ImportError as e: - raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e - if hasattr(doc_store_class, "from_dict"): - data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data) - else: - data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) + # deserialize the document store + data = deserialize_document_store_in_init_parameters(data) return default_from_dict(cls, data) diff --git a/haystack/components/retrievers/sentence_window_retriever.py b/haystack/components/retrievers/sentence_window_retriever.py index 3e9a7e9ae..a96b4adb2 100644 --- a/haystack/components/retrievers/sentence_window_retriever.py +++ b/haystack/components/retrievers/sentence_window_retriever.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List -from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict -from haystack.core.serialization import import_class_by_name +from haystack import Document, component, default_from_dict, default_to_dict from haystack.document_stores.types import DocumentStore +from haystack.utils import deserialize_document_store_in_init_parameters @component @@ -117,24 +117,8 @@ class SentenceWindowRetriever: :returns: Deserialized component. """ - init_params = data.get("init_parameters", {}) - - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - # deserialize the document store - doc_store_data = data["init_parameters"]["document_store"] - try: - doc_store_class = import_class_by_name(doc_store_data["type"]) - except ImportError as e: - raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e - - if hasattr(doc_store_class, "from_dict"): - data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data) - else: - data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) + data = deserialize_document_store_in_init_parameters(data) # deserialize the component return default_from_dict(cls, data) diff --git a/haystack/components/writers/document_writer.py b/haystack/components/writers/document_writer.py index 72dd3a34e..bd1fa3138 100644 --- a/haystack/components/writers/document_writer.py +++ b/haystack/components/writers/document_writer.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional -from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging -from haystack.core.serialization import import_class_by_name +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.document_stores.types import DocumentStore, DuplicatePolicy +from haystack.utils import deserialize_document_store_in_init_parameters logger = logging.getLogger(__name__) @@ -73,21 +73,9 @@ class DocumentWriter: :raises DeserializationError: If the document store is not properly specified in the serialization data or its type cannot be imported. """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") + # deserialize the document store + data = deserialize_document_store_in_init_parameters(data) - doc_store_data = data["init_parameters"]["document_store"] - try: - doc_store_class = import_class_by_name(doc_store_data["type"]) - except ImportError as e: - raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e - if hasattr(doc_store_class, "from_dict"): - data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data) - else: - data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]] return default_from_dict(cls, data) diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 906067b24..baefe8fa0 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -5,6 +5,7 @@ from .auth import Secret, deserialize_secrets_inplace from .callable_serialization import deserialize_callable, serialize_callable from .device import ComponentDevice, Device, DeviceMap, DeviceType +from .docstore_deserialization import deserialize_document_store_in_init_parameters from .expit import expit from .filters import document_matches_filter from .jupyter import is_in_jupyter @@ -26,4 +27,5 @@ __all__ = [ "deserialize_callable", "serialize_type", "deserialize_type", + "deserialize_document_store_in_init_parameters", ] diff --git a/haystack/utils/docstore_deserialization.py b/haystack/utils/docstore_deserialization.py new file mode 100644 index 000000000..419a2fa0c --- /dev/null +++ b/haystack/utils/docstore_deserialization.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict + +from haystack import DeserializationError +from haystack.core.serialization import default_from_dict, import_class_by_name + + +def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str = "document_store") -> Dict[str, Any]: + """ + Deserializes a generic document store from the init_parameters of a serialized component. + + :param data: + The dictionary to deserialize from. + :param key: + The key in the `data["init_parameters"]` dictionary where the document store is specified. + :returns: + The dictionary, with the document store deserialized. + + :raises DeserializationError: + If the document store is not properly specified in the serialization data or its type cannot be imported. + """ + init_params = data.get("init_parameters", {}) + if key not in init_params: + raise DeserializationError(f"Missing '{key}' in serialization data") + if "type" not in init_params[key]: + raise DeserializationError(f"Missing 'type' in {key} serialization data") + + doc_store_data = data["init_parameters"][key] + try: + doc_store_class = import_class_by_name(doc_store_data["type"]) + except ImportError as e: + raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e + if hasattr(doc_store_class, "from_dict"): + data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data) + else: + data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data) + + return data diff --git a/releasenotes/notes/docstore-deserialization-in-init-params-a123a39d5fbc957f.yaml b/releasenotes/notes/docstore-deserialization-in-init-params-a123a39d5fbc957f.yaml new file mode 100644 index 000000000..1582b9b74 --- /dev/null +++ b/releasenotes/notes/docstore-deserialization-in-init-params-a123a39d5fbc957f.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Introduce an utility function to deserialize a generic Document Store + from the init_parameters of a serialized component. diff --git a/test/components/caching/test_url_cache_checker.py b/test/components/caching/test_url_cache_checker.py index 9bf1e12c4..60b72ff89 100644 --- a/test/components/caching/test_url_cache_checker.py +++ b/test/components/caching/test_url_cache_checker.py @@ -60,7 +60,7 @@ class TestCacheChecker: "type": "haystack.components.caching.cache_checker.UrlCacheChecker", "init_parameters": {"document_store": {"init_parameters": {}}}, } - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + with pytest.raises(DeserializationError): CacheChecker.from_dict(data) def test_from_dict_nonexisting_docstore(self): diff --git a/test/components/retrievers/test_sentence_window_retriever.py b/test/components/retrievers/test_sentence_window_retriever.py index d1752c4b3..6d6e84782 100644 --- a/test/components/retrievers/test_sentence_window_retriever.py +++ b/test/components/retrievers/test_sentence_window_retriever.py @@ -87,7 +87,7 @@ class TestSentenceWindowRetriever: def test_from_dict_without_docstore_type(self): data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}} - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + with pytest.raises(DeserializationError): SentenceWindowRetriever.from_dict(data) def test_from_dict_non_existing_docstore(self): diff --git a/test/components/writers/test_document_writer.py b/test/components/writers/test_document_writer.py index f1af39947..44a4f40d2 100644 --- a/test/components/writers/test_document_writer.py +++ b/test/components/writers/test_document_writer.py @@ -57,7 +57,7 @@ class TestDocumentWriter: def test_from_dict_without_docstore_type(self): data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}} - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + with pytest.raises(DeserializationError): DocumentWriter.from_dict(data) def test_from_dict_nonexisting_docstore(self): diff --git a/test/utils/test_docstore_deserialization.py b/test/utils/test_docstore_deserialization.py new file mode 100644 index 000000000..7a70ad37e --- /dev/null +++ b/test/utils/test_docstore_deserialization.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch +import pytest + +from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore +from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_parameters +from haystack.core.errors import DeserializationError + + +class FakeDocumentStore: + pass + + +def test_deserialize_document_store_in_init_parameters(): + data = { + "type": "haystack.components.writers.document_writer.DocumentWriter", + "init_parameters": { + "document_store": { + "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": {}, + } + }, + } + + result = deserialize_document_store_in_init_parameters(data) + assert isinstance(result["init_parameters"]["document_store"], InMemoryDocumentStore) + + +def test_from_dict_is_called(): + """If the document store provides a from_dict method, it should be called.""" + data = { + "type": "haystack.components.writers.document_writer.DocumentWriter", + "init_parameters": { + "document_store": { + "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": {}, + } + }, + } + + with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict: + deserialize_document_store_in_init_parameters(data) + + mock_from_dict.assert_called_once_with( + {"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}} + ) + + +def test_default_from_dict_is_called(): + """If the document store does not provide a from_dict method, default_from_dict should be called.""" + data = { + "type": "haystack.components.writers.document_writer.DocumentWriter", + "init_parameters": { + "document_store": {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}} + }, + } + + with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict: + deserialize_document_store_in_init_parameters(data) + + mock_default_from_dict.assert_called_once_with( + FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}} + ) + + +def test_missing_document_store_key(): + data = {"init_parameters": {"policy": "SKIP"}} + with pytest.raises(DeserializationError): + deserialize_document_store_in_init_parameters(data) + + +def test_missing_type_key_in_document_store(): + data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}} + with pytest.raises(DeserializationError): + deserialize_document_store_in_init_parameters(data) + + +def test_invalid_class_import(): + data = { + "init_parameters": { + "document_store": {"type": "invalid.module.InvalidClass", "init_parameters": {}}, + "policy": "SKIP", + } + } + with pytest.raises(DeserializationError): + deserialize_document_store_in_init_parameters(data)