mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 07:48:43 +00:00
feat: extend pipeline.add_component to support stores (#5261)
* add protocol and adapt pipeline * change API in pipeline.add_component * adapt pipeline tests * adapt memoryretriever * additional checks * separate protocol and mixin * review feedback & update tests * pylint * Update haystack/preview/document_stores/protocols.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update haystack/preview/document_stores/memory/document_store.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * docstring of Store * adapt memorydocumentstore * fix tests * remove direct inheritance * pylint * Update haystack/preview/document_stores/mixins.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update test/preview/components/retrievers/test_memory_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * test names * revert suggestion * private self._stores * move asserts out * remove protocols * review feedback * review feedback * fix tests * mypy * review feedback * fix tests & other details * naming * mypy * fix tests * typing * partial review feedback * move .store to input dataclass * Revert "move .store to input dataclass" This reverts commit 53f624b99f3414c89d5134711725b31bd94ef77a. * disable reusing components with stores * disable sharing components with docstores * Update mixins.py * black * upgrade canals & fix tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
adfabdd648
commit
8f3fe85878
@ -1,72 +1,67 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin
|
||||
|
||||
|
||||
@component
|
||||
class MemoryRetriever:
|
||||
class MemoryRetriever(StoreAwareMixin):
|
||||
"""
|
||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||
|
||||
Needs to be connected to a MemoryDocumentStore to run.
|
||||
"""
|
||||
|
||||
class Input:
|
||||
"""
|
||||
Input data for the MemoryRetriever component.
|
||||
|
||||
:param query: The query string for the retriever.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:param scale_score: Whether to scale the BM25 scores or not.
|
||||
:param stores: A dictionary mapping document store names to instances.
|
||||
"""
|
||||
|
||||
queries: List[str]
|
||||
filters: Dict[str, Any]
|
||||
top_k: int
|
||||
scale_score: bool
|
||||
stores: Dict[str, Any]
|
||||
|
||||
class Output:
|
||||
"""
|
||||
Output data from the MemoryRetriever component.
|
||||
|
||||
:param documents: The retrieved documents.
|
||||
"""
|
||||
|
||||
documents: List[List[Document]]
|
||||
supported_stores = [MemoryDocumentStore]
|
||||
|
||||
@component.input
|
||||
def input(self): # type: ignore
|
||||
return MemoryRetriever.Input
|
||||
class Input:
|
||||
"""
|
||||
Input data for the MemoryRetriever component.
|
||||
|
||||
:param query: The query string for the retriever.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:param scale_score: Whether to scale the BM25 scores or not.
|
||||
:param stores: A dictionary mapping document store names to instances.
|
||||
"""
|
||||
|
||||
queries: List[str]
|
||||
filters: Dict[str, Any]
|
||||
top_k: int
|
||||
scale_score: bool
|
||||
|
||||
return Input
|
||||
|
||||
@component.output
|
||||
def output(self): # type: ignore
|
||||
return MemoryRetriever.Output
|
||||
class Output:
|
||||
"""
|
||||
Output data from the MemoryRetriever component.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_store_name: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
):
|
||||
:param documents: The retrieved documents.
|
||||
"""
|
||||
|
||||
documents: List[List[Document]]
|
||||
|
||||
return Output
|
||||
|
||||
def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
|
||||
"""
|
||||
Create a MemoryRetriever component.
|
||||
|
||||
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
|
||||
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
||||
:param top_k: The maximum number of documents to retrieve (default is 10).
|
||||
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
||||
|
||||
:raises ValueError: If the specified top_k is not > 0.
|
||||
"""
|
||||
self.document_store_name = document_store_name
|
||||
if top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
def run(self, data):
|
||||
"""
|
||||
Run the MemoryRetriever on the given input data.
|
||||
|
||||
@ -75,19 +70,14 @@ class MemoryRetriever:
|
||||
|
||||
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
|
||||
"""
|
||||
if self.document_store_name not in data.stores:
|
||||
raise ValueError(
|
||||
f"MemoryRetriever's document store '{self.document_store_name}' not found "
|
||||
f"in input stores {list(data.stores.keys())}"
|
||||
)
|
||||
document_store = data.stores[self.document_store_name]
|
||||
if not isinstance(document_store, MemoryDocumentStore):
|
||||
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
|
||||
self.store: MemoryDocumentStore
|
||||
|
||||
if not self.store:
|
||||
raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute")
|
||||
docs = []
|
||||
for query in data.queries:
|
||||
docs.append(
|
||||
document_store.bm25_retrieval(
|
||||
self.store.bm25_retrieval(
|
||||
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
|
||||
from haystack.preview.document_stores.mixins import StoreAwareMixin
|
||||
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
||||
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
|
||||
|
||||
31
haystack/preview/document_stores/mixins.py
Normal file
31
haystack/preview/document_stores/mixins.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
|
||||
from haystack.preview.document_stores.protocols import Store
|
||||
|
||||
|
||||
class StoreAwareMixin:
|
||||
"""
|
||||
Adds the capability of a component to use a single document store from the `self.store` property.
|
||||
|
||||
To use this mixin you must specify which document stores to support by setting a value to `supported_stores`.
|
||||
To support any document store, set it to `[Store]`.
|
||||
"""
|
||||
|
||||
_store: Optional[Store] = None
|
||||
supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717)
|
||||
|
||||
@property
|
||||
def store(self) -> Optional[Store]:
|
||||
return self._store
|
||||
|
||||
@store.setter
|
||||
def store(self, store: Store):
|
||||
if not isinstance(store, Store):
|
||||
raise ValueError("'store' does not respect the Store Protocol.")
|
||||
if not any(isinstance(store, type_) for type_ in type(self).supported_stores):
|
||||
raise ValueError(
|
||||
f"Store type '{type(store).__name__}' is not compatible with this component. "
|
||||
f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}"
|
||||
)
|
||||
self._store = store
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol, Optional, Dict, Any, List
|
||||
from typing import Protocol, Optional, Dict, Any, List, runtime_checkable
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
@ -15,6 +15,7 @@ class DuplicatePolicy(Enum):
|
||||
FAIL = "fail"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Store(Protocol):
|
||||
"""
|
||||
Stores Documents to be used by the components of a Pipeline.
|
||||
|
||||
@ -8,9 +8,13 @@ from canals.pipeline import (
|
||||
load_pipelines as load_canals_pipelines,
|
||||
save_pipelines as save_canals_pipelines,
|
||||
)
|
||||
from canals.pipeline.sockets import find_input_sockets
|
||||
|
||||
from haystack.preview.document_stores.protocols import Store
|
||||
from haystack.preview.document_stores.mixins import StoreAwareMixin
|
||||
|
||||
|
||||
class NotAStoreError(PipelineError):
|
||||
pass
|
||||
|
||||
|
||||
class NoSuchStoreError(PipelineError):
|
||||
@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stores: Dict[str, Store] = {}
|
||||
self._stores: Dict[str, Store] = {}
|
||||
|
||||
def add_store(self, name: str, store: Store) -> None:
|
||||
"""
|
||||
@ -34,7 +38,12 @@ class Pipeline(CanalsPipeline):
|
||||
:param store: the store object.
|
||||
:returns: None
|
||||
"""
|
||||
self.stores[name] = store
|
||||
if not isinstance(store, Store):
|
||||
raise NotAStoreError(
|
||||
f"This object ({store}) does not respect the Store Protocol, "
|
||||
"so it can't be added to the pipeline with Pipeline.add_store()."
|
||||
)
|
||||
self._stores[name] = store
|
||||
|
||||
def list_stores(self) -> List[str]:
|
||||
"""
|
||||
@ -42,7 +51,7 @@ class Pipeline(CanalsPipeline):
|
||||
|
||||
:returns: a dictionary with all the stores attached to this Pipeline.
|
||||
"""
|
||||
return list(self.stores.keys())
|
||||
return list(self._stores.keys())
|
||||
|
||||
def get_store(self, name: str) -> Store:
|
||||
"""
|
||||
@ -52,33 +61,49 @@ class Pipeline(CanalsPipeline):
|
||||
:returns: the store
|
||||
"""
|
||||
try:
|
||||
return self.stores[name]
|
||||
return self._stores[name]
|
||||
except KeyError as e:
|
||||
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
|
||||
|
||||
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
|
||||
def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None:
|
||||
"""
|
||||
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
|
||||
Make this component available to the pipeline. Components are not connected to anything by default:
|
||||
use `Pipeline.connect()` to connect components together.
|
||||
|
||||
:params data: the inputs to give to the input components of the Pipeline.
|
||||
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
|
||||
:params debug: whether to collect and return debug information.
|
||||
:returns A dictionary with the outputs of the output components of the Pipeline.
|
||||
Component names must be unique, but component instances can be reused if needed.
|
||||
|
||||
If `store` has a value, the pipeline will also connect this component to the requested document store.
|
||||
Note that only components that inherit from StoreAwareMixin can be connected to stores.
|
||||
|
||||
:param name: the name of the component.
|
||||
:param instance: the component instance.
|
||||
:param store: the store this component needs access to, if any.
|
||||
:raises ValueError: if:
|
||||
- a component with the same name already exists
|
||||
- a component requiring a store didn't receive it
|
||||
- a component that didn't expect a store received it
|
||||
:raises PipelineValidationError: if the given instance is not a component
|
||||
:raises NoSuchStoreError: if the given store name is not known to the pipeline
|
||||
"""
|
||||
# Get all nodes in this pipelines instance
|
||||
for node_name in self.graph.nodes:
|
||||
# Get node inputs
|
||||
node = self.graph.nodes[node_name]["instance"]
|
||||
input_params = find_input_sockets(node)
|
||||
if isinstance(instance, StoreAwareMixin):
|
||||
if not store:
|
||||
raise ValueError(f"Component '{name}' needs a store.")
|
||||
|
||||
# If the node needs a store, adds the list of stores to its default inputs
|
||||
if "stores" in input_params:
|
||||
if not hasattr(node, "defaults"):
|
||||
setattr(node, "defaults", {})
|
||||
node.defaults["stores"] = self.stores
|
||||
if store not in self._stores:
|
||||
raise NoSuchStoreError(
|
||||
f"Store named '{store}' not found. "
|
||||
f"Add it with 'pipeline.add_store('{store}', <the docstore instance>)'."
|
||||
)
|
||||
|
||||
# Run the pipeline
|
||||
return super().run(data=data, debug=debug)
|
||||
if instance.store:
|
||||
raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.")
|
||||
|
||||
instance.store = self._stores[store]
|
||||
|
||||
elif store:
|
||||
raise ValueError(f"Component '{name}' doesn't support stores.")
|
||||
|
||||
super().add_component(name, instance)
|
||||
|
||||
|
||||
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):
|
||||
|
||||
@ -79,7 +79,7 @@ dependencies = [
|
||||
"jsonschema",
|
||||
|
||||
# Preview
|
||||
"canals>=0.3,<0.4",
|
||||
"canals==0.3.2",
|
||||
|
||||
# Agent events
|
||||
"events",
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline
|
||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import Store, MemoryDocumentStore
|
||||
|
||||
from test.preview.components.base import BaseTestComponent
|
||||
|
||||
from haystack.preview.document_stores.protocols import DuplicatePolicy
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_docs():
|
||||
@ -21,43 +23,39 @@ def mock_docs():
|
||||
]
|
||||
|
||||
|
||||
class Test_MemoryRetriever(BaseTestComponent):
|
||||
class TestMemoryRetriever(BaseTestComponent):
|
||||
@pytest.mark.unit
|
||||
def test_save_load(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path)
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(), tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_load_with_parameters(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(
|
||||
MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path
|
||||
)
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
retriever = MemoryRetriever(document_store_name="memory")
|
||||
assert retriever.document_store_name == "memory"
|
||||
retriever = MemoryRetriever()
|
||||
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False)
|
||||
assert retriever.document_store_name == "memory-test"
|
||||
retriever = MemoryRetriever(top_k=5, scale_score=False)
|
||||
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_invalid_top_k_parameter(self):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False)
|
||||
MemoryRetriever(top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_run(self, mock_docs):
|
||||
top_k = 5
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory", top_k=top_k)
|
||||
result: MemoryRetriever.Output = mr.run(
|
||||
data=MemoryRetriever.Input(queries=["PHP", "Java"], stores={"memory": ds})
|
||||
)
|
||||
|
||||
retriever = MemoryRetriever(top_k=top_k)
|
||||
retriever.store = ds
|
||||
result = retriever.run(data=retriever.input(queries=["PHP", "Java"]))
|
||||
|
||||
assert getattr(result, "documents")
|
||||
assert len(result.documents) == 2
|
||||
@ -67,26 +65,42 @@ class Test_MemoryRetriever(BaseTestComponent):
|
||||
assert result.documents[1][0].content == "Java is a popular programming language"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_name(self):
|
||||
# Test invalid run with wrong store name
|
||||
ds = MemoryDocumentStore()
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"):
|
||||
invalid_input_data = MemoryRetriever.Input(
|
||||
queries=["test"], top_k=10, scale_score=True, stores={"invalid_store": ds}
|
||||
)
|
||||
mr.run(invalid_input_data)
|
||||
def test_invalid_run_no_store(self):
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(
|
||||
ValueError, match="MemoryRetriever needs a store to run: set the store instance to the self.store attribute"
|
||||
):
|
||||
retriever.run(retriever.input(queries=["test"]))
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_not_a_store(self):
|
||||
class MockStore:
|
||||
...
|
||||
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(ValueError, match="does not respect the Store Protocol"):
|
||||
retriever.store = MockStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_type(self):
|
||||
# Test invalid run with wrong store type
|
||||
ds = MemoryDocumentStore()
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."):
|
||||
invalid_input_data = MemoryRetriever.Input(
|
||||
queries=["test"], top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"}
|
||||
)
|
||||
mr.run(invalid_input_data)
|
||||
class MockStore:
|
||||
def count_documents(self) -> int:
|
||||
return 0
|
||||
|
||||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
return []
|
||||
|
||||
def write_documents(
|
||||
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
return None
|
||||
|
||||
retriever = MemoryRetriever()
|
||||
with pytest.raises(ValueError, match="is not compatible with this component"):
|
||||
retriever.store = MockStore()
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
@ -99,12 +113,12 @@ class Test_MemoryRetriever(BaseTestComponent):
|
||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
retriever = MemoryRetriever()
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", mr)
|
||||
pipeline.add_store("memory", ds)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query])})
|
||||
pipeline.add_component("retriever", retriever, store="memory")
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query])})
|
||||
|
||||
assert result
|
||||
assert "retriever" in result
|
||||
@ -124,12 +138,12 @@ class Test_MemoryRetriever(BaseTestComponent):
|
||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
retriever = MemoryRetriever()
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", mr)
|
||||
pipeline.add_store("memory", ds)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query], top_k=top_k)})
|
||||
pipeline.add_component("retriever", retriever, store="memory")
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query], top_k=top_k)})
|
||||
|
||||
assert result
|
||||
assert "retriever" in result
|
||||
|
||||
@ -1,16 +1,49 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline, component, NoSuchStoreError
|
||||
from haystack.preview import Pipeline, component, NoSuchStoreError, Document
|
||||
from haystack.preview.pipeline import NotAStoreError
|
||||
from haystack.preview.document_stores import StoreAwareMixin, DuplicatePolicy, Store
|
||||
|
||||
|
||||
# Note: we're using a real class instead of a mock because mocks don't play too well with protocols.
|
||||
class MockStore:
|
||||
...
|
||||
def count_documents(self) -> int:
|
||||
return 0
|
||||
|
||||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
return []
|
||||
|
||||
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||
return None
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_pipeline_store_api():
|
||||
def test_add_store():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
pipe = Pipeline()
|
||||
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
assert pipe._stores.get("first_store") == store_1
|
||||
assert pipe._stores.get("second_store") == store_2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_store_wrong_object():
|
||||
pipe = Pipeline()
|
||||
|
||||
with pytest.raises(NotAStoreError, match="does not respect the Store Protocol"):
|
||||
pipe.add_store(name="store", store="I'm surely not a Store object!")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list_stores():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
pipe = Pipeline()
|
||||
@ -20,22 +53,81 @@ def test_pipeline_store_api():
|
||||
|
||||
assert pipe.list_stores() == ["first_store", "second_store"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_store():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
pipe = Pipeline()
|
||||
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
assert pipe.get_store("first_store") == store_1
|
||||
assert pipe.get_store("second_store") == store_2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_store_wrong_name():
|
||||
store_1 = MockStore()
|
||||
pipe = Pipeline()
|
||||
|
||||
with pytest.raises(NoSuchStoreError):
|
||||
pipe.get_store("first_store")
|
||||
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
assert pipe.get_store("first_store") == store_1
|
||||
|
||||
with pytest.raises(NoSuchStoreError):
|
||||
pipe.get_store("third_store")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_pipeline_stores_in_params():
|
||||
def test_add_component_store_aware_component_receives_one_docstore():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent:
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [Store]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
mock = MockComponent()
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
pipe.add_component("component", mock, store="first_store")
|
||||
assert mock.store == store_1
|
||||
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_receives_no_docstore():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [Store]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
stores: Dict[str, Any]
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
@ -49,13 +141,274 @@ def test_pipeline_stores_in_params():
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
assert data.stores == {"first_store": store_1, "second_store": store_2}
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("component", MockComponent())
|
||||
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
|
||||
with pytest.raises(ValueError, match="Component 'component' needs a store."):
|
||||
pipe.add_component("component", MockComponent())
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_store_aware_component_receives_one_docstore():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent:
|
||||
supported_stores = [Store]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
with pytest.raises(ValueError, match="Component 'component' doesn't support stores."):
|
||||
pipe.add_component("component", MockComponent(), store="first_store")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_receives_wrong_docstore_name():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [Store]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
with pytest.raises(NoSuchStoreError, match="Store named 'wrong_store' not found."):
|
||||
pipe.add_component("component", MockComponent(), store="wrong_store")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_receives_correct_docstore_type():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [MockStore]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
mock = MockComponent()
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
pipe.add_component("component", mock, store="second_store")
|
||||
assert mock.store == store_2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_is_reused():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [MockStore]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
mock = MockComponent()
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
pipe.add_component("component", mock, store="second_store")
|
||||
|
||||
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
|
||||
pipe.add_component("component2", mock, store="second_store")
|
||||
|
||||
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
|
||||
pipe.add_component("component2", mock, store="first_store")
|
||||
|
||||
assert mock.store == store_2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_receives_subclass_of_correct_docstore_type():
|
||||
class MockStoreSubclass(MockStore):
|
||||
...
|
||||
|
||||
store_1 = MockStoreSubclass()
|
||||
store_2 = MockStore()
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [MockStore]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
mock = MockComponent()
|
||||
mock2 = MockComponent()
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
pipe.add_component("component", mock, store="first_store")
|
||||
assert mock.store == store_1
|
||||
pipe.add_component("component2", mock2, store="second_store")
|
||||
assert mock2.store == store_2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_does_not_check_supported_stores():
|
||||
class SomethingElse:
|
||||
...
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [SomethingElse]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
MockComponent()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_component_store_aware_component_receives_wrong_docstore_type():
|
||||
store_1 = MockStore()
|
||||
store_2 = MockStore()
|
||||
|
||||
class MockStore2:
|
||||
def count_documents(self) -> int:
|
||||
return 0
|
||||
|
||||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
return []
|
||||
|
||||
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||
return None
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
return None
|
||||
|
||||
@component
|
||||
class MockComponent(StoreAwareMixin):
|
||||
supported_stores = [MockStore2]
|
||||
|
||||
class Input:
|
||||
value: int
|
||||
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
mock = MockComponent()
|
||||
pipe = Pipeline()
|
||||
pipe.add_store(name="first_store", store=store_1)
|
||||
pipe.add_store(name="second_store", store=store_2)
|
||||
|
||||
with pytest.raises(ValueError, match="is not compatible with this component"):
|
||||
pipe.add_component("component", mock, store="second_store")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user