feat: Add memory sharing between different instances of InMemoryDocumentStore (#7781)

* Add memory sharing between different instances of InMemoryDocumentStore

* Fix FilterRetriever tests

* Fix InMemoryBM25Retriever tests
This commit is contained in:
Silvano Cerza 2024-05-31 16:44:14 +02:00 committed by GitHub
parent d81af81fbb
commit 854c4173f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 6 deletions

View File

@ -4,6 +4,7 @@
import math
import re
import uuid
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
@ -42,6 +43,11 @@ class BM25DocumentStats:
doc_len: int
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
_STORAGES: Dict[str, Dict[str, Document]] = {}
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
@ -53,6 +59,7 @@ class InMemoryDocumentStore:
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
bm25_parameters: Optional[Dict] = None,
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
index: Optional[str] = None,
):
"""
Initializes the DocumentStore.
@ -66,11 +73,19 @@ class InMemoryDocumentStore:
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
One of "dot_product" (default) or "cosine".
To choose the most appropriate function, look for information about your embedding model.
:param index: If specified uses a specific index to store the documents. If not specified, a random UUID is used.
Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
"""
self.storage: Dict[str, Document] = {}
self.bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall
if index is None:
index = str(uuid.uuid4())
self.index = index
if self.index not in _STORAGES:
_STORAGES[self.index] = {}
self.bm25_algorithm = bm25_algorithm
self.bm25_algorithm_inst = self._dispatch_bm25()
self.bm25_parameters = bm25_parameters or {}
@ -81,7 +96,19 @@ class InMemoryDocumentStore:
self._freq_vocab_for_idf: Counter = Counter()
# Per-document statistics
self._bm25_attr: Dict[str, BM25DocumentStats] = {}
if self.index not in _BM25_STATS_STORAGES:
_BM25_STATS_STORAGES[self.index] = {}
@property
def storage(self) -> Dict[str, Document]:
"""
Utility property that returns the storage used by this instance of InMemoryDocumentStore.
"""
return _STORAGES.get(self.index, {})
@property
def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
return _BM25_STATS_STORAGES.get(self.index, {})
def _dispatch_bm25(self):
"""
@ -281,6 +308,7 @@ class InMemoryDocumentStore:
bm25_algorithm=self.bm25_algorithm,
bm25_parameters=self.bm25_parameters,
embedding_similarity_function=self.embedding_similarity_function,
index=self.index,
)
@classmethod

View File

@ -0,0 +1,19 @@
---
features:
- |
Add memory sharing between different instances of InMemoryDocumentStore.
Setting the same `index` argument as another instance will make sure that the memory is shared.
e.g.
```python
index = "my_personal_index"
document_store_1 = InMemoryDocumentStore(index=index)
document_store_2 = InMemoryDocumentStore(index=index)
assert document_store_1.count_documents() == 0
assert document_store_2.count_documents() == 0
document_store_1.write_documents([Document(content="Hello world")])
assert document_store_1.count_documents() == 1
assert document_store_2.count_documents() == 1
```

View File

@ -63,10 +63,13 @@ class TestFilterRetriever:
}
def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore()
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
serialized_ds = ds.to_dict()
component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"})
component = FilterRetriever(
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
filters={"lang": "en"},
)
data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",

View File

@ -60,11 +60,14 @@ class TestMemoryBM25Retriever:
}
def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore()
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
serialized_ds = ds.to_dict()
component = InMemoryBM25Retriever(
document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
filters={"name": "test.txt"},
top_k=5,
scale_score=True,
)
data = component.to_dict()
assert data == {

View File

@ -32,6 +32,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": store.index,
},
}
@ -41,6 +42,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
bm25_algorithm="BM25Plus",
bm25_parameters={"key": "value"},
embedding_similarity_function="cosine",
index="my_cool_index",
)
data = store.to_dict()
assert data == {
@ -50,6 +52,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"},
"embedding_similarity_function": "cosine",
"index": "my_cool_index",
},
}
@ -61,6 +64,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"},
"index": "my_cool_index",
},
}
store = InMemoryDocumentStore.from_dict(data)
@ -68,6 +72,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
assert store.tokenizer
assert store.bm25_algorithm == "BM25Plus"
assert store.bm25_parameters == {"key": "value"}
assert store.index == "my_cool_index"
def test_invalid_bm25_algorithm(self):
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
@ -414,3 +419,28 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
)
assert scores == [0.1, 0.4]
def test_multiple_document_stores_using_same_index(self):
index = "test_multiple_document_stores_using_same_index"
document_store_1 = InMemoryDocumentStore(index=index)
document_store_2 = InMemoryDocumentStore(index=index)
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
doc_1 = Document(content="Hello world")
document_store_1.write_documents([doc_1])
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1]
doc_2 = Document(content="Hello another world")
document_store_2.write_documents([doc_2])
assert document_store_1.count_documents() == document_store_2.count_documents() == 2
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2]
document_store_1.delete_documents([doc_2.id])
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
document_store_2.delete_documents([doc_1.id])
assert document_store_1.count_documents() == document_store_2.count_documents() == 0