diff --git a/haystack/document_stores/in_memory/document_store.py b/haystack/document_stores/in_memory/document_store.py index 62c3fd496..f35fe56a9 100644 --- a/haystack/document_stores/in_memory/document_store.py +++ b/haystack/document_stores/in_memory/document_store.py @@ -4,6 +4,7 @@ import math import re +import uuid from collections import Counter from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple @@ -42,6 +43,11 @@ class BM25DocumentStats: doc_len: int +# Global storage for all InMemoryDocumentStore instances, indexed by the index name. +_STORAGES: Dict[str, Dict[str, Document]] = {} +_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {} + + class InMemoryDocumentStore: """ Stores data in-memory. It's ephemeral and cannot be saved to disk. @@ -53,6 +59,7 @@ class InMemoryDocumentStore: bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L", bm25_parameters: Optional[Dict] = None, embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product", + index: Optional[str] = None, ): """ Initializes the DocumentStore. @@ -66,11 +73,19 @@ class InMemoryDocumentStore: :param embedding_similarity_function: The similarity function used to compare Documents embeddings. One of "dot_product" (default) or "cosine". To choose the most appropriate function, look for information about your embedding model. + :param index: If specified uses a specific index to store the documents. If not specified, a random UUID is used. + Using the same index allows you to store documents across multiple InMemoryDocumentStore instances. """ - self.storage: Dict[str, Document] = {} self.bm25_tokenization_regex = bm25_tokenization_regex self.tokenizer = re.compile(bm25_tokenization_regex).findall + if index is None: + index = str(uuid.uuid4()) + + self.index = index + if self.index not in _STORAGES: + _STORAGES[self.index] = {} + self.bm25_algorithm = bm25_algorithm self.bm25_algorithm_inst = self._dispatch_bm25() self.bm25_parameters = bm25_parameters or {} @@ -81,7 +96,19 @@ class InMemoryDocumentStore: self._freq_vocab_for_idf: Counter = Counter() # Per-document statistics - self._bm25_attr: Dict[str, BM25DocumentStats] = {} + if self.index not in _BM25_STATS_STORAGES: + _BM25_STATS_STORAGES[self.index] = {} + + @property + def storage(self) -> Dict[str, Document]: + """ + Utility property that returns the storage used by this instance of InMemoryDocumentStore. + """ + return _STORAGES.get(self.index, {}) + + @property + def _bm25_attr(self) -> Dict[str, BM25DocumentStats]: + return _BM25_STATS_STORAGES.get(self.index, {}) def _dispatch_bm25(self): """ @@ -281,6 +308,7 @@ class InMemoryDocumentStore: bm25_algorithm=self.bm25_algorithm, bm25_parameters=self.bm25_parameters, embedding_similarity_function=self.embedding_similarity_function, + index=self.index, ) @classmethod diff --git a/releasenotes/notes/in-memory-docstore-memory-share-82b75d018b3545fc.yaml b/releasenotes/notes/in-memory-docstore-memory-share-82b75d018b3545fc.yaml new file mode 100644 index 000000000..d60997a41 --- /dev/null +++ b/releasenotes/notes/in-memory-docstore-memory-share-82b75d018b3545fc.yaml @@ -0,0 +1,19 @@ +--- +features: + - | + Add memory sharing between different instances of InMemoryDocumentStore. + Setting the same `index` argument as another instance will make sure that the memory is shared. + e.g. + ```python + index = "my_personal_index" + document_store_1 = InMemoryDocumentStore(index=index) + document_store_2 = InMemoryDocumentStore(index=index) + + assert document_store_1.count_documents() == 0 + assert document_store_2.count_documents() == 0 + + document_store_1.write_documents([Document(content="Hello world")]) + + assert document_store_1.count_documents() == 1 + assert document_store_2.count_documents() == 1 + ``` diff --git a/test/components/retrievers/test_filter_retriever.py b/test/components/retrievers/test_filter_retriever.py index 20d3392f5..bef09e7a7 100644 --- a/test/components/retrievers/test_filter_retriever.py +++ b/test/components/retrievers/test_filter_retriever.py @@ -63,10 +63,13 @@ class TestFilterRetriever: } def test_to_dict_with_custom_init_parameters(self): - ds = InMemoryDocumentStore() + ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters") serialized_ds = ds.to_dict() - component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"}) + component = FilterRetriever( + document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"), + filters={"lang": "en"}, + ) data = component.to_dict() assert data == { "type": "haystack.components.retrievers.filter_retriever.FilterRetriever", diff --git a/test/components/retrievers/test_in_memory_bm25_retriever.py b/test/components/retrievers/test_in_memory_bm25_retriever.py index ddc6454b4..59a88a8e8 100644 --- a/test/components/retrievers/test_in_memory_bm25_retriever.py +++ b/test/components/retrievers/test_in_memory_bm25_retriever.py @@ -60,11 +60,14 @@ class TestMemoryBM25Retriever: } def test_to_dict_with_custom_init_parameters(self): - ds = InMemoryDocumentStore() + ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters") serialized_ds = ds.to_dict() component = InMemoryBM25Retriever( - document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True + document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"), + filters={"name": "test.txt"}, + top_k=5, + scale_score=True, ) data = component.to_dict() assert data == { diff --git a/test/document_stores/test_in_memory.py b/test/document_stores/test_in_memory.py index d3a78aeff..2a8679502 100644 --- a/test/document_stores/test_in_memory.py +++ b/test/document_stores/test_in_memory.py @@ -32,6 +32,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 "bm25_algorithm": "BM25L", "bm25_parameters": {}, "embedding_similarity_function": "dot_product", + "index": store.index, }, } @@ -41,6 +42,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"}, embedding_similarity_function="cosine", + index="my_cool_index", ) data = store.to_dict() assert data == { @@ -50,6 +52,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 "bm25_algorithm": "BM25Plus", "bm25_parameters": {"key": "value"}, "embedding_similarity_function": "cosine", + "index": "my_cool_index", }, } @@ -61,6 +64,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 "bm25_tokenization_regex": "custom_regex", "bm25_algorithm": "BM25Plus", "bm25_parameters": {"key": "value"}, + "index": "my_cool_index", }, } store = InMemoryDocumentStore.from_dict(data) @@ -68,6 +72,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 assert store.tokenizer assert store.bm25_algorithm == "BM25Plus" assert store.bm25_parameters == {"key": "value"} + assert store.index == "my_cool_index" def test_invalid_bm25_algorithm(self): with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"): @@ -414,3 +419,28 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False ) assert scores == [0.1, 0.4] + + def test_multiple_document_stores_using_same_index(self): + index = "test_multiple_document_stores_using_same_index" + document_store_1 = InMemoryDocumentStore(index=index) + document_store_2 = InMemoryDocumentStore(index=index) + + assert document_store_1.count_documents() == document_store_2.count_documents() == 0 + + doc_1 = Document(content="Hello world") + document_store_1.write_documents([doc_1]) + assert document_store_1.count_documents() == document_store_2.count_documents() == 1 + + assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1] + + doc_2 = Document(content="Hello another world") + document_store_2.write_documents([doc_2]) + assert document_store_1.count_documents() == document_store_2.count_documents() == 2 + + assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2] + + document_store_1.delete_documents([doc_2.id]) + assert document_store_1.count_documents() == document_store_2.count_documents() == 1 + + document_store_2.delete_documents([doc_1.id]) + assert document_store_1.count_documents() == document_store_2.count_documents() == 0