From 3ba7a48b13f5ccbf097c6a4f9da3f77d365f5eaa Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 23 Jul 2024 16:38:37 -0700 Subject: [PATCH] Implement try_get_underlying_agent_instance (#249) --- .../_single_threaded_agent_runtime.py | 17 ++++++++++++-- python/src/agnext/core/_agent_runtime.py | 22 ++++++++++++++++++- python/src/agnext/core/exceptions.py | 4 ++++ python/src/agnext/worker/worker_runtime.py | 5 +++++ python/tests/test_cancellation.py | 10 ++++----- python/tests/test_intervention.py | 8 +++---- python/tests/test_runtime.py | 6 ++--- python/tests/test_state.py | 6 ++--- 8 files changed, 60 insertions(+), 18 deletions(-) diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 585a2b1ce..c4ed72a2b 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast +from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast from ..core import ( MESSAGE_TYPE_REGISTRY, @@ -482,7 +482,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): return self._instantiated_agents[agent_id] if agent_id.name not in self._agent_factories: - raise ValueError(f"Agent with name {agent_id.name} not found.") + raise LookupError(f"Agent with name {agent_id.name} not found.") agent_factory = self._agent_factories[agent_id.name] @@ -499,6 +499,19 @@ class SingleThreadedAgentRuntime(AgentRuntime): id = await self.get(name, namespace=namespace) return AgentProxy(id, self) + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + if id.name not in self._agent_factories: + raise LookupError(f"Agent with name {id.name} not found.") + + # TODO: check if remote + agent_instance = await self._get_agent(id) + + if not isinstance(agent_instance, type): + raise TypeError(f"Agent with name {id.name} is not of type {type.__name__}") + + return agent_instance + # Hydrate the agent instances in a namespace. The primary reason for this is # to ensure message type subscriptions are set up. async def _process_seen_namespace(self, namespace: str) -> None: diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index 1d423180e..a37772138 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable +from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable from ._agent import Agent from ._agent_id import AgentId @@ -146,6 +146,26 @@ class AgentRuntime(Protocol): """ ... + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + """Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases. + + If the underlying agent is not accessible, this will raise an exception. + + Args: + id (AgentId): The agent id. + type (Type[T], optional): The expected type of the agent. Defaults to Agent. + + Returns: + T: The concrete agent instance. + + Raises: + LookupError: If the agent is not found. + NotAccessibleError: If the agent is not accessible, for example if it is located remotely. + TypeError: If the agent is not of the expected type. + """ + ... + @overload async def register_and_get( self, diff --git a/python/src/agnext/core/exceptions.py b/python/src/agnext/core/exceptions.py index 4c4a70ded..f35c3fbfe 100644 --- a/python/src/agnext/core/exceptions.py +++ b/python/src/agnext/core/exceptions.py @@ -15,3 +15,7 @@ class UndeliverableException(Exception): class MessageDroppedException(Exception): """Raised when a message is dropped.""" + + +class NotAccessibleError(Exception): + """Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally.""" diff --git a/python/src/agnext/worker/worker_runtime.py b/python/src/agnext/worker/worker_runtime.py index e80102ce4..8d5f6810f 100644 --- a/python/src/agnext/worker/worker_runtime.py +++ b/python/src/agnext/worker/worker_runtime.py @@ -21,6 +21,7 @@ from typing import ( Mapping, ParamSpec, Set, + Type, TypeVar, cast, ) @@ -410,6 +411,10 @@ class WorkerAgentRuntime(AgentRuntime): id = await self.get(name, namespace=namespace) return AgentProxy(id, self) + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.") + # Hydrate the agent instances in a namespace. The primary reason for this is # to ensure message type subscriptions are set up. async def _process_seen_namespace(self, namespace: str) -> None: diff --git a/python/tests/test_cancellation.py b/python/tests/test_cancellation.py index f383d8de7..a077a2c9b 100644 --- a/python/tests/test_cancellation.py +++ b/python/tests/test_cancellation.py @@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None: await response assert response.done() - long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled @@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None: await response assert response.done() - nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore + nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) assert nested_agent.called assert nested_agent.cancelled - long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) assert long_running_agent.called is False assert long_running_agent.cancelled is False @@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None: await response assert response.done() - nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore + nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) assert nested_agent.called assert nested_agent.cancelled - long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled diff --git a/python/tests/test_intervention.py b/python/tests/test_intervention.py index 395305b22..62942bbb1 100644 --- a/python/tests/test_intervention.py +++ b/python/tests/test_intervention.py @@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None: await run_context.stop() assert handler.num_messages == 1 - loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore + loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert loopback_agent.num_calls == 1 @pytest.mark.asyncio @@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None: await run_context.stop() - loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore + loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert loopback_agent.num_calls == 0 @@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None: await run_context.stop() - long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent) assert long_running_agent.num_calls == 0 @pytest.mark.asyncio @@ -115,5 +115,5 @@ async def test_intervention_raise_exception_on_respond() -> None: await run_context.stop() - long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent) assert long_running_agent.num_calls == 1 diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index 8cef7b16c..d7e4f1a91 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -34,11 +34,11 @@ async def test_register_receives_publish() -> None: await run_context.stop_when_idle() # Agent in default namespace should have received the message - long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore + long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent) assert long_running_agent.num_calls == 1 # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore + other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent) assert other_long_running_agent.num_calls == 0 @@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None: # Check that each agent received the correct number of messages. for i in range(num_agents): - agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore + agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent) assert agent.num_calls == total_num_calls_expected diff --git a/python/tests/test_state.py b/python/tests/test_state.py index b73e244bc..b88f79b20 100644 --- a/python/tests/test_state.py +++ b/python/tests/test_state.py @@ -29,7 +29,7 @@ async def test_agent_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() agent1_id = await runtime.register_and_get("name1", StatefulAgent) - agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore + agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 @@ -47,7 +47,7 @@ async def test_runtime_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() agent1_id = await runtime.register_and_get("name1", StatefulAgent) - agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore + agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 @@ -56,7 +56,7 @@ async def test_runtime_can_save_state() -> None: runtime2 = SingleThreadedAgentRuntime() agent2_id = await runtime2.register_and_get("name1", StatefulAgent) - agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore + agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent) await runtime2.load_state(runtime_state) assert agent2.state == 1