MemoryEmbeddingRetriever (2.0) (#5726)

* MemoryDocumentStore - Embedding retrieval draft

* add release notes

* fix mypy

* better comment

* improve return_embeddings handling

* MemoryEmbeddingRetriever - first draft

* address PR comments

* release note

* update docstrings

* update docstrings

* incorporated feeback

* add return_embedding to __init__

* rm leftover docstring

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Stefano Fiorucci 2023-09-08 15:52:48 +02:00 committed by GitHub
parent d860a5c604
commit 2edf85f739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 269 additions and 45 deletions

View File

@ -1,3 +1,3 @@
from haystack.preview.components.retrievers.memory import MemoryRetriever from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
__all__ = ["MemoryRetriever"] __all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]

View File

@ -5,7 +5,7 @@ from haystack.preview.document_stores import MemoryDocumentStore, document_store
@component @component
class MemoryRetriever: class MemoryBM25Retriever:
""" """
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm. A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
@ -20,12 +20,12 @@ class MemoryRetriever:
scale_score: bool = True, scale_score: bool = True,
): ):
""" """
Create a MemoryRetriever component. Create a MemoryBM25Retriever component.
:param document_store: An instance of MemoryDocumentStore. :param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space (default is None). :param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve (default is 10). :param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the BM25 score or not (default is True). :param scale_score: Whether to scale the BM25 score or not. Default is True.
:raises ValueError: If the specified top_k is not > 0. :raises ValueError: If the specified top_k is not > 0.
""" """
@ -51,7 +51,7 @@ class MemoryRetriever:
) )
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever": def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
@ -77,13 +77,12 @@ class MemoryRetriever:
scale_score: Optional[bool] = None, scale_score: Optional[bool] = None,
): ):
""" """
Run the MemoryRetriever on the given input data. Run the MemoryBM25Retriever on the given input data.
:param query: The query string for the retriever. :param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space. :param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return. :param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not. :param scale_score: Whether to scale the BM25 scores or not.
:param document_stores: A dictionary mapping DocumentStore names to instances.
:return: The retrieved documents. :return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance. :raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
@ -101,3 +100,119 @@ class MemoryRetriever:
self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score) self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score)
) )
return {"documents": docs} return {"documents": docs}
@component
class MemoryEmbeddingRetriever:
"""
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
Needs to be connected to a MemoryDocumentStore to run.
"""
def __init__(
self,
document_store: MemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a MemoryEmbeddingRetriever component.
:param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
:raises ValueError: If the specified top_k is not > 0.
"""
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("document_store must be an instance of MemoryDocumentStore")
self.document_store = document_store
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.filters = filters
self.top_k = top_k
self.scale_score = scale_score
self.return_embedding = return_embedding
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self,
document_store=docstore,
filters=self.filters,
top_k=self.top_k,
scale_score=self.scale_score,
return_embedding=self.return_embedding,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")
docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)
@component.output_types(documents=List[List[Document]])
def run(
self,
queries_embeddings: List[List[float]],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
return_embedding: Optional[bool] = None,
):
"""
Run the MemoryEmbeddingRetriever on the given input data.
:param queries_embeddings: Embeddings of the queries.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the scores of the retrieved documents or not.
:param return_embedding: Whether to return the embedding of the retrieved Documents.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:
scale_score = self.scale_score
if return_embedding is None:
return_embedding = self.return_embedding
docs = []
for query_embedding in queries_embeddings:
docs.append(
self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
scale_score=scale_score,
return_embedding=return_embedding,
)
)
return {"documents": docs}

View File

@ -0,0 +1,6 @@
---
preview:
- |
Rename `MemoryRetriever` to `MemoryBM25Retriever`
Add `MemoryEmbeddingRetriever`, which takes as input a query embedding and
retrieves the most relevant Documents from a `MemoryDocumentStore`.

View File

