feat: ChatGenerator protocol (#9074)

* feat: ChatGenerator protocol

* move protocol to better location
This commit is contained in:
Stefano Fiorucci 2025-03-20 11:58:09 +01:00 committed by GitHub
parent 2d974ab4ad
commit db50579bbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 3 deletions

View File

@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from .protocol import ChatGenerator
__all__ = ["ChatGenerator"]

View File

@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Protocol, TypeVar
from haystack.dataclasses import ChatMessage
# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319.
# pylint: disable=unnecessary-ellipsis
T = TypeVar("T", bound="ChatGenerator")
class ChatGenerator(Protocol):
"""
Protocol for Chat Generators.
This protocol defines the minimal interface that Chat Generators must implement.
Chat Generators are components that process a list of `ChatMessage` objects as input and generate
responses using a Language Model. They return a dictionary.
"""
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this ChatGenerator to a dictionary.
:returns:
The serialized ChatGenerator as a dictionary.
"""
...
@classmethod
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
"""
Deserialize this ChatGenerator from a dictionary.
:param data: The dictionary representation of this ChatGenerator.
:returns:
An instance of the specific implementing class.
"""
...
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
"""
Generate messages using the underlying Language Model.
Implementing classes may accept additional optional parameters in their run method.
For example: `def run (self, messages: List[ChatMessage], param_a="default", param_b="another_default")`.
:param messages:
A list of ChatMessage instances representing the input messages.
:returns:
A dictionary.
"""
...

View File

@ -76,7 +76,9 @@ from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, runtime_checkable
from typing_extensions import ParamSpec
from haystack import logging
from haystack.core.errors import ComponentError
@ -86,6 +88,9 @@ from .types import InputSocket, OutputSocket, _empty
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R", bound=Dict[str, Any])
@dataclass
class PreInitHookPayload:
@ -442,7 +447,7 @@ class _Component:
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
)
def output_types(self, **types):
def output_types(self, **types: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator factory that specifies the output types of a component.
@ -457,7 +462,7 @@ class _Component:
```
"""
def output_types_decorator(run_method):
def output_types_decorator(run_method: Callable[P, R]) -> Callable[P, R]:
"""
Decorator that sets the output types of the decorated method.

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Introduce a `ChatGenerator` Protocol to qualify `ChatGenerator` components from a static type-checking perspective.
It defines the minimal interface that Chat Generators must implement.
This will especially help to standardize the integration of Chat Generators into other more complex components.