mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 07:29:54 +00:00
Add ability to register Agent instances (#6131)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Nice to have functionality ## Related issue number Closes #6060 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
c26d894c34
commit
9118f9b998
@ -1,9 +1,13 @@
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._message_context import MessageContext
|
||||
|
||||
# Forward declaration for type checking only
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@ -17,6 +21,15 @@ class Agent(Protocol):
|
||||
"""ID of the agent."""
|
||||
...
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
|
||||
"""Function used to bind an Agent instance to an `AgentRuntime`.
|
||||
|
||||
Args:
|
||||
agent_id (AgentId): ID of the agent.
|
||||
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
|
||||
|
||||
|
||||
@ -118,3 +118,9 @@ class AgentInstantiationContext:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def is_in_factory_call(cls) -> bool:
|
||||
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -130,6 +130,60 @@ class AgentRuntime(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
|
||||
from autogen_core.models import UserMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("My core agent")
|
||||
|
||||
@event
|
||||
async def handler(self, message: UserMessage, context: MessageContext) -> None:
|
||||
print("Event received: ", message.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime: AgentRuntime = ... # type: ignore
|
||||
agent = MyAgent()
|
||||
await runtime.register_agent_instance(
|
||||
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
|
||||
)
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
agent_instance (Agent): A concrete instance of the agent.
|
||||
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||
|
||||
@ -21,6 +21,7 @@ from ._subscription import Subscription, UnboundSubscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
from ._type_prefix_subscription import TypePrefixSubscription
|
||||
from ._type_subscription import TypeSubscription
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
@ -82,20 +83,25 @@ class BaseAgent(ABC, Agent):
|
||||
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
try:
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
) from e
|
||||
|
||||
self._runtime: AgentRuntime = runtime
|
||||
self._id: AgentId = id
|
||||
if AgentInstantiationContext.is_in_factory_call():
|
||||
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
|
||||
self._id = AgentInstantiationContext.current_agent_id()
|
||||
if not isinstance(description, str):
|
||||
raise ValueError("Agent description must be a string")
|
||||
self._description = description
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
|
||||
if hasattr(self, "_id"):
|
||||
if self._id != id:
|
||||
raise RuntimeError("Agent is already bound to a different ID")
|
||||
|
||||
if hasattr(self, "_runtime"):
|
||||
if self._runtime != runtime:
|
||||
raise RuntimeError("Agent is already bound to a different runtime")
|
||||
|
||||
self._id = id
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.id.type
|
||||
@ -155,6 +161,56 @@ class BaseAgent(ABC, Agent):
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_instance(
|
||||
self,
|
||||
runtime: AgentRuntime,
|
||||
agent_id: AgentId,
|
||||
*,
|
||||
skip_class_subscriptions: bool = True,
|
||||
skip_direct_message_subscription: bool = False,
|
||||
) -> AgentId:
|
||||
"""
|
||||
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
|
||||
"""
|
||||
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)
|
||||
|
||||
id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
|
||||
await runtime.add_subscription(id_subscription)
|
||||
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in self._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
if not skip_direct_message_subscription:
|
||||
# Additionally adds a special prefix subscription for this agent to receive direct messages
|
||||
try:
|
||||
await runtime.add_subscription(
|
||||
TypePrefixSubscription(
|
||||
# The prefix MUST include ":" to avoid collisions with other agents
|
||||
topic_type_prefix=agent_id.type + ":",
|
||||
agent_type=agent_id.type,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
# We don't care if the subscription already exists
|
||||
pass
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in self._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_id
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls,
|
||||
|
||||
@ -266,6 +266,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
|
||||
self._background_exception: BaseException | None = None
|
||||
self._agent_instance_types: Dict[str, Type[Agent]] = {}
|
||||
|
||||
@property
|
||||
def unprocessed_messages_count(
|
||||
@ -909,6 +910,32 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
return type
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
def agent_factory() -> Agent:
|
||||
raise RuntimeError(
|
||||
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
|
||||
)
|
||||
|
||||
if agent_id in self._instantiated_agents:
|
||||
raise ValueError(f"Agent with id {agent_id} already exists.")
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
self._agent_factories[agent_id.type] = agent_factory
|
||||
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
|
||||
else:
|
||||
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
|
||||
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
|
||||
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
|
||||
raise ValueError("Agent instances must be the same object type.")
|
||||
|
||||
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
|
||||
self._instantiated_agents[agent_id] = agent_instance
|
||||
return agent_id
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
@ -930,8 +957,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
agent = cast(T, await agent)
|
||||
return agent
|
||||
|
||||
except BaseException as e:
|
||||
|
||||
@ -9,7 +9,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
|
||||
runtime = mocker.Mock(spec=AgentRuntime)
|
||||
|
||||
# Shows how to set the context for the agent instantiation in a test context
|
||||
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
|
||||
agent = NoopAgent()
|
||||
assert agent.runtime == runtime
|
||||
assert agent.id == AgentId("name", "namespace")
|
||||
with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))):
|
||||
agent2 = NoopAgent()
|
||||
assert agent2.runtime == runtime
|
||||
assert agent2.id == AgentId("name2", "namespace2")
|
||||
|
||||
@ -82,6 +82,60 @@ async def test_agent_type_must_be_unique() -> None:
|
||||
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_instance() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent1_id = AgentId(type="name", key="default")
|
||||
agent2_id = AgentId(type="name", key="notdefault")
|
||||
agent1 = NoopAgent()
|
||||
agent1_dup = NoopAgent()
|
||||
agent2 = NoopAgent()
|
||||
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
|
||||
await agent2.register_instance(runtime=runtime, agent_id=agent2_id)
|
||||
|
||||
assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
|
||||
assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
|
||||
with pytest.raises(ValueError):
|
||||
await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_instance_different_types() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent_id1 = AgentId(type="name", key="noop")
|
||||
agent_id2 = AgentId(type="name", key="loopback")
|
||||
agent1 = NoopAgent()
|
||||
agent2 = LoopbackAgent()
|
||||
await agent1.register_instance(runtime=runtime, agent_id=agent_id1)
|
||||
with pytest.raises(ValueError):
|
||||
await agent2.register_instance(runtime=runtime, agent_id=agent_id2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_instance_publish_new_source() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
agent_id = AgentId(type="name", key="default")
|
||||
agent1 = LoopbackAgent()
|
||||
await agent1.register_instance(runtime=runtime, agent_id=agent_id)
|
||||
await runtime.add_subscription(TypeSubscription("notdefault", "name"))
|
||||
|
||||
runtime.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
|
||||
await runtime.stop_when_idle()
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance_factory() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent1_id = AgentId(type="name", key="default")
|
||||
agent1 = NoopAgent()
|
||||
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
|
||||
with pytest.raises(ValueError):
|
||||
await NoopAgent.register(runtime, "name", lambda: NoopAgent())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
|
||||
|
||||
@ -251,6 +251,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
self._extra_grpc_config = extra_grpc_config or []
|
||||
self._agent_instance_types: Dict[str, Type[Agent]] = {}
|
||||
|
||||
if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
|
||||
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
|
||||
@ -701,6 +702,14 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
except BaseException as e:
|
||||
logger.error("Error handling event", exc_info=e)
|
||||
|
||||
async def _register_agent_type(self, agent_type: str) -> None:
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type)
|
||||
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
type: str | AgentType,
|
||||
@ -729,14 +738,38 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
# Send the registration request message to the host.
|
||||
message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type)
|
||||
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
await self._register_agent_type(type.type)
|
||||
|
||||
return type
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
def agent_factory() -> Agent:
|
||||
raise RuntimeError(
|
||||
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
|
||||
)
|
||||
|
||||
if agent_id in self._instantiated_agents:
|
||||
raise ValueError(f"Agent with id {agent_id} already exists.")
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
self._agent_factories[agent_id.type] = agent_factory
|
||||
await self._register_agent_type(agent_id.type)
|
||||
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
|
||||
else:
|
||||
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
|
||||
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
|
||||
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
|
||||
raise ValueError("Agent instances must be the same object type.")
|
||||
|
||||
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
|
||||
self._instantiated_agents[agent_id] = agent_instance
|
||||
return agent_id
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
@ -757,7 +790,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
agent = cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, List, Type, Union
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
@ -570,7 +570,7 @@ def thought_with_tool_call_stream_client(monkeypatch: pytest.MonkeyPatch) -> Azu
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = MagicMock()
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("stream", False):
|
||||
|
||||
@ -577,6 +577,139 @@ async def test_grpc_max_message_size() -> None:
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_instance() -> None:
|
||||
host_address = "localhost:50051"
|
||||
agent1_id = AgentId(type="name", key="default")
|
||||
agentdup_id = AgentId(type="name", key="default")
|
||||
agent2_id = AgentId(type="name", key="notdefault")
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
agent1 = NoopAgent()
|
||||
agent2 = NoopAgent()
|
||||
agentdup = NoopAgent()
|
||||
await worker.start()
|
||||
|
||||
await worker.register_agent_instance(agent1, agent_id=agent1_id)
|
||||
await worker.register_agent_instance(agent2, agent_id=agent2_id)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await worker.register_agent_instance(agentdup, agent_id=agentdup_id)
|
||||
|
||||
assert await worker.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
|
||||
assert await worker.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_instance_different_types() -> None:
|
||||
host_address = "localhost:50051"
|
||||
agent1_id = AgentId(type="name", key="noop")
|
||||
agent2_id = AgentId(type="name", key="loopback")
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
agent1 = NoopAgent()
|
||||
agent2 = LoopbackAgent()
|
||||
await worker.start()
|
||||
|
||||
await worker.register_agent_instance(agent1, agent_id=agent1_id)
|
||||
with pytest.raises(ValueError):
|
||||
await worker.register_agent_instance(agent2, agent_id=agent2_id)
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance_factory() -> None:
|
||||
host_address = "localhost:50051"
|
||||
agent1_id = AgentId(type="name", key="default")
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
agent1 = NoopAgent()
|
||||
await worker.start()
|
||||
|
||||
await agent1.register_instance(runtime=worker, agent_id=agent1_id)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await NoopAgent.register(runtime=worker, type="name", factory=lambda: NoopAgent())
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
async def test_instance_factory_messaging() -> None:
|
||||
host_address = "localhost:50051"
|
||||
loopback_agent_id = AgentId(type="dm_agent", key="dm_agent")
|
||||
cascading_agent_id = AgentId(type="instance_agent", key="instance_agent")
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
cascading_agent = CascadingAgent(max_rounds=5)
|
||||
loopback_agent = LoopbackAgent()
|
||||
await worker.start()
|
||||
|
||||
await loopback_agent.register_instance(worker, agent_id=loopback_agent_id)
|
||||
resp = await worker.send_message(message=ContentMessage(content="Hello!"), recipient=loopback_agent_id)
|
||||
assert resp == ContentMessage(content="Hello!")
|
||||
|
||||
await cascading_agent.register_instance(worker, agent_id=cascading_agent_id)
|
||||
await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5))
|
||||
|
||||
# instance_agent will publish a message that factory_agent will pick up
|
||||
for i in range(5):
|
||||
await worker.publish_message(
|
||||
CascadingMessageType(round=i + 1), TopicId(type="instance_agent", source="instance_agent")
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
agent = await worker.try_get_underlying_agent_instance(AgentId("factory_agent", "default"), CascadingAgent)
|
||||
assert agent.num_calls == 4
|
||||
assert cascading_agent.num_calls == 5
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
# GrpcWorkerAgentRuntimeHost eats exceptions in the main loop
|
||||
# @pytest.mark.grpc
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_agent_type_register_instance_publish_new_source() -> None:
|
||||
# host_address = "localhost:50056"
|
||||
# agent_id = AgentId(type="name", key="default")
|
||||
# agent1 = LoopbackAgent()
|
||||
# host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
# host.start()
|
||||
# worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
# await worker.start()
|
||||
# publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
# publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
# await publisher.start()
|
||||
|
||||
# await agent1.register_instance(worker, agent_id=agent_id)
|
||||
# await worker.add_subscription(TypeSubscription("notdefault", "name"))
|
||||
|
||||
# with pytest.raises(RuntimeError):
|
||||
# await worker.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
|
||||
# await asyncio.sleep(2)
|
||||
|
||||
# await worker.stop()
|
||||
# await host.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["GRPC_VERBOSITY"] = "DEBUG"
|
||||
os.environ["GRPC_TRACE"] = "all"
|
||||
|
||||
@ -176,7 +176,7 @@ async def test_invalid_request(test_config: ComponentModel, test_server: None) -
|
||||
config.config["host"] = "fake"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)):
|
||||
await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user