Port changes from agexplore (#47)

This commit is contained in:
Jack Gerrits 2024-06-04 10:17:04 -04:00 committed by GitHub
parent 69627aeee6
commit 19570fdd98
5 changed files with 42 additions and 11 deletions

View File

@ -1,5 +1,7 @@
import asyncio
import logging
from asyncio import Future
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Mapping, Set
@ -7,6 +9,8 @@ from ..core import Agent, AgentRuntime, CancellationToken
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
logger = logging.getLogger("agnext")
@dataclass(kw_only=True)
class PublishMessageEnvelope:
@ -58,6 +62,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._per_type_subscribers[message_type].append(agent)
self._agents.add(agent)
@property
def agents(self) -> Sequence[Agent]:
return list(self._agents)
@property
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
# Returns the response of the message
def send_message(
self,
@ -123,6 +135,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
assert recipient in self._agents
try:
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
)
response = await recipient.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
@ -145,6 +161,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
continue
logger.info(
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
)
future = agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
@ -154,12 +175,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
try:
_all_responses = await asyncio.gather(*responses)
except BaseException:
# TODO log error
logger.error("Error processing publish message", exc_info=True)
return
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
logger.info(
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
)
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:

View File

@ -30,6 +30,10 @@ class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
self._max_retry_attempts_before_educated_guess = max_retry_attempts
@property
def children(self) -> Sequence[str]:
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
@message_handler(TextMessage)
async def on_text_message(
self,

View File

@ -3,8 +3,9 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon
"""
from ._agent import Agent
from ._agent_props import AgentChildren
from ._agent_runtime import AgentRuntime
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
__all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken"]
__all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"]

View File

@ -0,0 +1,9 @@
from typing import Protocol, Sequence, runtime_checkable
@runtime_checkable
class AgentChildren(Protocol):
@property
def children(self) -> Sequence[str]:
"""Names of the children of the agent."""
...

View File

@ -1,4 +1,4 @@
from typing import Any, Awaitable, Callable, Protocol, Sequence, final
from typing import Any, Awaitable, Callable, Protocol, final
from agnext.core import Agent
@ -21,9 +21,6 @@ class InterventionHandler(Protocol):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler):
@ -35,8 +32,3 @@ class DefaultInterventionHandler(InterventionHandler):
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
return message
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]:
return message