Migrate to using default sub/topic (#403)

This commit is contained in:
Jack Gerrits 2024-08-26 10:30:28 -04:00 committed by GitHub
parent d7ae2038fb
commit dbb35fc335
23 changed files with 491 additions and 526 deletions

View File

@ -295,18 +295,15 @@
"outputs": [],
"source": [
"from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import RoutedAgent, message_handler\n",
"from agnext.core import MessageContext, TopicId\n",
"from agnext.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler\n",
"from agnext.core import MessageContext\n",
"\n",
"\n",
"class BroadcastingAgent(RoutedAgent):\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n",
" # Publish a message to all agents in the same namespace.\n",
" assert ctx.topic_id is not None\n",
" await self.publish_message(\n",
" Message(f\"Publishing a message: {message.content}!\"), topic_id=TopicId(\"deafult\", self.id.key)\n",
" )\n",
" await self.publish_message(Message(f\"Publishing a message: {message.content}!\"), topic_id=DefaultTopicId())\n",
"\n",
"\n",
"class ReceivingAgent(RoutedAgent):\n",
@ -337,13 +334,11 @@
}
],
"source": [
"from agnext.components import TypeSubscription\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
"await runtime.register(\n",
" \"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"), lambda: [DefaultSubscription()]\n",
")\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"), lambda: [DefaultSubscription()])\n",
"runtime.start()\n",
"await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n",
"await runtime.stop()"
@ -376,12 +371,12 @@
"# Replace send_message with publish_message in the above example.\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n",
"await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n",
"await runtime.register(\n",
" \"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"), lambda: [DefaultSubscription()]\n",
")\n",
"await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"), lambda: [DefaultSubscription()])\n",
"runtime.start()\n",
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=TopicId(\"default\", \"default\"))\n",
"await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=DefaultTopicId())\n",
"await runtime.stop_when_idle()"
]
},

View File

@ -3,6 +3,7 @@ import json
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
from agnext.components import (
DefaultTopicId,
FunctionCall,
RoutedAgent,
message_handler,
@ -110,9 +111,8 @@ class ChatCompletionAgent(RoutedAgent):
# Generate a response.
response = await self._generate_response(message.response_format, ctx)
assert ctx.topic_id is not None
# Publish the response.
await self.publish_message(response, topic_id=ctx.topic_id)
await self.publish_message(response, topic_id=DefaultTopicId())
@message_handler()
async def on_tool_call_message(

View File

@ -2,6 +2,7 @@ from typing import Literal
import openai
from agnext.components import (
DefaultTopicId,
Image,
RoutedAgent,
message_handler,
@ -57,8 +58,7 @@ class ImageGenerationAgent(RoutedAgent):
image is published as a MultiModalMessage."""
response = await self._generate_response(ctx.cancellation_token)
assert ctx.topic_id is not None
await self.publish_message(response, topic_id=ctx.topic_id)
await self.publish_message(response, topic_id=DefaultTopicId())
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
messages = await self._memory.get_messages()

View File

@ -1,7 +1,7 @@
from typing import Any, Callable, List, Mapping
import openai
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import (
CancellationToken,
MessageContext, # type: ignore
@ -80,8 +80,7 @@ class OpenAIAssistantAgent(RoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method generates a response and publishes it."""
response = await self._generate_response(message.response_format, ctx.cancellation_token)
assert ctx.topic_id is not None
await self.publish_message(response, ctx.topic_id)
await self.publish_message(response, DefaultTopicId())
async def _generate_response(
self,

View File

@ -1,6 +1,6 @@
import asyncio
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import MessageContext
from ..types import PublishNow, TextMessage
@ -23,8 +23,9 @@ class UserProxyAgent(RoutedAgent):
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
user_input = await self.get_user_input(self._user_input_prompt)
assert ctx.topic_id is not None
await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id)
await self.publish_message(
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=DefaultTopicId()
)
async def get_user_input(self, prompt: str) -> str:
"""Get user input from the console. Override this method to customize how user input is retrieved."""

View File

@ -17,7 +17,7 @@ from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
@ -74,7 +74,7 @@ class ChatCompletionAgent(RoutedAgent):
if ctx.topic_id is not None:
await self.publish_message(
Message(content=response.content, source=self.metadata["type"]), topic_id=ctx.topic_id
Message(content=response.content, source=self.metadata["type"]), topic_id=DefaultTopicId()
)

View File

@ -12,7 +12,7 @@ from typing import List
import aiofiles
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.core import AgentId, AgentRuntime, MessageContext
from openai import AsyncAssistantEventHandler
from openai.types.beta.thread import ToolResources
@ -109,9 +109,8 @@ class UserProxyAgent(RoutedAgent):
return
else:
# Publish user input and exit handler.
assert ctx.topic_id is not None
await self.publish_message(
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id
TextMessage(content=user_input, source=self.metadata["type"]), topic_id=DefaultTopicId()
)
return

View File

@ -6,7 +6,8 @@ import os
import sys
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._default_subscription import DefaultSubscription
from agnext.components.memory import ChatMemory
from agnext.components.models import ChatCompletionClient, SystemMessage
from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime
@ -76,9 +77,8 @@ Use the following JSON format to provide your thought on the latest message and
# Publish the response if needed.
if respond is True or str(respond).lower().strip() == "true":
assert ctx.topic_id is not None
await self.publish_message(
TextMessage(source=self.metadata["type"], content=str(response)), topic_id=ctx.topic_id
TextMessage(source=self.metadata["type"], content=str(response)), topic_id=DefaultTopicId()
)
@ -98,6 +98,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
description="The user in the chat room.",
app=app,
),
lambda: [DefaultSubscription()],
)
await runtime.register(
"Alice",
@ -108,6 +109,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
lambda: [DefaultSubscription()],
)
alice = AgentProxy(AgentId("Alice", "default"), runtime)
await runtime.register(
@ -119,6 +121,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
lambda: [DefaultSubscription()],
)
bob = AgentProxy(AgentId("Bob", "default"), runtime)
await runtime.register(
@ -130,6 +133,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=10),
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
lambda: [DefaultSubscription()],
)
charlie = AgentProxy(AgentId("Charlie", "default"), runtime)
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:

