Initial impl of topics and subscriptions (#350)

* initial impl of topics and subscriptions

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add topic in context

* migrate

* migrate code for topics

* migrate team one

* edit notebooks

* formatting

* fix imports

* Build proto

* Fix circular import

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits 2024-08-20 14:41:24 -04:00 committed by GitHub
parent 4ba7e84721
commit e1a823fb6d
71 changed files with 685 additions and 495 deletions

View File

@ -2,7 +2,6 @@ syntax = "proto3";
package agents;
// TODO: update
message AgentId {
string name = 1;
string namespace = 2;
@ -25,10 +24,11 @@ message RpcResponse {
}
message Event {
string namespace = 1;
string type = 2;
string data = 3;
map<string, string> metadata = 4;
string topic_type = 1;
string topic_source = 2;
string data_type = 3;
string data = 4;
map<string, string> metadata = 5;
}
message RegisterAgentType {

View File

@ -45,7 +45,7 @@
"\n",
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"from agnext.core import AgentId, MessageContext\n",
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"from langchain_core.tools import tool # pyright: ignore\n",
@ -195,7 +195,7 @@
"outputs": [],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"langgraph_tool_use_agent\",\n",
" lambda: LangGraphToolUseAgent(\n",
" \"Tool use agent\",\n",
@ -214,7 +214,8 @@
" # ),\n",
" [get_weather],\n",
" ),\n",
")"
")\n",
"agent = AgentId(\"langgraph_tool_use_agent\", key=\"default\")"
]
},
{

View File

@ -44,7 +44,7 @@
"\n",
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"from agnext.core import AgentId, MessageContext\n",
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"from llama_index.core import Settings\n",
"from llama_index.core.agent import ReActAgent\n",
@ -221,7 +221,7 @@
"outputs": [],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"chat_agent\",\n",
" lambda: LlamaIndexAgent(\n",
" description=\"Llama Index Agent\",\n",
@ -233,7 +233,8 @@
" verbose=True,\n",
" ),\n",
" ),\n",
")"
")\n",
"agent = AgentId(\"chat_agent\", \"default\")"
]
},
{

View File

@ -105,7 +105,7 @@
"\n",
"import aiofiles\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"from agnext.core import AgentId, MessageContext\n",
"from openai import AsyncAssistantEventHandler, AsyncClient\n",
"from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch\n",
"\n",
@ -390,7 +390,7 @@
"from agnext.application import SingleThreadedAgentRuntime\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"assistant\",\n",
" lambda: OpenAIAssistantAgent(\n",
" description=\"OpenAI Assistant Agent\",\n",
@ -399,7 +399,8 @@
" thread_id=thread.id,\n",
" assistant_event_handler_factory=lambda: EventHandler(),\n",
" ),\n",
")"
")\n",
"agent = AgentId(\"assistant\", \"default\")"
]
},
{

View File

@ -57,7 +57,7 @@
"source": [
"from dataclasses import dataclass\n",
"\n",
"from agnext.core import BaseAgent, MessageContext\n",
"from agnext.core import AgentId, BaseAgent, MessageContext\n",
"\n",
"\n",
"@dataclass\n",
@ -67,7 +67,7 @@
"\n",
"class MyAgent(BaseAgent):\n",
" def __init__(self) -> None:\n",
" super().__init__(\"MyAgent\", subscriptions=[\"MyMessage\"])\n",
" super().__init__(\"MyAgent\")\n",
"\n",
" async def on_message(self, message: MyMessage, ctx: MessageContext) -> None:\n",
" print(f\"Received message: {message.content}\")"
@ -133,7 +133,7 @@
}
],
"source": [
"agent_id = await runtime.get(\"my_agent\")\n",
"agent_id = AgentId(\"my_agent\", \"default\")\n",
"run_context = runtime.start() # Start processing messages in the background.\n",
"await runtime.send_message(MyMessage(content=\"Hello, World!\"), agent_id)\n",
"await run_context.stop() # Stop processing messages in the background."

View File

@ -83,7 +83,7 @@
"source": [
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"from agnext.core import AgentId, MessageContext\n",
"\n",
"\n",
"class MyAgent(TypeRoutedAgent):\n",
@ -110,7 +110,8 @@
"outputs": [],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\"my_agent\", lambda: MyAgent(\"My Agent\"))"
"await runtime.register(\"my_agent\", lambda: MyAgent(\"My Agent\"))\n",
"agent = AgentId(\"my_agent\", \"default\")"
]
},
{
@ -185,7 +186,7 @@
"\n",
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import AgentId, MessageContext\n",
"from agnext.core import MessageContext\n",
"\n",
"\n",
"@dataclass\n",
@ -200,9 +201,9 @@
"\n",
"\n",
"class OuterAgent(TypeRoutedAgent):\n",
" def __init__(self, description: str, inner_agent_id: AgentId):\n",
" def __init__(self, description: str, inner_agent_type: str):\n",
" super().__init__(description)\n",
" self.inner_agent_id = inner_agent_id\n",
" self.inner_agent_id = AgentId(inner_agent_type, self.id.key)\n",
"\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
@ -238,9 +239,10 @@
],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"inner = await runtime.register_and_get(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n",
"outer = await runtime.register_and_get(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", inner))\n",
"await runtime.register(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n",
"await runtime.register(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", \"InnerAgent\"))\n",
"run_context = runtime.start()\n",
"outer = AgentId(\"outer_agent\", \"default\")\n",
"await runtime.send_message(Message(content=\"Hello, World!\"), outer)\n",
"await run_context.stop_when_idle()"
]
@ -294,14 +296,17 @@
"source": [
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"from agnext.core import MessageContext, TopicId\n",
"\n",
"\n",
"class BroadcastingAgent(TypeRoutedAgent):\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
" # Publish a message to all agents in the same namespace.\n",
" await self.publish_message(Message(f\"Publishing a message: {message.content}!\"))\n",
" assert ctx.topic_id is not None\n",
" await self.publish_message(\n",
" Message(f\"Publishing a message: {message.content}!\"), topic_id=TopicId(\"deafult\", self.id.key)\n",
" )\n",
"\n",
"\n",
"class ReceivingAgent(TypeRoutedAgent):\n",
@ -332,11 +337,15 @@
}
],
"source": [
"from agnext.components import TypeSubscription\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
"run_context = runtime.start()\n",
"await runtime.send_message(Message(\"Hello, World!\"), broadcaster)\n",
"await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n",
"await run_context.stop_when_idle()"
]
},
@ -367,10 +376,12 @@
"# Replace send_message with publish_message in the above example.\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
"run_context = runtime.start()\n",
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), namespace=\"default\")\n",
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=TopicId(\"default\", \"default\"))\n",
"await run_context.stop_when_idle()"
]
},

View File

@ -318,8 +318,10 @@
],
"source": [
"# Create the runtime and register the agent.\n",
"from agnext.core import AgentId\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"simple-agent\",\n",
" lambda: SimpleAgent(\n",
" OpenAIChatCompletionClient(\n",
@ -332,7 +334,7 @@
"run_context = runtime.start()\n",
"# Send a message to the agent and get the response.\n",
"message = Message(\"Hello, what are some fun things to do in Seattle?\")\n",
"response = await runtime.send_message(message, agent)\n",
"response = await runtime.send_message(message, AgentId(\"simple-agent\", \"default\"))\n",
"print(response.content)\n",
"# Stop the runtime processing messages.\n",
"await run_context.stop()"

View File

@ -131,7 +131,7 @@
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from agnext.core import MessageContext"
"from agnext.core import MessageContext, TopicId"
]
},
{
@ -201,7 +201,7 @@
" # Store the code review task in the session memory.\n",
" self._session_memory[session_id].append(code_review_task)\n",
" # Publish a code review task.\n",
" await self.publish_message(code_review_task)\n",
" await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n",
"\n",
" @message_handler\n",
" async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n",
@ -220,7 +220,8 @@
" code=review_request.code,\n",
" task=review_request.code_writing_task,\n",
" review=message.review,\n",
" )\n",
" ),\n",
" topic_id=TopicId(\"default\", self.id.key),\n",
" )\n",
" print(\"Code Writing Result:\")\n",
" print(\"-\" * 80)\n",
@ -259,7 +260,7 @@
" # Store the code review task in the session memory.\n",
" self._session_memory[message.session_id].append(code_review_task)\n",
" # Publish a new code review task.\n",
" await self.publish_message(code_review_task)\n",
" await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n",
"\n",
" def _extract_code_block(self, markdown_text: str) -> Union[str, None]:\n",
" pattern = r\"```(\\w+)\\n(.*?)\\n```\"\n",
@ -360,7 +361,7 @@
" # Store the review result in the session memory.\n",
" self._session_memory[message.session_id].append(result)\n",
" # Publish the review result.\n",
" await self.publish_message(result)"
" await self.publish_message(result, topic_id=TopicId(\"default\", self.id.key))"
]
},
{
@ -494,6 +495,7 @@
],
"source": [
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components._type_subscription import TypeSubscription\n",
"from agnext.components.models import OpenAIChatCompletionClient\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
@ -501,14 +503,16 @@
" \"ReviewerAgent\",\n",
" lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n",
")\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"CoderAgent\"))\n",
"await runtime.register(\n",
" \"CoderAgent\",\n",
" lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n",
")\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"ReviewerAgent\"))\n",
"run_context = runtime.start()\n",
"await runtime.publish_message(\n",
" message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n",
" namespace=\"default\",\n",
" topic_id=TopicId(\"default\", \"default\"),\n",
")\n",
"\n",
"# Keep processing messages until idle.\n",

View File

@ -148,7 +148,7 @@
")\n",
"from agnext.components.tool_agent import ToolAgent, ToolException\n",
"from agnext.components.tools import FunctionTool, Tool, ToolSchema\n",
"from agnext.core import AgentId, MessageContext\n",
"from agnext.core import AgentId, AgentInstantiationContext, MessageContext\n",
"\n",
"\n",
"@dataclass\n",
@ -239,19 +239,19 @@
"# Create the tools.\n",
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n",
"tool_executor_agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"tool-executor-agent\",\n",
" lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n",
" tools=tools,\n",
" ),\n",
")\n",
"tool_use_agent = await runtime.register_and_get(\n",
"await runtime.register(\n",
" \"tool-use-agent\",\n",
" lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[tool.schema for tool in tools],\n",
" tool_agent=tool_executor_agent,\n",
" tool_agent=AgentId(\"tool-executor-agent\", AgentInstantiationContext.current_agent_id().key),\n",
" ),\n",
")"
]
@ -282,6 +282,7 @@
"# Start processing messages.\n",
"run_context = runtime.start()\n",
"# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool-use-agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n",
"# Stop processing messages.\n",

