diff --git a/docs/src/guides/type-routed-agent.md b/docs/src/guides/type-routed-agent.md index 8d4a31955..8c9ba64bb 100644 --- a/docs/src/guides/type-routed-agent.md +++ b/docs/src/guides/type-routed-agent.md @@ -40,7 +40,7 @@ class MyAgent(TypeRoutedAgent): await self._publish_message( TextMessage( content=f"I received a message from {message.source}. Message received #{self._received_count}", - source=self.name, + source=self.metadata["name"], ) ) diff --git a/examples/assistant.py b/examples/assistant.py index bf595798f..7a21ab02e 100644 --- a/examples/assistant.py +++ b/examples/assistant.py @@ -108,7 +108,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore return else: # Publish user input and exit handler. - await self._publish_message(TextMessage(content=user_input, source=self.name)) + await self._publish_message(TextMessage(content=user_input, source=self.metadata["name"])) return diff --git a/examples/chat_room.py b/examples/chat_room.py index 26f0477ca..06ac4f0ad 100644 --- a/examples/chat_room.py +++ b/examples/chat_room.py @@ -71,13 +71,13 @@ Use the following JSON format to provide your thought on the latest message and # Get a response from the model. raw_response = await self._client.create( self._system_messages - + convert_messages_to_llm_messages(await self._memory.get_messages(), self_name=self.name), + + convert_messages_to_llm_messages(await self._memory.get_messages(), self_name=self.metadata["name"]), json_output=True, ) assert isinstance(raw_response.content, str) # Save the response to memory. - await self._memory.add_message(ChatRoomMessage(source=self.name, content=raw_response.content)) + await self._memory.add_message(ChatRoomMessage(source=self.metadata["name"], content=raw_response.content)) # Parse the response. data = json.loads(raw_response.content) @@ -86,8 +86,8 @@ Use the following JSON format to provide your thought on the latest message and # Publish the response if needed. if respond is True or str(respond).lower().strip() == "true": - await self._publish_message(ChatRoomMessage(source=self.name, content=str(response))) - print(f"{sep}\n{self._color}{self.name}:{Style.RESET_ALL}\n{response}") + await self._publish_message(ChatRoomMessage(source=self.metadata["name"], content=str(response))) + print(f"{sep}\n{self._color}{self.metadata['name']}:{Style.RESET_ALL}\n{response}") # Define a chat room with participants -- the runtime is the chat room. diff --git a/examples/coder_reviewer.py b/examples/coder_reviewer.py index 671d197ea..9a38260d3 100644 --- a/examples/coder_reviewer.py +++ b/examples/coder_reviewer.py @@ -1,4 +1,5 @@ import asyncio + from agnext.application import SingleThreadedAgentRuntime from agnext.chat.agents import ChatCompletionAgent from agnext.chat.memory import BufferedChatMemory diff --git a/examples/futures.py b/examples/futures.py index 84325b4cb..f810d9490 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -18,7 +18,7 @@ class Inner(TypeRoutedAgent): # type: ignore @message_handler() # type: ignore async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore - return MessageType(body=f"Inner: {message.body}", sender=self.name) + return MessageType(body=f"Inner: {message.body}", sender=self.metadata["name"]) class Outer(TypeRoutedAgent): # type: ignore @@ -31,7 +31,7 @@ class Outer(TypeRoutedAgent): # type: ignore inner_response = self._send_message(message, self._inner) inner_message = await inner_response assert isinstance(inner_message, MessageType) - return MessageType(body=f"Outer: {inner_message.body}", sender=self.name) + return MessageType(body=f"Outer: {inner_message.body}", sender=self.metadata["name"]) async def main() -> None: diff --git a/examples/orchestrator.py b/examples/orchestrator.py index 1a161ef30..ab31f8cd9 100644 --- a/examples/orchestrator.py +++ b/examples/orchestrator.py @@ -60,20 +60,20 @@ class LoggingHandler(DefaultInterventionHandler): # type: ignore @override async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore if sender is None: - print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}") + print(f"{self.send_color}Sending message to {recipient.metadata['name']}:{self.reset_color} {message}") else: print( - f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}" + f"{self.send_color}Sending message from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}" ) return message @override async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore if recipient is None: - print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}") + print(f"{self.response_color}Received response from {sender.metadata['name']}:{self.reset_color} {message}") else: print( - f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}" + f"{self.response_color}Received response from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}" ) return message diff --git a/src/agnext/application/_single_threaded_agent_runtime.py b/src/agnext/application/_single_threaded_agent_runtime.py index 2ce46df2e..932e35fe7 100644 --- a/src/agnext/application/_single_threaded_agent_runtime.py +++ b/src/agnext/application/_single_threaded_agent_runtime.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Awaitable, Dict, List, Mapping, Set -from ..core import Agent, AgentRuntime, CancellationToken +from ..core import Agent, AgentMetadata, AgentRuntime, CancellationToken from ..core.exceptions import MessageDroppedException from ..core.intervention import DropMessage, InterventionHandler @@ -53,11 +53,11 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._before_send = before_send def add_agent(self, agent: Agent) -> None: - agent_names = {agent.name for agent in self._agents} - if agent.name in agent_names: - raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.") + agent_names = {agent.metadata["name"] for agent in self._agents} + if agent.metadata["name"] in agent_names: + raise ValueError(f"Agent with name {agent.metadata['name']} already exists. Agent names must be unique.") - for message_type in agent.subscriptions: + for message_type in agent.metadata["subscriptions"]: if message_type not in self._per_type_subscribers: self._per_type_subscribers[message_type] = [] self._per_type_subscribers[message_type].append(agent) @@ -85,7 +85,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): if cancellation_token is None: cancellation_token = CancellationToken() - logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}") + logger.info( + f"Sending message of type {type(message).__name__} to {recipient.metadata['name']}: {message.__dict__}" + ) # event_logger.info( # MessageEvent( @@ -150,21 +152,21 @@ class SingleThreadedAgentRuntime(AgentRuntime): def save_state(self) -> Mapping[str, Any]: state: Dict[str, Dict[str, Any]] = {} for agent in self._agents: - state[agent.name] = dict(agent.save_state()) + state[agent.metadata["name"]] = dict(agent.save_state()) return state def load_state(self, state: Mapping[str, Any]) -> None: for agent in self._agents: - agent.load_state(state[agent.name]) + agent.load_state(state[agent.metadata["name"]]) async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: recipient = message_envelope.recipient assert recipient in self._agents try: - sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown" + sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}" + f"Calling message handler for {recipient.metadata['name']} with message type {type(message_envelope.message).__name__} sent by {sender_name}" ) # event_logger.info( # MessageEvent( @@ -195,12 +197,15 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: responses: List[Awaitable[Any]] = [] for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore - if message_envelope.sender is not None and agent.name == message_envelope.sender.name: + if ( + message_envelope.sender is not None + and agent.metadata["name"] == message_envelope.sender.metadata["name"] + ): continue - sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown" + sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}" + f"Calling message handler for {agent.metadata['name']} with message type {type(message_envelope.message).__name__} published by {sender_name}" ) # event_logger.info( # MessageEvent( @@ -227,14 +232,16 @@ class SingleThreadedAgentRuntime(AgentRuntime): # TODO if responses are given for a publish async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: - recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown" + recipient_name = ( + message_envelope.recipient.metadata["name"] if message_envelope.recipient is not None else "Unknown" + ) content = ( message_envelope.message.__dict__ if hasattr(message_envelope.message, "__dict__") else message_envelope.message ) logger.info( - f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}" + f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.metadata['name']}: {content}" ) # event_logger.info( # MessageEvent( @@ -292,3 +299,6 @@ class SingleThreadedAgentRuntime(AgentRuntime): # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) + + def agent_metadata(self, agent: Agent) -> AgentMetadata: + return agent.metadata diff --git a/src/agnext/application/logging/_events.py b/src/agnext/application/logging/_events.py index cec85b43a..7fd270af9 100644 --- a/src/agnext/application/logging/_events.py +++ b/src/agnext/application/logging/_events.py @@ -65,8 +65,8 @@ class MessageEvent: ) -> None: self.kwargs = kwargs self.kwargs["payload"] = payload - self.kwargs["sender"] = None if sender is None else sender.name - self.kwargs["receiver"] = None if receiver is None else receiver.name + self.kwargs["sender"] = None if sender is None else sender.metadata["name"] + self.kwargs["receiver"] = None if receiver is None else receiver.metadata["name"] self.kwargs["kind"] = kind self.kwargs["delivery_stage"] = delivery_stage self.kwargs["type"] = "Message" diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index 1575f74c9..e9cd9e806 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -170,7 +170,7 @@ class ChatCompletionAgent(TypeRoutedAgent): # Get a response from the model. hisorical_messages = await self._memory.get_messages() response = await self._client.create( - self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name), + self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]), tools=self._tools, json_output=response_format == ResponseFormat.json_object, ) @@ -185,14 +185,14 @@ class ChatCompletionAgent(TypeRoutedAgent): ): # Send a function call message to itself. response = await self._send_message( - message=FunctionCallMessage(content=response.content, source=self.name), + message=FunctionCallMessage(content=response.content, source=self.metadata["name"]), recipient=self, cancellation_token=cancellation_token, ) # Make an assistant message from the response. hisorical_messages = await self._memory.get_messages() response = await self._client.create( - self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name), + self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]), tools=self._tools, json_output=response_format == ResponseFormat.json_object, ) @@ -200,10 +200,10 @@ class ChatCompletionAgent(TypeRoutedAgent): final_response: Message if isinstance(response.content, str): # If the response is a string, return a text message. - final_response = TextMessage(content=response.content, source=self.name) + final_response = TextMessage(content=response.content, source=self.metadata["name"]) elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content): # If the response is a list of function calls, return a function call message. - final_response = FunctionCallMessage(content=response.content, source=self.name) + final_response = FunctionCallMessage(content=response.content, source=self.metadata["name"]) else: raise ValueError(f"Unexpected response: {response.content}") @@ -249,7 +249,6 @@ class ChatCompletionAgent(TypeRoutedAgent): def save_state(self) -> Mapping[str, Any]: return { - "description": self.description, "memory": self._memory.save_state(), "system_messages": self._system_messages, } @@ -257,4 +256,3 @@ class ChatCompletionAgent(TypeRoutedAgent): def load_state(self, state: Mapping[str, Any]) -> None: self._memory.load_state(state["memory"]) self._system_messages = state["system_messages"] - self._description = state["description"] diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 423c33c18..6b405ef29 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -123,16 +123,14 @@ class OpenAIAssistantAgent(TypeRoutedAgent): raise ValueError(f"Expected text content in the last message: {last_message_content}") # TODO: handle multiple text content. - return TextMessage(content=text_content[0].text.value, source=self.name) + return TextMessage(content=text_content[0].text.value, source=self.metadata["name"]) def save_state(self) -> Mapping[str, Any]: return { - "description": self.description, "assistant_id": self._assistant_id, "thread_id": self._thread_id, } def load_state(self, state: Mapping[str, Any]) -> None: - self._description = state["description"] self._assistant_id = state["assistant_id"] self._thread_id = state["thread_id"] diff --git a/src/agnext/chat/agents/user_proxy.py b/src/agnext/chat/agents/user_proxy.py index d856fa566..e650d6648 100644 --- a/src/agnext/chat/agents/user_proxy.py +++ b/src/agnext/chat/agents/user_proxy.py @@ -24,7 +24,7 @@ class UserProxyAgent(TypeRoutedAgent): async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: """Handle a publish now message. This method prompts the user for input, then publishes it.""" user_input = await self.get_user_input(self._user_input_prompt) - await self._publish_message(TextMessage(content=user_input, source=self.name)) + await self._publish_message(TextMessage(content=user_input, source=self.metadata["name"])) async def get_user_input(self, prompt: str) -> str: """Get user input from the console. Override this method to customize how user input is retrieved.""" diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 55e8bd07d..70d086c0d 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -30,11 +30,6 @@ class GroupChat(TypeRoutedAgent): self._output = output super().__init__(name, description, runtime) - @property - def subscriptions(self) -> Sequence[type]: - agent_sublists = [agent.subscriptions for agent in self._participants] - return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] - @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: self._history.clear() diff --git a/src/agnext/chat/patterns/group_chat_manager.py b/src/agnext/chat/patterns/group_chat_manager.py index 502b0020b..caed484fc 100644 --- a/src/agnext/chat/patterns/group_chat_manager.py +++ b/src/agnext/chat/patterns/group_chat_manager.py @@ -55,14 +55,14 @@ class GroupChatManager(TypeRoutedAgent): for key, value in transitions.items(): if not value: # Make sure no empty transitions are provided. - raise ValueError(f"Empty transition list provided for {key.name}.") + raise ValueError(f"Empty transition list provided for {key.metadata['name']}.") if key not in participants: # Make sure all keys are in the list of participants. - raise ValueError(f"Transition key {key.name} not found in participants.") + raise ValueError(f"Transition key {key.metadata['name']} not found in participants.") for v in value: if v not in participants: # Make sure all values are in the list of participants. - raise ValueError(f"Transition value {v.name} not found in participants.") + raise ValueError(f"Transition value {v.metadata['name']} not found in participants.") self._tranistiions = transitions self._on_message_received = on_message_received @@ -89,7 +89,9 @@ class GroupChatManager(TypeRoutedAgent): # Get the last speaker. last_speaker_name = message.source - last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None) + last_speaker_index = next( + (i for i, p in enumerate(self._participants) if p.metadata["name"] == last_speaker_name), None + ) # Select speaker. if self._client is None: diff --git a/src/agnext/chat/patterns/group_chat_utils.py b/src/agnext/chat/patterns/group_chat_utils.py index 3b8a184da..fca967835 100644 --- a/src/agnext/chat/patterns/group_chat_utils.py +++ b/src/agnext/chat/patterns/group_chat_utils.py @@ -21,10 +21,10 @@ async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agent history = "\n".join(history_messages) # Construct agent roles. - roles = "\n".join([f"{agent.name}: {agent.description}".strip() for agent in agents]) + roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents]) # Construct agent list. - participants = str([agent.name for agent in agents]) + participants = str([agent.metadata["name"] for agent in agents]) # Select the next speaker. select_speaker_prompt = f"""You are in a role play game. The following roles are available: @@ -42,7 +42,7 @@ Read the above conversation. Then select the next role from {participants} to pl if len(mentions) != 1: raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}") agent_name = list(mentions.keys())[0] - agent = next((agent for agent in agents if agent.name == agent_name), None) + agent = next((agent for agent in agents if agent.metadata["name"] == agent_name), None) assert agent is not None return agent @@ -65,16 +65,17 @@ def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict[str, int for agent in agents: # Finds agent mentions, taking word boundaries into account, # accommodates escaping underscores and underscores as spaces + name = agent.metadata["name"] regex = ( r"(?<=\W)(" - + re.escape(agent.name) + + re.escape(name) + r"|" - + re.escape(agent.name.replace("_", " ")) + + re.escape(name.replace("_", " ")) + r"|" - + re.escape(agent.name.replace("_", r"\_")) + + re.escape(name.replace("_", r"\_")) + r")(?=\W)" ) count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching if count > 0: - mentions[agent.name] = count + mentions[name] = count return mentions diff --git a/src/agnext/chat/patterns/orchestrator_chat.py b/src/agnext/chat/patterns/orchestrator_chat.py index 7b129964a..95c90f16a 100644 --- a/src/agnext/chat/patterns/orchestrator_chat.py +++ b/src/agnext/chat/patterns/orchestrator_chat.py @@ -31,7 +31,11 @@ class OrchestratorChat(TypeRoutedAgent): @property def children(self) -> Sequence[str]: - return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name] + return ( + [agent.metadata["name"] for agent in self._specialists] + + [self._orchestrator.metadata["name"]] + + [self._planner.metadata["name"]] + ) @message_handler() async def on_text_message( @@ -73,7 +77,7 @@ Some additional points to consider: # Send the task specs to the orchestrator and specialists. for agent in [*self._specialists, self._orchestrator]: - await self._send_message(TextMessage(content=task_specs, source=self.name), agent) + await self._send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent) # Inner loop. stalled_turns = 0 @@ -85,7 +89,7 @@ Some additional points to consider: if data["is_request_satisfied"]["answer"]: return TextMessage( content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}", - source=self.name, + source=self.metadata["name"], ) # Update stalled turns. @@ -111,7 +115,7 @@ Some additional points to consider: if educated_guess["has_educated_guesses"]["answer"]: return TextMessage( content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}", - source=self.name, + source=self.metadata["name"], ) # Come up with a new plan. @@ -128,13 +132,15 @@ Some additional points to consider: # Update agents. for agent in [*self._specialists, self._orchestrator]: _ = await self._send_message( - TextMessage(content=subtask, source=self.name), + TextMessage(content=subtask, source=self.metadata["name"]), agent, ) # Find the speaker. try: - speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"]) + speaker = next( + agent for agent in self._specialists if agent.metadata["name"] == data["next_speaker"]["answer"] + ) except StopIteration as e: raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e @@ -157,7 +163,7 @@ Some additional points to consider: return TextMessage( content="The task was not addressed. The maximum number of turns was reached.", - source=self.name, + source=self.metadata["name"], ) async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]: @@ -165,8 +171,8 @@ Some additional points to consider: await self._send_message(Reset(), self._planner) # A reusable description of the team. - team = "\n".join([agent.name + ": " + agent.description for agent in self._specialists]) - names = ", ".join([agent.name for agent in self._specialists]) + team = "\n".join([agent.metadata["name"] + ": " + agent.metadata["description"] for agent in self._specialists]) + names = ", ".join([agent.metadata["name"] for agent in self._specialists]) # A place to store relevant facts. facts = "" diff --git a/src/agnext/components/_type_routed_agent.py b/src/agnext/components/_type_routed_agent.py index 658747efc..72e6252ad 100644 --- a/src/agnext/components/_type_routed_agent.py +++ b/src/agnext/components/_type_routed_agent.py @@ -176,12 +176,8 @@ class TypeRoutedAgent(BaseAgent): message_handler = cast(MessageHandler[Any, Any], handler) for target_type in message_handler.target_types: self._handlers[target_type] = message_handler - - super().__init__(name, description, runtime) - - @property - def subscriptions(self) -> Sequence[Type[Any]]: - return list(self._handlers.keys()) + subscriptions = list(self._handlers.keys()) + super().__init__(name, description, subscriptions, runtime) async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: key_type: Type[Any] = type(message) # type: ignore diff --git a/src/agnext/core/__init__.py b/src/agnext/core/__init__.py index 138c7bf73..605e33bf3 100644 --- a/src/agnext/core/__init__.py +++ b/src/agnext/core/__init__.py @@ -3,9 +3,10 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon """ from ._agent import Agent +from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken -__all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"] +__all__ = ["Agent", "AgentMetadata", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"] diff --git a/src/agnext/core/_agent.py b/src/agnext/core/_agent.py index c4189e683..39caa9acb 100644 --- a/src/agnext/core/_agent.py +++ b/src/agnext/core/_agent.py @@ -1,29 +1,14 @@ -from typing import Any, Mapping, Protocol, Sequence, runtime_checkable +from typing import Any, Mapping, Protocol, runtime_checkable +from ._agent_metadata import AgentMetadata from ._cancellation_token import CancellationToken @runtime_checkable class Agent(Protocol): @property - def name(self) -> str: - """Name of the agent. - - Note: - This name should be unique within the runtime. - """ - ... - - @property - def description(self) -> str: - """Description of the agent. - - A human-readable description of the agent.""" - ... - - @property - def subscriptions(self) -> Sequence[type]: - """Types of messages that this agent can receive.""" + def metadata(self) -> AgentMetadata: + """Metadata of the agent.""" ... async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: diff --git a/src/agnext/core/_agent_metadata.py b/src/agnext/core/_agent_metadata.py new file mode 100644 index 000000000..3d9b95e0d --- /dev/null +++ b/src/agnext/core/_agent_metadata.py @@ -0,0 +1,7 @@ +from typing import Sequence, TypedDict + + +class AgentMetadata(TypedDict): + name: str + description: str + subscriptions: Sequence[type] diff --git a/src/agnext/core/_agent_runtime.py b/src/agnext/core/_agent_runtime.py index 7d134fdde..2c8e6a569 100644 --- a/src/agnext/core/_agent_runtime.py +++ b/src/agnext/core/_agent_runtime.py @@ -1,8 +1,9 @@ from asyncio import Future from typing import Any, Mapping, Protocol -from agnext.core._agent import Agent -from agnext.core._cancellation_token import CancellationToken +from ._agent import Agent +from ._agent_metadata import AgentMetadata +from ._cancellation_token import CancellationToken # Undeliverable - error @@ -41,3 +42,5 @@ class AgentRuntime(Protocol): def save_state(self) -> Mapping[str, Any]: ... def load_state(self, state: Mapping[str, Any]) -> None: ... + + def agent_metadata(self, agent: Agent) -> AgentMetadata: ... diff --git a/src/agnext/core/_base_agent.py b/src/agnext/core/_base_agent.py index 7b3aa74f1..939b0313d 100644 --- a/src/agnext/core/_base_agent.py +++ b/src/agnext/core/_base_agent.py @@ -4,6 +4,7 @@ from asyncio import Future from typing import Any, Mapping, Sequence, TypeVar from ._agent import Agent +from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._cancellation_token import CancellationToken @@ -15,24 +16,20 @@ OtherProducesT = TypeVar("OtherProducesT") class BaseAgent(ABC, Agent): - def __init__(self, name: str, description: str, router: AgentRuntime) -> None: + def __init__(self, name: str, description: str, subscriptions: Sequence[type], router: AgentRuntime) -> None: self._name = name self._description = description self._router = router + self._subscriptions = subscriptions router.add_agent(self) @property - def name(self) -> str: - return self._name - - @property - def description(self) -> str: - return self._description - - @property - @abstractmethod - def subscriptions(self) -> Sequence[type]: - return [] + def metadata(self) -> AgentMetadata: + return AgentMetadata( + name=self._name, + description=self._description, + subscriptions=self._subscriptions, + ) @abstractmethod async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ... diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 5f2edaafd..319a59c99 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -6,8 +6,8 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken class NoopAgent(BaseAgent): # type: ignore - def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A no op agent", router) + def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore + super().__init__(name, "A no op agent", [], runtime) @property def subscriptions(self) -> Sequence[type]: @@ -19,13 +19,13 @@ class NoopAgent(BaseAgent): # type: ignore @pytest.mark.asyncio async def test_agent_names_must_be_unique() -> None: - router = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime() - _agent1 = NoopAgent("name1", router) + _agent1 = NoopAgent("name1", runtime) with pytest.raises(ValueError): - _agent1_again = NoopAgent("name1", router) + _agent1_again = NoopAgent("name1", runtime) - _agent3 = NoopAgent("name3", router) + _agent3 = NoopAgent("name3", runtime) diff --git a/tests/test_state.py b/tests/test_state.py index 63c170468..f72a71893 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,7 +7,7 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken class StatefulAgent(BaseAgent): # type: ignore def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A stateful agent", runtime) + super().__init__(name, "A stateful agent", [], runtime) self.state = 0 @property