mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
Rework MemoryRetriever (#5582)
* Remove DocumentStoreAwareMixin from MemoryRetriever * Add release notes * Update an article --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
011baf492f
commit
4bc68cbc2f
@ -1,29 +1,39 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore, DocumentStoreAwareMixin
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
|
||||
@component
|
||||
class MemoryRetriever(DocumentStoreAwareMixin):
|
||||
class MemoryRetriever:
|
||||
"""
|
||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||
|
||||
Needs to be connected to a MemoryDocumentStore to run.
|
||||
"""
|
||||
|
||||
supported_document_stores = [MemoryDocumentStore]
|
||||
|
||||
def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
document_store: MemoryDocumentStore,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a MemoryRetriever component.
|
||||
|
||||
:param document_store: An instance of MemoryDocumentStore.
|
||||
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
||||
:param top_k: The maximum number of documents to retrieve (default is 10).
|
||||
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
||||
|
||||
:raises ValueError: If the specified top_k is not > 0.
|
||||
"""
|
||||
if not isinstance(document_store, MemoryDocumentStore):
|
||||
raise ValueError("document_store must be an instance of MemoryDocumentStore")
|
||||
|
||||
self.document_store = document_store
|
||||
|
||||
if top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
@ -51,12 +61,6 @@ class MemoryRetriever(DocumentStoreAwareMixin):
|
||||
|
||||
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
||||
"""
|
||||
self.document_store: MemoryDocumentStore
|
||||
if not self.document_store:
|
||||
raise ValueError(
|
||||
"MemoryRetriever needs a DocumentStore to run: set the DocumentStore instance to the self.document_store attribute"
|
||||
)
|
||||
|
||||
if filters is None:
|
||||
filters = self.filters
|
||||
if top_k is None:
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- Rework `MemoryRetriever` to remove `DocumentStoreAwareMixin`.
|
||||
Now we require a `MemoryDocumentStore` when initialisating the retriever.
|
||||
@ -1,16 +1,15 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
from test.preview.components.base import BaseTestComponent
|
||||
|
||||
from haystack.preview.document_stores.protocols import DuplicatePolicy
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_docs():
|
||||
@ -24,24 +23,26 @@ def mock_docs():
|
||||
|
||||
|
||||
class TestMemoryRetriever(BaseTestComponent):
|
||||
@pytest.mark.unit
|
||||
def test_save_load(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(), tmp_path)
|
||||
# TODO: We're going to rework these tests when we'll remove BaseTestComponent.
|
||||
# We also need to implement `to_dict` and `from_dict` to test this properly.
|
||||
# @pytest.mark.unit
|
||||
# def test_save_load(self, tmp_path):
|
||||
# self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(MemoryDocumentStore()), tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_load_with_parameters(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path)
|
||||
# @pytest.mark.unit
|
||||
# def test_save_load_with_parameters(self, tmp_path):
|
||||
# self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
retriever = MemoryRetriever()
|
||||
retriever = MemoryRetriever(MemoryDocumentStore())
|
||||
assert retriever.filters is None
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
retriever = MemoryRetriever(filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
||||
retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
||||
assert retriever.filters == {"name": "test.txt"}
|
||||
assert retriever.top_k == 5
|
||||
assert not retriever.scale_score
|
||||
@ -49,7 +50,7 @@ class TestMemoryRetriever(BaseTestComponent):
|
||||
@pytest.mark.unit
|
||||
def test_init_with_invalid_top_k_parameter(self):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryRetriever(top_k=-2, scale_score=False)
|
||||
MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_run(self, mock_docs):
|
||||
@ -57,8 +58,7 @@ class TestMemoryRetriever(BaseTestComponent):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
|
||||
retriever = MemoryRetriever(top_k=top_k)
|
||||
retriever.document_store = ds
|
||||
retriever = MemoryRetriever(ds, top_k=top_k)
|
||||
result = retriever.run(queries=["PHP", "Java"])
|
||||
|
||||
assert "documents" in result
|
||||
@ -68,44 +68,11 @@ class TestMemoryRetriever(BaseTestComponent):
|
||||
assert result["documents"][0][0].content == "PHP is a popular programming language"
|
||||
assert result["documents"][1][0].content == "Java is a popular programming language"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_no_store(self):
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MemoryRetriever needs a DocumentStore to run: set the DocumentStore instance to the self.document_store attribute",
|
||||
):
|
||||
retriever.run(queries=["test"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_not_a_store(self):
|
||||
class MockStore:
|
||||
...
|
||||
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(ValueError, match="'MockStore' is not decorate with @document_store."):
|
||||
retriever.document_store = MockStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_type(self):
|
||||
class MockStore:
|
||||
def count_documents(self) -> int:
|
||||
return 0
|
||||
|
||||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
return []
|
||||
|
||||
def write_documents(
|
||||
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
return None
|
||||
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(ValueError, match="'MockStore' is not decorate with @document_store."):
|
||||
retriever.document_store = MockStore()
|
||||
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
||||
MemoryRetriever(SomeOtherDocumentStore())
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
@ -118,11 +85,10 @@ class TestMemoryRetriever(BaseTestComponent):
|
||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
retriever = MemoryRetriever()
|
||||
retriever = MemoryRetriever(ds)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_store("memory", ds)
|
||||
pipeline.add_component("retriever", retriever, document_store="memory")
|
||||
pipeline.add_component("retriever", retriever)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": {"queries": [query]}})
|
||||
|
||||
assert result
|
||||
@ -143,11 +109,10 @@ class TestMemoryRetriever(BaseTestComponent):
|
||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
retriever = MemoryRetriever()
|
||||
retriever = MemoryRetriever(ds)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_store("memory", ds)
|
||||
pipeline.add_component("retriever", retriever, document_store="memory")
|
||||
pipeline.add_component("retriever", retriever)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": {"queries": [query], "top_k": top_k}})
|
||||
|
||||
assert result
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user