mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-03 18:36:04 +00:00
MemoryDocumentStore - Embedding retrieval (2.0) (#5715)
* MemoryDocumentStore - Embedding retrieval draft * add release notes * fix mypy * better comment * improve return_embeddings handling * address PR comments * update docstrings * incorporated feeback --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
71852c7b06
commit
b7bea3ae9c
@ -12,17 +12,19 @@ from haystack.preview.document_stores.decorator import document_store
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores.protocols import DuplicatePolicy, DocumentStore
|
||||
from haystack.preview.document_stores.memory._filters import match
|
||||
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
|
||||
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError, DocumentStoreError
|
||||
from haystack.preview.utils import expit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
|
||||
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
|
||||
# larger SCALING_FACTOR decreases scaled scores. For example, an input of 10 is scaled to 0.99 with SCALING_FACTOR=2
|
||||
# but to 0.78 with SCALING_FACTOR=8 (default). The default was chosen empirically. Increase the default if most
|
||||
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
|
||||
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
|
||||
# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with BM25_SCALING_FACTOR=2
|
||||
# but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. Increase the default if most
|
||||
# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1.
|
||||
SCALING_FACTOR = 8
|
||||
BM25_SCALING_FACTOR = 8
|
||||
DOT_PRODUCT_SCALING_FACTOR = 100
|
||||
|
||||
|
||||
@document_store
|
||||
@ -36,9 +38,20 @@ class MemoryDocumentStore:
|
||||
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
|
||||
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi",
|
||||
bm25_parameters: Optional[Dict] = None,
|
||||
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
|
||||
):
|
||||
"""
|
||||
Initializes the DocumentStore.
|
||||
|
||||
:param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
|
||||
:param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
|
||||
:param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
|
||||
For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
|
||||
You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25.
|
||||
By default, no parameters are set.
|
||||
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
|
||||
One of "dot_product" (default) or "cosine".
|
||||
To choose the most appropriate function, look for information about your embedding model.
|
||||
"""
|
||||
self.storage: Dict[str, Document] = {}
|
||||
self._bm25_tokenization_regex = bm25_tokenization_regex
|
||||
@ -48,6 +61,7 @@ class MemoryDocumentStore:
|
||||
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
|
||||
self.bm25_algorithm = algorithm_class
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
self.embedding_similarity_function = embedding_similarity_function
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -58,6 +72,7 @@ class MemoryDocumentStore:
|
||||
bm25_tokenization_regex=self._bm25_tokenization_regex,
|
||||
bm25_algorithm=self.bm25_algorithm.__name__,
|
||||
bm25_parameters=self.bm25_parameters,
|
||||
embedding_similarity_function=self.embedding_similarity_function,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -141,8 +156,8 @@ class MemoryDocumentStore:
|
||||
}
|
||||
```
|
||||
|
||||
:param filters: the filters to apply to the document list.
|
||||
:return: a list of Documents that match the given filters.
|
||||
:param filters: The filters to apply to the document list.
|
||||
:return: A list of Documents that match the given filters.
|
||||
"""
|
||||
if filters:
|
||||
return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)]
|
||||
@ -152,12 +167,12 @@ class MemoryDocumentStore:
|
||||
"""
|
||||
Writes (or overwrites) documents into the DocumentStore.
|
||||
|
||||
:param documents: a list of documents.
|
||||
:param policy: documents with the same ID count as duplicates. When duplicates are met,
|
||||
:param documents: A list of documents.
|
||||
:param policy: Documents with the same ID count as duplicates. When duplicates are met,
|
||||
the DocumentStore can:
|
||||
- skip: keep the existing document and ignore the new one.
|
||||
- overwrite: remove the old document and write the new one.
|
||||
- fail: an error is raised
|
||||
- fail: an error is raised.
|
||||
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
|
||||
:return: None
|
||||
"""
|
||||
@ -178,10 +193,10 @@ class MemoryDocumentStore:
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
"""
|
||||
Deletes all documents with a matching document_ids from the DocumentStore.
|
||||
Deletes all documents with matching document_ids from the DocumentStore.
|
||||
Fails with `MissingDocumentError` if no document with this id is present in the DocumentStore.
|
||||
|
||||
:param object_ids: the object_ids to delete
|
||||
:param object_ids: The object_ids to delete.
|
||||
"""
|
||||
for doc_id in document_ids:
|
||||
if not doc_id in self.storage.keys():
|
||||
@ -198,7 +213,7 @@ class MemoryDocumentStore:
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The number of top documents to retrieve. Default is 10.
|
||||
:param scale_score: Whether to scale the scores of the retrieved documents. Default is True.
|
||||
:return: A list of the top 'k' documents most relevant to the query.
|
||||
:return: A list of the top_k documents most relevant to the query.
|
||||
"""
|
||||
if not query:
|
||||
raise ValueError("Query should be a non-empty string")
|
||||
@ -218,9 +233,6 @@ class MemoryDocumentStore:
|
||||
filters = {**filters, "content_type": ["text", "table"]}
|
||||
all_documents = self.filter_documents(filters=filters)
|
||||
|
||||
# FIXME: remove this guard after resolving https://github.com/deepset-ai/canals/issues/33
|
||||
top_k = top_k if top_k is not None else 10
|
||||
|
||||
# Lowercase all documents
|
||||
lower_case_documents = []
|
||||
for doc in all_documents:
|
||||
@ -246,7 +258,7 @@ class MemoryDocumentStore:
|
||||
# get scores for the query against the corpus
|
||||
docs_scores = bm25_scorer.get_scores(tokenized_query)
|
||||
if scale_score:
|
||||
docs_scores = [expit(float(score / SCALING_FACTOR)) for score in docs_scores]
|
||||
docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores]
|
||||
# get the last top_k indexes and reverse them
|
||||
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
|
||||
|
||||
@ -259,3 +271,106 @@ class MemoryDocumentStore:
|
||||
return_document = Document(**doc_fields)
|
||||
return_documents.append(return_document)
|
||||
return return_documents
|
||||
|
||||
def embedding_retrieval(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
return_embedding: bool = False,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
|
||||
|
||||
:param query_embedding: Embedding of the query.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The number of top documents to retrieve. Default is 10.
|
||||
:param scale_score: Whether to scale the scores of the retrieved Documents. Default is True.
|
||||
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
|
||||
:return: A list of the top_k documents most relevant to the query.
|
||||
"""
|
||||
if len(query_embedding) == 0 or not isinstance(query_embedding[0], float):
|
||||
raise ValueError("query_embedding should be a non-empty list of floats.")
|
||||
|
||||
filters = filters or {}
|
||||
all_documents = self.filter_documents(filters=filters)
|
||||
|
||||
documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None]
|
||||
if len(documents_with_embeddings) == 0:
|
||||
logger.warning(
|
||||
"No Documents found with embeddings. Returning empty list. "
|
||||
"To generate embeddings, use a DocumentEmbedder."
|
||||
)
|
||||
return []
|
||||
elif len(documents_with_embeddings) < len(all_documents):
|
||||
logger.info(
|
||||
"Skipping some Documents that don't have an embedding. "
|
||||
"To generate embeddings, use a DocumentEmbedder."
|
||||
)
|
||||
|
||||
scores = self._compute_query_embedding_similarity_scores(
|
||||
embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score
|
||||
)
|
||||
|
||||
# create Documents with the similarity score for the top k results
|
||||
top_documents = []
|
||||
for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]:
|
||||
doc_fields = doc.to_dict()
|
||||
doc_fields["score"] = score
|
||||
if return_embedding is False:
|
||||
doc_fields["embedding"] = None
|
||||
top_documents.append(Document(**doc_fields))
|
||||
|
||||
return top_documents
|
||||
|
||||
def _compute_query_embedding_similarity_scores(
|
||||
self, embedding: List[float], documents: List[Document], scale_score: bool = True
|
||||
) -> List[float]:
|
||||
"""
|
||||
Computes the similarity scores between the query embedding and the embeddings of the documents.
|
||||
|
||||
:param embedding: Embedding of the query.
|
||||
:param documents: A list of Documents.
|
||||
:param scale_score: Whether to scale the scores of the Documents. Default is True.
|
||||
:return: A list of scores.
|
||||
"""
|
||||
|
||||
query_embedding = np.array(embedding)
|
||||
if query_embedding.ndim == 1:
|
||||
query_embedding = np.expand_dims(a=query_embedding, axis=0)
|
||||
|
||||
try:
|
||||
document_embeddings = np.array([doc.embedding for doc in documents])
|
||||
except ValueError as e:
|
||||
if "inhomogeneous shape" in str(e):
|
||||
raise DocumentStoreError(
|
||||
"The embedding size of all Documents should be the same. "
|
||||
"Please make sure that the Documents have been embedded with the same model."
|
||||
) from e
|
||||
raise e
|
||||
if document_embeddings.ndim == 1:
|
||||
document_embeddings = np.expand_dims(a=document_embeddings, axis=0)
|
||||
|
||||
if self.embedding_similarity_function == "cosine":
|
||||
# cosine similarity is a normed dot product
|
||||
query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True)
|
||||
document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True)
|
||||
|
||||
try:
|
||||
scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist()
|
||||
except ValueError as e:
|
||||
if "shapes" in str(e) and "not aligned" in str(e):
|
||||
raise DocumentStoreError(
|
||||
"The embedding size of the query should be the same as the embedding size of the Documents. "
|
||||
"Please make sure that the query has been embedded with the same model as the Documents."
|
||||
) from e
|
||||
raise e
|
||||
|
||||
if scale_score:
|
||||
if self.embedding_similarity_function == "dot_product":
|
||||
scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores]
|
||||
elif self.embedding_similarity_function == "cosine":
|
||||
scores = [(score + 1) / 2 for score in scores]
|
||||
|
||||
return scores
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add `embedding_retrieval` method to `MemoryDocumentStore`,
|
||||
which allows to retrieve the relevant Documents, given a query embedding.
|
||||
It will be called by the `MemoryEmbeddingRetriever`.
|
||||
@ -5,7 +5,8 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.document_stores import DocumentStore, MemoryDocumentStore
|
||||
from haystack.preview.document_stores import DocumentStore, MemoryDocumentStore, DocumentStoreError
|
||||
|
||||
|
||||
from haystack.preview.testing.document_store import DocumentStoreBaseTests
|
||||
|
||||
@ -29,13 +30,17 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
|
||||
"bm25_algorithm": "BM25Okapi",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
store = MemoryDocumentStore(
|
||||
bm25_tokenization_regex="custom_regex", bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"}
|
||||
bm25_tokenization_regex="custom_regex",
|
||||
bm25_algorithm="BM25Plus",
|
||||
bm25_parameters={"key": "value"},
|
||||
embedding_similarity_function="cosine",
|
||||
)
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
@ -44,6 +49,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
"embedding_similarity_function": "cosine",
|
||||
},
|
||||
}
|
||||
|
||||
@ -251,3 +257,148 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
df = results[0].content
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert df.equals(table_content)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
# Tests if the embedding retrieval method returns the correct document based on the input query embedding.
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.embedding_retrieval(
|
||||
query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters={}, scale_score=False
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_invalid_query(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
||||
docstore.embedding_retrieval(query_embedding=[])
|
||||
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
||||
docstore.embedding_retrieval(query_embedding=["invalid", "list", "of", "strings"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_no_embeddings(self, caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
||||
assert len(results) == 0
|
||||
assert "No Documents found with embeddings. Returning empty list." in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_some_documents_wo_embeddings(self, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
||||
assert "Skipping some Documents that don't have an embedding." in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_documents_different_embedding_sizes(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0]),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
with pytest.raises(DocumentStoreError, match="The embedding size of all Documents should be the same."):
|
||||
docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_query_documents_different_embedding_sizes(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
with pytest.raises(
|
||||
DocumentStoreError,
|
||||
match="The embedding size of the query should be the same as the embedding size of the Documents.",
|
||||
):
|
||||
docstore.embedding_retrieval(query_embedding=[0.1, 0.1])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_with_different_top_k(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2)
|
||||
assert len(results) == 2
|
||||
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_with_scale_score(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results1 = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=True)
|
||||
# Confirm that score is scaled between 0 and 1
|
||||
assert 0 <= results1[0].score <= 1
|
||||
|
||||
# Same query, different scale, scores differ when not scaled
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=False)
|
||||
assert results[0].score != results1[0].score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_return_embedding(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docs = [
|
||||
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=False)
|
||||
assert results[0].embedding is None
|
||||
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=True)
|
||||
assert results[0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compute_cosine_similarity_scores(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docs = [
|
||||
Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
||||
Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
]
|
||||
|
||||
scores = docstore._compute_query_embedding_similarity_scores(
|
||||
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
||||
)
|
||||
assert scores == [0.5, 1.0]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compute_dot_product_similarity_scores(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="dot_product")
|
||||
docs = [
|
||||
Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
||||
Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
]
|
||||
|
||||
scores = docstore._compute_query_embedding_similarity_scores(
|
||||
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
||||
)
|
||||
print(scores)
|
||||
assert scores == [0.1, 0.4]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user