Migrate to using default sub/topic (#403)

This commit is contained in:
Jack Gerrits 2024-08-26 10:30:28 -04:00 committed by GitHub
parent d7ae2038fb
commit dbb35fc335
23 changed files with 491 additions and 526 deletions

View File

@ -295,18 +295,15 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import RoutedAgent, message_handler\n", "from agnext.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler\n",
"from agnext.core import MessageContext, TopicId\n", "from agnext.core import MessageContext\n",
"\n", "\n",
"\n", "\n",
"class BroadcastingAgent(RoutedAgent):\n", "class BroadcastingAgent(RoutedAgent):\n",
" @message_handler\n", " @message_handler\n",
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
" # Publish a message to all agents in the same namespace.\n", " # Publish a message to all agents in the same namespace.\n",
" assert ctx.topic_id is not None\n", " await self.publish_message(Message(f\"Publishing a message: {message.content}!\"), topic_id=DefaultTopicId())\n",
" await self.publish_message(\n",
" Message(f\"Publishing a message: {message.content}!\"), topic_id=TopicId(\"deafult\", self.id.key)\n",
" )\n",
"\n", "\n",
"\n", "\n",
"class ReceivingAgent(RoutedAgent):\n", "class ReceivingAgent(RoutedAgent):\n",
@ -337,13 +334,11 @@
} }
], ],
"source": [ "source": [
"from agnext.components import TypeSubscription\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n", "runtime = SingleThreadedAgentRuntime()\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", "await runtime.register(\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n", " \"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"), lambda: [DefaultSubscription()]\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n", ")\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n", "await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"), lambda: [DefaultSubscription()])\n",
"runtime.start()\n", "runtime.start()\n",
"await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n", "await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n",
"await runtime.stop()" "await runtime.stop()"
@ -376,12 +371,12 @@
"# Replace send_message with publish_message in the above example.\n", "# Replace send_message with publish_message in the above example.\n",
"\n", "\n",
"runtime = SingleThreadedAgentRuntime()\n", "runtime = SingleThreadedAgentRuntime()\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", "await runtime.register(\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n", " \"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"), lambda: [DefaultSubscription()]\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n", ")\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n", "await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"), lambda: [DefaultSubscription()])\n",
"runtime.start()\n", "runtime.start()\n",
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=TopicId(\"default\", \"default\"))\n", "await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=DefaultTopicId())\n",
"await runtime.stop_when_idle()" "await runtime.stop_when_idle()"
] ]
}, },

View File

