autogen/src/agnext/application/_single_threaded_agent_runtime.py

235 lines
8.7 KiB
Python
Raw Normal View History

2024-05-15 12:31:13 -04:00
import asyncio
2024-06-04 10:17:04 -04:00
import logging
2024-05-15 12:31:13 -04:00
from asyncio import Future
2024-06-04 10:17:04 -04:00
from collections.abc import Sequence
2024-05-15 12:31:13 -04:00
from dataclasses import dataclass
2024-05-27 20:25:25 -04:00
from typing import Any, Awaitable, Dict, List, Mapping, Set
2024-05-15 12:31:13 -04:00
2024-05-27 17:10:56 -04:00
from ..core import Agent, AgentRuntime, CancellationToken
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
2024-05-15 12:31:13 -04:00
2024-06-04 10:17:04 -04:00
logger = logging.getLogger("agnext")
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
cancellation_token: CancellationToken
sender: Agent | None
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: Any
sender: Agent | None
recipient: Agent
future: Future[Any]
cancellation_token: CancellationToken
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Any
future: Future[Any]
sender: Agent
recipient: Agent | None
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
self._per_type_subscribers: Dict[type, List[Agent]] = {}
self._agents: Set[Agent] = set()
2024-05-20 17:30:45 -06:00
self._before_send = before_send
2024-05-15 12:31:13 -04:00
def add_agent(self, agent: Agent) -> None:
agent_names = {agent.name for agent in self._agents}
if agent.name in agent_names:
raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.")
for message_type in agent.subscriptions:
if message_type not in self._per_type_subscribers:
self._per_type_subscribers[message_type] = []
self._per_type_subscribers[message_type].append(agent)
2024-05-15 12:31:13 -04:00
self._agents.add(agent)
2024-06-04 10:17:04 -04:00
@property
def agents(self) -> Sequence[Agent]:
return list(self._agents)
@property
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
2024-05-15 12:31:13 -04:00
# Returns the response of the message
def send_message(
2024-05-20 17:30:45 -06:00
self,
message: Any,
recipient: Agent,
2024-05-20 17:30:45 -06:00
*,
sender: Agent | None = None,
2024-05-20 17:30:45 -06:00
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = asyncio.get_event_loop().create_future()
if recipient not in self._agents:
future.set_exception(Exception("Recipient not found"))
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
)
)
2024-05-15 12:31:13 -04:00
return future
def publish_message(
self,
message: Any,
*,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
2024-05-20 17:30:45 -06:00
self._message_queue.append(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
2024-05-20 17:30:45 -06:00
)
)
future = asyncio.get_event_loop().create_future()
future.set_result(None)
2024-05-15 12:31:13 -04:00
return future
2024-05-27 20:25:25 -04:00
def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent in self._agents:
state[agent.name] = dict(agent.save_state())
return state
def load_state(self, state: Mapping[str, Any]) -> None:
for agent in self._agents:
agent.load_state(state[agent.name])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
2024-05-20 17:30:45 -06:00
recipient = message_envelope.recipient
assert recipient in self._agents
2024-05-15 12:31:13 -04:00
try:
2024-06-04 10:17:04 -04:00
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
)
response = await recipient.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
)
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
responses: List[Awaitable[Any]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
2024-05-28 16:21:40 -04:00
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
continue
2024-06-04 10:17:04 -04:00
logger.info(
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
)
future = agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
)
responses.append(future)
try:
_all_responses = await asyncio.gather(*responses)
except BaseException:
2024-06-04 10:17:04 -04:00
logger.error("Error processing publish message", exc_info=True)
return
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
2024-06-04 10:17:04 -04:00
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
logger.info(
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
)
message_envelope.future.set_result(message_envelope.message)
2024-05-15 12:31:13 -04:00
async def process_next(self) -> None:
if len(self._message_queue) == 0:
2024-05-15 12:31:13 -04:00
# Yield control to the event loop to allow other tasks to run
await asyncio.sleep(0)
return
message_envelope = self._message_queue.pop(0)
2024-05-15 12:31:13 -04:00
match message_envelope:
2024-05-20 17:30:45 -06:00
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
2024-05-20 17:30:45 -06:00
asyncio.create_task(self._process_send(message_envelope))
case PublishMessageEnvelope(
2024-05-20 17:30:45 -06:00
message=message,
sender=sender,
):
if self._before_send is not None:
temp_message = await self._before_send.on_publish(message, sender=sender)
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped
2024-05-20 17:30:45 -06:00
return
message_envelope.message = temp_message
2024-05-20 17:30:45 -06:00
asyncio.create_task(self._process_publish(message_envelope))
2024-05-20 17:30:45 -06:00
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
2024-05-20 17:30:45 -06:00
asyncio.create_task(self._process_response(message_envelope))
# Yield control to the message loop to allow other tasks to run
2024-05-15 12:31:13 -04:00
await asyncio.sleep(0)