fix: Make race condition between channel open and RPC less likely to occur (#5514)

Right now we rely on opening the channel to associate a ClientId with an
entry on the gateway side. This causes a race when the channel is being
opened in the background while an RPC (e.g. MyAgent.register()) is
invoked.

If the RPC is processed first, the gateway rejects it due to "invalid"
clientId.

This fix makes this condition less likely to trigger, but there is still
a piece of the puzzle that needs to be solved on the Gateway side.
This commit is contained in:
Jacob Alber 2025-02-12 16:40:52 -05:00 committed by GitHub
parent f49f159a43
commit 676b611064
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 290 additions and 275 deletions

View File

@ -183,7 +183,7 @@ jobs:
codecov:
runs-on: ubuntu-latest
needs: [test]
needs: [test, test-grpc]
strategy:
matrix:
package:

View File

@ -1,225 +1,225 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed Agent Runtime\n",
"\n",
"```{attention}\n",
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
"to the API.\n",
"```\n",
"\n",
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
"across process boundaries.\n",
"It consists of a host service and at least one worker runtime.\n",
"\n",
"The host service maintains connections to all active worker runtimes,\n",
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
"A worker runtime processes application code (agents) and connects to the host service.\n",
"It also advertises the agents which they support to the host service,\n",
"so the host service can deliver messages to the correct worker.\n",
"\n",
"````{note}\n",
"The distributed agent runtime requires extra dependencies, install them using:\n",
"```bash\n",
"pip install \"autogen-ext[grpc]\"\n",
"```\n",
"````\n",
"\n",
"We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
"\n",
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
"host.start() # Start a host service in the background."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above code starts the host service in the background and accepts\n",
"worker connections on port 50051.\n",
"\n",
"Before running worker runtimes, let's define our agent.\n",
"The agent will publish a new message on every message it receives.\n",
"It also keeps track of how many messages it has published, and \n",
"stops publishing new messages once it has published 5 messages."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
"\n",
"\n",
"@dataclass\n",
"class MyMessage:\n",
" content: str\n",
"\n",
"\n",
"@default_subscription\n",
"class MyAgent(RoutedAgent):\n",
" def __init__(self, name: str) -> None:\n",
" super().__init__(\"My agent\")\n",
" self._name = name\n",
" self._counter = 0\n",
"\n",
" @message_handler\n",
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
" self._counter += 1\n",
" if self._counter > 5:\n",
" return\n",
" content = f\"{self._name}: Hello x {self._counter}\"\n",
" print(content)\n",
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can set up the worker agent runtimes.\n",
"We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n",
"We set up two worker runtimes. Each runtime hosts one agent.\n",
"All agents publish and subscribe to the default topic, so they can see all\n",
"messages being published.\n",
"\n",
"To run the agents, we publishes a message from a worker."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"worker1: Hello x 1\n",
"worker2: Hello x 1\n",
"worker2: Hello x 2\n",
"worker1: Hello x 2\n",
"worker1: Hello x 3\n",
"worker2: Hello x 3\n",
"worker2: Hello x 4\n",
"worker1: Hello x 4\n",
"worker1: Hello x 5\n",
"worker2: Hello x 5\n"
]
}
],
"source": [
"import asyncio\n",
"\n",
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
"\n",
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker1.start()\n",
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
"\n",
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker2.start()\n",
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
"\n",
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
"\n",
"# Let the agents run for a while.\n",
"await asyncio.sleep(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see each agent published exactly 5 messages.\n",
"\n",
"To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"await worker1.stop()\n",
"await worker2.stop()\n",
"\n",
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
"# await worker1.stop_when_signal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n",
"to stop the host service."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"await host.stop()\n",
"\n",
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
"# await host.stop_when_signal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cross-Language Runtimes\n",
"The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n",
"\n",
"# Next Steps\n",
"To see complete examples of using distributed runtime, please take a look at the following samples:\n",
"\n",
"- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n",
"- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n",
"- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "agnext",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed Agent Runtime\n",
"\n",
"```{attention}\n",
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
"to the API.\n",
"```\n",
"\n",
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
"across process boundaries.\n",
"It consists of a host service and at least one worker runtime.\n",
"\n",
"The host service maintains connections to all active worker runtimes,\n",
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
"A worker runtime processes application code (agents) and connects to the host service.\n",
"It also advertises the agents which they support to the host service,\n",
"so the host service can deliver messages to the correct worker.\n",
"\n",
"````{note}\n",
"The distributed agent runtime requires extra dependencies, install them using:\n",
"```bash\n",
"pip install \"autogen-ext[grpc]\"\n",
"```\n",
"````\n",
"\n",
"We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
"\n",
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
"host.start() # Start a host service in the background."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above code starts the host service in the background and accepts\n",
"worker connections on port 50051.\n",
"\n",
"Before running worker runtimes, let's define our agent.\n",
"The agent will publish a new message on every message it receives.\n",
"It also keeps track of how many messages it has published, and \n",
"stops publishing new messages once it has published 5 messages."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
"\n",
"\n",
"@dataclass\n",
"class MyMessage:\n",
" content: str\n",
"\n",
"\n",
"@default_subscription\n",
"class MyAgent(RoutedAgent):\n",
" def __init__(self, name: str) -> None:\n",
" super().__init__(\"My agent\")\n",
" self._name = name\n",
" self._counter = 0\n",
"\n",
" @message_handler\n",
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
" self._counter += 1\n",
" if self._counter > 5:\n",
" return\n",
" content = f\"{self._name}: Hello x {self._counter}\"\n",
" print(content)\n",
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can set up the worker agent runtimes.\n",
"We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n",
"We set up two worker runtimes. Each runtime hosts one agent.\n",
"All agents publish and subscribe to the default topic, so they can see all\n",
"messages being published.\n",
"\n",
"To run the agents, we publishes a message from a worker."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"worker1: Hello x 1\n",
"worker2: Hello x 1\n",
"worker2: Hello x 2\n",
"worker1: Hello x 2\n",
"worker1: Hello x 3\n",
"worker2: Hello x 3\n",
"worker2: Hello x 4\n",
"worker1: Hello x 4\n",
"worker1: Hello x 5\n",
"worker2: Hello x 5\n"
]
}
],
"source": [
"import asyncio\n",
"\n",
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
"\n",
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"await worker1.start()\n",
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
"\n",
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"await worker2.start()\n",
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
"\n",
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
"\n",
"# Let the agents run for a while.\n",
"await asyncio.sleep(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see each agent published exactly 5 messages.\n",
"\n",
"To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"await worker1.stop()\n",
"await worker2.stop()\n",
"\n",
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
"# await worker1.stop_when_signal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n",
"to stop the host service."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"await host.stop()\n",
"\n",
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
"# await host.stop_when_signal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cross-Language Runtimes\n",
"The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n",
"\n",
"# Next Steps\n",
"To see complete examples of using distributed runtime, please take a look at the following samples:\n",
"\n",
"- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n",
"- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n",
"- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -132,7 +132,9 @@ class HostConnection:
return [("client-id", self._client_id)]
@classmethod
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
async def from_host_address(
cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG
) -> Self:
logger.info("Connecting to %s", host_address)
# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
merged_options = [
@ -145,9 +147,11 @@ class HostConnection:
)
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
instance = cls(channel, stub)
instance._connection_task = asyncio.create_task(
instance._connect(stub, instance._send_queue, instance._recv_queue, instance._client_id)
instance._connection_task = await instance._connect(
stub, instance._send_queue, instance._recv_queue, instance._client_id
)
return instance
async def close(self) -> None:
@ -162,23 +166,28 @@ class HostConnection:
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> None:
) -> Task[None]:
from grpc.aio import StreamStreamCall
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
)
while True:
logger.info("Waiting for message from host")
message = cast(agent_worker_pb2.Message, await recv_stream.read()) # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
await stream.wait_for_connection()
async def read_loop() -> None:
while True:
logger.info("Waiting for message from host")
message = cast(agent_worker_pb2.Message, await stream.read()) # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
return asyncio.create_task(read_loop())
async def send(self, message: agent_worker_pb2.Message) -> None:
logger.info(f"Send message to host: {message}")
@ -248,12 +257,12 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
self._payload_serialization_format = payload_serialization_format
def start(self) -> None:
async def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
raise ValueError("Runtime is already running.")
logger.info(f"Connecting to host: {self._host_address}")
self._host_connection = HostConnection.from_host_address(
self._host_connection = await HostConnection.from_host_address(
self._host_address, extra_grpc_config=self._extra_grpc_config
)
logger.info("Connection established")

View File

@ -42,7 +42,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
@ -65,9 +65,9 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None:
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker2.start()
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
@ -91,7 +91,7 @@ async def test_register_receives_publish() -> None:
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
@ -99,7 +99,7 @@ async def test_register_receives_publish() -> None:
await worker1.add_subscription(TypeSubscription("default", "name1"))
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker2.start()
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker2.register_factory(
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
@ -137,7 +137,7 @@ async def test_register_doesnt_receive_after_removing_subscription() -> None:
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
@ -177,7 +177,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()
num_agents = 5
num_initial_messages = 5
@ -228,14 +228,14 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
# Register agents
for i in range(num_agents):
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
workers.append(runtime)
# Publish messages
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
publisher.start()
await publisher.start()
for _ in range(num_initial_messages):
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
@ -259,10 +259,10 @@ async def test_default_subscription() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()
await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())
@ -294,10 +294,10 @@ async def test_default_subscription_other_source() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())
@ -329,10 +329,10 @@ async def test_type_subscription() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()
@type_subscription("Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...
@ -369,10 +369,10 @@ async def test_duplicate_subscription() -> None:
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
host.start()
try:
worker1.start()
await worker1.start()
await NoopAgent.register(worker1, "worker1", lambda: NoopAgent())
worker1_2.start()
await worker1_2.start()
# Note: This passes because worker1 is still running
with pytest.raises(Exception, match="Agent type worker1 already registered"):
@ -411,7 +411,7 @@ async def test_disconnected_agent() -> None:
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
try:
worker1.start()
await worker1.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
@ -439,7 +439,7 @@ async def test_disconnected_agent() -> None:
assert len(recipients2) == 0
await asyncio.sleep(1)
worker1_2.start()
await worker1_2.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
@ -480,12 +480,12 @@ async def test_proto_payloads() -> None:
receiver_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
receiver_runtime.start()
await receiver_runtime.start()
publisher_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
publisher_runtime.start()
await publisher_runtime.start()
await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent)
@ -517,6 +517,7 @@ async def test_proto_payloads() -> None:
@pytest.mark.grpc
@pytest.mark.asyncio
@pytest.mark.skip(reason="Fix flakiness")
async def test_grpc_max_message_size() -> None:
default_max_size = 2**22
new_max_size = default_max_size * 2
@ -535,9 +536,9 @@ async def test_grpc_max_message_size() -> None:
try:
host.start()
worker1.start()
worker2.start()
worker3.start()
await worker1.start()
await worker2.start()
await worker3.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)

