mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-05 04:09:32 +00:00
refactor: utility function for docstore deserialization (#8226)
* refactor docstore deserialization * more tests * reno; headers * expose key
This commit is contained in:
parent
109e98aa44
commit
bcc4104729
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.core.serialization import import_class_by_name
|
|
||||||
from haystack.document_stores.types import DocumentStore
|
from haystack.document_stores.types import DocumentStore
|
||||||
|
from haystack.utils import deserialize_document_store_in_init_parameters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -71,21 +71,8 @@ class CacheChecker:
|
|||||||
:returns:
|
:returns:
|
||||||
Deserialized component.
|
Deserialized component.
|
||||||
"""
|
"""
|
||||||
init_params = data.get("init_parameters", {})
|
# deserialize the document store
|
||||||
if "document_store" not in init_params:
|
data = deserialize_document_store_in_init_parameters(data)
|
||||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
|
||||||
if "type" not in init_params["document_store"]:
|
|
||||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
|
||||||
|
|
||||||
doc_store_data = data["init_parameters"]["document_store"]
|
|
||||||
try:
|
|
||||||
doc_store_class = import_class_by_name(doc_store_data["type"])
|
|
||||||
except ImportError as e:
|
|
||||||
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
|
|
||||||
if hasattr(doc_store_class, "from_dict"):
|
|
||||||
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
|
|
||||||
else:
|
|
||||||
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
|
|
||||||
|
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
|||||||
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.core.serialization import import_class_by_name
|
|
||||||
from haystack.document_stores.types import DocumentStore
|
from haystack.document_stores.types import DocumentStore
|
||||||
|
from haystack.utils import deserialize_document_store_in_init_parameters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -77,21 +77,8 @@ class FilterRetriever:
|
|||||||
:returns:
|
:returns:
|
||||||
The deserialized component.
|
The deserialized component.
|
||||||
"""
|
"""
|
||||||
init_params = data.get("init_parameters", {})
|
# deserialize the document store
|
||||||
if "document_store" not in init_params:
|
data = deserialize_document_store_in_init_parameters(data)
|
||||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
|
||||||
if "type" not in init_params["document_store"]:
|
|
||||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
|
||||||
|
|
||||||
doc_store_data = data["init_parameters"]["document_store"]
|
|
||||||
try:
|
|
||||||
doc_store_class = import_class_by_name(doc_store_data["type"])
|
|
||||||
except ImportError as e:
|
|
||||||
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
|
|
||||||
if hasattr(doc_store_class, "from_dict"):
|
|
||||||
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
|
|
||||||
else:
|
|
||||||
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
|
|
||||||
|
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
|||||||
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
|
from haystack import Document, component, default_from_dict, default_to_dict
|
||||||
from haystack.core.serialization import import_class_by_name
|
|
||||||
from haystack.document_stores.types import DocumentStore
|
from haystack.document_stores.types import DocumentStore
|
||||||
|
from haystack.utils import deserialize_document_store_in_init_parameters
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
@ -117,24 +117,8 @@ class SentenceWindowRetriever:
|
|||||||
:returns:
|
:returns:
|
||||||
Deserialized component.
|
Deserialized component.
|
||||||
"""
|
"""
|
||||||
init_params = data.get("init_parameters", {})
|
|
||||||
|
|
||||||
if "document_store" not in init_params:
|
|
||||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
|
||||||
if "type" not in init_params["document_store"]:
|
|
||||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
|
||||||
|
|
||||||
# deserialize the document store
|
# deserialize the document store
|
||||||
doc_store_data = data["init_parameters"]["document_store"]
|
data = deserialize_document_store_in_init_parameters(data)
|
||||||
try:
|
|
||||||
doc_store_class = import_class_by_name(doc_store_data["type"])
|
|
||||||
except ImportError as e:
|
|
||||||
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
|
|
||||||
|
|
||||||
if hasattr(doc_store_class, "from_dict"):
|
|
||||||
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
|
|
||||||
else:
|
|
||||||
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
|
|
||||||
|
|
||||||
# deserialize the component
|
# deserialize the component
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|||||||
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.core.serialization import import_class_by_name
|
|
||||||
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
|
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
|
||||||
|
from haystack.utils import deserialize_document_store_in_init_parameters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -73,21 +73,9 @@ class DocumentWriter:
|
|||||||
:raises DeserializationError:
|
:raises DeserializationError:
|
||||||
If the document store is not properly specified in the serialization data or its type cannot be imported.
|
If the document store is not properly specified in the serialization data or its type cannot be imported.
|
||||||
"""
|
"""
|
||||||
init_params = data.get("init_parameters", {})
|
# deserialize the document store
|
||||||
if "document_store" not in init_params:
|
data = deserialize_document_store_in_init_parameters(data)
|
||||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
|
||||||
if "type" not in init_params["document_store"]:
|
|
||||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
|
||||||
|
|
||||||
doc_store_data = data["init_parameters"]["document_store"]
|
|
||||||
try:
|
|
||||||
doc_store_class = import_class_by_name(doc_store_data["type"])
|
|
||||||
except ImportError as e:
|
|
||||||
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
|
|
||||||
if hasattr(doc_store_class, "from_dict"):
|
|
||||||
data["init_parameters"]["document_store"] = doc_store_class.from_dict(doc_store_data)
|
|
||||||
else:
|
|
||||||
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
|
|
||||||
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
|
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
|
||||||
|
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
from .auth import Secret, deserialize_secrets_inplace
|
from .auth import Secret, deserialize_secrets_inplace
|
||||||
from .callable_serialization import deserialize_callable, serialize_callable
|
from .callable_serialization import deserialize_callable, serialize_callable
|
||||||
from .device import ComponentDevice, Device, DeviceMap, DeviceType
|
from .device import ComponentDevice, Device, DeviceMap, DeviceType
|
||||||
|
from .docstore_deserialization import deserialize_document_store_in_init_parameters
|
||||||
from .expit import expit
|
from .expit import expit
|
||||||
from .filters import document_matches_filter
|
from .filters import document_matches_filter
|
||||||
from .jupyter import is_in_jupyter
|
from .jupyter import is_in_jupyter
|
||||||
@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"deserialize_callable",
|
"deserialize_callable",
|
||||||
"serialize_type",
|
"serialize_type",
|
||||||
"deserialize_type",
|
"deserialize_type",
|
||||||
|
"deserialize_document_store_in_init_parameters",
|
||||||
]
|
]
|
||||||
|
|||||||
41
haystack/utils/docstore_deserialization.py
Normal file
41
haystack/utils/docstore_deserialization.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from haystack import DeserializationError
|
||||||
|
from haystack.core.serialization import default_from_dict, import_class_by_name
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str = "document_store") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Deserializes a generic document store from the init_parameters of a serialized component.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
The dictionary to deserialize from.
|
||||||
|
:param key:
|
||||||
|
The key in the `data["init_parameters"]` dictionary where the document store is specified.
|
||||||
|
:returns:
|
||||||
|
The dictionary, with the document store deserialized.
|
||||||
|
|
||||||
|
:raises DeserializationError:
|
||||||
|
If the document store is not properly specified in the serialization data or its type cannot be imported.
|
||||||
|
"""
|
||||||
|
init_params = data.get("init_parameters", {})
|
||||||
|
if key not in init_params:
|
||||||
|
raise DeserializationError(f"Missing '{key}' in serialization data")
|
||||||
|
if "type" not in init_params[key]:
|
||||||
|
raise DeserializationError(f"Missing 'type' in {key} serialization data")
|
||||||
|
|
||||||
|
doc_store_data = data["init_parameters"][key]
|
||||||
|
try:
|
||||||
|
doc_store_class = import_class_by_name(doc_store_data["type"])
|
||||||
|
except ImportError as e:
|
||||||
|
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
|
||||||
|
if hasattr(doc_store_class, "from_dict"):
|
||||||
|
data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data)
|
||||||
|
else:
|
||||||
|
data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data)
|
||||||
|
|
||||||
|
return data
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Introduce an utility function to deserialize a generic Document Store
|
||||||
|
from the init_parameters of a serialized component.
|
||||||
@ -60,7 +60,7 @@ class TestCacheChecker:
|
|||||||
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
|
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
|
||||||
"init_parameters": {"document_store": {"init_parameters": {}}},
|
"init_parameters": {"document_store": {"init_parameters": {}}},
|
||||||
}
|
}
|
||||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
with pytest.raises(DeserializationError):
|
||||||
CacheChecker.from_dict(data)
|
CacheChecker.from_dict(data)
|
||||||
|
|
||||||
def test_from_dict_nonexisting_docstore(self):
|
def test_from_dict_nonexisting_docstore(self):
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class TestSentenceWindowRetriever:
|
|||||||
|
|
||||||
def test_from_dict_without_docstore_type(self):
|
def test_from_dict_without_docstore_type(self):
|
||||||
data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
with pytest.raises(DeserializationError):
|
||||||
SentenceWindowRetriever.from_dict(data)
|
SentenceWindowRetriever.from_dict(data)
|
||||||
|
|
||||||
def test_from_dict_non_existing_docstore(self):
|
def test_from_dict_non_existing_docstore(self):
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class TestDocumentWriter:
|
|||||||
|
|
||||||
def test_from_dict_without_docstore_type(self):
|
def test_from_dict_without_docstore_type(self):
|
||||||
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
with pytest.raises(DeserializationError):
|
||||||
DocumentWriter.from_dict(data)
|
DocumentWriter.from_dict(data)
|
||||||
|
|
||||||
def test_from_dict_nonexisting_docstore(self):
|
def test_from_dict_nonexisting_docstore(self):
|
||||||
|
|||||||
89
test/utils/test_docstore_deserialization.py
Normal file
89
test/utils/test_docstore_deserialization.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore
|
||||||
|
from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_parameters
|
||||||
|
from haystack.core.errors import DeserializationError
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDocumentStore:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_deserialize_document_store_in_init_parameters():
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {
|
||||||
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||||
|
"init_parameters": {},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = deserialize_document_store_in_init_parameters(data)
|
||||||
|
assert isinstance(result["init_parameters"]["document_store"], InMemoryDocumentStore)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_is_called():
|
||||||
|
"""If the document store provides a from_dict method, it should be called."""
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {
|
||||||
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||||
|
"init_parameters": {},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict:
|
||||||
|
deserialize_document_store_in_init_parameters(data)
|
||||||
|
|
||||||
|
mock_from_dict.assert_called_once_with(
|
||||||
|
{"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_from_dict_is_called():
|
||||||
|
"""If the document store does not provide a from_dict method, default_from_dict should be called."""
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict:
|
||||||
|
deserialize_document_store_in_init_parameters(data)
|
||||||
|
|
||||||
|
mock_default_from_dict.assert_called_once_with(
|
||||||
|
FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_document_store_key():
|
||||||
|
data = {"init_parameters": {"policy": "SKIP"}}
|
||||||
|
with pytest.raises(DeserializationError):
|
||||||
|
deserialize_document_store_in_init_parameters(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_type_key_in_document_store():
|
||||||
|
data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}}
|
||||||
|
with pytest.raises(DeserializationError):
|
||||||
|
deserialize_document_store_in_init_parameters(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_class_import():
|
||||||
|
data = {
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {"type": "invalid.module.InvalidClass", "init_parameters": {}},
|
||||||
|
"policy": "SKIP",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with pytest.raises(DeserializationError):
|
||||||
|
deserialize_document_store_in_init_parameters(data)
|
||||||
Loading…
x
Reference in New Issue
Block a user