Remove require_response, rename broadcast to publish, remove publish responses (#25)

* rename broadcast to publish

* remove require response, remove responses from publishing
This commit is contained in:
Jack Gerrits 2024-05-26 08:45:02 -04:00 committed by GitHub
parent b6dd861166
commit cb55e00819
14 changed files with 69 additions and 161 deletions

View File

@ -19,10 +19,7 @@ class Inner(TypeRoutedAgent):
super().__init__(name, router) super().__init__(name, router)
@message_handler(MessageType) @message_handler(MessageType)
async def on_new_message( async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
) -> MessageType:
assert require_response
return MessageType(body=f"Inner: {message.body}", sender=self.name) return MessageType(body=f"Inner: {message.body}", sender=self.name)
@ -32,11 +29,8 @@ class Outer(TypeRoutedAgent):
self._inner = inner self._inner = inner
@message_handler(MessageType) @message_handler(MessageType)
async def on_new_message( async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken inner_response = self._send_message(message, self._inner)
) -> MessageType:
assert require_response
inner_response = self._send_message(message, self._inner, require_response=True)
inner_message = await inner_response inner_message = await inner_response
assert isinstance(inner_message, MessageType) assert isinstance(inner_message, MessageType)
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name) return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)

View File

@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"]
ignore = ["F401", "E501"] ignore = ["F401", "E501"]
[tool.mypy] [tool.mypy]
files = ["src", "examples"] files = ["src", "examples", "tests"]
strict = true strict = true
python_version = "3.10" python_version = "3.10"
@ -53,7 +53,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true disallow_any_unimported = true
[tool.pyright] [tool.pyright]
include = ["src", "examples"] include = ["src", "examples", "tests"]
typeCheckingMode = "strict" typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false reportMissingTypeStubs = false

View File

