fix: fix user input in m1 (#4995)

* Add lock for input and output management in m1

* Use event to signal it is time to prompt for input

* undo stop change

* undo changes

* Update python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* reduce exported surface area

* fix

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
Jack Gerrits 2025-01-13 10:28:08 -05:00 committed by GitHub
parent 0554fa3e2a
commit 466848ac65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 157 additions and 29 deletions

View File

@ -1,15 +1,17 @@
import asyncio
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast
from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken
from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ._base_chat_agent import BaseChatAgent
# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
@ -109,6 +111,33 @@ class UserProxyAgent(BaseChatAgent):
print(f"BaseException: {e}")
"""
class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
)
_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")
@classmethod
@contextmanager
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
""":meta private:"""
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
try:
yield
finally:
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)
@classmethod
def request_id(cls) -> str:
try:
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
) from e
def __init__(
self,
name: str,
@ -153,9 +182,15 @@ class UserProxyAgent(BaseChatAgent):
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
) -> Response:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first
@ -164,15 +199,18 @@ class UserProxyAgent(BaseChatAgent):
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
)
user_input = await self._get_input(prompt, cancellation_token)
request_id = str(uuid.uuid4())
input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
yield input_requested_event
with UserProxyAgent.InputRequestContext.populate_context(request_id):
user_input = await self._get_input(prompt, cancellation_token)
# Return appropriate message type based on handoff presence
if handoff:
return Response(
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
)
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
else:
return Response(chat_message=TextMessage(content=user_input, source=self.name))
yield Response(chat_message=TextMessage(content=user_input, source=self.name))
except asyncio.CancelledError:
raise

View File

@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
class UserInputRequestedEvent(BaseAgentEvent):
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
request_id: str
"""Identifier for the user input request."""
content: Literal[""] = ""
"""Empty content for compat with consumers expecting a content field."""
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""
AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
__all__ = [
"AgentEvent",
"BaseMessage",
"TextMessage",
"ChatMessage",
"HandoffMessage",
"MultiModalMessage",
"StopMessage",
"HandoffMessage",
"ToolCallRequestEvent",
"TextMessage",
"ToolCallExecutionEvent",
"ToolCallRequestEvent",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"UserInputRequestedEvent",
]

View File

@ -2,6 +2,6 @@
This module implements utility classes for formatting/printing agent messages.
"""
from ._console import Console
from ._console import Console, UserInputManager
__all__ = ["Console"]
__all__ = ["Console", "UserInputManager"]

View File

@ -1,14 +1,17 @@
import asyncio
import os
import sys
import time
from typing import AsyncGenerator, List, Optional, TypeVar, cast
from inspect import iscoroutinefunction
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
from aioconsole import aprint # type: ignore
from autogen_core import Image
from autogen_core import CancellationToken, Image
from autogen_core.models import RequestUsage
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
def _is_running_in_iterm() -> bool:
@ -19,14 +22,60 @@ def _is_output_a_tty() -> bool:
return sys.stdout.isatty()
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
T = TypeVar("T", bound=TaskResult | Response)
class UserInputManager:
def __init__(self, callback: InputFuncType):
self.input_events: Dict[str, asyncio.Event] = {}
self.callback = callback
def get_wrapped_callback(self) -> AsyncInputFunc:
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
# Lookup the event for the prompt, if it exists wait for it.
# If it doesn't exist, create it and store it.
# Get request ID:
request_id = UserProxyAgent.InputRequestContext.request_id()
if request_id in self.input_events:
event = self.input_events[request_id]
else:
event = asyncio.Event()
self.input_events[request_id] = event
await event.wait()
del self.input_events[request_id]
if iscoroutinefunction(self.callback):
# Cast to AsyncInputFunc for proper typing
async_func = cast(AsyncInputFunc, self.callback)
return await async_func(prompt, cancellation_token)
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(SyncInputFunc, self.callback)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)
return user_input_func_wrapper
def notify_event_received(self, request_id: str) -> None:
if request_id in self.input_events:
self.input_events[request_id].set()
else:
event = asyncio.Event()
self.input_events[request_id] = event
async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = False,
user_input_manager: UserInputManager | None = None,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
@ -67,6 +116,7 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore
@ -96,9 +146,13 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore
# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id)
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore

View File

@ -1,9 +1,10 @@
import warnings
from typing import List
from typing import Awaitable, Callable, List, Optional, Union
from autogen_agentchat.agents import CodeExecutorAgent, UserProxyAgent
from autogen_agentchat.base import ChatAgent
from autogen_agentchat.teams import MagenticOneGroupChat
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
from autogen_ext.agents.file_surfer import FileSurfer
@ -12,6 +13,10 @@ from autogen_ext.agents.web_surfer import MultimodalWebSurfer
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai._openai_client import BaseOpenAIChatCompletionClient
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
class MagenticOne(MagenticOneGroupChat):
"""
@ -116,7 +121,12 @@ class MagenticOne(MagenticOneGroupChat):
"""
def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
def __init__(
self,
client: ChatCompletionClient,
hil_mode: bool = False,
input_func: InputFuncType | None = None,
):
self.client = client
self._validate_client_capabilities(client)
@ -126,7 +136,7 @@ class MagenticOne(MagenticOneGroupChat):
executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor())
agents: List[ChatAgent] = [fs, ws, coder, executor]
if hil_mode:
user_proxy = UserProxyAgent("User")
user_proxy = UserProxyAgent("User", input_func=input_func)
agents.append(user_proxy)
super().__init__(agents, model_client=client)

View File

@ -1,8 +1,11 @@
import argparse
import asyncio
import warnings
from typing import Optional
from autogen_agentchat.ui import Console
from aioconsole import ainput # type: ignore
from autogen_agentchat.ui import Console, UserInputManager
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.teams.magentic_one import MagenticOne
@ -10,6 +13,13 @@ from autogen_ext.teams.magentic_one import MagenticOne
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
task: asyncio.Task[str] = asyncio.create_task(ainput(prompt)) # type: ignore
if cancellation_token is not None:
cancellation_token.link_future(task)
return await task
def main() -> None:
"""
Command-line interface for running a complex task using MagenticOne.
@ -37,9 +47,10 @@ def main() -> None:
args = parser.parse_args()
async def run_task(task: str, hil_mode: bool) -> None:
input_manager = UserInputManager(callback=cancellable_input)
client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client, hil_mode=hil_mode)
await Console(m1.run_stream(task=task), output_stats=False)
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback())
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)
task = args.task[0]
asyncio.run(run_task(task, not args.no_hil))