mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
Initial impl of topics and subscriptions (#350)
* initial impl of topics and subscriptions * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add topic in context * migrate * migrate code for topics * migrate team one * edit notebooks * formatting * fix imports * Build proto * Fix circular import --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
4ba7e84721
commit
e1a823fb6d
@ -2,7 +2,6 @@ syntax = "proto3";
|
||||
|
||||
package agents;
|
||||
|
||||
// TODO: update
|
||||
message AgentId {
|
||||
string name = 1;
|
||||
string namespace = 2;
|
||||
@ -25,10 +24,11 @@ message RpcResponse {
|
||||
}
|
||||
|
||||
message Event {
|
||||
string namespace = 1;
|
||||
string type = 2;
|
||||
string data = 3;
|
||||
map<string, string> metadata = 4;
|
||||
string topic_type = 1;
|
||||
string topic_source = 2;
|
||||
string data_type = 3;
|
||||
string data = 4;
|
||||
map<string, string> metadata = 5;
|
||||
}
|
||||
|
||||
message RegisterAgentType {
|
||||
|
||||
@ -45,7 +45,7 @@
|
||||
"\n",
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
|
||||
"from langchain_core.messages import HumanMessage, SystemMessage\n",
|
||||
"from langchain_core.tools import tool # pyright: ignore\n",
|
||||
@ -195,7 +195,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"langgraph_tool_use_agent\",\n",
|
||||
" lambda: LangGraphToolUseAgent(\n",
|
||||
" \"Tool use agent\",\n",
|
||||
@ -214,7 +214,8 @@
|
||||
" # ),\n",
|
||||
" [get_weather],\n",
|
||||
" ),\n",
|
||||
")"
|
||||
")\n",
|
||||
"agent = AgentId(\"langgraph_tool_use_agent\", key=\"default\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -44,7 +44,7 @@
|
||||
"\n",
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
|
||||
"from llama_index.core import Settings\n",
|
||||
"from llama_index.core.agent import ReActAgent\n",
|
||||
@ -221,7 +221,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"chat_agent\",\n",
|
||||
" lambda: LlamaIndexAgent(\n",
|
||||
" description=\"Llama Index Agent\",\n",
|
||||
@ -233,7 +233,8 @@
|
||||
" verbose=True,\n",
|
||||
" ),\n",
|
||||
" ),\n",
|
||||
")"
|
||||
")\n",
|
||||
"agent = AgentId(\"chat_agent\", \"default\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -105,7 +105,7 @@
|
||||
"\n",
|
||||
"import aiofiles\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"from openai import AsyncAssistantEventHandler, AsyncClient\n",
|
||||
"from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch\n",
|
||||
"\n",
|
||||
@ -390,7 +390,7 @@
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"assistant\",\n",
|
||||
" lambda: OpenAIAssistantAgent(\n",
|
||||
" description=\"OpenAI Assistant Agent\",\n",
|
||||
@ -399,7 +399,8 @@
|
||||
" thread_id=thread.id,\n",
|
||||
" assistant_event_handler_factory=lambda: EventHandler(),\n",
|
||||
" ),\n",
|
||||
")"
|
||||
")\n",
|
||||
"agent = AgentId(\"assistant\", \"default\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -57,7 +57,7 @@
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from agnext.core import BaseAgent, MessageContext\n",
|
||||
"from agnext.core import AgentId, BaseAgent, MessageContext\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
@ -67,7 +67,7 @@
|
||||
"\n",
|
||||
"class MyAgent(BaseAgent):\n",
|
||||
" def __init__(self) -> None:\n",
|
||||
" super().__init__(\"MyAgent\", subscriptions=[\"MyMessage\"])\n",
|
||||
" super().__init__(\"MyAgent\")\n",
|
||||
"\n",
|
||||
" async def on_message(self, message: MyMessage, ctx: MessageContext) -> None:\n",
|
||||
" print(f\"Received message: {message.content}\")"
|
||||
@ -133,7 +133,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_id = await runtime.get(\"my_agent\")\n",
|
||||
"agent_id = AgentId(\"my_agent\", \"default\")\n",
|
||||
"run_context = runtime.start() # Start processing messages in the background.\n",
|
||||
"await runtime.send_message(MyMessage(content=\"Hello, World!\"), agent_id)\n",
|
||||
"await run_context.stop() # Stop processing messages in the background."
|
||||
|
||||
@ -83,7 +83,7 @@
|
||||
"source": [
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyAgent(TypeRoutedAgent):\n",
|
||||
@ -110,7 +110,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"agent = await runtime.register_and_get(\"my_agent\", lambda: MyAgent(\"My Agent\"))"
|
||||
"await runtime.register(\"my_agent\", lambda: MyAgent(\"My Agent\"))\n",
|
||||
"agent = AgentId(\"my_agent\", \"default\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -185,7 +186,7 @@
|
||||
"\n",
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
@ -200,9 +201,9 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class OuterAgent(TypeRoutedAgent):\n",
|
||||
" def __init__(self, description: str, inner_agent_id: AgentId):\n",
|
||||
" def __init__(self, description: str, inner_agent_type: str):\n",
|
||||
" super().__init__(description)\n",
|
||||
" self.inner_agent_id = inner_agent_id\n",
|
||||
" self.inner_agent_id = AgentId(inner_agent_type, self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
|
||||
@ -238,9 +239,10 @@
|
||||
],
|
||||
"source": [
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"inner = await runtime.register_and_get(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n",
|
||||
"outer = await runtime.register_and_get(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", inner))\n",
|
||||
"await runtime.register(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n",
|
||||
"await runtime.register(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", \"InnerAgent\"))\n",
|
||||
"run_context = runtime.start()\n",
|
||||
"outer = AgentId(\"outer_agent\", \"default\")\n",
|
||||
"await runtime.send_message(Message(content=\"Hello, World!\"), outer)\n",
|
||||
"await run_context.stop_when_idle()"
|
||||
]
|
||||
@ -294,14 +296,17 @@
|
||||
"source": [
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components import TypeRoutedAgent, message_handler\n",
|
||||
"from agnext.core import MessageContext\n",
|
||||
"from agnext.core import MessageContext, TopicId\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class BroadcastingAgent(TypeRoutedAgent):\n",
|
||||
" @message_handler\n",
|
||||
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
|
||||
" # Publish a message to all agents in the same namespace.\n",
|
||||
" await self.publish_message(Message(f\"Publishing a message: {message.content}!\"))\n",
|
||||
" assert ctx.topic_id is not None\n",
|
||||
" await self.publish_message(\n",
|
||||
" Message(f\"Publishing a message: {message.content}!\"), topic_id=TopicId(\"deafult\", self.id.key)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ReceivingAgent(TypeRoutedAgent):\n",
|
||||
@ -332,11 +337,15 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from agnext.components import TypeSubscription\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
|
||||
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
|
||||
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
|
||||
"run_context = runtime.start()\n",
|
||||
"await runtime.send_message(Message(\"Hello, World!\"), broadcaster)\n",
|
||||
"await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n",
|
||||
"await run_context.stop_when_idle()"
|
||||
]
|
||||
},
|
||||
@ -367,10 +376,12 @@
|
||||
"# Replace send_message with publish_message in the above example.\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
|
||||
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
|
||||
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
|
||||
"run_context = runtime.start()\n",
|
||||
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), namespace=\"default\")\n",
|
||||
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=TopicId(\"default\", \"default\"))\n",
|
||||
"await run_context.stop_when_idle()"
|
||||
]
|
||||
},
|
||||
|
||||
@ -318,8 +318,10 @@
|
||||
],
|
||||
"source": [
|
||||
"# Create the runtime and register the agent.\n",
|
||||
"from agnext.core import AgentId\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"simple-agent\",\n",
|
||||
" lambda: SimpleAgent(\n",
|
||||
" OpenAIChatCompletionClient(\n",
|
||||
@ -332,7 +334,7 @@
|
||||
"run_context = runtime.start()\n",
|
||||
"# Send a message to the agent and get the response.\n",
|
||||
"message = Message(\"Hello, what are some fun things to do in Seattle?\")\n",
|
||||
"response = await runtime.send_message(message, agent)\n",
|
||||
"response = await runtime.send_message(message, AgentId(\"simple-agent\", \"default\"))\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop the runtime processing messages.\n",
|
||||
"await run_context.stop()"
|
||||
|
||||
@ -131,7 +131,7 @@
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from agnext.core import MessageContext"
|
||||
"from agnext.core import MessageContext, TopicId"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -201,7 +201,7 @@
|
||||
" # Store the code review task in the session memory.\n",
|
||||
" self._session_memory[session_id].append(code_review_task)\n",
|
||||
" # Publish a code review task.\n",
|
||||
" await self.publish_message(code_review_task)\n",
|
||||
" await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n",
|
||||
@ -220,7 +220,8 @@
|
||||
" code=review_request.code,\n",
|
||||
" task=review_request.code_writing_task,\n",
|
||||
" review=message.review,\n",
|
||||
" )\n",
|
||||
" ),\n",
|
||||
" topic_id=TopicId(\"default\", self.id.key),\n",
|
||||
" )\n",
|
||||
" print(\"Code Writing Result:\")\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
@ -259,7 +260,7 @@
|
||||
" # Store the code review task in the session memory.\n",
|
||||
" self._session_memory[message.session_id].append(code_review_task)\n",
|
||||
" # Publish a new code review task.\n",
|
||||
" await self.publish_message(code_review_task)\n",
|
||||
" await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n",
|
||||
"\n",
|
||||
" def _extract_code_block(self, markdown_text: str) -> Union[str, None]:\n",
|
||||
" pattern = r\"```(\\w+)\\n(.*?)\\n```\"\n",
|
||||
@ -360,7 +361,7 @@
|
||||
" # Store the review result in the session memory.\n",
|
||||
" self._session_memory[message.session_id].append(result)\n",
|
||||
" # Publish the review result.\n",
|
||||
" await self.publish_message(result)"
|
||||
" await self.publish_message(result, topic_id=TopicId(\"default\", self.id.key))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -494,6 +495,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from agnext.application import SingleThreadedAgentRuntime\n",
|
||||
"from agnext.components._type_subscription import TypeSubscription\n",
|
||||
"from agnext.components.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
@ -501,14 +503,16 @@
|
||||
" \"ReviewerAgent\",\n",
|
||||
" lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n",
|
||||
")\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"CoderAgent\"))\n",
|
||||
"await runtime.register(\n",
|
||||
" \"CoderAgent\",\n",
|
||||
" lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n",
|
||||
")\n",
|
||||
"await runtime.add_subscription(TypeSubscription(\"default\", \"ReviewerAgent\"))\n",
|
||||
"run_context = runtime.start()\n",
|
||||
"await runtime.publish_message(\n",
|
||||
" message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n",
|
||||
" namespace=\"default\",\n",
|
||||
" topic_id=TopicId(\"default\", \"default\"),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Keep processing messages until idle.\n",
|
||||
|
||||
@ -148,7 +148,7 @@
|
||||
")\n",
|
||||
"from agnext.components.tool_agent import ToolAgent, ToolException\n",
|
||||
"from agnext.components.tools import FunctionTool, Tool, ToolSchema\n",
|
||||
"from agnext.core import AgentId, MessageContext\n",
|
||||
"from agnext.core import AgentId, AgentInstantiationContext, MessageContext\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
@ -239,19 +239,19 @@
|
||||
"# Create the tools.\n",
|
||||
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
|
||||
"# Register the agents.\n",
|
||||
"tool_executor_agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"tool-executor-agent\",\n",
|
||||
" lambda: ToolAgent(\n",
|
||||
" description=\"Tool Executor Agent\",\n",
|
||||
" tools=tools,\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"tool_use_agent = await runtime.register_and_get(\n",
|
||||
"await runtime.register(\n",
|
||||
" \"tool-use-agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[tool.schema for tool in tools],\n",
|
||||
" tool_agent=tool_executor_agent,\n",
|
||||
" tool_agent=AgentId(\"tool-executor-agent\", AgentInstantiationContext.current_agent_id().key),\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
@ -282,6 +282,7 @@
|
||||
"# Start processing messages.\n",
|
||||
"run_context = runtime.start()\n",
|
||||
"# Send a direct message to the tool agent.\n",
|
||||
"tool_use_agent = AgentId(\"tool-use-agent\", \"default\")\n",
|
||||
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop processing messages.\n",
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Any, Callable, List, Literal
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import AgentId, MessageContext
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool # pyright: ignore
|
||||
from langchain_openai import ChatOpenAI
|
||||
@ -110,7 +110,7 @@ async def main() -> None:
|
||||
# Create runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
# Register the agent.
|
||||
agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"langgraph_tool_use_agent",
|
||||
lambda: LangGraphToolUseAgent(
|
||||
"Tool use agent",
|
||||
@ -118,6 +118,7 @@ async def main() -> None:
|
||||
[get_weather],
|
||||
),
|
||||
)
|
||||
agent = AgentId("langgraph_tool_use_agent", key="default")
|
||||
# Start the runtime.
|
||||
run_context = runtime.start()
|
||||
# Send a message to the agent and get a response.
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import List, Optional
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import AgentId, MessageContext
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.agent import ReActAgent
|
||||
from llama_index.core.agent.runner.base import AgentRunner
|
||||
@ -119,10 +119,11 @@ async def main() -> None:
|
||||
tools=[wikipedia_tool], llm=llm, max_iterations=8, memory=memory, verbose=True
|
||||
)
|
||||
|
||||
agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"chat_agent",
|
||||
lambda: LlamaIndexAgent("Chat agent", llama_index_agent=llama_index_agent),
|
||||
)
|
||||
agent = AgentId("chat_agent", key="default")
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
||||
@ -110,8 +110,9 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
# Generate a response.
|
||||
response = await self._generate_response(message.response_format, ctx)
|
||||
|
||||
assert ctx.topic_id is not None
|
||||
# Publish the response.
|
||||
await self.publish_message(response)
|
||||
await self.publish_message(response, topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler()
|
||||
async def on_tool_call_message(
|
||||
|
||||
@ -7,8 +7,7 @@ from agnext.components import (
|
||||
message_handler,
|
||||
)
|
||||
from agnext.components.memory import ChatMemory
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core._cancellation_token import CancellationToken
|
||||
from agnext.core import CancellationToken, MessageContext
|
||||
|
||||
from ..types import (
|
||||
Message,
|
||||
@ -58,7 +57,8 @@ class ImageGenerationAgent(TypeRoutedAgent):
|
||||
image is published as a MultiModalMessage."""
|
||||
|
||||
response = await self._generate_response(ctx.cancellation_token)
|
||||
await self.publish_message(response)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(response, topic_id=ctx.topic_id)
|
||||
|
||||
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
|
||||
messages = await self._memory.get_messages()
|
||||
|
||||
@ -80,7 +80,8 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
|
||||
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
|
||||
"""Handle a publish now message. This method generates a response and publishes it."""
|
||||
response = await self._generate_response(message.response_format, ctx.cancellation_token)
|
||||
await self.publish_message(response)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(response, ctx.topic_id)
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
|
||||
@ -23,7 +23,8 @@ class UserProxyAgent(TypeRoutedAgent):
|
||||
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
|
||||
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
|
||||
user_input = await self.get_user_input(self._user_input_prompt)
|
||||
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id)
|
||||
|
||||
async def get_user_input(self, prompt: str) -> str:
|
||||
"""Get user input from the console. Override this method to customize how user input is retrieved."""
|
||||
|
||||
@ -12,7 +12,7 @@ from dataclasses import dataclass
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, MessageContext
|
||||
from agnext.core import AgentId, AgentInstantiationContext, MessageContext
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,8 +45,9 @@ class Outer(TypeRoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
inner = await runtime.register_and_get("inner", Inner)
|
||||
outer = await runtime.register_and_get("outer", lambda: Outer(inner))
|
||||
await runtime.register("inner", Inner)
|
||||
await runtime.register("outer", lambda: Outer(AgentId("outer", AgentInstantiationContext.current_agent_id().key)))
|
||||
outer = AgentId("outer", "default")
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import AgentId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -45,10 +46,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"chat_agent",
|
||||
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-4o-mini")),
|
||||
)
|
||||
agent = AgentId("chat_agent", "default")
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from typing import List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@ -25,6 +26,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import AgentId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -69,7 +71,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
llm_messages.append(UserMessage(content=m.content, source=m.source))
|
||||
response = await self._model_client.create(self._system_messages + llm_messages)
|
||||
assert isinstance(response.content, str)
|
||||
await self.publish_message(Message(content=response.content, source=self.metadata["type"]))
|
||||
|
||||
if ctx.topic_id is not None:
|
||||
await self.publish_message(
|
||||
Message(content=response.content, source=self.metadata["type"]), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -77,7 +83,7 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the agents.
|
||||
jack = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Jack",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Jack a comedian",
|
||||
@ -88,7 +94,8 @@ async def main() -> None:
|
||||
termination_word="TERMINATE",
|
||||
),
|
||||
)
|
||||
await runtime.register_and_get(
|
||||
await runtime.add_subscription(TypeSubscription("default", "Jack"))
|
||||
await runtime.register(
|
||||
"Cathy",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Cathy a poet",
|
||||
@ -99,12 +106,13 @@ async def main() -> None:
|
||||
termination_word="TERMINATE",
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "Cathy"))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a message to Jack to start the conversation.
|
||||
message = Message(content="Can you tell me something fun about SF?", source="User")
|
||||
await runtime.send_message(message, jack)
|
||||
await runtime.send_message(message, AgentId("jack", "default"))
|
||||
|
||||
# Process messages.
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
@ -13,7 +13,7 @@ import aiofiles
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, AgentRuntime, CancellationToken
|
||||
from agnext.core import AgentId, AgentRuntime, MessageContext
|
||||
from openai import AsyncAssistantEventHandler
|
||||
from openai.types.beta.thread import ToolResources
|
||||
from openai.types.beta.threads import Message, Text, TextDelta
|
||||
@ -22,6 +22,7 @@ from typing_extensions import override
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from agnext.core import AgentInstantiationContext
|
||||
from common.agents import OpenAIAssistantAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
@ -30,7 +31,7 @@ from common.types import PublishNow, TextMessage
|
||||
sep = "-" * 50
|
||||
|
||||
|
||||
class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
class UserProxyAgent(TypeRoutedAgent):
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
client: openai.AsyncClient, # type: ignore
|
||||
@ -47,7 +48,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
self._vector_store_id = vector_store_id
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:
|
||||
# TODO: render image if message has image.
|
||||
# print(f"{message.source}: {message.content}")
|
||||
pass
|
||||
@ -57,7 +58,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
return await loop.run_in_executor(None, input, prompt)
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
|
||||
while True:
|
||||
user_input = await self._get_user_input(f"\n{sep}\nYou: ")
|
||||
# Parse upload file command '[upload code_interpreter | file_search filename]'.
|
||||
@ -108,7 +109,10 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
return
|
||||
else:
|
||||
# Publish user input and exit handler.
|
||||
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@ -166,7 +170,7 @@ class EventHandler(AsyncAssistantEventHandler):
|
||||
print("\n".join(citations))
|
||||
|
||||
|
||||
async def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
async def assistant_chat(runtime: AgentRuntime) -> str:
|
||||
oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
@ -177,7 +181,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
thread = openai.beta.threads.create(
|
||||
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
|
||||
)
|
||||
assistant = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Assistant",
|
||||
lambda: OpenAIAssistantAgent(
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
@ -188,7 +192,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
),
|
||||
)
|
||||
|
||||
user = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"User",
|
||||
lambda: UserProxyAgent(
|
||||
client=openai.AsyncClient(),
|
||||
@ -203,10 +207,13 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
participants=[assistant, user],
|
||||
participants=[
|
||||
AgentId("Assistant", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("User", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
),
|
||||
)
|
||||
return user
|
||||
return "User"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -229,7 +236,7 @@ Type "exit" to exit the chat.
|
||||
_run_context = runtime.start()
|
||||
print(usage)
|
||||
# Request the user to start the conversation.
|
||||
await runtime.send_message(PublishNow(), user)
|
||||
await runtime.send_message(PublishNow(), AgentId(user, "default"))
|
||||
|
||||
# TODO: have a way to exit the loop.
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.memory import ChatMemory
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage
|
||||
from agnext.core import AgentInstantiationContext, AgentRuntime
|
||||
from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
@ -76,7 +76,10 @@ Use the following JSON format to provide your thought on the latest message and
|
||||
|
||||
# Publish the response if needed.
|
||||
if respond is True or str(respond).lower().strip() == "true":
|
||||
await self.publish_message(TextMessage(source=self.metadata["type"], content=str(response)))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
TextMessage(source=self.metadata["type"], content=str(response)), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
|
||||
class ChatRoomUserAgent(TextualUserAgent):
|
||||
@ -96,7 +99,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
alice = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"Alice",
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().type,
|
||||
@ -106,7 +109,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
bob = await runtime.register_and_get_proxy(
|
||||
alice = AgentProxy(AgentId("Alice", "default"), runtime)
|
||||
await runtime.register(
|
||||
"Bob",
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().type,
|
||||
@ -116,7 +120,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
charlie = await runtime.register_and_get_proxy(
|
||||
bob = AgentProxy(AgentId("Bob", "default"), runtime)
|
||||
await runtime.register(
|
||||
"Charlie",
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().type,
|
||||
@ -126,6 +131,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
charlie = AgentProxy(AgentId("Charlie", "default"), runtime)
|
||||
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
|
||||
1. 👧 {alice.id.type}: {(await alice.metadata)['description']}
|
||||
2. 👱🏼♂️ {bob.id.type}: {(await bob.metadata)['description']}
|
||||
|
||||
@ -10,14 +10,16 @@ import sys
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from agnext.core import AgentInstantiationContext, AgentRuntime, TopicId
|
||||
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
|
||||
from chess import piece_name as get_piece_name
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from agnext.core import AgentId
|
||||
from common.agents._chat_completion_agent import ChatCompletionAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
@ -156,7 +158,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
),
|
||||
]
|
||||
|
||||
black = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"PlayerBlack",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing black.",
|
||||
@ -173,7 +175,8 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
tools=black_tools,
|
||||
),
|
||||
)
|
||||
white = await runtime.register_and_get(
|
||||
await runtime.add_subscription(TypeSubscription("default", "PlayerBlack"))
|
||||
await runtime.register(
|
||||
"PlayerWhite",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing white.",
|
||||
@ -190,6 +193,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
tools=white_tools,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "PlayerWhite"))
|
||||
# Create a group chat manager for the chess game to orchestrate a turn-based
|
||||
# conversation between the two agents.
|
||||
await runtime.register(
|
||||
@ -197,7 +201,10 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
lambda: GroupChatManager(
|
||||
description="A chess game between two agents.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
participants=[white, black], # white goes first
|
||||
participants=[
|
||||
AgentId("PlayerWhite", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("PlayerBlack", AgentInstantiationContext.current_agent_id().key),
|
||||
], # white goes first
|
||||
),
|
||||
)
|
||||
|
||||
@ -207,7 +214,9 @@ async def main() -> None:
|
||||
await chess_game(runtime)
|
||||
run_context = runtime.start()
|
||||
# Publish an initial message to trigger the group chat manager to start orchestration.
|
||||
await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
TextMessage(content="Game started.", source="System"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
|
||||
@ -7,11 +7,12 @@ import sys
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components.models import SystemMessage
|
||||
from agnext.core import AgentRuntime
|
||||
from agnext.core import AgentInstantiationContext, AgentRuntime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from common.agents import ChatCompletionAgent, ImageGenerationAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
@ -27,7 +28,7 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
descriptor = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"Descriptor",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides a description of the image.",
|
||||
@ -46,7 +47,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500),
|
||||
),
|
||||
)
|
||||
illustrator = await runtime.register_and_get_proxy(
|
||||
descriptor = AgentProxy(AgentId("Descriptor", "default"), runtime)
|
||||
await runtime.register(
|
||||
"Illustrator",
|
||||
lambda: ImageGenerationAgent(
|
||||
description="An AI agent that generates images.",
|
||||
@ -55,7 +57,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
|
||||
memory=BufferedChatMemory(buffer_size=1),
|
||||
),
|
||||
)
|
||||
critic = await runtime.register_and_get_proxy(
|
||||
illustrator = AgentProxy(AgentId("Illustrator", "default"), runtime)
|
||||
await runtime.register(
|
||||
"Critic",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides feedback on images given user's requirements.",
|
||||
@ -74,12 +77,17 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
critic = AgentProxy(AgentId("Critic", "default"), runtime)
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A chat manager that handles group chat.",
|
||||
memory=BufferedChatMemory(buffer_size=5),
|
||||
participants=[illustrator.id, critic.id, descriptor.id],
|
||||
participants=[
|
||||
AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Descriptor", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Critic", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
termination_word="APPROVE",
|
||||
),
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components.models import SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from agnext.core import AgentInstantiationContext, AgentRuntime
|
||||
from markdownify import markdownify # type: ignore
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import Annotated
|
||||
@ -27,6 +27,7 @@ from typing_extensions import Annotated
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import AgentId
|
||||
from common.agents import ChatCompletionAgent
|
||||
from common.memory import HeadAndTailChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
@ -106,14 +107,14 @@ async def create_image(
|
||||
|
||||
|
||||
async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
|
||||
user_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Customer",
|
||||
lambda: TextualUserAgent(
|
||||
description="A customer looking for help.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
developer = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Developer",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A Python software developer.",
|
||||
@ -149,11 +150,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
|
||||
product_manager = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"ProductManager",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A product manager. "
|
||||
@ -179,10 +180,10 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
ux_designer = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"UserExperienceDesigner",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A user experience designer for creating user interfaces.",
|
||||
@ -211,11 +212,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
|
||||
illustrator = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Illustrator",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An illustrator for creating images.",
|
||||
@ -237,7 +238,7 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
|
||||
description="Create an image from a description.",
|
||||
),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
await runtime.register(
|
||||
@ -246,7 +247,13 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
|
||||
description="A group chat manager.",
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
participants=[developer, product_manager, ux_designer, illustrator, user_agent],
|
||||
participants=[
|
||||
AgentId("Developer", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("ProductManager", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("UserExperienceDesigner", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
),
|
||||
)
|
||||
art = r"""
|
||||
|
||||
@ -13,6 +13,7 @@ from textual_imageview.viewer import ImageViewer
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from agnext.core import TopicId
|
||||
from common.types import (
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
@ -135,7 +136,9 @@ class TextualChatApp(App): # type: ignore
|
||||
chat_messages.query("#typing").remove()
|
||||
# Publish the user message to the runtime.
|
||||
await self._runtime.publish_message(
|
||||
TextMessage(source=self._user_name, content=user_input), namespace="default"
|
||||
# TODO fix hard coded topic_id
|
||||
TextMessage(source=self._user_name, content=user_input),
|
||||
topic_id=TopicId("default", "default"),
|
||||
)
|
||||
|
||||
async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import AzureOpenAIChatCompletionClient
|
||||
from agnext.core import AgentRuntime
|
||||
from auditor import AuditAgent
|
||||
@ -28,7 +29,6 @@ async def build_app(runtime: AgentRuntime) -> None:
|
||||
)
|
||||
|
||||
await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
|
||||
await runtime.add_subscription(TypeSubscription("default", "GraphicDesigner"))
|
||||
await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
|
||||
|
||||
await runtime.get("GraphicDesigner")
|
||||
await runtime.get("Auditor")
|
||||
await runtime.add_subscription(TypeSubscription("default", "Auditor"))
|
||||
|
||||
@ -30,4 +30,7 @@ class AuditAgent(TypeRoutedAgent):
|
||||
assert isinstance(completion.content, str)
|
||||
if "NOTFORME" in completion.content:
|
||||
return
|
||||
await self.publish_message(AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
@ -33,6 +33,9 @@ class GraphicDesignerAgent(TypeRoutedAgent):
|
||||
image_uri = response.data[0].url
|
||||
logger.info(f"Generated image for article. Got response: '{image_uri}'")
|
||||
|
||||
await self.publish_message(GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=ctx.topic_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate image for article. Error: {e}")
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import Image, TypeRoutedAgent, message_handler
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import MessageContext, TopicId
|
||||
from app import build_app
|
||||
from dotenv import load_dotenv
|
||||
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
||||
@ -34,13 +34,15 @@ async def main() -> None:
|
||||
|
||||
ctx = runtime.start()
|
||||
|
||||
topic_id = TopicId("default", "default")
|
||||
|
||||
await runtime.publish_message(
|
||||
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), namespace="default"
|
||||
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=topic_id
|
||||
)
|
||||
|
||||
await runtime.publish_message(
|
||||
ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"),
|
||||
namespace="default",
|
||||
topic_id=topic_id,
|
||||
)
|
||||
|
||||
await ctx.stop_when_idle()
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY
|
||||
from app import build_app
|
||||
from dotenv import load_dotenv
|
||||
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
||||
|
||||
@ -22,6 +22,7 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@ -30,6 +31,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -100,10 +102,12 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
AssistantMessage(content=response.content, source=self.metadata["type"])
|
||||
)
|
||||
|
||||
assert ctx.topic_id is not None
|
||||
# Publish the code execution task.
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
@message_handler
|
||||
@ -120,8 +124,11 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
|
||||
if "TERMINATE" in response.content:
|
||||
# If the task is completed, publish a message with the completion content.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
TaskCompletion(content=response.content), cancellation_token=ctx.cancellation_token
|
||||
TaskCompletion(content=response.content),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
print("--------------------")
|
||||
print("Task completed:")
|
||||
@ -129,9 +136,11 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
return
|
||||
|
||||
# Publish the code execution task.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@ -148,11 +157,13 @@ class Executor(TypeRoutedAgent):
|
||||
code_blocks = self._extract_code_blocks(message.content)
|
||||
if not code_blocks:
|
||||
# If no code block is found, publish a message with an error.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(
|
||||
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
|
||||
),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
return
|
||||
# Execute code blocks.
|
||||
@ -160,9 +171,11 @@ class Executor(TypeRoutedAgent):
|
||||
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
# Publish the code execution result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
|
||||
@ -185,10 +198,12 @@ async def main(task: str, temp_dir: str) -> None:
|
||||
"coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))
|
||||
)
|
||||
await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
|
||||
await runtime.add_subscription(TypeSubscription("default", "coder"))
|
||||
await runtime.add_subscription(TypeSubscription("default", "executor"))
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish the task message.
|
||||
await runtime.publish_message(TaskMessage(content=task), namespace="default")
|
||||
await runtime.publish_message(TaskMessage(content=task), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from typing import Dict, List, Union
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@ -29,6 +30,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -110,12 +112,14 @@ Please review the code and provide feedback.
|
||||
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
|
||||
approved = review["approval"].lower().strip() == "approve"
|
||||
# Publish the review result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeReviewResult(
|
||||
review=review_text,
|
||||
approved=approved,
|
||||
session_id=message.session_id,
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@ -179,7 +183,11 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[session_id].append(code_review_task)
|
||||
# Publish a code review task.
|
||||
await self.publish_message(code_review_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
@message_handler
|
||||
async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:
|
||||
@ -193,12 +201,14 @@ Code: <Your code>
|
||||
# Check if the code is approved.
|
||||
if message.approved:
|
||||
# Publish the code writing result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeWritingResult(
|
||||
code=review_request.code,
|
||||
task=review_request.code_writing_task,
|
||||
review=message.review,
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
print("Code Writing Result:")
|
||||
print("-" * 80)
|
||||
@ -237,7 +247,11 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[message.session_id].append(code_review_task)
|
||||
# Publish a new code review task.
|
||||
await self.publish_message(code_review_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
|
||||
pattern = r"```(\w+)\n(.*?)\n```"
|
||||
@ -258,6 +272,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReviewerAgent"))
|
||||
await runtime.register(
|
||||
"CoderAgent",
|
||||
lambda: CoderAgent(
|
||||
@ -265,12 +280,13 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "CoderAgent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(
|
||||
message=CodeWritingTask(
|
||||
task="Write a function to find the directory with the largest number of files using multi-processing."
|
||||
),
|
||||
namespace="default",
|
||||
topic_id=TopicId("default", "default"),
|
||||
)
|
||||
|
||||
# Keep processing messages until idle.
|
||||
|
||||
@ -26,7 +26,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import AgentId
|
||||
from agnext.core import AgentId, AgentInstantiationContext, TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -69,7 +69,8 @@ class RoundRobinGroupChatManager(TypeRoutedAgent):
|
||||
self._round_count += 1
|
||||
if self._round_count > self._num_rounds * len(self._participants):
|
||||
# End the conversation after the specified number of rounds.
|
||||
await self.publish_message(Termination())
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Termination(), ctx.topic_id)
|
||||
return
|
||||
# Send a request to speak message to the selected speaker.
|
||||
await self.send_message(RequestToSpeak(), speaker)
|
||||
@ -104,9 +105,10 @@ class GroupChatParticipant(TypeRoutedAgent):
|
||||
llm_messages.append(UserMessage(content=m.content, source=m.source))
|
||||
response = await self._model_client.create(self._system_messages + llm_messages)
|
||||
assert isinstance(response.content, str)
|
||||
speach = Message(content=response.content, source=self.metadata["type"])
|
||||
self._memory.append(speach)
|
||||
await self.publish_message(speach)
|
||||
speech = Message(content=response.content, source=self.metadata["type"])
|
||||
self._memory.append(speech)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(speech, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -114,7 +116,7 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the participants.
|
||||
agent1 = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"DataScientist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="A data scientist",
|
||||
@ -122,7 +124,8 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
agent2 = await runtime.register_and_get(
|
||||
|
||||
await runtime.register(
|
||||
"Engineer",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An engineer",
|
||||
@ -130,7 +133,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
agent3 = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Artist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An artist",
|
||||
@ -144,7 +147,11 @@ async def main() -> None:
|
||||
"GroupChatManager",
|
||||
lambda: RoundRobinGroupChatManager(
|
||||
description="A group chat manager",
|
||||
participants=[agent1, agent2, agent3],
|
||||
participants=[
|
||||
AgentId("DataScientist", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Engineer", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Artist", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
num_rounds=3,
|
||||
),
|
||||
)
|
||||
@ -153,7 +160,9 @@ async def main() -> None:
|
||||
run_context = runtime.start()
|
||||
|
||||
# Start the conversation.
|
||||
await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
Message(content="Hello, everyone!", source="Moderator"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@ -16,11 +16,13 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
|
||||
from agnext.core import MessageContext
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import TopicId
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
||||
@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent):
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
class AggregatorAgent(TypeRoutedAgent):
|
||||
@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
"""Handle a task message. This method publishes the task to the reference agents."""
|
||||
session_id = str(uuid.uuid4())
|
||||
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
|
||||
await self.publish_message(ref_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(ref_task, topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
|
||||
@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
task_result = AggregatorTaskResult(result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
self._session_results.pop(message.session_id)
|
||||
print(f"Aggregator result: {response.content}")
|
||||
|
||||
@ -120,6 +125,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
@ -128,6 +134,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
@ -136,6 +143,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3"))
|
||||
await runtime.register(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
@ -149,8 +157,11 @@ async def main() -> None:
|
||||
num_references=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
# Keep processing messages.
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
@ -41,6 +41,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@ -48,6 +49,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer = match.group(1)
|
||||
# Increment the counter.
|
||||
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
|
||||
assert ctx.topic_id is not None
|
||||
if self._counters[message.session_id] == self._max_round:
|
||||
# If the counter reaches the maximum round, publishes a final response.
|
||||
await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id))
|
||||
await self.publish_message(
|
||||
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
|
||||
)
|
||||
else:
|
||||
# Publish intermediate response.
|
||||
await self.publish_message(
|
||||
@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer=answer,
|
||||
session_id=message.session_id,
|
||||
round=self._counters[message.session_id],
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent):
|
||||
"in the form of {{answer}}, at the end of your response."
|
||||
)
|
||||
session_id = str(uuid.uuid4())
|
||||
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
@message_handler
|
||||
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
|
||||
@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent):
|
||||
answers = [resp.answer for resp in self._responses[message.session_id]]
|
||||
majority_answer = max(set(answers), key=answers.count)
|
||||
# Publish the aggregated response.
|
||||
await self.publish_message(Answer(content=majority_answer))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
|
||||
# Clear the responses.
|
||||
self._responses.pop(message.session_id)
|
||||
print(f"Aggregated answer: {majority_answer}")
|
||||
@ -223,6 +233,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver1"))
|
||||
await runtime.register(
|
||||
"MathSolver2",
|
||||
lambda: MathSolver(
|
||||
@ -231,6 +242,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver2"))
|
||||
await runtime.register(
|
||||
"MathSolver3",
|
||||
lambda: MathSolver(
|
||||
@ -239,6 +251,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver3"))
|
||||
await runtime.register(
|
||||
"MathSolver4",
|
||||
lambda: MathSolver(
|
||||
@ -247,13 +260,14 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver4"))
|
||||
# Register the aggregator agent.
|
||||
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a math problem to the aggregator agent.
|
||||
await runtime.publish_message(Question(content=question), namespace="default")
|
||||
await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from agnext.components.models import (
|
||||
)
|
||||
from agnext.components.tool_agent import ToolAgent, ToolException
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema
|
||||
from agnext.core import AgentId
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -107,21 +107,21 @@ async def main() -> None:
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
|
||||
@ -129,7 +129,7 @@ async def main() -> None:
|
||||
|
||||
# Send a task to the tool user.
|
||||
response = await runtime.send_message(
|
||||
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
|
||||
Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default")
|
||||
)
|
||||
print(response.content)
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import SystemMessage
|
||||
from agnext.components.tool_agent import ToolAgent, ToolException
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
||||
from agnext.core import AgentId
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
@ -48,21 +48,21 @@ async def main() -> None:
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ async def main() -> None:
|
||||
|
||||
# Send a task to the tool user.
|
||||
response = await runtime.send_message(
|
||||
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
|
||||
Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default")
|
||||
)
|
||||
print(response.content)
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@ -32,6 +33,7 @@ from agnext.components.models import (
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@ -88,7 +90,8 @@ class ToolExecutorAgent(TypeRoutedAgent):
|
||||
session_id=message.session_id,
|
||||
result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id),
|
||||
)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
class ToolUseAgent(TypeRoutedAgent):
|
||||
@ -126,7 +129,8 @@ class ToolUseAgent(TypeRoutedAgent):
|
||||
if isinstance(response.content, str):
|
||||
# If the response is a string, just publish the response.
|
||||
response_message = AgentResponse(content=response.content)
|
||||
await self.publish_message(response_message)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(response_message, topic_id=ctx.topic_id)
|
||||
print(f"AI Response: {response.content}")
|
||||
return
|
||||
|
||||
@ -139,7 +143,8 @@ class ToolUseAgent(TypeRoutedAgent):
|
||||
for function_call in response.content:
|
||||
task = ToolExecutionTask(session_id=session_id, function_call=function_call)
|
||||
self._tool_counter[session_id] += 1
|
||||
await self.publish_message(task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task, topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None:
|
||||
@ -165,10 +170,11 @@ class ToolUseAgent(TypeRoutedAgent):
|
||||
self._sessions[message.session_id].append(
|
||||
AssistantMessage(content=response.content, source=self.metadata["type"])
|
||||
)
|
||||
assert ctx.topic_id is not None
|
||||
# If the response is a string, just publish the response.
|
||||
if isinstance(response.content, str):
|
||||
response_message = AgentResponse(content=response.content)
|
||||
await self.publish_message(response_message)
|
||||
await self.publish_message(response_message, topic_id=ctx.topic_id)
|
||||
self._tool_results.pop(message.session_id)
|
||||
self._tool_counter.pop(message.session_id)
|
||||
print(f"AI Response: {response.content}")
|
||||
@ -179,7 +185,7 @@ class ToolUseAgent(TypeRoutedAgent):
|
||||
for function_call in response.content:
|
||||
task = ToolExecutionTask(session_id=message.session_id, function_call=function_call)
|
||||
self._tool_counter[message.session_id] += 1
|
||||
await self.publish_message(task)
|
||||
await self.publish_message(task, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -192,6 +198,7 @@ async def main() -> None:
|
||||
]
|
||||
# Register agents.
|
||||
await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
|
||||
await runtime.add_subscription(TypeSubscription("default", "tool_executor"))
|
||||
await runtime.register(
|
||||
"tool_use_agent",
|
||||
lambda: ToolUseAgent(
|
||||
@ -201,12 +208,13 @@ async def main() -> None:
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "tool_use_agent"))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish a task.
|
||||
await runtime.publish_message(
|
||||
UserRequest("Run the following Python code: print('Hello, World!')"), namespace="default"
|
||||
UserRequest("Run the following Python code: print('Hello, World!')"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
@ -15,11 +15,13 @@ from agnext.components.models import (
|
||||
)
|
||||
from agnext.components.tool_agent import ToolAgent
|
||||
from agnext.components.tools import FunctionTool, Tool
|
||||
from agnext.core import AgentInstantiationContext
|
||||
from typing_extensions import Annotated
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import AgentId
|
||||
from coding_direct import Message, ToolUseAgent
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
@ -42,23 +44,24 @@ async def main() -> None:
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
|
||||
),
|
||||
)
|
||||
tool_use_agent = AgentId("tool_enabled_agent", "default")
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from dataclasses import dataclass
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -38,11 +39,14 @@ class ReceiveAgent(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
|
||||
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
class GreeterAgent(TypeRoutedAgent):
|
||||
@ -51,11 +55,15 @@ class GreeterAgent(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||
await self.publish_message(Greeting(f"Hello, {message.content}!"))
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
|
||||
await self.publish_message(Feedback(f"Feedback: {message.content}"))
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -68,9 +76,11 @@ async def main() -> None:
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
await runtime.add_subscription(TypeSubscription("default", "reciever"))
|
||||
await runtime.register("greeter", lambda: GreeterAgent())
|
||||
await runtime.add_subscription(TypeSubscription("default", "greeter"))
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
|
||||
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -43,7 +43,8 @@ class GreeterAgent(TypeRoutedAgent):
|
||||
@message_handler
|
||||
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
|
||||
await self.publish_message(Feedback(f"Feedback: {response.content}"))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@ -54,10 +55,11 @@ async def main() -> None:
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
reciever = await runtime.get("reciever")
|
||||
await runtime.register("greeter", lambda: GreeterAgent(reciever))
|
||||
await runtime.register(
|
||||
"greeter", lambda: GreeterAgent(AgentId("reciever", AgentInstantiationContext.current_agent_id().key))
|
||||
)
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
|
||||
@ -12,13 +12,13 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from agnext.core import Subscription, TopicId
|
||||
|
||||
from ..core import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
@ -38,7 +38,7 @@ class PublishMessageEnvelope:
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
namespace: str
|
||||
topic_id: TopicId
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -124,16 +124,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
self,
|
||||
@ -177,8 +179,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if sender is not None and sender.key != recipient.key:
|
||||
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
|
||||
|
||||
await self._process_seen_namespace(recipient.key)
|
||||
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
|
||||
|
||||
@ -199,8 +199,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
@ -219,26 +219,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# )
|
||||
# )
|
||||
|
||||
if sender is None and namespace is None:
|
||||
raise ValueError("Namespace must be provided if sender is not provided.")
|
||||
|
||||
sender_namespace = sender.key if sender is not None else None
|
||||
explicit_namespace = namespace
|
||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||
raise ValueError(
|
||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||
)
|
||||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
await self._process_seen_namespace(namespace)
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
namespace=namespace,
|
||||
message=message, cancellation_token=cancellation_token, sender=sender, topic_id=topic_id
|
||||
)
|
||||
)
|
||||
|
||||
@ -300,12 +283,13 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._outstanding_tasks.decrement()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
self._build_for_new_topic(message_envelope.topic_id)
|
||||
responses: List[Awaitable[Any]] = []
|
||||
target_namespace = message_envelope.namespace
|
||||
for agent_id in self._per_type_subscribers[
|
||||
(target_namespace, MESSAGE_TYPE_REGISTRY.type_name(message_envelope.message))
|
||||
]:
|
||||
if message_envelope.sender is not None and agent_id.type == message_envelope.sender.type:
|
||||
|
||||
recipients = self._subscribed_recipients[message_envelope.topic_id]
|
||||
for agent_id in recipients:
|
||||
# Avoid sending the message back to the sender
|
||||
if message_envelope.sender is not None and agent_id == message_envelope.sender:
|
||||
continue
|
||||
|
||||
sender_agent = (
|
||||
@ -326,8 +310,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# )
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
# TODO: topic_id
|
||||
topic_id=None,
|
||||
topic_id=message_envelope.topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
@ -460,16 +443,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
await self._get_agent(AgentId(type=name, key=namespace))
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
self._agent_factories[type] = agent_factory
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
@ -496,7 +475,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return agent
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
await self._process_seen_namespace(agent_id.key)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
@ -504,20 +482,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
raise LookupError(f"Agent with name {agent_id.type} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.type]
|
||||
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return (await self._get_agent(AgentId(type=name, key=namespace))).id
|
||||
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
if id.type not in self._agent_factories:
|
||||
@ -531,12 +499,40 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
return agent_instance
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub.id == subscription.id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
if len(self._seen_topics) > 0:
|
||||
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
# Check if the subscription exists
|
||||
if not any(sub.id == id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription does not exist")
|
||||
|
||||
def is_not_sub(x: Subscription) -> bool:
|
||||
return x.id != id
|
||||
|
||||
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
|
||||
|
||||
# Rebuild the subscriptions
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
# TODO: optimize this...
|
||||
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
|
||||
self._subscribed_recipients.clear()
|
||||
for topic in topics:
|
||||
self._build_for_new_topic(topic)
|
||||
|
||||
def _build_for_new_topic(self, topic: TopicId) -> None:
|
||||
if topic in self._seen_topics:
|
||||
return
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
await self._get_agent(AgentId(type=name, key=namespace))
|
||||
self._seen_topics.add(topic)
|
||||
for subscription in self._subscriptions:
|
||||
if subscription.is_match(topic):
|
||||
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
|
||||
|
||||
@ -28,9 +28,9 @@ import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId
|
||||
|
||||
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken
|
||||
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken
|
||||
from .protos import AgentId as AgentIdProto
|
||||
from .protos import (
|
||||
AgentRpcStub,
|
||||
@ -153,6 +153,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
self._next_request_id = 0
|
||||
self._host_connection: HostConnection | None = None
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
async def start(self, host_connection_string: str) -> None:
|
||||
if self._running:
|
||||
@ -245,29 +248,25 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when publishing message.")
|
||||
assert self._host_connection is not None
|
||||
sender_namespace = sender.key if sender is not None else None
|
||||
explicit_namespace = namespace
|
||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||
raise ValueError(
|
||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||
)
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
await self._process_seen_namespace(actual_namespace)
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
||||
task = asyncio.create_task(self._host_connection.send(message))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
message = Message(
|
||||
event=Event(
|
||||
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
|
||||
)
|
||||
)
|
||||
|
||||
async def write_message() -> None:
|
||||
assert self._host_connection is not None
|
||||
await self._host_connection.send(message)
|
||||
|
||||
await asyncio.create_task(write_message())
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
@ -284,26 +283,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when registering agent.")
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
await self._get_agent(AgentId(type=name, key=namespace))
|
||||
|
||||
assert self._host_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=name))
|
||||
await self._host_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", name)
|
||||
|
||||
async def _process_request(self, request: RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
target = AgentId(request.target.name, request.target.namespace)
|
||||
@ -347,27 +326,41 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
future.set_result(response.result)
|
||||
|
||||
async def _process_event(self, event: Event) -> None:
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||
namespace = event.namespace
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
|
||||
# TODO: skip the sender?
|
||||
message_context = MessageContext(
|
||||
sender=None,
|
||||
topic_id=None,
|
||||
is_rpc=False,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(message, ctx=message_context)
|
||||
responses.append(future)
|
||||
...
|
||||
# message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
|
||||
|
||||
try:
|
||||
_ = await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
return
|
||||
event_logger.error("Error handling event message", exc_info=e)
|
||||
# for agent_id in self._per_type_subscribers[
|
||||
# (namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
||||
# ]:
|
||||
|
||||
# agent = await self._get_agent(agent_id)
|
||||
# message_context = MessageContext(
|
||||
# # TODO: should sender be in the proto even for published events?
|
||||
# sender=None,
|
||||
# # TODO: topic_id
|
||||
# topic_id=None,
|
||||
# is_rpc=False,
|
||||
# cancellation_token=CancellationToken(),
|
||||
# )
|
||||
# try:
|
||||
# await agent.on_message(message, ctx=message_context)
|
||||
# logger.info("%s handled event %s", agent_id, message)
|
||||
# except Exception as e:
|
||||
# event_logger.error("Error handling message", exc_info=e)
|
||||
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
self._agent_factories[type] = agent_factory
|
||||
|
||||
assert self._host_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=type))
|
||||
await self._host_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", type)
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
@ -394,7 +387,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
return agent
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
await self._process_seen_namespace(agent_id.key)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
@ -402,32 +394,16 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
raise ValueError(f"Agent with name {agent_id.type} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.type]
|
||||
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
|
||||
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return (await self._get_agent(AgentId(type=name, key=namespace))).id
|
||||
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
return
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
raise NotImplementedError("Subscriptions are not yet implemented.")
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
await self._get_agent(AgentId(type=name, key=namespace))
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
raise NotImplementedError("Subscriptions are not yet implemented.")
|
||||
|
||||
@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x96\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb2\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12\x11\n\tdata_type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@ -38,13 +38,13 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=257
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=304
|
||||
_globals['_EVENT']._serialized_start=476
|
||||
_globals['_EVENT']._serialized_end=626
|
||||
_globals['_EVENT']._serialized_end=654
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=257
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=304
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=628
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=661
|
||||
_globals['_MESSAGE']._serialized_start=664
|
||||
_globals['_MESSAGE']._serialized_end=852
|
||||
_globals['_AGENTRPC']._serialized_start=854
|
||||
_globals['_AGENTRPC']._serialized_end=917
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=656
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=689
|
||||
_globals['_MESSAGE']._serialized_start=692
|
||||
_globals['_MESSAGE']._serialized_end=880
|
||||
_globals['_AGENTRPC']._serialized_start=882
|
||||
_globals['_AGENTRPC']._serialized_end=945
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -14,8 +14,6 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class AgentId(google.protobuf.message.Message):
|
||||
"""TODO: update"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
@ -143,24 +141,27 @@ class Event(google.protobuf.message.Message):
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
NAMESPACE_FIELD_NUMBER: builtins.int
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
namespace: builtins.str
|
||||
type: builtins.str
|
||||
topic_type: builtins.str
|
||||
topic_source: builtins.str
|
||||
data_type: builtins.str
|
||||
data: builtins.str
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
namespace: builtins.str = ...,
|
||||
type: builtins.str = ...,
|
||||
topic_type: builtins.str = ...,
|
||||
topic_source: builtins.str = ...,
|
||||
data_type: builtins.str = ...,
|
||||
data: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "metadata", b"metadata", "namespace", b"namespace", "type", b"type"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
|
||||
global___Event = Event
|
||||
|
||||
|
||||
@ -72,7 +72,6 @@ class ClosureAgent(Agent):
|
||||
key=self._id.key,
|
||||
type=self._id.type,
|
||||
description=self._description,
|
||||
subscriptions=self._subscriptions,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -142,12 +142,12 @@ class TypeRoutedAgent(BaseAgent):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type] = message_handler
|
||||
subscriptions = list(self._handlers.keys())
|
||||
|
||||
for message_type in self._handlers.keys():
|
||||
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
|
||||
MESSAGE_TYPE_REGISTRY.add_type(message_type)
|
||||
subscriptions_str = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscriptions]
|
||||
super().__init__(description, subscriptions_str)
|
||||
|
||||
super().__init__(description)
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
|
||||
@ -42,5 +42,4 @@ class TypeSubscription(Subscription):
|
||||
if not self.is_match(topic_id):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
# TODO: Update agentid to reflect agent type and key
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
@ -12,7 +12,7 @@ from ._agent_runtime import AgentRuntime
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
@ -32,4 +32,5 @@ __all__ = [
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
"MessageContext",
|
||||
"Serialization",
|
||||
]
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from typing import Sequence, TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class AgentMetadata(TypedDict):
|
||||
type: str
|
||||
key: str
|
||||
description: str
|
||||
subscriptions: Sequence[str]
|
||||
|
||||
@ -5,8 +5,9 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
@ -45,8 +46,8 @@ class AgentRuntime(Protocol):
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
@ -56,23 +57,24 @@ class AgentRuntime(Protocol):
|
||||
|
||||
Args:
|
||||
message (Any): The message to publish.
|
||||
namespace (str | None, optional): The namespace to publish to. Defaults to None.
|
||||
topic (TopicId): The topic to publish the message to.
|
||||
sender (AgentId | None, optional): The agent which sent the message. Defaults to None.
|
||||
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
|
||||
|
||||
Raises:
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
"""
|
||||
...
|
||||
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
|
||||
|
||||
@ -93,30 +95,6 @@ class AgentRuntime(Protocol):
|
||||
|
||||
...
|
||||
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
"""Get an agent by name and namespace.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
AgentId: The agent id.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
"""Get a proxy for an agent by name and namespace.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
AgentProxy: The agent proxy.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||
@ -137,46 +115,6 @@ class AgentRuntime(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId:
|
||||
"""Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
AgentId: The agent id.
|
||||
"""
|
||||
await self.register(name, agent_factory)
|
||||
return await self.get(name, namespace=namespace)
|
||||
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy:
|
||||
"""Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
AgentProxy: The agent proxy.
|
||||
"""
|
||||
await self.register(name, agent_factory)
|
||||
return await self.get_proxy(name, namespace=namespace)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
|
||||
|
||||
@ -227,3 +165,22 @@ class AgentRuntime(Protocol):
|
||||
state (Mapping[str, Any]): The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
"""Add a new subscription that the runtime should fulfill when processing published messages
|
||||
|
||||
Args:
|
||||
subscription (Subscription): The subscription to add
|
||||
"""
|
||||
...
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
"""Remove a subscription from the runtime
|
||||
|
||||
Args:
|
||||
id (str): id of the subscription to remove
|
||||
|
||||
Raises:
|
||||
LookupError: If the subscription does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping, Sequence
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@ -9,20 +9,16 @@ from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(
|
||||
key=self._id.key,
|
||||
type=self._id.type,
|
||||
description=self._description,
|
||||
subscriptions=self._subscriptions,
|
||||
)
|
||||
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
|
||||
|
||||
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
|
||||
def __init__(self, description: str) -> None:
|
||||
try:
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
@ -36,7 +32,6 @@ class BaseAgent(ABC, Agent):
|
||||
if not isinstance(description, str):
|
||||
raise ValueError("Agent description must be a string")
|
||||
self._description = description
|
||||
self._subscriptions = subscriptions
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@ -74,10 +69,11 @@ class BaseAgent(ABC, Agent):
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
await self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
|
||||
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Protocol
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from agnext.core._agent_id import AgentId
|
||||
from agnext.core import AgentId
|
||||
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Subscription(Protocol):
|
||||
"""Subscriptions define the topics that an agent is interested in."""
|
||||
|
||||
@ -19,6 +20,20 @@ class Subscription(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check if two subscriptions are equal.
|
||||
|
||||
Args:
|
||||
other (object): Other subscription to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if the subscriptions are equal, False otherwise.
|
||||
"""
|
||||
if not isinstance(other, Subscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
"""Check if a given topic_id matches the subscription.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class TopicId:
|
||||
type: str
|
||||
"""Type of the event that this topic_id contains. Adhere's to the cloud event spec.
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from team_one.agents.coder import Coder, Executor
|
||||
from team_one.agents.orchestrator import LedgerOrchestrator
|
||||
from team_one.agents.user_proxy import UserProxy
|
||||
@ -15,18 +16,22 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register agents.
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=create_completion_client_from_env()),
|
||||
)
|
||||
coder = AgentProxy(AgentId("Coder", "default"), runtime)
|
||||
|
||||
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
|
||||
executor = AgentProxy(AgentId("Executor", "default"), runtime)
|
||||
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(description="The current user interacting with you."),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
|
||||
# TODO: doesn't work for more than default key
|
||||
await runtime.register(
|
||||
"orchestrator",
|
||||
lambda: LedgerOrchestrator(
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from team_one.agents.coder import Coder, Executor
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.agents.user_proxy import UserProxy
|
||||
@ -15,17 +16,20 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register agents.
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=create_completion_client_from_env()),
|
||||
)
|
||||
coder = AgentProxy(AgentId("Coder", "default"), runtime)
|
||||
|
||||
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
|
||||
executor = AgentProxy(AgentId("Executor", "default"), runtime)
|
||||
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
lambda: UserProxy(description="The current user interacting with you."),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from team_one.agents.file_surfer import FileSurfer
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.agents.user_proxy import UserProxy
|
||||
@ -18,14 +19,17 @@ async def main() -> None:
|
||||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
file_surfer = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"file_surfer",
|
||||
lambda: FileSurfer(model_client=client),
|
||||
)
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
file_surfer = AgentProxy(AgentId("file_surfer", "default"), runtime)
|
||||
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.components.models import UserMessage
|
||||
from agnext.core import AgentId, AgentProxy, TopicId
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.agents.reflex_agents import ReflexAgent
|
||||
from team_one.messages import BroadcastMessage
|
||||
@ -13,14 +14,19 @@ from team_one.utils import LogHandler
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
|
||||
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
|
||||
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
|
||||
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
|
||||
await runtime.register("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
|
||||
fake1 = AgentProxy(AgentId("fake_agent_1", "default"), runtime)
|
||||
await runtime.register("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
|
||||
fake2 = AgentProxy(AgentId("fake_agent_2", "default"), runtime)
|
||||
|
||||
await runtime.register("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
|
||||
fake3 = AgentProxy(AgentId("fake_agent_3", "default"), runtime)
|
||||
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
|
||||
|
||||
task_message = UserMessage(content="Test Message", source="User")
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(BroadcastMessage(task_message), namespace="default")
|
||||
await runtime.publish_message(BroadcastMessage(task_message), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
# from typing import Any, Dict, List, Tuple, Union
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from team_one.agents.coder import Coder
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.agents.user_proxy import UserProxy
|
||||
@ -19,14 +20,17 @@ async def main() -> None:
|
||||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=client),
|
||||
)
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
coder = AgentProxy(AgentId("Coder", "default"), runtime)
|
||||
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.core import AgentId, AgentProxy
|
||||
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.agents.user_proxy import UserProxy
|
||||
@ -21,15 +22,17 @@ async def main() -> None:
|
||||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(),
|
||||
)
|
||||
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
|
||||
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from agnext.components.models import (
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import CancellationToken, MessageContext
|
||||
from agnext.core import CancellationToken, MessageContext, TopicId
|
||||
|
||||
from team_one.messages import (
|
||||
BroadcastMessage,
|
||||
@ -45,7 +45,8 @@ class BaseWorker(TeamOneBaseAgent):
|
||||
self._chat_history.append(assistant_message)
|
||||
|
||||
user_message = UserMessage(content=response, source=self.metadata["type"])
|
||||
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
|
||||
"""Returns (request_halt, response_message)"""
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from agnext.core import AgentProxy
|
||||
from agnext.core import AgentProxy, TopicId
|
||||
|
||||
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
|
||||
from .base_orchestrator import BaseOrchestrator, logger
|
||||
@ -248,8 +248,10 @@ class LedgerOrchestrator(BaseOrchestrator):
|
||||
synthesized_prompt = self._get_synthesize_prompt(
|
||||
self._task, self._team_description, self._facts, self._plan
|
||||
)
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
await self.publish_message(
|
||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
|
||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
||||
topic_id=topic_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@ -319,14 +321,17 @@ class LedgerOrchestrator(BaseOrchestrator):
|
||||
|
||||
# Reset everyone, then rebroadcast the new plan
|
||||
self._chat_history = [self._chat_history[0]]
|
||||
await self.publish_message(ResetMessage())
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
await self.publish_message(ResetMessage(), topic_id=topic_id)
|
||||
|
||||
# Send everyone the NEW plan
|
||||
synthesized_prompt = self._get_synthesize_prompt(
|
||||
self._task, self._team_description, self._facts, self._plan
|
||||
)
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
await self.publish_message(
|
||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
|
||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
||||
topic_id=topic_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@ -351,8 +356,10 @@ class LedgerOrchestrator(BaseOrchestrator):
|
||||
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
|
||||
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
|
||||
self._chat_history.append(assistant_message) # My copy
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
await self.publish_message(
|
||||
BroadcastMessage(content=user_message, request_halt=False)
|
||||
BroadcastMessage(content=user_message, request_halt=False),
|
||||
topic_id=topic_id,
|
||||
) # Send to everyone else
|
||||
return agent
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import UserMessage
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import MessageContext, TopicId
|
||||
|
||||
from ..messages import BroadcastMessage, RequestReplyMessage
|
||||
|
||||
@ -22,5 +22,6 @@ class ReflexAgent(TypeRoutedAgent):
|
||||
content=f"Hello, world from {name}!",
|
||||
source=name,
|
||||
)
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
|
||||
await self.publish_message(BroadcastMessage(response_message))
|
||||
await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id)
|
||||
|
||||
@ -7,11 +7,14 @@ from math import ceil
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from agnext.core import AgentId
|
||||
from agnext.core import AgentProxy
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
from json import dumps
|
||||
|
||||
from team_one.utils import (
|
||||
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
|
||||
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
|
||||
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON,
|
||||
create_completion_client_from_env
|
||||
)
|
||||
@ -96,13 +99,14 @@ async def test_web_surfer() -> None:
|
||||
# Register agents.
|
||||
|
||||
# Register agents.
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(),
|
||||
)
|
||||
web_surfer = AgentId("WebSurfer", "default")
|
||||
run_context = runtime.start()
|
||||
|
||||
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
|
||||
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer, MultimodalWebSurfer)
|
||||
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
|
||||
|
||||
# Test some basic navigations
|
||||
@ -138,7 +142,7 @@ async def test_web_surfer() -> None:
|
||||
tool_resp = await make_browser_request(actual_surfer, TOOL_PAGE_DOWN)
|
||||
assert (
|
||||
f"The viewport shows {viewport_percentage}% of the webpage, and is positioned at the bottom of the page" in tool_resp
|
||||
)
|
||||
)
|
||||
|
||||
# Test Q&A and summarization -- we don't have a key so we expect it to fail #(but it means the code path is correct)
|
||||
with pytest.raises(AuthenticationError):
|
||||
@ -160,15 +164,17 @@ async def test_web_surfer_oai() -> None:
|
||||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(),
|
||||
)
|
||||
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
|
||||
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
|
||||
run_context = runtime.start()
|
||||
|
||||
@ -220,10 +226,12 @@ async def test_web_surfer_bing() -> None:
|
||||
# Register agents.
|
||||
|
||||
# Register agents.
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
await runtime.register(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(),
|
||||
)
|
||||
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
|
||||
|
||||
run_context = runtime.start()
|
||||
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
|
||||
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
|
||||
@ -235,7 +243,7 @@ async def test_web_surfer_bing() -> None:
|
||||
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:url"]
|
||||
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:title"]
|
||||
assert f"I typed '{BING_QUERY}' into the browser search bar." in tool_resp.replace("\\","")
|
||||
|
||||
|
||||
tool_resp = await make_browser_request(actual_surfer, TOOL_WEB_SEARCH, {"query": BING_QUERY + " Wikipedia"})
|
||||
markdown = await actual_surfer._get_page_markdown() # type: ignore
|
||||
assert "https://en.wikipedia.org/wiki/" in markdown
|
||||
|
||||
@ -6,16 +6,18 @@ from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import AgentInstantiationContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
...
|
||||
class MessageType: ...
|
||||
|
||||
|
||||
# Note for future reader:
|
||||
# To do cancellation, only the token should be interacted with as a user
|
||||
# If you cancel a future, it may not work as you expect.
|
||||
|
||||
|
||||
class LongRunningAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A long running agent")
|
||||
@ -34,6 +36,7 @@ class LongRunningAgent(TypeRoutedAgent):
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
|
||||
class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
def __init__(self, nested_agent: AgentId) -> None:
|
||||
super().__init__("A nesting long running agent")
|
||||
@ -58,9 +61,10 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
async def test_cancellation_with_token() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
agent_id = AgentId("long_running", key="default")
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
@ -74,21 +78,25 @@ async def test_cancellation_with_token() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent)
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
await runtime.register(
|
||||
"nested",
|
||||
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
|
||||
)
|
||||
|
||||
long_running_id = AgentId("long_running", key="default")
|
||||
nested_id = AgentId("nested", key="default")
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
@ -101,22 +109,29 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
|
||||
assert long_running_agent.called is False
|
||||
assert long_running_agent.cancelled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_inner_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
await runtime.register(
|
||||
"nested",
|
||||
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
|
||||
)
|
||||
|
||||
long_running_id = AgentId("long_running", key="default")
|
||||
nested_id = AgentId("nested", key="default")
|
||||
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
@ -131,9 +146,9 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.core import AgentRuntime, AgentId
|
||||
|
||||
from agnext.components import ClosureAgent
|
||||
@ -13,6 +14,7 @@ from agnext.components import ClosureAgent
|
||||
import asyncio
|
||||
|
||||
from agnext.core import MessageContext
|
||||
from agnext.core import TopicId
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
@ -30,11 +32,15 @@ async def test_register_receives_publish() -> None:
|
||||
key = id.key
|
||||
await queue.put((key, message.content))
|
||||
|
||||
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
await runtime.register("name", lambda: ClosureAgent("my_agent", log_message))
|
||||
await runtime.add_subscription(TypeSubscription("default", "name"))
|
||||
topic_id = TopicId("default", "default")
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(Message("first message"), namespace="default")
|
||||
await runtime.publish_message(Message("second message"), namespace="default")
|
||||
await runtime.publish_message(Message("third message"), namespace="default")
|
||||
|
||||
await runtime.publish_message(Message("first message"), topic_id=topic_id)
|
||||
await runtime.publish_message(Message("second message"), topic_id=topic_id)
|
||||
await runtime.publish_message(Message("third message"), topic_id=topic_id)
|
||||
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ async def test_intervention_count_messages() -> None:
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
run_context = runtime.start()
|
||||
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
@ -40,7 +41,8 @@ async def test_intervention_drop_send() -> None:
|
||||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
@ -62,7 +64,8 @@ async def test_intervention_drop_response() -> None:
|
||||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
@ -84,15 +87,16 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -108,12 +112,13 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
run_context = runtime.start()
|
||||
with pytest.raises(InterventionException):
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
from agnext.core import TopicId
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
@ -15,13 +17,12 @@ async def test_agent_names_must_be_unique() -> None:
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
agent1 = await runtime.register_and_get("name1", agent_factory)
|
||||
assert agent1 == AgentId("name1", "default")
|
||||
await runtime.register("name1", agent_factory)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1 = await runtime.register_and_get("name1", NoopAgent)
|
||||
await runtime.register("name1", NoopAgent)
|
||||
|
||||
_agent1 = await runtime.register_and_get("name3", NoopAgent)
|
||||
await runtime.register("name3", NoopAgent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -30,16 +31,19 @@ async def test_register_receives_publish() -> None:
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(MessageType(), namespace="default")
|
||||
await runtime.add_subscription(TypeSubscription("default", "name"))
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent)
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@ -56,17 +60,19 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish messages
|
||||
topic_id = TopicId("default", "default")
|
||||
for _ in range(num_initial_messages):
|
||||
await runtime.publish_message(CascadingMessageType(round=1), namespace="default")
|
||||
await runtime.publish_message(CascadingMessageType(round=1), topic_id)
|
||||
|
||||
# Process until idle.
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent)
|
||||
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from agnext.core._serialization import Serialization
|
||||
from agnext.core import Serialization
|
||||
|
||||
class PydanticMessage(BaseModel):
|
||||
message: str
|
||||
|
||||
@ -1,19 +1,16 @@
|
||||
from typing import Any, Mapping, Sequence
|
||||
from typing import Any, Mapping
|
||||
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import BaseAgent, MessageContext
|
||||
from agnext.core import AgentId
|
||||
|
||||
|
||||
class StatefulAgent(BaseAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A stateful agent", [])
|
||||
super().__init__("A stateful agent")
|
||||
self.state = 0
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -28,7 +25,8 @@ class StatefulAgent(BaseAgent):
|
||||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
await runtime.register("name1", StatefulAgent)
|
||||
agent1_id = AgentId("name1", key="default")
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
@ -46,7 +44,8 @@ async def test_agent_can_save_state() -> None:
|
||||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
await runtime.register("name1", StatefulAgent)
|
||||
agent1_id = AgentId("name1", key="default")
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
@ -55,7 +54,8 @@ async def test_runtime_can_save_state() -> None:
|
||||
runtime_state = await runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
||||
await runtime2.register("name1", StatefulAgent)
|
||||
agent2_id = AgentId("name1", key="default")
|
||||
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
|
||||
|
||||
await runtime2.load_state(runtime_state)
|
||||
|
||||
@ -13,6 +13,7 @@ from agnext.components.tool_agent import (
|
||||
)
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import CancellationToken
|
||||
from agnext.core import AgentId
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
@ -31,7 +32,7 @@ async def _async_sleep_function(input: str) -> str:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_agent() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"tool_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool agent",
|
||||
@ -42,6 +43,7 @@ async def test_tool_agent() -> None:
|
||||
],
|
||||
),
|
||||
)
|
||||
agent = AgentId("tool_agent", "default")
|
||||
run = runtime.start()
|
||||
|
||||
# Test pass function
|
||||
|
||||
@ -38,11 +38,12 @@ class CascadingAgent(TypeRoutedAgent):
|
||||
self.num_calls += 1
|
||||
if message.round == self.max_rounds:
|
||||
return
|
||||
await self.publish_message(CascadingMessageType(round=message.round + 1))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=ctx.topic_id)
|
||||
|
||||
class NoopAgent(BaseAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A no op agent", [])
|
||||
super().__init__("A no op agent")
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user