mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
fix: callables can be deserialized from fully qualified import path (#8788)
* fix: callables can be deserialized from fully qualified import path * fix: license header * fix: format * fix: types * fix? types * test: extend test case * format * add release notes
This commit is contained in:
parent
379711f63e
commit
1a91365cc8
10
haystack/testing/callable_serialization/random_callable.py
Normal file
10
haystack/testing/callable_serialization/random_callable.py
Normal file
@ -0,0 +1,10 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
def callable_to_deserialize(hello: str) -> str:
|
||||
"""
|
||||
A function to test callable deserialization.
|
||||
"""
|
||||
return f"{hello}, world!"
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
from haystack.core.errors import DeserializationError, SerializationError
|
||||
from haystack.utils.type_serialization import thread_safe_import
|
||||
@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable:
|
||||
:return: The callable
|
||||
:raises DeserializationError: If the callable cannot be found
|
||||
"""
|
||||
module_name, *attribute_chain = callable_handle.split(".")
|
||||
parts = callable_handle.split(".")
|
||||
|
||||
try:
|
||||
current = thread_safe_import(module_name)
|
||||
except Exception as e:
|
||||
raise DeserializationError(f"Could not locate the module: {module_name}") from e
|
||||
|
||||
for attr in attribute_chain:
|
||||
for i in range(len(parts), 0, -1):
|
||||
module_name = ".".join(parts[:i])
|
||||
try:
|
||||
attr_value = getattr(current, attr)
|
||||
except AttributeError as e:
|
||||
raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e
|
||||
mod: Any = thread_safe_import(module_name)
|
||||
except Exception:
|
||||
# keep reducing i until we find a valid module import
|
||||
continue
|
||||
|
||||
attr_value = mod
|
||||
for part in parts[i:]:
|
||||
try:
|
||||
attr_value = getattr(attr_value, part)
|
||||
except AttributeError as e:
|
||||
raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e
|
||||
|
||||
# when the attribute is a classmethod, we need the underlying function
|
||||
if isinstance(attr_value, (classmethod, staticmethod)):
|
||||
attr_value = attr_value.__func__
|
||||
|
||||
current = attr_value
|
||||
if not callable(attr_value):
|
||||
raise DeserializationError(f"The final attribute is not callable: {attr_value}")
|
||||
|
||||
if not callable(current):
|
||||
raise DeserializationError(f"The final attribute is not callable: {current}")
|
||||
return attr_value
|
||||
|
||||
return current
|
||||
# Fallback if we never find anything
|
||||
raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.")
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Callable deserialization now works for all fully qualified import paths.
|
||||
@ -1,10 +1,12 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from haystack.core.errors import DeserializationError, SerializationError
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.testing.callable_serialization.random_callable import callable_to_deserialize
|
||||
from haystack.utils import serialize_callable, deserialize_callable
|
||||
|
||||
|
||||
@ -40,6 +42,13 @@ def test_callable_serialization_non_local():
|
||||
assert result == "requests.api.get"
|
||||
|
||||
|
||||
def test_fully_qualified_import_deserialization():
|
||||
func = deserialize_callable("haystack.testing.callable_serialization.random_callable.callable_to_deserialize")
|
||||
|
||||
assert func is callable_to_deserialize
|
||||
assert func("Hello") == "Hello, world!"
|
||||
|
||||
|
||||
def test_callable_serialization_instance_methods_fail():
|
||||
with pytest.raises(SerializationError):
|
||||
serialize_callable(TestClass.my_method)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user