mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-29 07:59:50 +00:00
fix: Make race condition between channel open and RPC less likely to occur (#5514)
Right now we rely on opening the channel to associate a ClientId with an entry on the gateway side. This causes a race when the channel is being opened in the background while an RPC (e.g. MyAgent.register()) is invoked. If the RPC is processed first, the gateway rejects it due to "invalid" clientId. This fix makes this condition less likely to trigger, but there is still a piece of the puzzle that needs to be solved on the Gateway side.
This commit is contained in:
parent
f49f159a43
commit
676b611064
2
.github/workflows/checks.yml
vendored
2
.github/workflows/checks.yml
vendored
@ -183,7 +183,7 @@ jobs:
|
||||
|
||||
codecov:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test]
|
||||
needs: [test, test-grpc]
|
||||
strategy:
|
||||
matrix:
|
||||
package:
|
||||
|
||||
@ -1,225 +1,225 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Distributed Agent Runtime\n",
|
||||
"\n",
|
||||
"```{attention}\n",
|
||||
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
|
||||
"to the API.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
|
||||
"across process boundaries.\n",
|
||||
"It consists of a host service and at least one worker runtime.\n",
|
||||
"\n",
|
||||
"The host service maintains connections to all active worker runtimes,\n",
|
||||
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
|
||||
"A worker runtime processes application code (agents) and connects to the host service.\n",
|
||||
"It also advertises the agents which they support to the host service,\n",
|
||||
"so the host service can deliver messages to the correct worker.\n",
|
||||
"\n",
|
||||
"````{note}\n",
|
||||
"The distributed agent runtime requires extra dependencies, install them using:\n",
|
||||
"```bash\n",
|
||||
"pip install \"autogen-ext[grpc]\"\n",
|
||||
"```\n",
|
||||
"````\n",
|
||||
"\n",
|
||||
"We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
|
||||
"\n",
|
||||
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
|
||||
"host.start() # Start a host service in the background."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above code starts the host service in the background and accepts\n",
|
||||
"worker connections on port 50051.\n",
|
||||
"\n",
|
||||
"Before running worker runtimes, let's define our agent.\n",
|
||||
"The agent will publish a new message on every message it receives.\n",
|
||||
"It also keeps track of how many messages it has published, and \n",
|
||||
"stops publishing new messages once it has published 5 messages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class MyMessage:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@default_subscription\n",
|
||||
"class MyAgent(RoutedAgent):\n",
|
||||
" def __init__(self, name: str) -> None:\n",
|
||||
" super().__init__(\"My agent\")\n",
|
||||
" self._name = name\n",
|
||||
" self._counter = 0\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
|
||||
" self._counter += 1\n",
|
||||
" if self._counter > 5:\n",
|
||||
" return\n",
|
||||
" content = f\"{self._name}: Hello x {self._counter}\"\n",
|
||||
" print(content)\n",
|
||||
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can set up the worker agent runtimes.\n",
|
||||
"We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n",
|
||||
"We set up two worker runtimes. Each runtime hosts one agent.\n",
|
||||
"All agents publish and subscribe to the default topic, so they can see all\n",
|
||||
"messages being published.\n",
|
||||
"\n",
|
||||
"To run the agents, we publishes a message from a worker."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"worker1: Hello x 1\n",
|
||||
"worker2: Hello x 1\n",
|
||||
"worker2: Hello x 2\n",
|
||||
"worker1: Hello x 2\n",
|
||||
"worker1: Hello x 3\n",
|
||||
"worker2: Hello x 3\n",
|
||||
"worker2: Hello x 4\n",
|
||||
"worker1: Hello x 4\n",
|
||||
"worker1: Hello x 5\n",
|
||||
"worker2: Hello x 5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
|
||||
"\n",
|
||||
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker1.start()\n",
|
||||
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
|
||||
"\n",
|
||||
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker2.start()\n",
|
||||
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
|
||||
"\n",
|
||||
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
|
||||
"\n",
|
||||
"# Let the agents run for a while.\n",
|
||||
"await asyncio.sleep(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can see each agent published exactly 5 messages.\n",
|
||||
"\n",
|
||||
"To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await worker1.stop()\n",
|
||||
"await worker2.stop()\n",
|
||||
"\n",
|
||||
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
|
||||
"# await worker1.stop_when_signal()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n",
|
||||
"to stop the host service."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await host.stop()\n",
|
||||
"\n",
|
||||
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
|
||||
"# await host.stop_when_signal()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cross-Language Runtimes\n",
|
||||
"The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n",
|
||||
"\n",
|
||||
"# Next Steps\n",
|
||||
"To see complete examples of using distributed runtime, please take a look at the following samples:\n",
|
||||
"\n",
|
||||
"- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n",
|
||||
"- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n",
|
||||
"- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "agnext",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Distributed Agent Runtime\n",
|
||||
"\n",
|
||||
"```{attention}\n",
|
||||
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
|
||||
"to the API.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
|
||||
"across process boundaries.\n",
|
||||
"It consists of a host service and at least one worker runtime.\n",
|
||||
"\n",
|
||||
"The host service maintains connections to all active worker runtimes,\n",
|
||||
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
|
||||
"A worker runtime processes application code (agents) and connects to the host service.\n",
|
||||
"It also advertises the agents which they support to the host service,\n",
|
||||
"so the host service can deliver messages to the correct worker.\n",
|
||||
"\n",
|
||||
"````{note}\n",
|
||||
"The distributed agent runtime requires extra dependencies, install them using:\n",
|
||||
"```bash\n",
|
||||
"pip install \"autogen-ext[grpc]\"\n",
|
||||
"```\n",
|
||||
"````\n",
|
||||
"\n",
|
||||
"We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
|
||||
"\n",
|
||||
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
|
||||
"host.start() # Start a host service in the background."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above code starts the host service in the background and accepts\n",
|
||||
"worker connections on port 50051.\n",
|
||||
"\n",
|
||||
"Before running worker runtimes, let's define our agent.\n",
|
||||
"The agent will publish a new message on every message it receives.\n",
|
||||
"It also keeps track of how many messages it has published, and \n",
|
||||
"stops publishing new messages once it has published 5 messages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class MyMessage:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@default_subscription\n",
|
||||
"class MyAgent(RoutedAgent):\n",
|
||||
" def __init__(self, name: str) -> None:\n",
|
||||
" super().__init__(\"My agent\")\n",
|
||||
" self._name = name\n",
|
||||
" self._counter = 0\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
|
||||
" self._counter += 1\n",
|
||||
" if self._counter > 5:\n",
|
||||
" return\n",
|
||||
" content = f\"{self._name}: Hello x {self._counter}\"\n",
|
||||
" print(content)\n",
|
||||
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can set up the worker agent runtimes.\n",
|
||||
"We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n",
|
||||
"We set up two worker runtimes. Each runtime hosts one agent.\n",
|
||||
"All agents publish and subscribe to the default topic, so they can see all\n",
|
||||
"messages being published.\n",
|
||||
"\n",
|
||||
"To run the agents, we publishes a message from a worker."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"worker1: Hello x 1\n",
|
||||
"worker2: Hello x 1\n",
|
||||
"worker2: Hello x 2\n",
|
||||
"worker1: Hello x 2\n",
|
||||
"worker1: Hello x 3\n",
|
||||
"worker2: Hello x 3\n",
|
||||
"worker2: Hello x 4\n",
|
||||
"worker1: Hello x 4\n",
|
||||
"worker1: Hello x 5\n",
|
||||
"worker2: Hello x 5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
|
||||
"\n",
|
||||
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"await worker1.start()\n",
|
||||
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
|
||||
"\n",
|
||||
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"await worker2.start()\n",
|
||||
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
|
||||
"\n",
|
||||
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
|
||||
"\n",
|
||||
"# Let the agents run for a while.\n",
|
||||
"await asyncio.sleep(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can see each agent published exactly 5 messages.\n",
|
||||
"\n",
|
||||
"To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await worker1.stop()\n",
|
||||
"await worker2.stop()\n",
|
||||
"\n",
|
||||
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
|
||||
"# await worker1.stop_when_signal()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n",
|
||||
"to stop the host service."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await host.stop()\n",
|
||||
"\n",
|
||||
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
|
||||
"# await host.stop_when_signal()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cross-Language Runtimes\n",
|
||||
"The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n",
|
||||
"\n",
|
||||
"# Next Steps\n",
|
||||
"To see complete examples of using distributed runtime, please take a look at the following samples:\n",
|
||||
"\n",
|
||||
"- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n",
|
||||
"- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n",
|
||||
"- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@ -132,7 +132,9 @@ class HostConnection:
|
||||
return [("client-id", self._client_id)]
|
||||
|
||||
@classmethod
|
||||
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
|
||||
async def from_host_address(
|
||||
cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG
|
||||
) -> Self:
|
||||
logger.info("Connecting to %s", host_address)
|
||||
# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
|
||||
merged_options = [
|
||||
@ -145,9 +147,11 @@ class HostConnection:
|
||||
)
|
||||
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
||||
instance = cls(channel, stub)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(stub, instance._send_queue, instance._recv_queue, instance._client_id)
|
||||
|
||||
instance._connection_task = await instance._connect(
|
||||
stub, instance._send_queue, instance._recv_queue, instance._client_id
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
async def close(self) -> None:
|
||||
@ -162,23 +166,28 @@ class HostConnection:
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
) -> Task[None]:
|
||||
from grpc.aio import StreamStreamCall
|
||||
|
||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||
stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
|
||||
)
|
||||
|
||||
while True:
|
||||
logger.info("Waiting for message from host")
|
||||
message = cast(agent_worker_pb2.Message, await recv_stream.read()) # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in receive queue")
|
||||
await stream.wait_for_connection()
|
||||
|
||||
async def read_loop() -> None:
|
||||
while True:
|
||||
logger.info("Waiting for message from host")
|
||||
message = cast(agent_worker_pb2.Message, await stream.read()) # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in receive queue")
|
||||
|
||||
return asyncio.create_task(read_loop())
|
||||
|
||||
async def send(self, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Send message to host: {message}")
|
||||
@ -248,12 +257,12 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
self._payload_serialization_format = payload_serialization_format
|
||||
|
||||
def start(self) -> None:
|
||||
async def start(self) -> None:
|
||||
"""Start the runtime in a background task."""
|
||||
if self._running:
|
||||
raise ValueError("Runtime is already running.")
|
||||
logger.info(f"Connecting to host: {self._host_address}")
|
||||
self._host_connection = HostConnection.from_host_address(
|
||||
self._host_connection = await HostConnection.from_host_address(
|
||||
self._host_address, extra_grpc_config=self._extra_grpc_config
|
||||
)
|
||||
logger.info("Connection established")
|
||||
|
||||
@ -42,7 +42,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
|
||||
host.start()
|
||||
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
await worker.start()
|
||||
|
||||
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
@ -65,9 +65,9 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None:
|
||||
host.start()
|
||||
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
await worker1.start()
|
||||
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
await worker2.start()
|
||||
|
||||
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
@ -91,7 +91,7 @@ async def test_register_receives_publish() -> None:
|
||||
host.start()
|
||||
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
await worker1.start()
|
||||
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker1.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
@ -99,7 +99,7 @@ async def test_register_receives_publish() -> None:
|
||||
await worker1.add_subscription(TypeSubscription("default", "name1"))
|
||||
|
||||
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
await worker2.start()
|
||||
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker2.register_factory(
|
||||
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
@ -137,7 +137,7 @@ async def test_register_doesnt_receive_after_removing_subscription() -> None:
|
||||
host.start()
|
||||
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
await worker1.start()
|
||||
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker1.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
@ -177,7 +177,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None:
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
|
||||
num_agents = 5
|
||||
num_initial_messages = 5
|
||||
@ -228,14 +228,14 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
workers.append(runtime)
|
||||
|
||||
# Publish messages
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
publisher.start()
|
||||
await publisher.start()
|
||||
for _ in range(num_initial_messages):
|
||||
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
||||
|
||||
@ -259,10 +259,10 @@ async def test_default_subscription() -> None:
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
await worker.start()
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
await publisher.start()
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
||||
|
||||
@ -294,10 +294,10 @@ async def test_default_subscription_other_source() -> None:
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
await publisher.start()
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
||||
|
||||
@ -329,10 +329,10 @@ async def test_type_subscription() -> None:
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
await worker.start()
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
await publisher.start()
|
||||
|
||||
@type_subscription("Other")
|
||||
class LoopbackAgentWithSubscription(LoopbackAgent): ...
|
||||
@ -369,10 +369,10 @@ async def test_duplicate_subscription() -> None:
|
||||
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
host.start()
|
||||
try:
|
||||
worker1.start()
|
||||
await worker1.start()
|
||||
await NoopAgent.register(worker1, "worker1", lambda: NoopAgent())
|
||||
|
||||
worker1_2.start()
|
||||
await worker1_2.start()
|
||||
|
||||
# Note: This passes because worker1 is still running
|
||||
with pytest.raises(Exception, match="Agent type worker1 already registered"):
|
||||
@ -411,7 +411,7 @@ async def test_disconnected_agent() -> None:
|
||||
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
|
||||
|
||||
try:
|
||||
worker1.start()
|
||||
await worker1.start()
|
||||
await LoopbackAgentWithDefaultSubscription.register(
|
||||
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
||||
)
|
||||
@ -439,7 +439,7 @@ async def test_disconnected_agent() -> None:
|
||||
assert len(recipients2) == 0
|
||||
await asyncio.sleep(1)
|
||||
|
||||
worker1_2.start()
|
||||
await worker1_2.start()
|
||||
await LoopbackAgentWithDefaultSubscription.register(
|
||||
worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
||||
)
|
||||
@ -480,12 +480,12 @@ async def test_proto_payloads() -> None:
|
||||
receiver_runtime = GrpcWorkerAgentRuntime(
|
||||
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
||||
)
|
||||
receiver_runtime.start()
|
||||
await receiver_runtime.start()
|
||||
publisher_runtime = GrpcWorkerAgentRuntime(
|
||||
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
||||
)
|
||||
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
|
||||
publisher_runtime.start()
|
||||
await publisher_runtime.start()
|
||||
|
||||
await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent)
|
||||
|
||||
@ -517,6 +517,7 @@ async def test_proto_payloads() -> None:
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Fix flakiness")
|
||||
async def test_grpc_max_message_size() -> None:
|
||||
default_max_size = 2**22
|
||||
new_max_size = default_max_size * 2
|
||||
@ -535,9 +536,9 @@ async def test_grpc_max_message_size() -> None:
|
||||
|
||||
try:
|
||||
host.start()
|
||||
worker1.start()
|
||||
worker2.start()
|
||||
worker3.start()
|
||||
await worker1.start()
|
||||
await worker2.start()
|
||||
await worker3.start()
|
||||
await LoopbackAgentWithDefaultSubscription.register(
|
||||
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
||||
)
|
||||
|
||||
@ -110,3 +110,8 @@ cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-core/tests/pr
|
||||
|
||||
[[tool.poe.tasks.gen-test-proto.sequence]]
|
||||
cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-ext/tests/protos --grpc_python_out=./packages/autogen-ext/tests/protos --mypy_out=./packages/autogen-ext/tests/protos --mypy_grpc_out=./packages/autogen-ext/tests/protos --proto_path ./packages/autogen-core/tests/protos serialization_test.proto"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"grpc: tests invoking gRPC functionality",
|
||||
]
|
||||
@ -20,7 +20,7 @@ async def main(config: AppConfig):
|
||||
editor_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
await asyncio.sleep(4)
|
||||
Console().print(Markdown("Starting **`Editor Agent`**"))
|
||||
editor_agent_runtime.start()
|
||||
await editor_agent_runtime.start()
|
||||
editor_agent_type = await BaseGroupChatAgent.register(
|
||||
editor_agent_runtime,
|
||||
config.editor_agent.topic_type,
|
||||
|
||||
@ -23,7 +23,7 @@ async def main(config: AppConfig):
|
||||
group_chat_manager_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
await asyncio.sleep(1)
|
||||
Console().print(Markdown("Starting **`Group Chat Manager`**"))
|
||||
group_chat_manager_runtime.start()
|
||||
await group_chat_manager_runtime.start()
|
||||
set_all_log_levels(logging.ERROR)
|
||||
|
||||
group_chat_manager_type = await GroupChatManager.register(
|
||||
|
||||
@ -41,7 +41,7 @@ async def main(config: AppConfig):
|
||||
ui_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
|
||||
Console().print(Markdown("Starting **`UI Agent`**"))
|
||||
ui_agent_runtime.start()
|
||||
await ui_agent_runtime.start()
|
||||
set_all_log_levels(logging.ERROR)
|
||||
|
||||
ui_agent_type = await UIAgent.register(
|
||||
|
||||
@ -21,7 +21,7 @@ async def main(config: AppConfig) -> None:
|
||||
await asyncio.sleep(3)
|
||||
Console().print(Markdown("Starting **`Writer Agent`**"))
|
||||
|
||||
writer_agent_runtime.start()
|
||||
await writer_agent_runtime.start()
|
||||
writer_agent_type = await BaseGroupChatAgent.register(
|
||||
writer_agent_runtime,
|
||||
config.writer_agent.topic_type,
|
||||
|
||||
@ -6,7 +6,7 @@ from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
async def main() -> None:
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
|
||||
await runtime.publish_message(CascadingMessage(round=1), topic_id=DefaultTopicId())
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
@ -8,7 +8,7 @@ from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
async def main() -> None:
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
|
||||
await CascadingAgent.register(runtime, agent_type, lambda: CascadingAgent(max_rounds=3))
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
@ -73,7 +73,7 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(t))
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
|
||||
await ReceiveAgent.register(
|
||||
runtime,
|
||||
|
||||
@ -80,7 +80,7 @@ async def output_result(
|
||||
async def run_workers():
|
||||
agent_runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
|
||||
agent_runtime.start()
|
||||
await agent_runtime.start()
|
||||
|
||||
# Create the agents
|
||||
await WorkerAgent.register(agent_runtime, "finance", lambda: WorkerAgent("finance_agent"))
|
||||
|
||||
@ -26,7 +26,7 @@ agnext_logger = logging.getLogger("autogen_core")
|
||||
|
||||
async def main() -> None:
|
||||
load_dotenv()
|
||||
agentHost = os.getenv("AGENT_HOST") or "localhost:53072"
|
||||
agentHost = os.getenv("AGENT_HOST") or "http://localhost:50673"
|
||||
# grpc python bug - can only use the hostname, not prefix - if hostname has a prefix we have to remove it:
|
||||
if agentHost.startswith("http://"):
|
||||
agentHost = agentHost[7:]
|
||||
@ -37,7 +37,7 @@ async def main() -> None:
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
|
||||
|
||||
agnext_logger.info("1")
|
||||
runtime.start()
|
||||
await runtime.start()
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(NewMessageReceived))
|
||||
|
||||
agnext_logger.info("2")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user