mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 05:59:19 +00:00
Add text-only model support to M1 (#5344)
Modify M1 agents to support text-only settings. This allows M1 to be used with models like o3-mini and Llama3.1+
This commit is contained in:
parent
517e3f000e
commit
cf6fa77273
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
|
||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@ -24,6 +24,7 @@ from ....messages import (
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from ....utils import content_to_str, remove_images
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
@ -138,7 +139,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
# Create the initial task ledger
|
||||
#################################
|
||||
# Combine all message contents for task
|
||||
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
|
||||
self._task = " ".join([content_to_str(msg.content) for msg in message.messages])
|
||||
planning_conversation: List[LLMMessage] = []
|
||||
|
||||
# 1. GATHER FACTS
|
||||
@ -146,7 +147,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
@ -157,7 +160,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
@ -281,7 +286,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
assert self._max_json_retries > 0
|
||||
key_error: bool = False
|
||||
for _ in range(self._max_json_retries):
|
||||
response = await self._model_client.create(context, json_output=True)
|
||||
response = await self._model_client.create(self._get_compatible_context(context), json_output=True)
|
||||
ledger_str = response.content
|
||||
try:
|
||||
assert isinstance(ledger_str, str)
|
||||
@ -397,7 +402,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
|
||||
context.append(UserMessage(content=update_facts_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context, cancellation_token=cancellation_token)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
@ -407,7 +414,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
|
||||
context.append(UserMessage(content=update_plan_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context, cancellation_token=cancellation_token)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
@ -420,7 +429,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
final_answer_prompt = self._get_final_answer_prompt(self._task)
|
||||
context.append(UserMessage(content=final_answer_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context, cancellation_token=cancellation_token)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
message = TextMessage(content=response.content, source=self._name)
|
||||
|
||||
@ -464,15 +475,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
return context
|
||||
|
||||
def _content_to_str(self, content: str | List[str | Image]) -> str:
|
||||
"""Convert the content to a string."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
result: List[str] = []
|
||||
for c in content:
|
||||
if isinstance(c, str):
|
||||
result.append(c)
|
||||
else:
|
||||
result.append("<image>")
|
||||
return "\n".join(result)
|
||||
return remove_images(messages)
|
||||
|
||||
@ -2,6 +2,6 @@
|
||||
This module implements various utilities common to AgentChat agents and teams.
|
||||
"""
|
||||
|
||||
from ._utils import remove_images
|
||||
from ._utils import content_to_str, remove_images
|
||||
|
||||
__all__ = ["remove_images"]
|
||||
__all__ = ["content_to_str", "remove_images"]
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from autogen_core import Image
|
||||
from autogen_core.models import LLMMessage, UserMessage
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage, UserMessage
|
||||
|
||||
# Type aliases for convenience
|
||||
_UserContent = Union[str, List[Union[str, Image]]]
|
||||
_AssistantContent = Union[str, List[FunctionCall]]
|
||||
_FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
_SystemContent = str
|
||||
|
||||
|
||||
def _image_content_to_str(content: str | List[str | Image]) -> str:
|
||||
"""Convert the content of an LLMMessageto a string."""
|
||||
def content_to_str(content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent) -> str:
|
||||
"""Convert the content of an LLMMessage to a string."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
else:
|
||||
@ -16,7 +22,7 @@ def _image_content_to_str(content: str | List[str | Image]) -> str:
|
||||
elif isinstance(c, Image):
|
||||
result.append("<image>")
|
||||
else:
|
||||
raise AssertionError("Received unexpected content type.")
|
||||
result.append(str(c))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@ -26,7 +32,7 @@ def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
str_messages: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage) and isinstance(message.content, list):
|
||||
str_messages.append(UserMessage(content=_image_content_to_str(message.content), source=message.source))
|
||||
str_messages.append(UserMessage(content=content_to_str(message.content), source=message.source))
|
||||
else:
|
||||
str_messages.append(message)
|
||||
return str_messages
|
||||
|
||||
@ -9,6 +9,7 @@ from autogen_agentchat.messages import (
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from autogen_agentchat.utils import remove_images
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
@ -126,7 +127,7 @@ class FileSurfer(BaseChatAgent):
|
||||
)
|
||||
|
||||
create_result = await self._model_client.create(
|
||||
messages=history + [context_message, task_message],
|
||||
messages=self._get_compatible_context(history + [context_message, task_message]),
|
||||
tools=[
|
||||
TOOL_OPEN_PATH,
|
||||
TOOL_PAGE_DOWN,
|
||||
@ -172,3 +173,10 @@ class FileSurfer(BaseChatAgent):
|
||||
|
||||
final_response = "TERMINATE"
|
||||
return False, final_response
|
||||
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
|
||||
@ -24,6 +24,7 @@ import PIL.Image
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
|
||||
from autogen_agentchat.utils import content_to_str, remove_images
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from autogen_core import Image as AGImage
|
||||
from autogen_core.models import (
|
||||
@ -40,7 +41,13 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._events import WebSurferEvent
|
||||
from ._prompts import WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT
|
||||
from ._prompts import (
|
||||
WEB_SURFER_OCR_PROMPT,
|
||||
WEB_SURFER_QA_PROMPT,
|
||||
WEB_SURFER_QA_SYSTEM_MESSAGE,
|
||||
WEB_SURFER_TOOL_PROMPT_MM,
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT,
|
||||
)
|
||||
from ._set_of_mark import add_set_of_mark
|
||||
from ._tool_definitions import (
|
||||
TOOL_CLICK,
|
||||
@ -56,7 +63,6 @@ from ._tool_definitions import (
|
||||
TOOL_WEB_SEARCH,
|
||||
)
|
||||
from ._types import InteractiveRegion, UserContent
|
||||
from ._utils import message_content_to_str
|
||||
from .playwright_controller import PlaywrightController
|
||||
|
||||
|
||||
@ -215,8 +221,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
raise ValueError(
|
||||
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
|
||||
)
|
||||
if model_client.model_info["vision"] is False:
|
||||
raise ValueError("The model is not multimodal. MultimodalWebSurfer requires a multimodal model.")
|
||||
|
||||
self._model_client = model_client
|
||||
self.headless = headless
|
||||
self.browser_channel = browser_channel
|
||||
@ -404,7 +409,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
self.model_usage: List[RequestUsage] = []
|
||||
try:
|
||||
content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=message_content_to_str(content), source=self.name))
|
||||
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
|
||||
final_usage = RequestUsage(
|
||||
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
|
||||
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
|
||||
@ -434,22 +439,8 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
|
||||
assert self._page is not None
|
||||
|
||||
# Clone the messages to give context, removing old screenshots
|
||||
history: List[LLMMessage] = []
|
||||
for m in self._chat_history:
|
||||
assert isinstance(m, UserMessage | AssistantMessage | SystemMessage)
|
||||
assert isinstance(m.content, str | list)
|
||||
|
||||
if isinstance(m.content, str):
|
||||
history.append(m)
|
||||
else:
|
||||
content = message_content_to_str(m.content)
|
||||
if isinstance(m, UserMessage):
|
||||
history.append(UserMessage(content=content, source=m.source))
|
||||
elif isinstance(m, AssistantMessage):
|
||||
history.append(AssistantMessage(content=content, source=m.source))
|
||||
elif isinstance(m, SystemMessage):
|
||||
history.append(SystemMessage(content=content))
|
||||
# Clone the messages, removing old screenshots
|
||||
history: List[LLMMessage] = remove_images(self._chat_history)
|
||||
|
||||
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
|
||||
rects = await self._playwright_controller.get_interactive_rects(self._page)
|
||||
@ -512,22 +503,37 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
|
||||
tool_names = "\n".join([t["name"] for t in tools])
|
||||
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT.format(
|
||||
url=self._page.url,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
).strip()
|
||||
if self._model_client.model_info["vision"]:
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
|
||||
url=self._page.url,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
).strip()
|
||||
|
||||
# Scale the screenshot for the MLM, and close the original
|
||||
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
som_screenshot.close()
|
||||
if self.to_save_screenshots:
|
||||
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
|
||||
# Scale the screenshot for the MLM, and close the original
|
||||
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
som_screenshot.close()
|
||||
if self.to_save_screenshots:
|
||||
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
|
||||
|
||||
# Add the multimodal message and make the request
|
||||
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
|
||||
# Add the message
|
||||
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
|
||||
else:
|
||||
visible_text = await self._playwright_controller.get_visible_text(self._page)
|
||||
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
|
||||
url=self._page.url,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
visible_text=visible_text.strip(),
|
||||
).strip()
|
||||
|
||||
# Add the message
|
||||
history.append(UserMessage(content=text_prompt, source=self.name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
WEB_SURFER_TOOL_PROMPT = """
|
||||
WEB_SURFER_TOOL_PROMPT_MM = """
|
||||
Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
@ -13,6 +13,27 @@ When deciding between tools, consider if the request can be best addressed by:
|
||||
- on some other website entirely (in which case actions like performing a new web search might be the best option)
|
||||
"""
|
||||
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT = """
|
||||
Your web browser is open to the page '{url}'. The following text is visible in the viewport:
|
||||
|
||||
```
|
||||
{visible_text}
|
||||
```
|
||||
|
||||
You have also identified the following interactive components:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
|
||||
You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools:
|
||||
|
||||
{tool_names}
|
||||
|
||||
When deciding between tools, consider if the request can be best addressed by:
|
||||
- the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text might be most appropriate, or hovering over element)
|
||||
- contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate)
|
||||
- on some other website entirely (in which case actions like performing a new web search might be the best option)
|
||||
"""
|
||||
|
||||
WEB_SURFER_OCR_PROMPT = """
|
||||
Please transcribe all visible text on this page, including both main content and the labels of UI elements.
|
||||
"""
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from autogen_core import Image
|
||||
|
||||
from ._types import AssistantContent, FunctionExecutionContent, SystemContent, UserContent
|
||||
|
||||
|
||||
# Convert UserContent to a string
|
||||
def message_content_to_str(
|
||||
message_content: UserContent | AssistantContent | SystemContent | FunctionExecutionContent,
|
||||
) -> str:
|
||||
if isinstance(message_content, str):
|
||||
return message_content
|
||||
elif isinstance(message_content, List):
|
||||
converted: List[str] = list()
|
||||
for item in message_content:
|
||||
if isinstance(item, str):
|
||||
converted.append(item.rstrip())
|
||||
elif isinstance(item, Image):
|
||||
converted.append("<Image>")
|
||||
else:
|
||||
converted.append(str(item).rstrip())
|
||||
return "\n".join(converted)
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
@ -142,7 +142,7 @@ class MagenticOne(MagenticOneGroupChat):
|
||||
|
||||
def _validate_client_capabilities(self, client: ChatCompletionClient) -> None:
|
||||
capabilities = client.model_info
|
||||
required_capabilities = ["vision", "function_calling", "json_output"]
|
||||
required_capabilities = ["function_calling", "json_output"]
|
||||
|
||||
if not all(capabilities.get(cap) for cap in required_capabilities):
|
||||
warnings.warn(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user