Port changes from agexplore (#47)

This commit is contained in:
Jack Gerrits 2024-06-04 10:17:04 -04:00 committed by GitHub
parent 69627aeee6
commit 19570fdd98
5 changed files with 42 additions and 11 deletions

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
import logging
from asyncio import Future from asyncio import Future
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Mapping, Set from typing import Any, Awaitable, Dict, List, Mapping, Set
@ -7,6 +9,8 @@ from ..core import Agent, AgentRuntime, CancellationToken
from ..core.exceptions import MessageDroppedException from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler from ..core.intervention import DropMessage, InterventionHandler
logger = logging.getLogger("agnext")
@dataclass(kw_only=True) @dataclass(kw_only=True)
class PublishMessageEnvelope: class PublishMessageEnvelope:
@ -58,6 +62,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._per_type_subscribers[message_type].append(agent) self._per_type_subscribers[message_type].append(agent)
self._agents.add(agent) self._agents.add(agent)
@property
def agents(self) -> Sequence[Agent]:
return list(self._agents)
@property
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
# Returns the response of the message # Returns the response of the message
def send_message( def send_message(
self, self,
@ -123,6 +135,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
assert recipient in self._agents assert recipient in self._agents
try: try:
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
)
response = await recipient.on_message( response = await recipient.on_message(
message_envelope.message, message_envelope.message,
cancellation_token=message_envelope.cancellation_token, cancellation_token=message_envelope.cancellation_token,
@ -145,6 +161,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
if message_envelope.sender is not None and agent.name == message_envelope.sender.name: if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
continue continue
logger.info(
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
)
future = agent.on_message( future = agent.on_message(
message_envelope.message, message_envelope.message,
cancellation_token=message_envelope.cancellation_token, cancellation_token=message_envelope.cancellation_token,
@ -154,12 +175,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
try: try:
_all_responses = await asyncio.gather(*responses) _all_responses = await asyncio.gather(*responses)
except BaseException: except BaseException:
# TODO log error logger.error("Error processing publish message", exc_info=True)
return return
# TODO if responses are given for a publish # TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
logger.info(
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
)
message_envelope.future.set_result(message_envelope.message) message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None: async def process_next(self) -> None:

View File

@ -30,6 +30,10 @@ class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
self._max_retry_attempts_before_educated_guess = max_retry_attempts self._max_retry_attempts_before_educated_guess = max_retry_attempts
@property
def children(self) -> Sequence[str]:
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
@message_handler(TextMessage) @message_handler(TextMessage)
async def on_text_message( async def on_text_message(
self, self,

View File

@ -3,8 +3,9 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon
""" """
from ._agent import Agent from ._agent import Agent
from ._agent_props import AgentChildren
from ._agent_runtime import AgentRuntime from ._agent_runtime import AgentRuntime
from ._base_agent import BaseAgent from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken from ._cancellation_token import CancellationToken
__all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken"] __all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"]

View File

@ -0,0 +1,9 @@
from typing import Protocol, Sequence, runtime_checkable
@runtime_checkable
class AgentChildren(Protocol):
@property
def children(self) -> Sequence[str]:
"""Names of the children of the agent."""
...

View File

@ -1,4 +1,4 @@
from typing import Any, Awaitable, Callable, Protocol, Sequence, final from typing import Any, Awaitable, Callable, Protocol, final
from agnext.core import Agent from agnext.core import Agent
@ -21,9 +21,6 @@ class InterventionHandler(Protocol):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ... async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ... async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler): class DefaultInterventionHandler(InterventionHandler):
@ -35,8 +32,3 @@ class DefaultInterventionHandler(InterventionHandler):
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
return message return message
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]:
return message