MemoryEmbeddingRetriever (2.0) (#5726)

* MemoryDocumentStore - Embedding retrieval draft

* add release notes

* fix mypy

* better comment

* improve return_embeddings handling

* MemoryEmbeddingRetriever - first draft

* address PR comments

* release note

* update docstrings

* update docstrings

* incorporated feeback

* add return_embedding to __init__

* rm leftover docstring

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Stefano Fiorucci 2023-09-08 15:52:48 +02:00 committed by GitHub
parent d860a5c604
commit 2edf85f739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 269 additions and 45 deletions

View File

@ -1,3 +1,3 @@
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
__all__ = ["MemoryRetriever"]
__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]

View File

@ -5,7 +5,7 @@ from haystack.preview.document_stores import MemoryDocumentStore, document_store
@component
class MemoryRetriever:
class MemoryBM25Retriever:
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
@ -20,12 +20,12 @@ class MemoryRetriever:
scale_score: bool = True,
):
"""
Create a MemoryRetriever component.
Create a MemoryBM25Retriever component.
:param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the BM25 score or not. Default is True.
:raises ValueError: If the specified top_k is not > 0.
"""
@ -51,7 +51,7 @@ class MemoryRetriever:
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
@ -77,13 +77,12 @@ class MemoryRetriever:
scale_score: Optional[bool] = None,
):
"""
Run the MemoryRetriever on the given input data.
Run the MemoryBM25Retriever on the given input data.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param document_stores: A dictionary mapping DocumentStore names to instances.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
@ -101,3 +100,119 @@ class MemoryRetriever:
self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score)
)
return {"documents": docs}
@component
class MemoryEmbeddingRetriever:
"""
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
Needs to be connected to a MemoryDocumentStore to run.
"""
def __init__(
self,
document_store: MemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a MemoryEmbeddingRetriever component.
:param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
:raises ValueError: If the specified top_k is not > 0.
"""
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("document_store must be an instance of MemoryDocumentStore")
self.document_store = document_store
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.filters = filters
self.top_k = top_k
self.scale_score = scale_score
self.return_embedding = return_embedding
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self,
document_store=docstore,
filters=self.filters,
top_k=self.top_k,
scale_score=self.scale_score,
return_embedding=self.return_embedding,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")
docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)
@component.output_types(documents=List[List[Document]])
def run(
self,
queries_embeddings: List[List[float]],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
return_embedding: Optional[bool] = None,
):
"""
Run the MemoryEmbeddingRetriever on the given input data.
:param queries_embeddings: Embeddings of the queries.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the scores of the retrieved documents or not.
:param return_embedding: Whether to return the embedding of the retrieved Documents.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:
scale_score = self.scale_score
if return_embedding is None:
return_embedding = self.return_embedding
docs = []
for query_embedding in queries_embeddings:
docs.append(
self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
scale_score=scale_score,
return_embedding=return_embedding,
)
)
return {"documents": docs}

View File

@ -0,0 +1,6 @@
---
preview:
- |
Rename `MemoryRetriever` to `MemoryBM25Retriever`
Add `MemoryEmbeddingRetriever`, which takes as input a query embedding and
retrieves the most relevant Documents from a `MemoryDocumentStore`.

View File

@ -1,11 +1,10 @@
from typing import Dict, Any
from unittest.mock import MagicMock, patch
import pytest
from haystack.preview import Pipeline, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore
@ -21,36 +20,39 @@ def mock_docs():
]
class TestMemoryRetriever:
class TestMemoryRetrievers:
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_init_default(self):
retriever = MemoryRetriever(MemoryDocumentStore())
def test_init_default(self, retriever_cls):
retriever = retriever_cls(MemoryDocumentStore())
assert retriever.filters is None
assert retriever.top_k == 10
assert retriever.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_init_with_parameters(self):
retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
def test_init_with_parameters(self, retriever_cls):
retriever = retriever_cls(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
assert retriever.filters == {"name": "test.txt"}
assert retriever.top_k == 5
assert not retriever.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self):
def test_init_with_invalid_top_k_parameter(self, retriever_cls):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
retriever_cls(MemoryDocumentStore(), top_k=-2, scale_score=False)
@pytest.mark.unit
def test_to_dict(self):
def test_bm25_retriever_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever(document_store=document_store)
component = MemoryBM25Retriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryRetriever",
"type": "MemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
@ -60,16 +62,35 @@ class TestMemoryRetriever:
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
def test_embedding_retriever_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever(
component = MemoryEmbeddingRetriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
"top_k": 10,
"scale_score": True,
"return_embedding": False,
},
}
@pytest.mark.unit
def test_bm25_retriever_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryBM25Retriever(
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
)
data = component.to_dict()
assert data == {
"type": "MemoryRetriever",
"type": "MemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
@ -79,50 +100,78 @@ class TestMemoryRetriever:
}
@pytest.mark.unit
def test_from_dict(self):
def test_embedding_retriever_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryEmbeddingRetriever(
document_store=document_store,
filters={"name": "test.txt"},
top_k=5,
scale_score=False,
return_embedding=True,
)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": False,
"return_embedding": True,
},
}
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_from_dict(self, retriever_cls):
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
data = {
"type": "MemoryRetriever",
"type": retriever_cls.__name__,
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
},
}
component = MemoryRetriever.from_dict(data)
component = retriever_cls.from_dict(data)
assert isinstance(component.document_store, MemoryDocumentStore)
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_from_dict_without_docstore(self):
data = {"type": "MemoryRetriever", "init_parameters": {}}
def test_from_dict_without_docstore(self, retriever_cls):
data = {"type": retriever_cls.__name__, "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
MemoryRetriever.from_dict(data)
retriever_cls.from_dict(data)
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_from_dict_without_docstore_type(self):
data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
def test_from_dict_without_docstore_type(self, retriever_cls):
data = {"type": retriever_cls.__name__, "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
MemoryRetriever.from_dict(data)
retriever_cls.from_dict(data)
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_from_dict_nonexisting_docstore(self):
def test_from_dict_nonexisting_docstore(self, retriever_cls):
data = {
"type": "MemoryRetriever",
"type": retriever_cls.__name__,
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
MemoryRetriever.from_dict(data)
retriever_cls.from_dict(data)
@pytest.mark.unit
def test_valid_run(self, mock_docs):
def test_bm25_retriever_valid_run(self, mock_docs):
top_k = 5
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds, top_k=top_k)
retriever = MemoryBM25Retriever(ds, top_k=top_k)
result = retriever.run(queries=["PHP", "Java"])
assert "documents" in result
@ -133,10 +182,32 @@ class TestMemoryRetriever:
assert result["documents"][1][0].content == "Java is a popular programming language"
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self):
def test_embedding_retriever_valid_run(self):
top_k = 3
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
assert "documents" in result
assert len(result["documents"]) == 2
assert len(result["documents"][0]) == top_k
assert len(result["documents"][1]) == top_k
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self, retriever_cls):
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
MemoryRetriever(SomeOtherDocumentStore())
retriever_cls(SomeOtherDocumentStore())
@pytest.mark.integration
@pytest.mark.parametrize(
@ -146,10 +217,10 @@ class TestMemoryRetriever:
("Java", "Java is a popular programming language"),
],
)
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
def test_run_bm25_retriever_with_pipeline(self, mock_docs, query: str, query_result: str):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds)
retriever = MemoryBM25Retriever(ds)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
@ -161,6 +232,38 @@ class TestMemoryRetriever:
assert results_docs
assert results_docs[0][0].content == query_result
@pytest.mark.integration
def test_run_embedding_retriever_with_pipeline(self):
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
top_k = 2
docs = [
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(
data={
"retriever": {
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
"return_embedding": True,
}
}
)
assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert len(results_docs[0]) == top_k
assert len(results_docs[1]) == top_k
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.integration
@pytest.mark.parametrize(
"query, query_result, top_k",
@ -170,10 +273,10 @@ class TestMemoryRetriever:
("Ruby", "Ruby is a popular programming language", 3),
],
)
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
def test_run_bm25_retriever_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds)
retriever = MemoryBM25Retriever(ds)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)