View File

@ -10,7 +10,7 @@ from typing import Any, Callable, List, Literal
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MessageContext
from agnext.core import AgentId, MessageContext
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool # pyright: ignore
from langchain_openai import ChatOpenAI
@ -110,7 +110,7 @@ async def main() -> None:
# Create runtime.
runtime = SingleThreadedAgentRuntime()
# Register the agent.
agent = await runtime.register_and_get(
await runtime.register(
"langgraph_tool_use_agent",
lambda: LangGraphToolUseAgent(
"Tool use agent",
@ -118,6 +118,7 @@ async def main() -> None:
[get_weather],
),
)
agent = AgentId("langgraph_tool_use_agent", key="default")
# Start the runtime.
run_context = runtime.start()
# Send a message to the agent and get a response.

View File

@ -9,7 +9,7 @@ from typing import List, Optional
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MessageContext
from agnext.core import AgentId, MessageContext
from llama_index.core import Settings
from llama_index.core.agent import ReActAgent
from llama_index.core.agent.runner.base import AgentRunner
@ -119,10 +119,11 @@ async def main() -> None:
tools=[wikipedia_tool], llm=llm, max_iterations=8, memory=memory, verbose=True
)
agent = await runtime.register_and_get(
await runtime.register(
"chat_agent",
lambda: LlamaIndexAgent("Chat agent", llama_index_agent=llama_index_agent),
)
agent = AgentId("chat_agent", key="default")
run_context = runtime.start()

View File

@ -110,8 +110,9 @@ class ChatCompletionAgent(TypeRoutedAgent):
# Generate a response.
response = await self._generate_response(message.response_format, ctx)
assert ctx.topic_id is not None
# Publish the response.
await self.publish_message(response)
await self.publish_message(response, topic_id=ctx.topic_id)
@message_handler()
async def on_tool_call_message(

View File

@ -7,8 +7,7 @@ from agnext.components import (
message_handler,
)
from agnext.components.memory import ChatMemory
from agnext.core import MessageContext
from agnext.core._cancellation_token import CancellationToken
from agnext.core import CancellationToken, MessageContext
from ..types import (
Message,
@ -58,7 +57,8 @@ class ImageGenerationAgent(TypeRoutedAgent):
image is published as a MultiModalMessage."""
response = await self._generate_response(ctx.cancellation_token)
await self.publish_message(response)
assert ctx.topic_id is not None
await self.publish_message(response, topic_id=ctx.topic_id)
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
messages = await self._memory.get_messages()

View File

@ -80,7 +80,8 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method generates a response and publishes it."""
response = await self._generate_response(message.response_format, ctx.cancellation_token)
await self.publish_message(response)
assert ctx.topic_id is not None
await self.publish_message(response, ctx.topic_id)
async def _generate_response(
self,

View File

@ -23,7 +23,8 @@ class UserProxyAgent(TypeRoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
user_input = await self.get_user_input(self._user_input_prompt)
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]))
assert ctx.topic_id is not None
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id)
async def get_user_input(self, prompt: str) -> str:
"""Get user input from the console. Override this method to customize how user input is retrieved."""

View File

@ -12,7 +12,7 @@ from dataclasses import dataclass
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentId, MessageContext
from agnext.core import AgentId, AgentInstantiationContext, MessageContext
@dataclass
@ -45,8 +45,9 @@ class Outer(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
inner = await runtime.register_and_get("inner", Inner)
outer = await runtime.register_and_get("outer", lambda: Outer(inner))
await runtime.register("inner", Inner)
await runtime.register("outer", lambda: Outer(AgentId("outer", AgentInstantiationContext.current_agent_id().key)))
outer = AgentId("outer", "default")
run_context = runtime.start()

View File

@ -17,6 +17,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import AgentId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -45,10 +46,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
agent = await runtime.register_and_get(
await runtime.register(
"chat_agent",
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-4o-mini")),
)
agent = AgentId("chat_agent", "default")
run_context = runtime.start()

View File

@ -18,6 +18,7 @@ from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@ -25,6 +26,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import AgentId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -69,7 +71,11 @@ class ChatCompletionAgent(TypeRoutedAgent):
llm_messages.append(UserMessage(content=m.content, source=m.source))
response = await self._model_client.create(self._system_messages + llm_messages)
assert isinstance(response.content, str)
await self.publish_message(Message(content=response.content, source=self.metadata["type"]))
if ctx.topic_id is not None:
await self.publish_message(
Message(content=response.content, source=self.metadata["type"]), topic_id=ctx.topic_id
)
async def main() -> None:
@ -77,7 +83,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register the agents.
jack = await runtime.register_and_get(
await runtime.register(
"Jack",
lambda: ChatCompletionAgent(
description="Jack a comedian",
@ -88,7 +94,8 @@ async def main() -> None:
termination_word="TERMINATE",
),
)
await runtime.register_and_get(
await runtime.add_subscription(TypeSubscription("default", "Jack"))
await runtime.register(
"Cathy",
lambda: ChatCompletionAgent(
description="Cathy a poet",
@ -99,12 +106,13 @@ async def main() -> None:
termination_word="TERMINATE",
),
)
await runtime.add_subscription(TypeSubscription("default", "Cathy"))
run_context = runtime.start()
# Send a message to Jack to start the conversation.
message = Message(content="Can you tell me something fun about SF?", source="User")
await runtime.send_message(message, jack)
await runtime.send_message(message, AgentId("jack", "default"))
# Process messages.
await run_context.stop_when_idle()

View File

@ -13,7 +13,7 @@ import aiofiles
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentId, AgentRuntime, CancellationToken
from agnext.core import AgentId, AgentRuntime, MessageContext
from openai import AsyncAssistantEventHandler
from openai.types.beta.thread import ToolResources
from openai.types.beta.threads import Message, Text, TextDelta
@ -22,6 +22,7 @@ from typing_extensions import override
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agnext.core import AgentInstantiationContext
from common.agents import OpenAIAssistantAgent
from common.memory import BufferedChatMemory
from common.patterns._group_chat_manager import GroupChatManager
@ -30,7 +31,7 @@ from common.types import PublishNow, TextMessage
sep = "-" * 50
class UserProxyAgent(TypeRoutedAgent): # type: ignore
class UserProxyAgent(TypeRoutedAgent):
def __init__( # type: ignore
self,
client: openai.AsyncClient, # type: ignore
@ -47,7 +48,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
self._vector_store_id = vector_store_id
@message_handler() # type: ignore
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:
# TODO: render image if message has image.
# print(f"{message.source}: {message.content}")
pass
@ -57,7 +58,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
return await loop.run_in_executor(None, input, prompt)
@message_handler() # type: ignore
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: # type: ignore
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
while True:
user_input = await self._get_user_input(f"\n{sep}\nYou: ")
# Parse upload file command '[upload code_interpreter | file_search filename]'.
@ -108,7 +109,10 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore
return
else:
# Publish user input and exit handler.
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]))
assert ctx.topic_id is not None
await self.publish_message(
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id
)
return
@ -166,7 +170,7 @@ class EventHandler(AsyncAssistantEventHandler):
print("\n".join(citations))
async def assistant_chat(runtime: AgentRuntime) -> AgentId:
async def assistant_chat(runtime: AgentRuntime) -> str:
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
@ -177,7 +181,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
thread = openai.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant = await runtime.register_and_get(
await runtime.register(
"Assistant",
lambda: OpenAIAssistantAgent(
description="An AI assistant that helps with everyday tasks.",
@ -188,7 +192,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
),
)
user = await runtime.register_and_get(
await runtime.register(
"User",
lambda: UserProxyAgent(
client=openai.AsyncClient(),
@ -203,10 +207,13 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId:
lambda: GroupChatManager(
description="A group chat manager.",
memory=BufferedChatMemory(buffer_size=10),
participants=[assistant, user],
participants=[
AgentId("Assistant", AgentInstantiationContext.current_agent_id().key),
AgentId("User", AgentInstantiationContext.current_agent_id().key),
],
),
)
return user
return "User"
async def main() -> None:
@ -229,7 +236,7 @@ Type "exit" to exit the chat.
_run_context = runtime.start()
print(usage)
# Request the user to start the conversation.
await runtime.send_message(PublishNow(), user)
await runtime.send_message(PublishNow(), AgentId(user, "default"))
# TODO: have a way to exit the loop.

View File

@ -9,7 +9,7 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.memory import ChatMemory
from agnext.components.models import ChatCompletionClient, SystemMessage
from agnext.core import AgentInstantiationContext, AgentRuntime
from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -76,7 +76,10 @@ Use the following JSON format to provide your thought on the latest message and
# Publish the response if needed.
if respond is True or str(respond).lower().strip() == "true":
await self.publish_message(TextMessage(source=self.metadata["type"], content=str(response)))
assert ctx.topic_id is not None
await self.publish_message(
TextMessage(source=self.metadata["type"], content=str(response)), topic_id=ctx.topic_id
)
class ChatRoomUserAgent(TextualUserAgent):
@ -96,7 +99,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
app=app,
),
)
alice = await runtime.register_and_get_proxy(
await runtime.register(
"Alice",
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().type,
@ -106,7 +109,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
bob = await runtime.register_and_get_proxy(
alice = AgentProxy(AgentId("Alice", "default"), runtime)
await runtime.register(
"Bob",
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().type,
@ -116,7 +120,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
charlie = await runtime.register_and_get_proxy(
bob = AgentProxy(AgentId("Bob", "default"), runtime)
await runtime.register(
"Charlie",
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().type,
@ -126,6 +131,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
charlie = AgentProxy(AgentId("Charlie", "default"), runtime)
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
1. 👧 {alice.id.type}: {(await alice.metadata)['description']}
2. 👱🏼 {bob.id.type}: {(await bob.metadata)['description']}

View File

@ -10,14 +10,16 @@ import sys
from typing import Annotated, Literal
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentRuntime
from agnext.core import AgentInstantiationContext, AgentRuntime, TopicId
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
from chess import piece_name as get_piece_name
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agnext.core import AgentId
from common.agents._chat_completion_agent import ChatCompletionAgent
from common.memory import BufferedChatMemory
from common.patterns._group_chat_manager import GroupChatManager
@ -156,7 +158,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
),
]
black = await runtime.register_and_get(
await runtime.register(
"PlayerBlack",
lambda: ChatCompletionAgent(
description="Player playing black.",
@ -173,7 +175,8 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
tools=black_tools,
),
)
white = await runtime.register_and_get(
await runtime.add_subscription(TypeSubscription("default", "PlayerBlack"))
await runtime.register(
"PlayerWhite",
lambda: ChatCompletionAgent(
description="Player playing white.",
@ -190,6 +193,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
tools=white_tools,
),
)
await runtime.add_subscription(TypeSubscription("default", "PlayerWhite"))
# Create a group chat manager for the chess game to orchestrate a turn-based
# conversation between the two agents.
await runtime.register(
@ -197,7 +201,10 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
lambda: GroupChatManager(
description="A chess game between two agents.",
memory=BufferedChatMemory(buffer_size=10),
participants=[white, black], # white goes first
participants=[
AgentId("PlayerWhite", AgentInstantiationContext.current_agent_id().key),
AgentId("PlayerBlack", AgentInstantiationContext.current_agent_id().key),
], # white goes first
),
)
@ -207,7 +214,9 @@ async def main() -> None:
await chess_game(runtime)
run_context = runtime.start()
# Publish an initial message to trigger the group chat manager to start orchestration.
await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default")
await runtime.publish_message(
TextMessage(content="Game started.", source="System"), topic_id=TopicId("default", "default")
)
await run_context.stop_when_idle()

View File

@ -7,11 +7,12 @@ import sys
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.components.models import SystemMessage
from agnext.core import AgentRuntime
from agnext.core import AgentInstantiationContext, AgentRuntime
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import AgentId, AgentProxy
from common.agents import ChatCompletionAgent, ImageGenerationAgent
from common.memory import BufferedChatMemory
from common.patterns._group_chat_manager import GroupChatManager
@ -27,7 +28,7 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
app=app,
),
)
descriptor = await runtime.register_and_get_proxy(
await runtime.register(
"Descriptor",
lambda: ChatCompletionAgent(
description="An AI agent that provides a description of the image.",
@ -46,7 +47,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500),
),
)
illustrator = await runtime.register_and_get_proxy(
descriptor = AgentProxy(AgentId("Descriptor", "default"), runtime)
await runtime.register(
"Illustrator",
lambda: ImageGenerationAgent(
description="An AI agent that generates images.",
@ -55,7 +57,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
memory=BufferedChatMemory(buffer_size=1),
),
)
critic = await runtime.register_and_get_proxy(
illustrator = AgentProxy(AgentId("Illustrator", "default"), runtime)
await runtime.register(
"Critic",
lambda: ChatCompletionAgent(
description="An AI agent that provides feedback on images given user's requirements.",
@ -74,12 +77,17 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
critic = AgentProxy(AgentId("Critic", "default"), runtime)
await runtime.register(
"GroupChatManager",
lambda: GroupChatManager(
description="A chat manager that handles group chat.",
memory=BufferedChatMemory(buffer_size=5),
participants=[illustrator.id, critic.id, descriptor.id],
participants=[
AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key),
AgentId("Descriptor", AgentInstantiationContext.current_agent_id().key),
AgentId("Critic", AgentInstantiationContext.current_agent_id().key),
],
termination_word="APPROVE",
),
)

View File

@ -19,7 +19,7 @@ import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.components.models import SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentRuntime
from agnext.core import AgentInstantiationContext, AgentRuntime
from markdownify import markdownify # type: ignore
from tqdm import tqdm
from typing_extensions import Annotated
@ -27,6 +27,7 @@ from typing_extensions import Annotated
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import AgentId
from common.agents import ChatCompletionAgent
from common.memory import HeadAndTailChatMemory
from common.patterns._group_chat_manager import GroupChatManager
@ -106,14 +107,14 @@ async def create_image(
async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
user_agent = await runtime.register_and_get(
await runtime.register(
"Customer",
lambda: TextualUserAgent(
description="A customer looking for help.",
app=app,
),
)
developer = await runtime.register_and_get(
await runtime.register(
"Developer",
lambda: ChatCompletionAgent(
description="A Python software developer.",
@ -149,11 +150,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
FunctionTool(list_files, name="list_files", description="List files in a directory."),
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
],
tool_approver=user_agent,
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
),
)
product_manager = await runtime.register_and_get(
await runtime.register(
"ProductManager",
lambda: ChatCompletionAgent(
description="A product manager. "
@ -179,10 +180,10 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
FunctionTool(list_files, name="list_files", description="List files in a directory."),
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
],
tool_approver=user_agent,
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
),
)
ux_designer = await runtime.register_and_get(
await runtime.register(
"UserExperienceDesigner",
lambda: ChatCompletionAgent(
description="A user experience designer for creating user interfaces.",
@ -211,11 +212,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
),
FunctionTool(list_files, name="list_files", description="List files in a directory."),
],
tool_approver=user_agent,
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
),
)
illustrator = await runtime.register_and_get(
await runtime.register(
"Illustrator",
lambda: ChatCompletionAgent(
description="An illustrator for creating images.",
@ -237,7 +238,7 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
description="Create an image from a description.",
),
],
tool_approver=user_agent,
tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
),
)
await runtime.register(
@ -246,7 +247,13 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No
description="A group chat manager.",
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
participants=[developer, product_manager, ux_designer, illustrator, user_agent],
participants=[
AgentId("Developer", AgentInstantiationContext.current_agent_id().key),
AgentId("ProductManager", AgentInstantiationContext.current_agent_id().key),
AgentId("UserExperienceDesigner", AgentInstantiationContext.current_agent_id().key),
AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key),
AgentId("Customer", AgentInstantiationContext.current_agent_id().key),
],
),
)
art = r"""

View File

@ -13,6 +13,7 @@ from textual_imageview.viewer import ImageViewer
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agnext.core import TopicId
from common.types import (
MultiModalMessage,
PublishNow,
@ -135,7 +136,9 @@ class TextualChatApp(App): # type: ignore
chat_messages.query("#typing").remove()
# Publish the user message to the runtime.
await self._runtime.publish_message(
TextMessage(source=self._user_name, content=user_input), namespace="default"
# TODO fix hard coded topic_id
TextMessage(source=self._user_name, content=user_input),
topic_id=TopicId("default", "default"),
)
async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore

View File

@ -1,5 +1,6 @@
import os
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import AzureOpenAIChatCompletionClient
from agnext.core import AgentRuntime
from auditor import AuditAgent
@ -28,7 +29,6 @@ async def build_app(runtime: AgentRuntime) -> None:
)
await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
await runtime.add_subscription(TypeSubscription("default", "GraphicDesigner"))
await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
await runtime.get("GraphicDesigner")
await runtime.get("Auditor")
await runtime.add_subscription(TypeSubscription("default", "Auditor"))

View File

@ -30,4 +30,7 @@ class AuditAgent(TypeRoutedAgent):
assert isinstance(completion.content, str)
if "NOTFORME" in completion.content:
return
await self.publish_message(AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content))
assert ctx.topic_id is not None
await self.publish_message(
AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=ctx.topic_id
)

View File

@ -33,6 +33,9 @@ class GraphicDesignerAgent(TypeRoutedAgent):
image_uri = response.data[0].url
logger.info(f"Generated image for article. Got response: '{image_uri}'")
await self.publish_message(GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri))
assert ctx.topic_id is not None
await self.publish_message(
GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=ctx.topic_id
)
except Exception as e:
logger.error(f"Failed to generate image for article. Error: {e}")

View File

@ -3,7 +3,7 @@ import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import Image, TypeRoutedAgent, message_handler
from agnext.core import MessageContext
from agnext.core import MessageContext, TopicId
from app import build_app
from dotenv import load_dotenv
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
@ -34,13 +34,15 @@ async def main() -> None:
ctx = runtime.start()
topic_id = TopicId("default", "default")
await runtime.publish_message(
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), namespace="default"
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=topic_id
)
await runtime.publish_message(
ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"),
namespace="default",
topic_id=topic_id,
)
await ctx.stop_when_idle()

View File

@ -2,7 +2,7 @@ import asyncio
import logging
from agnext.application import WorkerAgentRuntime
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
from agnext.core import MESSAGE_TYPE_REGISTRY
from app import build_app
from dotenv import load_dotenv
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated

View File

@ -22,6 +22,7 @@ from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
from agnext.components.models import (
AssistantMessage,
@ -30,6 +31,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -100,10 +102,12 @@ Reply "TERMINATE" in the end when everything is done."""
AssistantMessage(content=response.content, source=self.metadata["type"])
)
assert ctx.topic_id is not None
# Publish the code execution task.
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
)
@message_handler
@ -120,8 +124,11 @@ Reply "TERMINATE" in the end when everything is done."""
if "TERMINATE" in response.content:
# If the task is completed, publish a message with the completion content.
assert ctx.topic_id is not None
await self.publish_message(
TaskCompletion(content=response.content), cancellation_token=ctx.cancellation_token
TaskCompletion(content=response.content),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
)
print("--------------------")
print("Task completed:")
@ -129,9 +136,11 @@ Reply "TERMINATE" in the end when everything is done."""
return
# Publish the code execution task.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
)
@ -148,11 +157,13 @@ class Executor(TypeRoutedAgent):
code_blocks = self._extract_code_blocks(message.content)
if not code_blocks:
# If no code block is found, publish a message with an error.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
)
return
# Execute code blocks.
@ -160,9 +171,11 @@ class Executor(TypeRoutedAgent):
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
)
# Publish the code execution result.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
)
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
@ -185,10 +198,12 @@ async def main(task: str, temp_dir: str) -> None:
"coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))
)
await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
await runtime.add_subscription(TypeSubscription("default", "coder"))
await runtime.add_subscription(TypeSubscription("default", "executor"))
run_context = runtime.start()
# Publish the task message.
await runtime.publish_message(TaskMessage(content=task), namespace="default")
await runtime.publish_message(TaskMessage(content=task), topic_id=TopicId("default", "default"))
await run_context.stop_when_idle()

View File

@ -22,6 +22,7 @@ from typing import Dict, List, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@ -29,6 +30,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -110,12 +112,14 @@ Please review the code and provide feedback.
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
approved = review["approval"].lower().strip() == "approve"
# Publish the review result.
assert ctx.topic_id is not None
await self.publish_message(
CodeReviewResult(
review=review_text,
approved=approved,
session_id=message.session_id,
)
),
topic_id=ctx.topic_id,
)
@ -179,7 +183,11 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[session_id].append(code_review_task)
# Publish a code review task.
await self.publish_message(code_review_task)
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
)
@message_handler
async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:
@ -193,12 +201,14 @@ Code: <Your code>
# Check if the code is approved.
if message.approved:
# Publish the code writing result.
assert ctx.topic_id is not None
await self.publish_message(
CodeWritingResult(
code=review_request.code,
task=review_request.code_writing_task,
review=message.review,
)
),
topic_id=ctx.topic_id,
)
print("Code Writing Result:")
print("-" * 80)
@ -237,7 +247,11 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[message.session_id].append(code_review_task)
# Publish a new code review task.
await self.publish_message(code_review_task)
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
)
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
pattern = r"```(\w+)\n(.*?)\n```"
@ -258,6 +272,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReviewerAgent"))
await runtime.register(
"CoderAgent",
lambda: CoderAgent(
@ -265,12 +280,13 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
),
)
await runtime.add_subscription(TypeSubscription("default", "CoderAgent"))
run_context = runtime.start()
await runtime.publish_message(
message=CodeWritingTask(
task="Write a function to find the directory with the largest number of files using multi-processing."
),
namespace="default",
topic_id=TopicId("default", "default"),
)
# Keep processing messages until idle.