@ -3,6 +3,7 @@ import json
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
from agnext.components import ( from agnext.components import (
DefaultTopicId,
FunctionCall, FunctionCall,
RoutedAgent, RoutedAgent,
message_handler, message_handler,
@ -110,9 +111,8 @@ class ChatCompletionAgent(RoutedAgent):
# Generate a response. # Generate a response.
response = await self._generate_response(message.response_format, ctx) response = await self._generate_response(message.response_format, ctx)
assert ctx.topic_id is not None
# Publish the response. # Publish the response.
await self.publish_message(response, topic_id=ctx.topic_id) await self.publish_message(response, topic_id=DefaultTopicId())
@message_handler() @message_handler()
async def on_tool_call_message( async def on_tool_call_message(

View File

@ -2,6 +2,7 @@ from typing import Literal
import openai import openai
from agnext.components import ( from agnext.components import (
DefaultTopicId,
Image, Image,
RoutedAgent, RoutedAgent,
message_handler, message_handler,
@ -57,8 +58,7 @@ class ImageGenerationAgent(RoutedAgent):
image is published as a MultiModalMessage.""" image is published as a MultiModalMessage."""
response = await self._generate_response(ctx.cancellation_token) response = await self._generate_response(ctx.cancellation_token)
assert ctx.topic_id is not None await self.publish_message(response, topic_id=DefaultTopicId())
await self.publish_message(response, topic_id=ctx.topic_id)
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage: async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
messages = await self._memory.get_messages() messages = await self._memory.get_messages()

View File

@ -1,7 +1,7 @@
from typing import Any, Callable, List, Mapping from typing import Any, Callable, List, Mapping
import openai import openai
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import ( from agnext.core import (
CancellationToken, CancellationToken,
MessageContext, # type: ignore MessageContext, # type: ignore
@ -80,8 +80,7 @@ class OpenAIAssistantAgent(RoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method generates a response and publishes it.""" """Handle a publish now message. This method generates a response and publishes it."""
response = await self._generate_response(message.response_format, ctx.cancellation_token) response = await self._generate_response(message.response_format, ctx.cancellation_token)
assert ctx.topic_id is not None await self.publish_message(response, DefaultTopicId())
await self.publish_message(response, ctx.topic_id)
async def _generate_response( async def _generate_response(
self, self,

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import MessageContext from agnext.core import MessageContext
from ..types import PublishNow, TextMessage from ..types import PublishNow, TextMessage
@ -23,8 +23,9 @@ class UserProxyAgent(RoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method prompts the user for input, then publishes it.""" """Handle a publish now message. This method prompts the user for input, then publishes it."""
user_input = await self.get_user_input(self._user_input_prompt) user_input = await self.get_user_input(self._user_input_prompt)
assert ctx.topic_id is not None await self.publish_message(
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id) TextMessage(content=user_input, source=self.metadata["type"]), topic_id=DefaultTopicId()
)
async def get_user_input(self, prompt: str) -> str: async def get_user_input(self, prompt: str) -> str:
"""Get user input from the console. Override this method to customize how user input is retrieved.""" """Get user input from the console. Override this method to customize how user input is retrieved."""

View File

@ -17,7 +17,7 @@ from dataclasses import dataclass
from typing import List from typing import List
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ( from agnext.components.models import (
AssistantMessage, AssistantMessage,
@ -74,7 +74,7 @@ class ChatCompletionAgent(RoutedAgent):
if ctx.topic_id is not None: if ctx.topic_id is not None:
await self.publish_message( await self.publish_message(
Message(content=response.content, source=self.metadata["type"]), topic_id=ctx.topic_id Message(content=response.content, source=self.metadata["type"]), topic_id=DefaultTopicId()
) )

View File

@ -12,7 +12,7 @@ from typing import List
import aiofiles import aiofiles
import openai import openai
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import AgentId, AgentRuntime, MessageContext from agnext.core import AgentId, AgentRuntime, MessageContext
from openai import AsyncAssistantEventHandler from openai import AsyncAssistantEventHandler
from openai.types.beta.thread import ToolResources from openai.types.beta.thread import ToolResources
@ -109,9 +109,8 @@ class UserProxyAgent(RoutedAgent):
return return
else: else:
# Publish user input and exit handler. # Publish user input and exit handler.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id TextMessage(content=user_input, source=self.metadata["type"]), topic_id=DefaultTopicId()
) )
return return

View File

@ -6,7 +6,8 @@ import os
import sys import sys
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._default_subscription import DefaultSubscription
from agnext.components.memory import ChatMemory from agnext.components.memory import ChatMemory
from agnext.components.models import ChatCompletionClient, SystemMessage from agnext.components.models import ChatCompletionClient, SystemMessage
from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime
@ -76,9 +77,8 @@ Use the following JSON format to provide your thought on the latest message and
# Publish the response if needed. # Publish the response if needed.
if respond is True or str(respond).lower().strip() == "true": if respond is True or str(respond).lower().strip() == "true":
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
TextMessage(source=self.metadata["type"], content=str(response)), topic_id=ctx.topic_id TextMessage(source=self.metadata["type"], content=str(response)), topic_id=DefaultTopicId()
) )
@ -98,6 +98,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
description="The user in the chat room.", description="The user in the chat room.",
app=app, app=app,
), ),
lambda: [DefaultSubscription()],
) )
await runtime.register( await runtime.register(
"Alice", "Alice",
@ -108,6 +109,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10), memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
), ),
lambda: [DefaultSubscription()],
) )
alice = AgentProxy(AgentId("Alice", "default"), runtime) alice = AgentProxy(AgentId("Alice", "default"), runtime)
await runtime.register( await runtime.register(
@ -119,6 +121,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10), memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
), ),
lambda: [DefaultSubscription()],
) )
bob = AgentProxy(AgentId("Bob", "default"), runtime) bob = AgentProxy(AgentId("Bob", "default"), runtime)
await runtime.register( await runtime.register(
@ -130,6 +133,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10), memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
), ),
lambda: [DefaultSubscription()],
) )
charlie = AgentProxy(AgentId("Charlie", "default"), runtime) charlie = AgentProxy(AgentId("Charlie", "default"), runtime)
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants: app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:

