mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-16 17:48:46 +00:00
Activate deactivate agents (#4800)
* Instantiate and call activate/deactivate on agents * autoformatting * remove activate. Rename deactivate to close * remove unneeded import * create close fn in runtime * change runtime close behavior * uv.lock --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
5635ea397f
commit
71a3b238e7
@ -190,3 +190,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||
BaseState.model_validate(state)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Called when the runtime is closed"""
|
||||
pass
|
||||
|
||||
@ -64,3 +64,7 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Called when the runtime is stopped or any stop method is called"""
|
||||
...
|
||||
|
||||
@ -45,3 +45,7 @@ class Agent(Protocol):
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Called when the runtime is closed"""
|
||||
...
|
||||
|
||||
@ -152,6 +152,9 @@ class BaseAgent(ABC, Agent):
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls,
|
||||
|
||||
@ -309,6 +309,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
)
|
||||
)
|
||||
recipient_agent = await self._get_agent(recipient)
|
||||
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=None,
|
||||
@ -589,10 +590,21 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
raise RuntimeError("Runtime is already started")
|
||||
self._run_context = RunContext(self)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Calls :meth:`stop` if applicable and the :meth:`Agent.close` method on all instantiated agents"""
|
||||
# stop the runtime if it hasn't been stopped yet
|
||||
if self._run_context is not None:
|
||||
await self.stop()
|
||||
# close all the agents that have been instantiated
|
||||
for agent_id in self._instantiated_agents:
|
||||
agent = await self._get_agent(agent_id)
|
||||
await agent.close()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
|
||||
await self._run_context.stop()
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
@ -603,6 +615,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when_idle()
|
||||
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
@ -623,6 +636,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when(condition)
|
||||
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
|
||||
@ -86,6 +86,8 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
|
||||
"autogen publish default.(default)-T",
|
||||
]
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
|
||||
@ -107,6 +109,8 @@ async def test_register_receives_publish_with_construction(caplog: pytest.LogCap
|
||||
# Check if logger has the exception.
|
||||
assert any("Error constructing agent" in e.message for e in caplog.records)
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade() -> None:
|
||||
@ -137,6 +141,8 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_explicit_name() -> None:
|
||||
@ -162,6 +168,8 @@ async def test_register_factory_explicit_name() -> None:
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
@ -185,6 +193,8 @@ async def test_default_subscription() -> None:
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_subscription() -> None:
|
||||
@ -208,6 +218,8 @@ async def test_type_subscription() -> None:
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription_publish_to_other_source() -> None:
|
||||
@ -229,3 +241,5 @@ async def test_default_subscription_publish_to_other_source() -> None:
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
||||
await runtime.close()
|
||||
|
||||
@ -179,6 +179,7 @@ class HostConnection:
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
# TODO: Needs to handle agent close() call
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user