mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 22:48:40 +00:00
Update send_message to be a single async operation. Add start helper to runtime to manage this (#165)
This commit is contained in:
parent
28f11c726d
commit
766635394a
@ -93,10 +93,11 @@ The local embedded runtime {py:class}`~agnext.application.SingleThreadedAgentRun
|
||||
can be called to process messages until there are no more messages to process.
|
||||
|
||||
```python
|
||||
await runtime.process_until_idle()
|
||||
run_context = runtime.start()
|
||||
await run_context.stop_when_idle()
|
||||
```
|
||||
|
||||
It can also be called to process a single message:
|
||||
`runtime.start()` will start a background task to process messages. You can directly process messages without a background task using:
|
||||
|
||||
```python
|
||||
await runtime.process_next()
|
||||
@ -157,7 +158,7 @@ class MyAgent(TypeRoutedAgent):
|
||||
@message_handler
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
print(f"Hello, {message.source}, you said {message.content}!")
|
||||
|
||||
|
||||
@message_handler
|
||||
async def on_image_message(self, message: ImageMessage, cancellation_token: CancellationToken) -> None:
|
||||
print(f"Hello, {message.source}, you sent me {message.url}!")
|
||||
@ -165,9 +166,10 @@ class MyAgent(TypeRoutedAgent):
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent = runtime.register_and_get("my_agent", lambda: MyAgent("My Agent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message(TextMessage(content="Hello, World!", source="User"), agent)
|
||||
await runtime.send_message(ImageMessage(url="https://example.com/image.jpg", source="User"), agent)
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
@ -189,7 +191,7 @@ Awaiting calls to these methods will return the return value of the
|
||||
receiving agent's message handler.
|
||||
|
||||
```{note}
|
||||
If the invoked agent raises an exception while the sender is awaiting,
|
||||
If the invoked agent raises an exception while the sender is awaiting,
|
||||
the exception will be propagated back to the sender.
|
||||
```
|
||||
|
||||
@ -228,8 +230,9 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
inner = runtime.register_and_get("inner_agent", lambda: InnerAgent("InnerAgent"))
|
||||
outer = runtime.register_and_get("outer_agent", lambda: OuterAgent("OuterAgent", inner))
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message("Hello, World!", outer)
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
@ -244,11 +247,6 @@ Received message: Hello, World!
|
||||
Received inner response: Hello from inner, Hello from outer, Hello, World!
|
||||
```
|
||||
|
||||
```{note}
|
||||
To get the response after sending a message, the sender must await on the
|
||||
response future. So you can also write `response = await await self.send_message(...)`.
|
||||
```
|
||||
|
||||
#### Command/Notification
|
||||
|
||||
In many scenarios, an agent can commanded another agent to perform an action,
|
||||
@ -278,7 +276,7 @@ When an agent publishes a message it is one way only, it cannot receive a respon
|
||||
from any other agent, even if a receiving agent sends a response.
|
||||
|
||||
```{note}
|
||||
An agent receiving a message does not know if it is handling a published or direct message.
|
||||
An agent receiving a message does not know if it is handling a published or direct message.
|
||||
So, if a response is given to a published message, it will be thrown away.
|
||||
```
|
||||
|
||||
@ -313,8 +311,9 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
broadcaster = runtime.register_and_get("broadcasting_agent", lambda: BroadcastingAgent("Broadcasting Agent"))
|
||||
runtime.register("receiving_agent", lambda: ReceivingAgent("Receiving Agent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message("Hello, World!", broadcaster)
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@ -47,12 +47,12 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
inner = runtime.register_and_get("inner", Inner)
|
||||
outer = runtime.register_and_get("outer", lambda: Outer(inner))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
response = await runtime.send_message(MessageType(body="Hello", sender="external"), outer)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print(await response)
|
||||
print(response)
|
||||
await run_context.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -50,19 +50,19 @@ async def main() -> None:
|
||||
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-3.5-turbo")),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a message to the agent.
|
||||
message = Message(content="Can you tell me something fun about SF?")
|
||||
result = await runtime.send_message(message, agent)
|
||||
|
||||
# Process messages until the agent responds.
|
||||
while result.done() is False:
|
||||
await runtime.process_next()
|
||||
|
||||
# Get the response from the agent.
|
||||
response = await result
|
||||
assert isinstance(response, Message)
|
||||
print(response.content)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
@ -100,12 +100,14 @@ async def main() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a message to Jack to start the conversation.
|
||||
message = Message(content="Can you tell me something fun about SF?", source="User")
|
||||
await runtime.send_message(message, jack)
|
||||
|
||||
# Process messages.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -226,13 +226,12 @@ Type "exit" to exit the chat.
|
||||
"""
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
user = assistant_chat(runtime)
|
||||
_run_context = runtime.start()
|
||||
print(usage)
|
||||
# Request the user to start the conversation.
|
||||
await runtime.send_message(PublishNow(), user)
|
||||
while True:
|
||||
# TODO: have a way to exit the loop.
|
||||
await runtime.process_next()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# TODO: have a way to exit the loop.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -17,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.types import Message, TextMessage
|
||||
from common.utils import convert_messages_to_llm_messages, get_chat_completion_client_from_envs
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
from utils import TextualChatApp, TextualUserAgent
|
||||
|
||||
|
||||
# Define a custom agent that can handle chat room messages.
|
||||
@ -140,7 +140,7 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
chat_room(runtime, app)
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
_run_context = runtime.start()
|
||||
await app.run_async()
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from common.agents import ChatCompletionAgent, ImageGenerationAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
from utils import TextualChatApp, TextualUserAgent
|
||||
|
||||
|
||||
def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
@ -98,7 +98,7 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
illustrator_critics(runtime, app)
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
_run_context = runtime.start()
|
||||
await app.run_async()
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from common.agents import ChatCompletionAgent
|
||||
from common.memory import HeadAndTailChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
from utils import TextualChatApp, TextualUserAgent
|
||||
|
||||
|
||||
async def write_file(filename: str, content: str) -> str:
|
||||
@ -281,7 +281,7 @@ async def main() -> None:
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
software_consultancy(runtime, app)
|
||||
# Start the runtime.
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
_run_context = runtime.start()
|
||||
# Start the app.
|
||||
await app.run_async()
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import random
|
||||
import sys
|
||||
from asyncio import Future
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import Image, TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
from textual.app import App, ComposeResult
|
||||
@ -189,9 +188,3 @@ class TextualUserAgent(TypeRoutedAgent): # type: ignore
|
||||
self, message: ToolApprovalRequest, cancellation_token: CancellationToken
|
||||
) -> ToolApprovalResponse:
|
||||
return await self._app.handle_tool_approval_request(message)
|
||||
|
||||
|
||||
async def start_runtime(runtime: SingleThreadedAgentRuntime) -> None: # type: ignore
|
||||
"""Run the runtime in a loop."""
|
||||
while True:
|
||||
await runtime.process_next()
|
||||
|
||||
@ -182,12 +182,12 @@ async def main(task: str, temp_dir: str) -> None:
|
||||
# Register the agents.
|
||||
runtime.register("coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")))
|
||||
runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish the task message.
|
||||
await runtime.publish_message(TaskMessage(content=task), namespace="default")
|
||||
|
||||
# Run the runtime until no more message.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -265,6 +265,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
),
|
||||
)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(
|
||||
message=CodeWritingTask(
|
||||
task="Write a function to find the directory with the largest number of files using multi-processing."
|
||||
@ -273,7 +274,7 @@ async def main() -> None:
|
||||
)
|
||||
|
||||
# Keep processing messages until idle.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -148,11 +148,13 @@ async def main() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
# Start the runtime.
|
||||
run_context = runtime.start()
|
||||
|
||||
# Start the conversation.
|
||||
await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
|
||||
|
||||
# Run the runtime.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -149,10 +149,11 @@ async def main() -> None:
|
||||
num_references=3,
|
||||
),
|
||||
)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
|
||||
|
||||
# Keep processing messages.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -246,11 +246,12 @@ async def main(question: str) -> None:
|
||||
# Register the aggregator agent.
|
||||
runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a math problem to the aggregator agent.
|
||||
await runtime.publish_message(Question(content=question), namespace="default")
|
||||
|
||||
# Run the runtime.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -130,10 +130,10 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
|
||||
async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
chat = scenario(runtime)
|
||||
run_context = runtime.start()
|
||||
response = await runtime.send_message(TextMessage(content=message, source=user), chat.id)
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
print((await response).content) # type: ignore
|
||||
await run_context.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -140,14 +140,15 @@ async def main() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
result = await runtime.send_message(
|
||||
UserRequest("Run the following Python code: print('Hello, World!')"), tool_agent
|
||||
)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
while not result.done():
|
||||
await runtime.process_next()
|
||||
await run_context.stop()
|
||||
|
||||
# Print the result.
|
||||
ai_response = result.result()
|
||||
|
||||
@ -202,13 +202,14 @@ async def main() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish a task.
|
||||
await runtime.publish_message(
|
||||
UserRequest("Run the following Python code: print('Hello, World!')"), namespace="default"
|
||||
)
|
||||
|
||||
# Run the runtime.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -49,12 +49,13 @@ async def main() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
result = await runtime.send_message(UserRequest("What is the stock price of NVDA on 2024/06/01"), tool_agent)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
while not result.done():
|
||||
await runtime.process_next()
|
||||
await run_context.stop()
|
||||
|
||||
# Print the result.
|
||||
ai_response = result.result()
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from asyncio import CancelledError, Future
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
||||
|
||||
from ..core import (
|
||||
@ -80,6 +83,40 @@ class Counter:
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class RunContext:
|
||||
class RunState(Enum):
|
||||
RUNNING = 0
|
||||
CANCELLED = 1
|
||||
UNTIL_IDLE = 2
|
||||
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._run_state == RunContext.RunState.CANCELLED:
|
||||
return
|
||||
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
|
||||
if self._runtime.idle:
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
await self._run_task
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
@ -90,6 +127,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
@ -113,7 +151,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]:
|
||||
) -> Any:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
@ -149,7 +187,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
)
|
||||
)
|
||||
|
||||
return future
|
||||
cancellation_token.link_future(future)
|
||||
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
@ -334,7 +374,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_send(message_envelope))
|
||||
task = asyncio.create_task(self._process_send(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
@ -352,7 +394,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
@ -369,16 +413,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_response(message_envelope))
|
||||
task = asyncio.create_task(self._process_response(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def process_until_idle(self) -> None:
|
||||
"""Process messages until there is no unprocessed message and no message currently being processed."""
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
|
||||
|
||||
while len(self.unprocessed_messages) > 0 or self.outstanding_tasks > 0:
|
||||
await self.process_next()
|
||||
def start(self) -> RunContext:
|
||||
return RunContext(self)
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return self._get_agent(agent).metadata
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from typing import TYPE_CHECKING, Any, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
@ -32,7 +31,7 @@ class AgentProxy:
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
) -> Any:
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
|
||||
@ -19,8 +18,6 @@ agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextV
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
# Returns the response of the message
|
||||
# Can raise CantHandleException
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
@ -28,17 +25,8 @@ class AgentRuntime(Protocol):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
"""Send a message to an agent and return a future that will resolve to the response.
|
||||
|
||||
The act of sending a message may be asynchronous, and the response to the message itself is also asynchronous. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response_future = await runtime.send_message(MyMessage("Hello"), recipient=agent_id)
|
||||
response = await response_future
|
||||
|
||||
The returned future only needs to be awaited if the response is needed. If the response is not needed, the future can be ignored.
|
||||
) -> Any:
|
||||
"""Send a message to an agent and get a response.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
@ -49,14 +37,14 @@ class AgentRuntime(Protocol):
|
||||
Raises:
|
||||
CantHandleException: If the recipient cannot handle the message.
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
Other: Any other exception raised by the recipient.
|
||||
|
||||
Returns:
|
||||
Future[Any]: A future that will resolve to the response of the message.
|
||||
Any: The response from the agent.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
# No responses from publishing
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from ._agent import Agent
|
||||
@ -55,19 +54,17 @@ class BaseAgent(ABC, Agent):
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
) -> Any:
|
||||
"""See :py:meth:`agnext.core.AgentRuntime.send_message` for more information."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future = await self._runtime.send_message(
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
|
||||
@ -24,12 +24,14 @@ async def main() -> None:
|
||||
|
||||
task = input("Enter a task: ")
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
await runtime.publish_message(
|
||||
BroadcastMessage(content=UserMessage(content=task, source="human")), namespace="default"
|
||||
)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -18,12 +18,14 @@ async def main() -> None:
|
||||
task = input(f"Enter a task for {file_surfer.name}: ")
|
||||
msg = BroadcastMessage(content=UserMessage(content=task, source="human"))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
await runtime.publish_message(msg, namespace="default")
|
||||
await runtime.publish_message(RequestReplyMessage(), namespace="default")
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -16,9 +16,10 @@ async def main() -> None:
|
||||
runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
|
||||
|
||||
task_message = UserMessage(content="Test Message", source="User")
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(BroadcastMessage(task_message), namespace="default")
|
||||
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -43,7 +43,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
@message_handler
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
response = await self.send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||
response = self.send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||
try:
|
||||
val = await response
|
||||
assert isinstance(val, MessageType)
|
||||
@ -59,10 +59,14 @@ async def test_cancellation_with_token() -> None:
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
token = CancellationToken()
|
||||
response = await runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
@ -83,9 +87,12 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = await runtime.send_message(MessageType(), nested, cancellation_token=token)
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
token.cancel()
|
||||
|
||||
@ -108,9 +115,12 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = await runtime.send_message(MessageType(), nested, cancellation_token=token)
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
# allow the inner agent to process
|
||||
await runtime.process_next()
|
||||
|
||||
@ -29,11 +29,12 @@ async def test_register_receives_publish() -> None:
|
||||
await queue.put((namespace, message.content))
|
||||
|
||||
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(Message("first message"), namespace="default")
|
||||
await runtime.publish_message(Message("second message"), namespace="default")
|
||||
await runtime.publish_message(Message("third message"), namespace="default")
|
||||
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
assert queue.qsize() == 3
|
||||
assert queue.get_nowait() == ("default", "first message")
|
||||
|
||||
@ -20,11 +20,11 @@ async def test_intervention_count_messages() -> None:
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
await run_context.stop()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
@ -41,13 +41,12 @@ async def test_intervention_drop_send() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 0
|
||||
@ -64,13 +63,12 @@ async def test_intervention_drop_response() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -87,13 +85,12 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
@ -112,13 +109,11 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
run_context = runtime.start()
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
@ -28,9 +28,10 @@ async def test_register_receives_publish() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(MessageType(), namespace="default")
|
||||
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
@ -54,13 +55,15 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish messages
|
||||
for _ in range(num_initial_messages):
|
||||
await runtime.publish_message(CascadingMessageType(round=1), namespace="default")
|
||||
|
||||
# Process until idle.
|
||||
await runtime.process_until_idle()
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user