mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-02 01:49:53 +00:00
Fix bug in register_factory for worker runtime (#563)
This commit is contained in:
parent
1edf5cfe9c
commit
fd021db91c
10
CHANGELOG.md
10
CHANGELOG.md
@ -1,10 +0,0 @@
|
||||
# Change Log
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.1.0] - 2024-08-16
|
||||
|
||||
### Changed
|
||||
|
||||
- Change `cancellation_token : CancellationToken` to `ctx: MessageContext` in
|
||||
agent's message handler signature.
|
||||
@ -291,7 +291,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
data_type = self._serialization_registry.type_name(message)
|
||||
with self._trace_helper.trace_block(
|
||||
"create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)}
|
||||
"create", recipient, parent=None, extraAttributes={"message_type": data_type}
|
||||
):
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
@ -555,6 +555,11 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
if type.type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
@ -569,6 +574,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type.type))
|
||||
await self._host_connection.send(message)
|
||||
|
||||
return type
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
|
||||
@ -24,7 +24,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||
self._agent_type_to_client_id: Dict[str, int] = {}
|
||||
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
|
||||
@ -65,7 +65,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
# Clean up the client connection.
|
||||
del self._send_queues[client_id]
|
||||
# Cancel pending requests sent to this client.
|
||||
for future in self._pending_requests.pop(client_id, {}).values():
|
||||
for future in self._pending_responses.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
# Remove the client id from the agent type to client id mapping.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
@ -137,7 +137,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
|
||||
# Create a future to wait for the response from the target.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._pending_requests.setdefault(target_client_id, {})[request.request_id] = future
|
||||
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
|
||||
|
||||
# Create a task to wait for the response and send it back to the client.
|
||||
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
||||
@ -156,7 +156,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
||||
# Setting the result of the future will send the response back to the original sender.
|
||||
future = self._pending_requests[client_id].pop(response.request_id)
|
||||
future = self._pending_responses[client_id].pop(response.request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
|
||||
@ -87,8 +87,7 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
|
||||
runtime.start()
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
@ -14,7 +15,6 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50051"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
@ -45,7 +45,6 @@ async def test_agent_names_must_be_unique() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50052"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
@ -79,12 +78,10 @@ async def test_register_receives_publish() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50053"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime.start()
|
||||
|
||||
@ -97,15 +94,14 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
|
||||
# Publish messages
|
||||
for _ in range(num_initial_messages):
|
||||
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
||||
|
||||
# Let the agents run for a bit.
|
||||
await asyncio.sleep(5)
|
||||
# Wait for all agents to finish.
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
@ -116,9 +112,54 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Fix flakiness")
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
host_address = "localhost:50057"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
|
||||
num_agents = 2
|
||||
num_initial_messages = 1
|
||||
max_rounds = 1
|
||||
total_num_calls_expected = 0
|
||||
for i in range(0, max_rounds):
|
||||
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
||||
|
||||
# Run multiple workers one for each agent.
|
||||
workers = []
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime.start()
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
workers.append(runtime)
|
||||
|
||||
# Publish messages
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
publisher.start()
|
||||
for _ in range(num_initial_messages):
|
||||
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
||||
|
||||
await asyncio.sleep(20)
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
for worker in workers:
|
||||
await worker.stop()
|
||||
await publisher.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50054"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
@ -148,7 +189,6 @@ async def test_default_subscription() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_default_default_subscription() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
@ -178,7 +218,6 @@ async def test_non_default_default_subscription() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_publish_to_other_source() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50056"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user