@ -1,11 +1,10 @@
from typing import Dict, Any from typing import Dict, Any
from unittest.mock import MagicMock, patch
import pytest import pytest
from haystack.preview import Pipeline, DeserializationError from haystack.preview import Pipeline, DeserializationError
from haystack.preview.testing.factory import document_store_class from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.retrievers.memory import MemoryRetriever from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
from haystack.preview.dataclasses import Document from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore from haystack.preview.document_stores import MemoryDocumentStore
@ -21,36 +20,39 @@ def mock_docs():
] ]
class TestMemoryRetriever: class TestMemoryRetrievers:
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_init_default(self): def test_init_default(self, retriever_cls):
retriever = MemoryRetriever(MemoryDocumentStore()) retriever = retriever_cls(MemoryDocumentStore())
assert retriever.filters is None assert retriever.filters is None
assert retriever.top_k == 10 assert retriever.top_k == 10
assert retriever.scale_score assert retriever.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_init_with_parameters(self): def test_init_with_parameters(self, retriever_cls):
retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False) retriever = retriever_cls(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
assert retriever.filters == {"name": "test.txt"} assert retriever.filters == {"name": "test.txt"}
assert retriever.top_k == 5 assert retriever.top_k == 5
assert not retriever.scale_score assert not retriever.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self): def test_init_with_invalid_top_k_parameter(self, retriever_cls):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"): with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False) retriever_cls(MemoryDocumentStore(), top_k=-2, scale_score=False)
@pytest.mark.unit @pytest.mark.unit
def test_to_dict(self): def test_bm25_retriever_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore() document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever(document_store=document_store) component = MemoryBM25Retriever(document_store=document_store)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "MemoryRetriever", "type": "MemoryBM25Retriever",
"init_parameters": { "init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}}, "document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None, "filters": None,
@ -60,16 +62,35 @@ class TestMemoryRetriever:
} }
@pytest.mark.unit @pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self): def test_embedding_retriever_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore() document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryRetriever( component = MemoryEmbeddingRetriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
"top_k": 10,
"scale_score": True,
"return_embedding": False,
},
}
@pytest.mark.unit
def test_bm25_retriever_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryBM25Retriever(
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "MemoryRetriever", "type": "MemoryBM25Retriever",
"init_parameters": { "init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}}, "document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"}, "filters": {"name": "test.txt"},
@ -79,50 +100,78 @@ class TestMemoryRetriever:
} }
@pytest.mark.unit @pytest.mark.unit
def test_from_dict(self): def test_embedding_retriever_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryEmbeddingRetriever(
document_store=document_store,
filters={"name": "test.txt"},
top_k=5,
scale_score=False,
return_embedding=True,
)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": False,
"return_embedding": True,
},
}
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_from_dict(self, retriever_cls):
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
data = { data = {
"type": "MemoryRetriever", "type": retriever_cls.__name__,
"init_parameters": { "init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}}, "document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"}, "filters": {"name": "test.txt"},
"top_k": 5, "top_k": 5,
}, },
} }
component = MemoryRetriever.from_dict(data) component = retriever_cls.from_dict(data)
assert isinstance(component.document_store, MemoryDocumentStore) assert isinstance(component.document_store, MemoryDocumentStore)
assert component.filters == {"name": "test.txt"} assert component.filters == {"name": "test.txt"}
assert component.top_k == 5 assert component.top_k == 5
assert component.scale_score assert component.scale_score
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_from_dict_without_docstore(self): def test_from_dict_without_docstore(self, retriever_cls):
data = {"type": "MemoryRetriever", "init_parameters": {}} data = {"type": retriever_cls.__name__, "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
MemoryRetriever.from_dict(data) retriever_cls.from_dict(data)
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_from_dict_without_docstore_type(self): def test_from_dict_without_docstore_type(self, retriever_cls):
data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}} data = {"type": retriever_cls.__name__, "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
MemoryRetriever.from_dict(data) retriever_cls.from_dict(data)
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
def test_from_dict_nonexisting_docstore(self): def test_from_dict_nonexisting_docstore(self, retriever_cls):
data = { data = {
"type": "MemoryRetriever", "type": retriever_cls.__name__,
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}}, "init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
} }
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"): with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
MemoryRetriever.from_dict(data) retriever_cls.from_dict(data)
@pytest.mark.unit @pytest.mark.unit
def test_valid_run(self, mock_docs): def test_bm25_retriever_valid_run(self, mock_docs):
top_k = 5 top_k = 5
ds = MemoryDocumentStore() ds = MemoryDocumentStore()
ds.write_documents(mock_docs) ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds, top_k=top_k) retriever = MemoryBM25Retriever(ds, top_k=top_k)
result = retriever.run(queries=["PHP", "Java"]) result = retriever.run(queries=["PHP", "Java"])
assert "documents" in result assert "documents" in result
@ -133,10 +182,32 @@ class TestMemoryRetriever:
assert result["documents"][1][0].content == "Java is a popular programming language" assert result["documents"][1][0].content == "Java is a popular programming language"
@pytest.mark.unit @pytest.mark.unit
def test_invalid_run_wrong_store_type(self): def test_embedding_retriever_valid_run(self):
top_k = 3
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
assert "documents" in result
assert len(result["documents"]) == 2
assert len(result["documents"][0]) == top_k
assert len(result["documents"][1]) == top_k
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self, retriever_cls):
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore") SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"): with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
MemoryRetriever(SomeOtherDocumentStore()) retriever_cls(SomeOtherDocumentStore())
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -146,10 +217,10 @@ class TestMemoryRetriever:
("Java", "Java is a popular programming language"), ("Java", "Java is a popular programming language"),
], ],
) )
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): def test_run_bm25_retriever_with_pipeline(self, mock_docs, query: str, query_result: str):
ds = MemoryDocumentStore() ds = MemoryDocumentStore()
ds.write_documents(mock_docs) ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds) retriever = MemoryBM25Retriever(ds)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_component("retriever", retriever) pipeline.add_component("retriever", retriever)
@ -161,6 +232,38 @@ class TestMemoryRetriever:
assert results_docs assert results_docs
assert results_docs[0][0].content == query_result assert results_docs[0][0].content == query_result
@pytest.mark.integration
def test_run_embedding_retriever_with_pipeline(self):
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
top_k = 2
docs = [
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(
data={
"retriever": {
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
"return_embedding": True,
}
}
)
assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert len(results_docs[0]) == top_k
assert len(results_docs[1]) == top_k
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, query_result, top_k", "query, query_result, top_k",
@ -170,10 +273,10 @@ class TestMemoryRetriever:
("Ruby", "Ruby is a popular programming language", 3), ("Ruby", "Ruby is a popular programming language", 3),
], ],
) )
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): def test_run_bm25_retriever_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
ds = MemoryDocumentStore() ds = MemoryDocumentStore()
ds.write_documents(mock_docs) ds.write_documents(mock_docs)
retriever = MemoryRetriever(ds) retriever = MemoryBM25Retriever(ds)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_component("retriever", retriever) pipeline.add_component("retriever", retriever)