mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
refactor!: rename MemoryDocumentStore and related Retrievers (#6076)
* rename doc store and retrievers * release note * fix patch
This commit is contained in:
parent
ec9f898cd6
commit
4e4af99a5e
@ -1,15 +1,15 @@
|
||||
import json
|
||||
|
||||
from haystack.preview import Pipeline, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.components.retrievers import MemoryBM25Retriever
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
from haystack.preview.components.retrievers import InMemoryBM25Retriever
|
||||
from haystack.preview.components.readers import ExtractiveReader
|
||||
|
||||
|
||||
def test_extractive_qa_pipeline(tmp_path):
|
||||
# Create the pipeline
|
||||
qa_pipeline = Pipeline()
|
||||
qa_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
|
||||
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.connect("retriever", "reader")
|
||||
|
||||
|
||||
@ -3,9 +3,9 @@ import json
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
from haystack.preview.components.writers import DocumentWriter
|
||||
from haystack.preview.components.retrievers import MemoryBM25Retriever, MemoryEmbeddingRetriever
|
||||
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.builders.answer_builder import AnswerBuilder
|
||||
@ -28,7 +28,7 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
\nAnswer:
|
||||
"""
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
|
||||
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
@ -99,7 +99,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
name="text_embedder",
|
||||
)
|
||||
rag_pipeline.add_component(
|
||||
instance=MemoryEmbeddingRetriever(document_store=MemoryDocumentStore()), name="retriever"
|
||||
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
|
||||
)
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
|
||||
@ -19,7 +19,7 @@ class DocumentCleaner:
|
||||
Example usage in an indexing pipeline:
|
||||
|
||||
```python
|
||||
document_store = MemoryDocumentStore()
|
||||
document_store = InMemoryDocumentStore()
|
||||
p = Pipeline()
|
||||
p.add_component(instance=TextFileToDocument(), name="text_file_converter")
|
||||
p.add_component(instance=DocumentCleaner(), name="cleaner")
|
||||
|
||||
@ -21,10 +21,10 @@ class TextLanguageClassifier:
|
||||
Example usage in a retrieval pipeline that passes only English language queries to the retriever:
|
||||
|
||||
```python
|
||||
document_store = MemoryDocumentStore()
|
||||
document_store = InMemoryDocumentStore()
|
||||
p = Pipeline()
|
||||
p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier")
|
||||
p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
p.connect("text_language_classifier.en", "retriever.query")
|
||||
p.run({"text_language_classifier": {"text": "What's your query?"}})
|
||||
```
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from haystack.preview.components.retrievers.memory_bm25_retriever import MemoryBM25Retriever
|
||||
from haystack.preview.components.retrievers.memory_embedding_retriever import MemoryEmbeddingRetriever
|
||||
from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever
|
||||
from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever
|
||||
|
||||
__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]
|
||||
__all__ = ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"]
|
||||
|
||||
@ -1,36 +1,36 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
|
||||
from haystack.preview.document_stores import MemoryDocumentStore, document_store
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore, document_store
|
||||
|
||||
|
||||
@component
|
||||
class MemoryBM25Retriever:
|
||||
class InMemoryBM25Retriever:
|
||||
"""
|
||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||
A component for retrieving documents from a InMemoryDocumentStore using the BM25 algorithm.
|
||||
|
||||
Needs to be connected to a MemoryDocumentStore to run.
|
||||
Needs to be connected to a InMemoryDocumentStore to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_store: MemoryDocumentStore,
|
||||
document_store: InMemoryDocumentStore,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a MemoryBM25Retriever component.
|
||||
Create a InMemoryBM25Retriever component.
|
||||
|
||||
:param document_store: An instance of MemoryDocumentStore.
|
||||
:param document_store: An instance of InMemoryDocumentStore.
|
||||
:param filters: A dictionary with filters to narrow down the search space. Default is None.
|
||||
:param top_k: The maximum number of documents to retrieve. Default is 10.
|
||||
:param scale_score: Whether to scale the BM25 score or not. Default is True.
|
||||
|
||||
:raises ValueError: If the specified top_k is not > 0.
|
||||
"""
|
||||
if not isinstance(document_store, MemoryDocumentStore):
|
||||
raise ValueError("document_store must be an instance of MemoryDocumentStore")
|
||||
if not isinstance(document_store, InMemoryDocumentStore):
|
||||
raise ValueError("document_store must be an instance of InMemoryDocumentStore")
|
||||
|
||||
self.document_store = document_store
|
||||
|
||||
@ -57,7 +57,7 @@ class MemoryBM25Retriever:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryBM25Retriever":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
@ -83,7 +83,7 @@ class MemoryBM25Retriever:
|
||||
scale_score: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Run the MemoryBM25Retriever on the given input data.
|
||||
Run the InMemoryBM25Retriever on the given input data.
|
||||
|
||||
:param query: The query string for the retriever.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
@ -91,7 +91,7 @@ class MemoryBM25Retriever:
|
||||
:param scale_score: Whether to scale the BM25 scores or not.
|
||||
:return: The retrieved documents.
|
||||
|
||||
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
||||
:raises ValueError: If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
|
||||
"""
|
||||
if filters is None:
|
||||
filters = self.filters
|
||||
@ -1,29 +1,29 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
|
||||
from haystack.preview.document_stores import MemoryDocumentStore, document_store
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore, document_store
|
||||
|
||||
|
||||
@component
|
||||
class MemoryEmbeddingRetriever:
|
||||
class InMemoryEmbeddingRetriever:
|
||||
"""
|
||||
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
|
||||
A component for retrieving documents from a InMemoryDocumentStore using a vector similarity metric.
|
||||
|
||||
Needs to be connected to a MemoryDocumentStore to run.
|
||||
Needs to be connected to a InMemoryDocumentStore to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_store: MemoryDocumentStore,
|
||||
document_store: InMemoryDocumentStore,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
return_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a MemoryEmbeddingRetriever component.
|
||||
Create a InMemoryEmbeddingRetriever component.
|
||||
|
||||
:param document_store: An instance of MemoryDocumentStore.
|
||||
:param document_store: An instance of InMemoryDocumentStore.
|
||||
:param filters: A dictionary with filters to narrow down the search space. Default is None.
|
||||
:param top_k: The maximum number of documents to retrieve. Default is 10.
|
||||
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
|
||||
@ -31,8 +31,8 @@ class MemoryEmbeddingRetriever:
|
||||
|
||||
:raises ValueError: If the specified top_k is not > 0.
|
||||
"""
|
||||
if not isinstance(document_store, MemoryDocumentStore):
|
||||
raise ValueError("document_store must be an instance of MemoryDocumentStore")
|
||||
if not isinstance(document_store, InMemoryDocumentStore):
|
||||
raise ValueError("document_store must be an instance of InMemoryDocumentStore")
|
||||
|
||||
self.document_store = document_store
|
||||
|
||||
@ -65,7 +65,7 @@ class MemoryEmbeddingRetriever:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryEmbeddingRetriever":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
@ -92,7 +92,7 @@ class MemoryEmbeddingRetriever:
|
||||
return_embedding: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Run the MemoryEmbeddingRetriever on the given input data.
|
||||
Run the InMemoryEmbeddingRetriever on the given input data.
|
||||
|
||||
:param query_embedding: Embedding of the query.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
@ -101,7 +101,7 @@ class MemoryEmbeddingRetriever:
|
||||
:param return_embedding: Whether to return the embedding of the retrieved Documents.
|
||||
:return: The retrieved documents.
|
||||
|
||||
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
|
||||
:raises ValueError: If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
|
||||
"""
|
||||
if filters is None:
|
||||
filters = self.filters
|
||||
@ -1,12 +1,12 @@
|
||||
from haystack.preview.document_stores.protocols import DocumentStore, DuplicatePolicy
|
||||
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
||||
from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore
|
||||
from haystack.preview.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError
|
||||
from haystack.preview.document_stores.decorator import document_store
|
||||
|
||||
__all__ = [
|
||||
"DocumentStore",
|
||||
"DuplicatePolicy",
|
||||
"MemoryDocumentStore",
|
||||
"InMemoryDocumentStore",
|
||||
"DocumentStoreError",
|
||||
"DuplicateDocumentError",
|
||||
"MissingDocumentError",
|
||||
|
||||
3
haystack/preview/document_stores/in_memory/__init__.py
Normal file
3
haystack/preview/document_stores/in_memory/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore
|
||||
|
||||
__all__ = ["InMemoryDocumentStore"]
|
||||
@ -28,7 +28,7 @@ DOT_PRODUCT_SCALING_FACTOR = 100
|
||||
|
||||
|
||||
@document_store
|
||||
class MemoryDocumentStore:
|
||||
class InMemoryDocumentStore:
|
||||
"""
|
||||
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
||||
"""
|
||||
@ -199,7 +199,7 @@ class MemoryDocumentStore:
|
||||
:param object_ids: The object_ids to delete.
|
||||
"""
|
||||
for doc_id in document_ids:
|
||||
if not doc_id in self.storage.keys():
|
||||
if doc_id not in self.storage.keys():
|
||||
raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.")
|
||||
del self.storage[doc_id]
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
||||
|
||||
__all__ = ["MemoryDocumentStore"]
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Rename `MemoryDocumentStore` to `InMemoryDocumentStore`
|
||||
Rename `MemoryBM25Retriever` to `InMemoryBM25Retriever`
|
||||
Rename `MemoryEmbeddingRetriever` to `InMemoryEmbeddingRetriever`
|
||||
@ -1,10 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document, DeserializationError
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.document_stores.memory import MemoryDocumentStore
|
||||
from haystack.preview.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.preview.components.caching.url_cache_checker import UrlCacheChecker
|
||||
|
||||
|
||||
@ -72,7 +70,7 @@ class TestUrlCacheChecker:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
documents = [
|
||||
Document(text="doc1", metadata={"url": "https://example.com/1"}),
|
||||
Document(text="doc2", metadata={"url": "https://example.com/2"}),
|
||||
|
||||
@ -4,9 +4,9 @@ import pytest
|
||||
|
||||
from haystack.preview import Pipeline, DeserializationError
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.components.retrievers.memory_bm25_retriever import MemoryBM25Retriever
|
||||
from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -23,14 +23,16 @@ def mock_docs():
|
||||
class TestMemoryBM25Retriever:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
retriever = MemoryBM25Retriever(MemoryDocumentStore())
|
||||
retriever = InMemoryBM25Retriever(InMemoryDocumentStore())
|
||||
assert retriever.filters is None
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
retriever = MemoryBM25Retriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
|
||||
retriever = InMemoryBM25Retriever(
|
||||
InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||
)
|
||||
assert retriever.filters == {"name": "test.txt"}
|
||||
assert retriever.top_k == 5
|
||||
assert not retriever.scale_score
|
||||
@ -38,18 +40,18 @@ class TestMemoryBM25Retriever:
|
||||
@pytest.mark.unit
|
||||
def test_init_with_invalid_top_k_parameter(self):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryBM25Retriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
InMemoryBM25Retriever(InMemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryBM25Retriever(document_store=document_store)
|
||||
component = InMemoryBM25Retriever(document_store=document_store)
|
||||
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryBM25Retriever",
|
||||
"type": "InMemoryBM25Retriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": None,
|
||||
@ -60,15 +62,15 @@ class TestMemoryBM25Retriever:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryBM25Retriever(
|
||||
component = InMemoryBM25Retriever(
|
||||
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryBM25Retriever",
|
||||
"type": "InMemoryBM25Retriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
@ -79,49 +81,49 @@ class TestMemoryBM25Retriever:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
data = {
|
||||
"type": "MemoryBM25Retriever",
|
||||
"type": "InMemoryBM25Retriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
},
|
||||
}
|
||||
component = MemoryBM25Retriever.from_dict(data)
|
||||
assert isinstance(component.document_store, MemoryDocumentStore)
|
||||
component = InMemoryBM25Retriever.from_dict(data)
|
||||
assert isinstance(component.document_store, InMemoryDocumentStore)
|
||||
assert component.filters == {"name": "test.txt"}
|
||||
assert component.top_k == 5
|
||||
assert component.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "MemoryBM25Retriever", "init_parameters": {}}
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
MemoryBM25Retriever.from_dict(data)
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "MemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||
MemoryBM25Retriever.from_dict(data)
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "MemoryBM25Retriever",
|
||||
"type": "InMemoryBM25Retriever",
|
||||
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
|
||||
MemoryBM25Retriever.from_dict(data)
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retriever_valid_run(self, mock_docs):
|
||||
top_k = 5
|
||||
ds = MemoryDocumentStore()
|
||||
ds = InMemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
|
||||
retriever = MemoryBM25Retriever(ds, top_k=top_k)
|
||||
retriever = InMemoryBM25Retriever(ds, top_k=top_k)
|
||||
result = retriever.run(query="PHP")
|
||||
|
||||
assert "documents" in result
|
||||
@ -131,8 +133,8 @@ class TestMemoryBM25Retriever:
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_type(self):
|
||||
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
||||
MemoryBM25Retriever(SomeOtherDocumentStore())
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
|
||||
InMemoryBM25Retriever(SomeOtherDocumentStore())
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
@ -143,9 +145,9 @@ class TestMemoryBM25Retriever:
|
||||
],
|
||||
)
|
||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||
ds = MemoryDocumentStore()
|
||||
ds = InMemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
retriever = MemoryBM25Retriever(ds)
|
||||
retriever = InMemoryBM25Retriever(ds)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", retriever)
|
||||
@ -167,9 +169,9 @@ class TestMemoryBM25Retriever:
|
||||
],
|
||||
)
|
||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||
ds = MemoryDocumentStore()
|
||||
ds = InMemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
retriever = MemoryBM25Retriever(ds)
|
||||
retriever = InMemoryBM25Retriever(ds)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", retriever)
|
||||
@ -4,23 +4,23 @@ import pytest
|
||||
|
||||
from haystack.preview import Pipeline, DeserializationError
|
||||
from haystack.preview.testing.factory import document_store_class
|
||||
from haystack.preview.components.retrievers.memory_embedding_retriever import MemoryEmbeddingRetriever
|
||||
from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
|
||||
|
||||
class TestMemoryEmbeddingRetriever:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
retriever = MemoryEmbeddingRetriever(MemoryDocumentStore())
|
||||
retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
|
||||
assert retriever.filters is None
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
retriever = MemoryEmbeddingRetriever(
|
||||
MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||
retriever = InMemoryEmbeddingRetriever(
|
||||
InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
|
||||
)
|
||||
assert retriever.filters == {"name": "test.txt"}
|
||||
assert retriever.top_k == 5
|
||||
@ -29,18 +29,18 @@ class TestMemoryEmbeddingRetriever:
|
||||
@pytest.mark.unit
|
||||
def test_init_with_invalid_top_k_parameter(self):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryEmbeddingRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryEmbeddingRetriever(document_store=document_store)
|
||||
component = InMemoryEmbeddingRetriever(document_store=document_store)
|
||||
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryEmbeddingRetriever",
|
||||
"type": "InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": None,
|
||||
@ -52,10 +52,10 @@ class TestMemoryEmbeddingRetriever:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
document_store = MyFakeStore()
|
||||
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
|
||||
component = MemoryEmbeddingRetriever(
|
||||
component = InMemoryEmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
filters={"name": "test.txt"},
|
||||
top_k=5,
|
||||
@ -64,7 +64,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryEmbeddingRetriever",
|
||||
"type": "InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
@ -76,46 +76,46 @@ class TestMemoryEmbeddingRetriever:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
|
||||
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
|
||||
data = {
|
||||
"type": "MemoryEmbeddingRetriever",
|
||||
"type": "InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
},
|
||||
}
|
||||
component = MemoryEmbeddingRetriever.from_dict(data)
|
||||
assert isinstance(component.document_store, MemoryDocumentStore)
|
||||
component = InMemoryEmbeddingRetriever.from_dict(data)
|
||||
assert isinstance(component.document_store, InMemoryDocumentStore)
|
||||
assert component.filters == {"name": "test.txt"}
|
||||
assert component.top_k == 5
|
||||
assert component.scale_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "MemoryEmbeddingRetriever", "init_parameters": {}}
|
||||
data = {"type": "InMemoryEmbeddingRetriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
MemoryEmbeddingRetriever.from_dict(data)
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "MemoryEmbeddingRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
data = {"type": "InMemoryEmbeddingRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||
MemoryEmbeddingRetriever.from_dict(data)
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "MemoryEmbeddingRetriever",
|
||||
"type": "InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
|
||||
MemoryEmbeddingRetriever.from_dict(data)
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_run(self):
|
||||
top_k = 3
|
||||
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docs = [
|
||||
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
@ -123,7 +123,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
]
|
||||
ds.write_documents(docs)
|
||||
|
||||
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
|
||||
|
||||
assert "documents" in result
|
||||
@ -133,12 +133,12 @@ class TestMemoryEmbeddingRetriever:
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_type(self):
|
||||
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
|
||||
MemoryEmbeddingRetriever(SomeOtherDocumentStore())
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
|
||||
InMemoryEmbeddingRetriever(SomeOtherDocumentStore())
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_pipeline(self):
|
||||
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
top_k = 2
|
||||
docs = [
|
||||
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
@ -146,7 +146,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
Document(text="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
|
||||
]
|
||||
ds.write_documents(docs)
|
||||
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", retriever)
|
||||
@ -6,7 +6,7 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.document_stores import DocumentStore, MemoryDocumentStore, DocumentStoreError
|
||||
from haystack.preview.document_stores import DocumentStore, InMemoryDocumentStore, DocumentStoreError
|
||||
|
||||
|
||||
from haystack.preview.testing.document_store import DocumentStoreBaseTests
|
||||
@ -14,19 +14,19 @@ from haystack.preview.testing.document_store import DocumentStoreBaseTests
|
||||
|
||||
class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
"""
|
||||
Test MemoryDocumentStore's specific features
|
||||
Test InMemoryDocumentStore's specific features
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def docstore(self) -> MemoryDocumentStore:
|
||||
return MemoryDocumentStore()
|
||||
def docstore(self) -> InMemoryDocumentStore:
|
||||
return InMemoryDocumentStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
store = MemoryDocumentStore()
|
||||
store = InMemoryDocumentStore()
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryDocumentStore",
|
||||
"type": "InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
|
||||
"bm25_algorithm": "BM25Okapi",
|
||||
@ -37,7 +37,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
store = MemoryDocumentStore(
|
||||
store = InMemoryDocumentStore(
|
||||
bm25_tokenization_regex="custom_regex",
|
||||
bm25_algorithm="BM25Plus",
|
||||
bm25_parameters={"key": "value"},
|
||||
@ -45,7 +45,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
)
|
||||
data = store.to_dict()
|
||||
assert data == {
|
||||
"type": "MemoryDocumentStore",
|
||||
"type": "InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
@ -55,17 +55,17 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.document_stores.memory.document_store.re")
|
||||
@patch("haystack.preview.document_stores.in_memory.document_store.re")
|
||||
def test_from_dict(self, mock_regex):
|
||||
data = {
|
||||
"type": "MemoryDocumentStore",
|
||||
"type": "InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "custom_regex",
|
||||
"bm25_algorithm": "BM25Plus",
|
||||
"bm25_parameters": {"key": "value"},
|
||||
},
|
||||
}
|
||||
store = MemoryDocumentStore.from_dict(data)
|
||||
store = InMemoryDocumentStore.from_dict(data)
|
||||
mock_regex.compile.assert_called_with("custom_regex")
|
||||
assert store.tokenizer
|
||||
assert store.bm25_algorithm.__name__ == "BM25Plus"
|
||||
@ -73,7 +73,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval(self, docstore: DocumentStore):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
# Tests if the bm25_retrieval method returns the correct document based on the input query.
|
||||
docs = [Document(text="Hello world"), Document(text="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
@ -253,7 +253,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
# Tests if the embedding retrieval method returns the correct document based on the input query embedding.
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
@ -268,7 +268,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_invalid_query(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
||||
docstore.embedding_retrieval(query_embedding=[])
|
||||
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
||||
@ -277,7 +277,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_no_embeddings(self, caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [Document(text="Hello world"), Document(text="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
||||
@ -287,7 +287,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_some_documents_wo_embeddings(self, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="Haystack supports multiple languages"),
|
||||
@ -298,7 +298,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_documents_different_embedding_sizes(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0]),
|
||||
@ -310,7 +310,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_query_documents_different_embedding_sizes(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
@ -322,7 +322,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_with_different_top_k(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
@ -338,7 +338,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_with_scale_score(self):
|
||||
docstore = MemoryDocumentStore()
|
||||
docstore = InMemoryDocumentStore()
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
@ -356,7 +356,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedding_retrieval_return_embedding(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docs = [
|
||||
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
||||
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
@ -371,7 +371,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compute_cosine_similarity_scores(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
||||
docs = [
|
||||
Document(text="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
||||
Document(text="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
@ -384,7 +384,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compute_dot_product_similarity_scores(self):
|
||||
docstore = MemoryDocumentStore(embedding_similarity_function="dot_product")
|
||||
docstore = InMemoryDocumentStore(embedding_similarity_function="dot_product")
|
||||
docs = [
|
||||
Document(text="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
||||
Document(text="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
||||
Loading…
x
Reference in New Issue
Block a user