refactor: Rename deserialize_document_store_in_init_parameters (#8302)

* 8259

* update function name

* rename and update docstring

* fix linting

* add a release note
This commit is contained in:
Alper 2024-09-02 11:42:23 +02:00 committed by GitHub
parent 7dbc51a3e7
commit e614fa0c62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 26 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List
from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters from haystack.utils import deserialize_document_store_in_init_params_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +73,7 @@ class CacheChecker:
Deserialized component. Deserialized component.
""" """
# deserialize the document store # deserialize the document store
data = deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters from haystack.utils import deserialize_document_store_in_init_params_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,7 +78,7 @@ class FilterRetriever:
The deserialized component. The deserialized component.
""" """
# deserialize the document store # deserialize the document store
data = deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List
from haystack import Document, component, default_from_dict, default_to_dict from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import DocumentStore from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters from haystack.utils import deserialize_document_store_in_init_params_inplace
@component @component
@ -131,7 +131,7 @@ class SentenceWindowRetriever:
Deserialized component. Deserialized component.
""" """
# deserialize the document store # deserialize the document store
data = deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
# deserialize the component # deserialize the component
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore, DuplicatePolicy from haystack.document_stores.types import DocumentStore, DuplicatePolicy
from haystack.utils import deserialize_document_store_in_init_parameters from haystack.utils import deserialize_document_store_in_init_params_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -74,7 +74,7 @@ class DocumentWriter:
If the document store is not properly specified in the serialization data or its type cannot be imported. If the document store is not properly specified in the serialization data or its type cannot be imported.
""" """
# deserialize the document store # deserialize the document store
data = deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]] data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]

View File

@ -5,7 +5,7 @@
from .auth import Secret, deserialize_secrets_inplace from .auth import Secret, deserialize_secrets_inplace
from .callable_serialization import deserialize_callable, serialize_callable from .callable_serialization import deserialize_callable, serialize_callable
from .device import ComponentDevice, Device, DeviceMap, DeviceType from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .docstore_deserialization import deserialize_document_store_in_init_parameters from .docstore_deserialization import deserialize_document_store_in_init_params_inplace
from .expit import expit from .expit import expit
from .filters import document_matches_filter from .filters import document_matches_filter
from .jupyter import is_in_jupyter from .jupyter import is_in_jupyter
@ -27,5 +27,5 @@ __all__ = [
"deserialize_callable", "deserialize_callable",
"serialize_type", "serialize_type",
"deserialize_type", "deserialize_type",
"deserialize_document_store_in_init_parameters", "deserialize_document_store_in_init_params_inplace",
] ]

View File

@ -8,9 +8,9 @@ from haystack import DeserializationError
from haystack.core.serialization import default_from_dict, import_class_by_name from haystack.core.serialization import default_from_dict, import_class_by_name
def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str = "document_store") -> Dict[str, Any]: def deserialize_document_store_in_init_params_inplace(data: Dict[str, Any], key: str = "document_store"):
""" """
Deserializes a generic document store from the init_parameters of a serialized component. Deserializes a generic document store from the init_parameters of a serialized component in place.
:param data: :param data:
The dictionary to deserialize from. The dictionary to deserialize from.
@ -37,5 +37,3 @@ def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str
data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data) data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data)
else: else:
data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data) data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data)
return data

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Refactor deserialize_document_store_in_init_parameters so that new function name
indicates that the operation occurs in place, with no return value.

View File

@ -6,7 +6,7 @@ from unittest.mock import patch
import pytest import pytest
from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore
from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_parameters from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_params_inplace
from haystack.core.errors import DeserializationError from haystack.core.errors import DeserializationError
@ -14,7 +14,7 @@ class FakeDocumentStore:
pass pass
def test_deserialize_document_store_in_init_parameters(): def test_deserialize_document_store_in_init_params_inplace():
data = { data = {
"type": "haystack.components.writers.document_writer.DocumentWriter", "type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": { "init_parameters": {
@ -25,8 +25,8 @@ def test_deserialize_document_store_in_init_parameters():
}, },
} }
result = deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
assert isinstance(result["init_parameters"]["document_store"], InMemoryDocumentStore) assert isinstance(data["init_parameters"]["document_store"], InMemoryDocumentStore)
def test_from_dict_is_called(): def test_from_dict_is_called():
@ -42,7 +42,7 @@ def test_from_dict_is_called():
} }
with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict: with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict:
deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
mock_from_dict.assert_called_once_with( mock_from_dict.assert_called_once_with(
{"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}} {"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}}
@ -59,7 +59,7 @@ def test_default_from_dict_is_called():
} }
with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict: with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict:
deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
mock_default_from_dict.assert_called_once_with( mock_default_from_dict.assert_called_once_with(
FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}} FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
@ -69,13 +69,13 @@ def test_default_from_dict_is_called():
def test_missing_document_store_key(): def test_missing_document_store_key():
data = {"init_parameters": {"policy": "SKIP"}} data = {"init_parameters": {"policy": "SKIP"}}
with pytest.raises(DeserializationError): with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
def test_missing_type_key_in_document_store(): def test_missing_type_key_in_document_store():
data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}} data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}}
with pytest.raises(DeserializationError): with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)
def test_invalid_class_import(): def test_invalid_class_import():
@ -86,4 +86,4 @@ def test_invalid_class_import():
} }
} }
with pytest.raises(DeserializationError): with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data) deserialize_document_store_in_init_params_inplace(data)