feat: InMemoryDocumentStore serialization (#7888)

* Add: InMemoryDocumentStore serialization

* Add: additional chek to test if path exists

* Fix: failing test
This commit is contained in:
David Berenstein 2024-06-21 16:45:25 +02:00 committed by GitHub
parent 9c45203a76
commit 08104e0042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 0 deletions

View File

@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0
import json
import math
import re
import uuid
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
import numpy as np
@ -339,6 +341,42 @@ class InMemoryDocumentStore:
"""
return default_from_dict(cls, data)
def save_to_disk(self, path: str) -> None:
"""
Write the database and its' data to disk as a JSON file.
:param path: The path to the JSON file.
"""
data: Dict[str, Any] = self.to_dict()
data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
with open(path, "w") as f:
json.dump(data, f)
@classmethod
def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
"""
Load the database and its' data from disk as a JSON file.
:param path: The path to the JSON file.
:returns: The loaded InMemoryDocumentStore.
"""
if Path(path).exists():
try:
with open(path, "r") as f:
data = json.load(f)
except Exception as e:
raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
documents = data.pop("documents")
cls_object = default_from_dict(cls, data)
cls_object.write_documents(
documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
)
return cls_object
else:
raise FileNotFoundError(f"File {path} not found.")
def count_documents(self) -> int:
"""
Returns the number of how many documents are present in the DocumentStore.

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore.

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import pandas as pd
import pytest
import tempfile
from haystack import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
Test InMemoryDocumentStore's specific features
"""
@pytest.fixture
def tmp_dir(self):
with tempfile.TemporaryDirectory() as tmp_dir:
yield tmp_dir
@pytest.fixture
def document_store(self) -> InMemoryDocumentStore:
return InMemoryDocumentStore(bm25_algorithm="BM25L")
@ -74,6 +80,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
assert store.bm25_parameters == {"key": "value"}
assert store.index == "my_cool_index"
def test_save_to_disk_and_load_from_disk(self, tmp_dir: str):
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
document_store = InMemoryDocumentStore()
document_store.write_documents(docs)
tmp_dir = tmp_dir + "/document_store.json"
document_store.save_to_disk(tmp_dir)
document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir)
assert document_store_loaded.count_documents() == 2
assert list(document_store_loaded.storage.values()) == docs
assert document_store_loaded.to_dict() == document_store.to_dict()
def test_invalid_bm25_algorithm(self):
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
InMemoryDocumentStore(bm25_algorithm="invalid")