mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 07:29:54 +00:00
Use PIL Image internally for the Multimodal Agent (#1124)
* Change defualt model for `lmm` * Try to use PIL image for LMM's _oai_messages * Update test cases and llava * Remove redundant files * Update the imports for lmm tests * Test case fix * Docstring update * LMM notebook lint * Typo correction for img_utils and its test * Update test_llava.py debug, reformat --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Shaokun Zhang <shaokunzhang529@gmail.com> Co-authored-by: Shaokun Zhang <shaokun.zhang@psu.edu>
This commit is contained in:
parent
2d29d36b1d
commit
9de374a495
@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import copy
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
@ -8,17 +10,63 @@ import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_image_data(image_file: str, use_b64=True) -> bytes:
|
||||
def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
|
||||
"""
|
||||
Loads an image from a file and returns a PIL Image object.
|
||||
|
||||
Parameters:
|
||||
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
|
||||
|
||||
Returns:
|
||||
Image.Image: The PIL Image object.
|
||||
"""
|
||||
if isinstance(image_file, Image.Image):
|
||||
# Already a PIL Image object
|
||||
return image_file
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
# A URL file
|
||||
response = requests.get(image_file)
|
||||
content = response.content
|
||||
content = BytesIO(response.content)
|
||||
image = Image.open(content)
|
||||
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
|
||||
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
|
||||
image = _to_pil(base64_data)
|
||||
elif os.path.exists(image_file):
|
||||
# A local file
|
||||
image = Image.open(image_file)
|
||||
else:
|
||||
image = Image.open(image_file).convert("RGB")
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
# base64 encoded string
|
||||
image = _to_pil(image_file)
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
|
||||
"""
|
||||
Loads an image and returns its data either as raw bytes or in base64-encoded format.
|
||||
|
||||
This function first loads an image from the specified file, URL, or base64 string using
|
||||
the `get_pil_image` function. It then saves this image in memory in PNG format and
|
||||
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
|
||||
either returned directly or as a base64-encoded string.
|
||||
|
||||
Parameters:
|
||||
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
|
||||
string of the image.
|
||||
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
|
||||
If False, it returns the raw byte data of the image. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
|
||||
if `use_b64` is True.
|
||||
"""
|
||||
image = get_pil_image(image_file)
|
||||
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
|
||||
if use_b64:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
@ -72,6 +120,22 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str,
|
||||
return new_prompt, images
|
||||
|
||||
|
||||
def pil_to_data_uri(image: Image.Image) -> str:
|
||||
"""
|
||||
Converts a PIL Image object to a data URI.
|
||||
|
||||
Parameters:
|
||||
image (Image.Image): The PIL Image object.
|
||||
|
||||
Returns:
|
||||
str: The data URI string.
|
||||
"""
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||
|
||||
|
||||
def convert_base64_to_data_uri(base64_image):
|
||||
def _get_mime_type_from_data_uri(base64_image):
|
||||
# Decode the base64 string
|
||||
@ -92,16 +156,19 @@ def convert_base64_to_data_uri(base64_image):
|
||||
return data_uri
|
||||
|
||||
|
||||
def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
|
||||
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
|
||||
"""
|
||||
Formats the input prompt by replacing image tags and returns a list of text and images.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
- prompt (str): The input string that may contain image tags like <img ...>.
|
||||
- img_format (str): what image format should be used. One of "uri", "url", "pil".
|
||||
|
||||
Returns:
|
||||
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
|
||||
"""
|
||||
assert img_format in ["uri", "url", "pil"]
|
||||
|
||||
output = []
|
||||
last_index = 0
|
||||
image_count = 0
|
||||
@ -114,7 +181,15 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
|
||||
image_location = match.group(1)
|
||||
|
||||
try:
|
||||
img_data = get_image_data(image_location)
|
||||
if img_format == "pil":
|
||||
img_data = get_pil_image(image_location)
|
||||
elif img_format == "uri":
|
||||
img_data = get_image_data(image_location)
|
||||
img_data = convert_base64_to_data_uri(img_data)
|
||||
elif img_format == "url":
|
||||
img_data = image_location
|
||||
else:
|
||||
raise ValueError(f"Unknown image format {img_format}")
|
||||
except Exception as e:
|
||||
# Warning and skip this token
|
||||
print(f"Warning! Unable to load image from {image_location}, because {e}")
|
||||
@ -124,7 +199,7 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
|
||||
output.append({"type": "text", "text": prompt[last_index : match.start()]})
|
||||
|
||||
# Add image data to output list
|
||||
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
|
||||
output.append({"type": "image_url", "image_url": {"url": img_data}})
|
||||
|
||||
last_index = match.end()
|
||||
image_count += 1
|
||||
@ -162,9 +237,61 @@ def _to_pil(data: str) -> Image.Image:
|
||||
and finally creates and returns a PIL Image object from the BytesIO object.
|
||||
|
||||
Parameters:
|
||||
data (str): The base64 encoded image data string.
|
||||
data (str): The encoded image data string.
|
||||
|
||||
Returns:
|
||||
Image.Image: The PIL Image object created from the input data.
|
||||
"""
|
||||
return Image.open(BytesIO(base64.b64decode(data)))
|
||||
|
||||
|
||||
def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Converts the PIL image URLs in the messages to base64 encoded data URIs.
|
||||
|
||||
This function iterates over a list of message dictionaries. For each message,
|
||||
if it contains a 'content' key with a list of items, it looks for items
|
||||
with an 'image_url' key. The function then converts the PIL image URL
|
||||
(pointed to by 'image_url') to a base64 encoded data URI.
|
||||
|
||||
Parameters:
|
||||
messages (List[Dict]): A list of message dictionaries. Each dictionary
|
||||
may contain a 'content' key with a list of items,
|
||||
some of which might be image URLs.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A new list of message dictionaries with PIL image URLs in the
|
||||
'image_url' key converted to base64 encoded data URIs.
|
||||
|
||||
Example Input:
|
||||
[
|
||||
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||
{'content': [
|
||||
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
|
||||
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
|
||||
{'type': 'text', 'text': '.'}],
|
||||
'role': 'user'}
|
||||
]
|
||||
|
||||
Example Output:
|
||||
[
|
||||
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||
{'content': [
|
||||
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
|
||||
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
|
||||
{'type': 'text', 'text': '.'}],
|
||||
'role': 'user'}
|
||||
]
|
||||
"""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
# Handle the new GPT messages format.
|
||||
if isinstance(message, dict) and "content" in message and isinstance(message["content"], list):
|
||||
message = copy.deepcopy(message)
|
||||
for item in message["content"]:
|
||||
if isinstance(item, dict) and "image_url" in item:
|
||||
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
|
||||
|
||||
new_messages.append(message)
|
||||
|
||||
return new_messages
|
||||
|
||||
@ -77,7 +77,9 @@ class LLaVAAgent(MultimodalConversableAgent):
|
||||
content_prompt = content_str(msg["content"])
|
||||
prompt += f"{SEP}{role}: {content_prompt}\n"
|
||||
prompt += "\n" + SEP + "Assistant: "
|
||||
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]
|
||||
|
||||
# TODO: PIL to base64
|
||||
images = [get_image_data(im) for im in images]
|
||||
print(colored(prompt, "blue"))
|
||||
|
||||
out = ""
|
||||
|
||||
@ -3,7 +3,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from autogen import OpenAIWrapper
|
||||
from autogen.agentchat import Agent, ConversableAgent
|
||||
from autogen.agentchat.contrib.img_utils import gpt4v_formatter
|
||||
from autogen.agentchat.contrib.img_utils import (
|
||||
convert_base64_to_data_uri,
|
||||
gpt4v_formatter,
|
||||
message_formatter_pil_to_b64,
|
||||
pil_to_data_uri,
|
||||
)
|
||||
|
||||
from ..._pydantic import model_dump
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
@ -55,6 +62,21 @@ class MultimodalConversableAgent(ConversableAgent):
|
||||
else (lambda x: content_str(x.get("content")) == "TERMINATE")
|
||||
)
|
||||
|
||||
# Override the `generate_oai_reply`
|
||||
def _replace_reply_func(arr, x, y):
|
||||
for item in arr:
|
||||
if item["reply_func"] is x:
|
||||
item["reply_func"] = y
|
||||
|
||||
_replace_reply_func(
|
||||
self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply
|
||||
)
|
||||
_replace_reply_func(
|
||||
self._reply_func_list,
|
||||
ConversableAgent.a_generate_oai_reply,
|
||||
MultimodalConversableAgent.a_generate_oai_reply,
|
||||
)
|
||||
|
||||
def update_system_message(self, system_message: Union[Dict, List, str]):
|
||||
"""Update the system message.
|
||||
|
||||
@ -76,14 +98,14 @@ class MultimodalConversableAgent(ConversableAgent):
|
||||
will be processed using the gpt4v_formatter.
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return {"content": gpt4v_formatter(message)}
|
||||
return {"content": gpt4v_formatter(message, img_format="pil")}
|
||||
if isinstance(message, list):
|
||||
return {"content": message}
|
||||
if isinstance(message, dict):
|
||||
assert "content" in message, "The message dict must have a `content` field"
|
||||
if isinstance(message["content"], str):
|
||||
message = copy.deepcopy(message)
|
||||
message["content"] = gpt4v_formatter(message["content"])
|
||||
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
|
||||
try:
|
||||
content_str(message["content"])
|
||||
except (TypeError, ValueError) as e:
|
||||
@ -91,3 +113,27 @@ class MultimodalConversableAgent(ConversableAgent):
|
||||
raise e
|
||||
return message
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
def generate_oai_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[OpenAIWrapper] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
"""Generate a reply using autogen.oai."""
|
||||
client = self.client if config is None else config
|
||||
if client is None:
|
||||
return False, None
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = client.create(context=messages[-1].pop("context", None), messages=messages_with_b64_img)
|
||||
|
||||
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
if not isinstance(extracted_response, str):
|
||||
extracted_response = model_dump(extracted_response)
|
||||
return True, extracted_response
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
import pdb
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -8,9 +7,18 @@ import pytest
|
||||
import requests
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from autogen.agentchat.contrib.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formatter
|
||||
from autogen.agentchat.contrib.img_utils import (
|
||||
convert_base64_to_data_uri,
|
||||
extract_img_paths,
|
||||
get_image_data,
|
||||
get_pil_image,
|
||||
gpt4v_formatter,
|
||||
llava_formatter,
|
||||
message_formatter_pil_to_b64,
|
||||
)
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
@ -18,7 +26,8 @@ else:
|
||||
|
||||
|
||||
base64_encoded_image = (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
|
||||
"data:image/png;base64,"
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
|
||||
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
@ -27,6 +36,35 @@ raw_encoded_image = (
|
||||
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
if skip:
|
||||
raw_pil_image = None
|
||||
else:
|
||||
raw_pil_image = Image.new("RGB", (10, 10), color="red")
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestGetPilImage(unittest.TestCase):
|
||||
def test_read_local_file(self):
|
||||
# Create a small red image for testing
|
||||
temp_file = "_temp.png"
|
||||
raw_pil_image.save(temp_file)
|
||||
img2 = get_pil_image(temp_file)
|
||||
self.assertTrue((np.array(raw_pil_image) == np.array(img2)).all())
|
||||
|
||||
def test_read_pil(self):
|
||||
# Create a small red image for testing
|
||||
img2 = get_pil_image(raw_pil_image)
|
||||
self.assertTrue((np.array(raw_pil_image) == np.array(img2)).all())
|
||||
|
||||
|
||||
def are_b64_images_equal(x: str, y: str):
|
||||
"""
|
||||
Asserts that two base64 encoded images are equal.
|
||||
"""
|
||||
img1 = get_pil_image(x)
|
||||
img2 = get_pil_image(y)
|
||||
return (np.array(img1) == np.array(img2)).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestGetImageData(unittest.TestCase):
|
||||
@ -34,20 +72,20 @@ class TestGetImageData(unittest.TestCase):
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = requests.Response()
|
||||
mock_response.status_code = 200
|
||||
mock_response._content = b"fake image content"
|
||||
mock_response._content = base64.b64decode(raw_encoded_image)
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = get_image_data("http://example.com/image.png")
|
||||
self.assertEqual(result, base64.b64encode(b"fake image content").decode("utf-8"))
|
||||
self.assertTrue(are_b64_images_equal(result, raw_encoded_image))
|
||||
|
||||
def test_base64_encoded_image(self):
|
||||
result = get_image_data(base64_encoded_image)
|
||||
self.assertEqual(result, base64_encoded_image.split(",", 1)[1])
|
||||
|
||||
self.assertTrue(are_b64_images_equal(result, base64_encoded_image.split(",", 1)[1]))
|
||||
|
||||
def test_local_image(self):
|
||||
# Create a temporary file to simulate a local image file.
|
||||
temp_file = "_temp.png"
|
||||
|
||||
image = Image.new("RGB", (60, 30), color=(73, 109, 137))
|
||||
image.save(temp_file)
|
||||
|
||||
@ -126,6 +164,36 @@ class TestGpt4vFormatter(unittest.TestCase):
|
||||
result = gpt4v_formatter(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.agentchat.contrib.img_utils.get_pil_image")
|
||||
def test_with_images_for_pil(self, mock_get_pil_image):
|
||||
"""
|
||||
Test the gpt4v_formatter function with a prompt containing images.
|
||||
"""
|
||||
# Mock the get_image_data function to return a fixed string.
|
||||
mock_get_pil_image.return_value = raw_pil_image
|
||||
|
||||
prompt = "This is a test with an image <img http://example.com/image.png>."
|
||||
expected_output = [
|
||||
{"type": "text", "text": "This is a test with an image "},
|
||||
{"type": "image_url", "image_url": {"url": raw_pil_image}},
|
||||
{"type": "text", "text": "."},
|
||||
]
|
||||
result = gpt4v_formatter(prompt, img_format="pil")
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
def test_with_images_for_url(self):
|
||||
"""
|
||||
Test the gpt4v_formatter function with a prompt containing images.
|
||||
"""
|
||||
prompt = "This is a test with an image <img http://example.com/image.png>."
|
||||
expected_output = [
|
||||
{"type": "text", "text": "This is a test with an image "},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.png"}},
|
||||
{"type": "text", "text": "."},
|
||||
]
|
||||
result = gpt4v_formatter(prompt, img_format="url")
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.agentchat.contrib.img_utils.get_image_data")
|
||||
def test_multiple_images(self, mock_get_image_data):
|
||||
"""
|
||||
@ -189,5 +257,36 @@ class TestExtractImgPaths(unittest.TestCase):
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class MessageFormatterPILtoB64Test(unittest.TestCase):
|
||||
def test_formatting(self):
|
||||
messages = [
|
||||
{"content": [{"type": "text", "text": "You are a helpful AI assistant."}], "role": "system"},
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "What's the breed of this dog here? \n"},
|
||||
{"type": "image_url", "image_url": {"url": raw_pil_image}},
|
||||
{"type": "text", "text": "."},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
]
|
||||
|
||||
img_uri_data = convert_base64_to_data_uri(get_image_data(raw_pil_image))
|
||||
expected_output = [
|
||||
{"content": [{"type": "text", "text": "You are a helpful AI assistant."}], "role": "system"},
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "What's the breed of this dog here? \n"},
|
||||
{"type": "image_url", "image_url": {"url": img_uri_data}},
|
||||
{"type": "text", "text": "."},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
]
|
||||
result = message_formatter_pil_to_b64(messages)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -8,14 +8,10 @@ import autogen
|
||||
from conftest import MOCK_OPEN_AI_API_KEY
|
||||
|
||||
try:
|
||||
from autogen.agentchat.contrib.llava_agent import (
|
||||
LLaVAAgent,
|
||||
_llava_call_binary_with_config,
|
||||
llava_call,
|
||||
llava_call_binary,
|
||||
)
|
||||
from autogen.agentchat.contrib.llava_agent import LLaVAAgent, _llava_call_binary_with_config, llava_call
|
||||
except ImportError:
|
||||
skip = True
|
||||
|
||||
else:
|
||||
skip = False
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from autogen.agentchat.conversable_agent import ConversableAgent
|
||||
from conftest import MOCK_OPEN_AI_API_KEY
|
||||
|
||||
try:
|
||||
from autogen.agentchat.contrib.img_utils import get_pil_image
|
||||
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
||||
except ImportError:
|
||||
skip = True
|
||||
@ -22,6 +23,12 @@ base64_encoded_image = (
|
||||
)
|
||||
|
||||
|
||||
if skip:
|
||||
pil_image = None
|
||||
else:
|
||||
pil_image = get_pil_image(base64_encoded_image)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestMultimodalConversableAgent(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -53,7 +60,7 @@ class TestMultimodalConversableAgent(unittest.TestCase):
|
||||
self.agent.system_message,
|
||||
[
|
||||
{"type": "text", "text": "We will discuss "},
|
||||
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
||||
{"type": "image_url", "image_url": {"url": pil_image}},
|
||||
{"type": "text", "text": " in this conversation."},
|
||||
],
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user