mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 23:49:13 +00:00
migrate name, desc, subs to metadata (#83)
* migrate name, desc, subs to metadata * fix quote in f string * remove file * add metadata func to runtime * format
This commit is contained in:
parent
40701a5a00
commit
89f1133831
@ -40,7 +40,7 @@ class MyAgent(TypeRoutedAgent):
|
||||
await self._publish_message(
|
||||
TextMessage(
|
||||
content=f"I received a message from {message.source}. Message received #{self._received_count}",
|
||||
source=self.name,
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
return
|
||||
else:
|
||||
# Publish user input and exit handler.
|
||||
await self._publish_message(TextMessage(content=user_input, source=self.name))
|
||||
await self._publish_message(TextMessage(content=user_input, source=self.metadata["name"]))
|
||||
return
|
||||
|
||||
|
||||
|
||||
@ -71,13 +71,13 @@ Use the following JSON format to provide your thought on the latest message and
|
||||
# Get a response from the model.
|
||||
raw_response = await self._client.create(
|
||||
self._system_messages
|
||||
+ convert_messages_to_llm_messages(await self._memory.get_messages(), self_name=self.name),
|
||||
+ convert_messages_to_llm_messages(await self._memory.get_messages(), self_name=self.metadata["name"]),
|
||||
json_output=True,
|
||||
)
|
||||
assert isinstance(raw_response.content, str)
|
||||
|
||||
# Save the response to memory.
|
||||
await self._memory.add_message(ChatRoomMessage(source=self.name, content=raw_response.content))
|
||||
await self._memory.add_message(ChatRoomMessage(source=self.metadata["name"], content=raw_response.content))
|
||||
|
||||
# Parse the response.
|
||||
data = json.loads(raw_response.content)
|
||||
@ -86,8 +86,8 @@ Use the following JSON format to provide your thought on the latest message and
|
||||
|
||||
# Publish the response if needed.
|
||||
if respond is True or str(respond).lower().strip() == "true":
|
||||
await self._publish_message(ChatRoomMessage(source=self.name, content=str(response)))
|
||||
print(f"{sep}\n{self._color}{self.name}:{Style.RESET_ALL}\n{response}")
|
||||
await self._publish_message(ChatRoomMessage(source=self.metadata["name"], content=str(response)))
|
||||
print(f"{sep}\n{self._color}{self.metadata['name']}:{Style.RESET_ALL}\n{response}")
|
||||
|
||||
|
||||
# Define a chat room with participants -- the runtime is the chat room.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
|
||||
@ -18,7 +18,7 @@ class Inner(TypeRoutedAgent): # type: ignore
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
||||
return MessageType(body=f"Inner: {message.body}", sender=self.name)
|
||||
return MessageType(body=f"Inner: {message.body}", sender=self.metadata["name"])
|
||||
|
||||
|
||||
class Outer(TypeRoutedAgent): # type: ignore
|
||||
@ -31,7 +31,7 @@ class Outer(TypeRoutedAgent): # type: ignore
|
||||
inner_response = self._send_message(message, self._inner)
|
||||
inner_message = await inner_response
|
||||
assert isinstance(inner_message, MessageType)
|
||||
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)
|
||||
return MessageType(body=f"Outer: {inner_message.body}", sender=self.metadata["name"])
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@ -60,20 +60,20 @@ class LoggingHandler(DefaultInterventionHandler): # type: ignore
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore
|
||||
if sender is None:
|
||||
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
|
||||
print(f"{self.send_color}Sending message to {recipient.metadata['name']}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
f"{self.send_color}Sending message from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore
|
||||
if recipient is None:
|
||||
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
|
||||
print(f"{self.response_color}Received response from {sender.metadata['name']}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
f"{self.response_color}Received response from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Dict, List, Mapping, Set
|
||||
|
||||
from ..core import Agent, AgentRuntime, CancellationToken
|
||||
from ..core import Agent, AgentMetadata, AgentRuntime, CancellationToken
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
|
||||
@ -53,11 +53,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._before_send = before_send
|
||||
|
||||
def add_agent(self, agent: Agent) -> None:
|
||||
agent_names = {agent.name for agent in self._agents}
|
||||
if agent.name in agent_names:
|
||||
raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.")
|
||||
agent_names = {agent.metadata["name"] for agent in self._agents}
|
||||
if agent.metadata["name"] in agent_names:
|
||||
raise ValueError(f"Agent with name {agent.metadata['name']} already exists. Agent names must be unique.")
|
||||
|
||||
for message_type in agent.subscriptions:
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
if message_type not in self._per_type_subscribers:
|
||||
self._per_type_subscribers[message_type] = []
|
||||
self._per_type_subscribers[message_type].append(agent)
|
||||
@ -85,7 +85,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
|
||||
logger.info(
|
||||
f"Sending message of type {type(message).__name__} to {recipient.metadata['name']}: {message.__dict__}"
|
||||
)
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
@ -150,21 +152,21 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent in self._agents:
|
||||
state[agent.name] = dict(agent.save_state())
|
||||
state[agent.metadata["name"]] = dict(agent.save_state())
|
||||
return state
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent in self._agents:
|
||||
agent.load_state(state[agent.name])
|
||||
agent.load_state(state[agent.metadata["name"]])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
recipient = message_envelope.recipient
|
||||
assert recipient in self._agents
|
||||
|
||||
try:
|
||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||
sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
f"Calling message handler for {recipient.metadata['name']} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
@ -195,12 +197,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
|
||||
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
|
||||
if (
|
||||
message_envelope.sender is not None
|
||||
and agent.metadata["name"] == message_envelope.sender.metadata["name"]
|
||||
):
|
||||
continue
|
||||
|
||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||
sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
f"Calling message handler for {agent.metadata['name']} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
@ -227,14 +232,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
|
||||
recipient_name = (
|
||||
message_envelope.recipient.metadata["name"] if message_envelope.recipient is not None else "Unknown"
|
||||
)
|
||||
content = (
|
||||
message_envelope.message.__dict__
|
||||
if hasattr(message_envelope.message, "__dict__")
|
||||
else message_envelope.message
|
||||
)
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}"
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.metadata['name']}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
@ -292,3 +299,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def agent_metadata(self, agent: Agent) -> AgentMetadata:
|
||||
return agent.metadata
|
||||
|
||||
@ -65,8 +65,8 @@ class MessageEvent:
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else sender.name
|
||||
self.kwargs["receiver"] = None if receiver is None else receiver.name
|
||||
self.kwargs["sender"] = None if sender is None else sender.metadata["name"]
|
||||
self.kwargs["receiver"] = None if receiver is None else receiver.metadata["name"]
|
||||
self.kwargs["kind"] = kind
|
||||
self.kwargs["delivery_stage"] = delivery_stage
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
@ -170,7 +170,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
# Get a response from the model.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name),
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
@ -185,14 +185,14 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
):
|
||||
# Send a function call message to itself.
|
||||
response = await self._send_message(
|
||||
message=FunctionCallMessage(content=response.content, source=self.name),
|
||||
message=FunctionCallMessage(content=response.content, source=self.metadata["name"]),
|
||||
recipient=self,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Make an assistant message from the response.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name),
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
@ -200,10 +200,10 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
final_response: Message
|
||||
if isinstance(response.content, str):
|
||||
# If the response is a string, return a text message.
|
||||
final_response = TextMessage(content=response.content, source=self.name)
|
||||
final_response = TextMessage(content=response.content, source=self.metadata["name"])
|
||||
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
|
||||
# If the response is a list of function calls, return a function call message.
|
||||
final_response = FunctionCallMessage(content=response.content, source=self.name)
|
||||
final_response = FunctionCallMessage(content=response.content, source=self.metadata["name"])
|
||||
else:
|
||||
raise ValueError(f"Unexpected response: {response.content}")
|
||||
|
||||
@ -249,7 +249,6 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"description": self.description,
|
||||
"memory": self._memory.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
@ -257,4 +256,3 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state(state["memory"])
|
||||
self._system_messages = state["system_messages"]
|
||||
self._description = state["description"]
|
||||
|
||||
@ -123,16 +123,14 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
|
||||
raise ValueError(f"Expected text content in the last message: {last_message_content}")
|
||||
|
||||
# TODO: handle multiple text content.
|
||||
return TextMessage(content=text_content[0].text.value, source=self.name)
|
||||
return TextMessage(content=text_content[0].text.value, source=self.metadata["name"])
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"description": self.description,
|
||||
"assistant_id": self._assistant_id,
|
||||
"thread_id": self._thread_id,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._description = state["description"]
|
||||
self._assistant_id = state["assistant_id"]
|
||||
self._thread_id = state["thread_id"]
|
||||
|
||||
@ -24,7 +24,7 @@ class UserProxyAgent(TypeRoutedAgent):
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
|
||||
user_input = await self.get_user_input(self._user_input_prompt)
|
||||
await self._publish_message(TextMessage(content=user_input, source=self.name))
|
||||
await self._publish_message(TextMessage(content=user_input, source=self.metadata["name"]))
|
||||
|
||||
async def get_user_input(self, prompt: str) -> str:
|
||||
"""Get user input from the console. Override this method to customize how user input is retrieved."""
|
||||
|
||||
@ -30,11 +30,6 @@ class GroupChat(TypeRoutedAgent):
|
||||
self._output = output
|
||||
super().__init__(name, description, runtime)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
agent_sublists = [agent.subscriptions for agent in self._participants]
|
||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
self._history.clear()
|
||||
|
||||
@ -55,14 +55,14 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
for key, value in transitions.items():
|
||||
if not value:
|
||||
# Make sure no empty transitions are provided.
|
||||
raise ValueError(f"Empty transition list provided for {key.name}.")
|
||||
raise ValueError(f"Empty transition list provided for {key.metadata['name']}.")
|
||||
if key not in participants:
|
||||
# Make sure all keys are in the list of participants.
|
||||
raise ValueError(f"Transition key {key.name} not found in participants.")
|
||||
raise ValueError(f"Transition key {key.metadata['name']} not found in participants.")
|
||||
for v in value:
|
||||
if v not in participants:
|
||||
# Make sure all values are in the list of participants.
|
||||
raise ValueError(f"Transition value {v.name} not found in participants.")
|
||||
raise ValueError(f"Transition value {v.metadata['name']} not found in participants.")
|
||||
self._tranistiions = transitions
|
||||
self._on_message_received = on_message_received
|
||||
|
||||
@ -89,7 +89,9 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
|
||||
# Get the last speaker.
|
||||
last_speaker_name = message.source
|
||||
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
|
||||
last_speaker_index = next(
|
||||
(i for i, p in enumerate(self._participants) if p.metadata["name"] == last_speaker_name), None
|
||||
)
|
||||
|
||||
# Select speaker.
|
||||
if self._client is None:
|
||||
|
||||
@ -21,10 +21,10 @@ async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agent
|
||||
history = "\n".join(history_messages)
|
||||
|
||||
# Construct agent roles.
|
||||
roles = "\n".join([f"{agent.name}: {agent.description}".strip() for agent in agents])
|
||||
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
|
||||
|
||||
# Construct agent list.
|
||||
participants = str([agent.name for agent in agents])
|
||||
participants = str([agent.metadata["name"] for agent in agents])
|
||||
|
||||
# Select the next speaker.
|
||||
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
|
||||
@ -42,7 +42,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
if len(mentions) != 1:
|
||||
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
|
||||
agent_name = list(mentions.keys())[0]
|
||||
agent = next((agent for agent in agents if agent.name == agent_name), None)
|
||||
agent = next((agent for agent in agents if agent.metadata["name"] == agent_name), None)
|
||||
assert agent is not None
|
||||
return agent
|
||||
|
||||
@ -65,16 +65,17 @@ def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict[str, int
|
||||
for agent in agents:
|
||||
# Finds agent mentions, taking word boundaries into account,
|
||||
# accommodates escaping underscores and underscores as spaces
|
||||
name = agent.metadata["name"]
|
||||
regex = (
|
||||
r"(?<=\W)("
|
||||
+ re.escape(agent.name)
|
||||
+ re.escape(name)
|
||||
+ r"|"
|
||||
+ re.escape(agent.name.replace("_", " "))
|
||||
+ re.escape(name.replace("_", " "))
|
||||
+ r"|"
|
||||
+ re.escape(agent.name.replace("_", r"\_"))
|
||||
+ re.escape(name.replace("_", r"\_"))
|
||||
+ r")(?=\W)"
|
||||
)
|
||||
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
|
||||
if count > 0:
|
||||
mentions[agent.name] = count
|
||||
mentions[name] = count
|
||||
return mentions
|
||||
|
||||
@ -31,7 +31,11 @@ class OrchestratorChat(TypeRoutedAgent):
|
||||
|
||||
@property
|
||||
def children(self) -> Sequence[str]:
|
||||
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
|
||||
return (
|
||||
[agent.metadata["name"] for agent in self._specialists]
|
||||
+ [self._orchestrator.metadata["name"]]
|
||||
+ [self._planner.metadata["name"]]
|
||||
)
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(
|
||||
@ -73,7 +77,7 @@ Some additional points to consider:
|
||||
|
||||
# Send the task specs to the orchestrator and specialists.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
await self._send_message(TextMessage(content=task_specs, source=self.name), agent)
|
||||
await self._send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent)
|
||||
|
||||
# Inner loop.
|
||||
stalled_turns = 0
|
||||
@ -85,7 +89,7 @@ Some additional points to consider:
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
|
||||
source=self.name,
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
# Update stalled turns.
|
||||
@ -111,7 +115,7 @@ Some additional points to consider:
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
|
||||
source=self.name,
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
@ -128,13 +132,15 @@ Some additional points to consider:
|
||||
# Update agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
_ = await self._send_message(
|
||||
TextMessage(content=subtask, source=self.name),
|
||||
TextMessage(content=subtask, source=self.metadata["name"]),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
|
||||
speaker = next(
|
||||
agent for agent in self._specialists if agent.metadata["name"] == data["next_speaker"]["answer"]
|
||||
)
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
@ -157,7 +163,7 @@ Some additional points to consider:
|
||||
|
||||
return TextMessage(
|
||||
content="The task was not addressed. The maximum number of turns was reached.",
|
||||
source=self.name,
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
@ -165,8 +171,8 @@ Some additional points to consider:
|
||||
await self._send_message(Reset(), self._planner)
|
||||
|
||||
# A reusable description of the team.
|
||||
team = "\n".join([agent.name + ": " + agent.description for agent in self._specialists])
|
||||
names = ", ".join([agent.name for agent in self._specialists])
|
||||
team = "\n".join([agent.metadata["name"] + ": " + agent.metadata["description"] for agent in self._specialists])
|
||||
names = ", ".join([agent.metadata["name"] for agent in self._specialists])
|
||||
|
||||
# A place to store relevant facts.
|
||||
facts = ""
|
||||
|
||||
@ -176,12 +176,8 @@ class TypeRoutedAgent(BaseAgent):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type] = message_handler
|
||||
|
||||
super().__init__(name, description, runtime)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Type[Any]]:
|
||||
return list(self._handlers.keys())
|
||||
subscriptions = list(self._handlers.keys())
|
||||
super().__init__(name, description, subscriptions, runtime)
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
|
||||
@ -3,9 +3,10 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon
|
||||
"""
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
__all__ = ["Agent", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"]
|
||||
__all__ = ["Agent", "AgentMetadata", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"]
|
||||
|
||||
@ -1,29 +1,14 @@
|
||||
from typing import Any, Mapping, Protocol, Sequence, runtime_checkable
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the agent.
|
||||
|
||||
Note:
|
||||
This name should be unique within the runtime.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the agent.
|
||||
|
||||
A human-readable description of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
"""Types of messages that this agent can receive."""
|
||||
def metadata(self) -> AgentMetadata:
|
||||
"""Metadata of the agent."""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
|
||||
7
src/agnext/core/_agent_metadata.py
Normal file
7
src/agnext/core/_agent_metadata.py
Normal file
@ -0,0 +1,7 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
|
||||
class AgentMetadata(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
subscriptions: Sequence[type]
|
||||
@ -1,8 +1,9 @@
|
||||
from asyncio import Future
|
||||
from typing import Any, Mapping, Protocol
|
||||
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.core._cancellation_token import CancellationToken
|
||||
from ._agent import Agent
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
@ -41,3 +42,5 @@ class AgentRuntime(Protocol):
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
def agent_metadata(self, agent: Agent) -> AgentMetadata: ...
|
||||
|
||||
@ -4,6 +4,7 @@ from asyncio import Future
|
||||
from typing import Any, Mapping, Sequence, TypeVar
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
@ -15,24 +16,20 @@ OtherProducesT = TypeVar("OtherProducesT")
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
def __init__(self, name: str, description: str, router: AgentRuntime) -> None:
|
||||
def __init__(self, name: str, description: str, subscriptions: Sequence[type], router: AgentRuntime) -> None:
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._router = router
|
||||
self._subscriptions = subscriptions
|
||||
router.add_agent(self)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
def metadata(self) -> AgentMetadata:
|
||||
return AgentMetadata(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
subscriptions=self._subscriptions,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
@ -6,8 +6,8 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken
|
||||
|
||||
|
||||
class NoopAgent(BaseAgent): # type: ignore
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
|
||||
super().__init__(name, "A no op agent", router)
|
||||
def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore
|
||||
super().__init__(name, "A no op agent", [], runtime)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
@ -19,13 +19,13 @@ class NoopAgent(BaseAgent): # type: ignore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
router = SingleThreadedAgentRuntime()
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
_agent1 = NoopAgent("name1", router)
|
||||
_agent1 = NoopAgent("name1", runtime)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1_again = NoopAgent("name1", router)
|
||||
_agent1_again = NoopAgent("name1", runtime)
|
||||
|
||||
_agent3 = NoopAgent("name3", router)
|
||||
_agent3 = NoopAgent("name3", runtime)
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken
|
||||
|
||||
class StatefulAgent(BaseAgent): # type: ignore
|
||||
def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore
|
||||
super().__init__(name, "A stateful agent", runtime)
|
||||
super().__init__(name, "A stateful agent", [], runtime)
|
||||
self.state = 0
|
||||
|
||||
@property
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user