refactor: utility function for docstore deserialization (#8226)

* refactor docstore deserialization

* more tests

* reno; headers

* expose key
This commit is contained in:
Stefano Fiorucci 2024-08-14 13:29:27 +02:00 committed by GitHub
parent 109e98aa44
commit bcc4104729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 155 additions and 72 deletions

View File

@ -4,9 +4,9 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,21 +71,8 @@ class CacheChecker:
:returns: :returns:
Deserialized component. Deserialized component.
""" """
init_params = data.get("init_parameters", {}) # deserialize the document store
if "document_store" not in init_params: data = deserialize_document_store_in_init_parameters(data)
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -4,9 +4,9 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,21 +77,8 @@ class FilterRetriever:
:returns: :returns:
The deserialized component. The deserialized component.
""" """
init_params = data.get("init_parameters", {}) # deserialize the document store
if "document_store" not in init_params: data = deserialize_document_store_in_init_parameters(data)
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -4,9 +4,9 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict from haystack import Document, component, default_from_dict, default_to_dict
from haystack.core.serialization import import_class_by_name
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters
@component @component
@ -117,24 +117,8 @@ class SentenceWindowRetriever:
:returns: :returns:
Deserialized component. Deserialized component.
""" """
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
# deserialize the document store # deserialize the document store
doc_store_data = data["init_parameters"]["document_store"] data = deserialize_document_store_in_init_parameters(data)
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
# deserialize the component # deserialize the component
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -4,9 +4,9 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack.document_stores.types import DocumentStore, DuplicatePolicy from haystack.document_stores.types import DocumentStore, DuplicatePolicy
from haystack.utils import deserialize_document_store_in_init_parameters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,21 +73,9 @@ class DocumentWriter:
:raises DeserializationError: :raises DeserializationError:
If the document store is not properly specified in the serialization data or its type cannot be imported. If the document store is not properly specified in the serialization data or its type cannot be imported.
""" """
init_params = data.get("init_parameters", {}) # deserialize the document store
if "document_store" not in init_params: data = deserialize_document_store_in_init_parameters(data)
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]] data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -5,6 +5,7 @@
from .auth import Secret, deserialize_secrets_inplace from .auth import Secret, deserialize_secrets_inplace
from .callable_serialization import deserialize_callable, serialize_callable from .callable_serialization import deserialize_callable, serialize_callable
from .device import ComponentDevice, Device, DeviceMap, DeviceType from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .docstore_deserialization import deserialize_document_store_in_init_parameters
from .expit import expit from .expit import expit
from .filters import document_matches_filter from .filters import document_matches_filter
from .jupyter import is_in_jupyter from .jupyter import is_in_jupyter
@ -26,4 +27,5 @@ __all__ = [
"deserialize_callable", "deserialize_callable",
"serialize_type", "serialize_type",
"deserialize_type", "deserialize_type",
"deserialize_document_store_in_init_parameters",
] ]

View File

@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict
from haystack import DeserializationError
from haystack.core.serialization import default_from_dict, import_class_by_name
def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str = "document_store") -> Dict[str, Any]:
"""
Deserializes a generic document store from the init_parameters of a serialized component.
:param data:
The dictionary to deserialize from.
:param key:
The key in the `data["init_parameters"]` dictionary where the document store is specified.
:returns:
The dictionary, with the document store deserialized.
:raises DeserializationError:
If the document store is not properly specified in the serialization data or its type cannot be imported.
"""
init_params = data.get("init_parameters", {})
if key not in init_params:
raise DeserializationError(f"Missing '{key}' in serialization data")
if "type" not in init_params[key]:
raise DeserializationError(f"Missing 'type' in {key} serialization data")
doc_store_data = data["init_parameters"][key]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data)
return data

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Introduce an utility function to deserialize a generic Document Store
from the init_parameters of a serialized component.

View File

@ -60,7 +60,7 @@ class TestCacheChecker:
"type": "haystack.components.caching.cache_checker.UrlCacheChecker", "type": "haystack.components.caching.cache_checker.UrlCacheChecker",
"init_parameters": {"document_store": {"init_parameters": {}}}, "init_parameters": {"document_store": {"init_parameters": {}}},
} }
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): with pytest.raises(DeserializationError):
CacheChecker.from_dict(data) CacheChecker.from_dict(data)
def test_from_dict_nonexisting_docstore(self): def test_from_dict_nonexisting_docstore(self):

View File

@ -87,7 +87,7 @@ class TestSentenceWindowRetriever:
def test_from_dict_without_docstore_type(self): def test_from_dict_without_docstore_type(self):
data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}} data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): with pytest.raises(DeserializationError):
SentenceWindowRetriever.from_dict(data) SentenceWindowRetriever.from_dict(data)
def test_from_dict_non_existing_docstore(self): def test_from_dict_non_existing_docstore(self):

View File

@ -57,7 +57,7 @@ class TestDocumentWriter:
def test_from_dict_without_docstore_type(self): def test_from_dict_without_docstore_type(self):
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}} data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): with pytest.raises(DeserializationError):
DocumentWriter.from_dict(data) DocumentWriter.from_dict(data)
def test_from_dict_nonexisting_docstore(self): def test_from_dict_nonexisting_docstore(self):

View File

@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch
import pytest
from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore
from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_parameters
from haystack.core.errors import DeserializationError
class FakeDocumentStore:
pass
def test_deserialize_document_store_in_init_parameters():
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
}
},
}
result = deserialize_document_store_in_init_parameters(data)
assert isinstance(result["init_parameters"]["document_store"], InMemoryDocumentStore)
def test_from_dict_is_called():
"""If the document store provides a from_dict method, it should be called."""
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
}
},
}
with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict:
deserialize_document_store_in_init_parameters(data)
mock_from_dict.assert_called_once_with(
{"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}}
)
def test_default_from_dict_is_called():
"""If the document store does not provide a from_dict method, default_from_dict should be called."""
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
},
}
with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict:
deserialize_document_store_in_init_parameters(data)
mock_default_from_dict.assert_called_once_with(
FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
)
def test_missing_document_store_key():
data = {"init_parameters": {"policy": "SKIP"}}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)
def test_missing_type_key_in_document_store():
data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)
def test_invalid_class_import():
data = {
"init_parameters": {
"document_store": {"type": "invalid.module.InvalidClass", "init_parameters": {}},
"policy": "SKIP",
}
}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)