feat: extend pipeline.add_component to support stores (#5261)

* add protocol and adapt pipeline

* change API in pipeline.add_component

* adapt pipeline tests

* adapt memoryretriever

* additional checks

* separate protocol and mixin

* review feedback & update tests

* pylint

* Update haystack/preview/document_stores/protocols.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update haystack/preview/document_stores/memory/document_store.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* docstring of Store

* adapt memorydocumentstore

* fix tests

* remove direct inheritance

* pylint

* Update haystack/preview/document_stores/mixins.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* test names

* revert suggestion

* private self._stores

* move asserts out

* remove protocols

* review feedback

* review feedback

* fix tests

* mypy

* review feedback

* fix tests & other details

* naming

* mypy

* fix tests

* typing

* partial review feedback

* move .store to input dataclass

* Revert "move .store to input dataclass"

This reverts commit 53f624b99f3414c89d5134711725b31bd94ef77a.

* disable reusing components with stores

* disable sharing components with docstores

* Update mixins.py

* black

* upgrade canals & fix tests

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-07-17 15:06:19 +02:00 committed by GitHub
parent adfabdd648
commit 8f3fe85878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 538 additions and 123 deletions

View File

@ -1,72 +1,67 @@
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin
@component
class MemoryRetriever:
class MemoryRetriever(StoreAwareMixin):
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
Needs to be connected to a MemoryDocumentStore to run.
"""
class Input:
"""
Input data for the MemoryRetriever component.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""
queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool
stores: Dict[str, Any]
class Output:
"""
Output data from the MemoryRetriever component.
:param documents: The retrieved documents.
"""
documents: List[List[Document]]
supported_stores = [MemoryDocumentStore]
@component.input
def input(self): # type: ignore
return MemoryRetriever.Input
class Input:
"""
Input data for the MemoryRetriever component.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""
queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool
return Input
@component.output
def output(self): # type: ignore
return MemoryRetriever.Output
class Output:
"""
Output data from the MemoryRetriever component.
def __init__(
self,
document_store_name: str,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
):
:param documents: The retrieved documents.
"""
documents: List[List[Document]]
return Output
def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
"""
Create a MemoryRetriever component.
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).
:raises ValueError: If the specified top_k is not > 0.
"""
self.document_store_name = document_store_name
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
def run(self, data: Input) -> Output:
def run(self, data):
"""
Run the MemoryRetriever on the given input data.
@ -75,19 +70,14 @@ class MemoryRetriever:
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if self.document_store_name not in data.stores:
raise ValueError(
f"MemoryRetriever's document store '{self.document_store_name}' not found "
f"in input stores {list(data.stores.keys())}"
)
document_store = data.stores[self.document_store_name]
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
self.store: MemoryDocumentStore
if not self.store:
raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute")
docs = []
for query in data.queries:
docs.append(
document_store.bm25_retrieval(
self.store.bm25_retrieval(
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
)
)

View File

@ -1,3 +1,4 @@
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
from haystack.preview.document_stores.mixins import StoreAwareMixin
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError

View File

@ -0,0 +1,31 @@
from typing import List, Optional, Type
from haystack.preview.document_stores.protocols import Store
class StoreAwareMixin:
"""
Adds the capability of a component to use a single document store from the `self.store` property.
To use this mixin you must specify which document stores to support by setting a value to `supported_stores`.
To support any document store, set it to `[Store]`.
"""
_store: Optional[Store] = None
supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717)
@property
def store(self) -> Optional[Store]:
return self._store
@store.setter
def store(self, store: Store):
if not isinstance(store, Store):
raise ValueError("'store' does not respect the Store Protocol.")
if not any(isinstance(store, type_) for type_ in type(self).supported_stores):
raise ValueError(
f"Store type '{type(store).__name__}' is not compatible with this component. "
f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}"
)
self._store = store

View File

