From db50579bbfdd89aeaf17b73d0dcd50b5ea93f1dc Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 20 Mar 2025 11:58:09 +0100 Subject: [PATCH] feat: ChatGenerator protocol (#9074) * feat: ChatGenerator protocol * move protocol to better location --- .../generators/chat/types/__init__.py | 7 +++ .../generators/chat/types/protocol.py | 56 +++++++++++++++++++ haystack/core/component/component.py | 11 +++- ...atgenerator-protocol-a5136f63e05b7210.yaml | 6 ++ 4 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 haystack/components/generators/chat/types/__init__.py create mode 100644 haystack/components/generators/chat/types/protocol.py create mode 100644 releasenotes/notes/chatgenerator-protocol-a5136f63e05b7210.yaml diff --git a/haystack/components/generators/chat/types/__init__.py b/haystack/components/generators/chat/types/__init__.py new file mode 100644 index 000000000..00424b148 --- /dev/null +++ b/haystack/components/generators/chat/types/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .protocol import ChatGenerator + +__all__ = ["ChatGenerator"] diff --git a/haystack/components/generators/chat/types/protocol.py b/haystack/components/generators/chat/types/protocol.py new file mode 100644 index 000000000..79d082fa0 --- /dev/null +++ b/haystack/components/generators/chat/types/protocol.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Protocol, TypeVar + +from haystack.dataclasses import ChatMessage + +# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319. +# pylint: disable=unnecessary-ellipsis + +T = TypeVar("T", bound="ChatGenerator") + + +class ChatGenerator(Protocol): + """ + Protocol for Chat Generators. + + This protocol defines the minimal interface that Chat Generators must implement. + Chat Generators are components that process a list of `ChatMessage` objects as input and generate + responses using a Language Model. They return a dictionary. + """ + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this ChatGenerator to a dictionary. + + :returns: + The serialized ChatGenerator as a dictionary. + """ + ... + + @classmethod + def from_dict(cls: type[T], data: Dict[str, Any]) -> T: + """ + Deserialize this ChatGenerator from a dictionary. + + :param data: The dictionary representation of this ChatGenerator. + :returns: + An instance of the specific implementing class. + """ + ... + + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + """ + Generate messages using the underlying Language Model. + + Implementing classes may accept additional optional parameters in their run method. + For example: `def run (self, messages: List[ChatMessage], param_a="default", param_b="another_default")`. + + :param messages: + A list of ChatMessage instances representing the input messages. + :returns: + A dictionary. + """ + ... diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index d50bb6afd..2faa13a82 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -76,7 +76,9 @@ from contextvars import ContextVar from copy import deepcopy from dataclasses import dataclass from types import new_class -from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable +from typing import Any, Dict, Optional, Protocol, Type, TypeVar, runtime_checkable + +from typing_extensions import ParamSpec from haystack import logging from haystack.core.errors import ComponentError @@ -86,6 +88,9 @@ from .types import InputSocket, OutputSocket, _empty logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R", bound=Dict[str, Any]) + @dataclass class PreInitHookPayload: @@ -442,7 +447,7 @@ class _Component: instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket ) - def output_types(self, **types): + def output_types(self, **types: Any) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator factory that specifies the output types of a component. @@ -457,7 +462,7 @@ class _Component: ``` """ - def output_types_decorator(run_method): + def output_types_decorator(run_method: Callable[P, R]) -> Callable[P, R]: """ Decorator that sets the output types of the decorated method. diff --git a/releasenotes/notes/chatgenerator-protocol-a5136f63e05b7210.yaml b/releasenotes/notes/chatgenerator-protocol-a5136f63e05b7210.yaml new file mode 100644 index 000000000..cc5b6afbe --- /dev/null +++ b/releasenotes/notes/chatgenerator-protocol-a5136f63e05b7210.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Introduce a `ChatGenerator` Protocol to qualify `ChatGenerator` components from a static type-checking perspective. + It defines the minimal interface that Chat Generators must implement. + This will especially help to standardize the integration of Chat Generators into other more complex components.