@ -16,12 +16,12 @@ ProducesT = TypeVar("ProducesT", covariant=True)
def message_handler( def message_handler(
*target_types: Type[ReceivesT], *target_types: Type[ReceivesT],
) -> Callable[ ) -> Callable[
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
]: ]:
def decorator( def decorator(
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: ) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
# Convert target_types to list and stash # Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore func._target_types = list(target_types) # type: ignore
return func return func
@ -34,7 +34,7 @@ class TypeRoutedAgent(BaseAgent):
super().__init__(name, router) super().__init__(name, router)
# Self is already bound to the handlers # Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
router.add_agent(self) router.add_agent(self)
@ -49,17 +49,13 @@ class TypeRoutedAgent(BaseAgent):
def subscriptions(self) -> Sequence[Type[Any]]: def subscriptions(self) -> Sequence[Type[Any]]:
return list(self._handlers.keys()) return list(self._handlers.keys())
async def on_message( async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
key_type: Type[Any] = type(message) # type: ignore key_type: Type[Any] = type(message) # type: ignore
handler = self._handlers.get(key_type) # type: ignore handler = self._handlers.get(key_type) # type: ignore
if handler is not None: if handler is not None:
return await handler(message, require_response, cancellation_token) return await handler(message, cancellation_token)
else: else:
return await self.on_unhandled_message(message, require_response, cancellation_token) return await self.on_unhandled_message(message, cancellation_token)
async def on_unhandled_message( async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> NoReturn:
raise CantHandleException(f"Unhandled message: {message}") raise CantHandleException(f"Unhandled message: {message}")

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
from asyncio import Future from asyncio import Future
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Sequence, Set, cast from typing import Any, Awaitable, Dict, List, Set
from agnext.core.cancellation_token import CancellationToken from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException from agnext.core.exceptions import MessageDroppedException
@ -12,15 +12,13 @@ from ..core.agent_runtime import AgentRuntime
@dataclass(kw_only=True) @dataclass(kw_only=True)
class BroadcastMessageEnvelope: class PublishMessageEnvelope:
"""A message envelope for broadcasting messages to all agents that can handle """A message envelope for publishing messages to all agents that can handle
the message of the type T.""" the message of the type T."""
message: Any message: Any
future: Future[Sequence[Any] | None]
cancellation_token: CancellationToken cancellation_token: CancellationToken
sender: Agent | None sender: Agent | None
require_response: bool
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -31,9 +29,8 @@ class SendMessageEnvelope:
message: Any message: Any
sender: Agent | None sender: Agent | None
recipient: Agent recipient: Agent
future: Future[Any | None] future: Future[Any]
cancellation_token: CancellationToken cancellation_token: CancellationToken
require_response: bool
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -46,20 +43,9 @@ class ResponseMessageEnvelope:
recipient: Agent | None recipient: Agent | None
@dataclass(kw_only=True)
class BroadcastResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Sequence[Any]
future: Future[Sequence[Any]]
recipient: Agent | None
class SingleThreadedAgentRuntime(AgentRuntime): class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None: def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
self._message_queue: List[ self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope
] = []
self._per_type_subscribers: Dict[type, List[Agent]] = {} self._per_type_subscribers: Dict[type, List[Agent]] = {}
self._agents: Set[Agent] = set() self._agents: Set[Agent] = set()
self._before_send = before_send self._before_send = before_send
@ -77,7 +63,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message: Any, message: Any,
recipient: Agent, recipient: Agent,
*, *,
require_response: bool = True,
sender: Agent | None = None, sender: Agent | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]: ) -> Future[Any | None]:
@ -95,36 +80,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future=future, future=future,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
sender=sender, sender=sender,
require_response=require_response,
) )
) )
return future return future
# send message, require_response=False -> returns after delivery, gives None def publish_message(
# send message, require_response=True -> returns after handling, gives Response
def broadcast_message(
self, self,
message: Any, message: Any,
*, *,
require_response: bool = True,
sender: Agent | None = None, sender: Agent | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]: ) -> Future[None]:
if cancellation_token is None: if cancellation_token is None:
cancellation_token = CancellationToken() cancellation_token = CancellationToken()
future = asyncio.get_event_loop().create_future()
self._message_queue.append( self._message_queue.append(
BroadcastMessageEnvelope( PublishMessageEnvelope(
message=message, message=message,
future=future,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
sender=sender, sender=sender,
require_response=require_response,
) )
) )
future = asyncio.get_event_loop().create_future()
future.set_result(None)
return future return future
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
@ -134,20 +114,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
try: try:
response = await recipient.on_message( response = await recipient.on_message(
message_envelope.message, message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token, cancellation_token=message_envelope.cancellation_token,
) )
except BaseException as e: except BaseException as e:
message_envelope.future.set_exception(e) message_envelope.future.set_exception(e)
return return
if not message_envelope.require_response and response is not None:
raise Exception("Recipient returned a response for a message that did not request a response")
if message_envelope.require_response and response is None:
raise Exception("Recipient did not return a response for a message that requested a response")
if message_envelope.require_response:
self._message_queue.append( self._message_queue.append(
ResponseMessageEnvelope( ResponseMessageEnvelope(
message=response, message=response,
@ -156,42 +128,27 @@ class SingleThreadedAgentRuntime(AgentRuntime):
recipient=message_envelope.sender, recipient=message_envelope.sender,
) )
) )
else:
message_envelope.future.set_result(None)
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None: async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
responses: List[Awaitable[Any]] = [] responses: List[Awaitable[Any]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
future = agent.on_message( future = agent.on_message(
message_envelope.message, message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token, cancellation_token=message_envelope.cancellation_token,
) )
responses.append(future) responses.append(future)
try: try:
all_responses = await asyncio.gather(*responses) _all_responses = await asyncio.gather(*responses)
except BaseException as e: except BaseException:
message_envelope.future.set_exception(e) # TODO log error
return return
if message_envelope.require_response: # TODO if responses are given for a publish
self._message_queue.append(
BroadcastResponseMessageEnvelope(
message=all_responses,
future=cast(Future[Sequence[Any]], message_envelope.future),
recipient=message_envelope.sender,
)
)
else:
message_envelope.future.set_result(None)
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message) message_envelope.future.set_result(message_envelope.message)
async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None: async def process_next(self) -> None:
if len(self._message_queue) == 0: if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run # Yield control to the event loop to allow other tasks to run
@ -211,20 +168,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message message_envelope.message = temp_message
asyncio.create_task(self._process_send(message_envelope)) asyncio.create_task(self._process_send(message_envelope))
case BroadcastMessageEnvelope( case PublishMessageEnvelope(
message=message, message=message,
sender=sender, sender=sender,
future=future,
): ):
if self._before_send is not None: if self._before_send is not None:
temp_message = await self._before_send.on_broadcast(message, sender=sender) temp_message = await self._before_send.on_publish(message, sender=sender)
if temp_message is DropMessage or isinstance(temp_message, DropMessage): if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException()) # TODO log message dropped
return return
message_envelope.message = temp_message message_envelope.message = temp_message
asyncio.create_task(self._process_broadcast(message_envelope)) asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None: if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient) temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
@ -236,16 +192,5 @@ class SingleThreadedAgentRuntime(AgentRuntime):
asyncio.create_task(self._process_response(message_envelope)) asyncio.create_task(self._process_response(message_envelope))
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
if self._before_send is not None:
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = list(temp_message_list) # type: ignore
asyncio.create_task(self._process_broadcast_response(message_envelope))
# Yield control to the message loop to allow other tasks to run # Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -26,7 +26,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response # TODO: use require_response
@message_handler(TextMessage) @message_handler(TextMessage)
async def on_chat_message_with_cancellation( async def on_chat_message_with_cancellation(
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken self, message: TextMessage, cancellation_token: CancellationToken
) -> None: ) -> None:
print("---------------") print("---------------")
print(f"{self.name} received message from {message.source}: {message.content}") print(f"{self.name} received message from {message.source}: {message.content}")
@ -41,22 +41,13 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
) )
self._current_session_window_length += 1 self._current_session_window_length += 1
if require_response:
# TODO ?
...
@message_handler(Reset) @message_handler(Reset)
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None: async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window. # Reset the current session window.
self._current_session_window_length = 0 self._current_session_window_length = 0
@message_handler(RespondNow) @message_handler(RespondNow)
async def on_respond_now( async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
) -> TextMessage | None:
if not require_response:
return None
# Create a run and wait until it finishes. # Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll( run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id, thread_id=self._thread_id,

