feat: Add memory sharing between different instances of InMemoryDocumentStore (#7781)

* Add memory sharing between different instances of InMemoryDocumentStore

* Fix FilterRetriever tests

* Fix InMemoryBM25Retriever tests
This commit is contained in:
Silvano Cerza 2024-05-31 16:44:14 +02:00 committed by GitHub
parent d81af81fbb
commit 854c4173f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 6 deletions

View File

@ -4,6 +4,7 @@
import math import math
import re import re
import uuid
from collections import Counter from collections import Counter
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
@ -42,6 +43,11 @@ class BM25DocumentStats:
doc_len: int doc_len: int
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
_STORAGES: Dict[str, Dict[str, Document]] = {}
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
class InMemoryDocumentStore: class InMemoryDocumentStore:
""" """
Stores data in-memory. It's ephemeral and cannot be saved to disk. Stores data in-memory. It's ephemeral and cannot be saved to disk.
@ -53,6 +59,7 @@ class InMemoryDocumentStore:
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L", bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
bm25_parameters: Optional[Dict] = None, bm25_parameters: Optional[Dict] = None,
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product", embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
index: Optional[str] = None,
): ):
""" """
Initializes the DocumentStore. Initializes the DocumentStore.
@ -66,11 +73,19 @@ class InMemoryDocumentStore:
:param embedding_similarity_function: The similarity function used to compare Documents embeddings. :param embedding_similarity_function: The similarity function used to compare Documents embeddings.
One of "dot_product" (default) or "cosine". One of "dot_product" (default) or "cosine".
To choose the most appropriate function, look for information about your embedding model. To choose the most appropriate function, look for information about your embedding model.
:param index: If specified uses a specific index to store the documents. If not specified, a random UUID is used.
Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
""" """
self.storage: Dict[str, Document] = {}
self.bm25_tokenization_regex = bm25_tokenization_regex self.bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall self.tokenizer = re.compile(bm25_tokenization_regex).findall
if index is None:
index = str(uuid.uuid4())
self.index = index
if self.index not in _STORAGES:
_STORAGES[self.index] = {}
self.bm25_algorithm = bm25_algorithm self.bm25_algorithm = bm25_algorithm
self.bm25_algorithm_inst = self._dispatch_bm25() self.bm25_algorithm_inst = self._dispatch_bm25()
self.bm25_parameters = bm25_parameters or {} self.bm25_parameters = bm25_parameters or {}
@ -81,7 +96,19 @@ class InMemoryDocumentStore:
self._freq_vocab_for_idf: Counter = Counter() self._freq_vocab_for_idf: Counter = Counter()
# Per-document statistics # Per-document statistics
self._bm25_attr: Dict[str, BM25DocumentStats] = {} if self.index not in _BM25_STATS_STORAGES:
_BM25_STATS_STORAGES[self.index] = {}
@property
def storage(self) -> Dict[str, Document]:
"""
Utility property that returns the storage used by this instance of InMemoryDocumentStore.
"""
return _STORAGES.get(self.index, {})
@property
def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
return _BM25_STATS_STORAGES.get(self.index, {})
def _dispatch_bm25(self): def _dispatch_bm25(self):
""" """
@ -281,6 +308,7 @@ class InMemoryDocumentStore:
bm25_algorithm=self.bm25_algorithm, bm25_algorithm=self.bm25_algorithm,
bm25_parameters=self.bm25_parameters, bm25_parameters=self.bm25_parameters,
embedding_similarity_function=self.embedding_similarity_function, embedding_similarity_function=self.embedding_similarity_function,
index=self.index,
) )
@classmethod @classmethod

View File

@ -0,0 +1,19 @@
---
features:
- |
Add memory sharing between different instances of InMemoryDocumentStore.
Setting the same `index` argument as another instance will make sure that the memory is shared.
e.g.
```python
index = "my_personal_index"
document_store_1 = InMemoryDocumentStore(index=index)
document_store_2 = InMemoryDocumentStore(index=index)
assert document_store_1.count_documents() == 0
assert document_store_2.count_documents() == 0
document_store_1.write_documents([Document(content="Hello world")])
assert document_store_1.count_documents() == 1
assert document_store_2.count_documents() == 1
```

View File

@ -63,10 +63,13 @@ class TestFilterRetriever:
} }
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore() ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
serialized_ds = ds.to_dict() serialized_ds = ds.to_dict()
component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"}) component = FilterRetriever(
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
filters={"lang": "en"},
)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever", "type": "haystack.components.retrievers.filter_retriever.FilterRetriever",

View File

@ -60,11 +60,14 @@ class TestMemoryBM25Retriever:
} }
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore() ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
serialized_ds = ds.to_dict() serialized_ds = ds.to_dict()
component = InMemoryBM25Retriever( component = InMemoryBM25Retriever(
document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
filters={"name": "test.txt"},
top_k=5,
scale_score=True,
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {

View File

@ -32,6 +32,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_algorithm": "BM25L", "bm25_algorithm": "BM25L",
"bm25_parameters": {}, "bm25_parameters": {},
"embedding_similarity_function": "dot_product", "embedding_similarity_function": "dot_product",
"index": store.index,
}, },
} }
@ -41,6 +42,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
bm25_algorithm="BM25Plus", bm25_algorithm="BM25Plus",
bm25_parameters={"key": "value"}, bm25_parameters={"key": "value"},
embedding_similarity_function="cosine", embedding_similarity_function="cosine",
index="my_cool_index",
) )
data = store.to_dict() data = store.to_dict()
assert data == { assert data == {
@ -50,6 +52,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_algorithm": "BM25Plus", "bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"}, "bm25_parameters": {"key": "value"},
"embedding_similarity_function": "cosine", "embedding_similarity_function": "cosine",
"index": "my_cool_index",
}, },
} }
@ -61,6 +64,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
"bm25_tokenization_regex": "custom_regex", "bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus", "bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"}, "bm25_parameters": {"key": "value"},
"index": "my_cool_index",
}, },
} }
store = InMemoryDocumentStore.from_dict(data) store = InMemoryDocumentStore.from_dict(data)
@ -68,6 +72,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
assert store.tokenizer assert store.tokenizer
assert store.bm25_algorithm == "BM25Plus" assert store.bm25_algorithm == "BM25Plus"
assert store.bm25_parameters == {"key": "value"} assert store.bm25_parameters == {"key": "value"}
assert store.index == "my_cool_index"
def test_invalid_bm25_algorithm(self): def test_invalid_bm25_algorithm(self):
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"): with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
@ -414,3 +419,28 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
) )
assert scores == [0.1, 0.4] assert scores == [0.1, 0.4]
def test_multiple_document_stores_using_same_index(self):
index = "test_multiple_document_stores_using_same_index"
document_store_1 = InMemoryDocumentStore(index=index)
document_store_2 = InMemoryDocumentStore(index=index)
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
doc_1 = Document(content="Hello world")
document_store_1.write_documents([doc_1])
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1]
doc_2 = Document(content="Hello another world")
document_store_2.write_documents([doc_2])
assert document_store_1.count_documents() == document_store_2.count_documents() == 2
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2]
document_store_1.delete_documents([doc_2.id])
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
document_store_2.delete_documents([doc_1.id])
assert document_store_1.count_documents() == document_store_2.count_documents() == 0