mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-11 01:02:30 +00:00
feat: Add memory sharing between different instances of InMemoryDocumentStore
(#7781)
* Add memory sharing between different instances of InMemoryDocumentStore * Fix FilterRetriever tests * Fix InMemoryBM25Retriever tests
This commit is contained in:
parent
d81af81fbb
commit
854c4173f2
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||||
@ -42,6 +43,11 @@ class BM25DocumentStats:
|
|||||||
doc_len: int
|
doc_len: int
|
||||||
|
|
||||||
|
|
||||||
|
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
|
||||||
|
_STORAGES: Dict[str, Dict[str, Document]] = {}
|
||||||
|
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
|
||||||
|
|
||||||
|
|
||||||
class InMemoryDocumentStore:
|
class InMemoryDocumentStore:
|
||||||
"""
|
"""
|
||||||
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
||||||
@ -53,6 +59,7 @@ class InMemoryDocumentStore:
|
|||||||
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
|
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
|
||||||
bm25_parameters: Optional[Dict] = None,
|
bm25_parameters: Optional[Dict] = None,
|
||||||
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
|
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
|
||||||
|
index: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the DocumentStore.
|
Initializes the DocumentStore.
|
||||||
@ -66,11 +73,19 @@ class InMemoryDocumentStore:
|
|||||||
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
|
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
|
||||||
One of "dot_product" (default) or "cosine".
|
One of "dot_product" (default) or "cosine".
|
||||||
To choose the most appropriate function, look for information about your embedding model.
|
To choose the most appropriate function, look for information about your embedding model.
|
||||||
|
:param index: If specified uses a specific index to store the documents. If not specified, a random UUID is used.
|
||||||
|
Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
|
||||||
"""
|
"""
|
||||||
self.storage: Dict[str, Document] = {}
|
|
||||||
self.bm25_tokenization_regex = bm25_tokenization_regex
|
self.bm25_tokenization_regex = bm25_tokenization_regex
|
||||||
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
index = str(uuid.uuid4())
|
||||||
|
|
||||||
|
self.index = index
|
||||||
|
if self.index not in _STORAGES:
|
||||||
|
_STORAGES[self.index] = {}
|
||||||
|
|
||||||
self.bm25_algorithm = bm25_algorithm
|
self.bm25_algorithm = bm25_algorithm
|
||||||
self.bm25_algorithm_inst = self._dispatch_bm25()
|
self.bm25_algorithm_inst = self._dispatch_bm25()
|
||||||
self.bm25_parameters = bm25_parameters or {}
|
self.bm25_parameters = bm25_parameters or {}
|
||||||
@ -81,7 +96,19 @@ class InMemoryDocumentStore:
|
|||||||
self._freq_vocab_for_idf: Counter = Counter()
|
self._freq_vocab_for_idf: Counter = Counter()
|
||||||
|
|
||||||
# Per-document statistics
|
# Per-document statistics
|
||||||
self._bm25_attr: Dict[str, BM25DocumentStats] = {}
|
if self.index not in _BM25_STATS_STORAGES:
|
||||||
|
_BM25_STATS_STORAGES[self.index] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage(self) -> Dict[str, Document]:
|
||||||
|
"""
|
||||||
|
Utility property that returns the storage used by this instance of InMemoryDocumentStore.
|
||||||
|
"""
|
||||||
|
return _STORAGES.get(self.index, {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
|
||||||
|
return _BM25_STATS_STORAGES.get(self.index, {})
|
||||||
|
|
||||||
def _dispatch_bm25(self):
|
def _dispatch_bm25(self):
|
||||||
"""
|
"""
|
||||||
@ -281,6 +308,7 @@ class InMemoryDocumentStore:
|
|||||||
bm25_algorithm=self.bm25_algorithm,
|
bm25_algorithm=self.bm25_algorithm,
|
||||||
bm25_parameters=self.bm25_parameters,
|
bm25_parameters=self.bm25_parameters,
|
||||||
embedding_similarity_function=self.embedding_similarity_function,
|
embedding_similarity_function=self.embedding_similarity_function,
|
||||||
|
index=self.index,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add memory sharing between different instances of InMemoryDocumentStore.
|
||||||
|
Setting the same `index` argument as another instance will make sure that the memory is shared.
|
||||||
|
e.g.
|
||||||
|
```python
|
||||||
|
index = "my_personal_index"
|
||||||
|
document_store_1 = InMemoryDocumentStore(index=index)
|
||||||
|
document_store_2 = InMemoryDocumentStore(index=index)
|
||||||
|
|
||||||
|
assert document_store_1.count_documents() == 0
|
||||||
|
assert document_store_2.count_documents() == 0
|
||||||
|
|
||||||
|
document_store_1.write_documents([Document(content="Hello world")])
|
||||||
|
|
||||||
|
assert document_store_1.count_documents() == 1
|
||||||
|
assert document_store_2.count_documents() == 1
|
||||||
|
```
|
@ -63,10 +63,13 @@ class TestFilterRetriever:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self):
|
def test_to_dict_with_custom_init_parameters(self):
|
||||||
ds = InMemoryDocumentStore()
|
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
|
||||||
serialized_ds = ds.to_dict()
|
serialized_ds = ds.to_dict()
|
||||||
|
|
||||||
component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"})
|
component = FilterRetriever(
|
||||||
|
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
|
||||||
|
filters={"lang": "en"},
|
||||||
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
|
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
|
||||||
|
@ -60,11 +60,14 @@ class TestMemoryBM25Retriever:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self):
|
def test_to_dict_with_custom_init_parameters(self):
|
||||||
ds = InMemoryDocumentStore()
|
ds = InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters")
|
||||||
serialized_ds = ds.to_dict()
|
serialized_ds = ds.to_dict()
|
||||||
|
|
||||||
component = InMemoryBM25Retriever(
|
component = InMemoryBM25Retriever(
|
||||||
document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
|
document_store=InMemoryDocumentStore(index="test_to_dict_with_custom_init_parameters"),
|
||||||
|
filters={"name": "test.txt"},
|
||||||
|
top_k=5,
|
||||||
|
scale_score=True,
|
||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
|
@ -32,6 +32,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
"bm25_algorithm": "BM25L",
|
"bm25_algorithm": "BM25L",
|
||||||
"bm25_parameters": {},
|
"bm25_parameters": {},
|
||||||
"embedding_similarity_function": "dot_product",
|
"embedding_similarity_function": "dot_product",
|
||||||
|
"index": store.index,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,6 +42,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
bm25_algorithm="BM25Plus",
|
bm25_algorithm="BM25Plus",
|
||||||
bm25_parameters={"key": "value"},
|
bm25_parameters={"key": "value"},
|
||||||
embedding_similarity_function="cosine",
|
embedding_similarity_function="cosine",
|
||||||
|
index="my_cool_index",
|
||||||
)
|
)
|
||||||
data = store.to_dict()
|
data = store.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -50,6 +52,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
"bm25_algorithm": "BM25Plus",
|
"bm25_algorithm": "BM25Plus",
|
||||||
"bm25_parameters": {"key": "value"},
|
"bm25_parameters": {"key": "value"},
|
||||||
"embedding_similarity_function": "cosine",
|
"embedding_similarity_function": "cosine",
|
||||||
|
"index": "my_cool_index",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,6 +64,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
"bm25_tokenization_regex": "custom_regex",
|
"bm25_tokenization_regex": "custom_regex",
|
||||||
"bm25_algorithm": "BM25Plus",
|
"bm25_algorithm": "BM25Plus",
|
||||||
"bm25_parameters": {"key": "value"},
|
"bm25_parameters": {"key": "value"},
|
||||||
|
"index": "my_cool_index",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
store = InMemoryDocumentStore.from_dict(data)
|
store = InMemoryDocumentStore.from_dict(data)
|
||||||
@ -68,6 +72,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
assert store.tokenizer
|
assert store.tokenizer
|
||||||
assert store.bm25_algorithm == "BM25Plus"
|
assert store.bm25_algorithm == "BM25Plus"
|
||||||
assert store.bm25_parameters == {"key": "value"}
|
assert store.bm25_parameters == {"key": "value"}
|
||||||
|
assert store.index == "my_cool_index"
|
||||||
|
|
||||||
def test_invalid_bm25_algorithm(self):
|
def test_invalid_bm25_algorithm(self):
|
||||||
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
||||||
@ -414,3 +419,28 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
|||||||
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
||||||
)
|
)
|
||||||
assert scores == [0.1, 0.4]
|
assert scores == [0.1, 0.4]
|
||||||
|
|
||||||
|
def test_multiple_document_stores_using_same_index(self):
|
||||||
|
index = "test_multiple_document_stores_using_same_index"
|
||||||
|
document_store_1 = InMemoryDocumentStore(index=index)
|
||||||
|
document_store_2 = InMemoryDocumentStore(index=index)
|
||||||
|
|
||||||
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|
||||||
|
|
||||||
|
doc_1 = Document(content="Hello world")
|
||||||
|
document_store_1.write_documents([doc_1])
|
||||||
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
||||||
|
|
||||||
|
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1]
|
||||||
|
|
||||||
|
doc_2 = Document(content="Hello another world")
|
||||||
|
document_store_2.write_documents([doc_2])
|
||||||
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 2
|
||||||
|
|
||||||
|
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2]
|
||||||
|
|
||||||
|
document_store_1.delete_documents([doc_2.id])
|
||||||
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
||||||
|
|
||||||
|
document_store_2.delete_documents([doc_1.id])
|
||||||
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user