View File

@ -110,3 +110,8 @@ cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-core/tests/pr
[[tool.poe.tasks.gen-test-proto.sequence]]
cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-ext/tests/protos --grpc_python_out=./packages/autogen-ext/tests/protos --mypy_out=./packages/autogen-ext/tests/protos --mypy_grpc_out=./packages/autogen-ext/tests/protos --proto_path ./packages/autogen-core/tests/protos serialization_test.proto"
[tool.pytest.ini_options]
markers = [
"grpc: tests invoking gRPC functionality",
]

View File

@ -20,7 +20,7 @@ async def main(config: AppConfig):
editor_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(4)
Console().print(Markdown("Starting **`Editor Agent`**"))
editor_agent_runtime.start()
await editor_agent_runtime.start()
editor_agent_type = await BaseGroupChatAgent.register(
editor_agent_runtime,
config.editor_agent.topic_type,

View File

@ -23,7 +23,7 @@ async def main(config: AppConfig):
group_chat_manager_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(1)
Console().print(Markdown("Starting **`Group Chat Manager`**"))
group_chat_manager_runtime.start()
await group_chat_manager_runtime.start()
set_all_log_levels(logging.ERROR)
group_chat_manager_type = await GroupChatManager.register(

View File

@ -41,7 +41,7 @@ async def main(config: AppConfig):
ui_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
Console().print(Markdown("Starting **`UI Agent`**"))
ui_agent_runtime.start()
await ui_agent_runtime.start()
set_all_log_levels(logging.ERROR)
ui_agent_type = await UIAgent.register(

View File

@ -21,7 +21,7 @@ async def main(config: AppConfig) -> None:
await asyncio.sleep(3)
Console().print(Markdown("Starting **`Writer Agent`**"))
writer_agent_runtime.start()
await writer_agent_runtime.start()
writer_agent_type = await BaseGroupChatAgent.register(
writer_agent_runtime,
config.writer_agent.topic_type,

View File

@ -6,7 +6,7 @@ from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
runtime.start()
await runtime.start()
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
await runtime.publish_message(CascadingMessage(round=1), topic_id=DefaultTopicId())
await runtime.stop_when_signal()

View File

@ -8,7 +8,7 @@ from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
runtime.start()
await runtime.start()
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
await CascadingAgent.register(runtime, agent_type, lambda: CascadingAgent(max_rounds=3))
await runtime.stop_when_signal()

View File

@ -73,7 +73,7 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
await runtime.start()
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
runtime.add_message_serializer(try_get_known_serializers_for_type(t))

View File

@ -54,7 +54,7 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
await runtime.start()
await ReceiveAgent.register(
runtime,

View File

@ -80,7 +80,7 @@ async def output_result(
async def run_workers():
agent_runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
agent_runtime.start()
await agent_runtime.start()
# Create the agents
await WorkerAgent.register(agent_runtime, "finance", lambda: WorkerAgent("finance_agent"))

View File

@ -26,7 +26,7 @@ agnext_logger = logging.getLogger("autogen_core")
async def main() -> None:
load_dotenv()
agentHost = os.getenv("AGENT_HOST") or "localhost:53072"
agentHost = os.getenv("AGENT_HOST") or "http://localhost:50673"
# grpc python bug - can only use the hostname, not prefix - if hostname has a prefix we have to remove it:
if agentHost.startswith("http://"):
agentHost = agentHost[7:]
@ -37,7 +37,7 @@ async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
agnext_logger.info("1")
runtime.start()
await runtime.start()
runtime.add_message_serializer(try_get_known_serializers_for_type(NewMessageReceived))
agnext_logger.info("2")