mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 03:10:04 +00:00
Add UserProxyAgent in AgentChat API (#4255)
* initial addition of a user proxy agent in agentchat, related to #3614 * fix typing/mypy errors * format fixes * format and pyright checks * update, add support for returning handoff message, add tests --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
parent
c9835f3b52
commit
0ff1687485
@ -4,6 +4,7 @@ from ._code_executor_agent import CodeExecutorAgent
|
|||||||
from ._coding_assistant_agent import CodingAssistantAgent
|
from ._coding_assistant_agent import CodingAssistantAgent
|
||||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||||
|
from ._user_proxy_agent import UserProxyAgent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseChatAgent",
|
"BaseChatAgent",
|
||||||
@ -13,4 +14,5 @@ __all__ = [
|
|||||||
"CodingAssistantAgent",
|
"CodingAssistantAgent",
|
||||||
"ToolUseAssistantAgent",
|
"ToolUseAssistantAgent",
|
||||||
"SocietyOfMindAgent",
|
"SocietyOfMindAgent",
|
||||||
|
"UserProxyAgent",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -0,0 +1,89 @@
|
|||||||
|
import asyncio
|
||||||
|
from inspect import iscoroutinefunction
|
||||||
|
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
|
from ..base import Response
|
||||||
|
from ..messages import ChatMessage, HandoffMessage, TextMessage
|
||||||
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
|
# Define input function types more precisely
|
||||||
|
SyncInputFunc = Callable[[str], str]
|
||||||
|
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||||
|
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||||
|
|
||||||
|
|
||||||
|
class UserProxyAgent(BaseChatAgent):
|
||||||
|
"""An agent that can represent a human user in a chat."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str = "a human user",
|
||||||
|
input_func: Optional[InputFuncType] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the UserProxyAgent."""
|
||||||
|
super().__init__(name=name, description=description)
|
||||||
|
self.input_func = input_func or input
|
||||||
|
self._is_async = iscoroutinefunction(self.input_func)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
|
"""Message types this agent can produce."""
|
||||||
|
return [TextMessage, HandoffMessage]
|
||||||
|
|
||||||
|
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
|
||||||
|
"""Find the most recent HandoffMessage in the message sequence."""
|
||||||
|
for message in reversed(messages):
|
||||||
|
if isinstance(message, HandoffMessage):
|
||||||
|
return message
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||||
|
"""Handle input based on function signature."""
|
||||||
|
try:
|
||||||
|
if self._is_async:
|
||||||
|
# Cast to AsyncInputFunc for proper typing
|
||||||
|
async_func = cast(AsyncInputFunc, self.input_func)
|
||||||
|
return await async_func(prompt, cancellation_token)
|
||||||
|
else:
|
||||||
|
# Cast to SyncInputFunc for proper typing
|
||||||
|
sync_func = cast(SyncInputFunc, self.input_func)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, sync_func, prompt)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||||
|
|
||||||
|
async def on_messages(
|
||||||
|
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
|
||||||
|
) -> Response:
|
||||||
|
"""Handle incoming messages by requesting user input."""
|
||||||
|
try:
|
||||||
|
# Check for handoff first
|
||||||
|
handoff = self._get_latest_handoff(messages)
|
||||||
|
prompt = (
|
||||||
|
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
|
||||||
|
)
|
||||||
|
|
||||||
|
user_input = await self._get_input(prompt, cancellation_token)
|
||||||
|
|
||||||
|
# Return appropriate message type based on handoff presence
|
||||||
|
if handoff:
|
||||||
|
return Response(
|
||||||
|
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||||
|
|
||||||
|
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||||
|
"""Reset agent state."""
|
||||||
|
pass
|
||||||
103
python/packages/autogen-agentchat/tests/test_userproxy_agent.py
Normal file
103
python/packages/autogen-agentchat/tests/test_userproxy_agent.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from autogen_agentchat.agents import UserProxyAgent
|
||||||
|
from autogen_agentchat.base import Response
|
||||||
|
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_input() -> None:
|
||||||
|
"""Test basic message handling with custom input"""
|
||||||
|
|
||||||
|
def custom_input(prompt: str) -> str:
|
||||||
|
return "The height of the eiffel tower is 324 meters. Aloha!"
|
||||||
|
|
||||||
|
agent = UserProxyAgent(name="test_user", input_func=custom_input)
|
||||||
|
messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")]
|
||||||
|
|
||||||
|
response = await agent.on_messages(messages, CancellationToken())
|
||||||
|
|
||||||
|
assert isinstance(response, Response)
|
||||||
|
assert isinstance(response.chat_message, TextMessage)
|
||||||
|
assert response.chat_message.content == "The height of the eiffel tower is 324 meters. Aloha!"
|
||||||
|
assert response.chat_message.source == "test_user"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_input() -> None:
|
||||||
|
"""Test handling of async input function"""
|
||||||
|
|
||||||
|
async def async_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return "async response"
|
||||||
|
|
||||||
|
agent = UserProxyAgent(name="test_user", input_func=async_input)
|
||||||
|
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||||
|
|
||||||
|
response = await agent.on_messages(messages, CancellationToken())
|
||||||
|
|
||||||
|
assert isinstance(response.chat_message, TextMessage)
|
||||||
|
assert response.chat_message.content == "async response"
|
||||||
|
assert response.chat_message.source == "test_user"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handoff_handling() -> None:
|
||||||
|
"""Test handling of handoff messages"""
|
||||||
|
|
||||||
|
def custom_input(prompt: str) -> str:
|
||||||
|
return "handoff response"
|
||||||
|
|
||||||
|
agent = UserProxyAgent(name="test_user", input_func=custom_input)
|
||||||
|
|
||||||
|
messages: Sequence[ChatMessage] = [
|
||||||
|
TextMessage(content="Initial message", source="assistant"),
|
||||||
|
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await agent.on_messages(messages, CancellationToken())
|
||||||
|
|
||||||
|
assert isinstance(response.chat_message, HandoffMessage)
|
||||||
|
assert response.chat_message.content == "handoff response"
|
||||||
|
assert response.chat_message.source == "test_user"
|
||||||
|
assert response.chat_message.target == "assistant"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancellation() -> None:
|
||||||
|
"""Test cancellation during message handling"""
|
||||||
|
|
||||||
|
async def cancellable_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
if token and token.is_cancelled():
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
return "cancellable response"
|
||||||
|
|
||||||
|
agent = UserProxyAgent(name="test_user", input_func=cancellable_input)
|
||||||
|
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||||
|
token = CancellationToken()
|
||||||
|
|
||||||
|
async def cancel_after_delay() -> None:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
token.cancel()
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await asyncio.gather(agent.on_messages(messages, token), cancel_after_delay())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling() -> None:
|
||||||
|
"""Test error handling with problematic input function"""
|
||||||
|
|
||||||
|
def failing_input(_: str) -> str:
|
||||||
|
raise ValueError("Input function failed")
|
||||||
|
|
||||||
|
agent = UserProxyAgent(name="test_user", input_func=failing_input)
|
||||||
|
messages = [TextMessage(content="test prompt", source="assistant")]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
await agent.on_messages(messages, CancellationToken())
|
||||||
|
assert "Failed to get user input" in str(exc_info.value)
|
||||||
Loading…
x
Reference in New Issue
Block a user