mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 10:50:03 +00:00
Add UserProxyAgent in AgentChat API (#4255)
* initial addition of a user proxy agent in agentchat, related to #3614 * fix typing/mypy errors * format fixes * format and pyright checks * update, add support for returning handoff message, add tests --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
parent
c9835f3b52
commit
0ff1687485
@ -4,6 +4,7 @@ from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
from ._user_proxy_agent import UserProxyAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
@ -13,4 +14,5 @@ __all__ = [
|
||||
"CodingAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
"SocietyOfMindAgent",
|
||||
"UserProxyAgent",
|
||||
]
|
||||
|
||||
@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import ChatMessage, HandoffMessage, TextMessage
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
# Define input function types more precisely
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
|
||||
class UserProxyAgent(BaseChatAgent):
|
||||
"""An agent that can represent a human user in a chat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str = "a human user",
|
||||
input_func: Optional[InputFuncType] = None,
|
||||
) -> None:
|
||||
"""Initialize the UserProxyAgent."""
|
||||
super().__init__(name=name, description=description)
|
||||
self.input_func = input_func or input
|
||||
self._is_async = iscoroutinefunction(self.input_func)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""Message types this agent can produce."""
|
||||
return [TextMessage, HandoffMessage]
|
||||
|
||||
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
|
||||
"""Find the most recent HandoffMessage in the message sequence."""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, HandoffMessage):
|
||||
return message
|
||||
return None
|
||||
|
||||
async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
"""Handle input based on function signature."""
|
||||
try:
|
||||
if self._is_async:
|
||||
# Cast to AsyncInputFunc for proper typing
|
||||
async_func = cast(AsyncInputFunc, self.input_func)
|
||||
return await async_func(prompt, cancellation_token)
|
||||
else:
|
||||
# Cast to SyncInputFunc for proper typing
|
||||
sync_func = cast(SyncInputFunc, self.input_func)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, sync_func, prompt)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_messages(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
|
||||
) -> Response:
|
||||
"""Handle incoming messages by requesting user input."""
|
||||
try:
|
||||
# Check for handoff first
|
||||
handoff = self._get_latest_handoff(messages)
|
||||
prompt = (
|
||||
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
|
||||
)
|
||||
|
||||
user_input = await self._get_input(prompt, cancellation_token)
|
||||
|
||||
# Return appropriate message type based on handoff presence
|
||||
if handoff:
|
||||
return Response(
|
||||
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
|
||||
)
|
||||
else:
|
||||
return Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
"""Reset agent state."""
|
||||
pass
|
||||
103
python/packages/autogen-agentchat/tests/test_userproxy_agent.py
Normal file
103
python/packages/autogen-agentchat/tests/test_userproxy_agent.py
Normal file
@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_input() -> None:
|
||||
"""Test basic message handling with custom input"""
|
||||
|
||||
def custom_input(prompt: str) -> str:
|
||||
return "The height of the eiffel tower is 324 meters. Aloha!"
|
||||
|
||||
agent = UserProxyAgent(name="test_user", input_func=custom_input)
|
||||
messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")]
|
||||
|
||||
response = await agent.on_messages(messages, CancellationToken())
|
||||
|
||||
assert isinstance(response, Response)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
assert response.chat_message.content == "The height of the eiffel tower is 324 meters. Aloha!"
|
||||
assert response.chat_message.source == "test_user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_input() -> None:
|
||||
"""Test handling of async input function"""
|
||||
|
||||
async def async_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
|
||||
await asyncio.sleep(0.1)
|
||||
return "async response"
|
||||
|
||||
agent = UserProxyAgent(name="test_user", input_func=async_input)
|
||||
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||
|
||||
response = await agent.on_messages(messages, CancellationToken())
|
||||
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
assert response.chat_message.content == "async response"
|
||||
assert response.chat_message.source == "test_user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoff_handling() -> None:
|
||||
"""Test handling of handoff messages"""
|
||||
|
||||
def custom_input(prompt: str) -> str:
|
||||
return "handoff response"
|
||||
|
||||
agent = UserProxyAgent(name="test_user", input_func=custom_input)
|
||||
|
||||
messages: Sequence[ChatMessage] = [
|
||||
TextMessage(content="Initial message", source="assistant"),
|
||||
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"),
|
||||
]
|
||||
|
||||
response = await agent.on_messages(messages, CancellationToken())
|
||||
|
||||
assert isinstance(response.chat_message, HandoffMessage)
|
||||
assert response.chat_message.content == "handoff response"
|
||||
assert response.chat_message.source == "test_user"
|
||||
assert response.chat_message.target == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation() -> None:
|
||||
"""Test cancellation during message handling"""
|
||||
|
||||
async def cancellable_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
|
||||
await asyncio.sleep(0.1)
|
||||
if token and token.is_cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
return "cancellable response"
|
||||
|
||||
agent = UserProxyAgent(name="test_user", input_func=cancellable_input)
|
||||
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||
token = CancellationToken()
|
||||
|
||||
async def cancel_after_delay() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await asyncio.gather(agent.on_messages(messages, token), cancel_after_delay())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling() -> None:
|
||||
"""Test error handling with problematic input function"""
|
||||
|
||||
def failing_input(_: str) -> str:
|
||||
raise ValueError("Input function failed")
|
||||
|
||||
agent = UserProxyAgent(name="test_user", input_func=failing_input)
|
||||
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent.on_messages(messages, CancellationToken())
|
||||
assert "Failed to get user input" in str(exc_info.value)
|
||||
Loading…
x
Reference in New Issue
Block a user