View File

@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response # TODO: use require_response
@message_handler(RespondNow) @message_handler(RespondNow)
async def on_chat_message_with_cancellation( async def on_chat_message_with_cancellation(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage: ) -> TextMessage:
# Generate a random response. # Generate a random response.
response_body = random.choice( response_body = random.choice(

View File

@ -36,9 +36,7 @@ class GroupChat(BaseChatAgent):
agent_sublists = [agent.subscriptions for agent in self._agents] agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
async def on_message( async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
if isinstance(message, Reset): if isinstance(message, Reset):
# Reset the history. # Reset the history.
self._history = [] self._history = []
@ -48,10 +46,8 @@ class GroupChat(BaseChatAgent):
# TODO reset... # TODO reset...
return self._output.get_output() return self._output.get_output()
# TODO: should we do nothing here? # TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Perhaps it should be saved into the message history? # Should this class disallow it?
if not require_response:
return None
self._history.append(message) self._history.append(message)
round = 0 round = 0
@ -67,14 +63,13 @@ class GroupChat(BaseChatAgent):
_ = await self._send_message( _ = await self._send_message(
self._history[-1], self._history[-1],
agent, agent,
require_response=False,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
) )
# TODO handle if response is not None
response = await self._send_message( response = await self._send_message(
RespondNow(), RespondNow(),
speaker, speaker,
require_response=True,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
) )
@ -88,4 +83,5 @@ class GroupChat(BaseChatAgent):
output = self._output.get_output() output = self._output.get_output()
self._output.reset() self._output.reset()
self._history.clear()
return output return output

View File

@ -34,7 +34,6 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
async def on_chat_message( async def on_chat_message(
self, self,
message: ChatMessage, message: ChatMessage,
require_response: bool,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
) -> ChatMessage | None: ) -> ChatMessage | None:
# A task is received. # A task is received.

View File

@ -11,6 +11,4 @@ class Agent(Protocol):
@property @property
def subscriptions(self) -> Sequence[type]: ... def subscriptions(self) -> Sequence[type]: ...
async def on_message( async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None: ...

View File

@ -1,5 +1,5 @@
from asyncio import Future from asyncio import Future
from typing import Any, Protocol, Sequence from typing import Any, Protocol
from agnext.core.agent import Agent from agnext.core.agent import Agent
from agnext.core.cancellation_token import CancellationToken from agnext.core.cancellation_token import CancellationToken
@ -16,17 +16,15 @@ class AgentRuntime(Protocol):
message: Any, message: Any,
recipient: Agent, recipient: Agent,
*, *,
require_response: bool = True,
sender: Agent | None = None, sender: Agent | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]: ... ) -> Future[Any]: ...
# Returns the response of all handling agents # No responses from publishing
def broadcast_message( def publish_message(
self, self,
message: Any, message: Any,
*, *,
require_response: bool = True,
sender: Agent | None = None, sender: Agent | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]: ... ) -> Future[None]: ...

View File

@ -29,9 +29,7 @@ class BaseAgent(ABC, Agent):
return [] return []
@abstractmethod @abstractmethod
async def on_message( async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None: ...
# Returns the response of the message # Returns the response of the message
def _send_message( def _send_message(
@ -39,9 +37,8 @@ class BaseAgent(ABC, Agent):
message: Any, message: Any,
recipient: Agent, recipient: Agent,
*, *,
require_response: bool = True,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]: ) -> Future[Any]:
if cancellation_token is None: if cancellation_token is None:
cancellation_token = CancellationToken() cancellation_token = CancellationToken()
@ -49,23 +46,18 @@ class BaseAgent(ABC, Agent):
message, message,
sender=self, sender=self,
recipient=recipient, recipient=recipient,
require_response=require_response,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
) )
cancellation_token.link_future(future) cancellation_token.link_future(future)
return future return future
# Returns the response of all handling agents def _publish_message(
def _broadcast_message(
self, self,
message: Any, message: Any,
*, *,
require_response: bool = True,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]: ) -> Future[None]:
if cancellation_token is None: if cancellation_token is None:
cancellation_token = CancellationToken() cancellation_token = CancellationToken()
future = self._router.broadcast_message( future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
message, sender=self, require_response=require_response, cancellation_token=cancellation_token
)
return future return future

View File

@ -12,9 +12,9 @@ InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
class InterventionHandler(Protocol): class InterventionHandler(Protocol):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ... async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ... async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
async def on_broadcast_response( async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]: ... ) -> Sequence[Any] | type[DropMessage]: ...
@ -23,13 +23,13 @@ class DefaultInterventionHandler(InterventionHandler):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
return message return message
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
return message return message
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
return message return message
async def on_broadcast_response( async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]: ) -> Sequence[Any] | type[DropMessage]:
return message return message

View File

@ -23,7 +23,7 @@ class LongRunningAgent(TypeRoutedAgent):
self.cancelled = False self.cancelled = False
@message_handler(MessageType) @message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100)) sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep) cancellation_token.link_future(sleep)
@ -42,10 +42,9 @@ class NestingLongRunningAgent(TypeRoutedAgent):
self._nested_agent = nested_agent self._nested_agent = nested_agent
@message_handler(MessageType) @message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
assert require_response == True
self.called = True self.called = True
response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token) response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
try: try:
val = await response val = await response
assert isinstance(val, MessageType) assert isinstance(val, MessageType)

View File

@ -20,7 +20,7 @@ class LoopbackAgent(TypeRoutedAgent):
@message_handler(MessageType) @message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1 self.num_calls += 1
return message return message
@ -28,7 +28,7 @@ class LoopbackAgent(TypeRoutedAgent):
async def test_intervention_count_messages() -> None: async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler): class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self): def __init__(self) -> None:
self.num_messages = 0 self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: