Add text-only model support to M1 (#5344)

Modify M1 agents to support text-only settings.
This allows M1 to be used with models like o3-mini and Llama3.1+
This commit is contained in:
afourney 2025-02-04 08:25:48 -08:00 committed by GitHub
parent 517e3f000e
commit cf6fa77273
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 112 additions and 91 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
from typing import Any, Dict, List, Mapping
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@ -24,6 +24,7 @@ from ....messages import (
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from ....utils import content_to_str, remove_images
from .._base_group_chat_manager import BaseGroupChatManager
from .._events import (
GroupChatAgentResponse,
@ -138,7 +139,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Create the initial task ledger
#################################
# Combine all message contents for task
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
self._task = " ".join([content_to_str(msg.content) for msg in message.messages])
planning_conversation: List[LLMMessage] = []
# 1. GATHER FACTS
@ -146,7 +147,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
planning_conversation.append(
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
)
assert isinstance(response.content, str)
self._facts = response.content
@ -157,7 +160,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
planning_conversation.append(
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
)
assert isinstance(response.content, str)
self._plan = response.content
@ -281,7 +286,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
assert self._max_json_retries > 0
key_error: bool = False
for _ in range(self._max_json_retries):
response = await self._model_client.create(context, json_output=True)
response = await self._model_client.create(self._get_compatible_context(context), json_output=True)
ledger_str = response.content
try:
assert isinstance(ledger_str, str)
@ -397,7 +402,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
context.append(UserMessage(content=update_facts_prompt, source=self._name))
response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._facts = response.content
@ -407,7 +414,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
context.append(UserMessage(content=update_plan_prompt, source=self._name))
response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._plan = response.content
@ -420,7 +429,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
final_answer_prompt = self._get_final_answer_prompt(self._task)
context.append(UserMessage(content=final_answer_prompt, source=self._name))
response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
message = TextMessage(content=response.content, source=self._name)
@ -464,15 +475,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
context.append(UserMessage(content=m.content, source=m.source))
return context
def _content_to_str(self, content: str | List[str | Image]) -> str:
"""Convert the content to a string."""
if isinstance(content, str):
return content
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
result: List[str] = []
for c in content:
if isinstance(c, str):
result.append(c)
else:
result.append("<image>")
return "\n".join(result)
return remove_images(messages)

View File

@ -2,6 +2,6 @@
This module implements various utilities common to AgentChat agents and teams.
"""
from ._utils import remove_images
from ._utils import content_to_str, remove_images
__all__ = ["remove_images"]
__all__ = ["content_to_str", "remove_images"]

View File

@ -1,11 +1,17 @@
from typing import List
from typing import List, Union
from autogen_core import Image
from autogen_core.models import LLMMessage, UserMessage
from autogen_core import FunctionCall, Image
from autogen_core.models import FunctionExecutionResult, LLMMessage, UserMessage
# Type aliases for convenience
_UserContent = Union[str, List[Union[str, Image]]]
_AssistantContent = Union[str, List[FunctionCall]]
_FunctionExecutionContent = List[FunctionExecutionResult]
_SystemContent = str
def _image_content_to_str(content: str | List[str | Image]) -> str:
"""Convert the content of an LLMMessageto a string."""
def content_to_str(content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent) -> str:
"""Convert the content of an LLMMessage to a string."""
if isinstance(content, str):
return content
else:
@ -16,7 +22,7 @@ def _image_content_to_str(content: str | List[str | Image]) -> str:
elif isinstance(c, Image):
result.append("<image>")
else:
raise AssertionError("Received unexpected content type.")
result.append(str(c))
return "\n".join(result)
@ -26,7 +32,7 @@ def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]:
str_messages: List[LLMMessage] = []
for message in messages:
if isinstance(message, UserMessage) and isinstance(message.content, list):
str_messages.append(UserMessage(content=_image_content_to_str(message.content), source=message.source))
str_messages.append(UserMessage(content=content_to_str(message.content), source=message.source))
else:
str_messages.append(message)
return str_messages

View File

@ -9,6 +9,7 @@ from autogen_agentchat.messages import (
MultiModalMessage,
TextMessage,
)
from autogen_agentchat.utils import remove_images
from autogen_core import CancellationToken, FunctionCall
from autogen_core.models import (
AssistantMessage,
@ -126,7 +127,7 @@ class FileSurfer(BaseChatAgent):
)
create_result = await self._model_client.create(
messages=history + [context_message, task_message],
messages=self._get_compatible_context(history + [context_message, task_message]),
tools=[
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
@ -172,3 +173,10 @@ class FileSurfer(BaseChatAgent):
final_response = "TERMINATE"
return False, final_response
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)

View File

@ -24,6 +24,7 @@ import PIL.Image
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
from autogen_agentchat.utils import content_to_str, remove_images
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core import Image as AGImage
from autogen_core.models import (
@ -40,7 +41,13 @@ from pydantic import BaseModel
from typing_extensions import Self
from ._events import WebSurferEvent
from ._prompts import WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT
from ._prompts import (
WEB_SURFER_OCR_PROMPT,
WEB_SURFER_QA_PROMPT,
WEB_SURFER_QA_SYSTEM_MESSAGE,
WEB_SURFER_TOOL_PROMPT_MM,
WEB_SURFER_TOOL_PROMPT_TEXT,
)
from ._set_of_mark import add_set_of_mark
from ._tool_definitions import (
TOOL_CLICK,
@ -56,7 +63,6 @@ from ._tool_definitions import (
TOOL_WEB_SEARCH,
)
from ._types import InteractiveRegion, UserContent
from ._utils import message_content_to_str
from .playwright_controller import PlaywrightController
@ -215,8 +221,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
raise ValueError(
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
)
if model_client.model_info["vision"] is False:
raise ValueError("The model is not multimodal. MultimodalWebSurfer requires a multimodal model.")
self._model_client = model_client
self.headless = headless
self.browser_channel = browser_channel
@ -404,7 +409,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
self.model_usage: List[RequestUsage] = []
try:
content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=message_content_to_str(content), source=self.name))
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
final_usage = RequestUsage(
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
@ -434,22 +439,8 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
assert self._page is not None
# Clone the messages to give context, removing old screenshots
history: List[LLMMessage] = []
for m in self._chat_history:
assert isinstance(m, UserMessage | AssistantMessage | SystemMessage)
assert isinstance(m.content, str | list)
if isinstance(m.content, str):
history.append(m)
else:
content = message_content_to_str(m.content)
if isinstance(m, UserMessage):
history.append(UserMessage(content=content, source=m.source))
elif isinstance(m, AssistantMessage):
history.append(AssistantMessage(content=content, source=m.source))
elif isinstance(m, SystemMessage):
history.append(SystemMessage(content=content))
# Clone the messages, removing old screenshots
history: List[LLMMessage] = remove_images(self._chat_history)
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
rects = await self._playwright_controller.get_interactive_rects(self._page)
@ -512,22 +503,37 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
tool_names = "\n".join([t["name"] for t in tools])
text_prompt = WEB_SURFER_TOOL_PROMPT.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
).strip()
if self._model_client.model_info["vision"]:
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
).strip()
# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
# Add the multimodal message and make the request
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
# Add the message
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
else:
visible_text = await self._playwright_controller.get_visible_text(self._page)
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
visible_text=visible_text.strip(),
).strip()
# Add the message
history.append(UserMessage(content=text_prompt, source=self.name))
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token

View File

@ -1,4 +1,4 @@
WEB_SURFER_TOOL_PROMPT = """
WEB_SURFER_TOOL_PROMPT_MM = """
Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{visible_targets}{other_targets_str}{focused_hint}
@ -13,6 +13,27 @@ When deciding between tools, consider if the request can be best addressed by:
- on some other website entirely (in which case actions like performing a new web search might be the best option)
"""
WEB_SURFER_TOOL_PROMPT_TEXT = """
Your web browser is open to the page '{url}'. The following text is visible in the viewport:
```
{visible_text}
```
You have also identified the following interactive components:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text might be most appropriate, or hovering over element)
- contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate)
- on some other website entirely (in which case actions like performing a new web search might be the best option)
"""
WEB_SURFER_OCR_PROMPT = """
Please transcribe all visible text on this page, including both main content and the labels of UI elements.
"""

View File

@ -1,25 +0,0 @@
from typing import List
from autogen_core import Image
from ._types import AssistantContent, FunctionExecutionContent, SystemContent, UserContent
# Convert UserContent to a string
def message_content_to_str(
message_content: UserContent | AssistantContent | SystemContent | FunctionExecutionContent,
) -> str:
if isinstance(message_content, str):
return message_content
elif isinstance(message_content, List):
converted: List[str] = list()
for item in message_content:
if isinstance(item, str):
converted.append(item.rstrip())
elif isinstance(item, Image):
converted.append("<Image>")
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
else:
raise AssertionError("Unexpected response type.")

View File

@ -142,7 +142,7 @@ class MagenticOne(MagenticOneGroupChat):
def _validate_client_capabilities(self, client: ChatCompletionClient) -> None:
capabilities = client.model_info
required_capabilities = ["vision", "function_calling", "json_output"]
required_capabilities = ["function_calling", "json_output"]
if not all(capabilities.get(cap) for cap in required_capabilities):
warnings.warn(