Make ChatAgent an ABC (#5129)

This commit is contained in:
Jack Gerrits 2025-01-21 20:08:53 -05:00 committed by GitHub
parent da1c2bf12e
commit 226b37d07b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Mapping, Protocol, Sequence, runtime_checkable from typing import Any, AsyncGenerator, Mapping, Sequence
from autogen_core import CancellationToken from autogen_core import CancellationToken
@ -19,17 +20,18 @@ class Response:
or :class:`ChatMessage`.""" or :class:`ChatMessage`."""
@runtime_checkable class ChatAgent(ABC, TaskRunner):
class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent.""" """Protocol for a chat agent."""
@property @property
@abstractmethod
def name(self) -> str: def name(self) -> str:
"""The name of the agent. This is used by team to uniquely identify """The name of the agent. This is used by team to uniquely identify
the agent. It should be unique within the team.""" the agent. It should be unique within the team."""
... ...
@property @property
@abstractmethod
def description(self) -> str: def description(self) -> str:
"""The description of the agent. This is used by team to """The description of the agent. This is used by team to
make decisions about which agents to use. The description should make decisions about which agents to use. The description should
@ -37,15 +39,18 @@ class ChatAgent(TaskRunner, Protocol):
... ...
@property @property
@abstractmethod
def produced_message_types(self) -> Sequence[type[ChatMessage]]: def produced_message_types(self) -> Sequence[type[ChatMessage]]:
"""The types of messages that the agent produces in the """The types of messages that the agent produces in the
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types.""" :attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
... ...
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response.""" """Handles incoming messages and returns a response."""
... ...
@abstractmethod
def on_messages_stream( def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
@ -53,18 +58,22 @@ class ChatAgent(TaskRunner, Protocol):
and the final item is the response.""" and the final item is the response."""
... ...
@abstractmethod
async def on_reset(self, cancellation_token: CancellationToken) -> None: async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Resets the agent to its initialization state.""" """Resets the agent to its initialization state."""
... ...
@abstractmethod
async def save_state(self) -> Mapping[str, Any]: async def save_state(self) -> Mapping[str, Any]:
"""Save agent state for later restoration""" """Save agent state for later restoration"""
... ...
@abstractmethod
async def load_state(self, state: Mapping[str, Any]) -> None: async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state""" """Restore agent from saved state"""
... ...
@abstractmethod
async def close(self) -> None: async def close(self) -> None:
"""Called when the runtime is stopped or any stop method is called""" """Called when the runtime is stopped or any stop method is called"""
... ...