View File

@ -26,7 +26,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import AgentId
from agnext.core import AgentId, AgentInstantiationContext, TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -69,7 +69,8 @@ class RoundRobinGroupChatManager(TypeRoutedAgent):
self._round_count += 1
if self._round_count > self._num_rounds * len(self._participants):
# End the conversation after the specified number of rounds.
await self.publish_message(Termination())
assert ctx.topic_id is not None
await self.publish_message(Termination(), ctx.topic_id)
return
# Send a request to speak message to the selected speaker.
await self.send_message(RequestToSpeak(), speaker)
@ -104,9 +105,10 @@ class GroupChatParticipant(TypeRoutedAgent):
llm_messages.append(UserMessage(content=m.content, source=m.source))
response = await self._model_client.create(self._system_messages + llm_messages)
assert isinstance(response.content, str)
speach = Message(content=response.content, source=self.metadata["type"])
self._memory.append(speach)
await self.publish_message(speach)
speech = Message(content=response.content, source=self.metadata["type"])
self._memory.append(speech)
assert ctx.topic_id is not None
await self.publish_message(speech, topic_id=ctx.topic_id)
async def main() -> None:
@ -114,7 +116,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register the participants.
agent1 = await runtime.register_and_get(
await runtime.register(
"DataScientist",
lambda: GroupChatParticipant(
description="A data scientist",
@ -122,7 +124,8 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
),
)
agent2 = await runtime.register_and_get(
await runtime.register(
"Engineer",
lambda: GroupChatParticipant(
description="An engineer",
@ -130,7 +133,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
),
)
agent3 = await runtime.register_and_get(
await runtime.register(
"Artist",
lambda: GroupChatParticipant(
description="An artist",
@ -144,7 +147,11 @@ async def main() -> None:
"GroupChatManager",
lambda: RoundRobinGroupChatManager(
description="A group chat manager",
participants=[agent1, agent2, agent3],
participants=[
AgentId("DataScientist", AgentInstantiationContext.current_agent_id().key),
AgentId("Engineer", AgentInstantiationContext.current_agent_id().key),
AgentId("Artist", AgentInstantiationContext.current_agent_id().key),
],
num_rounds=3,
),
)
@ -153,7 +160,9 @@ async def main() -> None:
run_context = runtime.start()
# Start the conversation.
await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
await runtime.publish_message(
Message(content="Hello, everyone!", source="Moderator"), topic_id=TopicId("default", "default")
)
await run_context.stop_when_idle()

View File

@ -16,11 +16,13 @@ from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import MessageContext
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import TopicId
from common.utils import get_chat_completion_client_from_envs
@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent):
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
await self.publish_message(task_result)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
class AggregatorAgent(TypeRoutedAgent):
@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent):
"""Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
await self.publish_message(ref_task)
assert ctx.topic_id is not None
await self.publish_message(ref_task, topic_id=ctx.topic_id)
@message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent):
)
assert isinstance(response.content, str)
task_result = AggregatorTaskResult(result=response.content)
await self.publish_message(task_result)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
self._session_results.pop(message.session_id)
print(f"Aggregator result: {response.content}")
@ -120,6 +125,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1"))
await runtime.register(
"ReferenceAgent2",
lambda: ReferenceAgent(
@ -128,6 +134,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2"))
await runtime.register(
"ReferenceAgent3",
lambda: ReferenceAgent(
@ -136,6 +143,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3"))
await runtime.register(
"AggregatorAgent",
lambda: AggregatorAgent(
@ -149,8 +157,11 @@ async def main() -> None:
num_references=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent"))
run_context = runtime.start()
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
await runtime.publish_message(
AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default")
)
# Keep processing messages.
await run_context.stop_when_idle()

View File

@ -41,6 +41,7 @@ from typing import Dict, List, Tuple
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@ -48,6 +49,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent):
answer = match.group(1)
# Increment the counter.
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
assert ctx.topic_id is not None
if self._counters[message.session_id] == self._max_round:
# If the counter reaches the maximum round, publishes a final response.
await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id))
await self.publish_message(
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
)
else:
# Publish intermediate response.
await self.publish_message(
@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent):
answer=answer,
session_id=message.session_id,
round=self._counters[message.session_id],
)
),
topic_id=ctx.topic_id,
)
@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent):
"in the form of {{answer}}, at the end of your response."
)
session_id = str(uuid.uuid4())
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
assert ctx.topic_id is not None
await self.publish_message(
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
)
@message_handler
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent):
answers = [resp.answer for resp in self._responses[message.session_id]]
majority_answer = max(set(answers), key=answers.count)
# Publish the aggregated response.
await self.publish_message(Answer(content=majority_answer))
assert ctx.topic_id is not None
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
# Clear the responses.
self._responses.pop(message.session_id)
print(f"Aggregated answer: {majority_answer}")
@ -223,6 +233,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver1"))
await runtime.register(
"MathSolver2",
lambda: MathSolver(
@ -231,6 +242,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver2"))
await runtime.register(
"MathSolver3",
lambda: MathSolver(
@ -239,6 +251,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver3"))
await runtime.register(
"MathSolver4",
lambda: MathSolver(
@ -247,13 +260,14 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver4"))
# Register the aggregator agent.
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
run_context = runtime.start()
# Send a math problem to the aggregator agent.
await runtime.publish_message(Question(content=question), namespace="default")
await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default"))
await run_context.stop_when_idle()

View File

@ -30,7 +30,7 @@ from agnext.components.models import (
)
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema
from agnext.core import AgentId
from agnext.core import AgentId, AgentInstantiationContext
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -107,21 +107,21 @@ async def main() -> None:
)
]
# Register agents.
tool_executor_agent = await runtime.register_and_get(
await runtime.register(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
await runtime.register(
"tool_enabled_agent",
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
),
)
@ -129,7 +129,7 @@ async def main() -> None:
# Send a task to the tool user.
response = await runtime.send_message(
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default")
)
print(response.content)

View File

@ -16,7 +16,7 @@ from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import SystemMessage
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import AgentId
from agnext.core import AgentId, AgentInstantiationContext
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -48,21 +48,21 @@ async def main() -> None:
)
]
# Register agents.
tool_executor_agent = await runtime.register_and_get(
await runtime.register(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
await runtime.register(
"tool_enabled_agent",
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
),
)
@ -70,7 +70,7 @@ async def main() -> None:
# Send a task to the tool user.
response = await runtime.send_message(
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default")
)
print(response.content)

View File

@ -21,6 +21,7 @@ from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import (
AssistantMessage,
@ -32,6 +33,7 @@ from agnext.components.models import (
UserMessage,
)
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -88,7 +90,8 @@ class ToolExecutorAgent(TypeRoutedAgent):
session_id=message.session_id,
result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id),
)
await self.publish_message(task_result)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
class ToolUseAgent(TypeRoutedAgent):
@ -126,7 +129,8 @@ class ToolUseAgent(TypeRoutedAgent):
if isinstance(response.content, str):
# If the response is a string, just publish the response.
response_message = AgentResponse(content=response.content)
await self.publish_message(response_message)
assert ctx.topic_id is not None
await self.publish_message(response_message, topic_id=ctx.topic_id)
print(f"AI Response: {response.content}")
return
@ -139,7 +143,8 @@ class ToolUseAgent(TypeRoutedAgent):
for function_call in response.content:
task = ToolExecutionTask(session_id=session_id, function_call=function_call)
self._tool_counter[session_id] += 1
await self.publish_message(task)
assert ctx.topic_id is not None
await self.publish_message(task, topic_id=ctx.topic_id)
@message_handler
async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None:
@ -165,10 +170,11 @@ class ToolUseAgent(TypeRoutedAgent):
self._sessions[message.session_id].append(
AssistantMessage(content=response.content, source=self.metadata["type"])
)
assert ctx.topic_id is not None
# If the response is a string, just publish the response.
if isinstance(response.content, str):
response_message = AgentResponse(content=response.content)
await self.publish_message(response_message)
await self.publish_message(response_message, topic_id=ctx.topic_id)
self._tool_results.pop(message.session_id)
self._tool_counter.pop(message.session_id)
print(f"AI Response: {response.content}")
@ -179,7 +185,7 @@ class ToolUseAgent(TypeRoutedAgent):
for function_call in response.content:
task = ToolExecutionTask(session_id=message.session_id, function_call=function_call)
self._tool_counter[message.session_id] += 1
await self.publish_message(task)
await self.publish_message(task, topic_id=ctx.topic_id)
async def main() -> None:
@ -192,6 +198,7 @@ async def main() -> None:
]
# Register agents.
await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
await runtime.add_subscription(TypeSubscription("default", "tool_executor"))
await runtime.register(
"tool_use_agent",
lambda: ToolUseAgent(
@ -201,12 +208,13 @@ async def main() -> None:
tools=tools,
),
)
await runtime.add_subscription(TypeSubscription("default", "tool_use_agent"))
run_context = runtime.start()
# Publish a task.
await runtime.publish_message(
UserRequest("Run the following Python code: print('Hello, World!')"), namespace="default"
UserRequest("Run the following Python code: print('Hello, World!')"), topic_id=TopicId("default", "default")
)
await run_context.stop_when_idle()

View File

@ -15,11 +15,13 @@ from agnext.components.models import (
)
from agnext.components.tool_agent import ToolAgent
from agnext.components.tools import FunctionTool, Tool
from agnext.core import AgentInstantiationContext
from typing_extensions import Annotated
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import AgentId
from coding_direct import Message, ToolUseAgent
from common.utils import get_chat_completion_client_from_envs
@ -42,23 +44,24 @@ async def main() -> None:
)
]
# Register agents.
tool_executor_agent = await runtime.register_and_get(
await runtime.register(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
await runtime.register(
"tool_enabled_agent",
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key),
),
)
tool_use_agent = AgentId("tool_enabled_agent", "default")
run_context = runtime.start()

View File

@ -4,7 +4,8 @@ from dataclasses import dataclass
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
from agnext.components._type_subscription import TypeSubscription
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId
@dataclass
@ -38,11 +39,14 @@ class ReceiveAgent(TypeRoutedAgent):
@message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
assert ctx.topic_id is not None
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=ctx.topic_id)
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
assert ctx.topic_id is not None
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
class GreeterAgent(TypeRoutedAgent):
@ -51,11 +55,15 @@ class GreeterAgent(TypeRoutedAgent):
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
await self.publish_message(Greeting(f"Hello, {message.content}!"))
assert ctx.topic_id is not None
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
@message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
await self.publish_message(Feedback(f"Feedback: {message.content}"))
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
async def main() -> None:
@ -68,9 +76,11 @@ async def main() -> None:
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
await runtime.add_subscription(TypeSubscription("default", "reciever"))
await runtime.register("greeter", lambda: GreeterAgent())
await runtime.add_subscription(TypeSubscription("default", "greeter"))
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
# Just to keep the runtime running
try:

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
@dataclass
@ -43,7 +43,8 @@ class GreeterAgent(TypeRoutedAgent):
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
await self.publish_message(Feedback(f"Feedback: {response.content}"))
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
async def main() -> None:
@ -54,10 +55,11 @@ async def main() -> None:
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
reciever = await runtime.get("reciever")
await runtime.register("greeter", lambda: GreeterAgent(reciever))
await runtime.register(
"greeter", lambda: GreeterAgent(AgentId("reciever", AgentInstantiationContext.current_agent_id().key))
)
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
# Just to keep the runtime running
try:

View File

@ -12,13 +12,13 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from agnext.core import Subscription, TopicId
from ..core import (
MESSAGE_TYPE_REGISTRY,
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentProxy,
AgentRuntime,
CancellationToken,
MessageContext,
@ -38,7 +38,7 @@ class PublishMessageEnvelope:
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
namespace: str
topic_id: TopicId
@dataclass(kw_only=True)
@ -124,16 +124,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
self._subscriptions: List[Subscription] = []
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
@property
def unprocessed_messages(
self,
@ -177,8 +179,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if sender is not None and sender.key != recipient.key:
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
await self._process_seen_namespace(recipient.key)
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
@ -199,8 +199,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
@ -219,26 +219,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
# )
# )
if sender is None and namespace is None:
raise ValueError("Namespace must be provided if sender is not provided.")
sender_namespace = sender.key if sender is not None else None
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
raise ValueError(
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
)
assert explicit_namespace is not None or sender_namespace is not None
namespace = cast(str, explicit_namespace or sender_namespace)
await self._process_seen_namespace(namespace)
self._message_queue.append(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
namespace=namespace,
message=message, cancellation_token=cancellation_token, sender=sender, topic_id=topic_id
)
)
@ -300,12 +283,13 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._outstanding_tasks.decrement()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
self._build_for_new_topic(message_envelope.topic_id)
responses: List[Awaitable[Any]] = []
target_namespace = message_envelope.namespace
for agent_id in self._per_type_subscribers[
(target_namespace, MESSAGE_TYPE_REGISTRY.type_name(message_envelope.message))
]:
if message_envelope.sender is not None and agent_id.type == message_envelope.sender.type:
recipients = self._subscribed_recipients[message_envelope.topic_id]
for agent_id in recipients:
# Avoid sending the message back to the sender
if message_envelope.sender is not None and agent_id == message_envelope.sender:
continue
sender_agent = (
@ -326,8 +310,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
# )
message_context = MessageContext(
sender=message_envelope.sender,
# TODO: topic_id
topic_id=None,
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
)
@ -460,16 +443,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def register(
self,
name: str,
type: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
await self._get_agent(AgentId(type=name, key=namespace))
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
self._agent_factories[type] = agent_factory
async def _invoke_agent_factory(
self,
@ -496,7 +475,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent:
await self._process_seen_namespace(agent_id.key)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
@ -504,20 +482,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
raise LookupError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
for message_type in agent.metadata["subscriptions"]:
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
self._instantiated_agents[agent_id] = agent
return agent
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
return (await self._get_agent(AgentId(type=name, key=namespace))).id
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.type not in self._agent_factories:
@ -531,12 +499,40 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent_instance
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
async def _process_seen_namespace(self, namespace: str) -> None:
if namespace in self._known_namespaces:
async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub.id == subscription.id for sub in self._subscriptions):
raise ValueError("Subscription already exists")
if len(self._seen_topics) > 0:
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
self._subscriptions.append(subscription)
async def remove_subscription(self, id: str) -> None:
# Check if the subscription exists
if not any(sub.id == id for sub in self._subscriptions):
raise ValueError("Subscription does not exist")
def is_not_sub(x: Subscription) -> bool:
return x.id != id
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
# Rebuild the subscriptions
self._rebuild_subscriptions(self._seen_topics)
# TODO: optimize this...
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
self._subscribed_recipients.clear()
for topic in topics:
self._build_for_new_topic(topic)
def _build_for_new_topic(self, topic: TopicId) -> None:
if topic in self._seen_topics:
return
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
await self._get_agent(AgentId(type=name, key=namespace))
self._seen_topics.add(topic)
for subscription in self._subscriptions:
if subscription.is_match(topic):
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))

View File

@ -28,9 +28,9 @@ import grpc
from grpc.aio import StreamStreamCall
from typing_extensions import Self
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken
from .protos import AgentId as AgentIdProto
from .protos import (
AgentRpcStub,
@ -153,6 +153,9 @@ class WorkerAgentRuntime(AgentRuntime):
self._next_request_id = 0
self._host_connection: HostConnection | None = None
self._background_tasks: Set[Task[Any]] = set()
self._subscriptions: List[Subscription] = []
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
async def start(self, host_connection_string: str) -> None:
if self._running:
@ -245,29 +248,25 @@ class WorkerAgentRuntime(AgentRuntime):
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
assert self._host_connection is not None
sender_namespace = sender.key if sender is not None else None
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
raise ValueError(
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
)
assert explicit_namespace is not None or sender_namespace is not None
actual_namespace = cast(str, explicit_namespace or sender_namespace)
await self._process_seen_namespace(actual_namespace)
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
task = asyncio.create_task(self._host_connection.send(message))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
message = Message(
event=Event(
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
)
)
async def write_message() -> None:
assert self._host_connection is not None
await self._host_connection.send(message)
await asyncio.create_task(write_message())
async def save_state(self) -> Mapping[str, Any]:
raise NotImplementedError("Saving state is not yet implemented.")
@ -284,26 +283,6 @@ class WorkerAgentRuntime(AgentRuntime):
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")
async def register(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
if not self._running:
raise ValueError("Runtime must be running when registering agent.")
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
await self._get_agent(AgentId(type=name, key=namespace))
assert self._host_connection is not None
message = Message(registerAgentType=RegisterAgentType(type=name))
await self._host_connection.send(message)
logger.info("Sent registerAgentType message for %s", name)
async def _process_request(self, request: RpcRequest) -> None:
assert self._host_connection is not None
target = AgentId(request.target.name, request.target.namespace)
@ -347,27 +326,41 @@ class WorkerAgentRuntime(AgentRuntime):
future.set_result(response.result)
async def _process_event(self, event: Event) -> None:
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
namespace = event.namespace
responses: List[Awaitable[Any]] = []
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
# TODO: skip the sender?
message_context = MessageContext(
sender=None,
topic_id=None,
is_rpc=False,
cancellation_token=CancellationToken(),
)
agent = await self._get_agent(agent_id)
future = agent.on_message(message, ctx=message_context)
responses.append(future)
...
# message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
try:
_ = await asyncio.gather(*responses)
except BaseException as e:
if isinstance(e, asyncio.CancelledError):
return
event_logger.error("Error handling event message", exc_info=e)
# for agent_id in self._per_type_subscribers[
# (namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
# ]:
# agent = await self._get_agent(agent_id)
# message_context = MessageContext(
# # TODO: should sender be in the proto even for published events?
# sender=None,
# # TODO: topic_id
# topic_id=None,
# is_rpc=False,
# cancellation_token=CancellationToken(),
# )
# try:
# await agent.on_message(message, ctx=message_context)
# logger.info("%s handled event %s", agent_id, message)
# except Exception as e:
# event_logger.error("Error handling message", exc_info=e)
async def register(
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
) -> None:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
self._agent_factories[type] = agent_factory
assert self._host_connection is not None
message = Message(registerAgentType=RegisterAgentType(type=type))
await self._host_connection.send(message)
logger.info("Sent registerAgentType message for %s", type)
async def _invoke_agent_factory(
self,
@ -394,7 +387,6 @@ class WorkerAgentRuntime(AgentRuntime):
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent:
await self._process_seen_namespace(agent_id.key)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
@ -402,32 +394,16 @@ class WorkerAgentRuntime(AgentRuntime):
raise ValueError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
for message_type in agent.metadata["subscriptions"]:
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
self._instantiated_agents[agent_id] = agent
return agent
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
return (await self._get_agent(AgentId(type=name, key=namespace))).id
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
async def _process_seen_namespace(self, namespace: str) -> None:
if namespace in self._known_namespaces:
return
async def add_subscription(self, subscription: Subscription) -> None:
raise NotImplementedError("Subscriptions are not yet implemented.")
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
await self._get_agent(AgentId(type=name, key=namespace))
async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions are not yet implemented.")

View File

@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x96\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb2\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12\x11\n\tdata_type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -38,13 +38,13 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=257
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=304
_globals['_EVENT']._serialized_start=476
_globals['_EVENT']._serialized_end=626
_globals['_EVENT']._serialized_end=654
_globals['_EVENT_METADATAENTRY']._serialized_start=257
_globals['_EVENT_METADATAENTRY']._serialized_end=304
_globals['_REGISTERAGENTTYPE']._serialized_start=628
_globals['_REGISTERAGENTTYPE']._serialized_end=661
_globals['_MESSAGE']._serialized_start=664
_globals['_MESSAGE']._serialized_end=852
_globals['_AGENTRPC']._serialized_start=854
_globals['_AGENTRPC']._serialized_end=917
_globals['_REGISTERAGENTTYPE']._serialized_start=656
_globals['_REGISTERAGENTTYPE']._serialized_end=689
_globals['_MESSAGE']._serialized_start=692
_globals['_MESSAGE']._serialized_end=880
_globals['_AGENTRPC']._serialized_start=882
_globals['_AGENTRPC']._serialized_end=945
# @@protoc_insertion_point(module_scope)

View File

@ -14,8 +14,6 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class AgentId(google.protobuf.message.Message):
"""TODO: update"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
@ -143,24 +141,27 @@ class Event(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
NAMESPACE_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
TOPIC_TYPE_FIELD_NUMBER: builtins.int
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
namespace: builtins.str
type: builtins.str
topic_type: builtins.str
topic_source: builtins.str
data_type: builtins.str
data: builtins.str
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
*,
namespace: builtins.str = ...,
type: builtins.str = ...,
topic_type: builtins.str = ...,
topic_source: builtins.str = ...,
data_type: builtins.str = ...,
data: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "metadata", b"metadata", "namespace", b"namespace", "type", b"type"]) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
global___Event = Event

View File

@ -72,7 +72,6 @@ class ClosureAgent(Agent):
key=self._id.key,
type=self._id.type,
description=self._description,
subscriptions=self._subscriptions,
)
@property

