mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 20:21:10 +00:00
Implement try_get_underlying_agent_instance (#249)
This commit is contained in:
parent
a52d3bab53
commit
3ba7a48b13
@ -9,7 +9,7 @@ from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from ..core import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
@ -482,7 +482,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
||||
raise LookupError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
@ -499,6 +499,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
if id.name not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {id.name} not found.")
|
||||
|
||||
# TODO: check if remote
|
||||
agent_instance = await self._get_agent(id)
|
||||
|
||||
if not isinstance(agent_instance, type):
|
||||
raise TypeError(f"Agent with name {id.name} is not of type {type.__name__}")
|
||||
|
||||
return agent_instance
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@ -146,6 +146,26 @@ class AgentRuntime(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||
|
||||
If the underlying agent is not accessible, this will raise an exception.
|
||||
|
||||
Args:
|
||||
id (AgentId): The agent id.
|
||||
type (Type[T], optional): The expected type of the agent. Defaults to Agent.
|
||||
|
||||
Returns:
|
||||
T: The concrete agent instance.
|
||||
|
||||
Raises:
|
||||
LookupError: If the agent is not found.
|
||||
NotAccessibleError: If the agent is not accessible, for example if it is located remotely.
|
||||
TypeError: If the agent is not of the expected type.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def register_and_get(
|
||||
self,
|
||||
|
@ -15,3 +15,7 @@ class UndeliverableException(Exception):
|
||||
|
||||
class MessageDroppedException(Exception):
|
||||
"""Raised when a message is dropped."""
|
||||
|
||||
|
||||
class NotAccessibleError(Exception):
|
||||
"""Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally."""
|
||||
|
@ -21,6 +21,7 @@ from typing import (
|
||||
Mapping,
|
||||
ParamSpec,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
@ -410,6 +411,10 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
|
@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
assert long_running_agent.called is False
|
||||
assert long_running_agent.cancelled is False
|
||||
|
||||
@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
|
||||
await run_context.stop()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||
assert loopback_agent.num_calls == 0
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -115,5 +115,5 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
@ -34,11 +34,11 @@ async def test_register_receives_publish() -> None:
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
|
||||
agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
@ -29,7 +29,7 @@ async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
@ -47,7 +47,7 @@ async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
@ -56,7 +56,7 @@ async def test_runtime_can_save_state() -> None:
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
|
||||
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
|
||||
|
||||
await runtime2.load_state(runtime_state)
|
||||
assert agent2.state == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user