Fix bug in register_factory for worker runtime (#563)

This commit is contained in:
Eric Zhu 2024-09-18 19:08:35 -07:00 committed by GitHub
parent 1edf5cfe9c
commit fd021db91c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 28 deletions

View File

@ -1,10 +0,0 @@
# Change Log
## [Unreleased]
## [0.1.0] - 2024-08-16
### Changed
- Change `cancellation_token : CancellationToken` to `ctx: MessageContext` in
agent's message handler signature.

View File

@ -291,7 +291,7 @@ class WorkerAgentRuntime(AgentRuntime):
raise RuntimeError("Host connection is not set.")
data_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)}
"create", recipient, parent=None, extraAttributes={"message_type": data_type}
):
# create a new future for the result
future = asyncio.get_event_loop().create_future()
@ -555,6 +555,11 @@ class WorkerAgentRuntime(AgentRuntime):
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
@ -569,6 +574,9 @@ class WorkerAgentRuntime(AgentRuntime):
self._agent_factories[type.type] = factory_wrapper
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type.type))
await self._host_connection.send(message)
return type
async def _invoke_agent_factory(

View File

@ -24,7 +24,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, int] = {}
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
@ -65,7 +65,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
# Clean up the client connection.
del self._send_queues[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_requests.pop(client_id, {}).values():
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
async with self._agent_type_to_client_id_lock:
@ -137,7 +137,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
self._pending_requests.setdefault(target_client_id, {})[request.request_id] = future
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
@ -156,7 +156,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_requests[client_id].pop(response.request_id)
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:

View File

@ -87,8 +87,7 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
runtime.start()

View File

@ -1,4 +1,5 @@
import asyncio
import logging
import pytest
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
@ -14,7 +15,6 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
# Keep it unique to this test only.
host_address = "localhost:50051"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
@ -45,7 +45,6 @@ async def test_agent_names_must_be_unique() -> None:
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
# Keep it unique to this test only.
host_address = "localhost:50052"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
@ -79,12 +78,10 @@ async def test_register_receives_publish() -> None:
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
# Keep it unique to this test only.
host_address = "localhost:50053"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime.start()
@ -97,15 +94,14 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
# Publish messages
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
# Let the agents run for a bit.
await asyncio.sleep(5)
# Wait for all agents to finish.
await asyncio.sleep(10)
# Check that each agent received the correct number of messages.
for i in range(num_agents):
@ -116,9 +112,54 @@ async def test_register_receives_publish_cascade() -> None:
await host.stop()
@pytest.mark.skip(reason="Fix flakiness")
@pytest.mark.asyncio
async def test_register_receives_publish_cascade_multiple_workers() -> None:
logging.basicConfig(level=logging.DEBUG)
host_address = "localhost:50057"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
num_agents = 2
num_initial_messages = 1
max_rounds = 1
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Run multiple workers one for each agent.
workers = []
# Register agents
for i in range(num_agents):
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime.start()
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
workers.append(runtime)
# Publish messages
publisher = WorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
publisher.start()
for _ in range(num_initial_messages):
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
await asyncio.sleep(20)
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
for worker in workers:
await worker.stop()
await publisher.stop()
await host.stop()
@pytest.mark.asyncio
async def test_default_subscription() -> None:
# Keep it unique to this test only.
host_address = "localhost:50054"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
@ -148,7 +189,6 @@ async def test_default_subscription() -> None:
@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
# Keep it unique to this test only.
host_address = "localhost:50055"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
@ -178,7 +218,6 @@ async def test_non_default_default_subscription() -> None:
@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
# Keep it unique to this test only.
host_address = "localhost:50056"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()