View File

@ -10,10 +10,11 @@ import sys
from typing import Annotated, Literal from typing import Annotated, Literal
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import DefaultTopicId
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import SystemMessage from agnext.components.models import SystemMessage
from agnext.components.tools import FunctionTool from agnext.components.tools import FunctionTool
from agnext.core import AgentInstantiationContext, AgentRuntime, TopicId from agnext.core import AgentInstantiationContext, AgentRuntime
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
from chess import piece_name as get_piece_name from chess import piece_name as get_piece_name
@ -214,9 +215,7 @@ async def main() -> None:
await chess_game(runtime) await chess_game(runtime)
runtime.start() runtime.start()
# Publish an initial message to trigger the group chat manager to start orchestration. # Publish an initial message to trigger the group chat manager to start orchestration.
await runtime.publish_message( await runtime.publish_message(TextMessage(content="Game started.", source="System"), topic_id=DefaultTopicId())
TextMessage(content="Game started.", source="System"), topic_id=TopicId("default", "default")
)
await runtime.stop_when_idle() await runtime.stop_when_idle()

View File

@ -4,7 +4,7 @@ import random
import sys import sys
from asyncio import Future from asyncio import Future
from agnext.components import Image, RoutedAgent, message_handler from agnext.components import DefaultTopicId, Image, RoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken from agnext.core import AgentRuntime, CancellationToken
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
@ -13,7 +13,6 @@ from textual_imageview.viewer import ImageViewer
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agnext.core import TopicId
from common.types import ( from common.types import (
MultiModalMessage, MultiModalMessage,
PublishNow, PublishNow,
@ -136,9 +135,7 @@ class TextualChatApp(App): # type: ignore
chat_messages.query("#typing").remove() chat_messages.query("#typing").remove()
# Publish the user message to the runtime. # Publish the user message to the runtime.
await self._runtime.publish_message( await self._runtime.publish_message(
# TODO fix hard coded topic_id TextMessage(source=self._user_name, content=user_input), topic_id=DefaultTopicId()
TextMessage(source=self._user_name, content=user_input),
topic_id=TopicId("default", "default"),
) )
async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore

View File

@ -1,4 +1,4 @@
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components.models import ChatCompletionClient from agnext.components.models import ChatCompletionClient
from agnext.components.models._types import SystemMessage from agnext.components.models._types import SystemMessage
from agnext.core import MessageContext from agnext.core import MessageContext
@ -30,7 +30,6 @@ class AuditAgent(RoutedAgent):
assert isinstance(completion.content, str) assert isinstance(completion.content, str)
if "NOTFORME" in completion.content: if "NOTFORME" in completion.content:
return return
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=ctx.topic_id AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=DefaultTopicId()
) )

View File

