2024-05-15 12:31:13 -04:00
|
|
|
import asyncio
|
|
|
|
from asyncio import Future
|
|
|
|
from dataclasses import dataclass
|
2024-05-26 08:45:02 -04:00
|
|
|
from typing import Any, Awaitable, Dict, List, Set
|
2024-05-15 12:31:13 -04:00
|
|
|
|
2024-05-20 13:32:08 -06:00
|
|
|
from agnext.core.cancellation_token import CancellationToken
|
2024-05-20 17:30:45 -06:00
|
|
|
from agnext.core.exceptions import MessageDroppedException
|
|
|
|
from agnext.core.intervention import DropMessage, InterventionHandler
|
2024-05-20 13:32:08 -06:00
|
|
|
|
2024-05-17 11:09:59 -04:00
|
|
|
from ..core.agent import Agent
|
|
|
|
from ..core.agent_runtime import AgentRuntime
|
2024-05-15 12:31:13 -04:00
|
|
|
|
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
@dataclass(kw_only=True)
|
2024-05-26 08:45:02 -04:00
|
|
|
class PublishMessageEnvelope:
|
|
|
|
"""A message envelope for publishing messages to all agents that can handle
|
2024-05-17 14:59:00 -07:00
|
|
|
the message of the type T."""
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message: Any
|
2024-05-20 13:32:08 -06:00
|
|
|
cancellation_token: CancellationToken
|
2024-05-23 16:00:05 -04:00
|
|
|
sender: Agent | None
|
2024-05-15 12:31:13 -04:00
|
|
|
|
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
@dataclass(kw_only=True)
|
2024-05-23 16:00:05 -04:00
|
|
|
class SendMessageEnvelope:
|
2024-05-17 14:59:00 -07:00
|
|
|
"""A message envelope for sending a message to a specific agent that can handle
|
|
|
|
the message of the type T."""
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message: Any
|
|
|
|
sender: Agent | None
|
|
|
|
recipient: Agent
|
2024-05-26 08:45:02 -04:00
|
|
|
future: Future[Any]
|
2024-05-20 13:32:08 -06:00
|
|
|
cancellation_token: CancellationToken
|
2024-05-15 12:31:13 -04:00
|
|
|
|
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
@dataclass(kw_only=True)
|
2024-05-23 16:00:05 -04:00
|
|
|
class ResponseMessageEnvelope:
|
2024-05-17 14:59:00 -07:00
|
|
|
"""A message envelope for sending a response to a message."""
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message: Any
|
|
|
|
future: Future[Any]
|
|
|
|
sender: Agent
|
|
|
|
recipient: Agent | None
|
2024-05-19 17:12:49 -06:00
|
|
|
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
class SingleThreadedAgentRuntime(AgentRuntime):
|
|
|
|
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
2024-05-26 08:45:02 -04:00
|
|
|
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
2024-05-23 16:00:05 -04:00
|
|
|
self._per_type_subscribers: Dict[type, List[Agent]] = {}
|
|
|
|
self._agents: Set[Agent] = set()
|
2024-05-20 17:30:45 -06:00
|
|
|
self._before_send = before_send
|
2024-05-15 12:31:13 -04:00
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
def add_agent(self, agent: Agent) -> None:
|
2024-05-27 16:33:28 -04:00
|
|
|
agent_names = {agent.name for agent in self._agents}
|
|
|
|
if agent.name in agent_names:
|
|
|
|
raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.")
|
|
|
|
|
2024-05-17 14:59:00 -07:00
|
|
|
for message_type in agent.subscriptions:
|
|
|
|
if message_type not in self._per_type_subscribers:
|
|
|
|
self._per_type_subscribers[message_type] = []
|
|
|
|
self._per_type_subscribers[message_type].append(agent)
|
2024-05-15 12:31:13 -04:00
|
|
|
self._agents.add(agent)
|
|
|
|
|
|
|
|
# Returns the response of the message
|
2024-05-20 13:32:08 -06:00
|
|
|
def send_message(
|
2024-05-20 17:30:45 -06:00
|
|
|
self,
|
2024-05-23 16:00:05 -04:00
|
|
|
message: Any,
|
|
|
|
recipient: Agent,
|
2024-05-20 17:30:45 -06:00
|
|
|
*,
|
2024-05-23 16:00:05 -04:00
|
|
|
sender: Agent | None = None,
|
2024-05-20 17:30:45 -06:00
|
|
|
cancellation_token: CancellationToken | None = None,
|
2024-05-23 16:00:05 -04:00
|
|
|
) -> Future[Any | None]:
|
2024-05-20 13:32:08 -06:00
|
|
|
if cancellation_token is None:
|
|
|
|
cancellation_token = CancellationToken()
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
future = asyncio.get_event_loop().create_future()
|
|
|
|
if recipient not in self._agents:
|
|
|
|
future.set_exception(Exception("Recipient not found"))
|
2024-05-15 12:31:13 -04:00
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
self._message_queue.append(
|
|
|
|
SendMessageEnvelope(
|
|
|
|
message=message,
|
|
|
|
recipient=recipient,
|
|
|
|
future=future,
|
|
|
|
cancellation_token=cancellation_token,
|
|
|
|
sender=sender,
|
|
|
|
)
|
|
|
|
)
|
2024-05-20 13:32:08 -06:00
|
|
|
|
2024-05-15 12:31:13 -04:00
|
|
|
return future
|
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
def publish_message(
|
2024-05-23 16:00:05 -04:00
|
|
|
self,
|
|
|
|
message: Any,
|
|
|
|
*,
|
|
|
|
sender: Agent | None = None,
|
|
|
|
cancellation_token: CancellationToken | None = None,
|
2024-05-26 08:45:02 -04:00
|
|
|
) -> Future[None]:
|
2024-05-20 13:32:08 -06:00
|
|
|
if cancellation_token is None:
|
|
|
|
cancellation_token = CancellationToken()
|
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
self._message_queue.append(
|
2024-05-26 08:45:02 -04:00
|
|
|
PublishMessageEnvelope(
|
2024-05-23 16:00:05 -04:00
|
|
|
message=message,
|
|
|
|
cancellation_token=cancellation_token,
|
|
|
|
sender=sender,
|
2024-05-20 17:30:45 -06:00
|
|
|
)
|
|
|
|
)
|
2024-05-23 16:00:05 -04:00
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
future = asyncio.get_event_loop().create_future()
|
|
|
|
future.set_result(None)
|
2024-05-15 12:31:13 -04:00
|
|
|
return future
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
2024-05-20 17:30:45 -06:00
|
|
|
recipient = message_envelope.recipient
|
2024-05-23 16:00:05 -04:00
|
|
|
assert recipient in self._agents
|
2024-05-15 12:31:13 -04:00
|
|
|
|
2024-05-20 13:32:08 -06:00
|
|
|
try:
|
|
|
|
response = await recipient.on_message(
|
2024-05-23 16:00:05 -04:00
|
|
|
message_envelope.message,
|
|
|
|
cancellation_token=message_envelope.cancellation_token,
|
2024-05-20 13:32:08 -06:00
|
|
|
)
|
|
|
|
except BaseException as e:
|
|
|
|
message_envelope.future.set_exception(e)
|
|
|
|
return
|
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
self._message_queue.append(
|
|
|
|
ResponseMessageEnvelope(
|
|
|
|
message=response,
|
|
|
|
future=message_envelope.future,
|
|
|
|
sender=message_envelope.recipient,
|
|
|
|
recipient=message_envelope.sender,
|
2024-05-23 16:00:05 -04:00
|
|
|
)
|
2024-05-26 08:45:02 -04:00
|
|
|
)
|
2024-05-23 16:00:05 -04:00
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
2024-05-23 16:00:05 -04:00
|
|
|
responses: List[Awaitable[Any]] = []
|
|
|
|
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
|
|
|
|
future = agent.on_message(
|
|
|
|
message_envelope.message,
|
|
|
|
cancellation_token=message_envelope.cancellation_token,
|
|
|
|
)
|
2024-05-19 17:12:49 -06:00
|
|
|
responses.append(future)
|
|
|
|
|
2024-05-20 13:32:08 -06:00
|
|
|
try:
|
2024-05-26 08:45:02 -04:00
|
|
|
_all_responses = await asyncio.gather(*responses)
|
|
|
|
except BaseException:
|
|
|
|
# TODO log error
|
2024-05-20 13:32:08 -06:00
|
|
|
return
|
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
# TODO if responses are given for a publish
|
2024-05-19 17:12:49 -06:00
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
2024-05-19 17:12:49 -06:00
|
|
|
message_envelope.future.set_result(message_envelope.message)
|
|
|
|
|
2024-05-15 12:31:13 -04:00
|
|
|
async def process_next(self) -> None:
|
2024-05-17 14:59:00 -07:00
|
|
|
if len(self._message_queue) == 0:
|
2024-05-15 12:31:13 -04:00
|
|
|
# Yield control to the event loop to allow other tasks to run
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
return
|
|
|
|
|
2024-05-17 14:59:00 -07:00
|
|
|
message_envelope = self._message_queue.pop(0)
|
2024-05-15 12:31:13 -04:00
|
|
|
|
2024-05-17 14:59:00 -07:00
|
|
|
match message_envelope:
|
2024-05-20 17:30:45 -06:00
|
|
|
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
|
|
|
if self._before_send is not None:
|
|
|
|
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
|
|
|
|
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
|
|
|
future.set_exception(MessageDroppedException())
|
|
|
|
return
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message_envelope.message = temp_message
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
asyncio.create_task(self._process_send(message_envelope))
|
2024-05-26 08:45:02 -04:00
|
|
|
case PublishMessageEnvelope(
|
2024-05-20 17:30:45 -06:00
|
|
|
message=message,
|
|
|
|
sender=sender,
|
|
|
|
):
|
|
|
|
if self._before_send is not None:
|
2024-05-26 08:45:02 -04:00
|
|
|
temp_message = await self._before_send.on_publish(message, sender=sender)
|
2024-05-20 17:30:45 -06:00
|
|
|
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
2024-05-26 08:45:02 -04:00
|
|
|
# TODO log message dropped
|
2024-05-20 17:30:45 -06:00
|
|
|
return
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message_envelope.message = temp_message
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-05-26 08:45:02 -04:00
|
|
|
asyncio.create_task(self._process_publish(message_envelope))
|
2024-05-20 17:30:45 -06:00
|
|
|
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
|
|
|
if self._before_send is not None:
|
|
|
|
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
|
|
|
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
|
|
|
future.set_exception(MessageDroppedException())
|
|
|
|
return
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
message_envelope.message = temp_message
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
asyncio.create_task(self._process_response(message_envelope))
|
|
|
|
|
2024-05-17 14:59:00 -07:00
|
|
|
# Yield control to the message loop to allow other tasks to run
|
2024-05-15 12:31:13 -04:00
|
|
|
await asyncio.sleep(0)
|