View File

@ -142,12 +142,12 @@ class TypeRoutedAgent(BaseAgent):
message_handler = cast(MessageHandler[Any, Any], handler)
for target_type in message_handler.target_types:
self._handlers[target_type] = message_handler
subscriptions = list(self._handlers.keys())
for message_type in self._handlers.keys():
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
MESSAGE_TYPE_REGISTRY.add_type(message_type)
subscriptions_str = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscriptions]
super().__init__(description, subscriptions_str)
super().__init__(description)
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
key_type: Type[Any] = type(message) # type: ignore

View File

@ -42,5 +42,4 @@ class TypeSubscription(Subscription):
if not self.is_match(topic_id):
raise CantHandleException("TopicId does not match the subscription")
# TODO: Update agentid to reflect agent type and key
return AgentId(type=self._agent_type, key=topic_id.source)

View File

@ -12,7 +12,7 @@ from ._agent_runtime import AgentRuntime
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
from ._subscription import Subscription
from ._topic import TopicId
@ -32,4 +32,5 @@ __all__ = [
"TopicId",
"Subscription",
"MessageContext",
"Serialization",
]

View File

@ -1,8 +1,7 @@
from typing import Sequence, TypedDict
from typing import TypedDict
class AgentMetadata(TypedDict):
type: str
key: str
description: str
subscriptions: Sequence[str]

