mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-08 07:41:43 +00:00
feat: Add memory sharing between different instances of InMemoryDocumentStore
(#7781)
* Add memory sharing between different instances of InMemoryDocumentStore * Fix FilterRetriever tests * Fix InMemoryBM25Retriever tests
This commit is contained in:
parent
d81af81fbb
commit
854c4173f2
@ -4,6 +4,7 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||
@ -42,6 +43,11 @@ class BM25DocumentStats:
|
||||
doc_len: int
|
||||
|
||||
|
||||
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
|
||||
_STORAGES: Dict[str, Dict[str, Document]] = {}
|
||||
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
|
||||
|
||||
|
||||
class InMemoryDocumentStore:
|
||||
"""
|
||||
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
||||
@ -53,6 +59,7 @@ class InMemoryDocumentStore:
|
||||
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
|
||||
bm25_parameters: Optional[Dict] = None,
|
||||
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the DocumentStore.
|
||||
@ -66,11 +73,19 @@ class InMemoryDocumentStore:
|
||||
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
|
||||
One of "dot_product" (default) or "cosine".
|
||||
To choose the most appropriate function, look for information about your embedding model.
|
||||
:param index: If specified uses a specific index to store the documents. If not specified, a random UUID is used.
|
||||
Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
|
||||
"""
|
||||
self.storage: Dict[str, Document] = {}
|
||||
self.bm25_tokenization_regex = bm25_tokenization_regex
|
||||
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
||||
|
||||
if index is None:
|
||||
index = str(uuid.uuid4())
|
||||
|
||||
self.index = index
|
||||
if self.index not in _STORAGES:
|
||||
_STORAGES[self.index] = {}
|
||||
|
||||
self.bm25_algorithm = bm25_algorithm
|
||||
self.bm25_algorithm_inst = self._dispatch_bm25()
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
@ -81,7 +96,19 @@ class InMemoryDocumentStore:
|
||||
self._freq_vocab_for_idf: Counter = Counter()
|
||||
|
||||
# Per-document statistics
|
||||
self._bm25_attr: Dict[str, BM25DocumentStats] = {}
|
||||
if self.index not in _BM25_STATS_STORAGES:
|
||||
_BM25_STATS_STORAGES[self.index] = {}
|
||||
|
||||
@property
|
||||
def storage(self) -> Dict[str, Document]:
|
||||
"""
|
||||
Utility property that returns the storage used by this instance of InMemoryDocumentStore.
|
||||
"""
|
||||
return _STORAGES.get(self.index, {})
|
||||
|
||||
@property
|
||||
def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
|
||||
return _BM25_STATS_STORAGES.get(self.index, {})
|
||||
|
||||
def _dispatch_bm25(self):
|
||||
"""
|
||||
@ -281,6 +308,7 @@ class InMemoryDocumentStore:
|
||||
bm25_algorithm=self.bm25_algorithm,
|
||||
bm25_parameters=self.bm25_parameters,
|
||||
embedding_similarity_function=self.embedding_similarity_function,
|
||||
index=self.index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -0,0 +1,19 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add memory sharing between different instances of InMemoryDocumentStore.
|
||||
Setting the same `index` argument as another instance will make sure that the memory is shared.
|
||||
e.g.
|
||||
```python
|
||||
index = "my_personal_index"
|
||||
document_store_1 = InMemoryDocumentStore(index=index)
|
||||
document_store_2 = InMemoryDocumentStore(index=index)
|
||||
|
||||
assert document_store_1.count_documents() == 0
|
||||
assert document_store_2.count_documents() == 0
|
||||
|
||||
document_store_1.write_documents([Document(content="Hello world")])
|
||||
|
||||
assert document_store_1.count_documents() == 1
|
||||
assert document_store_2.count_documents() == 1
|
||||
```
|
@ -63,10 +63,13 @@ class TestFilterRetriever:
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
ds = InMemoryDocumentStore()
|
||||
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
|
||||
serialized_ds = ds.to_dict()
|
||||
|
||||
component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"})
|
||||
component = FilterRetriever(
|
||||
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
|
||||
filters={"lang": "en"},
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
|
||||
|
@ -60,11 +60,14 @@ class TestMemoryBM25Retriever:
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
ds = InMemoryDocumentStore()
|
||||
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
|
||||
serialized_ds = ds.to_dict()
|
||||
|
||||
component = InMemoryBM25Retriever(
|
||||
document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
|
||||
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
|
||||
filters={"name": "test.txt"},
|
||||
top_k=5,
|
||||
scale_score=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
|
@ -32,6 +32,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
"bm25_algorithm": "BM25L",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
"index": store.index,
|
||||
},
|
||||
}
|
||||
|
||||
@ -41,6 +42,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
bm25_algorithm="BM25Plus",
|
||||
bm25_parameters={"key": "value"},
|
||||
embedding_similarity_function="cosine",
|
||||
index="my_cool_index",
|
||||
)
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
@ -50,6 +52,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
"embedding_similarity_function": "cosine",
|
||||
"index": "my_cool_index",
|
||||
},
|
||||
}
|
||||
|
||||
@ -61,6 +64,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
"index": "my_cool_index",
|
||||
},
|
||||
}
|
||||
store = InMemoryDocumentStore.from_dict(data)
|
||||
@ -68,6 +72,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
assert store.tokenizer
|
||||
assert store.bm25_algorithm == "BM25Plus"
|
||||
assert store.bm25_parameters == {"key": "value"}
|
||||
assert store.index == "my_cool_index"
|
||||
|
||||
def test_invalid_bm25_algorithm(self):
|
||||
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
||||
@ -414,3 +419,28 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
||||
)
|
||||
assert scores == [0.1, 0.4]
|
||||
|
||||
def test_multiple_document_stores_using_same_index(self):
|
||||
index = "test_multiple_document_stores_using_same_index"
|
||||
document_store_1 = InMemoryDocumentStore(index=index)
|
||||
document_store_2 = InMemoryDocumentStore(index=index)
|
||||
|
||||
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|
||||
|
||||
doc_1 = Document(content="Hello world")
|
||||
document_store_1.write_documents([doc_1])
|
||||
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
||||
|
||||
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1]
|
||||
|
||||
doc_2 = Document(content="Hello another world")
|
||||
document_store_2.write_documents([doc_2])
|
||||
assert document_store_1.count_documents() == document_store_2.count_documents() == 2
|
||||
|
||||
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2]
|
||||
|
||||
document_store_1.delete_documents([doc_2.id])
|
||||
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
||||
|
||||
document_store_2.delete_documents([doc_1.id])
|
||||
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user