mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
chore: rearrange the core package, move tests and clean up (#6427)
* rearrange code * fix tests * relnote * merge test modules * remove extra * rearrange draw tests * forgot * remove unused import
This commit is contained in:
parent
9a7fd6f2ce
commit
00e1dd6eb8
4
.github/workflows/license_compliance.yml
vendored
4
.github/workflows/license_compliance.yml
vendored
@ -27,10 +27,10 @@ jobs:
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Get direct dependencies, all extras
|
||||
- name: Get direct dependencies
|
||||
run: |
|
||||
pip install toml
|
||||
python .github/utils/pyproject_to_requirements.py pyproject.toml --extra all > ${{ env.REQUIREMENTS_FILE }}
|
||||
python .github/utils/pyproject_to_requirements.py pyproject.toml > ${{ env.REQUIREMENTS_FILE }}
|
||||
|
||||
- name: Check Licenses
|
||||
id: license_check_report
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -156,3 +156,6 @@ haystack/json-schemas
|
||||
|
||||
# http cache (requests-cache)
|
||||
**/http_cache.sqlite
|
||||
|
||||
# ruff
|
||||
.ruff_cache
|
||||
|
||||
@ -14,7 +14,7 @@ from pydantic import BaseModel, ValidationError
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger("canals.pipeline.pipeline").setLevel(logging.DEBUG)
|
||||
logging.getLogger("haystack.core.pipeline.pipeline").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
# Let's define a simple schema for the data we want to extract from a passsage via the LLM
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from canals import component
|
||||
from canals.serialization import default_from_dict, default_to_dict
|
||||
from canals.errors import DeserializationError, ComponentError
|
||||
from haystack.core.component import component
|
||||
from haystack.core.serialization import default_from_dict, default_to_dict
|
||||
from haystack.core.errors import DeserializationError, ComponentError
|
||||
from haystack.pipeline import Pipeline
|
||||
from haystack.dataclasses import Document, Answer, GeneratedAnswer, ExtractedAnswer
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from math import inf
|
||||
from typing import List, Optional
|
||||
from canals.component.types import Variadic
|
||||
from haystack.core.component.types import Variadic
|
||||
|
||||
from haystack import component, Document
|
||||
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals.__about__ import __version__
|
||||
|
||||
from canals.component import component, Component
|
||||
from canals.pipeline.pipeline import Pipeline
|
||||
|
||||
__all__ = ["component", "Component", "Pipeline"]
|
||||
@ -1,11 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals.pipeline.pipeline import Pipeline
|
||||
from canals.errors import (
|
||||
PipelineError,
|
||||
PipelineRuntimeError,
|
||||
PipelineValidationError,
|
||||
PipelineConnectError,
|
||||
PipelineMaxLoops,
|
||||
)
|
||||
@ -1,116 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from canals import component, Component
|
||||
from canals.serialization import default_to_dict, default_from_dict
|
||||
|
||||
|
||||
def component_class(
|
||||
name: str,
|
||||
input_types: Optional[Dict[str, Any]] = None,
|
||||
output_types: Optional[Dict[str, Any]] = None,
|
||||
output: Optional[Dict[str, Any]] = None,
|
||||
bases: Optional[Tuple[type, ...]] = None,
|
||||
extra_fields: Optional[Dict[str, Any]] = None,
|
||||
) -> Type[Component]:
|
||||
"""
|
||||
Utility class to create a Component class with the given name and input and output types.
|
||||
|
||||
If `output` is set but `output_types` is not, `output_types` will be set to the types of the values in `output`.
|
||||
Though if `output_types` is set but `output` is not the component's `run` method will return a dictionary
|
||||
of the same keys as `output_types` all with a value of None.
|
||||
|
||||
### Usage
|
||||
|
||||
Create a component class with default input and output types:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory("MyFakeComponent")
|
||||
component = MyFakeComponent()
|
||||
output = component.run(value=1)
|
||||
assert output == {"value": None}
|
||||
```
|
||||
|
||||
Create a component class with an "value" input of type `int` and with a "value" output of `10`:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
input_types={"value": int},
|
||||
output={"value": 10}
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
output = component.run(value=1)
|
||||
assert output == {"value": 10}
|
||||
```
|
||||
|
||||
Create a component class with a custom base class:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
bases=(MyBaseClass,)
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
assert isinstance(component, MyBaseClass)
|
||||
```
|
||||
|
||||
Create a component class with an extra field `my_field`:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
extra_fields={"my_field": 10}
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
assert component.my_field == 10
|
||||
```
|
||||
|
||||
Args:
|
||||
name: Name of the component class
|
||||
input_types: Dictionary of string and type that defines the inputs of the component,
|
||||
if set to None created component will expect a single input "value" of Any type.
|
||||
Defaults to None.
|
||||
output_types: Dictionary of string and type that defines the outputs of the component,
|
||||
if set to None created component will return a single output "value" of NoneType and None value.
|
||||
Defaults to None.
|
||||
output: Actual output dictionary returned by the created component run,
|
||||
is set to None it will return a dictionary of string and None values.
|
||||
Keys will be the same as the keys of output_types. Defaults to None.
|
||||
bases: Base classes for this component, if set to None only base is object. Defaults to None.
|
||||
extra_fields: Extra fields for the Component, defaults to None.
|
||||
|
||||
:return: A class definition that can be used as a component.
|
||||
"""
|
||||
if input_types is None:
|
||||
input_types = {"value": Any}
|
||||
if output_types is None and output is not None:
|
||||
output_types = {key: type(value) for key, value in output.items()}
|
||||
elif output_types is None:
|
||||
output_types = {"value": type(None)}
|
||||
|
||||
def init(self):
|
||||
component.set_input_types(self, **input_types)
|
||||
component.set_output_types(self, **output_types)
|
||||
|
||||
# Both arguments are necessary to correctly define
|
||||
# run but pylint doesn't like that we don't use them.
|
||||
# It's fine ignoring the warning here.
|
||||
def run(self, **kwargs): # pylint: disable=unused-argument
|
||||
if output is not None:
|
||||
return output
|
||||
return {name: None for name in output_types.keys()}
|
||||
|
||||
def to_dict(self):
|
||||
return default_to_dict(self)
|
||||
|
||||
def from_dict(cls, data: Dict[str, Any]):
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
fields = {"__init__": init, "run": run, "to_dict": to_dict, "from_dict": classmethod(from_dict)}
|
||||
if extra_fields is not None:
|
||||
fields = {**fields, **extra_fields}
|
||||
|
||||
if bases is None:
|
||||
bases = (object,)
|
||||
|
||||
cls = type(name, bases, fields)
|
||||
return component(cls)
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals.component.component import component, Component
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.component.component import component, Component
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
|
||||
__all__ = ["component", "Component", "InputSocket", "OutputSocket"]
|
||||
@ -35,7 +35,7 @@
|
||||
|
||||
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
|
||||
|
||||
The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
|
||||
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
|
||||
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
|
||||
the `warm_up()` method.
|
||||
|
||||
@ -74,8 +74,8 @@ from typing import Protocol, runtime_checkable, Any
|
||||
from types import new_class
|
||||
from copy import deepcopy
|
||||
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from canals.errors import ComponentError
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.errors import ComponentError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2,9 +2,9 @@ import itertools
|
||||
from typing import Optional, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from canals.type_utils import _type_name, _types_are_compatible
|
||||
from canals.errors import PipelineConnectError
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.type_utils import _type_name, _types_are_compatible
|
||||
from haystack.core.errors import PipelineConnectError
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -5,7 +5,7 @@ from typing import get_args, List, Type
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from canals.component.types import CANALS_VARIADIC_ANNOTATION
|
||||
from haystack.core.component.types import CANALS_VARIADIC_ANNOTATION
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1,4 +1,6 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals.pipeline.draw.draw import _draw, _convert, _convert_for_debug, RenderingEngines
|
||||
from haystack.core.pipeline.pipeline import Pipeline
|
||||
|
||||
__all__ = ["Pipeline"]
|
||||
@ -6,8 +6,8 @@ import logging
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from canals.type_utils import _type_name
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.type_utils import _type_name
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -8,10 +8,10 @@ from pathlib import Path
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from canals.pipeline.draw.graphviz import _to_agraph
|
||||
from canals.pipeline.draw.mermaid import _to_mermaid_image, _to_mermaid_text
|
||||
from canals.type_utils import _type_name
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from haystack.core.pipeline.draw.graphviz import _to_agraph
|
||||
from haystack.core.pipeline.draw.mermaid import _to_mermaid_image, _to_mermaid_text
|
||||
from haystack.core.type_utils import _type_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"]
|
||||
@ -7,8 +7,8 @@ import base64
|
||||
import requests
|
||||
import networkx # type:ignore
|
||||
|
||||
from canals.errors import PipelineDrawingError
|
||||
from canals.type_utils import _type_name
|
||||
from haystack.core.errors import PipelineDrawingError
|
||||
from haystack.core.type_utils import _type_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,20 +14,20 @@ from collections import defaultdict
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from canals.component import component, Component, InputSocket, OutputSocket
|
||||
from canals.errors import (
|
||||
from haystack.core.component import component, Component, InputSocket, OutputSocket
|
||||
from haystack.core.errors import (
|
||||
PipelineError,
|
||||
PipelineConnectError,
|
||||
PipelineMaxLoops,
|
||||
PipelineRuntimeError,
|
||||
PipelineValidationError,
|
||||
)
|
||||
from canals.pipeline.descriptions import find_pipeline_outputs
|
||||
from canals.pipeline.draw import _draw, _convert_for_debug, RenderingEngines
|
||||
from canals.pipeline.validation import validate_pipeline_input, find_pipeline_inputs
|
||||
from canals.component.connection import Connection, parse_connect_string
|
||||
from canals.type_utils import _type_name
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_outputs
|
||||
from haystack.core.pipeline.draw.draw import _draw, _convert_for_debug, RenderingEngines
|
||||
from haystack.core.pipeline.validation import validate_pipeline_input, find_pipeline_inputs
|
||||
from haystack.core.component.connection import Connection, parse_connect_string
|
||||
from haystack.core.type_utils import _type_name
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -388,7 +388,7 @@ class Pipeline:
|
||||
def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
|
||||
"""
|
||||
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
|
||||
Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
|
||||
Run `pip install graphviz` or `pip install mermaid` to install missing dependencies.
|
||||
|
||||
Args:
|
||||
path: where to save the diagram.
|
||||
@ -6,9 +6,9 @@ import logging
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from canals.errors import PipelineValidationError
|
||||
from canals.component.sockets import InputSocket
|
||||
from canals.pipeline.descriptions import find_pipeline_inputs, describe_pipeline_inputs_as_string
|
||||
from haystack.core.errors import PipelineValidationError
|
||||
from haystack.core.component.sockets import InputSocket
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, describe_pipeline_inputs_as_string
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1,42 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components.concatenate import Concatenate
|
||||
from sample_components.subtract import Subtract
|
||||
from sample_components.parity import Parity
|
||||
from sample_components.remainder import Remainder
|
||||
from sample_components.accumulate import Accumulate
|
||||
from sample_components.threshold import Threshold
|
||||
from sample_components.add_value import AddFixedValue
|
||||
from sample_components.repeat import Repeat
|
||||
from sample_components.sum import Sum
|
||||
from sample_components.greet import Greet
|
||||
from sample_components.double import Double
|
||||
from sample_components.joiner import StringJoiner, StringListJoiner, FirstIntSelector
|
||||
from sample_components.hello import Hello
|
||||
from sample_components.text_splitter import TextSplitter
|
||||
from sample_components.merge_loop import MergeLoop
|
||||
from sample_components.self_loop import SelfLoop
|
||||
from sample_components.fstring import FString
|
||||
|
||||
__all__ = [
|
||||
"Concatenate",
|
||||
"Subtract",
|
||||
"Parity",
|
||||
"Remainder",
|
||||
"Accumulate",
|
||||
"Threshold",
|
||||
"AddFixedValue",
|
||||
"MergeLoop",
|
||||
"Repeat",
|
||||
"Sum",
|
||||
"Greet",
|
||||
"Double",
|
||||
"StringJoiner",
|
||||
"Hello",
|
||||
"TextSplitter",
|
||||
"StringListJoiner",
|
||||
"FirstIntSelector",
|
||||
"SelfLoop",
|
||||
"FString",
|
||||
]
|
||||
@ -4,7 +4,7 @@
|
||||
import inspect
|
||||
from typing import Type, Dict, Any
|
||||
|
||||
from canals.errors import DeserializationError, SerializationError
|
||||
from haystack.core.errors import DeserializationError, SerializationError
|
||||
|
||||
|
||||
def component_to_dict(obj: Any) -> Dict[str, Any]:
|
||||
@ -1,13 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from test.sample_components.test_accumulate import Accumulate
|
||||
from test.sample_components.test_add_value import AddFixedValue
|
||||
from test.sample_components.test_double import Double
|
||||
from test.sample_components.test_parity import Parity
|
||||
from test.sample_components.test_greet import Greet
|
||||
from test.sample_components.test_remainder import Remainder
|
||||
from test.sample_components.test_repeat import Repeat
|
||||
from test.sample_components.test_subtract import Subtract
|
||||
from test.sample_components.test_sum import Sum
|
||||
from test.sample_components.test_threshold import Threshold
|
||||
@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from canals import component
|
||||
from canals.errors import ComponentError
|
||||
from canals.testing import factory
|
||||
|
||||
|
||||
def test_component_class_default():
|
||||
MyComponent = factory.component_class("MyComponent")
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(value="something")
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(non_existing_input=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_is_registered():
|
||||
MyComponent = factory.component_class("MyComponent")
|
||||
assert component.registry["canals.testing.factory.MyComponent"] == MyComponent
|
||||
|
||||
|
||||
def test_component_class_with_input_types():
|
||||
MyComponent = factory.component_class("MyComponent", input_types={"value": int})
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(value="something")
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_with_output_types():
|
||||
MyComponent = factory.component_class("MyComponent", output_types={"value": int})
|
||||
comp = MyComponent()
|
||||
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_with_output():
|
||||
MyComponent = factory.component_class("MyComponent", output={"value": 100})
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": 100}
|
||||
|
||||
|
||||
def test_component_class_with_output_and_output_types():
|
||||
MyComponent = factory.component_class("MyComponent", output_types={"value": str}, output={"value": 100})
|
||||
comp = MyComponent()
|
||||
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": 100}
|
||||
|
||||
|
||||
def test_component_class_with_bases():
|
||||
MyComponent = factory.component_class("MyComponent", bases=(Exception,))
|
||||
comp = MyComponent()
|
||||
assert isinstance(comp, Exception)
|
||||
|
||||
|
||||
def test_component_class_with_extra_fields():
|
||||
MyComponent = factory.component_class("MyComponent", extra_fields={"my_field": 10})
|
||||
comp = MyComponent()
|
||||
assert comp.my_field == 10
|
||||
@ -2,8 +2,8 @@ from typing import Any, Dict, Optional, Union, TextIO
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
import logging
|
||||
import canals
|
||||
|
||||
from haystack.core.pipeline import Pipeline as _pipeline
|
||||
from haystack.telemetry import pipeline_running
|
||||
from haystack.marshal import Marshaller, YamlMarshaller
|
||||
|
||||
@ -12,7 +12,7 @@ DEFAULT_MARSHALLER = YamlMarshaller()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pipeline(canals.Pipeline):
|
||||
class Pipeline(_pipeline):
|
||||
def __init__(
|
||||
self,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type, List, Union
|
||||
|
||||
from haystack import default_to_dict, default_from_dict
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.document_stores import document_store, DocumentStore, DuplicatePolicy
|
||||
from haystack.core.component import component, Component
|
||||
from haystack.core.serialization import default_to_dict, default_from_dict
|
||||
|
||||
|
||||
def document_store_class(
|
||||
@ -117,3 +118,112 @@ def document_store_class(
|
||||
|
||||
cls = type(name, bases, fields)
|
||||
return document_store(cls)
|
||||
|
||||
|
||||
def component_class(
|
||||
name: str,
|
||||
input_types: Optional[Dict[str, Any]] = None,
|
||||
output_types: Optional[Dict[str, Any]] = None,
|
||||
output: Optional[Dict[str, Any]] = None,
|
||||
bases: Optional[Tuple[type, ...]] = None,
|
||||
extra_fields: Optional[Dict[str, Any]] = None,
|
||||
) -> Type[Component]:
|
||||
"""
|
||||
Utility class to create a Component class with the given name and input and output types.
|
||||
|
||||
If `output` is set but `output_types` is not, `output_types` will be set to the types of the values in `output`.
|
||||
Though if `output_types` is set but `output` is not the component's `run` method will return a dictionary
|
||||
of the same keys as `output_types` all with a value of None.
|
||||
|
||||
### Usage
|
||||
|
||||
Create a component class with default input and output types:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory("MyFakeComponent")
|
||||
component = MyFakeComponent()
|
||||
output = component.run(value=1)
|
||||
assert output == {"value": None}
|
||||
```
|
||||
|
||||
Create a component class with an "value" input of type `int` and with a "value" output of `10`:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
input_types={"value": int},
|
||||
output={"value": 10}
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
output = component.run(value=1)
|
||||
assert output == {"value": 10}
|
||||
```
|
||||
|
||||
Create a component class with a custom base class:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
bases=(MyBaseClass,)
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
assert isinstance(component, MyBaseClass)
|
||||
```
|
||||
|
||||
Create a component class with an extra field `my_field`:
|
||||
```python
|
||||
MyFakeComponent = component_class_factory(
|
||||
"MyFakeComponent",
|
||||
extra_fields={"my_field": 10}
|
||||
)
|
||||
component = MyFakeComponent()
|
||||
assert component.my_field == 10
|
||||
```
|
||||
|
||||
Args:
|
||||
name: Name of the component class
|
||||
input_types: Dictionary of string and type that defines the inputs of the component,
|
||||
if set to None created component will expect a single input "value" of Any type.
|
||||
Defaults to None.
|
||||
output_types: Dictionary of string and type that defines the outputs of the component,
|
||||
if set to None created component will return a single output "value" of NoneType and None value.
|
||||
Defaults to None.
|
||||
output: Actual output dictionary returned by the created component run,
|
||||
is set to None it will return a dictionary of string and None values.
|
||||
Keys will be the same as the keys of output_types. Defaults to None.
|
||||
bases: Base classes for this component, if set to None only base is object. Defaults to None.
|
||||
extra_fields: Extra fields for the Component, defaults to None.
|
||||
|
||||
:return: A class definition that can be used as a component.
|
||||
"""
|
||||
if input_types is None:
|
||||
input_types = {"value": Any}
|
||||
if output_types is None and output is not None:
|
||||
output_types = {key: type(value) for key, value in output.items()}
|
||||
elif output_types is None:
|
||||
output_types = {"value": type(None)}
|
||||
|
||||
def init(self):
|
||||
component.set_input_types(self, **input_types)
|
||||
component.set_output_types(self, **output_types)
|
||||
|
||||
# Both arguments are necessary to correctly define
|
||||
# run but pylint doesn't like that we don't use them.
|
||||
# It's fine ignoring the warning here.
|
||||
def run(self, **kwargs): # pylint: disable=unused-argument
|
||||
if output is not None:
|
||||
return output
|
||||
return {name: None for name in output_types.keys()}
|
||||
|
||||
def to_dict(self):
|
||||
return default_to_dict(self)
|
||||
|
||||
def from_dict(cls, data: Dict[str, Any]):
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
fields = {"__init__": init, "run": run, "to_dict": to_dict, "from_dict": classmethod(from_dict)}
|
||||
if extra_fields is not None:
|
||||
fields = {**fields, **extra_fields}
|
||||
|
||||
if bases is None:
|
||||
bases = (object,)
|
||||
|
||||
cls = type(name, bases, fields)
|
||||
return component(cls)
|
||||
|
||||
42
haystack/testing/sample_components/__init__.py
Normal file
42
haystack/testing/sample_components/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from haystack.testing.sample_components.concatenate import Concatenate
|
||||
from haystack.testing.sample_components.subtract import Subtract
|
||||
from haystack.testing.sample_components.parity import Parity
|
||||
from haystack.testing.sample_components.remainder import Remainder
|
||||
from haystack.testing.sample_components.accumulate import Accumulate
|
||||
from haystack.testing.sample_components.threshold import Threshold
|
||||
from haystack.testing.sample_components.add_value import AddFixedValue
|
||||
from haystack.testing.sample_components.repeat import Repeat
|
||||
from haystack.testing.sample_components.sum import Sum
|
||||
from haystack.testing.sample_components.greet import Greet
|
||||
from haystack.testing.sample_components.double import Double
|
||||
from haystack.testing.sample_components.joiner import StringJoiner, StringListJoiner, FirstIntSelector
|
||||
from haystack.testing.sample_components.hello import Hello
|
||||
from haystack.testing.sample_components.text_splitter import TextSplitter
|
||||
from haystack.testing.sample_components.merge_loop import MergeLoop
|
||||
from haystack.testing.sample_components.self_loop import SelfLoop
|
||||
from haystack.testing.sample_components.fstring import FString
|
||||
|
||||
__all__ = [
|
||||
"Concatenate",
|
||||
"Subtract",
|
||||
"Parity",
|
||||
"Remainder",
|
||||
"Accumulate",
|
||||
"Threshold",
|
||||
"AddFixedValue",
|
||||
"MergeLoop",
|
||||
"Repeat",
|
||||
"Sum",
|
||||
"Greet",
|
||||
"Double",
|
||||
"StringJoiner",
|
||||
"Hello",
|
||||
"TextSplitter",
|
||||
"StringListJoiner",
|
||||
"FirstIntSelector",
|
||||
"SelfLoop",
|
||||
"FString",
|
||||
]
|
||||
@ -6,9 +6,9 @@ import sys
|
||||
import builtins
|
||||
from importlib import import_module
|
||||
|
||||
from canals.serialization import default_to_dict
|
||||
from canals.component import component
|
||||
from canals.errors import ComponentDeserializationError
|
||||
from haystack.core.serialization import default_to_dict
|
||||
from haystack.core.component import component
|
||||
from haystack.core.errors import ComponentDeserializationError
|
||||
|
||||
|
||||
def _default_function(first: int, second: int) -> int:
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Union, List
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -4,7 +4,7 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -3,8 +3,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
from canals import component
|
||||
from canals.component.types import Variadic
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import Variadic
|
||||
|
||||
|
||||
@component
|
||||
@ -4,9 +4,9 @@
|
||||
from typing import List, Any, Optional, Dict
|
||||
import sys
|
||||
|
||||
from canals import component
|
||||
from canals.errors import DeserializationError
|
||||
from canals.serialization import default_to_dict
|
||||
from haystack.core.component import component
|
||||
from haystack.core.errors import DeserializationError
|
||||
from haystack.core.serialization import default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -1,5 +1,5 @@
|
||||
from canals import component
|
||||
from canals.component.types import Variadic
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import Variadic
|
||||
|
||||
|
||||
@component
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from canals import component
|
||||
from canals.component.types import Variadic
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import Variadic
|
||||
|
||||
|
||||
@component
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
from canals import component
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@component
|
||||
@ -46,8 +46,6 @@ classifiers = [
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
dependencies = [
|
||||
"canals==0.11.0",
|
||||
"requests", # needed by canals
|
||||
"pandas",
|
||||
"rank_bm25",
|
||||
"tqdm",
|
||||
@ -58,6 +56,8 @@ dependencies = [
|
||||
"posthog", # telemetry
|
||||
"pyyaml",
|
||||
"more-itertools", # TextDocumentSplitter
|
||||
"networkx", # Pipeline graphs
|
||||
"typing_extensions", # typing support for Python 3.8
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
Any import from `canals` should be rewritten to import from `haystack.core`
|
||||
enhancements:
|
||||
- |
|
||||
Use the code formerly in `canals` from the `haystack.core` package across the whole codebase.
|
||||
@ -3,10 +3,10 @@ from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from canals import component
|
||||
from canals.component.descriptions import find_component_inputs, find_component_outputs
|
||||
from canals.errors import ComponentError
|
||||
from canals.component import InputSocket, OutputSocket, Component
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.descriptions import find_component_inputs, find_component_outputs
|
||||
from haystack.core.errors import ComponentError
|
||||
from haystack.core.component import InputSocket, OutputSocket, Component
|
||||
|
||||
|
||||
def test_correct_declaration():
|
||||
@ -82,7 +82,7 @@ def test_correct_declaration_with_additional_writable_property():
|
||||
|
||||
|
||||
def test_missing_run():
|
||||
with pytest.raises(ComponentError, match="must have a 'run\(\)' method"):
|
||||
with pytest.raises(ComponentError, match=r"must have a 'run\(\)' method"):
|
||||
|
||||
@component
|
||||
class MockComponent:
|
||||
@ -1,6 +1,6 @@
|
||||
from canals.component.connection import Connection
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from canals.errors import PipelineConnectError
|
||||
from haystack.core.component.connection import Connection
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.errors import PipelineConnectError
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,7 +15,7 @@ def mock_mermaid_request(test_files):
|
||||
"""
|
||||
Prevents real requests to https://mermaid.ink/
|
||||
"""
|
||||
with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get:
|
||||
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read()
|
||||
@ -5,8 +5,8 @@ from pathlib import Path
|
||||
from pprint import pprint
|
||||
import logging
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import (
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import (
|
||||
Accumulate,
|
||||
AddFixedValue,
|
||||
Greet,
|
||||
@ -4,8 +4,9 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals import Pipeline, component
|
||||
from sample_components import AddFixedValue, Sum
|
||||
from haystack.core.component import component
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Sum
|
||||
|
||||
import logging
|
||||
|
||||
@ -20,7 +21,6 @@ class WithDefault:
|
||||
|
||||
|
||||
def test_pipeline(tmp_path):
|
||||
# https://github.com/deepset-ai/canals/issues/105
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("with_defaults", WithDefault())
|
||||
pipeline.draw(tmp_path / "default_value.png")
|
||||
@ -1,12 +1,11 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, MergeLoop, Remainder, FirstIntSelector
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, MergeLoop, Remainder, FirstIntSelector
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
from canals import Pipeline
|
||||
from sample_components import FString, Hello, TextSplitter
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import FString, Hello, TextSplitter
|
||||
|
||||
|
||||
def test_pipeline(tmp_path):
|
||||
@ -5,8 +5,8 @@ import logging
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Parity, Double, Subtract
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Parity, Double, Subtract
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Parity, Double
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Parity, Double
|
||||
|
||||
import logging
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Subtract
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Subtract
|
||||
|
||||
import logging
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import StringJoiner, StringListJoiner, Hello, TextSplitter
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import StringJoiner, StringListJoiner, Hello, TextSplitter
|
||||
|
||||
import logging
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Double
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Double
|
||||
|
||||
import logging
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import Accumulate, AddFixedValue, Threshold, Sum, FirstIntSelector, MergeLoop
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import Accumulate, AddFixedValue, Threshold, Sum, FirstIntSelector, MergeLoop
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop, FirstIntSelector
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop, FirstIntSelector
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
from canals import Pipeline, component
|
||||
from sample_components import StringListJoiner
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.component import component
|
||||
from haystack.testing.sample_components import StringListJoiner
|
||||
|
||||
|
||||
@component
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Repeat, Double
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Repeat, Double
|
||||
|
||||
import logging
|
||||
|
||||
@ -5,9 +5,9 @@ from typing import Optional
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals import component
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, SelfLoop
|
||||
from haystack.core.component import component
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, SelfLoop
|
||||
|
||||
import logging
|
||||
|
||||
@ -5,8 +5,8 @@ import logging
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Remainder, Double, Sum
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Remainder, Double, Sum
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Remainder, Double
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Remainder, Double
|
||||
|
||||
import logging
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from sample_components import AddFixedValue, Sum
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.sample_components import AddFixedValue, Sum
|
||||
|
||||
import logging
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
__version__ = "0.11.0"
|
||||
@ -8,11 +8,11 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from canals import Pipeline
|
||||
from canals.errors import PipelineConnectError
|
||||
from canals.testing import factory
|
||||
from canals.component.connection import parse_connect_string
|
||||
from sample_components import AddFixedValue
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.errors import PipelineConnectError
|
||||
from haystack.testing import factory
|
||||
from haystack.core.component.connection import parse_connect_string
|
||||
from haystack.testing.sample_components import AddFixedValue
|
||||
|
||||
|
||||
class Class1:
|
||||
@ -2,32 +2,19 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import sys
|
||||
import filecmp
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from canals.pipeline.draw import _draw, _convert
|
||||
from canals.errors import PipelineDrawingError
|
||||
from sample_components import Double, AddFixedValue
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform.lower().startswith("darwin"), reason="the available graphviz version is too recent")
|
||||
@pytest.mark.skipif(sys.platform.lower().startswith("win"), reason="pygraphviz is not really available in Windows")
|
||||
def test_draw_pygraphviz(tmp_path, test_files):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", Double())
|
||||
pipe.add_component("comp2", Double())
|
||||
pipe.connect("comp1", "comp2")
|
||||
|
||||
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="graphviz")
|
||||
assert os.path.exists(tmp_path / "test_pipe.jpg")
|
||||
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg")
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.pipeline.draw.draw import _draw, _convert
|
||||
from haystack.core.errors import PipelineDrawingError
|
||||
from haystack.testing.sample_components import Double, AddFixedValue
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_draw_mermaid_image(tmp_path, test_files):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", Double())
|
||||
@ -40,6 +27,7 @@ def test_draw_mermaid_image(tmp_path, test_files):
|
||||
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_draw_mermaid_img_failing_request(tmp_path):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", Double())
|
||||
@ -47,7 +35,7 @@ def test_draw_mermaid_img_failing_request(tmp_path):
|
||||
pipe.connect("comp1", "comp2")
|
||||
pipe.connect("comp2", "comp1")
|
||||
|
||||
with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get:
|
||||
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get:
|
||||
|
||||
def raise_for_status(self):
|
||||
raise requests.HTTPError()
|
||||
@ -62,6 +50,7 @@ def test_draw_mermaid_img_failing_request(tmp_path):
|
||||
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_draw_mermaid_text(tmp_path):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", AddFixedValue(add=3))
|
||||
26
test/core/pipeline/unit/test_draw_graphviz.py
Normal file
26
test/core/pipeline/unit/test_draw_graphviz.py
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import filecmp
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.pipeline.draw.draw import _draw
|
||||
from haystack.testing.sample_components import Double
|
||||
|
||||
|
||||
pygraphviz = pytest.importorskip("pygraphviz")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_draw_pygraphviz(tmp_path, test_files):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", Double())
|
||||
pipe.add_component("comp2", Double())
|
||||
pipe.connect("comp1", "comp2")
|
||||
|
||||
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="graphviz")
|
||||
assert os.path.exists(tmp_path / "test_pipe.jpg")
|
||||
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg")
|
||||
@ -6,11 +6,11 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from canals import Pipeline
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from canals.errors import PipelineMaxLoops, PipelineError, PipelineRuntimeError
|
||||
from sample_components import AddFixedValue, Threshold, Double, Sum
|
||||
from canals.testing.factory import component_class
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.errors import PipelineMaxLoops, PipelineError, PipelineRuntimeError
|
||||
from haystack.testing.sample_components import AddFixedValue, Threshold, Double, Sum
|
||||
from haystack.testing.factory import component_class
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
@ -56,9 +56,15 @@ def test_to_dict():
|
||||
"metadata": {"test": "test"},
|
||||
"max_loops_allowed": 42,
|
||||
"components": {
|
||||
"add_two": {"type": "sample_components.add_value.AddFixedValue", "init_parameters": {"add": 2}},
|
||||
"add_default": {"type": "sample_components.add_value.AddFixedValue", "init_parameters": {"add": 1}},
|
||||
"double": {"type": "sample_components.double.Double", "init_parameters": {}},
|
||||
"add_two": {
|
||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||
"init_parameters": {"add": 2},
|
||||
},
|
||||
"add_default": {
|
||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||
"init_parameters": {"add": 1},
|
||||
},
|
||||
"double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}},
|
||||
},
|
||||
"connections": [
|
||||
{"sender": "add_two.result", "receiver": "double.value"},
|
||||
@ -73,9 +79,15 @@ def test_from_dict():
|
||||
"metadata": {"test": "test"},
|
||||
"max_loops_allowed": 101,
|
||||
"components": {
|
||||
"add_two": {"type": "sample_components.add_value.AddFixedValue", "init_parameters": {"add": 2}},
|
||||
"add_default": {"type": "sample_components.add_value.AddFixedValue", "init_parameters": {"add": 1}},
|
||||
"double": {"type": "sample_components.double.Double", "init_parameters": {}},
|
||||
"add_two": {
|
||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||
"init_parameters": {"add": 2},
|
||||
},
|
||||
"add_default": {
|
||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||
"init_parameters": {"add": 1},
|
||||
},
|
||||
"double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}},
|
||||
},
|
||||
"connections": [
|
||||
{"sender": "add_two.result", "receiver": "double.value"},
|
||||
@ -153,7 +165,7 @@ def test_from_dict_with_components_instances():
|
||||
"components": {
|
||||
"add_two": {},
|
||||
"add_default": {},
|
||||
"double": {"type": "sample_components.double.Double", "init_parameters": {}},
|
||||
"double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}},
|
||||
},
|
||||
"connections": [
|
||||
{"sender": "add_two.result", "receiver": "double.value"},
|
||||
@ -335,8 +347,6 @@ def test_describe_output_multiple_possible():
|
||||
pipe.add_component("b", B())
|
||||
pipe.connect("a.output_b", "b.input_b")
|
||||
|
||||
# waiting for https://github.com/deepset-ai/canals/pull/148 to be merged
|
||||
# then this unit test will pass
|
||||
assert pipe.outputs() == {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}}
|
||||
|
||||
|
||||
@ -3,12 +3,12 @@ from typing import Optional
|
||||
import pytest
|
||||
import inspect
|
||||
|
||||
from canals.pipeline import Pipeline
|
||||
from canals.component.types import Variadic
|
||||
from canals.errors import PipelineValidationError
|
||||
from canals.component.sockets import InputSocket, OutputSocket
|
||||
from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from sample_components import Double, AddFixedValue, Sum, Parity
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.component.types import Variadic
|
||||
from haystack.core.errors import PipelineValidationError
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from haystack.testing.sample_components import Double, AddFixedValue, Sum, Parity
|
||||
|
||||
|
||||
def test_find_pipeline_input_no_input():
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components.accumulate import Accumulate, _default_function
|
||||
from haystack.testing.sample_components.accumulate import Accumulate, _default_function
|
||||
|
||||
|
||||
def my_subtract(first, second):
|
||||
@ -12,8 +12,8 @@ def test_to_dict():
|
||||
accumulate = Accumulate()
|
||||
res = accumulate.to_dict()
|
||||
assert res == {
|
||||
"type": "sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "sample_components.accumulate._default_function"},
|
||||
"type": "haystack.testing.sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "haystack.testing.sample_components.accumulate._default_function"},
|
||||
}
|
||||
|
||||
|
||||
@ -21,21 +21,21 @@ def test_to_dict_with_custom_function():
|
||||
accumulate = Accumulate(function=my_subtract)
|
||||
res = accumulate.to_dict()
|
||||
assert res == {
|
||||
"type": "sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"},
|
||||
"type": "haystack.testing.sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "test_accumulate.my_subtract"},
|
||||
}
|
||||
|
||||
|
||||
def test_from_dict():
|
||||
data = {"type": "sample_components.accumulate.Accumulate", "init_parameters": {}}
|
||||
data = {"type": "haystack.testing.sample_components.accumulate.Accumulate", "init_parameters": {}}
|
||||
accumulate = Accumulate.from_dict(data)
|
||||
assert accumulate.function == _default_function
|
||||
|
||||
|
||||
def test_from_dict_with_default_function():
|
||||
data = {
|
||||
"type": "sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "sample_components.accumulate._default_function"},
|
||||
"type": "haystack.testing.sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "haystack.testing.sample_components.accumulate._default_function"},
|
||||
}
|
||||
accumulate = Accumulate.from_dict(data)
|
||||
assert accumulate.function == _default_function
|
||||
@ -43,8 +43,8 @@ def test_from_dict_with_default_function():
|
||||
|
||||
def test_from_dict_with_custom_function():
|
||||
data = {
|
||||
"type": "sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"},
|
||||
"type": "haystack.testing.sample_components.accumulate.Accumulate",
|
||||
"init_parameters": {"function": "test_accumulate.my_subtract"},
|
||||
}
|
||||
accumulate = Accumulate.from_dict(data)
|
||||
assert accumulate.function == my_subtract
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import AddFixedValue
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import AddFixedValue
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_run():
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import Concatenate
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Concatenate
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_input_lists():
|
||||
@ -2,8 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from sample_components import Double
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Double
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_double_default():
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from sample_components import FString
|
||||
from haystack.testing.sample_components import FString
|
||||
|
||||
|
||||
def test_fstring_with_one_var():
|
||||
@ -3,8 +3,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
|
||||
from sample_components import Greet
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Greet
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_greet_message(caplog):
|
||||
@ -5,16 +5,16 @@ from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from canals.errors import DeserializationError
|
||||
from haystack.core.errors import DeserializationError
|
||||
|
||||
from sample_components import MergeLoop
|
||||
from haystack.testing.sample_components import MergeLoop
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
component = MergeLoop(expected_type=int, inputs=["first", "second"])
|
||||
res = component.to_dict()
|
||||
assert res == {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]},
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ def test_to_dict_with_typing_class():
|
||||
component = MergeLoop(expected_type=Dict, inputs=["first", "second"])
|
||||
res = component.to_dict()
|
||||
assert res == {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "typing.Dict", "inputs": ["first", "second"]},
|
||||
}
|
||||
|
||||
@ -32,14 +32,17 @@ def test_to_dict_with_custom_class():
|
||||
component = MergeLoop(expected_type=MergeLoop, inputs=["first", "second"])
|
||||
res = component.to_dict()
|
||||
assert res == {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "sample_components.merge_loop.MergeLoop", "inputs": ["first", "second"]},
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {
|
||||
"expected_type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"inputs": ["first", "second"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_from_dict():
|
||||
data = {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]},
|
||||
}
|
||||
component = MergeLoop.from_dict(data)
|
||||
@ -49,7 +52,7 @@ def test_from_dict():
|
||||
|
||||
def test_from_dict_with_typing_class():
|
||||
data = {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "typing.Dict", "inputs": ["first", "second"]},
|
||||
}
|
||||
component = MergeLoop.from_dict(data)
|
||||
@ -59,16 +62,19 @@ def test_from_dict_with_typing_class():
|
||||
|
||||
def test_from_dict_with_custom_class():
|
||||
data = {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "sample_components.merge_loop.MergeLoop", "inputs": ["first", "second"]},
|
||||
}
|
||||
component = MergeLoop.from_dict(data)
|
||||
assert component.expected_type == "sample_components.merge_loop.MergeLoop"
|
||||
assert component.expected_type == "haystack.testing.sample_components.merge_loop.MergeLoop"
|
||||
assert component.inputs == ["first", "second"]
|
||||
|
||||
|
||||
def test_from_dict_without_expected_type():
|
||||
data = {"type": "sample_components.merge_loop.MergeLoop", "init_parameters": {"inputs": ["first", "second"]}}
|
||||
data = {
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"inputs": ["first", "second"]},
|
||||
}
|
||||
with pytest.raises(DeserializationError) as exc:
|
||||
MergeLoop.from_dict(data)
|
||||
|
||||
@ -77,7 +83,7 @@ def test_from_dict_without_expected_type():
|
||||
|
||||
def test_from_dict_without_inputs():
|
||||
data = {
|
||||
"type": "sample_components.merge_loop.MergeLoop",
|
||||
"type": "haystack.testing.sample_components.merge_loop.MergeLoop",
|
||||
"init_parameters": {"expected_type": "sample_components.merge_loop.MergeLoop"},
|
||||
}
|
||||
with pytest.raises(DeserializationError) as exc:
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import Parity
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Parity
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_parity():
|
||||
@ -3,8 +3,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
from sample_components import Remainder
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Remainder
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_remainder_default():
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import Repeat
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Repeat
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_repeat_default():
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import Subtract
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Subtract
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_subtract():
|
||||
@ -2,8 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from sample_components import Sum
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Sum
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_sum_receives_no_values():
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from sample_components import Threshold
|
||||
from canals.serialization import component_to_dict, component_from_dict
|
||||
from haystack.testing.sample_components import Threshold
|
||||
from haystack.core.serialization import component_to_dict, component_from_dict
|
||||
|
||||
|
||||
def test_threshold():
|
||||
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 9.1 KiB After Width: | Height: | Size: 9.1 KiB |
@ -3,24 +3,25 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from canals import Pipeline, component
|
||||
from canals.errors import DeserializationError
|
||||
from canals.testing import factory
|
||||
from canals.serialization import default_to_dict, default_from_dict
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.component import component
|
||||
from haystack.core.errors import DeserializationError
|
||||
from haystack.testing import factory
|
||||
from haystack.core.serialization import default_to_dict, default_from_dict
|
||||
|
||||
|
||||
def test_default_component_to_dict():
|
||||
MyComponent = factory.component_class("MyComponent")
|
||||
comp = MyComponent()
|
||||
res = default_to_dict(comp)
|
||||
assert res == {"type": "canals.testing.factory.MyComponent", "init_parameters": {}}
|
||||
assert res == {"type": "haystack.testing.factory.MyComponent", "init_parameters": {}}
|
||||
|
||||
|
||||
def test_default_component_to_dict_with_init_parameters():
|
||||
MyComponent = factory.component_class("MyComponent")
|
||||
comp = MyComponent()
|
||||
res = default_to_dict(comp, some_key="some_value")
|
||||
assert res == {"type": "canals.testing.factory.MyComponent", "init_parameters": {"some_key": "some_value"}}
|
||||
assert res == {"type": "haystack.testing.factory.MyComponent", "init_parameters": {"some_key": "some_value"}}
|
||||
|
||||
|
||||
def test_default_component_from_dict():
|
||||
@ -30,7 +31,7 @@ def test_default_component_from_dict():
|
||||
extra_fields = {"__init__": custom_init}
|
||||
MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields)
|
||||
comp = default_from_dict(
|
||||
MyComponent, {"type": "canals.testing.factory.MyComponent", "init_parameters": {"some_param": 10}}
|
||||
MyComponent, {"type": "haystack.testing.factory.MyComponent", "init_parameters": {"some_param": 10}}
|
||||
)
|
||||
assert isinstance(comp, MyComponent)
|
||||
assert comp.some_param == 10
|
||||
@ -56,7 +57,7 @@ def test_from_dict_import_type():
|
||||
"max_loops_allowed": 100,
|
||||
"components": {
|
||||
"greeter": {
|
||||
"type": "sample_components.greet.Greet",
|
||||
"type": "haystack.testing.sample_components.greet.Greet",
|
||||
"init_parameters": {
|
||||
"message": "\nGreeting component says: Hi! The value is {value}\n",
|
||||
"log_level": "INFO",
|
||||
@ -67,12 +68,12 @@ def test_from_dict_import_type():
|
||||
}
|
||||
|
||||
# remove the target component from the registry if already there
|
||||
component.registry.pop("sample_components.greet.Greet", None)
|
||||
component.registry.pop("haystack.testing.sample_components.greet.Greet", None)
|
||||
# remove the module from sys.modules if already there
|
||||
sys.modules.pop("sample_components.greet", None)
|
||||
sys.modules.pop("haystack.testing.sample_components.greet", None)
|
||||
|
||||
p = Pipeline.from_dict(pipeline_dict)
|
||||
|
||||
from sample_components.greet import Greet
|
||||
from haystack.testing.sample_components.greet import Greet
|
||||
|
||||
assert type(p.get_component("greeter")) == Greet
|
||||
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from canals.type_utils import _type_name
|
||||
from haystack.core.type_utils import _type_name
|
||||
|
||||
|
||||
class Class1:
|
||||
@ -1,8 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.testing.factory import document_store_class
|
||||
from haystack.testing.factory import document_store_class, component_class
|
||||
from haystack.document_stores.decorator import document_store
|
||||
from haystack.core.component import component
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -67,4 +68,67 @@ def test_document_store_class_with_bases():
|
||||
def test_document_store_class_with_extra_fields():
|
||||
MyStore = document_store_class("MyStore", extra_fields={"my_field": 10})
|
||||
store = MyStore()
|
||||
assert store.my_field == 10
|
||||
assert store.my_field == 10 # type: ignore
|
||||
|
||||
|
||||
def test_component_class_default():
|
||||
MyComponent = component_class("MyComponent")
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(value="something")
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(non_existing_input=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_is_registered():
|
||||
MyComponent = component_class("MyComponent")
|
||||
assert component.registry["haystack.testing.factory.MyComponent"] == MyComponent
|
||||
|
||||
|
||||
def test_component_class_with_input_types():
|
||||
MyComponent = component_class("MyComponent", input_types={"value": int})
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
res = comp.run(value="something")
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_with_output_types():
|
||||
MyComponent = component_class("MyComponent", output_types={"value": int})
|
||||
comp = MyComponent()
|
||||
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": None}
|
||||
|
||||
|
||||
def test_component_class_with_output():
|
||||
MyComponent = component_class("MyComponent", output={"value": 100})
|
||||
comp = MyComponent()
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": 100}
|
||||
|
||||
|
||||
def test_component_class_with_output_and_output_types():
|
||||
MyComponent = component_class("MyComponent", output_types={"value": str}, output={"value": 100})
|
||||
comp = MyComponent()
|
||||
|
||||
res = comp.run(value=1)
|
||||
assert res == {"value": 100}
|
||||
|
||||
|
||||
def test_component_class_with_bases():
|
||||
MyComponent = component_class("MyComponent", bases=(Exception,))
|
||||
comp = MyComponent()
|
||||
assert isinstance(comp, Exception)
|
||||
|
||||
|
||||
def test_component_class_with_extra_fields():
|
||||
MyComponent = component_class("MyComponent", extra_fields={"my_field": 10})
|
||||
comp = MyComponent()
|
||||
assert comp.my_field == 10 # type: ignore
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user