Activate deactivate agents (#4800)

* Instantiate and call activate/deactivate on agents

* autoformatting

* remove activate. Rename deactivate to close

* remove unneeded import

* create close fn in runtime

* change runtime close behavior

* uv.lock

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
peterychang 2025-01-07 16:37:02 -05:00 committed by GitHub
parent 5635ea397f
commit 71a3b238e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 44 additions and 0 deletions

View File

@ -190,3 +190,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def load_state(self, state: Mapping[str, Any]) -> None: async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state. Default implementation for stateless agents.""" """Restore agent from saved state. Default implementation for stateless agents."""
BaseState.model_validate(state) BaseState.model_validate(state)
async def close(self) -> None:
"""Called when the runtime is closed"""
pass

View File

@ -64,3 +64,7 @@ class ChatAgent(TaskRunner, Protocol):
async def load_state(self, state: Mapping[str, Any]) -> None: async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state""" """Restore agent from saved state"""
... ...
async def close(self) -> None:
"""Called when the runtime is stopped or any stop method is called"""
...

View File

@ -45,3 +45,7 @@ class Agent(Protocol):
""" """
... ...
async def close(self) -> None:
"""Called when the runtime is closed"""
...

View File

@ -152,6 +152,9 @@ class BaseAgent(ABC, Agent):
warnings.warn("load_state not implemented", stacklevel=2) warnings.warn("load_state not implemented", stacklevel=2)
pass pass
async def close(self) -> None:
pass
@classmethod @classmethod
async def register( async def register(
cls, cls,

View File

@ -309,6 +309,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
) )
) )
recipient_agent = await self._get_agent(recipient) recipient_agent = await self._get_agent(recipient)
message_context = MessageContext( message_context = MessageContext(
sender=message_envelope.sender, sender=message_envelope.sender,
topic_id=None, topic_id=None,
@ -589,10 +590,21 @@ class SingleThreadedAgentRuntime(AgentRuntime):
raise RuntimeError("Runtime is already started") raise RuntimeError("Runtime is already started")
self._run_context = RunContext(self) self._run_context = RunContext(self)
async def close(self) -> None:
"""Calls :meth:`stop` if applicable and the :meth:`Agent.close` method on all instantiated agents"""
# stop the runtime if it hasn't been stopped yet
if self._run_context is not None:
await self.stop()
# close all the agents that have been instantiated
for agent_id in self._instantiated_agents:
agent = await self._get_agent(agent_id)
await agent.close()
async def stop(self) -> None: async def stop(self) -> None:
"""Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded.""" """Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded."""
if self._run_context is None: if self._run_context is None:
raise RuntimeError("Runtime is not started") raise RuntimeError("Runtime is not started")
await self._run_context.stop() await self._run_context.stop()
self._run_context = None self._run_context = None
self._message_queue = Queue() self._message_queue = Queue()
@ -603,6 +615,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if self._run_context is None: if self._run_context is None:
raise RuntimeError("Runtime is not started") raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle() await self._run_context.stop_when_idle()
self._run_context = None self._run_context = None
self._message_queue = Queue() self._message_queue = Queue()
@ -623,6 +636,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if self._run_context is None: if self._run_context is None:
raise RuntimeError("Runtime is not started") raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition) await self._run_context.stop_when(condition)
self._run_context = None self._run_context = None
self._message_queue = Queue() self._message_queue = Queue()

View File

@ -86,6 +86,8 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
"autogen publish default.(default)-T", "autogen publish default.(default)-T",
] ]
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None: async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
@ -107,6 +109,8 @@ async def test_register_receives_publish_with_construction(caplog: pytest.LogCap
# Check if logger has the exception. # Check if logger has the exception.
assert any("Error constructing agent" in e.message for e in caplog.records) assert any("Error constructing agent" in e.message for e in caplog.records)
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None: async def test_register_receives_publish_cascade() -> None:
@ -137,6 +141,8 @@ async def test_register_receives_publish_cascade() -> None:
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected assert agent.num_calls == total_num_calls_expected
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_factory_explicit_name() -> None: async def test_register_factory_explicit_name() -> None:
@ -162,6 +168,8 @@ async def test_register_factory_explicit_name() -> None:
) )
assert other_long_running_agent.num_calls == 0 assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_subscription() -> None: async def test_default_subscription() -> None:
@ -185,6 +193,8 @@ async def test_default_subscription() -> None:
) )
assert other_long_running_agent.num_calls == 0 assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_type_subscription() -> None: async def test_type_subscription() -> None:
@ -208,6 +218,8 @@ async def test_type_subscription() -> None:
) )
assert other_long_running_agent.num_calls == 0 assert other_long_running_agent.num_calls == 0
await runtime.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_subscription_publish_to_other_source() -> None: async def test_default_subscription_publish_to_other_source() -> None:
@ -229,3 +241,5 @@ async def test_default_subscription_publish_to_other_source() -> None:
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
) )
assert other_long_running_agent.num_calls == 1 assert other_long_running_agent.num_calls == 1
await runtime.close()

View File

@ -179,6 +179,7 @@ class HostConnection:
class GrpcWorkerAgentRuntime(AgentRuntime): class GrpcWorkerAgentRuntime(AgentRuntime):
# TODO: Needs to handle agent close() call
def __init__( def __init__(
self, self,
host_address: str, host_address: str,