@ -1,4 +1,4 @@
from typing import Protocol, Optional, Dict, Any, List
from typing import Protocol, Optional, Dict, Any, List, runtime_checkable
import logging
from enum import Enum
@ -15,6 +15,7 @@ class DuplicatePolicy(Enum):
FAIL = "fail"
@runtime_checkable
class Store(Protocol):
"""
Stores Documents to be used by the components of a Pipeline.

View File

@ -8,9 +8,13 @@ from canals.pipeline import (
load_pipelines as load_canals_pipelines,
save_pipelines as save_canals_pipelines,
)
from canals.pipeline.sockets import find_input_sockets
from haystack.preview.document_stores.protocols import Store
from haystack.preview.document_stores.mixins import StoreAwareMixin
class NotAStoreError(PipelineError):
pass
class NoSuchStoreError(PipelineError):
@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline):
def __init__(self):
super().__init__()
self.stores: Dict[str, Store] = {}
self._stores: Dict[str, Store] = {}
def add_store(self, name: str, store: Store) -> None:
"""
@ -34,7 +38,12 @@ class Pipeline(CanalsPipeline):
:param store: the store object.
:returns: None
"""
self.stores[name] = store
if not isinstance(store, Store):
raise NotAStoreError(
f"This object ({store}) does not respect the Store Protocol, "
"so it can't be added to the pipeline with Pipeline.add_store()."
)
self._stores[name] = store
def list_stores(self) -> List[str]:
"""
@ -42,7 +51,7 @@ class Pipeline(CanalsPipeline):
:returns: a dictionary with all the stores attached to this Pipeline.
"""
return list(self.stores.keys())
return list(self._stores.keys())
def get_store(self, name: str) -> Store:
"""
@ -52,33 +61,49 @@ class Pipeline(CanalsPipeline):
:returns: the store
"""
try:
return self.stores[name]
return self._stores[name]
except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None:
"""
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
Make this component available to the pipeline. Components are not connected to anything by default:
use `Pipeline.connect()` to connect components together.
:params data: the inputs to give to the input components of the Pipeline.
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
:params debug: whether to collect and return debug information.
:returns A dictionary with the outputs of the output components of the Pipeline.
Component names must be unique, but component instances can be reused if needed.
If `store` has a value, the pipeline will also connect this component to the requested document store.
Note that only components that inherit from StoreAwareMixin can be connected to stores.
:param name: the name of the component.
:param instance: the component instance.
:param store: the store this component needs access to, if any.
:raises ValueError: if:
- a component with the same name already exists
- a component requiring a store didn't receive it
- a component that didn't expect a store received it
:raises PipelineValidationError: if the given instance is not a component
:raises NoSuchStoreError: if the given store name is not known to the pipeline
"""
# Get all nodes in this pipelines instance
for node_name in self.graph.nodes:
# Get node inputs
node = self.graph.nodes[node_name]["instance"]
input_params = find_input_sockets(node)
if isinstance(instance, StoreAwareMixin):
if not store:
raise ValueError(f"Component '{name}' needs a store.")
# If the node needs a store, adds the list of stores to its default inputs
if "stores" in input_params:
if not hasattr(node, "defaults"):
setattr(node, "defaults", {})
node.defaults["stores"] = self.stores
if store not in self._stores:
raise NoSuchStoreError(
f"Store named '{store}' not found. "
f"Add it with 'pipeline.add_store('{store}', <the docstore instance>)'."
)
# Run the pipeline
return super().run(data=data, debug=debug)
if instance.store:
raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.")
instance.store = self._stores[store]
elif store:
raise ValueError(f"Component '{name}' doesn't support stores.")
super().add_component(name, instance)
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):

View File

@ -79,7 +79,7 @@ dependencies = [
"jsonschema",
# Preview
"canals>=0.3,<0.4",
"canals==0.3.2",
# Agent events
"events",

View File

@ -1,14 +1,16 @@
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
import pytest
from haystack.preview import Pipeline
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import Store, MemoryDocumentStore
from test.preview.components.base import BaseTestComponent
from haystack.preview.document_stores.protocols import DuplicatePolicy
@pytest.fixture()
def mock_docs():
@ -21,43 +23,39 @@ def mock_docs():
]
class Test_MemoryRetriever(BaseTestComponent):
class TestMemoryRetriever(BaseTestComponent):
@pytest.mark.unit
def test_save_load(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path)
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(), tmp_path)
@pytest.mark.unit
def test_save_load_with_parameters(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(
MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path
)
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path)
@pytest.mark.unit
def test_init_default(self):
retriever = MemoryRetriever(document_store_name="memory")
assert retriever.document_store_name == "memory"
retriever = MemoryRetriever()
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
@pytest.mark.unit
def test_init_with_parameters(self):
retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False)
assert retriever.document_store_name == "memory-test"
retriever = MemoryRetriever(top_k=5, scale_score=False)
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
@pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False)
MemoryRetriever(top_k=-2, scale_score=False)
@pytest.mark.unit
def test_valid_run(self, mock_docs):
top_k = 5
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory", top_k=top_k)
result: MemoryRetriever.Output = mr.run(
data=MemoryRetriever.Input(queries=["PHP", "Java"], stores={"memory": ds})
)
retriever = MemoryRetriever(top_k=top_k)
retriever.store = ds
result = retriever.run(data=retriever.input(queries=["PHP", "Java"]))
assert getattr(result, "documents")
assert len(result.documents) == 2
@ -67,26 +65,42 @@ class Test_MemoryRetriever(BaseTestComponent):
assert result.documents[1][0].content == "Java is a popular programming language"
@pytest.mark.unit
def test_invalid_run_wrong_store_name(self):
# Test invalid run with wrong store name
ds = MemoryDocumentStore()
mr = MemoryRetriever(document_store_name="memory")
with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"):
invalid_input_data = MemoryRetriever.Input(
queries=["test"], top_k=10, scale_score=True, stores={"invalid_store": ds}
)
mr.run(invalid_input_data)
def test_invalid_run_no_store(self):
retriever = MemoryRetriever()
with pytest.raises(
ValueError, match="MemoryRetriever needs a store to run: set the store instance to the self.store attribute"
):
retriever.run(retriever.input(queries=["test"]))
@pytest.mark.unit
def test_invalid_run_not_a_store(self):
class MockStore:
...
retriever = MemoryRetriever()
with pytest.raises(ValueError, match="does not respect the Store Protocol"):
retriever.store = MockStore()
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self):
# Test invalid run with wrong store type
ds = MemoryDocumentStore()
mr = MemoryRetriever(document_store_name="memory")
with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."):
invalid_input_data = MemoryRetriever.Input(
queries=["test"], top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"}
)
mr.run(invalid_input_data)
class MockStore:
def count_documents(self) -> int:
return 0
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
return []
def write_documents(
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL
) -> None:
return None
def delete_documents(self, document_ids: List[str]) -> None:
return None
retriever = MemoryRetriever()
with pytest.raises(ValueError, match="is not compatible with this component"):
retriever.store = MockStore()
@pytest.mark.integration
@pytest.mark.parametrize(
@ -99,12 +113,12 @@ class Test_MemoryRetriever(BaseTestComponent):
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory")
retriever = MemoryRetriever()
pipeline = Pipeline()
pipeline.add_component("retriever", mr)
pipeline.add_store("memory", ds)
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query])})
pipeline.add_component("retriever", retriever, store="memory")
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query])})
assert result
assert "retriever" in result
@ -124,12 +138,12 @@ class Test_MemoryRetriever(BaseTestComponent):
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory")
retriever = MemoryRetriever()
pipeline = Pipeline()
pipeline.add_component("retriever", mr)
pipeline.add_store("memory", ds)
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query], top_k=top_k)})
pipeline.add_component("retriever", retriever, store="memory")
result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query], top_k=top_k)})
assert result
assert "retriever" in result

View File

@ -1,16 +1,49 @@
from typing import Dict, Any
from typing import Any, Optional, Dict, List
import pytest
from haystack.preview import Pipeline, component, NoSuchStoreError
from haystack.preview import Pipeline, component, NoSuchStoreError, Document
from haystack.preview.pipeline import NotAStoreError
from haystack.preview.document_stores import StoreAwareMixin, DuplicatePolicy, Store
# Note: we're using a real class instead of a mock because mocks don't play too well with protocols.
class MockStore:
...
def count_documents(self) -> int:
return 0
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
return []
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
return None
def delete_documents(self, document_ids: List[str]) -> None:
return None
@pytest.mark.unit
def test_pipeline_store_api():
def test_add_store():
store_1 = MockStore()
store_2 = MockStore()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
assert pipe._stores.get("first_store") == store_1
assert pipe._stores.get("second_store") == store_2
@pytest.mark.unit
def test_add_store_wrong_object():
pipe = Pipeline()
with pytest.raises(NotAStoreError, match="does not respect the Store Protocol"):
pipe.add_store(name="store", store="I'm surely not a Store object!")
@pytest.mark.unit
def test_list_stores():
store_1 = MockStore()
store_2 = MockStore()
pipe = Pipeline()
@ -20,22 +53,81 @@ def test_pipeline_store_api():
assert pipe.list_stores() == ["first_store", "second_store"]
@pytest.mark.unit
def test_get_store():
store_1 = MockStore()
store_2 = MockStore()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
assert pipe.get_store("first_store") == store_1
assert pipe.get_store("second_store") == store_2
@pytest.mark.unit
def test_get_store_wrong_name():
store_1 = MockStore()
pipe = Pipeline()
with pytest.raises(NoSuchStoreError):
pipe.get_store("first_store")
pipe.add_store(name="first_store", store=store_1)
assert pipe.get_store("first_store") == store_1
with pytest.raises(NoSuchStoreError):
pipe.get_store("third_store")
@pytest.mark.unit
def test_pipeline_stores_in_params():
def test_add_component_store_aware_component_receives_one_docstore():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent:
class MockComponent(StoreAwareMixin):
supported_stores = [Store]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
mock = MockComponent()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
pipe.add_component("component", mock, store="first_store")
assert mock.store == store_1
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
@pytest.mark.unit
def test_add_component_store_aware_component_receives_no_docstore():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent(StoreAwareMixin):
supported_stores = [Store]
class Input:
value: int
stores: Dict[str, Any]
class Output:
value: int
@ -49,13 +141,274 @@ def test_pipeline_stores_in_params():
return MockComponent.Output
def run(self, data: Input) -> Output:
assert data.stores == {"first_store": store_1, "second_store": store_2}
return MockComponent.Output(value=data.value)
pipe = Pipeline()
pipe.add_component("component", MockComponent())
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}
with pytest.raises(ValueError, match="Component 'component' needs a store."):
pipe.add_component("component", MockComponent())
@pytest.mark.unit
def test_non_store_aware_component_receives_one_docstore():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent:
supported_stores = [Store]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
with pytest.raises(ValueError, match="Component 'component' doesn't support stores."):
pipe.add_component("component", MockComponent(), store="first_store")
@pytest.mark.unit
def test_add_component_store_aware_component_receives_wrong_docstore_name():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent(StoreAwareMixin):
supported_stores = [Store]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
with pytest.raises(NoSuchStoreError, match="Store named 'wrong_store' not found."):
pipe.add_component("component", MockComponent(), store="wrong_store")
@pytest.mark.unit
def test_add_component_store_aware_component_receives_correct_docstore_type():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent(StoreAwareMixin):
supported_stores = [MockStore]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
mock = MockComponent()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
pipe.add_component("component", mock, store="second_store")
assert mock.store == store_2
@pytest.mark.unit
def test_add_component_store_aware_component_is_reused():
store_1 = MockStore()
store_2 = MockStore()
@component
class MockComponent(StoreAwareMixin):
supported_stores = [MockStore]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
mock = MockComponent()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
pipe.add_component("component", mock, store="second_store")
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
pipe.add_component("component2", mock, store="second_store")
with pytest.raises(ValueError, match="Reusing components with stores is not supported"):
pipe.add_component("component2", mock, store="first_store")
assert mock.store == store_2
@pytest.mark.unit
def test_add_component_store_aware_component_receives_subclass_of_correct_docstore_type():
class MockStoreSubclass(MockStore):
...
store_1 = MockStoreSubclass()
store_2 = MockStore()
@component
class MockComponent(StoreAwareMixin):
supported_stores = [MockStore]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
mock = MockComponent()
mock2 = MockComponent()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
pipe.add_component("component", mock, store="first_store")
assert mock.store == store_1
pipe.add_component("component2", mock2, store="second_store")
assert mock2.store == store_2
@pytest.mark.unit
def test_add_component_store_aware_component_does_not_check_supported_stores():
class SomethingElse:
...
@component
class MockComponent(StoreAwareMixin):
supported_stores = [SomethingElse]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
MockComponent()
@pytest.mark.unit
def test_add_component_store_aware_component_receives_wrong_docstore_type():
store_1 = MockStore()
store_2 = MockStore()
class MockStore2:
def count_documents(self) -> int:
return 0
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
return []
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
return None
def delete_documents(self, document_ids: List[str]) -> None:
return None
@component
class MockComponent(StoreAwareMixin):
supported_stores = [MockStore2]
class Input:
value: int
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
return MockComponent.Output(value=data.value)
mock = MockComponent()
pipe = Pipeline()
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
with pytest.raises(ValueError, match="is not compatible with this component"):
pipe.add_component("component", mock, store="second_store")