refactor!: rename MemoryDocumentStore and related Retrievers (#6076)

* rename doc store and retrievers

* release note

* fix patch
This commit is contained in:
Stefano Fiorucci 2023-10-17 16:15:16 +02:00 committed by GitHub
parent ec9f898cd6
commit 4e4af99a5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 134 additions and 128 deletions

View File

@ -1,15 +1,15 @@
import json
from haystack.preview import Pipeline, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.components.retrievers import MemoryBM25Retriever
from haystack.preview.document_stores import InMemoryDocumentStore
from haystack.preview.components.retrievers import InMemoryBM25Retriever
from haystack.preview.components.readers import ExtractiveReader
def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader")

View File

@ -3,9 +3,9 @@ import json
import pytest
from haystack.preview import Pipeline, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import InMemoryDocumentStore
from haystack.preview.components.writers import DocumentWriter
from haystack.preview.components.retrievers import MemoryBM25Retriever, MemoryEmbeddingRetriever
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.preview.components.generators.openai.gpt import GPTGenerator
from haystack.preview.components.builders.answer_builder import AnswerBuilder
@ -28,7 +28,7 @@ def test_bm25_rag_pipeline(tmp_path):
\nAnswer:
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
@ -99,7 +99,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
name="text_embedder",
)
rag_pipeline.add_component(
instance=MemoryEmbeddingRetriever(document_store=MemoryDocumentStore()), name="retriever"
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
)
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")

View File

