mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-16 04:31:17 +00:00
Remove require_response, rename broadcast to publish, remove publish responses (#25)
* rename broadcast to publish * remove require response, remove responses from publishing
This commit is contained in:
parent
b6dd861166
commit
cb55e00819
@ -19,10 +19,7 @@ class Inner(TypeRoutedAgent):
|
|||||||
super().__init__(name, router)
|
super().__init__(name, router)
|
||||||
|
|
||||||
@message_handler(MessageType)
|
@message_handler(MessageType)
|
||||||
async def on_new_message(
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||||
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> MessageType:
|
|
||||||
assert require_response
|
|
||||||
return MessageType(body=f"Inner: {message.body}", sender=self.name)
|
return MessageType(body=f"Inner: {message.body}", sender=self.name)
|
||||||
|
|
||||||
|
|
||||||
@ -32,11 +29,8 @@ class Outer(TypeRoutedAgent):
|
|||||||
self._inner = inner
|
self._inner = inner
|
||||||
|
|
||||||
@message_handler(MessageType)
|
@message_handler(MessageType)
|
||||||
async def on_new_message(
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||||
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
|
inner_response = self._send_message(message, self._inner)
|
||||||
) -> MessageType:
|
|
||||||
assert require_response
|
|
||||||
inner_response = self._send_message(message, self._inner, require_response=True)
|
|
||||||
inner_message = await inner_response
|
inner_message = await inner_response
|
||||||
assert isinstance(inner_message, MessageType)
|
assert isinstance(inner_message, MessageType)
|
||||||
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)
|
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)
|
||||||
|
@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"]
|
|||||||
ignore = ["F401", "E501"]
|
ignore = ["F401", "E501"]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
files = ["src", "examples"]
|
files = ["src", "examples", "tests"]
|
||||||
|
|
||||||
strict = true
|
strict = true
|
||||||
python_version = "3.10"
|
python_version = "3.10"
|
||||||
@ -53,7 +53,7 @@ disallow_untyped_decorators = true
|
|||||||
disallow_any_unimported = true
|
disallow_any_unimported = true
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["src", "examples"]
|
include = ["src", "examples", "tests"]
|
||||||
typeCheckingMode = "strict"
|
typeCheckingMode = "strict"
|
||||||
reportUnnecessaryIsInstance = false
|
reportUnnecessaryIsInstance = false
|
||||||
reportMissingTypeStubs = false
|
reportMissingTypeStubs = false
|
||||||
|
@ -16,12 +16,12 @@ ProducesT = TypeVar("ProducesT", covariant=True)
|
|||||||
def message_handler(
|
def message_handler(
|
||||||
*target_types: Type[ReceivesT],
|
*target_types: Type[ReceivesT],
|
||||||
) -> Callable[
|
) -> Callable[
|
||||||
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
||||||
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||||
]:
|
]:
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||||
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
||||||
# Convert target_types to list and stash
|
# Convert target_types to list and stash
|
||||||
func._target_types = list(target_types) # type: ignore
|
func._target_types = list(target_types) # type: ignore
|
||||||
return func
|
return func
|
||||||
@ -34,7 +34,7 @@ class TypeRoutedAgent(BaseAgent):
|
|||||||
super().__init__(name, router)
|
super().__init__(name, router)
|
||||||
|
|
||||||
# Self is already bound to the handlers
|
# Self is already bound to the handlers
|
||||||
self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||||
|
|
||||||
router.add_agent(self)
|
router.add_agent(self)
|
||||||
|
|
||||||
@ -49,17 +49,13 @@ class TypeRoutedAgent(BaseAgent):
|
|||||||
def subscriptions(self) -> Sequence[Type[Any]]:
|
def subscriptions(self) -> Sequence[Type[Any]]:
|
||||||
return list(self._handlers.keys())
|
return list(self._handlers.keys())
|
||||||
|
|
||||||
async def on_message(
|
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> Any | None:
|
|
||||||
key_type: Type[Any] = type(message) # type: ignore
|
key_type: Type[Any] = type(message) # type: ignore
|
||||||
handler = self._handlers.get(key_type) # type: ignore
|
handler = self._handlers.get(key_type) # type: ignore
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
return await handler(message, require_response, cancellation_token)
|
return await handler(message, cancellation_token)
|
||||||
else:
|
else:
|
||||||
return await self.on_unhandled_message(message, require_response, cancellation_token)
|
return await self.on_unhandled_message(message, cancellation_token)
|
||||||
|
|
||||||
async def on_unhandled_message(
|
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
|
||||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> NoReturn:
|
|
||||||
raise CantHandleException(f"Unhandled message: {message}")
|
raise CantHandleException(f"Unhandled message: {message}")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Dict, List, Sequence, Set, cast
|
from typing import Any, Awaitable, Dict, List, Set
|
||||||
|
|
||||||
from agnext.core.cancellation_token import CancellationToken
|
from agnext.core.cancellation_token import CancellationToken
|
||||||
from agnext.core.exceptions import MessageDroppedException
|
from agnext.core.exceptions import MessageDroppedException
|
||||||
@ -12,15 +12,13 @@ from ..core.agent_runtime import AgentRuntime
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class BroadcastMessageEnvelope:
|
class PublishMessageEnvelope:
|
||||||
"""A message envelope for broadcasting messages to all agents that can handle
|
"""A message envelope for publishing messages to all agents that can handle
|
||||||
the message of the type T."""
|
the message of the type T."""
|
||||||
|
|
||||||
message: Any
|
message: Any
|
||||||
future: Future[Sequence[Any] | None]
|
|
||||||
cancellation_token: CancellationToken
|
cancellation_token: CancellationToken
|
||||||
sender: Agent | None
|
sender: Agent | None
|
||||||
require_response: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
@ -31,9 +29,8 @@ class SendMessageEnvelope:
|
|||||||
message: Any
|
message: Any
|
||||||
sender: Agent | None
|
sender: Agent | None
|
||||||
recipient: Agent
|
recipient: Agent
|
||||||
future: Future[Any | None]
|
future: Future[Any]
|
||||||
cancellation_token: CancellationToken
|
cancellation_token: CancellationToken
|
||||||
require_response: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
@ -46,20 +43,9 @@ class ResponseMessageEnvelope:
|
|||||||
recipient: Agent | None
|
recipient: Agent | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class BroadcastResponseMessageEnvelope:
|
|
||||||
"""A message envelope for sending a response to a message."""
|
|
||||||
|
|
||||||
message: Sequence[Any]
|
|
||||||
future: Future[Sequence[Any]]
|
|
||||||
recipient: Agent | None
|
|
||||||
|
|
||||||
|
|
||||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
||||||
self._message_queue: List[
|
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||||
BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope
|
|
||||||
] = []
|
|
||||||
self._per_type_subscribers: Dict[type, List[Agent]] = {}
|
self._per_type_subscribers: Dict[type, List[Agent]] = {}
|
||||||
self._agents: Set[Agent] = set()
|
self._agents: Set[Agent] = set()
|
||||||
self._before_send = before_send
|
self._before_send = before_send
|
||||||
@ -77,7 +63,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
message: Any,
|
message: Any,
|
||||||
recipient: Agent,
|
recipient: Agent,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
sender: Agent | None = None,
|
sender: Agent | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Any | None]:
|
) -> Future[Any | None]:
|
||||||
@ -95,36 +80,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
future=future,
|
future=future,
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
require_response=require_response,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return future
|
return future
|
||||||
|
|
||||||
# send message, require_response=False -> returns after delivery, gives None
|
def publish_message(
|
||||||
# send message, require_response=True -> returns after handling, gives Response
|
|
||||||
def broadcast_message(
|
|
||||||
self,
|
self,
|
||||||
message: Any,
|
message: Any,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
sender: Agent | None = None,
|
sender: Agent | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Sequence[Any] | None]:
|
) -> Future[None]:
|
||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self._message_queue.append(
|
self._message_queue.append(
|
||||||
BroadcastMessageEnvelope(
|
PublishMessageEnvelope(
|
||||||
message=message,
|
message=message,
|
||||||
future=future,
|
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
require_response=require_response,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
future.set_result(None)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||||
@ -134,20 +114,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
try:
|
try:
|
||||||
response = await recipient.on_message(
|
response = await recipient.on_message(
|
||||||
message_envelope.message,
|
message_envelope.message,
|
||||||
require_response=message_envelope.require_response,
|
|
||||||
cancellation_token=message_envelope.cancellation_token,
|
cancellation_token=message_envelope.cancellation_token,
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
message_envelope.future.set_exception(e)
|
message_envelope.future.set_exception(e)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not message_envelope.require_response and response is not None:
|
|
||||||
raise Exception("Recipient returned a response for a message that did not request a response")
|
|
||||||
|
|
||||||
if message_envelope.require_response and response is None:
|
|
||||||
raise Exception("Recipient did not return a response for a message that requested a response")
|
|
||||||
|
|
||||||
if message_envelope.require_response:
|
|
||||||
self._message_queue.append(
|
self._message_queue.append(
|
||||||
ResponseMessageEnvelope(
|
ResponseMessageEnvelope(
|
||||||
message=response,
|
message=response,
|
||||||
@ -156,42 +128,27 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
recipient=message_envelope.sender,
|
recipient=message_envelope.sender,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
message_envelope.future.set_result(None)
|
|
||||||
|
|
||||||
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None:
|
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||||
responses: List[Awaitable[Any]] = []
|
responses: List[Awaitable[Any]] = []
|
||||||
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
|
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
|
||||||
future = agent.on_message(
|
future = agent.on_message(
|
||||||
message_envelope.message,
|
message_envelope.message,
|
||||||
require_response=message_envelope.require_response,
|
|
||||||
cancellation_token=message_envelope.cancellation_token,
|
cancellation_token=message_envelope.cancellation_token,
|
||||||
)
|
)
|
||||||
responses.append(future)
|
responses.append(future)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
all_responses = await asyncio.gather(*responses)
|
_all_responses = await asyncio.gather(*responses)
|
||||||
except BaseException as e:
|
except BaseException:
|
||||||
message_envelope.future.set_exception(e)
|
# TODO log error
|
||||||
return
|
return
|
||||||
|
|
||||||
if message_envelope.require_response:
|
# TODO if responses are given for a publish
|
||||||
self._message_queue.append(
|
|
||||||
BroadcastResponseMessageEnvelope(
|
|
||||||
message=all_responses,
|
|
||||||
future=cast(Future[Sequence[Any]], message_envelope.future),
|
|
||||||
recipient=message_envelope.sender,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_envelope.future.set_result(None)
|
|
||||||
|
|
||||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||||
message_envelope.future.set_result(message_envelope.message)
|
message_envelope.future.set_result(message_envelope.message)
|
||||||
|
|
||||||
async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None:
|
|
||||||
message_envelope.future.set_result(message_envelope.message)
|
|
||||||
|
|
||||||
async def process_next(self) -> None:
|
async def process_next(self) -> None:
|
||||||
if len(self._message_queue) == 0:
|
if len(self._message_queue) == 0:
|
||||||
# Yield control to the event loop to allow other tasks to run
|
# Yield control to the event loop to allow other tasks to run
|
||||||
@ -211,20 +168,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
message_envelope.message = temp_message
|
message_envelope.message = temp_message
|
||||||
|
|
||||||
asyncio.create_task(self._process_send(message_envelope))
|
asyncio.create_task(self._process_send(message_envelope))
|
||||||
case BroadcastMessageEnvelope(
|
case PublishMessageEnvelope(
|
||||||
message=message,
|
message=message,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
future=future,
|
|
||||||
):
|
):
|
||||||
if self._before_send is not None:
|
if self._before_send is not None:
|
||||||
temp_message = await self._before_send.on_broadcast(message, sender=sender)
|
temp_message = await self._before_send.on_publish(message, sender=sender)
|
||||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||||
future.set_exception(MessageDroppedException())
|
# TODO log message dropped
|
||||||
return
|
return
|
||||||
|
|
||||||
message_envelope.message = temp_message
|
message_envelope.message = temp_message
|
||||||
|
|
||||||
asyncio.create_task(self._process_broadcast(message_envelope))
|
asyncio.create_task(self._process_publish(message_envelope))
|
||||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||||
if self._before_send is not None:
|
if self._before_send is not None:
|
||||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||||
@ -236,16 +192,5 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
|
|
||||||
asyncio.create_task(self._process_response(message_envelope))
|
asyncio.create_task(self._process_response(message_envelope))
|
||||||
|
|
||||||
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
|
|
||||||
if self._before_send is not None:
|
|
||||||
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
|
|
||||||
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
|
|
||||||
future.set_exception(MessageDroppedException())
|
|
||||||
return
|
|
||||||
|
|
||||||
message_envelope.message = list(temp_message_list) # type: ignore
|
|
||||||
|
|
||||||
asyncio.create_task(self._process_broadcast_response(message_envelope))
|
|
||||||
|
|
||||||
# Yield control to the message loop to allow other tasks to run
|
# Yield control to the message loop to allow other tasks to run
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
@ -26,7 +26,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
|||||||
# TODO: use require_response
|
# TODO: use require_response
|
||||||
@message_handler(TextMessage)
|
@message_handler(TextMessage)
|
||||||
async def on_chat_message_with_cancellation(
|
async def on_chat_message_with_cancellation(
|
||||||
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
|
self, message: TextMessage, cancellation_token: CancellationToken
|
||||||
) -> None:
|
) -> None:
|
||||||
print("---------------")
|
print("---------------")
|
||||||
print(f"{self.name} received message from {message.source}: {message.content}")
|
print(f"{self.name} received message from {message.source}: {message.content}")
|
||||||
@ -41,22 +41,13 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
|||||||
)
|
)
|
||||||
self._current_session_window_length += 1
|
self._current_session_window_length += 1
|
||||||
|
|
||||||
if require_response:
|
|
||||||
# TODO ?
|
|
||||||
...
|
|
||||||
|
|
||||||
@message_handler(Reset)
|
@message_handler(Reset)
|
||||||
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||||
# Reset the current session window.
|
# Reset the current session window.
|
||||||
self._current_session_window_length = 0
|
self._current_session_window_length = 0
|
||||||
|
|
||||||
@message_handler(RespondNow)
|
@message_handler(RespondNow)
|
||||||
async def on_respond_now(
|
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> TextMessage | None:
|
|
||||||
if not require_response:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Create a run and wait until it finishes.
|
# Create a run and wait until it finishes.
|
||||||
run = await self._client.beta.threads.runs.create_and_poll(
|
run = await self._client.beta.threads.runs.create_and_poll(
|
||||||
thread_id=self._thread_id,
|
thread_id=self._thread_id,
|
||||||
|
@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
|
|||||||
# TODO: use require_response
|
# TODO: use require_response
|
||||||
@message_handler(RespondNow)
|
@message_handler(RespondNow)
|
||||||
async def on_chat_message_with_cancellation(
|
async def on_chat_message_with_cancellation(
|
||||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
self, message: RespondNow, cancellation_token: CancellationToken
|
||||||
) -> TextMessage:
|
) -> TextMessage:
|
||||||
# Generate a random response.
|
# Generate a random response.
|
||||||
response_body = random.choice(
|
response_body = random.choice(
|
||||||
|
@ -36,9 +36,7 @@ class GroupChat(BaseChatAgent):
|
|||||||
agent_sublists = [agent.subscriptions for agent in self._agents]
|
agent_sublists = [agent.subscriptions for agent in self._agents]
|
||||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||||
|
|
||||||
async def on_message(
|
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> Any | None:
|
|
||||||
if isinstance(message, Reset):
|
if isinstance(message, Reset):
|
||||||
# Reset the history.
|
# Reset the history.
|
||||||
self._history = []
|
self._history = []
|
||||||
@ -48,10 +46,8 @@ class GroupChat(BaseChatAgent):
|
|||||||
# TODO reset...
|
# TODO reset...
|
||||||
return self._output.get_output()
|
return self._output.get_output()
|
||||||
|
|
||||||
# TODO: should we do nothing here?
|
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
|
||||||
# Perhaps it should be saved into the message history?
|
# Should this class disallow it?
|
||||||
if not require_response:
|
|
||||||
return None
|
|
||||||
|
|
||||||
self._history.append(message)
|
self._history.append(message)
|
||||||
round = 0
|
round = 0
|
||||||
@ -67,14 +63,13 @@ class GroupChat(BaseChatAgent):
|
|||||||
_ = await self._send_message(
|
_ = await self._send_message(
|
||||||
self._history[-1],
|
self._history[-1],
|
||||||
agent,
|
agent,
|
||||||
require_response=False,
|
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
|
# TODO handle if response is not None
|
||||||
|
|
||||||
response = await self._send_message(
|
response = await self._send_message(
|
||||||
RespondNow(),
|
RespondNow(),
|
||||||
speaker,
|
speaker,
|
||||||
require_response=True,
|
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,4 +83,5 @@ class GroupChat(BaseChatAgent):
|
|||||||
|
|
||||||
output = self._output.get_output()
|
output = self._output.get_output()
|
||||||
self._output.reset()
|
self._output.reset()
|
||||||
|
self._history.clear()
|
||||||
return output
|
return output
|
||||||
|
@ -34,7 +34,6 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
|||||||
async def on_chat_message(
|
async def on_chat_message(
|
||||||
self,
|
self,
|
||||||
message: ChatMessage,
|
message: ChatMessage,
|
||||||
require_response: bool,
|
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> ChatMessage | None:
|
) -> ChatMessage | None:
|
||||||
# A task is received.
|
# A task is received.
|
||||||
|
@ -11,6 +11,4 @@ class Agent(Protocol):
|
|||||||
@property
|
@property
|
||||||
def subscriptions(self) -> Sequence[type]: ...
|
def subscriptions(self) -> Sequence[type]: ...
|
||||||
|
|
||||||
async def on_message(
|
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> Any | None: ...
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from typing import Any, Protocol, Sequence
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from agnext.core.agent import Agent
|
from agnext.core.agent import Agent
|
||||||
from agnext.core.cancellation_token import CancellationToken
|
from agnext.core.cancellation_token import CancellationToken
|
||||||
@ -16,17 +16,15 @@ class AgentRuntime(Protocol):
|
|||||||
message: Any,
|
message: Any,
|
||||||
recipient: Agent,
|
recipient: Agent,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
sender: Agent | None = None,
|
sender: Agent | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Any | None]: ...
|
) -> Future[Any]: ...
|
||||||
|
|
||||||
# Returns the response of all handling agents
|
# No responses from publishing
|
||||||
def broadcast_message(
|
def publish_message(
|
||||||
self,
|
self,
|
||||||
message: Any,
|
message: Any,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
sender: Agent | None = None,
|
sender: Agent | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Sequence[Any] | None]: ...
|
) -> Future[None]: ...
|
||||||
|
@ -29,9 +29,7 @@ class BaseAgent(ABC, Agent):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def on_message(
|
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
|
||||||
) -> Any | None: ...
|
|
||||||
|
|
||||||
# Returns the response of the message
|
# Returns the response of the message
|
||||||
def _send_message(
|
def _send_message(
|
||||||
@ -39,9 +37,8 @@ class BaseAgent(ABC, Agent):
|
|||||||
message: Any,
|
message: Any,
|
||||||
recipient: Agent,
|
recipient: Agent,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Any | None]:
|
) -> Future[Any]:
|
||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
|
|
||||||
@ -49,23 +46,18 @@ class BaseAgent(ABC, Agent):
|
|||||||
message,
|
message,
|
||||||
sender=self,
|
sender=self,
|
||||||
recipient=recipient,
|
recipient=recipient,
|
||||||
require_response=require_response,
|
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
cancellation_token.link_future(future)
|
cancellation_token.link_future(future)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
# Returns the response of all handling agents
|
def _publish_message(
|
||||||
def _broadcast_message(
|
|
||||||
self,
|
self,
|
||||||
message: Any,
|
message: Any,
|
||||||
*,
|
*,
|
||||||
require_response: bool = True,
|
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Future[Sequence[Any] | None]:
|
) -> Future[None]:
|
||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
future = self._router.broadcast_message(
|
future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
|
||||||
message, sender=self, require_response=require_response, cancellation_token=cancellation_token
|
|
||||||
)
|
|
||||||
return future
|
return future
|
||||||
|
@ -12,9 +12,9 @@ InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
|
|||||||
|
|
||||||
class InterventionHandler(Protocol):
|
class InterventionHandler(Protocol):
|
||||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
|
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
|
||||||
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
|
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
|
||||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
|
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
|
||||||
async def on_broadcast_response(
|
async def on_publish_response(
|
||||||
self, message: Sequence[Any], *, recipient: Agent | None
|
self, message: Sequence[Any], *, recipient: Agent | None
|
||||||
) -> Sequence[Any] | type[DropMessage]: ...
|
) -> Sequence[Any] | type[DropMessage]: ...
|
||||||
|
|
||||||
@ -23,13 +23,13 @@ class DefaultInterventionHandler(InterventionHandler):
|
|||||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
|
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def on_broadcast_response(
|
async def on_publish_response(
|
||||||
self, message: Sequence[Any], *, recipient: Agent | None
|
self, message: Sequence[Any], *, recipient: Agent | None
|
||||||
) -> Sequence[Any] | type[DropMessage]:
|
) -> Sequence[Any] | type[DropMessage]:
|
||||||
return message
|
return message
|
||||||
|
@ -23,7 +23,7 @@ class LongRunningAgent(TypeRoutedAgent):
|
|||||||
self.cancelled = False
|
self.cancelled = False
|
||||||
|
|
||||||
@message_handler(MessageType)
|
@message_handler(MessageType)
|
||||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||||
self.called = True
|
self.called = True
|
||||||
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
||||||
cancellation_token.link_future(sleep)
|
cancellation_token.link_future(sleep)
|
||||||
@ -42,10 +42,9 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
|||||||
self._nested_agent = nested_agent
|
self._nested_agent = nested_agent
|
||||||
|
|
||||||
@message_handler(MessageType)
|
@message_handler(MessageType)
|
||||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||||
assert require_response == True
|
|
||||||
self.called = True
|
self.called = True
|
||||||
response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token)
|
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||||
try:
|
try:
|
||||||
val = await response
|
val = await response
|
||||||
assert isinstance(val, MessageType)
|
assert isinstance(val, MessageType)
|
||||||
|
@ -20,7 +20,7 @@ class LoopbackAgent(TypeRoutedAgent):
|
|||||||
|
|
||||||
|
|
||||||
@message_handler(MessageType)
|
@message_handler(MessageType)
|
||||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||||
self.num_calls += 1
|
self.num_calls += 1
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class LoopbackAgent(TypeRoutedAgent):
|
|||||||
async def test_intervention_count_messages() -> None:
|
async def test_intervention_count_messages() -> None:
|
||||||
|
|
||||||
class DebugInterventionHandler(DefaultInterventionHandler):
|
class DebugInterventionHandler(DefaultInterventionHandler):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.num_messages = 0
|
self.num_messages = 0
|
||||||
|
|
||||||
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType:
|
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user