@ -3,6 +3,7 @@ from typing import Literal
import openai import openai
from agnext.components import ( from agnext.components import (
DefaultTopicId,
RoutedAgent, RoutedAgent,
message_handler, message_handler,
) )
@ -33,9 +34,8 @@ class GraphicDesignerAgent(RoutedAgent):
image_uri = response.data[0].url image_uri = response.data[0].url
logger.info(f"Generated image for article. Got response: '{image_uri}'") logger.info(f"Generated image for article. Got response: '{image_uri}'")
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=ctx.topic_id GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=DefaultTopicId()
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to generate image for article. Error: {e}") logger.error(f"Failed to generate image for article. Error: {e}")

View File

@ -2,8 +2,8 @@ import asyncio
import os import os
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import Image, RoutedAgent, message_handler from agnext.components import DefaultTopicId, Image, RoutedAgent, message_handler
from agnext.core import MessageContext, TopicId from agnext.core import MessageContext
from app import build_app from app import build_app
from dotenv import load_dotenv from dotenv import load_dotenv
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
@ -34,15 +34,13 @@ async def main() -> None:
runtime.start() runtime.start()
topic_id = TopicId("default", "default")
await runtime.publish_message( await runtime.publish_message(
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=topic_id AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=DefaultTopicId()
) )
await runtime.publish_message( await runtime.publish_message(
ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"), ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"),
topic_id=topic_id, topic_id=DefaultTopicId(),
) )
await runtime.stop_when_idle() await runtime.stop_when_idle()

View File

@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
from agnext.components.models import ( from agnext.components.models import (
@ -102,12 +102,11 @@ Reply "TERMINATE" in the end when everything is done."""
AssistantMessage(content=response.content, source=self.metadata["type"]) AssistantMessage(content=response.content, source=self.metadata["type"])
) )
assert ctx.topic_id is not None
# Publish the code execution task. # Publish the code execution task.
await self.publish_message( await self.publish_message(
CodeExecutionTask(content=response.content, session_id=session_id), CodeExecutionTask(content=response.content, session_id=session_id),
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
@message_handler @message_handler
@ -124,11 +123,10 @@ Reply "TERMINATE" in the end when everything is done."""
if "TERMINATE" in response.content: if "TERMINATE" in response.content:
# If the task is completed, publish a message with the completion content. # If the task is completed, publish a message with the completion content.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
TaskCompletion(content=response.content), TaskCompletion(content=response.content),
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
print("--------------------") print("--------------------")
print("Task completed:") print("Task completed:")
@ -136,11 +134,10 @@ Reply "TERMINATE" in the end when everything is done."""
return return
# Publish the code execution task. # Publish the code execution task.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
CodeExecutionTask(content=response.content, session_id=message.session_id), CodeExecutionTask(content=response.content, session_id=message.session_id),
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
@ -157,13 +154,12 @@ class Executor(RoutedAgent):
code_blocks = self._extract_code_blocks(message.content) code_blocks = self._extract_code_blocks(message.content)
if not code_blocks: if not code_blocks:
# If no code block is found, publish a message with an error. # If no code block is found, publish a message with an error.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
CodeExecutionTaskResult( CodeExecutionTaskResult(
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
), ),
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
return return
# Execute code blocks. # Execute code blocks.
@ -171,11 +167,10 @@ class Executor(RoutedAgent):
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
) )
# Publish the code execution result. # Publish the code execution result.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id), CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]: def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:

View File

@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List, Union from typing import Dict, List, Union
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ( from agnext.components.models import (
AssistantMessage, AssistantMessage,
@ -112,14 +112,13 @@ Please review the code and provide feedback.
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()]) review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
approved = review["approval"].lower().strip() == "approve" approved = review["approval"].lower().strip() == "approve"
# Publish the review result. # Publish the review result.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
CodeReviewResult( CodeReviewResult(
review=review_text, review=review_text,
approved=approved, approved=approved,
session_id=message.session_id, session_id=message.session_id,
), ),
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
@ -183,10 +182,9 @@ Code: <Your code>
# Store the code review task in the session memory. # Store the code review task in the session memory.
self._session_memory[session_id].append(code_review_task) self._session_memory[session_id].append(code_review_task)
# Publish a code review task. # Publish a code review task.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
code_review_task, code_review_task,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
@message_handler @message_handler
@ -201,14 +199,13 @@ Code: <Your code>
# Check if the code is approved. # Check if the code is approved.
if message.approved: if message.approved:
# Publish the code writing result. # Publish the code writing result.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
CodeWritingResult( CodeWritingResult(
code=review_request.code, code=review_request.code,
task=review_request.code_writing_task, task=review_request.code_writing_task,
review=message.review, review=message.review,
), ),
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
print("Code Writing Result:") print("Code Writing Result:")
print("-" * 80) print("-" * 80)
@ -247,10 +244,9 @@ Code: <Your code>
# Store the code review task in the session memory. # Store the code review task in the session memory.
self._session_memory[message.session_id].append(code_review_task) self._session_memory[message.session_id].append(code_review_task)
# Publish a new code review task. # Publish a new code review task.
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
code_review_task, code_review_task,
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
def _extract_code_block(self, markdown_text: str) -> Union[str, None]: def _extract_code_block(self, markdown_text: str) -> Union[str, None]:

View File

@ -18,7 +18,7 @@ from dataclasses import dataclass
from typing import List from typing import List
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components.models import ( from agnext.components.models import (
AssistantMessage, AssistantMessage,
ChatCompletionClient, ChatCompletionClient,
@ -69,8 +69,7 @@ class RoundRobinGroupChatManager(RoutedAgent):
self._round_count += 1 self._round_count += 1
if self._round_count > self._num_rounds * len(self._participants): if self._round_count > self._num_rounds * len(self._participants):
# End the conversation after the specified number of rounds. # End the conversation after the specified number of rounds.
assert ctx.topic_id is not None await self.publish_message(Termination(), DefaultTopicId())
await self.publish_message(Termination(), ctx.topic_id)
return return
# Send a request to speak message to the selected speaker. # Send a request to speak message to the selected speaker.
await self.send_message(RequestToSpeak(), speaker) await self.send_message(RequestToSpeak(), speaker)
@ -107,8 +106,7 @@ class GroupChatParticipant(RoutedAgent):
assert isinstance(response.content, str) assert isinstance(response.content, str)
speech = Message(content=response.content, source=self.metadata["type"]) speech = Message(content=response.content, source=self.metadata["type"])
self._memory.append(speech) self._memory.append(speech)
assert ctx.topic_id is not None await self.publish_message(speech, topic_id=DefaultTopicId())
await self.publish_message(speech, topic_id=ctx.topic_id)
async def main() -> None: async def main() -> None:

View File

@ -15,7 +15,7 @@ from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import MessageContext from agnext.core import MessageContext
@ -68,8 +68,7 @@ class ReferenceAgent(RoutedAgent):
response = await self._model_client.create(self._system_messages + [task_message]) response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str) assert isinstance(response.content, str)
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content) task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
assert ctx.topic_id is not None await self.publish_message(task_result, topic_id=DefaultTopicId())
await self.publish_message(task_result, topic_id=ctx.topic_id)
class AggregatorAgent(RoutedAgent): class AggregatorAgent(RoutedAgent):
@ -93,8 +92,7 @@ class AggregatorAgent(RoutedAgent):
"""Handle a task message. This method publishes the task to the reference agents.""" """Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task) ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
assert ctx.topic_id is not None await self.publish_message(ref_task, topic_id=DefaultTopicId())
await self.publish_message(ref_task, topic_id=ctx.topic_id)
@message_handler @message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None: async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
@ -108,8 +106,7 @@ class AggregatorAgent(RoutedAgent):
) )
assert isinstance(response.content, str) assert isinstance(response.content, str)
task_result = AggregatorTaskResult(result=response.content) task_result = AggregatorTaskResult(result=response.content)
assert ctx.topic_id is not None await self.publish_message(task_result, topic_id=DefaultTopicId())
await self.publish_message(task_result, topic_id=ctx.topic_id)
self._session_results.pop(message.session_id) self._session_results.pop(message.session_id)
print(f"Aggregator result: {response.content}") print(f"Aggregator result: {response.content}")

