haystack/test/core/test_serialization.py
Massimiliano Pippi 00e1dd6eb8
chore: rearrange the core package, move tests and clean up (#6427)
* rearrange code

* fix tests

* relnote

* merge test modules

* remove extra

* rearrange draw tests

* forgot

* remove unused import
2023-11-28 09:58:56 +01:00

80 lines
2.9 KiB
Python

import sys
from unittest.mock import Mock
import pytest
from haystack.core.pipeline import Pipeline
from haystack.core.component import component
from haystack.core.errors import DeserializationError
from haystack.testing import factory
from haystack.core.serialization import default_to_dict, default_from_dict
def test_default_component_to_dict():
MyComponent = factory.component_class("MyComponent")
comp = MyComponent()
res = default_to_dict(comp)
assert res == {"type": "haystack.testing.factory.MyComponent", "init_parameters": {}}
def test_default_component_to_dict_with_init_parameters():
MyComponent = factory.component_class("MyComponent")
comp = MyComponent()
res = default_to_dict(comp, some_key="some_value")
assert res == {"type": "haystack.testing.factory.MyComponent", "init_parameters": {"some_key": "some_value"}}
def test_default_component_from_dict():
def custom_init(self, some_param):
self.some_param = some_param
extra_fields = {"__init__": custom_init}
MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields)
comp = default_from_dict(
MyComponent, {"type": "haystack.testing.factory.MyComponent", "init_parameters": {"some_param": 10}}
)
assert isinstance(comp, MyComponent)
assert comp.some_param == 10
def test_default_component_from_dict_without_type():
with pytest.raises(DeserializationError, match="Missing 'type' in serialization data"):
default_from_dict(Mock, {})
def test_default_component_from_dict_unregistered_component(request):
# We use the test function name as component name to make sure it's not registered.
# Since the registry is global we risk to have a component with the same name registered in another test.
component_name = request.node.name
with pytest.raises(DeserializationError, match=f"Class '{component_name}' can't be deserialized as 'Mock'"):
default_from_dict(Mock, {"type": component_name})
def test_from_dict_import_type():
pipeline_dict = {
"metadata": {},
"max_loops_allowed": 100,
"components": {
"greeter": {
"type": "haystack.testing.sample_components.greet.Greet",
"init_parameters": {
"message": "\nGreeting component says: Hi! The value is {value}\n",
"log_level": "INFO",
},
}
},
"connections": [],
}
# remove the target component from the registry if already there
component.registry.pop("haystack.testing.sample_components.greet.Greet", None)
# remove the module from sys.modules if already there
sys.modules.pop("haystack.testing.sample_components.greet", None)
p = Pipeline.from_dict(pipeline_dict)
from haystack.testing.sample_components.greet import Greet
assert type(p.get_component("greeter")) == Greet