Implement try_get_underlying_agent_instance (#249)

This commit is contained in:
Jack Gerrits 2024-07-23 16:38:37 -07:00 committed by GitHub
parent a52d3bab53
commit 3ba7a48b13
8 changed files with 60 additions and 18 deletions

View File

@ -9,7 +9,7 @@ from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from ..core import (
MESSAGE_TYPE_REGISTRY,
@ -482,7 +482,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return self._instantiated_agents[agent_id]
if agent_id.name not in self._agent_factories:
raise ValueError(f"Agent with name {agent_id.name} not found.")
raise LookupError(f"Agent with name {agent_id.name} not found.")
agent_factory = self._agent_factories[agent_id.name]
@ -499,6 +499,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.name not in self._agent_factories:
raise LookupError(f"Agent with name {id.name} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(f"Agent with name {id.name} is not of type {type.__name__}")
return agent_instance
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
async def _process_seen_namespace(self, namespace: str) -> None:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
@ -146,6 +146,26 @@ class AgentRuntime(Protocol):
"""
...
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
If the underlying agent is not accessible, this will raise an exception.
Args:
id (AgentId): The agent id.
type (Type[T], optional): The expected type of the agent. Defaults to Agent.
Returns:
T: The concrete agent instance.
Raises:
LookupError: If the agent is not found.
NotAccessibleError: If the agent is not accessible, for example if it is located remotely.
TypeError: If the agent is not of the expected type.
"""
...
@overload
async def register_and_get(
self,

View File

@ -15,3 +15,7 @@ class UndeliverableException(Exception):
class MessageDroppedException(Exception):
"""Raised when a message is dropped."""
class NotAccessibleError(Exception):
"""Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally."""

View File

@ -21,6 +21,7 @@ from typing import (
Mapping,
ParamSpec,
Set,
Type,
TypeVar,
cast,
)
@ -410,6 +411,10 @@ class WorkerAgentRuntime(AgentRuntime):
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
async def _process_seen_namespace(self, namespace: str) -> None:

View File

@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
await response
assert response.done()
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled
@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled

View File

@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
await run_context.stop()
assert handler.num_messages == 1
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1
@pytest.mark.asyncio
@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
await run_context.stop()
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 0
@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
await run_context.stop()
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
@ -115,5 +115,5 @@ async def test_intervention_raise_exception_on_respond() -> None:
await run_context.stop()
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

View File

@ -34,11 +34,11 @@ async def test_register_receives_publish() -> None:
await run_context.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected

View File

@ -29,7 +29,7 @@ async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
@ -47,7 +47,7 @@ async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
@ -56,7 +56,7 @@ async def test_runtime_can_save_state() -> None:
runtime2 = SingleThreadedAgentRuntime()
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
await runtime2.load_state(runtime_state)
assert agent2.state == 1