mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-16 20:51:38 +00:00
Implement try_get_underlying_agent_instance (#249)
This commit is contained in:
parent
a52d3bab53
commit
3ba7a48b13
@ -9,7 +9,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||||
|
|
||||||
from ..core import (
|
from ..core import (
|
||||||
MESSAGE_TYPE_REGISTRY,
|
MESSAGE_TYPE_REGISTRY,
|
||||||
@ -482,7 +482,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
return self._instantiated_agents[agent_id]
|
return self._instantiated_agents[agent_id]
|
||||||
|
|
||||||
if agent_id.name not in self._agent_factories:
|
if agent_id.name not in self._agent_factories:
|
||||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
raise LookupError(f"Agent with name {agent_id.name} not found.")
|
||||||
|
|
||||||
agent_factory = self._agent_factories[agent_id.name]
|
agent_factory = self._agent_factories[agent_id.name]
|
||||||
|
|
||||||
@ -499,6 +499,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
id = await self.get(name, namespace=namespace)
|
id = await self.get(name, namespace=namespace)
|
||||||
return AgentProxy(id, self)
|
return AgentProxy(id, self)
|
||||||
|
|
||||||
|
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||||
|
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||||
|
if id.name not in self._agent_factories:
|
||||||
|
raise LookupError(f"Agent with name {id.name} not found.")
|
||||||
|
|
||||||
|
# TODO: check if remote
|
||||||
|
agent_instance = await self._get_agent(id)
|
||||||
|
|
||||||
|
if not isinstance(agent_instance, type):
|
||||||
|
raise TypeError(f"Agent with name {id.name} is not of type {type.__name__}")
|
||||||
|
|
||||||
|
return agent_instance
|
||||||
|
|
||||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||||
# to ensure message type subscriptions are set up.
|
# to ensure message type subscriptions are set up.
|
||||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||||
|
|
||||||
from ._agent import Agent
|
from ._agent import Agent
|
||||||
from ._agent_id import AgentId
|
from ._agent_id import AgentId
|
||||||
@ -146,6 +146,26 @@ class AgentRuntime(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||||
|
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||||
|
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||||
|
|
||||||
|
If the underlying agent is not accessible, this will raise an exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (AgentId): The agent id.
|
||||||
|
type (Type[T], optional): The expected type of the agent. Defaults to Agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The concrete agent instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LookupError: If the agent is not found.
|
||||||
|
NotAccessibleError: If the agent is not accessible, for example if it is located remotely.
|
||||||
|
TypeError: If the agent is not of the expected type.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def register_and_get(
|
async def register_and_get(
|
||||||
self,
|
self,
|
||||||
|
@ -15,3 +15,7 @@ class UndeliverableException(Exception):
|
|||||||
|
|
||||||
class MessageDroppedException(Exception):
|
class MessageDroppedException(Exception):
|
||||||
"""Raised when a message is dropped."""
|
"""Raised when a message is dropped."""
|
||||||
|
|
||||||
|
|
||||||
|
class NotAccessibleError(Exception):
|
||||||
|
"""Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally."""
|
||||||
|
@ -21,6 +21,7 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
Set,
|
Set,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -410,6 +411,10 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
id = await self.get(name, namespace=namespace)
|
id = await self.get(name, namespace=namespace)
|
||||||
return AgentProxy(id, self)
|
return AgentProxy(id, self)
|
||||||
|
|
||||||
|
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||||
|
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||||
|
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
|
||||||
|
|
||||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||||
# to ensure message type subscriptions are set up.
|
# to ensure message type subscriptions are set up.
|
||||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||||
|
@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
|
|||||||
await response
|
await response
|
||||||
|
|
||||||
assert response.done()
|
assert response.done()
|
||||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||||
assert long_running_agent.called
|
assert long_running_agent.called
|
||||||
assert long_running_agent.cancelled
|
assert long_running_agent.cancelled
|
||||||
|
|
||||||
@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
|||||||
await response
|
await response
|
||||||
|
|
||||||
assert response.done()
|
assert response.done()
|
||||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||||
assert nested_agent.called
|
assert nested_agent.called
|
||||||
assert nested_agent.cancelled
|
assert nested_agent.cancelled
|
||||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||||
assert long_running_agent.called is False
|
assert long_running_agent.called is False
|
||||||
assert long_running_agent.cancelled is False
|
assert long_running_agent.cancelled is False
|
||||||
|
|
||||||
@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
|
|||||||
await response
|
await response
|
||||||
|
|
||||||
assert response.done()
|
assert response.done()
|
||||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||||
assert nested_agent.called
|
assert nested_agent.called
|
||||||
assert nested_agent.cancelled
|
assert nested_agent.cancelled
|
||||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||||
assert long_running_agent.called
|
assert long_running_agent.called
|
||||||
assert long_running_agent.cancelled
|
assert long_running_agent.cancelled
|
||||||
|
@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
|
|||||||
await run_context.stop()
|
await run_context.stop()
|
||||||
|
|
||||||
assert handler.num_messages == 1
|
assert handler.num_messages == 1
|
||||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||||
assert loopback_agent.num_calls == 1
|
assert loopback_agent.num_calls == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
|
|||||||
|
|
||||||
await run_context.stop()
|
await run_context.stop()
|
||||||
|
|
||||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||||
assert loopback_agent.num_calls == 0
|
assert loopback_agent.num_calls == 0
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
|||||||
|
|
||||||
await run_context.stop()
|
await run_context.stop()
|
||||||
|
|
||||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||||
assert long_running_agent.num_calls == 0
|
assert long_running_agent.num_calls == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -115,5 +115,5 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
|||||||
|
|
||||||
await run_context.stop()
|
await run_context.stop()
|
||||||
|
|
||||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||||
assert long_running_agent.num_calls == 1
|
assert long_running_agent.num_calls == 1
|
||||||
|
@ -34,11 +34,11 @@ async def test_register_receives_publish() -> None:
|
|||||||
await run_context.stop_when_idle()
|
await run_context.stop_when_idle()
|
||||||
|
|
||||||
# Agent in default namespace should have received the message
|
# Agent in default namespace should have received the message
|
||||||
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
|
long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent)
|
||||||
assert long_running_agent.num_calls == 1
|
assert long_running_agent.num_calls == 1
|
||||||
|
|
||||||
# Agent in other namespace should not have received the message
|
# Agent in other namespace should not have received the message
|
||||||
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
|
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent)
|
||||||
assert other_long_running_agent.num_calls == 0
|
assert other_long_running_agent.num_calls == 0
|
||||||
|
|
||||||
|
|
||||||
@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
|
|||||||
|
|
||||||
# Check that each agent received the correct number of messages.
|
# Check that each agent received the correct number of messages.
|
||||||
for i in range(num_agents):
|
for i in range(num_agents):
|
||||||
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
|
agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent)
|
||||||
assert agent.num_calls == total_num_calls_expected
|
assert agent.num_calls == total_num_calls_expected
|
||||||
|
@ -29,7 +29,7 @@ async def test_agent_can_save_state() -> None:
|
|||||||
runtime = SingleThreadedAgentRuntime()
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
|
||||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||||
assert agent1.state == 0
|
assert agent1.state == 0
|
||||||
agent1.state = 1
|
agent1.state = 1
|
||||||
assert agent1.state == 1
|
assert agent1.state == 1
|
||||||
@ -47,7 +47,7 @@ async def test_runtime_can_save_state() -> None:
|
|||||||
runtime = SingleThreadedAgentRuntime()
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
|
||||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||||
assert agent1.state == 0
|
assert agent1.state == 0
|
||||||
agent1.state = 1
|
agent1.state = 1
|
||||||
assert agent1.state == 1
|
assert agent1.state == 1
|
||||||
@ -56,7 +56,7 @@ async def test_runtime_can_save_state() -> None:
|
|||||||
|
|
||||||
runtime2 = SingleThreadedAgentRuntime()
|
runtime2 = SingleThreadedAgentRuntime()
|
||||||
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
||||||
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
|
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
|
||||||
|
|
||||||
await runtime2.load_state(runtime_state)
|
await runtime2.load_state(runtime_state)
|
||||||
assert agent2.state == 1
|
assert agent2.state == 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user