@ -19,7 +19,7 @@ class DocumentCleaner:
Example usage in an indexing pipeline:
```python
document_store = MemoryDocumentStore()
document_store = InMemoryDocumentStore()
p = Pipeline()
p.add_component(instance=TextFileToDocument(), name="text_file_converter")
p.add_component(instance=DocumentCleaner(), name="cleaner")

View File

@ -21,10 +21,10 @@ class TextLanguageClassifier:
Example usage in a retrieval pipeline that passes only English language queries to the retriever:
```python
document_store = MemoryDocumentStore()
document_store = InMemoryDocumentStore()
p = Pipeline()
p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier")
p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
p.connect("text_language_classifier.en", "retriever.query")
p.run({"text_language_classifier": {"text": "What's your query?"}})
```

View File

@ -1,4 +1,4 @@
from haystack.preview.components.retrievers.memory_bm25_retriever import MemoryBM25Retriever
from haystack.preview.components.retrievers.memory_embedding_retriever import MemoryEmbeddingRetriever
from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever
from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever
__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]
__all__ = ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"]

View File

@ -1,36 +1,36 @@
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.preview.document_stores import MemoryDocumentStore, document_store
from haystack.preview.document_stores import InMemoryDocumentStore, document_store
@component
class MemoryBM25Retriever:
class InMemoryBM25Retriever:
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
A component for retrieving documents from a InMemoryDocumentStore using the BM25 algorithm.
Needs to be connected to a MemoryDocumentStore to run.
Needs to be connected to a InMemoryDocumentStore to run.
"""
def __init__(
self,
document_store: MemoryDocumentStore,
document_store: InMemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
):
"""
Create a MemoryBM25Retriever component.
Create a InMemoryBM25Retriever component.
:param document_store: An instance of MemoryDocumentStore.
:param document_store: An instance of InMemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the BM25 score or not. Default is True.
:raises ValueError: If the specified top_k is not > 0.
"""
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("document_store must be an instance of MemoryDocumentStore")
if not isinstance(document_store, InMemoryDocumentStore):
raise ValueError("document_store must be an instance of InMemoryDocumentStore")
self.document_store = document_store
@ -57,7 +57,7 @@ class MemoryBM25Retriever:
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
@ -83,7 +83,7 @@ class MemoryBM25Retriever:
scale_score: Optional[bool] = None,
):
"""
Run the MemoryBM25Retriever on the given input data.
Run the InMemoryBM25Retriever on the given input data.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
@ -91,7 +91,7 @@ class MemoryBM25Retriever:
:param scale_score: Whether to scale the BM25 scores or not.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
:raises ValueError: If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters

View File

@ -1,29 +1,29 @@
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.preview.document_stores import MemoryDocumentStore, document_store
from haystack.preview.document_stores import InMemoryDocumentStore, document_store
@component
class MemoryEmbeddingRetriever:
class InMemoryEmbeddingRetriever:
"""
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
A component for retrieving documents from a InMemoryDocumentStore using a vector similarity metric.
Needs to be connected to a MemoryDocumentStore to run.
Needs to be connected to a InMemoryDocumentStore to run.
"""
def __init__(
self,
document_store: MemoryDocumentStore,
document_store: InMemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a MemoryEmbeddingRetriever component.
Create a InMemoryEmbeddingRetriever component.
:param document_store: An instance of MemoryDocumentStore.
:param document_store: An instance of InMemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
@ -31,8 +31,8 @@ class MemoryEmbeddingRetriever:
:raises ValueError: If the specified top_k is not > 0.
"""
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("document_store must be an instance of MemoryDocumentStore")
if not isinstance(document_store, InMemoryDocumentStore):
raise ValueError("document_store must be an instance of InMemoryDocumentStore")
self.document_store = document_store
@ -65,7 +65,7 @@ class MemoryEmbeddingRetriever:
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryEmbeddingRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
"""
Deserialize this component from a dictionary.
"""
@ -92,7 +92,7 @@ class MemoryEmbeddingRetriever:
return_embedding: Optional[bool] = None,
):
"""
Run the MemoryEmbeddingRetriever on the given input data.
Run the InMemoryEmbeddingRetriever on the given input data.
:param query_embedding: Embedding of the query.
:param filters: A dictionary with filters to narrow down the search space.
@ -101,7 +101,7 @@ class MemoryEmbeddingRetriever:
:param return_embedding: Whether to return the embedding of the retrieved Documents.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
:raises ValueError: If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters

View File

@ -1,12 +1,12 @@
from haystack.preview.document_stores.protocols import DocumentStore, DuplicatePolicy
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore
from haystack.preview.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError
from haystack.preview.document_stores.decorator import document_store
__all__ = [
"DocumentStore",
"DuplicatePolicy",
"MemoryDocumentStore",
"InMemoryDocumentStore",
"DocumentStoreError",
"DuplicateDocumentError",
"MissingDocumentError",

View File

@ -0,0 +1,3 @@
from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore
__all__ = ["InMemoryDocumentStore"]

View File

@ -28,7 +28,7 @@ DOT_PRODUCT_SCALING_FACTOR = 100
@document_store
class MemoryDocumentStore:
class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""
@ -199,7 +199,7 @@ class MemoryDocumentStore:
:param object_ids: The object_ids to delete.
"""
for doc_id in document_ids:
if not doc_id in self.storage.keys():
if doc_id not in self.storage.keys():
raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.")
del self.storage[doc_id]

View File

@ -1,3 +0,0 @@
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
__all__ = ["MemoryDocumentStore"]

View File

@ -0,0 +1,6 @@
---
preview:
- |
Rename `MemoryDocumentStore` to `InMemoryDocumentStore`
Rename `MemoryBM25Retriever` to `InMemoryBM25Retriever`
Rename `MemoryEmbeddingRetriever` to `InMemoryEmbeddingRetriever`

View File

@ -1,10 +1,8 @@
from unittest.mock import MagicMock
import pytest
from haystack.preview import Document, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.document_stores.memory import MemoryDocumentStore
from haystack.preview.document_stores.in_memory import InMemoryDocumentStore
from haystack.preview.components.caching.url_cache_checker import UrlCacheChecker
@ -72,7 +70,7 @@ class TestUrlCacheChecker:
@pytest.mark.unit
def test_run(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
documents = [
Document(text="doc1", metadata={"url": "https://example.com/1"}),
Document(text="doc2", metadata={"url": "https://example.com/2"}),

View File

@ -4,9 +4,9 @@ import pytest
from haystack.preview import Pipeline, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.retrievers.memory_bm25_retriever import MemoryBM25Retriever
from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import InMemoryDocumentStore
@pytest.fixture()
@ -23,14 +23,16 @@ def mock_docs():
class TestMemoryBM25Retriever:
@pytest.mark.unit
def test_init_default(self):
retriever = MemoryBM25Retriever(MemoryDocumentStore())
retriever = InMemoryBM25Retriever(InMemoryDocumentStore())
assert retriever.filters is None
assert retriever.top_k == 10
assert retriever.scale_score
@pytest.mark.unit
def test_init_with_parameters(self):
retriever = MemoryBM25Retriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False)
retriever = InMemoryBM25Retriever(
InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
)
assert retriever.filters == {"name": "test.txt"}
assert retriever.top_k == 5
assert not retriever.scale_score
@ -38,18 +40,18 @@ class TestMemoryBM25Retriever:
@pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryBM25Retriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
InMemoryBM25Retriever(InMemoryDocumentStore(), top_k=-2, scale_score=False)
@pytest.mark.unit
def test_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryBM25Retriever(document_store=document_store)
component = InMemoryBM25Retriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryBM25Retriever",
"type": "InMemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
@ -60,15 +62,15 @@ class TestMemoryBM25Retriever:
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryBM25Retriever(
component = InMemoryBM25Retriever(
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False
)
data = component.to_dict()
assert data == {
"type": "MemoryBM25Retriever",
"type": "InMemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
@ -79,49 +81,49 @@ class TestMemoryBM25Retriever:
@pytest.mark.unit
def test_from_dict(self):
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
data = {
"type": "MemoryBM25Retriever",
"type": "InMemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
},
}
component = MemoryBM25Retriever.from_dict(data)
assert isinstance(component.document_store, MemoryDocumentStore)
component = InMemoryBM25Retriever.from_dict(data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score
@pytest.mark.unit
def test_from_dict_without_docstore(self):
data = {"type": "MemoryBM25Retriever", "init_parameters": {}}
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
MemoryBM25Retriever.from_dict(data)
InMemoryBM25Retriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_without_docstore_type(self):
data = {"type": "MemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
MemoryBM25Retriever.from_dict(data)
InMemoryBM25Retriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "MemoryBM25Retriever",
"type": "InMemoryBM25Retriever",
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
MemoryBM25Retriever.from_dict(data)
InMemoryBM25Retriever.from_dict(data)
@pytest.mark.unit
def test_retriever_valid_run(self, mock_docs):
top_k = 5
ds = MemoryDocumentStore()
ds = InMemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryBM25Retriever(ds, top_k=top_k)
retriever = InMemoryBM25Retriever(ds, top_k=top_k)
result = retriever.run(query="PHP")
assert "documents" in result
@ -131,8 +133,8 @@ class TestMemoryBM25Retriever:
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self):
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
MemoryBM25Retriever(SomeOtherDocumentStore())
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
InMemoryBM25Retriever(SomeOtherDocumentStore())
@pytest.mark.integration
@pytest.mark.parametrize(
@ -143,9 +145,9 @@ class TestMemoryBM25Retriever:
],
)
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
ds = MemoryDocumentStore()
ds = InMemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryBM25Retriever(ds)
retriever = InMemoryBM25Retriever(ds)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
@ -167,9 +169,9 @@ class TestMemoryBM25Retriever:
],
)
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
ds = MemoryDocumentStore()
ds = InMemoryDocumentStore()
ds.write_documents(mock_docs)
retriever = MemoryBM25Retriever(ds)
retriever = InMemoryBM25Retriever(ds)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)

View File

@ -4,23 +4,23 @@ import pytest
from haystack.preview import Pipeline, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.retrievers.memory_embedding_retriever import MemoryEmbeddingRetriever
from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import InMemoryDocumentStore
class TestMemoryEmbeddingRetriever:
@pytest.mark.unit
def test_init_default(self):
retriever = MemoryEmbeddingRetriever(MemoryDocumentStore())
retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
assert retriever.filters is None
assert retriever.top_k == 10
assert retriever.scale_score
@pytest.mark.unit
def test_init_with_parameters(self):
retriever = MemoryEmbeddingRetriever(
MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
retriever = InMemoryEmbeddingRetriever(
InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False
)
assert retriever.filters == {"name": "test.txt"}
assert retriever.top_k == 5
@ -29,18 +29,18 @@ class TestMemoryEmbeddingRetriever:
@pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryEmbeddingRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False)
InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2, scale_score=False)
@pytest.mark.unit
def test_to_dict(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryEmbeddingRetriever(document_store=document_store)
component = InMemoryEmbeddingRetriever(document_store=document_store)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"type": "InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": None,
@ -52,10 +52,10 @@ class TestMemoryEmbeddingRetriever:
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
component = MemoryEmbeddingRetriever(
component = InMemoryEmbeddingRetriever(
document_store=document_store,
filters={"name": "test.txt"},
top_k=5,
@ -64,7 +64,7 @@ class TestMemoryEmbeddingRetriever:
)
data = component.to_dict()
assert data == {
"type": "MemoryEmbeddingRetriever",
"type": "InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
@ -76,46 +76,46 @@ class TestMemoryEmbeddingRetriever:
@pytest.mark.unit
def test_from_dict(self):
document_store_class("MyFakeStore", bases=(MemoryDocumentStore,))
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
data = {
"type": "MemoryEmbeddingRetriever",
"type": "InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"filters": {"name": "test.txt"},
"top_k": 5,
},
}
component = MemoryEmbeddingRetriever.from_dict(data)
assert isinstance(component.document_store, MemoryDocumentStore)
component = InMemoryEmbeddingRetriever.from_dict(data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score
@pytest.mark.unit
def test_from_dict_without_docstore(self):
data = {"type": "MemoryEmbeddingRetriever", "init_parameters": {}}
data = {"type": "InMemoryEmbeddingRetriever", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
MemoryEmbeddingRetriever.from_dict(data)
InMemoryEmbeddingRetriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_without_docstore_type(self):
data = {"type": "MemoryEmbeddingRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
data = {"type": "InMemoryEmbeddingRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
MemoryEmbeddingRetriever.from_dict(data)
InMemoryEmbeddingRetriever.from_dict(data)
@pytest.mark.unit
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "MemoryEmbeddingRetriever",
"type": "InMemoryEmbeddingRetriever",
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
MemoryEmbeddingRetriever.from_dict(data)
InMemoryEmbeddingRetriever.from_dict(data)
@pytest.mark.unit
def test_valid_run(self):
top_k = 3
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
@ -123,7 +123,7 @@ class TestMemoryEmbeddingRetriever:
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
assert "documents" in result
@ -133,12 +133,12 @@ class TestMemoryEmbeddingRetriever:
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self):
SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"):
MemoryEmbeddingRetriever(SomeOtherDocumentStore())
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
InMemoryEmbeddingRetriever(SomeOtherDocumentStore())
@pytest.mark.integration
def test_run_with_pipeline(self):
ds = MemoryDocumentStore(embedding_similarity_function="cosine")
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
top_k = 2
docs = [
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
@ -146,7 +146,7 @@ class TestMemoryEmbeddingRetriever:
Document(text="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)

View File

@ -6,7 +6,7 @@ import pandas as pd
import pytest
from haystack.preview import Document
from haystack.preview.document_stores import DocumentStore, MemoryDocumentStore, DocumentStoreError
from haystack.preview.document_stores import DocumentStore, InMemoryDocumentStore, DocumentStoreError
from haystack.preview.testing.document_store import DocumentStoreBaseTests
@ -14,19 +14,19 @@ from haystack.preview.testing.document_store import DocumentStoreBaseTests
class TestMemoryDocumentStore(DocumentStoreBaseTests):
"""
Test MemoryDocumentStore's specific features
Test InMemoryDocumentStore's specific features
"""
@pytest.fixture
def docstore(self) -> MemoryDocumentStore:
return MemoryDocumentStore()
def docstore(self) -> InMemoryDocumentStore:
return InMemoryDocumentStore()
@pytest.mark.unit
def test_to_dict(self):
store = MemoryDocumentStore()
store = InMemoryDocumentStore()
data = store.to_dict()
assert data == {
"type": "MemoryDocumentStore",
"type": "InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
"bm25_algorithm": "BM25Okapi",
@ -37,7 +37,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
store = MemoryDocumentStore(
store = InMemoryDocumentStore(
bm25_tokenization_regex="custom_regex",
bm25_algorithm="BM25Plus",
bm25_parameters={"key": "value"},
@ -45,7 +45,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
)
data = store.to_dict()
assert data == {
"type": "MemoryDocumentStore",
"type": "InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus",
@ -55,17 +55,17 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
}
@pytest.mark.unit
@patch("haystack.preview.document_stores.memory.document_store.re")
@patch("haystack.preview.document_stores.in_memory.document_store.re")
def test_from_dict(self, mock_regex):
data = {
"type": "MemoryDocumentStore",
"type": "InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "custom_regex",
"bm25_algorithm": "BM25Plus",
"bm25_parameters": {"key": "value"},
},
}
store = MemoryDocumentStore.from_dict(data)
store = InMemoryDocumentStore.from_dict(data)
mock_regex.compile.assert_called_with("custom_regex")
assert store.tokenizer
assert store.bm25_algorithm.__name__ == "BM25Plus"
@ -73,7 +73,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_bm25_retrieval(self, docstore: DocumentStore):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
# Tests if the bm25_retrieval method returns the correct document based on the input query.
docs = [Document(text="Hello world"), Document(text="Haystack supports multiple languages")]
docstore.write_documents(docs)
@ -253,7 +253,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval(self):
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
# Tests if the embedding retrieval method returns the correct document based on the input query embedding.
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
@ -268,7 +268,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_invalid_query(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
docstore.embedding_retrieval(query_embedding=[])
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
@ -277,7 +277,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_no_embeddings(self, caplog):
caplog.set_level(logging.WARNING)
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [Document(text="Hello world"), Document(text="Haystack supports multiple languages")]
docstore.write_documents(docs)
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
@ -287,7 +287,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_some_documents_wo_embeddings(self, caplog):
caplog.set_level(logging.INFO)
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="Haystack supports multiple languages"),
@ -298,7 +298,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_documents_different_embedding_sizes(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0]),
@ -310,7 +310,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_query_documents_different_embedding_sizes(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])]
docstore.write_documents(docs)
@ -322,7 +322,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_with_different_top_k(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
@ -338,7 +338,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_with_scale_score(self):
docstore = MemoryDocumentStore()
docstore = InMemoryDocumentStore()
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
@ -356,7 +356,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_embedding_retrieval_return_embedding(self):
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(text="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
@ -371,7 +371,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_compute_cosine_similarity_scores(self):
docstore = MemoryDocumentStore(embedding_similarity_function="cosine")
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(text="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
Document(text="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
@ -384,7 +384,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.mark.unit
def test_compute_dot_product_similarity_scores(self):
docstore = MemoryDocumentStore(embedding_similarity_function="dot_product")
docstore = InMemoryDocumentStore(embedding_similarity_function="dot_product")
docs = [
Document(text="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
Document(text="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),