mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 15:38:53 +00:00
fix: fix user input in m1 (#4995)
* Add lock for input and output management in m1 * Use event to signal it is time to prompt for input * undo stop change * undo changes * Update python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * reduce exported surface area * fix --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
parent
0554fa3e2a
commit
466848ac65
@ -1,15 +1,17 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Awaitable, Callable, Optional, Sequence, Union, cast
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast
|
||||
|
||||
from aioconsole import ainput # type: ignore
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import ChatMessage, HandoffMessage, TextMessage
|
||||
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
# Define input function types more precisely
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
@ -109,6 +111,33 @@ class UserProxyAgent(BaseChatAgent):
|
||||
print(f"BaseException: {e}")
|
||||
"""
|
||||
|
||||
class InputRequestContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
|
||||
)
|
||||
|
||||
_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def request_id(cls) -> str:
|
||||
try:
|
||||
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
|
||||
) from e
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@ -153,9 +182,15 @@ class UserProxyAgent(BaseChatAgent):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_messages(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
|
||||
) -> Response:
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
"""Handle incoming messages by requesting user input."""
|
||||
try:
|
||||
# Check for handoff first
|
||||
@ -164,15 +199,18 @@ class UserProxyAgent(BaseChatAgent):
|
||||
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
|
||||
)
|
||||
|
||||
user_input = await self._get_input(prompt, cancellation_token)
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
|
||||
yield input_requested_event
|
||||
with UserProxyAgent.InputRequestContext.populate_context(request_id):
|
||||
user_input = await self._get_input(prompt, cancellation_token)
|
||||
|
||||
# Return appropriate message type based on handoff presence
|
||||
if handoff:
|
||||
return Response(
|
||||
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
|
||||
)
|
||||
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
|
||||
else:
|
||||
return Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
yield Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage):
|
||||
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
|
||||
|
||||
|
||||
class UserInputRequestedEvent(BaseAgentEvent):
|
||||
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
|
||||
|
||||
request_id: str
|
||||
"""Identifier for the user input request."""
|
||||
|
||||
content: Literal[""] = ""
|
||||
"""Empty content for compat with consumers expecting a content field."""
|
||||
|
||||
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
"""Messages for agent-to-agent communication only."""
|
||||
|
||||
|
||||
AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
|
||||
AgentEvent = Annotated[
|
||||
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
|
||||
]
|
||||
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
"BaseMessage",
|
||||
"TextMessage",
|
||||
"ChatMessage",
|
||||
"HandoffMessage",
|
||||
"MultiModalMessage",
|
||||
"StopMessage",
|
||||
"HandoffMessage",
|
||||
"ToolCallRequestEvent",
|
||||
"TextMessage",
|
||||
"ToolCallExecutionEvent",
|
||||
"ToolCallRequestEvent",
|
||||
"ToolCallSummaryMessage",
|
||||
"ChatMessage",
|
||||
"AgentEvent",
|
||||
"UserInputRequestedEvent",
|
||||
]
|
||||
|
||||
@ -2,6 +2,6 @@
|
||||
This module implements utility classes for formatting/printing agent messages.
|
||||
"""
|
||||
|
||||
from ._console import Console
|
||||
from ._console import Console, UserInputManager
|
||||
|
||||
__all__ = ["Console"]
|
||||
__all__ = ["Console", "UserInputManager"]
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional, TypeVar, cast
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
|
||||
|
||||
from aioconsole import aprint # type: ignore
|
||||
from autogen_core import Image
|
||||
from autogen_core import CancellationToken, Image
|
||||
from autogen_core.models import RequestUsage
|
||||
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
|
||||
|
||||
|
||||
def _is_running_in_iterm() -> bool:
|
||||
@ -19,14 +22,60 @@ def _is_output_a_tty() -> bool:
|
||||
return sys.stdout.isatty()
|
||||
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
T = TypeVar("T", bound=TaskResult | Response)
|
||||
|
||||
|
||||
class UserInputManager:
|
||||
def __init__(self, callback: InputFuncType):
|
||||
self.input_events: Dict[str, asyncio.Event] = {}
|
||||
self.callback = callback
|
||||
|
||||
def get_wrapped_callback(self) -> AsyncInputFunc:
|
||||
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
# Lookup the event for the prompt, if it exists wait for it.
|
||||
# If it doesn't exist, create it and store it.
|
||||
# Get request ID:
|
||||
request_id = UserProxyAgent.InputRequestContext.request_id()
|
||||
if request_id in self.input_events:
|
||||
event = self.input_events[request_id]
|
||||
else:
|
||||
event = asyncio.Event()
|
||||
self.input_events[request_id] = event
|
||||
|
||||
await event.wait()
|
||||
|
||||
del self.input_events[request_id]
|
||||
|
||||
if iscoroutinefunction(self.callback):
|
||||
# Cast to AsyncInputFunc for proper typing
|
||||
async_func = cast(AsyncInputFunc, self.callback)
|
||||
return await async_func(prompt, cancellation_token)
|
||||
else:
|
||||
# Cast to SyncInputFunc for proper typing
|
||||
sync_func = cast(SyncInputFunc, self.callback)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, sync_func, prompt)
|
||||
|
||||
return user_input_func_wrapper
|
||||
|
||||
def notify_event_received(self, request_id: str) -> None:
|
||||
if request_id in self.input_events:
|
||||
self.input_events[request_id].set()
|
||||
else:
|
||||
event = asyncio.Event()
|
||||
self.input_events[request_id] = event
|
||||
|
||||
|
||||
async def Console(
|
||||
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
|
||||
*,
|
||||
no_inline_images: bool = False,
|
||||
output_stats: bool = False,
|
||||
user_input_manager: UserInputManager | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
|
||||
@ -67,6 +116,7 @@ async def Console(
|
||||
f"Duration: {duration:.2f} seconds\n"
|
||||
)
|
||||
await aprint(output, end="")
|
||||
|
||||
# mypy ignore
|
||||
last_processed = message # type: ignore
|
||||
|
||||
@ -96,9 +146,13 @@ async def Console(
|
||||
f"Duration: {duration:.2f} seconds\n"
|
||||
)
|
||||
await aprint(output, end="")
|
||||
|
||||
# mypy ignore
|
||||
last_processed = message # type: ignore
|
||||
|
||||
# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
|
||||
elif isinstance(message, UserInputRequestedEvent):
|
||||
if user_input_manager is not None:
|
||||
user_input_manager.notify_event_received(message.request_id)
|
||||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import warnings
|
||||
from typing import List
|
||||
from typing import Awaitable, Callable, List, Optional, Union
|
||||
|
||||
from autogen_agentchat.agents import CodeExecutorAgent, UserProxyAgent
|
||||
from autogen_agentchat.base import ChatAgent
|
||||
from autogen_agentchat.teams import MagenticOneGroupChat
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
|
||||
from autogen_ext.agents.file_surfer import FileSurfer
|
||||
@ -12,6 +13,10 @@ from autogen_ext.agents.web_surfer import MultimodalWebSurfer
|
||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai._openai_client import BaseOpenAIChatCompletionClient
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
|
||||
class MagenticOne(MagenticOneGroupChat):
|
||||
"""
|
||||
@ -116,7 +121,12 @@ class MagenticOne(MagenticOneGroupChat):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
hil_mode: bool = False,
|
||||
input_func: InputFuncType | None = None,
|
||||
):
|
||||
self.client = client
|
||||
self._validate_client_capabilities(client)
|
||||
|
||||
@ -126,7 +136,7 @@ class MagenticOne(MagenticOneGroupChat):
|
||||
executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor())
|
||||
agents: List[ChatAgent] = [fs, ws, coder, executor]
|
||||
if hil_mode:
|
||||
user_proxy = UserProxyAgent("User")
|
||||
user_proxy = UserProxyAgent("User", input_func=input_func)
|
||||
agents.append(user_proxy)
|
||||
super().__init__(agents, model_client=client)
|
||||
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from autogen_agentchat.ui import Console
|
||||
from aioconsole import ainput # type: ignore
|
||||
from autogen_agentchat.ui import Console, UserInputManager
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.teams.magentic_one import MagenticOne
|
||||
|
||||
@ -10,6 +13,13 @@ from autogen_ext.teams.magentic_one import MagenticOne
|
||||
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
|
||||
|
||||
|
||||
async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
task: asyncio.Task[str] = asyncio.create_task(ainput(prompt)) # type: ignore
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(task)
|
||||
return await task
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Command-line interface for running a complex task using MagenticOne.
|
||||
@ -37,9 +47,10 @@ def main() -> None:
|
||||
args = parser.parse_args()
|
||||
|
||||
async def run_task(task: str, hil_mode: bool) -> None:
|
||||
input_manager = UserInputManager(callback=cancellable_input)
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
m1 = MagenticOne(client=client, hil_mode=hil_mode)
|
||||
await Console(m1.run_stream(task=task), output_stats=False)
|
||||
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback())
|
||||
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)
|
||||
|
||||
task = args.task[0]
|
||||
asyncio.run(run_task(task, not args.no_hil))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user