mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-06 21:05:33 +00:00
feat: InMemoryDocumentStore serialization (#7888)
* Add: InMemoryDocumentStore serialization * Add: additional chek to test if path exists * Fix: failing test
This commit is contained in:
parent
9c45203a76
commit
08104e0042
@ -2,11 +2,13 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -339,6 +341,42 @@ class InMemoryDocumentStore:
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def save_to_disk(self, path: str) -> None:
|
||||
"""
|
||||
Write the database and its' data to disk as a JSON file.
|
||||
|
||||
:param path: The path to the JSON file.
|
||||
"""
|
||||
data: Dict[str, Any] = self.to_dict()
|
||||
data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
@classmethod
|
||||
def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
|
||||
"""
|
||||
Load the database and its' data from disk as a JSON file.
|
||||
|
||||
:param path: The path to the JSON file.
|
||||
:returns: The loaded InMemoryDocumentStore.
|
||||
"""
|
||||
if Path(path).exists():
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
|
||||
|
||||
documents = data.pop("documents")
|
||||
cls_object = default_from_dict(cls, data)
|
||||
cls_object.write_documents(
|
||||
documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
|
||||
)
|
||||
return cls_object
|
||||
|
||||
else:
|
||||
raise FileNotFoundError(f"File {path} not found.")
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
Returns the number of how many documents are present in the DocumentStore.
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore.
|
||||
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
|
||||
@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
Test InMemoryDocumentStore's specific features
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield tmp_dir
|
||||
|
||||
@pytest.fixture
|
||||
def document_store(self) -> InMemoryDocumentStore:
|
||||
return InMemoryDocumentStore(bm25_algorithm="BM25L")
|
||||
@ -74,6 +80,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
assert store.bm25_parameters == {"key": "value"}
|
||||
assert store.index == "my_cool_index"
|
||||
|
||||
def test_save_to_disk_and_load_from_disk(self, tmp_dir: str):
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
document_store = InMemoryDocumentStore()
|
||||
document_store.write_documents(docs)
|
||||
tmp_dir = tmp_dir + "/document_store.json"
|
||||
document_store.save_to_disk(tmp_dir)
|
||||
document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir)
|
||||
|
||||
assert document_store_loaded.count_documents() == 2
|
||||
assert list(document_store_loaded.storage.values()) == docs
|
||||
assert document_store_loaded.to_dict() == document_store.to_dict()
|
||||
|
||||
def test_invalid_bm25_algorithm(self):
|
||||
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
||||
InMemoryDocumentStore(bm25_algorithm="invalid")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user