Update send_message to be a single async operation. Add start helper to runtime to manage this (#165)

This commit is contained in:
Jack Gerrits 2024-07-01 11:53:45 -04:00 committed by GitHub
parent 28f11c726d
commit 766635394a
29 changed files with 170 additions and 124 deletions

View File

@ -93,10 +93,11 @@ The local embedded runtime {py:class}`~agnext.application.SingleThreadedAgentRun
can be called to process messages until there are no more messages to process.
```python
await runtime.process_until_idle()
run_context = runtime.start()
await run_context.stop_when_idle()
```
It can also be called to process a single message:
`runtime.start()` will start a background task to process messages. You can directly process messages without a background task using:
```python
await runtime.process_next()
@ -157,7 +158,7 @@ class MyAgent(TypeRoutedAgent):
@message_handler
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
print(f"Hello, {message.source}, you said {message.content}!")
@message_handler
async def on_image_message(self, message: ImageMessage, cancellation_token: CancellationToken) -> None:
print(f"Hello, {message.source}, you sent me {message.url}!")
@ -165,9 +166,10 @@ class MyAgent(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
agent = runtime.register_and_get("my_agent", lambda: MyAgent("My Agent"))
run_context = runtime.start()
await runtime.send_message(TextMessage(content="Hello, World!", source="User"), agent)
await runtime.send_message(ImageMessage(url="https://example.com/image.jpg", source="User"), agent)
await runtime.process_until_idle()
await run_context.stop_when_idle()
import asyncio
asyncio.run(main())
@ -189,7 +191,7 @@ Awaiting calls to these methods will return the return value of the
receiving agent's message handler.
```{note}
If the invoked agent raises an exception while the sender is awaiting,
If the invoked agent raises an exception while the sender is awaiting,
the exception will be propagated back to the sender.
```
@ -228,8 +230,9 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
inner = runtime.register_and_get("inner_agent", lambda: InnerAgent("InnerAgent"))
outer = runtime.register_and_get("outer_agent", lambda: OuterAgent("OuterAgent", inner))
run_context = runtime.start()
await runtime.send_message("Hello, World!", outer)
await runtime.process_until_idle()
await run_context.stop_when_idle()
import asyncio
asyncio.run(main())
@ -244,11 +247,6 @@ Received message: Hello, World!
Received inner response: Hello from inner, Hello from outer, Hello, World!
```
```{note}
To get the response after sending a message, the sender must await on the
response future. So you can also write `response = await await self.send_message(...)`.
```
#### Command/Notification
In many scenarios, an agent can commanded another agent to perform an action,
@ -278,7 +276,7 @@ When an agent publishes a message it is one way only, it cannot receive a respon
from any other agent, even if a receiving agent sends a response.
```{note}
An agent receiving a message does not know if it is handling a published or direct message.
An agent receiving a message does not know if it is handling a published or direct message.
So, if a response is given to a published message, it will be thrown away.
```
@ -313,8 +311,9 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
broadcaster = runtime.register_and_get("broadcasting_agent", lambda: BroadcastingAgent("Broadcasting Agent"))
runtime.register("receiving_agent", lambda: ReceivingAgent("Receiving Agent"))
run_context = runtime.start()
await runtime.send_message("Hello, World!", broadcaster)
await runtime.process_until_idle()
await run_context.stop_when_idle()
import asyncio
asyncio.run(main())

View File

@ -47,12 +47,12 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
inner = runtime.register_and_get("inner", Inner)
outer = runtime.register_and_get("outer", lambda: Outer(inner))
run_context = runtime.start()
response = await runtime.send_message(MessageType(body="Hello", sender="external"), outer)
while not response.done():
await runtime.process_next()
print(await response)
print(response)
await run_context.stop()
if __name__ == "__main__":

View File

@ -50,19 +50,19 @@ async def main() -> None:
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-3.5-turbo")),
)
run_context = runtime.start()
# Send a message to the agent.
message = Message(content="Can you tell me something fun about SF?")
result = await runtime.send_message(message, agent)
# Process messages until the agent responds.
while result.done() is False:
await runtime.process_next()
# Get the response from the agent.
response = await result
assert isinstance(response, Message)
print(response.content)
await run_context.stop()
if __name__ == "__main__":
import logging

View File

@ -100,12 +100,14 @@ async def main() -> None:
),
)
run_context = runtime.start()
# Send a message to Jack to start the conversation.
message = Message(content="Can you tell me something fun about SF?", source="User")
await runtime.send_message(message, jack)
# Process messages.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -226,13 +226,12 @@ Type "exit" to exit the chat.
"""
runtime = SingleThreadedAgentRuntime()
user = assistant_chat(runtime)
_run_context = runtime.start()
print(usage)
# Request the user to start the conversation.
await runtime.send_message(PublishNow(), user)
while True:
# TODO: have a way to exit the loop.
await runtime.process_next()
await asyncio.sleep(1)
# TODO: have a way to exit the loop.
if __name__ == "__main__":

View File

@ -17,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from common.memory import BufferedChatMemory
from common.types import Message, TextMessage
from common.utils import convert_messages_to_llm_messages, get_chat_completion_client_from_envs
from utils import TextualChatApp, TextualUserAgent, start_runtime
from utils import TextualChatApp, TextualUserAgent
# Define a custom agent that can handle chat room messages.
@ -140,7 +140,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
app = TextualChatApp(runtime, user_name="You")
chat_room(runtime, app)
asyncio.create_task(start_runtime(runtime))
_run_context = runtime.start()
await app.run_async()

View File

@ -16,7 +16,7 @@ from common.agents import ChatCompletionAgent, ImageGenerationAgent
from common.memory import BufferedChatMemory
from common.patterns._group_chat_manager import GroupChatManager
from common.utils import get_chat_completion_client_from_envs
from utils import TextualChatApp, TextualUserAgent, start_runtime
from utils import TextualChatApp, TextualUserAgent
def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
@ -98,7 +98,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
app = TextualChatApp(runtime, user_name="You")
illustrator_critics(runtime, app)
asyncio.create_task(start_runtime(runtime))
_run_context = runtime.start()
await app.run_async()

View File

@ -31,7 +31,7 @@ from common.agents import ChatCompletionAgent
from common.memory import HeadAndTailChatMemory
from common.patterns._group_chat_manager import GroupChatManager
from common.utils import get_chat_completion_client_from_envs
from utils import TextualChatApp, TextualUserAgent, start_runtime
from utils import TextualChatApp, TextualUserAgent
async def write_file(filename: str, content: str) -> str:
@ -281,7 +281,7 @@ async def main() -> None:
app = TextualChatApp(runtime, user_name="You")
software_consultancy(runtime, app)
# Start the runtime.
asyncio.create_task(start_runtime(runtime))
_run_context = runtime.start()
# Start the app.
await app.run_async()

View File

@ -4,7 +4,6 @@ import random
import sys
from asyncio import Future
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import Image, TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from textual.app import App, ComposeResult
@ -189,9 +188,3 @@ class TextualUserAgent(TypeRoutedAgent): # type: ignore
self, message: ToolApprovalRequest, cancellation_token: CancellationToken
) -> ToolApprovalResponse:
return await self._app.handle_tool_approval_request(message)
async def start_runtime(runtime: SingleThreadedAgentRuntime) -> None: # type: ignore
"""Run the runtime in a loop."""
while True:
await runtime.process_next()

View File

@ -182,12 +182,12 @@ async def main(task: str, temp_dir: str) -> None:
# Register the agents.
runtime.register("coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")))
runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
run_context = runtime.start()
# Publish the task message.
await runtime.publish_message(TaskMessage(content=task), namespace="default")
# Run the runtime until no more message.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -265,6 +265,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
),
)
run_context = runtime.start()
await runtime.publish_message(
message=CodeWritingTask(
task="Write a function to find the directory with the largest number of files using multi-processing."
@ -273,7 +274,7 @@ async def main() -> None:
)
# Keep processing messages until idle.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -148,11 +148,13 @@ async def main() -> None:
),
)
# Start the runtime.
run_context = runtime.start()
# Start the conversation.
await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
# Run the runtime.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -149,10 +149,11 @@ async def main() -> None:
num_references=3,
),
)
run_context = runtime.start()
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
# Keep processing messages.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -246,11 +246,12 @@ async def main(question: str) -> None:
# Register the aggregator agent.
runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
run_context = runtime.start()
# Send a math problem to the aggregator agent.
await runtime.publish_message(Question(content=question), namespace="default")
# Run the runtime.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -130,10 +130,10 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore
runtime = SingleThreadedAgentRuntime()
chat = scenario(runtime)
run_context = runtime.start()
response = await runtime.send_message(TextMessage(content=message, source=user), chat.id)
while not response.done():
await runtime.process_next()
print((await response).content) # type: ignore
await run_context.stop()
if __name__ == "__main__":

View File

@ -140,14 +140,15 @@ async def main() -> None:
),
)
run_context = runtime.start()
# Send a task to the tool user.
result = await runtime.send_message(
UserRequest("Run the following Python code: print('Hello, World!')"), tool_agent
)
# Run the runtime until the task is completed.
while not result.done():
await runtime.process_next()
await run_context.stop()
# Print the result.
ai_response = result.result()

View File

@ -202,13 +202,14 @@ async def main() -> None:
),
)
run_context = runtime.start()
# Publish a task.
await runtime.publish_message(
UserRequest("Run the following Python code: print('Hello, World!')"), namespace="default"
)
# Run the runtime.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -49,12 +49,13 @@ async def main() -> None:
),
)
run_context = runtime.start()
# Send a task to the tool user.
result = await runtime.send_message(UserRequest("What is the stock price of NVDA on 2024/06/01"), tool_agent)
# Run the runtime until the task is completed.
while not result.done():
await runtime.process_next()
await run_context.stop()
# Print the result.
ai_response = result.result()

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
import inspect
import logging
import threading
from asyncio import CancelledError, Future
from asyncio import CancelledError, Future, Task
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
from ..core import (
@ -80,6 +83,40 @@ class Counter:
self.threadLock.release()
class RunContext:
class RunState(Enum):
RUNNING = 0
CANCELLED = 1
UNTIL_IDLE = 2
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_state = RunContext.RunState.RUNNING
self._run_task = asyncio.create_task(self._run())
self._lock = asyncio.Lock()
async def _run(self) -> None:
while True:
async with self._lock:
if self._run_state == RunContext.RunState.CANCELLED:
return
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
if self._runtime.idle:
return
await self._runtime.process_next()
async def stop(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.CANCELLED
await self._run_task
async def stop_when_idle(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.UNTIL_IDLE
await self._run_task
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
@ -90,6 +127,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
@property
def unprocessed_messages(
@ -113,7 +151,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
@ -149,7 +187,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
)
)
return future
cancellation_token.link_future(future)
return await future
async def publish_message(
self,
@ -334,7 +374,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_send(message_envelope))
task = asyncio.create_task(self._process_send(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case PublishMessageEnvelope(
message=message,
sender=sender,
@ -352,7 +394,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_publish(message_envelope))
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handler is not None:
try:
@ -369,16 +413,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_response(message_envelope))
task = asyncio.create_task(self._process_response(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
async def process_until_idle(self) -> None:
"""Process messages until there is no unprocessed message and no message currently being processed."""
@property
def idle(self) -> bool:
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
while len(self.unprocessed_messages) > 0 or self.outstanding_tasks > 0:
await self.process_next()
def start(self) -> RunContext:
return RunContext(self)
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return self._get_agent(agent).metadata

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from asyncio import Future
from typing import TYPE_CHECKING, Any, Mapping
from ._agent_id import AgentId
@ -32,7 +31,7 @@ class AgentProxy:
*,
sender: AgentId,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
) -> Any:
return await self._runtime.send_message(
message,
recipient=self._agent,

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from asyncio import Future
from contextvars import ContextVar
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
@ -19,8 +18,6 @@ agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextV
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
# Can raise CantHandleException
async def send_message(
self,
message: Any,
@ -28,17 +25,8 @@ class AgentRuntime(Protocol):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
"""Send a message to an agent and return a future that will resolve to the response.
The act of sending a message may be asynchronous, and the response to the message itself is also asynchronous. For example:
.. code-block:: python
response_future = await runtime.send_message(MyMessage("Hello"), recipient=agent_id)
response = await response_future
The returned future only needs to be awaited if the response is needed. If the response is not needed, the future can be ignored.
) -> Any:
"""Send a message to an agent and get a response.
Args:
message (Any): The message to send.
@ -49,14 +37,14 @@ class AgentRuntime(Protocol):
Raises:
CantHandleException: If the recipient cannot handle the message.
UndeliverableException: If the message cannot be delivered.
Other: Any other exception raised by the recipient.
Returns:
Future[Any]: A future that will resolve to the response of the message.
Any: The response from the agent.
"""
...
# No responses from publishing
async def publish_message(
self,
message: Any,

View File

@ -1,6 +1,5 @@
import warnings
from abc import ABC, abstractmethod
from asyncio import Future
from typing import Any, Mapping, Sequence
from ._agent import Agent
@ -55,19 +54,17 @@ class BaseAgent(ABC, Agent):
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
) -> Any:
"""See :py:meth:`agnext.core.AgentRuntime.send_message` for more information."""
if cancellation_token is None:
cancellation_token = CancellationToken()
future = await self._runtime.send_message(
return await self._runtime.send_message(
message,
sender=self.id,
recipient=recipient,
cancellation_token=cancellation_token,
)
cancellation_token.link_future(future)
return future
async def publish_message(
self,

View File

@ -24,12 +24,14 @@ async def main() -> None:
task = input("Enter a task: ")
run_context = runtime.start()
await runtime.publish_message(
BroadcastMessage(content=UserMessage(content=task, source="human")), namespace="default"
)
# Run the runtime until the task is completed.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -18,12 +18,14 @@ async def main() -> None:
task = input(f"Enter a task for {file_surfer.name}: ")
msg = BroadcastMessage(content=UserMessage(content=task, source="human"))
run_context = runtime.start()
# Send a task to the tool user.
await runtime.publish_message(msg, namespace="default")
await runtime.publish_message(RequestReplyMessage(), namespace="default")
# Run the runtime until the task is completed.
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -16,9 +16,10 @@ async def main() -> None:
runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
task_message = UserMessage(content="Test Message", source="User")
run_context = runtime.start()
await runtime.publish_message(BroadcastMessage(task_message), namespace="default")
await runtime.process_until_idle()
await run_context.stop_when_idle()
if __name__ == "__main__":

View File

@ -43,7 +43,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
@message_handler
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = await self.send_message(message, self._nested_agent, cancellation_token=cancellation_token)
response = self.send_message(message, self._nested_agent, cancellation_token=cancellation_token)
try:
val = await response
assert isinstance(val, MessageType)
@ -59,10 +59,14 @@ async def test_cancellation_with_token() -> None:
long_running = runtime.register_and_get("long_running", LongRunningAgent)
token = CancellationToken()
response = await runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
@ -83,9 +87,12 @@ async def test_nested_cancellation_only_outer_called() -> None:
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = await runtime.send_message(MessageType(), nested, cancellation_token=token)
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
token.cancel()
@ -108,9 +115,12 @@ async def test_nested_cancellation_inner_called() -> None:
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = await runtime.send_message(MessageType(), nested, cancellation_token=token)
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
# allow the inner agent to process
await runtime.process_next()

View File

@ -29,11 +29,12 @@ async def test_register_receives_publish() -> None:
await queue.put((namespace, message.content))
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
run_context = runtime.start()
await runtime.publish_message(Message("first message"), namespace="default")
await runtime.publish_message(Message("second message"), namespace="default")
await runtime.publish_message(Message("third message"), namespace="default")
await runtime.process_until_idle()
await run_context.stop_when_idle()
assert queue.qsize() == 3
assert queue.get_nowait() == ("default", "first message")

View File

@ -20,11 +20,11 @@ async def test_intervention_count_messages() -> None:
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
response = await runtime.send_message(MessageType(), recipient=loopback)
_response = await runtime.send_message(MessageType(), recipient=loopback)
while not response.done():
await runtime.process_next()
await run_context.stop()
assert handler.num_messages == 1
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
@ -41,13 +41,12 @@ async def test_intervention_drop_send() -> None:
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
response = await runtime.send_message(MessageType(), recipient=loopback)
while not response.done():
await runtime.process_next()
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
await response
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
assert loopback_agent.num_calls == 0
@ -64,13 +63,12 @@ async def test_intervention_drop_response() -> None:
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
response = await runtime.send_message(MessageType(), recipient=loopback)
while not response.done():
await runtime.process_next()
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
await response
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
@pytest.mark.asyncio
@ -87,13 +85,12 @@ async def test_intervention_raise_exception_on_send() -> None:
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
response = await runtime.send_message(MessageType(), recipient=long_running)
while not response.done():
await runtime.process_next()
run_context = runtime.start()
with pytest.raises(InterventionException):
await response
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 0
@ -112,13 +109,11 @@ async def test_intervention_raise_exception_on_respond() -> None:
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
response = await runtime.send_message(MessageType(), recipient=long_running)
while not response.done():
await runtime.process_next()
run_context = runtime.start()
with pytest.raises(InterventionException):
await response
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 1

View File

@ -28,9 +28,10 @@ async def test_register_receives_publish() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
run_context = runtime.start()
await runtime.publish_message(MessageType(), namespace="default")
await runtime.process_until_idle()
await run_context.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
@ -54,13 +55,15 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
run_context = runtime.start()
# Publish messages
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), namespace="default")
# Process until idle.
await runtime.process_until_idle()
await run_context.stop_when_idle()
# Check that each agent received the correct number of messages.
for i in range(num_agents):