mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-06 21:05:33 +00:00
feat: InMemoryDocumentStore serialization (#7888)
* Add: InMemoryDocumentStore serialization * Add: additional chek to test if path exists * Fix: failing test
This commit is contained in:
parent
9c45203a76
commit
08104e0042
@ -2,11 +2,13 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -339,6 +341,42 @@ class InMemoryDocumentStore:
|
|||||||
"""
|
"""
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
def save_to_disk(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Write the database and its' data to disk as a JSON file.
|
||||||
|
|
||||||
|
:param path: The path to the JSON file.
|
||||||
|
"""
|
||||||
|
data: Dict[str, Any] = self.to_dict()
|
||||||
|
data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
|
||||||
|
"""
|
||||||
|
Load the database and its' data from disk as a JSON file.
|
||||||
|
|
||||||
|
:param path: The path to the JSON file.
|
||||||
|
:returns: The loaded InMemoryDocumentStore.
|
||||||
|
"""
|
||||||
|
if Path(path).exists():
|
||||||
|
try:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
|
||||||
|
|
||||||
|
documents = data.pop("documents")
|
||||||
|
cls_object = default_from_dict(cls, data)
|
||||||
|
cls_object.write_documents(
|
||||||
|
documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
|
||||||
|
)
|
||||||
|
return cls_object
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File {path} not found.")
|
||||||
|
|
||||||
def count_documents(self) -> int:
|
def count_documents(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of how many documents are present in the DocumentStore.
|
Returns the number of how many documents are present in the DocumentStore.
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore.
|
||||||
@ -6,6 +6,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
|
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
|
||||||
@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
Test InMemoryDocumentStore's specific features
|
Test InMemoryDocumentStore's specific features
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_dir(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
yield tmp_dir
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def document_store(self) -> InMemoryDocumentStore:
|
def document_store(self) -> InMemoryDocumentStore:
|
||||||
return InMemoryDocumentStore(bm25_algorithm="BM25L")
|
return InMemoryDocumentStore(bm25_algorithm="BM25L")
|
||||||
@ -74,6 +80,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
assert store.bm25_parameters == {"key": "value"}
|
assert store.bm25_parameters == {"key": "value"}
|
||||||
assert store.index == "my_cool_index"
|
assert store.index == "my_cool_index"
|
||||||
|
|
||||||
|
def test_save_to_disk_and_load_from_disk(self, tmp_dir: str):
|
||||||
|
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(docs)
|
||||||
|
tmp_dir = tmp_dir + "/document_store.json"
|
||||||
|
document_store.save_to_disk(tmp_dir)
|
||||||
|
document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir)
|
||||||
|
|
||||||
|
assert document_store_loaded.count_documents() == 2
|
||||||
|
assert list(document_store_loaded.storage.values()) == docs
|
||||||
|
assert document_store_loaded.to_dict() == document_store.to_dict()
|
||||||
|
|
||||||
def test_invalid_bm25_algorithm(self):
|
def test_invalid_bm25_algorithm(self):
|
||||||
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
||||||
InMemoryDocumentStore(bm25_algorithm="invalid")
|
InMemoryDocumentStore(bm25_algorithm="invalid")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user