mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	feat: InMemoryDocumentStore serialization (#7888)
* Add: InMemoryDocumentStore serialization * Add: additional chek to test if path exists * Fix: failing test
This commit is contained in:
		
							parent
							
								
									9c45203a76
								
							
						
					
					
						commit
						08104e0042
					
				@ -2,11 +2,13 @@
 | 
			
		||||
#
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import math
 | 
			
		||||
import re
 | 
			
		||||
import uuid
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
@ -339,6 +341,42 @@ class InMemoryDocumentStore:
 | 
			
		||||
        """
 | 
			
		||||
        return default_from_dict(cls, data)
 | 
			
		||||
 | 
			
		||||
    def save_to_disk(self, path: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Write the database and its' data to disk as a JSON file.
 | 
			
		||||
 | 
			
		||||
        :param path: The path to the JSON file.
 | 
			
		||||
        """
 | 
			
		||||
        data: Dict[str, Any] = self.to_dict()
 | 
			
		||||
        data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
 | 
			
		||||
        with open(path, "w") as f:
 | 
			
		||||
            json.dump(data, f)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
 | 
			
		||||
        """
 | 
			
		||||
        Load the database and its' data from disk as a JSON file.
 | 
			
		||||
 | 
			
		||||
        :param path: The path to the JSON file.
 | 
			
		||||
        :returns: The loaded InMemoryDocumentStore.
 | 
			
		||||
        """
 | 
			
		||||
        if Path(path).exists():
 | 
			
		||||
            try:
 | 
			
		||||
                with open(path, "r") as f:
 | 
			
		||||
                    data = json.load(f)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
 | 
			
		||||
 | 
			
		||||
            documents = data.pop("documents")
 | 
			
		||||
            cls_object = default_from_dict(cls, data)
 | 
			
		||||
            cls_object.write_documents(
 | 
			
		||||
                documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
 | 
			
		||||
            )
 | 
			
		||||
            return cls_object
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            raise FileNotFoundError(f"File {path} not found.")
 | 
			
		||||
 | 
			
		||||
    def count_documents(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the number of how many documents are present in the DocumentStore.
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,4 @@
 | 
			
		||||
---
 | 
			
		||||
enhancements:
 | 
			
		||||
  - |
 | 
			
		||||
    Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore.
 | 
			
		||||
@ -6,6 +6,7 @@ from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
import pytest
 | 
			
		||||
import tempfile
 | 
			
		||||
 | 
			
		||||
from haystack import Document
 | 
			
		||||
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
 | 
			
		||||
@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):  # pylint: disable=R0904
 | 
			
		||||
    Test InMemoryDocumentStore's specific features
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def tmp_dir(self):
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmp_dir:
 | 
			
		||||
            yield tmp_dir
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def document_store(self) -> InMemoryDocumentStore:
 | 
			
		||||
        return InMemoryDocumentStore(bm25_algorithm="BM25L")
 | 
			
		||||
@ -74,6 +80,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):  # pylint: disable=R0904
 | 
			
		||||
        assert store.bm25_parameters == {"key": "value"}
 | 
			
		||||
        assert store.index == "my_cool_index"
 | 
			
		||||
 | 
			
		||||
    def test_save_to_disk_and_load_from_disk(self, tmp_dir: str):
 | 
			
		||||
        docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
 | 
			
		||||
        document_store = InMemoryDocumentStore()
 | 
			
		||||
        document_store.write_documents(docs)
 | 
			
		||||
        tmp_dir = tmp_dir + "/document_store.json"
 | 
			
		||||
        document_store.save_to_disk(tmp_dir)
 | 
			
		||||
        document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir)
 | 
			
		||||
 | 
			
		||||
        assert document_store_loaded.count_documents() == 2
 | 
			
		||||
        assert list(document_store_loaded.storage.values()) == docs
 | 
			
		||||
        assert document_store_loaded.to_dict() == document_store.to_dict()
 | 
			
		||||
 | 
			
		||||
    def test_invalid_bm25_algorithm(self):
 | 
			
		||||
        with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
 | 
			
		||||
            InMemoryDocumentStore(bm25_algorithm="invalid")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user