mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-28 08:18:52 +00:00
feat: extend pipeline.add_component to support stores (#5261)
* add protocol and adapt pipeline * change API in pipeline.add_component * adapt pipeline tests * adapt memoryretriever * additional checks * separate protocol and mixin * review feedback & update tests * pylint * Update haystack/preview/document_stores/protocols.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update haystack/preview/document_stores/memory/document_store.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * docstring of Store * adapt memorydocumentstore * fix tests * remove direct inheritance * pylint * Update haystack/preview/document_stores/mixins.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * test names * revert suggestion * private self._stores * move asserts out * remove protocols * review feedback * review feedback * fix tests * mypy * review feedback * fix tests & other details * naming * mypy * fix tests * typing * partial review feedback * move .store to input dataclass * Revert "move .store to input dataclass" This reverts commit 53f624b99f3414c89d5134711725b31bd94ef77a. * disable reusing components with stores * disable sharing components with docstores * Update mixins.py * black * upgrade canals & fix tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
adfabdd648
commit
8f3fe85878
@ -1,72 +1,67 @@
|
|||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
from haystack.preview import component, Document
|
from haystack.preview import component, Document
|
||||||
from haystack.preview.document_stores import MemoryDocumentStore
|
from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class MemoryRetriever:
|
class MemoryRetriever(StoreAwareMixin):
|
||||||
"""
|
"""
|
||||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||||
|
|
||||||
|
Needs to be connected to a MemoryDocumentStore to run.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Input:
|
supported_stores = [MemoryDocumentStore]
|
||||||
"""
|
|
||||||
Input data for the MemoryRetriever component.
|
|
||||||
|
|
||||||
:param query: The query string for the retriever.
|
|
||||||
:param filters: A dictionary with filters to narrow down the search space.
|
|
||||||
:param top_k: The maximum number of documents to return.
|
|
||||||
:param scale_score: Whether to scale the BM25 scores or not.
|
|
||||||
:param stores: A dictionary mapping document store names to instances.
|
|
||||||
"""
|
|
||||||
|
|
||||||
queries: List[str]
|
|
||||||
filters: Dict[str, Any]
|
|
||||||
top_k: int
|
|
||||||
scale_score: bool
|
|
||||||
stores: Dict[str, Any]
|
|
||||||
|
|
||||||
class Output:
|
|
||||||
"""
|
|
||||||
Output data from the MemoryRetriever component.
|
|
||||||
|
|
||||||
:param documents: The retrieved documents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
documents: List[List[Document]]
|
|
||||||
|
|
||||||
@component.input
|
@component.input
|
||||||
def input(self): # type: ignore
|
def input(self): # type: ignore
|
||||||
return MemoryRetriever.Input
|
class Input:
|
||||||
|
"""
|
||||||
|
Input data for the MemoryRetriever component.
|
||||||
|
|
||||||
|
:param query: The query string for the retriever.
|
||||||
|
:param filters: A dictionary with filters to narrow down the search space.
|
||||||
|
:param top_k: The maximum number of documents to return.
|
||||||
|
:param scale_score: Whether to scale the BM25 scores or not.
|
||||||
|
:param stores: A dictionary mapping document store names to instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
queries: List[str]
|
||||||
|
filters: Dict[str, Any]
|
||||||
|
top_k: int
|
||||||
|
scale_score: bool
|
||||||
|
|
||||||
|
return Input
|
||||||
|
|
||||||
@component.output
|
@component.output
|
||||||
def output(self): # type: ignore
|
def output(self): # type: ignore
|
||||||
return MemoryRetriever.Output
|
class Output:
|
||||||
|
"""
|
||||||
|
Output data from the MemoryRetriever component.
|
||||||
|
|
||||||
def __init__(
|
:param documents: The retrieved documents.
|
||||||
self,
|
"""
|
||||||
document_store_name: str,
|
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
documents: List[List[Document]]
|
||||||
top_k: int = 10,
|
|
||||||
scale_score: bool = True,
|
return Output
|
||||||
):
|
|
||||||
|
def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
|
||||||
"""
|
"""
|
||||||
Create a MemoryRetriever component.
|
Create a MemoryRetriever component.
|
||||||
|
|
||||||
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
|
|
||||||
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
||||||
:param top_k: The maximum number of documents to retrieve (default is 10).
|
:param top_k: The maximum number of documents to retrieve (default is 10).
|
||||||
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
||||||
|
|
||||||
:raises ValueError: If the specified top_k is not > 0.
|
:raises ValueError: If the specified top_k is not > 0.
|
||||||
"""
|
"""
|
||||||
self.document_store_name = document_store_name
|
|
||||||
if top_k <= 0:
|
if top_k <= 0:
|
||||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||||
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
|
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
|
||||||
|
|
||||||
def run(self, data: Input) -> Output:
|
def run(self, data):
|
||||||
"""
|
"""
|
||||||
Run the MemoryRetriever on the given input data.
|
Run the MemoryRetriever on the given input data.
|
||||||
|
|
||||||
@ -75,19 +70,14 @@ class MemoryRetriever:
|
|||||||
|
|
||||||
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
|
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
|
||||||
"""
|
"""
|
||||||
if self.document_store_name not in data.stores:
|
self.store: MemoryDocumentStore
|
||||||
raise ValueError(
|
|
||||||
f"MemoryRetriever's document store '{self.document_store_name}' not found "
|
|
||||||
f"in input stores {list(data.stores.keys())}"
|
|
||||||
)
|
|
||||||
document_store = data.stores[self.document_store_name]
|
|
||||||
if not isinstance(document_store, MemoryDocumentStore):
|
|
||||||
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
|
|
||||||
|
|
||||||
|
if not self.store:
|
||||||
|
raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute")
|
||||||
docs = []
|
docs = []
|
||||||
for query in data.queries:
|
for query in data.queries:
|
||||||
docs.append(
|
docs.append(
|
||||||
document_store.bm25_retrieval(
|
self.store.bm25_retrieval(
|
||||||
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
|
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
|
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
|
||||||
|
from haystack.preview.document_stores.mixins import StoreAwareMixin
|
||||||
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
||||||
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
|
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
|
||||||
|
|||||||
31
haystack/preview/document_stores/mixins.py
Normal file
31
haystack/preview/document_stores/mixins.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
|
||||||
|
from haystack.preview.document_stores.protocols import Store
|
||||||
|
|
||||||
|
|
||||||
|
class StoreAwareMixin:
|
||||||
|
"""
|
||||||
|
Adds the capability of a component to use a single document store from the `self.store` property.
|
||||||
|
|
||||||
|
To use this mixin you must specify which document stores to support by setting a value to `supported_stores`.
|
||||||
|
To support any document store, set it to `[Store]`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_store: Optional[Store] = None
|
||||||
|
supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def store(self) -> Optional[Store]:
|
||||||
|
return self._store
|
||||||
|
|
||||||
|
@store.setter
|
||||||
|
def store(self, store: Store):
|
||||||
|
if not isinstance(store, Store):
|
||||||
|
raise ValueError("'store' does not respect the Store Protocol.")
|
||||||
|
if not any(isinstance(store, type_) for type_ in type(self).supported_stores):
|
||||||
|
raise ValueError(
|
||||||
|
f"Store type '{type(store).__name__}' is not compatible with this component. "
|
||||||
|
f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}"
|
||||||
|
)
|
||||||
|
self._store = store
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Protocol, Optional, Dict, Any, List
|
from typing import Protocol, Optional, Dict, Any, List, runtime_checkable
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -15,6 +15,7 @@ class DuplicatePolicy(Enum):
|
|||||||
FAIL = "fail"
|
FAIL = "fail"
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Store(Protocol):
|
class Store(Protocol):
|
||||||
"""
|
"""
|
||||||
Stores Documents to be used by the components of a Pipeline.
|
Stores Documents to be used by the components of a Pipeline.
|
||||||
|
|||||||
@ -8,9 +8,13 @@ from canals.pipeline import (
|
|||||||
load_pipelines as load_canals_pipelines,
|
load_pipelines as load_canals_pipelines,
|
||||||
save_pipelines as save_canals_pipelines,
|
save_pipelines as save_canals_pipelines,
|
||||||
)
|
)
|
||||||
from canals.pipeline.sockets import find_input_sockets
|
|
||||||
|
|
||||||
from haystack.preview.document_stores.protocols import Store
|
from haystack.preview.document_stores.protocols import Store
|
||||||
|
from haystack.preview.document_stores.mixins import StoreAwareMixin
|
||||||
|
|
||||||
|
|
||||||
|
class NotAStoreError(PipelineError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NoSuchStoreError(PipelineError):
|
class NoSuchStoreError(PipelineError):
|
||||||
@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stores: Dict[str, Store] = {}
|
self._stores: Dict[str, Store] = {}
|
||||||
|
|
||||||
def add_store(self, name: str, store: Store) -> None:
|
def add_store(self, name: str, store: Store) -> None:
|
||||||
"""
|
"""
|
||||||
@ -34,7 +38,12 @@ class Pipeline(CanalsPipeline):
|
|||||||
:param store: the store object.
|
:param store: the store object.
|
||||||
:returns: None
|
:returns: None
|
||||||
"""
|
"""
|
||||||
self.stores[name] = store
|
if not isinstance(store, Store):
|
||||||
|
raise NotAStoreError(
|
||||||
|
f"This object ({store}) does not respect the Store Protocol, "
|
||||||
|
"so it can't be added to the pipeline with Pipeline.add_store()."
|
||||||
|
)
|
||||||
|
self._stores[name] = store
|
||||||
|
|
||||||
def list_stores(self) -> List[str]:
|
def list_stores(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@ -42,7 +51,7 @@ class Pipeline(CanalsPipeline):
|
|||||||
|
|
||||||
:returns: a dictionary with all the stores attached to this Pipeline.
|
:returns: a dictionary with all the stores attached to this Pipeline.
|
||||||
"""
|
"""
|
||||||
return list(self.stores.keys())
|
return list(self._stores.keys())
|
||||||
|
|
||||||
def get_store(self, name: str) -> Store:
|
def get_store(self, name: str) -> Store:
|
||||||
"""
|
"""
|
||||||
@ -52,33 +61,49 @@ class Pipeline(CanalsPipeline):
|
|||||||
:returns: the store
|
:returns: the store
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.stores[name]
|
return self._stores[name]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
|
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
|
||||||
|
|
||||||
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
|
def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
|
Make this component available to the pipeline. Components are not connected to anything by default:
|
||||||
|
use `Pipeline.connect()` to connect components together.
|
||||||
|
|
||||||
:params data: the inputs to give to the input components of the Pipeline.
|
Component names must be unique, but component instances can be reused if needed.
|
||||||
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
|
|
||||||
:params debug: whether to collect and return debug information.
|
If `store` has a value, the pipeline will also connect this component to the requested document store.
|
||||||
:returns A dictionary with the outputs of the output components of the Pipeline.
|
Note that only components that inherit from StoreAwareMixin can be connected to stores.
|
||||||
|
|
||||||
|
:param name: the name of the component.
|
||||||
|
:param instance: the component instance.
|
||||||
|
:param store: the store this component needs access to, if any.
|
||||||
|
:raises ValueError: if:
|
||||||
|
- a component with the same name already exists
|
||||||
|
- a component requiring a store didn't receive it
|
||||||
|
- a component that didn't expect a store received it
|
||||||
|
:raises PipelineValidationError: if the given instance is not a component
|
||||||
|
:raises NoSuchStoreError: if the given store name is not known to the pipeline
|
||||||
"""
|
"""
|
||||||
# Get all nodes in this pipelines instance
|
if isinstance(instance, StoreAwareMixin):
|
||||||
for node_name in self.graph.nodes:
|
if not store:
|
||||||
# Get node inputs
|
raise ValueError(f"Component '{name}' needs a store.")
|
||||||
node = self.graph.nodes[node_name]["instance"]
|
|
||||||
input_params = find_input_sockets(node)
|
|
||||||
|
|
||||||
# If the node needs a store, adds the list of stores to its default inputs
|
if store not in self._stores:
|
||||||
if "stores" in input_params:
|
raise NoSuchStoreError(
|
||||||
if not hasattr(node, "defaults"):
|
f"Store named '{store}' not found. "
|
||||||
setattr(node, "defaults", {})
|
f"Add it with 'pipeline.add_store('{store}', <the docstore instance>)'."
|
||||||
node.defaults["stores"] = self.stores
|
)
|
||||||
|
|
||||||
# Run the pipeline
|
if instance.store:
|
||||||
return super().run(data=data, debug=debug)
|
raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.")
|
||||||
|
|
||||||
|
instance.store = self._stores[store]
|
||||||
|
|
||||||
|
elif store:
|
||||||
|
raise ValueError(f"Component '{name}' doesn't support stores.")
|
||||||
|
|
||||||
|
super().add_component(name, instance)
|
||||||
|
|
||||||
|
|
||||||
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):
|
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):
|
||||||
|
|||||||
@ -79,7 +79,7 @@ dependencies = [
|
|||||||
"jsonschema",
|
"jsonschema",
|
||||||
|
|
||||||
# Preview
|
# Preview
|
||||||
"canals>=0.3,<0.4",
|
"canals==0.3.2",
|
||||||
|
|
||||||
# Agent events
|
# Agent events
|
||||||
"events",
|
"events",
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.preview import Pipeline
|
from haystack.preview import Pipeline
|
||||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
||||||
from haystack.preview.dataclasses import Document
|
from haystack.preview.dataclasses import Document
|
||||||
from haystack.preview.document_stores import MemoryDocumentStore
|
from haystack.preview.document_stores import Store, MemoryDocumentStore
|
||||||
|
|
||||||
from test.preview.components.base import BaseTestComponent
|
from test.preview.components.base import BaseTestComponent
|
||||||
|
|
||||||
|
from haystack.preview.document_stores.protocols import DuplicatePolicy
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_docs():
|
def mock_docs():
|
||||||
@ -21,43 +23,39 @@ def mock_docs():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Test_MemoryRetriever(BaseTestComponent):
|
class TestMemoryRetriever(BaseTestComponent):
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_save_load(self, tmp_path):
|
def test_save_load(self, tmp_path):
|
||||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path)
|
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(), tmp_path)
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_save_load_with_parameters(self, tmp_path):
|
def test_save_load_with_parameters(self, tmp_path):
|
||||||
self.assert_can_be_saved_and_loaded_in_pipeline(
|
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path)
|
||||||
MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_default(self):
|
def test_init_default(self):
|
||||||
retriever = MemoryRetriever(document_store_name="memory")
|
retriever = MemoryRetriever()
|
||||||
assert retriever.document_store_name == "memory"
|
|
||||||
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
|
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_with_parameters(self):
|
def test_init_with_parameters(self):
|
||||||
retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False)
|
retriever = MemoryRetriever(top_k=5, scale_score=False)
|
||||||
assert retriever.document_store_name == "memory-test"
|
|
||||||
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
|
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_with_invalid_top_k_parameter(self):
|
def test_init_with_invalid_top_k_parameter(self):
|
||||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||||
MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False)
|
MemoryRetriever(top_k=-2, scale_score=False)
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_valid_run(self, mock_docs):
|
def test_valid_run(self, mock_docs):
|
||||||
top_k = 5
|
top_k = 5
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
mr = MemoryRetriever(document_store_name="memory", top_k=top_k)
|
|
||||||
result: MemoryRetriever.Output = mr.run(
|
retriever = MemoryRetriever(top_k=top_k)
|
||||||
data=MemoryRetriever.Input(queries=["PHP", "Java"], stores={"memory": ds})
|
retriever.store = ds
|
||||||
)
|
result = retriever.run(data=retriever.input(queries=["PHP", "Java"]))
|
||||||
|
|
||||||
assert getattr(result, "documents")
|
assert getattr(result, "documents")
|
||||||
assert len(result.documents) == 2
|
assert len(result.documents) == 2
|
||||||
@ -67,26 +65,42 @@ class Test_MemoryRetriever(BaseTestComponent):
|
|||||||
assert result.documents[1][0].content == "Java is a popular programming language"
|
assert result.documents[1][0].content == "Java is a popular programming language"
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_invalid_run_wrong_store_name(self):
|
def test_invalid_run_no_store(self):
|
||||||
# Test invalid run with wrong store name
|
retriever = MemoryRetriever()
|
||||||
ds = MemoryDocumentStore()
|
with pytest.raises(
|
||||||
mr = MemoryRetriever(document_store_name="memory")
|
ValueError, match="MemoryRetriever needs a store to run: set the store instance to the self.store attribute"
|
||||||
with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"):
|
):
|
||||||
invalid_input_data = MemoryRetriever.Input(
|
retriever.run(retriever.input(queries=["test"]))
|
||||||
queries=["test"], top_k=10, scale_score=True, stores={"invalid_store": ds}
|
|
||||||
)
|
@pytest.mark.unit
|
||||||
mr.run(invalid_input_data)
|
def test_invalid_run_not_a_store(self):
|
||||||
|
class MockStore:
|
||||||
|
...
|
||||||
|
|
||||||
|
retriever = MemoryRetriever()
|
||||||
|
with pytest.raises(ValueError, match="does not respect the Store Protocol"):
|
||||||
|
retriever.store = MockStore()
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_invalid_run_wrong_store_type(self):
|
def test_invalid_run_wrong_store_type(self):
|
||||||
# Test invalid run with wrong store type
|
class MockStore:
|
||||||
ds = MemoryDocumentStore()
|
def count_documents(self) -> int:
|
||||||
mr = MemoryRetriever(document_store_name="memory")
|
return 0
|
||||||
with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."):
|
|
||||||
invalid_input_data = MemoryRetriever.Input(
|
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||||
queries=["test"], top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"}
|
return []
|
||||||
)
|
|
||||||
mr.run(invalid_input_data)
|
def write_documents(
|
||||||
|
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_documents(self, document_ids: List[str]) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
retriever = MemoryRetriever()
|
||||||
|
with pytest.raises(ValueError, match="is not compatible with this component"):
|
||||||
|
retriever.store = MockStore()
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -99,12 +113,12 @@ class Test_MemoryRetriever(BaseTestComponent):
|
|||||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
mr = MemoryRetriever(document_store_name="memory")
|
retriever = MemoryRetriever()
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_component("retriever", mr)
|
|
||||||
pipeline.add_store("memory", ds)
|
pipeline.add_store("memory", ds)
|
||||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query])})
|
pipeline.add_component("retriever", retriever, store="memory")
|
||||||
|
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query])})
|
||||||
|
|
||||||
assert result
|
assert result
|
||||||
assert "retriever" in result
|
assert "retriever" in result
|
||||||
@ -124,12 +138,12 @@ class Test_MemoryRetriever(BaseTestComponent):
|
|||||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||||
ds = MemoryDocumentStore()
|
ds = MemoryDocumentStore()
|
||||||
ds.write_documents(mock_docs)
|
ds.write_documents(mock_docs)
|
||||||
mr = MemoryRetriever(document_store_name="memory")
|
retriever = MemoryRetriever()
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_component("retriever", mr)
|
|
||||||
pipeline.add_store("memory", ds)
|
pipeline.add_store("memory", ds)
|
||||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query], top_k=top_k)})
|
pipeline.add_component("retriever", retriever, store="memory")
|
||||||
|
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query], top_k=top_k)})
|
||||||
|
|
||||||
assert result
|
assert result
|
||||||
assert "retriever" in result
|
assert "retriever" in result
|
||||||
|
|||||||
@ -1,16 +1,49 @@
|
|||||||
from typing import Dict, Any
|
from typing import Any, Optional, Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.preview import Pipeline, component, NoSuchStoreError
|
from haystack.preview import Pipeline, component, NoSuchStoreError, Document
|
||||||
|
from haystack.preview.pipeline import NotAStoreError
|
||||||
|
from haystack.preview.document_stores import StoreAwareMixin, DuplicatePolicy, Store
|
||||||
|
|
||||||
|
|
||||||
|
# Note: we're using a real class instead of a mock because mocks don't play too well with protocols.
|
||||||
class MockStore:
|
class MockStore:
|
||||||
...
|
def count_documents(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_documents(self, document_ids: List[str]) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_pipeline_store_api():
|
def test_add_store():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
pipe = Pipeline()
|
||||||
|
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
assert pipe._stores.get("first_store") == store_1
|
||||||
|
assert pipe._stores.get("second_store") == store_2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_store_wrong_object():
|
||||||
|
pipe = Pipeline()
|
||||||
|
|
||||||
|
with pytest.raises(NotAStoreError, match="does not respect the Store Protocol"):
|
||||||
|
pipe.add_store(name="store", store="I'm surely not a Store object!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_list_stores():
|
||||||
store_1 = MockStore()
|
store_1 = MockStore()
|
||||||
store_2 = MockStore()
|
store_2 = MockStore()
|
||||||
pipe = Pipeline()
|
pipe = Pipeline()
|
||||||
@ -20,22 +53,81 @@ def test_pipeline_store_api():
|
|||||||
|
|
||||||
assert pipe.list_stores() == ["first_store", "second_store"]
|
assert pipe.list_stores() == ["first_store", "second_store"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_get_store():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
pipe = Pipeline()
|
||||||
|
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
assert pipe.get_store("first_store") == store_1
|
assert pipe.get_store("first_store") == store_1
|
||||||
assert pipe.get_store("second_store") == store_2
|
assert pipe.get_store("second_store") == store_2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_get_store_wrong_name():
|
||||||
|
store_1 = MockStore()
|
||||||
|
pipe = Pipeline()
|
||||||
|
|
||||||
|
with pytest.raises(NoSuchStoreError):
|
||||||
|
pipe.get_store("first_store")
|
||||||
|
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
assert pipe.get_store("first_store") == store_1
|
||||||
|
|
||||||
with pytest.raises(NoSuchStoreError):
|
with pytest.raises(NoSuchStoreError):
|
||||||
pipe.get_store("third_store")
|
pipe.get_store("third_store")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_pipeline_stores_in_params():
|
def test_add_component_store_aware_component_receives_one_docstore():
|
||||||
store_1 = MockStore()
|
store_1 = MockStore()
|
||||||
store_2 = MockStore()
|
store_2 = MockStore()
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class MockComponent:
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [Store]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
mock = MockComponent()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
pipe.add_component("component", mock, store="first_store")
|
||||||
|
assert mock.store == store_1
|
||||||
|
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_receives_no_docstore():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [Store]
|
||||||
|
|
||||||
class Input:
|
class Input:
|
||||||
value: int
|
value: int
|
||||||
stores: Dict[str, Any]
|
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
value: int
|
value: int
|
||||||
@ -49,13 +141,274 @@ def test_pipeline_stores_in_params():
|
|||||||
return MockComponent.Output
|
return MockComponent.Output
|
||||||
|
|
||||||
def run(self, data: Input) -> Output:
|
def run(self, data: Input) -> Output:
|
||||||
assert data.stores == {"first_store": store_1, "second_store": store_2}
|
|
||||||
return MockComponent.Output(value=data.value)
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
pipe = Pipeline()
|
pipe = Pipeline()
|
||||||
pipe.add_component("component", MockComponent())
|
|
||||||
|
|
||||||
pipe.add_store(name="first_store", store=store_1)
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
pipe.add_store(name="second_store", store=store_2)
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
|
with pytest.raises(ValueError, match="Component 'component' needs a store."):
|
||||||
|
pipe.add_component("component", MockComponent())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_non_store_aware_component_receives_one_docstore():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
supported_stores = [Store]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Component 'component' doesn't support stores."):
|
||||||
|
pipe.add_component("component", MockComponent(), store="first_store")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_receives_wrong_docstore_name():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [Store]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
with pytest.raises(NoSuchStoreError, match="Store named 'wrong_store' not found."):
|
||||||
|
pipe.add_component("component", MockComponent(), store="wrong_store")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_receives_correct_docstore_type():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [MockStore]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
mock = MockComponent()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
pipe.add_component("component", mock, store="second_store")
|
||||||
|
assert mock.store == store_2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_is_reused():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [MockStore]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
mock = MockComponent()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
pipe.add_component("component", mock, store="second_store")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
|
||||||
|
pipe.add_component("component2", mock, store="second_store")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
|
||||||
|
pipe.add_component("component2", mock, store="first_store")
|
||||||
|
|
||||||
|
assert mock.store == store_2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_receives_subclass_of_correct_docstore_type():
|
||||||
|
class MockStoreSubclass(MockStore):
|
||||||
|
...
|
||||||
|
|
||||||
|
store_1 = MockStoreSubclass()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [MockStore]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
mock = MockComponent()
|
||||||
|
mock2 = MockComponent()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
pipe.add_component("component", mock, store="first_store")
|
||||||
|
assert mock.store == store_1
|
||||||
|
pipe.add_component("component2", mock2, store="second_store")
|
||||||
|
assert mock2.store == store_2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_does_not_check_supported_stores():
|
||||||
|
class SomethingElse:
|
||||||
|
...
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [SomethingElse]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
MockComponent()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_add_component_store_aware_component_receives_wrong_docstore_type():
|
||||||
|
store_1 = MockStore()
|
||||||
|
store_2 = MockStore()
|
||||||
|
|
||||||
|
class MockStore2:
|
||||||
|
def count_documents(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_documents(self, document_ids: List[str]) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent(StoreAwareMixin):
|
||||||
|
supported_stores = [MockStore2]
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
value: int
|
||||||
|
|
||||||
|
@component.input
|
||||||
|
def input(self):
|
||||||
|
return MockComponent.Input
|
||||||
|
|
||||||
|
@component.output
|
||||||
|
def output(self):
|
||||||
|
return MockComponent.Output
|
||||||
|
|
||||||
|
def run(self, data: Input) -> Output:
|
||||||
|
return MockComponent.Output(value=data.value)
|
||||||
|
|
||||||
|
mock = MockComponent()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_store(name="first_store", store=store_1)
|
||||||
|
pipe.add_store(name="second_store", store=store_2)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="is not compatible with this component"):
|
||||||
|
pipe.add_component("component", mock, store="second_store")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user