From fd021db91c67ca83981f5cf15a3bef22f42c85b6 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 18 Sep 2024 19:08:35 -0700 Subject: [PATCH] Fix bug in register_factory for worker runtime (#563) --- CHANGELOG.md | 10 --- .../application/_worker_runtime.py | 10 ++- .../_worker_runtime_host_servicer.py | 8 +-- .../autogen-core/tests/test_runtime.py | 3 +- .../autogen-core/tests/test_worker_runtime.py | 61 +++++++++++++++---- 5 files changed, 64 insertions(+), 28 deletions(-) delete mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index fba516ac8..000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,10 +0,0 @@ -# Change Log - -## [Unreleased] - -## [0.1.0] - 2024-08-16 - -### Changed - -- Change `cancellation_token : CancellationToken` to `ctx: MessageContext` in - agent's message handler signature. diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 96d00d9fe..db35d494e 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -291,7 +291,7 @@ class WorkerAgentRuntime(AgentRuntime): raise RuntimeError("Host connection is not set.") data_type = self._serialization_registry.type_name(message) with self._trace_helper.trace_block( - "create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)} + "create", recipient, parent=None, extraAttributes={"message_type": data_type} ): # create a new future for the result future = asyncio.get_event_loop().create_future() @@ -555,6 +555,11 @@ class WorkerAgentRuntime(AgentRuntime): agent_factory: Callable[[], T | Awaitable[T]], expected_class: type[T], ) -> AgentType: + if type.type in self._agent_factories: + raise ValueError(f"Agent with type {type} already exists.") + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + async def factory_wrapper() -> T: maybe_agent_instance = agent_factory() if inspect.isawaitable(maybe_agent_instance): @@ -569,6 +574,9 @@ class WorkerAgentRuntime(AgentRuntime): self._agent_factories[type.type] = factory_wrapper + message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type.type)) + await self._host_connection.send(message) + return type async def _invoke_agent_factory( diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py index 10ffe920e..4d65f9cd2 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py @@ -24,7 +24,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, int] = {} - self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {} + self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {} self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() @@ -65,7 +65,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): # Clean up the client connection. del self._send_queues[client_id] # Cancel pending requests sent to this client. - for future in self._pending_requests.pop(client_id, {}).values(): + for future in self._pending_responses.pop(client_id, {}).values(): future.cancel() # Remove the client id from the agent type to client id mapping. async with self._agent_type_to_client_id_lock: @@ -137,7 +137,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): # Create a future to wait for the response from the target. future = asyncio.get_event_loop().create_future() - self._pending_requests.setdefault(target_client_id, {})[request.request_id] = future + self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future # Create a task to wait for the response and send it back to the client. send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) @@ -156,7 +156,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: # Setting the result of the future will send the response back to the original sender. - future = self._pending_requests[client_id].pop(response.request_id) + future = self._pending_responses[client_id].pop(response.request_id) future.set_result(response) async def _process_event(self, event: agent_worker_pb2.Event) -> None: diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 60000484a..8d4e95a0b 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -87,8 +87,7 @@ async def test_register_receives_publish_cascade() -> None: # Register agents for i in range(num_agents): - await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds)) - await runtime.add_subscription(TypeSubscription("default", f"name{i}")) + await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()]) runtime.start() diff --git a/python/packages/autogen-core/tests/test_worker_runtime.py b/python/packages/autogen-core/tests/test_worker_runtime.py index d2db024d1..159268fa6 100644 --- a/python/packages/autogen-core/tests/test_worker_runtime.py +++ b/python/packages/autogen-core/tests/test_worker_runtime.py @@ -1,4 +1,5 @@ import asyncio +import logging import pytest from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost @@ -14,7 +15,6 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess @pytest.mark.asyncio async def test_agent_names_must_be_unique() -> None: - # Keep it unique to this test only. host_address = "localhost:50051" host = WorkerAgentRuntimeHost(address=host_address) host.start() @@ -45,7 +45,6 @@ async def test_agent_names_must_be_unique() -> None: @pytest.mark.asyncio async def test_register_receives_publish() -> None: - # Keep it unique to this test only. host_address = "localhost:50052" host = WorkerAgentRuntimeHost(address=host_address) host.start() @@ -79,12 +78,10 @@ async def test_register_receives_publish() -> None: @pytest.mark.asyncio async def test_register_receives_publish_cascade() -> None: - # Keep it unique to this test only. host_address = "localhost:50053" host = WorkerAgentRuntimeHost(address=host_address) host.start() runtime = WorkerAgentRuntime(host_address=host_address) - runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType)) runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType)) runtime.start() @@ -97,15 +94,14 @@ async def test_register_receives_publish_cascade() -> None: # Register agents for i in range(num_agents): - await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds)) - await runtime.add_subscription(TypeSubscription("default", f"name{i}")) + await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()]) # Publish messages for _ in range(num_initial_messages): await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId()) - # Let the agents run for a bit. - await asyncio.sleep(5) + # Wait for all agents to finish. + await asyncio.sleep(10) # Check that each agent received the correct number of messages. for i in range(num_agents): @@ -116,9 +112,54 @@ async def test_register_receives_publish_cascade() -> None: await host.stop() +@pytest.mark.skip(reason="Fix flakiness") +@pytest.mark.asyncio +async def test_register_receives_publish_cascade_multiple_workers() -> None: + logging.basicConfig(level=logging.DEBUG) + host_address = "localhost:50057" + host = WorkerAgentRuntimeHost(address=host_address) + host.start() + + # TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail. + num_agents = 2 + num_initial_messages = 1 + max_rounds = 1 + total_num_calls_expected = 0 + for i in range(0, max_rounds): + total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i) + + # Run multiple workers one for each agent. + workers = [] + # Register agents + for i in range(num_agents): + runtime = WorkerAgentRuntime(host_address=host_address) + runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType)) + runtime.start() + await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()]) + workers.append(runtime) + + # Publish messages + publisher = WorkerAgentRuntime(host_address=host_address) + publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType)) + publisher.start() + for _ in range(num_initial_messages): + await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId()) + + await asyncio.sleep(20) + + # Check that each agent received the correct number of messages. + for i in range(num_agents): + agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) + assert agent.num_calls == total_num_calls_expected + + for worker in workers: + await worker.stop() + await publisher.stop() + await host.stop() + + @pytest.mark.asyncio async def test_default_subscription() -> None: - # Keep it unique to this test only. host_address = "localhost:50054" host = WorkerAgentRuntimeHost(address=host_address) host.start() @@ -148,7 +189,6 @@ async def test_default_subscription() -> None: @pytest.mark.asyncio async def test_non_default_default_subscription() -> None: - # Keep it unique to this test only. host_address = "localhost:50055" host = WorkerAgentRuntimeHost(address=host_address) host.start() @@ -178,7 +218,6 @@ async def test_non_default_default_subscription() -> None: @pytest.mark.asyncio async def test_non_publish_to_other_source() -> None: - # Keep it unique to this test only. host_address = "localhost:50056" host = WorkerAgentRuntimeHost(address=host_address) host.start()