View File

@ -5,8 +5,9 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_proxy import AgentProxy
from ._cancellation_token import CancellationToken
from ._subscription import Subscription
from ._topic import TopicId
# Undeliverable - error
@ -45,8 +46,8 @@ class AgentRuntime(Protocol):
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
@ -56,23 +57,24 @@ class AgentRuntime(Protocol):
Args:
message (Any): The message to publish.
namespace (str | None, optional): The namespace to publish to. Defaults to None.
topic (TopicId): The topic to publish the message to.
sender (AgentId | None, optional): The agent which sent the message. Defaults to None.
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
Raises:
UndeliverableException: If the message cannot be delivered.
"""
...
async def register(
self,
name: str,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
Args:
name (str): The name of the type agent this factory creates.
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
@ -93,30 +95,6 @@ class AgentRuntime(Protocol):
...
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
"""Get an agent by name and namespace.
Args:
name (str): The name of the agent.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentId: The agent id.
"""
...
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
"""Get a proxy for an agent by name and namespace.
Args:
name (str): The name of the agent.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentProxy: The agent proxy.
"""
...
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
@ -137,46 +115,6 @@ class AgentRuntime(Protocol):
"""
...
async def register_and_get(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId:
"""Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentId: The agent id.
"""
await self.register(name, agent_factory)
return await self.get(name, namespace=namespace)
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy:
"""Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
AgentProxy: The agent proxy.
"""
await self.register(name, agent_factory)
return await self.get_proxy(name, namespace=namespace)
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
@ -227,3 +165,22 @@ class AgentRuntime(Protocol):
state (Mapping[str, Any]): The saved state.
"""
...
async def add_subscription(self, subscription: Subscription) -> None:
"""Add a new subscription that the runtime should fulfill when processing published messages
Args:
subscription (Subscription): The subscription to add
"""
...
async def remove_subscription(self, id: str) -> None:
"""Remove a subscription from the runtime
Args:
id (str): id of the subscription to remove
Raises:
LookupError: If the subscription does not exist
"""
...

View File

@ -1,6 +1,6 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Mapping, Sequence
from typing import Any, Mapping
from ._agent import Agent
from ._agent_id import AgentId
@ -9,20 +9,16 @@ from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._topic import TopicId
class BaseAgent(ABC, Agent):
@property
def metadata(self) -> AgentMetadata:
assert self._id is not None
return AgentMetadata(
key=self._id.key,
type=self._id.type,
description=self._description,
subscriptions=self._subscriptions,
)
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
def __init__(self, description: str) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
@ -36,7 +32,6 @@ class BaseAgent(ABC, Agent):
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description
self._subscriptions = subscriptions
@property
def type(self) -> str:
@ -74,10 +69,11 @@ class BaseAgent(ABC, Agent):
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
cancellation_token: CancellationToken | None = None,
) -> None:
await self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
def save_state(self) -> Mapping[str, Any]:
warnings.warn("save_state not implemented", stacklevel=2)

View File

@ -1,10 +1,11 @@
from typing import Protocol
from typing import Protocol, runtime_checkable
from agnext.core._agent_id import AgentId
from agnext.core import AgentId
from ._topic import TopicId
@runtime_checkable
class Subscription(Protocol):
"""Subscriptions define the topics that an agent is interested in."""
@ -19,6 +20,20 @@ class Subscription(Protocol):
"""
...
def __eq__(self, other: object) -> bool:
"""Check if two subscriptions are equal.
Args:
other (object): Other subscription to compare against.
Returns:
bool: True if the subscriptions are equal, False otherwise.
"""
if not isinstance(other, Subscription):
return False
return self.id == other.id
def is_match(self, topic_id: TopicId) -> bool:
"""Check if a given topic_id matches the subscription.

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
@dataclass
@dataclass(eq=True, frozen=True)
class TopicId:
type: str
"""Type of the event that this topic_id contains. Adhere's to the cloud event spec.

View File

@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import LedgerOrchestrator
from team_one.agents.user_proxy import UserProxy
@ -15,18 +16,22 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
coder = AgentProxy(AgentId("Coder", "default"), runtime)
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
executor = AgentProxy(AgentId("Executor", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(description="The current user interacting with you."),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
# TODO: doesn't work for more than default key
await runtime.register(
"orchestrator",
lambda: LedgerOrchestrator(

View File

@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@ -15,17 +16,20 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
coder = AgentProxy(AgentId("Coder", "default"), runtime)
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
executor = AgentProxy(AgentId("Executor", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
lambda: UserProxy(description="The current user interacting with you."),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))

View File

@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.file_surfer import FileSurfer
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@ -18,14 +19,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
file_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
user_proxy = await runtime.register_and_get_proxy(
file_surfer = AgentProxy(AgentId("file_surfer", "default"), runtime)
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))

View File

@ -4,6 +4,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import UserMessage
from agnext.core import AgentId, AgentProxy, TopicId
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.reflex_agents import ReflexAgent
from team_one.messages import BroadcastMessage
@ -13,14 +14,19 @@ from team_one.utils import LogHandler
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
await runtime.register("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake1 = AgentProxy(AgentId("fake_agent_1", "default"), runtime)
await runtime.register("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake2 = AgentProxy(AgentId("fake_agent_2", "default"), runtime)
await runtime.register("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
fake3 = AgentProxy(AgentId("fake_agent_3", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
task_message = UserMessage(content="Test Message", source="User")
run_context = runtime.start()
await runtime.publish_message(BroadcastMessage(task_message), namespace="default")
await runtime.publish_message(BroadcastMessage(task_message), topic_id=TopicId("default", "default"))
await run_context.stop_when_idle()

View File

@ -4,6 +4,7 @@ import logging
# from typing import Any, Dict, List, Tuple, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@ -19,14 +20,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=client),
)
user_proxy = await runtime.register_and_get_proxy(
coder = AgentProxy(AgentId("Coder", "default"), runtime)
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))

View File

@ -4,6 +4,7 @@ import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@ -21,15 +22,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))

View File

@ -5,7 +5,7 @@ from agnext.components.models import (
LLMMessage,
UserMessage,
)
from agnext.core import CancellationToken, MessageContext
from agnext.core import CancellationToken, MessageContext, TopicId
from team_one.messages import (
BroadcastMessage,
@ -45,7 +45,8 @@ class BaseWorker(TeamOneBaseAgent):
self._chat_history.append(assistant_message)
user_message = UserMessage(content=response, source=self.metadata["type"])
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""

View File

@ -2,7 +2,7 @@ import json
from typing import Any, Dict, List, Optional
from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from agnext.core import AgentProxy
from agnext.core import AgentProxy, TopicId
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger
@ -248,8 +248,10 @@ class LedgerOrchestrator(BaseOrchestrator):
synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
)
logger.info(
@ -319,14 +321,17 @@ class LedgerOrchestrator(BaseOrchestrator):
# Reset everyone, then rebroadcast the new plan
self._chat_history = [self._chat_history[0]]
await self.publish_message(ResetMessage())
topic_id = TopicId("default", self.id.key)
await self.publish_message(ResetMessage(), topic_id=topic_id)
# Send everyone the NEW plan
synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
)
logger.info(
@ -351,8 +356,10 @@ class LedgerOrchestrator(BaseOrchestrator):
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
self._chat_history.append(assistant_message) # My copy
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=user_message, request_halt=False)
BroadcastMessage(content=user_message, request_halt=False),
topic_id=topic_id,
) # Send to everyone else
return agent

View File

@ -1,6 +1,6 @@
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import UserMessage
from agnext.core import MessageContext
from agnext.core import MessageContext, TopicId
from ..messages import BroadcastMessage, RequestReplyMessage
@ -22,5 +22,6 @@ class ReflexAgent(TypeRoutedAgent):
content=f"Hello, world from {name}!",
source=name,
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(response_message))
await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id)

View File

@ -7,11 +7,14 @@ from math import ceil
import asyncio
import pytest
from agnext.core import AgentId
from agnext.core import AgentProxy
pytest_plugins = ('pytest_asyncio',)
from json import dumps
from team_one.utils import (
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON,
create_completion_client_from_env
)
@ -96,13 +99,14 @@ async def test_web_surfer() -> None:
# Register agents.
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentId("WebSurfer", "default")
run_context = runtime.start()
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
# Test some basic navigations
@ -138,7 +142,7 @@ async def test_web_surfer() -> None:
tool_resp = await make_browser_request(actual_surfer, TOOL_PAGE_DOWN)
assert (
f"The viewport shows {viewport_percentage}% of the webpage, and is positioned at the bottom of the page" in tool_resp
)
)
# Test Q&A and summarization -- we don't have a key so we expect it to fail #(but it means the code path is correct)
with pytest.raises(AuthenticationError):
@ -160,15 +164,17 @@ async def test_web_surfer_oai() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
run_context = runtime.start()
@ -220,10 +226,12 @@ async def test_web_surfer_bing() -> None:
# Register agents.
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
run_context = runtime.start()
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
@ -235,7 +243,7 @@ async def test_web_surfer_bing() -> None:
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:url"]
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:title"]
assert f"I typed '{BING_QUERY}' into the browser search bar." in tool_resp.replace("\\","")
tool_resp = await make_browser_request(actual_surfer, TOOL_WEB_SEARCH, {"query": BING_QUERY + " Wikipedia"})
markdown = await actual_surfer._get_page_markdown() # type: ignore
assert "https://en.wikipedia.org/wiki/" in markdown

View File

@ -6,16 +6,18 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentId, CancellationToken
from agnext.core import MessageContext
from agnext.core import AgentInstantiationContext
@dataclass
class MessageType:
...
class MessageType: ...
# Note for future reader:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("A long running agent")
@ -34,6 +36,7 @@ class LongRunningAgent(TypeRoutedAgent):
self.cancelled = True
raise
class NestingLongRunningAgent(TypeRoutedAgent):
def __init__(self, nested_agent: AgentId) -> None:
super().__init__("A nesting long running agent")
@ -58,9 +61,10 @@ class NestingLongRunningAgent(TypeRoutedAgent):
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
await runtime.register("long_running", LongRunningAgent)
agent_id = AgentId("long_running", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
@ -74,21 +78,25 @@ async def test_cancellation_with_token() -> None:
await response
assert response.done()
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled
@pytest.mark.asyncio
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
await runtime.register("long_running", LongRunningAgent)
await runtime.register(
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
@ -101,22 +109,29 @@ async def test_nested_cancellation_only_outer_called() -> None:
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
@pytest.mark.asyncio
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
await runtime.register("long_running", LongRunningAgent)
await runtime.register(
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
@ -131,9 +146,9 @@ async def test_nested_cancellation_inner_called() -> None:
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.core import AgentRuntime, AgentId
from agnext.components import ClosureAgent
@ -13,6 +14,7 @@ from agnext.components import ClosureAgent
import asyncio
from agnext.core import MessageContext
from agnext.core import TopicId
@dataclass
class Message:
@ -30,11 +32,15 @@ async def test_register_receives_publish() -> None:
key = id.key
await queue.put((key, message.content))
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
await runtime.register("name", lambda: ClosureAgent("my_agent", log_message))
await runtime.add_subscription(TypeSubscription("default", "name"))
topic_id = TopicId("default", "default")
run_context = runtime.start()
await runtime.publish_message(Message("first message"), namespace="default")
await runtime.publish_message(Message("second message"), namespace="default")
await runtime.publish_message(Message("third message"), namespace="default")
await runtime.publish_message(Message("first message"), topic_id=topic_id)
await runtime.publish_message(Message("second message"), topic_id=topic_id)
await runtime.publish_message(Message("third message"), topic_id=topic_id)
await run_context.stop_when_idle()

View File

@ -19,7 +19,8 @@ async def test_intervention_count_messages() -> None:
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = await runtime.register_and_get("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
loopback = AgentId("name", key="default")
run_context = runtime.start()
_response = await runtime.send_message(MessageType(), recipient=loopback)
@ -40,7 +41,8 @@ async def test_intervention_drop_send() -> None:
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = await runtime.register_and_get("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
loopback = AgentId("name", key="default")
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@ -62,7 +64,8 @@ async def test_intervention_drop_response() -> None:
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = await runtime.register_and_get("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
loopback = AgentId("name", key="default")
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@ -84,15 +87,16 @@ async def test_intervention_raise_exception_on_send() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = await runtime.register_and_get("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
loopback = AgentId("name", key="default")
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
@ -108,12 +112,13 @@ async def test_intervention_raise_exception_on_respond() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = await runtime.register_and_get("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
loopback = AgentId("name", key="default")
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

View File

@ -1,6 +1,8 @@
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.core import AgentId, AgentInstantiationContext
from agnext.core import TopicId
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@ -15,13 +17,12 @@ async def test_agent_names_must_be_unique() -> None:
assert agent.id == id
return agent
agent1 = await runtime.register_and_get("name1", agent_factory)
assert agent1 == AgentId("name1", "default")
await runtime.register("name1", agent_factory)
with pytest.raises(ValueError):
_agent1 = await runtime.register_and_get("name1", NoopAgent)
await runtime.register("name1", NoopAgent)
_agent1 = await runtime.register_and_get("name3", NoopAgent)
await runtime.register("name3", NoopAgent)
@pytest.mark.asyncio
@ -30,16 +31,19 @@ async def test_register_receives_publish() -> None:
await runtime.register("name", LoopbackAgent)
run_context = runtime.start()
await runtime.publish_message(MessageType(), namespace="default")
await runtime.add_subscription(TypeSubscription("default", "name"))
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await run_context.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent)
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@ -56,17 +60,19 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
run_context = runtime.start()
# Publish messages
topic_id = TopicId("default", "default")
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), namespace="default")
await runtime.publish_message(CascadingMessageType(round=1), topic_id)
# Process until idle.
await run_context.stop_when_idle()
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent)
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
import pytest
from agnext.core._serialization import Serialization
from agnext.core import Serialization
class PydanticMessage(BaseModel):
message: str

View File

@ -1,19 +1,16 @@
from typing import Any, Mapping, Sequence
from typing import Any, Mapping
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.core import BaseAgent, MessageContext
from agnext.core import AgentId
class StatefulAgent(BaseAgent):
def __init__(self) -> None:
super().__init__("A stateful agent", [])
super().__init__("A stateful agent")
self.state = 0
@property
def subscriptions(self) -> Sequence[type]:
return []
async def on_message(self, message: Any, ctx: MessageContext) -> None:
raise NotImplementedError
@ -28,7 +25,8 @@ class StatefulAgent(BaseAgent):
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
await runtime.register("name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
agent1.state = 1
@ -46,7 +44,8 @@ async def test_agent_can_save_state() -> None:
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
await runtime.register("name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
agent1.state = 1
@ -55,7 +54,8 @@ async def test_runtime_can_save_state() -> None:
runtime_state = await runtime.save_state()
runtime2 = SingleThreadedAgentRuntime()
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
await runtime2.register("name1", StatefulAgent)
agent2_id = AgentId("name1", key="default")
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
await runtime2.load_state(runtime_state)

View File

@ -13,6 +13,7 @@ from agnext.components.tool_agent import (
)
from agnext.components.tools import FunctionTool
from agnext.core import CancellationToken
from agnext.core import AgentId
def _pass_function(input: str) -> str:
@ -31,7 +32,7 @@ async def _async_sleep_function(input: str) -> str:
@pytest.mark.asyncio
async def test_tool_agent() -> None:
runtime = SingleThreadedAgentRuntime()
agent = await runtime.register_and_get(
await runtime.register(
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
@ -42,6 +43,7 @@ async def test_tool_agent() -> None:
],
),
)
agent = AgentId("tool_agent", "default")
run = runtime.start()
# Test pass function

View File

@ -38,11 +38,12 @@ class CascadingAgent(TypeRoutedAgent):
self.num_calls += 1
if message.round == self.max_rounds:
return
await self.publish_message(CascadingMessageType(round=message.round + 1))
assert ctx.topic_id is not None
await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=ctx.topic_id)
class NoopAgent(BaseAgent):
def __init__(self) -> None:
super().__init__("A no op agent", [])
super().__init__("A no op agent")
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
raise NotImplementedError