2023-09-19 19:21:49 +02:00
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
import pytest
|
2023-10-31 12:44:04 +01:00
|
|
|
import numpy as np
|
2023-09-19 19:21:49 +02:00
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack import Pipeline, DeserializationError
|
|
|
|
from haystack.testing.factory import document_store_class
|
|
|
|
from haystack.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever
|
|
|
|
from haystack.dataclasses import Document
|
|
|
|
from haystack.document_stores import InMemoryDocumentStore
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestMemoryEmbeddingRetriever:
|
|
|
|
def test_init_default(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
|
2023-09-19 19:21:49 +02:00
|
|
|
assert retriever.filters is None
|
|
|
|
assert retriever.top_k == 10
|
2023-11-13 11:59:18 +01:00
|
|
|
assert retriever.scale_score is False
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_init_with_parameters(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
retriever = InMemoryEmbeddingRetriever(
|
2023-11-13 11:59:18 +01:00
|
|
|
InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
|
2023-09-19 19:21:49 +02:00
|
|
|
)
|
|
|
|
assert retriever.filters == {"name": "test.txt"}
|
|
|
|
assert retriever.top_k == 5
|
2023-11-13 11:59:18 +01:00
|
|
|
assert retriever.scale_score
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_init_with_invalid_top_k_parameter(self):
|
2023-11-09 17:34:52 +01:00
|
|
|
with pytest.raises(ValueError):
|
2023-11-13 11:59:18 +01:00
|
|
|
InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_to_dict(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
2023-09-19 19:21:49 +02:00
|
|
|
document_store = MyFakeStore()
|
2023-11-17 13:46:23 +00:00
|
|
|
document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
|
2023-10-17 16:15:16 +02:00
|
|
|
component = InMemoryEmbeddingRetriever(document_store=document_store)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-09-19 19:21:49 +02:00
|
|
|
"init_parameters": {
|
2023-11-17 13:46:23 +00:00
|
|
|
"document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
|
2023-09-19 19:21:49 +02:00
|
|
|
"filters": None,
|
|
|
|
"top_k": 10,
|
2023-11-13 11:59:18 +01:00
|
|
|
"scale_score": False,
|
2023-09-19 19:21:49 +02:00
|
|
|
"return_embedding": False,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
2023-09-19 19:21:49 +02:00
|
|
|
document_store = MyFakeStore()
|
2023-11-17 13:46:23 +00:00
|
|
|
document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
|
2023-10-17 16:15:16 +02:00
|
|
|
component = InMemoryEmbeddingRetriever(
|
2023-09-19 19:21:49 +02:00
|
|
|
document_store=document_store,
|
|
|
|
filters={"name": "test.txt"},
|
|
|
|
top_k=5,
|
2023-11-13 11:59:18 +01:00
|
|
|
scale_score=True,
|
2023-09-19 19:21:49 +02:00
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-09-19 19:21:49 +02:00
|
|
|
"init_parameters": {
|
2023-11-17 13:46:23 +00:00
|
|
|
"document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
|
2023-09-19 19:21:49 +02:00
|
|
|
"filters": {"name": "test.txt"},
|
|
|
|
"top_k": 5,
|
2023-11-13 11:59:18 +01:00
|
|
|
"scale_score": True,
|
2023-09-19 19:21:49 +02:00
|
|
|
"return_embedding": True,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_from_dict(self):
|
|
|
|
data = {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-09-19 19:21:49 +02:00
|
|
|
"init_parameters": {
|
2023-12-04 15:17:28 +01:00
|
|
|
"document_store": {
|
|
|
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
|
|
|
"init_parameters": {},
|
|
|
|
},
|
2023-09-19 19:21:49 +02:00
|
|
|
"filters": {"name": "test.txt"},
|
|
|
|
"top_k": 5,
|
|
|
|
},
|
|
|
|
}
|
2023-10-17 16:15:16 +02:00
|
|
|
component = InMemoryEmbeddingRetriever.from_dict(data)
|
|
|
|
assert isinstance(component.document_store, InMemoryDocumentStore)
|
2023-09-19 19:21:49 +02:00
|
|
|
assert component.filters == {"name": "test.txt"}
|
|
|
|
assert component.top_k == 5
|
2023-11-13 11:59:18 +01:00
|
|
|
assert component.scale_score is False
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_from_dict_without_docstore(self):
|
2023-11-17 13:46:23 +00:00
|
|
|
data = {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-11-17 13:46:23 +00:00
|
|
|
"init_parameters": {},
|
|
|
|
}
|
2023-09-19 19:21:49 +02:00
|
|
|
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
2023-10-17 16:15:16 +02:00
|
|
|
InMemoryEmbeddingRetriever.from_dict(data)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_from_dict_without_docstore_type(self):
|
2023-11-17 13:46:23 +00:00
|
|
|
data = {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-11-17 13:46:23 +00:00
|
|
|
"init_parameters": {"document_store": {"init_parameters": {}}},
|
|
|
|
}
|
2023-12-04 15:17:28 +01:00
|
|
|
with pytest.raises(DeserializationError):
|
2023-10-17 16:15:16 +02:00
|
|
|
InMemoryEmbeddingRetriever.from_dict(data)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_from_dict_nonexisting_docstore(self):
|
|
|
|
data = {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
|
2023-12-04 15:17:28 +01:00
|
|
|
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
|
2023-09-19 19:21:49 +02:00
|
|
|
}
|
2023-12-04 15:17:28 +01:00
|
|
|
with pytest.raises(DeserializationError):
|
2023-10-17 16:15:16 +02:00
|
|
|
InMemoryEmbeddingRetriever.from_dict(data)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_valid_run(self):
|
|
|
|
top_k = 3
|
2023-10-17 16:15:16 +02:00
|
|
|
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
2023-09-19 19:21:49 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
|
|
|
|
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
|
2023-09-19 19:21:49 +02:00
|
|
|
]
|
|
|
|
ds.write_documents(docs)
|
|
|
|
|
2023-10-17 16:15:16 +02:00
|
|
|
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
|
2023-09-19 19:21:49 +02:00
|
|
|
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
|
|
|
|
|
|
|
|
assert "documents" in result
|
|
|
|
assert len(result["documents"]) == top_k
|
2023-10-20 10:34:28 +02:00
|
|
|
assert np.array_equal(result["documents"][0].embedding, [1.0, 1.0, 1.0, 1.0])
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
def test_invalid_run_wrong_store_type(self):
|
|
|
|
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
2023-10-17 16:15:16 +02:00
|
|
|
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
|
|
|
|
InMemoryEmbeddingRetriever(SomeOtherDocumentStore())
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_run_with_pipeline(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
2023-09-19 19:21:49 +02:00
|
|
|
top_k = 2
|
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
|
|
|
|
Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
|
2023-09-19 19:21:49 +02:00
|
|
|
]
|
|
|
|
ds.write_documents(docs)
|
2023-10-17 16:15:16 +02:00
|
|
|
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
|
2023-09-19 19:21:49 +02:00
|
|
|
|
|
|
|
pipeline = Pipeline()
|
|
|
|
pipeline.add_component("retriever", retriever)
|
|
|
|
result: Dict[str, Any] = pipeline.run(
|
2023-10-23 12:26:05 +02:00
|
|
|
data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
|
2023-09-19 19:21:49 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
assert result
|
|
|
|
assert "retriever" in result
|
|
|
|
results_docs = result["retriever"]["documents"]
|
|
|
|
assert results_docs
|
|
|
|
assert len(results_docs) == top_k
|
2023-10-20 10:34:28 +02:00
|
|
|
assert np.array_equal(results_docs[0].embedding, [1.0, 1.0, 1.0, 1.0])
|