2023-08-18 16:33:35 +02:00
|
|
|
from typing import Dict, Any
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack.preview import Pipeline
|
2023-08-18 16:33:35 +02:00
|
|
|
from haystack.preview.testing.factory import document_store_class
|
2023-06-27 17:42:23 +02:00
|
|
|
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
|
|
|
from haystack.preview.dataclasses import Document
|
2023-07-26 09:32:23 +02:00
|
|
|
from haystack.preview.document_stores import MemoryDocumentStore
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture()
|
|
|
|
def mock_docs():
|
|
|
|
return [
|
|
|
|
Document.from_dict({"content": "Javascript is a popular programming language"}),
|
|
|
|
Document.from_dict({"content": "Java is a popular programming language"}),
|
|
|
|
Document.from_dict({"content": "Python is a popular programming language"}),
|
|
|
|
Document.from_dict({"content": "Ruby is a popular programming language"}),
|
|
|
|
Document.from_dict({"content": "PHP is a popular programming language"}),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2023-08-23 17:03:37 +02:00
|
|
|
class TestMemoryRetriever:
|
2023-06-27 17:42:23 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_default(self):
|
2023-08-18 16:33:35 +02:00
|
|
|
retriever = MemoryRetriever(MemoryDocumentStore())
|
2023-08-09 15:51:32 +02:00
|
|
|
assert retriever.filters is None
|
|
|
|
assert retriever.top_k == 10
|
|
|
|
assert retriever.scale_score
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_with_parameters(self):
|
2023-08-18 16:33:35 +02:00
|
|
|
retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
2023-08-09 15:51:32 +02:00
|
|
|
assert retriever.filters == {"name": "test.txt"}
|
|
|
|
assert retriever.top_k == 5
|
|
|
|
assert not retriever.scale_score
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_with_invalid_top_k_parameter(self):
|
|
|
|
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
2023-08-18 16:33:35 +02:00
|
|
|
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_valid_run(self, mock_docs):
|
|
|
|
top_k = 5
|
|
|
|
ds = MemoryDocumentStore()
|
|
|
|
ds.write_documents(mock_docs)
|
2023-07-17 15:06:19 +02:00
|
|
|
|
2023-08-18 16:33:35 +02:00
|
|
|
retriever = MemoryRetriever(ds, top_k=top_k)
|
2023-08-09 15:51:32 +02:00
|
|
|
result = retriever.run(queries=["PHP", "Java"])
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-08-09 15:51:32 +02:00
|
|
|
assert "documents" in result
|
|
|
|
assert len(result["documents"]) == 2
|
|
|
|
assert len(result["documents"][0]) == top_k
|
|
|
|
assert len(result["documents"][1]) == top_k
|
|
|
|
assert result["documents"][0][0].content == "PHP is a popular programming language"
|
|
|
|
assert result["documents"][1][0].content == "Java is a popular programming language"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_invalid_run_wrong_store_type(self):
|
2023-08-18 16:33:35 +02:00
|
|
|
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
|
|
|
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
|
|
|
MemoryRetriever(SomeOtherDocumentStore())
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"query, query_result",
|
|
|
|
[
|
|
|
|
("Javascript", "Javascript is a popular programming language"),
|
|
|
|
("Java", "Java is a popular programming language"),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
|
|
|
ds = MemoryDocumentStore()
|
|
|
|
ds.write_documents(mock_docs)
|
2023-08-18 16:33:35 +02:00
|
|
|
retriever = MemoryRetriever(ds)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
pipeline = Pipeline()
|
2023-08-18 16:33:35 +02:00
|
|
|
pipeline.add_component("retriever", retriever)
|
2023-08-09 15:51:32 +02:00
|
|
|
result: Dict[str, Any] = pipeline.run(data={"retriever": {"queries": [query]}})
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
assert result
|
|
|
|
assert "retriever" in result
|
2023-08-09 15:51:32 +02:00
|
|
|
results_docs = result["retriever"]["documents"]
|
2023-06-27 17:42:23 +02:00
|
|
|
assert results_docs
|
2023-07-07 12:10:35 +02:00
|
|
|
assert results_docs[0][0].content == query_result
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"query, query_result, top_k",
|
|
|
|
[
|
|
|
|
("Javascript", "Javascript is a popular programming language", 1),
|
|
|
|
("Java", "Java is a popular programming language", 2),
|
|
|
|
("Ruby", "Ruby is a popular programming language", 3),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
|
|
|
ds = MemoryDocumentStore()
|
|
|
|
ds.write_documents(mock_docs)
|
2023-08-18 16:33:35 +02:00
|
|
|
retriever = MemoryRetriever(ds)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
pipeline = Pipeline()
|
2023-08-18 16:33:35 +02:00
|
|
|
pipeline.add_component("retriever", retriever)
|
2023-08-09 15:51:32 +02:00
|
|
|
result: Dict[str, Any] = pipeline.run(data={"retriever": {"queries": [query], "top_k": top_k}})
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
assert result
|
|
|
|
assert "retriever" in result
|
2023-08-09 15:51:32 +02:00
|
|
|
results_docs = result["retriever"]["documents"]
|
2023-06-27 17:42:23 +02:00
|
|
|
assert results_docs
|
2023-07-07 12:10:35 +02:00
|
|
|
assert len(results_docs[0]) == top_k
|
|
|
|
assert results_docs[0][0].content == query_result
|