Simplify handler decorator (#50)

* Simplify handler decorator

* add more tests

* mypy

* formatting

* fix 3.10 and improve type handling of decorator

* test fix

* format
This commit is contained in:
Jack Gerrits 2024-06-05 08:51:49 -04:00 committed by GitHub
parent ad513d5017
commit 8cb530f65e
10 changed files with 191 additions and 29 deletions

View File

@ -16,7 +16,7 @@ class Inner(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
return MessageType(body=f"Inner: {message.body}", sender=self.name)
@ -26,7 +26,7 @@ class Outer(TypeRoutedAgent):
super().__init__(name, router)
self._inner = inner
@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
inner_response = self._send_message(message, self._inner)
inner_message = await inner_response

View File

@ -38,17 +38,17 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
self._chat_messages: List[Message] = []
self._function_executor = function_executor
@message_handler(TextMessage)
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Add a user message.
self._chat_messages.append(message)
@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the chat messages.
self._chat_messages = []
@message_handler(RespondNow)
@message_handler()
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
@ -101,7 +101,7 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
# Return the response.
return final_response
@message_handler(FunctionCallMessage)
@message_handler()
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
) -> FunctionExecutionResultMessage:

View File

@ -24,7 +24,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
self._assistant_id = assistant_id
self._thread_id = thread_id
@message_handler(TextMessage)
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
@ -34,7 +34,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
metadata={"sender": message.source},
)
@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Get all messages in this thread.
all_msgs: List[str] = []
@ -52,7 +52,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
assert status.deleted is True
@message_handler(RespondNow)
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Handle response format.
if message.response_format == ResponseFormat.json_object:

View File

@ -35,16 +35,16 @@ class GroupChat(BaseChatAgent, TypeRoutedAgent):
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
self._history.clear()
@message_handler(RespondNow)
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
return self._output.get_output()
@message_handler(TextMessage)
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> Any:
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Should this class disallow it?

View File

@ -34,7 +34,7 @@ class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
def children(self) -> Sequence[str]:
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
@message_handler(TextMessage)
@message_handler()
async def on_text_message(
self,
message: TextMessage,

View File

@ -0,0 +1,18 @@
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from ...core import AgentRuntime
from ..agents.base import BaseChatAgent
class TwoAgentChat(GroupChat):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agent1: BaseChatAgent,
agent2: BaseChatAgent,
num_rounds: int,
output: GroupChatOutput,
) -> None:
super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output)

View File

@ -1,28 +1,132 @@
from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar
import logging
from functools import wraps
from types import NoneType, UnionType
from typing import (
Any,
Callable,
Coroutine,
Dict,
Literal,
NoReturn,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
runtime_checkable,
)
from agnext.core import AgentRuntime, BaseAgent, CancellationToken
from agnext.core.exceptions import CantHandleException
ReceivesT = TypeVar("ReceivesT")
logger = logging.getLogger("agnext")
ReceivesT = TypeVar("ReceivesT", contravariant=True)
ProducesT = TypeVar("ProducesT", covariant=True)
# TODO: Generic typevar bound binding U to agent type
# Can't do because python doesnt support it
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
def is_optional(t: object) -> bool:
origin = get_origin(t)
return origin is Optional
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
class AnyType:
pass
def get_types(t: object) -> Sequence[Type[Any]] | None:
if is_union(t):
return get_args(t)
elif is_optional(t):
return tuple(list(get_args(t)) + [NoneType])
elif t is Any:
return (AnyType,)
elif isinstance(t, type):
return (t,)
elif isinstance(t, NoneType):
return (NoneType,)
else:
return None
@runtime_checkable
class MessageHandler(Protocol[ReceivesT, ProducesT]):
target_types: Sequence[type]
produces_types: Sequence[type]
is_message_handler: Literal[True]
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocl for the outer function to check checked arg names
def message_handler(
*target_types: Type[ReceivesT],
strict: bool = True,
) -> Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]:
def decorator(
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")
print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
# Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore
return func
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
if strict:
if type(message) not in target_types:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, cancellation_token)
if strict:
if return_value is not AnyType and type(return_value) not in return_types:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
elif return_value is not AnyType:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
return wrapper_handler
return decorator
@ -35,9 +139,10 @@ class TypeRoutedAgent(BaseAgent):
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
if hasattr(handler, "_target_types"):
for target_type in handler._target_types:
self._handlers[target_type] = handler
if hasattr(handler, "is_message_handler"):
message_handler = cast(MessageHandler[Any, Any], handler)
for target_type in message_handler.target_types:
self._handlers[target_type] = message_handler
super().__init__(name, router)

View File

@ -22,7 +22,7 @@ class LongRunningAgent(TypeRoutedAgent):
self.called = False
self.cancelled = False
@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
@ -41,7 +41,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
self.cancelled = False
self._nested_agent = nested_agent
@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)

View File

@ -19,7 +19,7 @@ class LoopbackAgent(TypeRoutedAgent):
self.num_calls = 0
@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1
return message

39
tests/test_types.py Normal file
View File

@ -0,0 +1,39 @@
from types import NoneType
from typing import Any, Optional, Union
from agnext.components.type_routed_agent import AnyType, get_types, message_handler
from agnext.core import CancellationToken
def test_get_types() -> None:
assert get_types(Union[int, str]) == (int, str)
assert get_types(int | str) == (int, str)
assert get_types(int) == (int,)
assert get_types(str) == (str,)
assert get_types("test") is None
assert get_types(Optional[int]) == (int, NoneType)
assert get_types(NoneType) == (NoneType,)
assert get_types(None) == (NoneType,)
def test_handler() -> None:
class HandlerClass:
@message_handler()
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
return None
@message_handler()
async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None:
return None
assert HandlerClass.handler.target_types == [int]
assert HandlerClass.handler.produces_types == [AnyType]
assert HandlerClass.handler2.target_types == [str, bool]
assert HandlerClass.handler2.produces_types == [NoneType]
class HandlerClass:
@message_handler()
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
return None