View File

@ -40,7 +40,7 @@ from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ( from agnext.components.models import (
AssistantMessage, AssistantMessage,
@ -165,11 +165,10 @@ class MathSolver(RoutedAgent):
answer = match.group(1) answer = match.group(1)
# Increment the counter. # Increment the counter.
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1 self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
assert ctx.topic_id is not None
if self._counters[message.session_id] == self._max_round: if self._counters[message.session_id] == self._max_round:
# If the counter reaches the maximum round, publishes a final response. # If the counter reaches the maximum round, publishes a final response.
await self.publish_message( await self.publish_message(
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=DefaultTopicId()
) )
else: else:
# Publish intermediate response. # Publish intermediate response.
@ -181,7 +180,7 @@ class MathSolver(RoutedAgent):
session_id=message.session_id, session_id=message.session_id,
round=self._counters[message.session_id], round=self._counters[message.session_id],
), ),
topic_id=ctx.topic_id, topic_id=DefaultTopicId(),
) )
@ -199,9 +198,8 @@ class MathAggregator(RoutedAgent):
"in the form of {{answer}}, at the end of your response." "in the form of {{answer}}, at the end of your response."
) )
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
assert ctx.topic_id is not None
await self.publish_message( await self.publish_message(
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=DefaultTopicId()
) )
@message_handler @message_handler
@ -212,8 +210,7 @@ class MathAggregator(RoutedAgent):
answers = [resp.answer for resp in self._responses[message.session_id]] answers = [resp.answer for resp in self._responses[message.session_id]]
majority_answer = max(set(answers), key=answers.count) majority_answer = max(set(answers), key=answers.count)
# Publish the aggregated response. # Publish the aggregated response.
assert ctx.topic_id is not None await self.publish_message(Answer(content=majority_answer), topic_id=DefaultTopicId())
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
# Clear the responses. # Clear the responses.
self._responses.pop(message.session_id) self._responses.pop(message.session_id)
print(f"Aggregated answer: {majority_answer}") print(f"Aggregated answer: {majority_answer}")

View File