View File

@ -10,10 +10,11 @@ import sys
from typing import Annotated, Literal
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import DefaultTopicId
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentInstantiationContext, AgentRuntime, TopicId
from agnext.core import AgentInstantiationContext, AgentRuntime
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
from chess import piece_name as get_piece_name
@ -214,9 +215,7 @@ async def main() -> None:
await chess_game(runtime)
runtime.start()
# Publish an initial message to trigger the group chat manager to start orchestration.
await runtime.publish_message(
TextMessage(content="Game started.", source="System"), topic_id=TopicId("default", "default")
)
await runtime.publish_message(TextMessage(content="Game started.", source="System"), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

View File

@ -4,7 +4,7 @@ import random
import sys
from asyncio import Future
from agnext.components import Image, RoutedAgent, message_handler
from agnext.components import DefaultTopicId, Image, RoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from textual.app import App, ComposeResult
from textual.containers import ScrollableContainer
@ -13,7 +13,6 @@ from textual_imageview.viewer import ImageViewer
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agnext.core import TopicId
from common.types import (
MultiModalMessage,
PublishNow,
@ -136,9 +135,7 @@ class TextualChatApp(App): # type: ignore
chat_messages.query("#typing").remove()
# Publish the user message to the runtime.
await self._runtime.publish_message(
# TODO fix hard coded topic_id
TextMessage(source=self._user_name, content=user_input),
topic_id=TopicId("default", "default"),
TextMessage(source=self._user_name, content=user_input), topic_id=DefaultTopicId()
)
async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore

View File

@ -1,4 +1,4 @@
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components.models import ChatCompletionClient
from agnext.components.models._types import SystemMessage
from agnext.core import MessageContext
@ -30,7 +30,6 @@ class AuditAgent(RoutedAgent):
assert isinstance(completion.content, str)
if "NOTFORME" in completion.content:
return
assert ctx.topic_id is not None
await self.publish_message(
AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=ctx.topic_id
AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=DefaultTopicId()
)

View File

@ -3,6 +3,7 @@ from typing import Literal
import openai
from agnext.components import (
DefaultTopicId,
RoutedAgent,
message_handler,
)
@ -33,9 +34,8 @@ class GraphicDesignerAgent(RoutedAgent):
image_uri = response.data[0].url
logger.info(f"Generated image for article. Got response: '{image_uri}'")
assert ctx.topic_id is not None
await self.publish_message(
GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=ctx.topic_id
GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=DefaultTopicId()
)
except Exception as e:
logger.error(f"Failed to generate image for article. Error: {e}")

View File

@ -2,8 +2,8 @@ import asyncio
import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import Image, RoutedAgent, message_handler
from agnext.core import MessageContext, TopicId
from agnext.components import DefaultTopicId, Image, RoutedAgent, message_handler
from agnext.core import MessageContext
from app import build_app
from dotenv import load_dotenv
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
@ -34,15 +34,13 @@ async def main() -> None:
runtime.start()
topic_id = TopicId("default", "default")
await runtime.publish_message(
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=topic_id
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=DefaultTopicId()
)
await runtime.publish_message(
ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"),
topic_id=topic_id,
topic_id=DefaultTopicId(),
)
await runtime.stop_when_idle()

View File

@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
from agnext.components.models import (
@ -102,12 +102,11 @@ Reply "TERMINATE" in the end when everything is done."""
AssistantMessage(content=response.content, source=self.metadata["type"])
)
assert ctx.topic_id is not None
# Publish the code execution task.
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@message_handler
@ -124,11 +123,10 @@ Reply "TERMINATE" in the end when everything is done."""
if "TERMINATE" in response.content:
# If the task is completed, publish a message with the completion content.
assert ctx.topic_id is not None
await self.publish_message(
TaskCompletion(content=response.content),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
print("--------------------")
print("Task completed:")
@ -136,11 +134,10 @@ Reply "TERMINATE" in the end when everything is done."""
return
# Publish the code execution task.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@ -157,13 +154,12 @@ class Executor(RoutedAgent):
code_blocks = self._extract_code_blocks(message.content)
if not code_blocks:
# If no code block is found, publish a message with an error.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
return
# Execute code blocks.
@ -171,11 +167,10 @@ class Executor(RoutedAgent):
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
)
# Publish the code execution result.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:

View File

@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
@ -112,14 +112,13 @@ Please review the code and provide feedback.
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
approved = review["approval"].lower().strip() == "approve"
# Publish the review result.
assert ctx.topic_id is not None
await self.publish_message(
CodeReviewResult(
review=review_text,
approved=approved,
session_id=message.session_id,
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@ -183,10 +182,9 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[session_id].append(code_review_task)
# Publish a code review task.
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@message_handler
@ -201,14 +199,13 @@ Code: <Your code>
# Check if the code is approved.
if message.approved:
# Publish the code writing result.
assert ctx.topic_id is not None
await self.publish_message(
CodeWritingResult(
code=review_request.code,
task=review_request.code_writing_task,
review=message.review,
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
print("Code Writing Result:")
print("-" * 80)
@ -247,10 +244,9 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[message.session_id].append(code_review_task)
# Publish a new code review task.
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:

View File

@ -18,7 +18,7 @@ from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@ -69,8 +69,7 @@ class RoundRobinGroupChatManager(RoutedAgent):
self._round_count += 1
if self._round_count > self._num_rounds * len(self._participants):
# End the conversation after the specified number of rounds.
assert ctx.topic_id is not None
await self.publish_message(Termination(), ctx.topic_id)
await self.publish_message(Termination(), DefaultTopicId())
return
# Send a request to speak message to the selected speaker.
await self.send_message(RequestToSpeak(), speaker)
@ -107,8 +106,7 @@ class GroupChatParticipant(RoutedAgent):
assert isinstance(response.content, str)
speech = Message(content=response.content, source=self.metadata["type"])
self._memory.append(speech)
assert ctx.topic_id is not None
await self.publish_message(speech, topic_id=ctx.topic_id)
await self.publish_message(speech, topic_id=DefaultTopicId())
async def main() -> None:

View File

@ -15,7 +15,7 @@ from dataclasses import dataclass
from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import MessageContext
@ -68,8 +68,7 @@ class ReferenceAgent(RoutedAgent):
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
await self.publish_message(task_result, topic_id=DefaultTopicId())
class AggregatorAgent(RoutedAgent):
@ -93,8 +92,7 @@ class AggregatorAgent(RoutedAgent):
"""Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
assert ctx.topic_id is not None
await self.publish_message(ref_task, topic_id=ctx.topic_id)
await self.publish_message(ref_task, topic_id=DefaultTopicId())
@message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
@ -108,8 +106,7 @@ class AggregatorAgent(RoutedAgent):
)
assert isinstance(response.content, str)
task_result = AggregatorTaskResult(result=response.content)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
await self.publish_message(task_result, topic_id=DefaultTopicId())
self._session_results.pop(message.session_id)
print(f"Aggregator result: {response.content}")

View File

@ -40,7 +40,7 @@ from dataclasses import dataclass
from typing import Dict, List, Tuple
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
@ -165,11 +165,10 @@ class MathSolver(RoutedAgent):
answer = match.group(1)
# Increment the counter.
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
assert ctx.topic_id is not None
if self._counters[message.session_id] == self._max_round:
# If the counter reaches the maximum round, publishes a final response.
await self.publish_message(
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=DefaultTopicId()
)
else:
# Publish intermediate response.
@ -181,7 +180,7 @@ class MathSolver(RoutedAgent):
session_id=message.session_id,
round=self._counters[message.session_id],
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@ -199,9 +198,8 @@ class MathAggregator(RoutedAgent):
"in the form of {{answer}}, at the end of your response."
)
session_id = str(uuid.uuid4())
assert ctx.topic_id is not None
await self.publish_message(
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=DefaultTopicId()
)
@message_handler
@ -212,8 +210,7 @@ class MathAggregator(RoutedAgent):
answers = [resp.answer for resp in self._responses[message.session_id]]
majority_answer = max(set(answers), key=answers.count)
# Publish the aggregated response.
assert ctx.topic_id is not None
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
await self.publish_message(Answer(content=majority_answer), topic_id=DefaultTopicId())
# Clear the responses.
self._responses.pop(message.session_id)
print(f"Aggregated answer: {majority_answer}")

View File

@ -20,7 +20,7 @@ from dataclasses import dataclass
from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, RoutedAgent, message_handler
from agnext.components import DefaultTopicId, FunctionCall, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import (
@ -90,8 +90,7 @@ class ToolExecutorAgent(RoutedAgent):
session_id=message.session_id,
result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id),
)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
await self.publish_message(task_result, topic_id=DefaultTopicId())
class ToolUseAgent(RoutedAgent):
@ -129,8 +128,7 @@ class ToolUseAgent(RoutedAgent):
if isinstance(response.content, str):
# If the response is a string, just publish the response.
response_message = AgentResponse(content=response.content)
assert ctx.topic_id is not None
await self.publish_message(response_message, topic_id=ctx.topic_id)
await self.publish_message(response_message, topic_id=DefaultTopicId())
print(f"AI Response: {response.content}")
return
@ -143,8 +141,7 @@ class ToolUseAgent(RoutedAgent):
for function_call in response.content:
task = ToolExecutionTask(session_id=session_id, function_call=function_call)
self._tool_counter[session_id] += 1
assert ctx.topic_id is not None
await self.publish_message(task, topic_id=ctx.topic_id)
await self.publish_message(task, topic_id=DefaultTopicId())
@message_handler
async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None:
@ -170,11 +167,10 @@ class ToolUseAgent(RoutedAgent):
self._sessions[message.session_id].append(
AssistantMessage(content=response.content, source=self.metadata["type"])
)
assert ctx.topic_id is not None
# If the response is a string, just publish the response.
if isinstance(response.content, str):
response_message = AgentResponse(content=response.content)
await self.publish_message(response_message, topic_id=ctx.topic_id)
await self.publish_message(response_message, topic_id=DefaultTopicId())
self._tool_results.pop(message.session_id)
self._tool_counter.pop(message.session_id)
print(f"AI Response: {response.content}")
@ -185,7 +181,7 @@ class ToolUseAgent(RoutedAgent):
for function_call in response.content:
task = ToolExecutionTask(session_id=message.session_id, function_call=function_call)
self._tool_counter[message.session_id] += 1
await self.publish_message(task, topic_id=ctx.topic_id)
await self.publish_message(task, topic_id=DefaultTopicId())
async def main() -> None:

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId
@ -40,13 +40,11 @@ class ReceiveAgent(RoutedAgent):
@message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=ctx.topic_id)
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=DefaultTopicId())
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=DefaultTopicId())
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")
@ -58,13 +56,11 @@ class GreeterAgent(RoutedAgent):
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=DefaultTopicId())
@message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=DefaultTopicId())
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime
from agnext.components import RoutedAgent, TypeSubscription, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, TypeSubscription, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
@ -47,8 +47,7 @@ class GreeterAgent(RoutedAgent):
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=DefaultTopicId())
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")

View File

@ -6,7 +6,7 @@ from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._routed_agent import RoutedAgent, message_handler, TypeRoutedAgent
from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler
from ._type_subscription import TypeSubscription
from ._types import FunctionCall

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import Any
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId
from agnext.core import BaseAgent
from agnext.core import MessageContext
@ -38,8 +39,7 @@ class CascadingAgent(RoutedAgent):
self.num_calls += 1
if message.round == self.max_rounds:
return
assert ctx.topic_id is not None
await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=ctx.topic_id)
await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=DefaultTopicId())
class NoopAgent(BaseAgent):
def __init__(self) -> None: