Use PIL Image internally for the Multimodal Agent (#1124)

* Change defualt model for `lmm`

* Try to use PIL image for LMM's _oai_messages

* Update test cases and llava

* Remove redundant files

* Update the imports for lmm tests

* Test case fix

* Docstring update

* LMM notebook lint

* Typo correction for img_utils and its test

* Update test_llava.py

debug, reformat

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Shaokun Zhang <shaokunzhang529@gmail.com>
Co-authored-by: Shaokun Zhang <shaokun.zhang@psu.edu>
This commit is contained in:
Beibin Li 2024-02-18 07:08:55 -08:00 committed by GitHub
parent 2d29d36b1d
commit 9de374a495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 493 additions and 192 deletions

View File

@ -1,5 +1,7 @@
import base64
import copy
import mimetypes
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
@ -8,17 +10,63 @@ import requests
from PIL import Image
def get_image_data(image_file: str, use_b64=True) -> bytes:
def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
"""
Loads an image from a file and returns a PIL Image object.
Parameters:
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file
if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
content = response.content
content = BytesIO(response.content)
image = Image.open(content)
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
image = _to_pil(base64_data)
elif os.path.exists(image_file):
# A local file
image = Image.open(image_file)
else:
image = Image.open(image_file).convert("RGB")
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
# base64 encoded string
image = _to_pil(image_file)
return image.convert("RGB")
def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
"""
Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
the `get_pil_image` function. It then saves this image in memory in PNG format and
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
either returned directly or as a base64-encoded string.
Parameters:
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.
Returns:
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
if `use_b64` is True.
"""
image = get_pil_image(image_file)
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
@ -72,6 +120,22 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str,
return new_prompt, images
def pil_to_data_uri(image: Image.Image) -> str:
"""
Converts a PIL Image object to a data URI.
Parameters:
image (Image.Image): The PIL Image object.
Returns:
str: The data URI string.
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
@ -92,16 +156,19 @@ def convert_base64_to_data_uri(base64_image):
return data_uri
def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.
Parameters:
Args:
- prompt (str): The input string that may contain image tags like <img ...>.
- img_format (str): what image format should be used. One of "uri", "url", "pil".
Returns:
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
"""
assert img_format in ["uri", "url", "pil"]
output = []
last_index = 0
image_count = 0
@ -114,7 +181,15 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
image_location = match.group(1)
try:
img_data = get_image_data(image_location)
if img_format == "pil":
img_data = get_pil_image(image_location)
elif img_format == "uri":
img_data = get_image_data(image_location)
img_data = convert_base64_to_data_uri(img_data)
elif img_format == "url":
img_data = image_location
else:
raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
@ -124,7 +199,7 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
output.append({"type": "text", "text": prompt[last_index : match.start()]})
# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
output.append({"type": "image_url", "image_url": {"url": img_data}})
last_index = match.end()
image_count += 1
@ -162,9 +237,61 @@ def _to_pil(data: str) -> Image.Image:
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The base64 encoded image data string.
data (str): The encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))
def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
"""
Converts the PIL image URLs in the messages to base64 encoded data URIs.
This function iterates over a list of message dictionaries. For each message,
if it contains a 'content' key with a list of items, it looks for items
with an 'image_url' key. The function then converts the PIL image URL
(pointed to by 'image_url') to a base64 encoded data URI.
Parameters:
messages (List[Dict]): A list of message dictionaries. Each dictionary
may contain a 'content' key with a list of items,
some of which might be image URLs.
Returns:
List[Dict]: A new list of message dictionaries with PIL image URLs in the
'image_url' key converted to base64 encoded data URIs.
Example Input:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
Example Output:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
"""
new_messages = []
for message in messages:
# Handle the new GPT messages format.
if isinstance(message, dict) and "content" in message and isinstance(message["content"], list):
message = copy.deepcopy(message)
for item in message["content"]:
if isinstance(item, dict) and "image_url" in item:
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
new_messages.append(message)
return new_messages

View File

@ -77,7 +77,9 @@ class LLaVAAgent(MultimodalConversableAgent):
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]
# TODO: PIL to base64
images = [get_image_data(im) for im in images]
print(colored(prompt, "blue"))
out = ""

View File

@ -3,7 +3,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent
from autogen.agentchat.contrib.img_utils import gpt4v_formatter
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
gpt4v_formatter,
message_formatter_pil_to_b64,
pil_to_data_uri,
)
from ..._pydantic import model_dump
try:
from termcolor import colored
@ -55,6 +62,21 @@ class MultimodalConversableAgent(ConversableAgent):
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
# Override the `generate_oai_reply`
def _replace_reply_func(arr, x, y):
for item in arr:
if item["reply_func"] is x:
item["reply_func"] = y
_replace_reply_func(
self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply
)
_replace_reply_func(
self._reply_func_list,
ConversableAgent.a_generate_oai_reply,
MultimodalConversableAgent.a_generate_oai_reply,
)
def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.
@ -76,14 +98,14 @@ class MultimodalConversableAgent(ConversableAgent):
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message)}
return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"])
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
@ -91,3 +113,27 @@ class MultimodalConversableAgent(ConversableAgent):
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")
def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
# TODO: #1143 handle token limit exceeded error
response = client.create(context=messages[-1].pop("context", None), messages=messages_with_b64_img)
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = model_dump(extracted_response)
return True, extracted_response

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,5 @@
import base64
import os
import pdb
import unittest
from unittest.mock import patch
@ -8,9 +7,18 @@ import pytest
import requests
try:
import numpy as np
from PIL import Image
from autogen.agentchat.contrib.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formatter
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
extract_img_paths,
get_image_data,
get_pil_image,
gpt4v_formatter,
llava_formatter,
message_formatter_pil_to_b64,
)
except ImportError:
skip = True
else:
@ -18,7 +26,8 @@ else:
base64_encoded_image = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
"data:image/png;base64,"
"iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
@ -27,6 +36,35 @@ raw_encoded_image = (
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
if skip:
raw_pil_image = None
else:
raw_pil_image = Image.new("RGB", (10, 10), color="red")
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestGetPilImage(unittest.TestCase):
def test_read_local_file(self):
# Create a small red image for testing
temp_file = "_temp.png"
raw_pil_image.save(temp_file)
img2 = get_pil_image(temp_file)
self.assertTrue((np.array(raw_pil_image) == np.array(img2)).all())
def test_read_pil(self):
# Create a small red image for testing
img2 = get_pil_image(raw_pil_image)
self.assertTrue((np.array(raw_pil_image) == np.array(img2)).all())
def are_b64_images_equal(x: str, y: str):
"""
Asserts that two base64 encoded images are equal.
"""
img1 = get_pil_image(x)
img2 = get_pil_image(y)
return (np.array(img1) == np.array(img2)).all()
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestGetImageData(unittest.TestCase):
@ -34,20 +72,20 @@ class TestGetImageData(unittest.TestCase):
with patch("requests.get") as mock_get:
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = b"fake image content"
mock_response._content = base64.b64decode(raw_encoded_image)
mock_get.return_value = mock_response
result = get_image_data("http://example.com/image.png")
self.assertEqual(result, base64.b64encode(b"fake image content").decode("utf-8"))
self.assertTrue(are_b64_images_equal(result, raw_encoded_image))
def test_base64_encoded_image(self):
result = get_image_data(base64_encoded_image)
self.assertEqual(result, base64_encoded_image.split(",", 1)[1])
self.assertTrue(are_b64_images_equal(result, base64_encoded_image.split(",", 1)[1]))
def test_local_image(self):
# Create a temporary file to simulate a local image file.
temp_file = "_temp.png"
image = Image.new("RGB", (60, 30), color=(73, 109, 137))
image.save(temp_file)
@ -126,6 +164,36 @@ class TestGpt4vFormatter(unittest.TestCase):
result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output)
@patch("autogen.agentchat.contrib.img_utils.get_pil_image")
def test_with_images_for_pil(self, mock_get_pil_image):
"""
Test the gpt4v_formatter function with a prompt containing images.
"""
# Mock the get_image_data function to return a fixed string.
mock_get_pil_image.return_value = raw_pil_image
prompt = "This is a test with an image <img http://example.com/image.png>."
expected_output = [
{"type": "text", "text": "This is a test with an image "},
{"type": "image_url", "image_url": {"url": raw_pil_image}},
{"type": "text", "text": "."},
]
result = gpt4v_formatter(prompt, img_format="pil")
self.assertEqual(result, expected_output)
def test_with_images_for_url(self):
"""
Test the gpt4v_formatter function with a prompt containing images.
"""
prompt = "This is a test with an image <img http://example.com/image.png>."
expected_output = [
{"type": "text", "text": "This is a test with an image "},
{"type": "image_url", "image_url": {"url": "http://example.com/image.png"}},
{"type": "text", "text": "."},
]
result = gpt4v_formatter(prompt, img_format="url")
self.assertEqual(result, expected_output)
@patch("autogen.agentchat.contrib.img_utils.get_image_data")
def test_multiple_images(self, mock_get_image_data):
"""
@ -189,5 +257,36 @@ class TestExtractImgPaths(unittest.TestCase):
self.assertEqual(result, expected_output)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class MessageFormatterPILtoB64Test(unittest.TestCase):
def test_formatting(self):
messages = [
{"content": [{"type": "text", "text": "You are a helpful AI assistant."}], "role": "system"},
{
"content": [
{"type": "text", "text": "What's the breed of this dog here? \n"},
{"type": "image_url", "image_url": {"url": raw_pil_image}},
{"type": "text", "text": "."},
],
"role": "user",
},
]
img_uri_data = convert_base64_to_data_uri(get_image_data(raw_pil_image))
expected_output = [
{"content": [{"type": "text", "text": "You are a helpful AI assistant."}], "role": "system"},
{
"content": [
{"type": "text", "text": "What's the breed of this dog here? \n"},
{"type": "image_url", "image_url": {"url": img_uri_data}},
{"type": "text", "text": "."},
],
"role": "user",
},
]
result = message_formatter_pil_to_b64(messages)
self.assertEqual(result, expected_output)
if __name__ == "__main__":
unittest.main()

View File

@ -8,14 +8,10 @@ import autogen
from conftest import MOCK_OPEN_AI_API_KEY
try:
from autogen.agentchat.contrib.llava_agent import (
LLaVAAgent,
_llava_call_binary_with_config,
llava_call,
llava_call_binary,
)
from autogen.agentchat.contrib.llava_agent import LLaVAAgent, _llava_call_binary_with_config, llava_call
except ImportError:
skip = True
else:
skip = False

View File

@ -9,6 +9,7 @@ from autogen.agentchat.conversable_agent import ConversableAgent
from conftest import MOCK_OPEN_AI_API_KEY
try:
from autogen.agentchat.contrib.img_utils import get_pil_image
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
except ImportError:
skip = True
@ -22,6 +23,12 @@ base64_encoded_image = (
)
if skip:
pil_image = None
else:
pil_image = get_pil_image(base64_encoded_image)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestMultimodalConversableAgent(unittest.TestCase):
def setUp(self):
@ -53,7 +60,7 @@ class TestMultimodalConversableAgent(unittest.TestCase):
self.agent.system_message,
[
{"type": "text", "text": "We will discuss "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "image_url", "image_url": {"url": pil_image}},
{"type": "text", "text": " in this conversation."},
],
)