@ -20,7 +20,7 @@ from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, RoutedAgent, message_handler from agnext.components import DefaultTopicId, FunctionCall, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import LocalCommandLineCodeExecutor from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import ( from agnext.components.models import (
@ -90,8 +90,7 @@ class ToolExecutorAgent(RoutedAgent):
session_id=message.session_id, session_id=message.session_id,
result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id), result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id),
) )
assert ctx.topic_id is not None await self.publish_message(task_result, topic_id=DefaultTopicId())
await self.publish_message(task_result, topic_id=ctx.topic_id)
class ToolUseAgent(RoutedAgent): class ToolUseAgent(RoutedAgent):
@ -129,8 +128,7 @@ class ToolUseAgent(RoutedAgent):
if isinstance(response.content, str): if isinstance(response.content, str):
# If the response is a string, just publish the response. # If the response is a string, just publish the response.
response_message = AgentResponse(content=response.content) response_message = AgentResponse(content=response.content)
assert ctx.topic_id is not None await self.publish_message(response_message, topic_id=DefaultTopicId())
await self.publish_message(response_message, topic_id=ctx.topic_id)
print(f"AI Response: {response.content}") print(f"AI Response: {response.content}")
return return
@ -143,8 +141,7 @@ class ToolUseAgent(RoutedAgent):
for function_call in response.content: for function_call in response.content:
task = ToolExecutionTask(session_id=session_id, function_call=function_call) task = ToolExecutionTask(session_id=session_id, function_call=function_call)
self._tool_counter[session_id] += 1 self._tool_counter[session_id] += 1
assert ctx.topic_id is not None await self.publish_message(task, topic_id=DefaultTopicId())
await self.publish_message(task, topic_id=ctx.topic_id)
@message_handler @message_handler
async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None: async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None:
@ -170,11 +167,10 @@ class ToolUseAgent(RoutedAgent):
self._sessions[message.session_id].append( self._sessions[message.session_id].append(
AssistantMessage(content=response.content, source=self.metadata["type"]) AssistantMessage(content=response.content, source=self.metadata["type"])
) )
assert ctx.topic_id is not None
# If the response is a string, just publish the response. # If the response is a string, just publish the response.
if isinstance(response.content, str): if isinstance(response.content, str):
response_message = AgentResponse(content=response.content) response_message = AgentResponse(content=response.content)
await self.publish_message(response_message, topic_id=ctx.topic_id) await self.publish_message(response_message, topic_id=DefaultTopicId())
self._tool_results.pop(message.session_id) self._tool_results.pop(message.session_id)
self._tool_counter.pop(message.session_id) self._tool_counter.pop(message.session_id)
print(f"AI Response: {response.content}") print(f"AI Response: {response.content}")
@ -185,7 +181,7 @@ class ToolUseAgent(RoutedAgent):
for function_call in response.content: for function_call in response.content:
task = ToolExecutionTask(session_id=message.session_id, function_call=function_call) task = ToolExecutionTask(session_id=message.session_id, function_call=function_call)
self._tool_counter[message.session_id] += 1 self._tool_counter[message.session_id] += 1
await self.publish_message(task, topic_id=ctx.topic_id) await self.publish_message(task, topic_id=DefaultTopicId())
async def main() -> None: async def main() -> None:

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime from agnext.application import WorkerAgentRuntime
from agnext.components import RoutedAgent, message_handler from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription from agnext.components._type_subscription import TypeSubscription
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId
@ -40,13 +40,11 @@ class ReceiveAgent(RoutedAgent):
@message_handler @message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None: async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
assert ctx.topic_id is not None await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=DefaultTopicId())
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=ctx.topic_id)
@message_handler @message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None: async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
assert ctx.topic_id is not None await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=DefaultTopicId())
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}") print(f"Unhandled message: {message}")
@ -58,13 +56,11 @@ class GreeterAgent(RoutedAgent):
@message_handler @message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
assert ctx.topic_id is not None await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=DefaultTopicId())
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
@message_handler @message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None: async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
assert ctx.topic_id is not None await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=DefaultTopicId())
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}") print(f"Unhandled message: {message}")

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime from agnext.application import WorkerAgentRuntime
from agnext.components import RoutedAgent, TypeSubscription, message_handler from agnext.components import DefaultTopicId, RoutedAgent, TypeSubscription, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
@ -47,8 +47,7 @@ class GreeterAgent(RoutedAgent):
@message_handler @message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id) response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
assert ctx.topic_id is not None await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=DefaultTopicId())
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}") print(f"Unhandled message: {message}")

View File

@ -6,7 +6,7 @@ from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription from ._default_subscription import DefaultSubscription
from ._default_topic import DefaultTopicId from ._default_topic import DefaultTopicId
from ._image import Image from ._image import Image
from ._routed_agent import RoutedAgent, message_handler, TypeRoutedAgent from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler
from ._type_subscription import TypeSubscription from ._type_subscription import TypeSubscription
from ._types import FunctionCall from ._types import FunctionCall

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import Any from typing import Any
from agnext.components import RoutedAgent, message_handler from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId
from agnext.core import BaseAgent from agnext.core import BaseAgent
from agnext.core import MessageContext from agnext.core import MessageContext
@ -38,8 +39,7 @@ class CascadingAgent(RoutedAgent):
self.num_calls += 1 self.num_calls += 1
if message.round == self.max_rounds: if message.round == self.max_rounds:
return return
assert ctx.topic_id is not None await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=DefaultTopicId())
await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=ctx.topic_id)
class NoopAgent(BaseAgent): class NoopAgent(BaseAgent):
def __init__(self) -> None: def __init__(self) -> None: