mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-03 23:19:20 +00:00
MemoryEmbeddingRetriever
(2.0) (#5726)
* MemoryDocumentStore - Embedding retrieval draft * add release notes * fix mypy * better comment * improve return_embeddings handling * MemoryEmbeddingRetriever - first draft * address PR comments * release note * update docstrings * update docstrings * incorporated feeback * add return_embedding to __init__ * rm leftover docstring --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
d860a5c604
commit
2edf85f739
@ -1,3 +1,3 @@
|
|||||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
|
||||||
|
|
||||||
__all__ = ["MemoryRetriever"]
|
__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]
|
||||||
|
@ -5,7 +5,7 @@ from haystack.preview.document_stores import MemoryDocumentStore, document_store
|
|||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class MemoryRetriever:
|
class MemoryBM25Retriever:
|
||||||
"""
|
"""
|
||||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||||
|
|
||||||
@ -20,12 +20,12 @@ class MemoryRetriever:
|
|||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a MemoryRetriever component.
|
Create a MemoryBM25Retriever component.
|
||||||
|
|
||||||
:param document_store: An instance of MemoryDocumentStore.
|
:param document_store: An instance of MemoryDocumentStore.
|
||||||
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
:param filters: A dictionary with filters to narrow down the search space. Default is None.
|
||||||
:param top_k: The maximum number of documents to retrieve (default is 10).
|
:param top_k: The maximum number of documents to retrieve. Default is 10.
|
||||||
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
:param scale_score: Whether to scale the BM25 score or not. Default is True.
|
||||||
|
|
||||||
:raises ValueError: If the specified top_k is not > 0.
|
:raises ValueError: If the specified top_k is not > 0.
|
||||||
"""
|
"""
|
||||||
@ -51,7 +51,7 @@ class MemoryRetriever:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
|
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
|
||||||
"""
|
"""
|
||||||
Deserialize this component from a dictionary.
|
Deserialize this component from a dictionary.
|
||||||
"""
|
"""
|
||||||
@ -77,13 +77,12 @@ class MemoryRetriever:
|
|||||||
scale_score: Optional[bool] = None,
|
scale_score: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run the MemoryRetriever on the given input data.
|
Run the MemoryBM25Retriever on the given input data.
|
||||||
|
|
||||||
:param query: The query string for the retriever.
|
:param query: The query string for the retriever.
|
||||||
:param filters: A dictionary with filters to narrow down the search space.
|
:param filters: A dictionary with filters to narrow down the search space.
|
||||||
:param top_k: The maximum number of documents to return.
|
:param top_k: The maximum number of documents to return.
|
||||||
:param scale_score: Whether to scale the BM25 scores or not.
|
:param scale_score: Whether to scale the BM25 scores or not.
|
||||||
:param document_stores: A dictionary mapping DocumentStore names to instances.
|
|
||||||
:return: The retrieved documents.
|
:return: The retrieved documents.
|
||||||
|
|
||||||
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
||||||
@ -101,3 +100,119 @@ class MemoryRetriever:
|
|||||||
self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score)
|
self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score)
|
||||||
)
|
)
|
||||||
return {"documents": docs}
|
return {"documents": docs}
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MemoryEmbeddingRetriever:
|
||||||
|
"""
|
||||||
|
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
|
||||||
|
|
||||||
|
Needs to be connected to a MemoryDocumentStore to run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
document_store: MemoryDocumentStore,
|
||||||
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
scale_score: bool = True,
|
||||||
|
return_embedding: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a MemoryEmbeddingRetriever component.
|
||||||
|
|
||||||
|
:param document_store: An instance of MemoryDocumentStore.
|
||||||
|
:param filters: A dictionary with filters to narrow down the search space. Default is None.
|
||||||
|
:param top_k: The maximum number of documents to retrieve. Default is 10.
|
||||||
|
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
|
||||||
|
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
|
||||||
|
|
||||||
|
:raises ValueError: If the specified top_k is not > 0.
|
||||||
|
"""
|
||||||
|
if not isinstance(document_store, MemoryDocumentStore):
|
||||||
|
raise ValueError("document_store must be an instance of MemoryDocumentStore")
|
||||||
|
|
||||||
|
self.document_store = document_store
|
||||||
|
|
||||||
|
if top_k <= 0:
|
||||||
|
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||||
|
|
||||||
|
self.filters = filters
|
||||||
|
self.top_k = top_k
|
||||||
|
self.scale_score = scale_score
|
||||||
|
self.return_embedding = return_embedding
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serialize this component to a dictionary.
|
||||||
|
"""
|
||||||
|
docstore = self.document_store.to_dict()
|
||||||
|
return default_to_dict(
|
||||||
|
self,
|
||||||
|
document_store=docstore,
|
||||||
|
filters=self.filters,
|
||||||
|
top_k=self.top_k,
|
||||||
|
scale_score=self.scale_score,
|
||||||
|
return_embedding=self.return_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
|
||||||
|
"""
|
||||||
|
Deserialize this component from a dictionary.
|
||||||
|
"""
|
||||||
|
init_params = data.get("init_parameters", {})
|
||||||
|
if "document_store" not in init_params:
|
||||||
|
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||||
|
if "type" not in init_params["document_store"]:
|
||||||
|
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||||
|
if init_params["document_store"]["type"] not in document_store.registry:
|
||||||
|
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")
|
||||||
|
|
||||||
|
docstore_class = document_store.registry[init_params["document_store"]["type"]]
|
||||||
|
docstore = docstore_class.from_dict(init_params["document_store"])
|
||||||
|
data["init_parameters"]["document_store"] = docstore
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
@component.output_types(documents=List[List[Document]])
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
queries_embeddings: List[List[float]],
|
||||||
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
scale_score: Optional[bool] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the MemoryEmbeddingRetriever on the given input data.
|
||||||
|
|
||||||
|
:param queries_embeddings: Embeddings of the queries.
|
||||||
|
:param filters: A dictionary with filters to narrow down the search space.
|
||||||
|
:param top_k: The maximum number of documents to return.
|
||||||
|
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
||||||
|
:param return_embedding: Whether to return the embedding of the retrieved Documents.
|
||||||
|
:return: The retrieved documents.
|
||||||
|
|
||||||
|
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
||||||
|
"""
|
||||||
|
if filters is None:
|
||||||
|
filters = self.filters
|
||||||
|
if top_k is None:
|
||||||
|
top_k = self.top_k
|
||||||
|
if scale_score is None:
|
||||||
|
scale_score = self.scale_score
|
||||||
|
if return_embedding is None:
|
||||||
|
return_embedding = self.return_embedding
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for query_embedding in queries_embeddings:
|
||||||
|
docs.append(
|
||||||
|
self.document_store.embedding_retrieval(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
filters=filters,
|
||||||
|
top_k=top_k,
|
||||||
|
scale_score=scale_score,
|
||||||
|
return_embedding=return_embedding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"documents": docs}
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Rename `MemoryRetriever` to `MemoryBM25Retriever`
|
||||||
|
Add `MemoryEmbeddingRetriever`, which takes as input a query embedding and
|
||||||
|
retrieves the most relevant Documents from a `MemoryDocumentStore`.
|
@ -1,11 +1,10 @@
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.preview import Pipeline, DeserializationError
|
from haystack.preview import Pipeline, DeserializationError
|
||||||
from haystack.preview.testing.factory import document_store_class
|
from haystack.preview.testing.factory import document_store_class
|
||||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever
|
||||||
from haystack.preview.dataclasses import Document
|
from haystack.preview.dataclasses import Document
|
||||||
from haystack.preview.document_stores import MemoryDocumentStore
|
from haystack.preview.document_stores import MemoryDocumentStore
|
||||||
|
|
||||||
@ -21,36 +20,39 @@ def mock_docs():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestMemoryRetriever:
|
class TestMemoryRetrievers:
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_default(self):
|
def test_init_default(self, retriever_cls):
|
||||||
retriever = MemoryRetriever(MemoryDocumentStore())
|
retriever = retriever_cls(MemoryDocumentStore())
|
||||||
assert retriever.filters is None
|
assert retriever.filters is None
|
||||||
assert retriever.top_k == 10
|
assert retriever.top_k == 10
|
||||||
assert retriever.scale_score
|
assert retriever.scale_score
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_with_parameters(self):
|
def test_init_with_parameters(self, retriever_cls):
|
||||||
retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
retriever = retriever_cls(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
||||||
assert retriever.filters == {"name": "test.txt"}
|
assert retriever.filters == {"name": "test.txt"}
|
||||||
assert retriever.top_k == 5
|
assert retriever.top_k == 5
|
||||||
assert not retriever.scale_score
|
assert not retriever.scale_score
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_with_invalid_top_k_parameter(self):
|
def test_init_with_invalid_top_k_parameter(self, retriever_cls):
|
||||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||||
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
retriever_cls(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict(self):
|
def test_bm25_retriever_to_dict(self):
|
||||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||||
document_store = MyFakeStore()
|
document_store = MyFakeStore()
|
||||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||||
component = MemoryRetriever(document_store=document_store)
|
component = MemoryBM25Retriever(document_store=document_store)
|
||||||
|
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "MemoryRetriever",
|
"type": "MemoryBM25Retriever",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||||
"filters": None,
|
"filters": None,
|
||||||
@ -60,16 +62,35 @@ class TestMemoryRetriever:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict_with_custom_init_parameters(self):
|
def test_embedding_retriever_to_dict(self):
|
||||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||||
document_store = MyFakeStore()
|
document_store = MyFakeStore()
|
||||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||||
component = MemoryRetriever(
|
component = MemoryEmbeddingRetriever(document_store=document_store)
|
||||||
|
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "MemoryEmbeddingRetriever",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||||
|
"filters": None,
|
||||||
|
"top_k": 10,
|
||||||
|
"scale_score": True,
|
||||||
|
"return_embedding": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_bm25_retriever_to_dict_with_custom_init_parameters(self):
|
||||||
|
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||||
|
document_store = MyFakeStore()
|
||||||
|
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||||
|
component = MemoryBM25Retriever(
|
||||||
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
|
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "MemoryRetriever",
|
"type": "MemoryBM25Retriever",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||||
"filters": {"name": "test.txt"},
|
"filters": {"name": "test.txt"},
|
||||||
@ -79,50 +100,78 @@ class TestMemoryRetriever:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict(self):
|
def test_embedding_retriever_to_dict_with_custom_init_parameters(self):
|
||||||
|
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||||
|
document_store = MyFakeStore()
|
||||||
|
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||||
|
component = MemoryEmbeddingRetriever(
|
||||||
|
document_store=document_store,
|
||||||
|
filters={"name": "test.txt"},
|
||||||
|
top_k=5,
|
||||||
|
scale_score=False,
|
||||||
|
return_embedding=True,
|
||||||
|
)
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "MemoryEmbeddingRetriever",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||||
|
"filters": {"name": "test.txt"},
|
||||||
|
"top_k": 5,
|
||||||
|
"scale_score": False,
|
||||||
|
"return_embedding": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_from_dict(self, retriever_cls):
|
||||||
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||||
data = {
|
data = {
|
||||||
"type": "MemoryRetriever",
|
"type": retriever_cls.__name__,
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||||
"filters": {"name": "test.txt"},
|
"filters": {"name": "test.txt"},
|
||||||
"top_k": 5,
|
"top_k": 5,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
component = MemoryRetriever.from_dict(data)
|
component = retriever_cls.from_dict(data)
|
||||||
assert isinstance(component.document_store, MemoryDocumentStore)
|
assert isinstance(component.document_store, MemoryDocumentStore)
|
||||||
assert component.filters == {"name": "test.txt"}
|
assert component.filters == {"name": "test.txt"}
|
||||||
assert component.top_k == 5
|
assert component.top_k == 5
|
||||||
assert component.scale_score
|
assert component.scale_score
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict_without_docstore(self):
|
def test_from_dict_without_docstore(self, retriever_cls):
|
||||||
data = {"type": "MemoryRetriever", "init_parameters": {}}
|
data = {"type": retriever_cls.__name__, "init_parameters": {}}
|
||||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||||
MemoryRetriever.from_dict(data)
|
retriever_cls.from_dict(data)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict_without_docstore_type(self):
|
def test_from_dict_without_docstore_type(self, retriever_cls):
|
||||||
data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
data = {"type": retriever_cls.__name__, "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||||
MemoryRetriever.from_dict(data)
|
retriever_cls.from_dict(data)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict_nonexisting_docstore(self):
|
def test_from_dict_nonexisting_docstore(self, retriever_cls):
|
||||||
data = {
|
data = {
|
||||||
"type": "MemoryRetriever",
|
"type": retriever_cls.__name__,
|
||||||
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
|
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
|
||||||
}
|
}
|
||||||
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
|
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
|
||||||
MemoryRetriever.from_dict(data)
|
retriever_cls.from_dict(data)
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_valid_run(self, mock_docs):
|
def test_bm25_retriever_valid_run(self, mock_docs):
|
||||||
top_k = 5
|
top_k = 5
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
|
|
||||||
retriever = MemoryRetriever(ds, top_k=top_k)
|
retriever = MemoryBM25Retriever(ds, top_k=top_k)
|
||||||
result = retriever.run(queries=["PHP", "Java"])
|
result = retriever.run(queries=["PHP", "Java"])
|
||||||
|
|
||||||
assert "documents" in result
|
assert "documents" in result
|
||||||
@ -133,10 +182,32 @@ class TestMemoryRetriever:
|
|||||||
assert result["documents"][1][0].content == "Java is a popular programming language"
|
assert result["documents"][1][0].content == "Java is a popular programming language"
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_invalid_run_wrong_store_type(self):
|
def test_embedding_retriever_valid_run(self):
|
||||||
|
top_k = 3
|
||||||
|
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||||
|
docs = [
|
||||||
|
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||||
|
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||||
|
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
|
||||||
|
]
|
||||||
|
ds.write_documents(docs)
|
||||||
|
|
||||||
|
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||||
|
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
|
||||||
|
|
||||||
|
assert "documents" in result
|
||||||
|
assert len(result["documents"]) == 2
|
||||||
|
assert len(result["documents"][0]) == top_k
|
||||||
|
assert len(result["documents"][1]) == top_k
|
||||||
|
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
||||||
|
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_invalid_run_wrong_store_type(self, retriever_cls):
|
||||||
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
||||||
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
||||||
MemoryRetriever(SomeOtherDocumentStore())
|
retriever_cls(SomeOtherDocumentStore())
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -146,10 +217,10 @@ class TestMemoryRetriever:
|
|||||||
("Java", "Java is a popular programming language"),
|
("Java", "Java is a popular programming language"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
def test_run_bm25_retriever_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
retriever = MemoryRetriever(ds)
|
retriever = MemoryBM25Retriever(ds)
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_component("retriever", retriever)
|
pipeline.add_component("retriever", retriever)
|
||||||
@ -161,6 +232,38 @@ class TestMemoryRetriever:
|
|||||||
assert results_docs
|
assert results_docs
|
||||||
assert results_docs[0][0].content == query_result
|
assert results_docs[0][0].content == query_result
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_run_embedding_retriever_with_pipeline(self):
|
||||||
|
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||||
|
top_k = 2
|
||||||
|
docs = [
|
||||||
|
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||||
|
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||||
|
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
|
||||||
|
]
|
||||||
|
ds.write_documents(docs)
|
||||||
|
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_component("retriever", retriever)
|
||||||
|
result: Dict[str, Any] = pipeline.run(
|
||||||
|
data={
|
||||||
|
"retriever": {
|
||||||
|
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
|
||||||
|
"return_embedding": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert "retriever" in result
|
||||||
|
results_docs = result["retriever"]["documents"]
|
||||||
|
assert results_docs
|
||||||
|
assert len(results_docs[0]) == top_k
|
||||||
|
assert len(results_docs[1]) == top_k
|
||||||
|
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
||||||
|
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"query, query_result, top_k",
|
"query, query_result, top_k",
|
||||||
@ -170,10 +273,10 @@ class TestMemoryRetriever:
|
|||||||
("Ruby", "Ruby is a popular programming language", 3),
|
("Ruby", "Ruby is a popular programming language", 3),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
def test_run_bm25_retriever_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
retriever = MemoryRetriever(ds)
|
retriever = MemoryBM25Retriever(ds)
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_component("retriever", retriever)
|
pipeline.add_component("retriever", retriever)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user