fix: callables can be deserialized from fully qualified import path (#8788)

* fix: callables can be deserialized from fully qualified import path

* fix: license header

* fix: format

* fix: types

* fix? types

* test: extend test case

* format

* add release notes
This commit is contained in:
mathislucka 2025-02-03 12:35:37 +01:00 committed by GitHub
parent 379711f63e
commit 1a91365cc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 15 deletions

View File

@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
def callable_to_deserialize(hello: str) -> str:
"""
A function to test callable deserialization.
"""
return f"{hello}, world!"

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Callable
from typing import Any, Callable
from haystack.core.errors import DeserializationError, SerializationError
from haystack.utils.type_serialization import thread_safe_import
@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable:
:return: The callable
:raises DeserializationError: If the callable cannot be found
"""
module_name, *attribute_chain = callable_handle.split(".")
parts = callable_handle.split(".")
try:
current = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module: {module_name}") from e
for attr in attribute_chain:
for i in range(len(parts), 0, -1):
module_name = ".".join(parts[:i])
try:
attr_value = getattr(current, attr)
except AttributeError as e:
raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e
mod: Any = thread_safe_import(module_name)
except Exception:
# keep reducing i until we find a valid module import
continue
attr_value = mod
for part in parts[i:]:
try:
attr_value = getattr(attr_value, part)
except AttributeError as e:
raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e
# when the attribute is a classmethod, we need the underlying function
if isinstance(attr_value, (classmethod, staticmethod)):
attr_value = attr_value.__func__
current = attr_value
if not callable(attr_value):
raise DeserializationError(f"The final attribute is not callable: {attr_value}")
if not callable(current):
raise DeserializationError(f"The final attribute is not callable: {current}")
return attr_value
return current
# Fallback if we never find anything
raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.")

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Callable deserialization now works for all fully qualified import paths.

View File

@ -1,10 +1,12 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests
from haystack.core.errors import DeserializationError, SerializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.testing.callable_serialization.random_callable import callable_to_deserialize
from haystack.utils import serialize_callable, deserialize_callable
@ -40,6 +42,13 @@ def test_callable_serialization_non_local():
assert result == "requests.api.get"
def test_fully_qualified_import_deserialization():
func = deserialize_callable("haystack.testing.callable_serialization.random_callable.callable_to_deserialize")
assert func is callable_to_deserialize
assert func("Hello") == "Hello, world!"
def test_callable_serialization_instance_methods_fail():
with pytest.raises(SerializationError):
serialize_callable(TestClass.my_method)