Activate deactivate agents (#4800)

* Instantiate and call activate/deactivate on agents

* autoformatting

* remove activate. Rename deactivate to close

* remove unneeded import

* create close fn in runtime

* change runtime close behavior

* uv.lock

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
peterychang 2025-01-07 16:37:02 -05:00 committed by GitHub
parent 5635ea397f
commit 71a3b238e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 44 additions and 0 deletions

View File

@ -190,3 +190,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state. Default implementation for stateless agents."""
BaseState.model_validate(state)
async def close(self) -> None:
"""Called when the runtime is closed"""
pass

View File

@ -64,3 +64,7 @@ class ChatAgent(TaskRunner, Protocol):
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state"""
...
async def close(self) -> None:
"""Called when the runtime is stopped or any stop method is called"""
...

View File

@ -45,3 +45,7 @@ class Agent(Protocol):
"""
...
async def close(self) -> None:
"""Called when the runtime is closed"""
...

View File

@ -152,6 +152,9 @@ class BaseAgent(ABC, Agent):
warnings.warn("load_state not implemented", stacklevel=2)
pass
async def close(self) -> None:
pass
@classmethod
async def register(
cls,

View File

@ -309,6 +309,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
)
)
recipient_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
@ -589,10 +590,21 @@ class SingleThreadedAgentRuntime(AgentRuntime):
raise RuntimeError("Runtime is already started")
self._run_context = RunContext(self)
async def close(self) -> None:
"""Calls :meth:`stop` if applicable and the :meth:`Agent.close` method on all instantiated agents"""
# stop the runtime if it hasn't been stopped yet
if self._run_context is not None:
await self.stop()
# close all the agents that have been instantiated
for agent_id in self._instantiated_agents:
agent = await self._get_agent(agent_id)
await agent.close()
async def stop(self) -> None:
"""Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop()
self._run_context = None
self._message_queue = Queue()
@ -603,6 +615,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle()
self._run_context = None
self._message_queue = Queue()
@ -623,6 +636,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition)
self._run_context = None
self._message_queue = Queue()

View File

@ -86,6 +86,8 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
"autogen publish default.(default)-T",
]
await runtime.close()
@pytest.mark.asyncio
async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
@ -107,6 +109,8 @@ async def test_register_receives_publish_with_construction(caplog: pytest.LogCap
# Check if logger has the exception.
assert any("Error constructing agent" in e.message for e in caplog.records)
await runtime.close()
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
@ -137,6 +141,8 @@ async def test_register_receives_publish_cascade() -> None:
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
await runtime.close()
@pytest.mark.asyncio
async def test_register_factory_explicit_name() -> None:
@ -162,6 +168,8 @@ async def test_register_factory_explicit_name() -> None:
)
assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio
async def test_default_subscription() -> None:
@ -185,6 +193,8 @@ async def test_default_subscription() -> None:
)
assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio
async def test_type_subscription() -> None:
@ -208,6 +218,8 @@ async def test_type_subscription() -> None:
)
assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio
async def test_default_subscription_publish_to_other_source() -> None:
@ -229,3 +241,5 @@ async def test_default_subscription_publish_to_other_source() -> None:
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 1
await runtime.close()

View File

@ -179,6 +179,7 @@ class HostConnection:
class GrpcWorkerAgentRuntime(AgentRuntime):
# TODO: Needs to handle agent close() call
def __init__(
self,
host_address: str,