mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-10 09:42:18 +00:00

* Add isort * Apply isort on py files * Fix circular import * Fix format for notebooks * Fix format --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
227 lines
8.9 KiB
Python
227 lines
8.9 KiB
Python
import itertools
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import pytest
|
|
|
|
from autogen import code_utils
|
|
from autogen.agentchat.conversable_agent import ConversableAgent
|
|
from autogen.agentchat.user_proxy_agent import UserProxyAgent
|
|
from autogen.cache.cache import Cache
|
|
from autogen.oai import openai_utils
|
|
|
|
try:
|
|
from PIL import Image
|
|
|
|
from autogen.agentchat.contrib.capabilities import generate_images
|
|
from autogen.agentchat.contrib.img_utils import get_pil_image
|
|
except ImportError:
|
|
skip_requirement = True
|
|
else:
|
|
skip_requirement = False
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
|
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
|
|
|
|
filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
|
|
|
|
RESOLUTIONS = ["256x256", "512x512", "1024x1024"]
|
|
QUALITIES = ["standard", "hd"]
|
|
PROMPTS = [
|
|
"Generate an image of a robot holding a 'I Love Autogen' sign",
|
|
"Generate an image of a dog holding a 'I Love Autogen' sign",
|
|
]
|
|
|
|
|
|
class _TestImageGenerator:
|
|
def __init__(self, image):
|
|
self._image = image
|
|
|
|
def generate_image(self, prompt: str):
|
|
return self._image
|
|
|
|
def cache_key(self, prompt: str):
|
|
return prompt
|
|
|
|
|
|
def create_test_agent(name: str = "test_agent", default_auto_reply: str = "") -> ConversableAgent:
|
|
return ConversableAgent(name=name, llm_config=False, default_auto_reply=default_auto_reply)
|
|
|
|
|
|
def dalle_image_generator(dalle_config: Dict[str, Any], resolution: str, quality: str):
|
|
return generate_images.DalleImageGenerator(dalle_config, resolution=resolution, quality=quality, num_images=1)
|
|
|
|
|
|
def api_key():
|
|
return MOCK_OPEN_AI_API_KEY if skip_openai else os.environ.get("OPENAI_API_KEY")
|
|
|
|
|
|
@pytest.fixture
|
|
def dalle_config() -> Dict[str, Any]:
|
|
config_list = openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")
|
|
if not config_list:
|
|
config_list = [{"model": "dall-e-2", "api_key": api_key()}]
|
|
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
|
|
|
|
|
|
@pytest.fixture
|
|
def gpt3_config() -> Dict[str, Any]:
|
|
config_list = [
|
|
{
|
|
"model": "gpt-35-turbo-16k",
|
|
"api_key": api_key(),
|
|
},
|
|
{
|
|
"model": "gpt-3.5-turbo-16k",
|
|
"api_key": api_key(),
|
|
},
|
|
]
|
|
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
|
|
|
|
|
|
@pytest.fixture
|
|
def image_gen_capability():
|
|
image_generator = _TestImageGenerator(Image.new("RGB", (256, 256)))
|
|
return generate_images.ImageGeneration(image_generator)
|
|
|
|
|
|
@pytest.mark.skipif(skip_openai, reason="Requested to skip.")
|
|
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
|
def test_dalle_image_generator(dalle_config: Dict[str, Any]):
|
|
"""Tests DalleImageGenerator capability to generate images by calling the OpenAI API."""
|
|
dalle_generator = dalle_image_generator(dalle_config, RESOLUTIONS[0], QUALITIES[0])
|
|
image = dalle_generator.generate_image(PROMPTS[0])
|
|
|
|
assert isinstance(image, Image.Image)
|
|
|
|
|
|
# Using cartesian product to generate all possible combinations of resolution, quality, and prompt
|
|
@pytest.mark.parametrize("gen_config_1", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
|
|
@pytest.mark.parametrize("gen_config_2", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
|
|
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
|
def test_dalle_image_generator_cache_key(
|
|
dalle_config: Dict[str, Any], gen_config_1: Tuple[str, str, str], gen_config_2: Tuple[str, str, str]
|
|
):
|
|
"""Tests if DalleImageGenerator creates unique cache keys.
|
|
|
|
Args:
|
|
dalle_config: The LLM config for the DalleImageGenerator.
|
|
gen_config_1: A tuple containing the resolution, quality, and prompt for the first image generator.
|
|
gen_config_2: A tuple containing the resolution, quality, and prompt for the second image generator.
|
|
"""
|
|
|
|
dalle_generator_1 = dalle_image_generator(dalle_config, resolution=gen_config_1[0], quality=gen_config_1[1])
|
|
dalle_generator_2 = dalle_image_generator(dalle_config, resolution=gen_config_2[0], quality=gen_config_2[1])
|
|
|
|
cache_key_1 = dalle_generator_1.cache_key(gen_config_1[2])
|
|
cache_key_2 = dalle_generator_2.cache_key(gen_config_2[2])
|
|
|
|
if gen_config_1 == gen_config_2:
|
|
assert cache_key_1 == cache_key_2
|
|
else:
|
|
assert cache_key_1 != cache_key_2
|
|
|
|
|
|
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
|
def test_image_generation_capability_positive(monkeypatch, image_gen_capability):
|
|
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
|
|
|
|
This tests if the message is asking the agent to generate an image.
|
|
"""
|
|
auto_reply = "Didn't need to generate an image."
|
|
|
|
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent to make API calls
|
|
# Improves reproducibility and falkiness of the test
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
|
|
|
user = UserProxyAgent("user", human_input_mode="NEVER")
|
|
agent = create_test_agent(default_auto_reply=auto_reply)
|
|
image_gen_capability.add_to_agent(agent)
|
|
|
|
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
|
last_message = agent.last_message()
|
|
|
|
assert last_message
|
|
|
|
processed_message = code_utils.content_str(last_message["content"])
|
|
|
|
assert "<image>" in processed_message
|
|
assert auto_reply not in processed_message
|
|
|
|
|
|
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
|
def test_image_generation_capability_negative(monkeypatch, image_gen_capability):
|
|
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
|
|
|
|
This tests if the message is not asking the agent to generate an image.
|
|
"""
|
|
auto_reply = "Didn't need to generate an image."
|
|
|
|
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
|
|
# Improves reproducibility and flakiness of the test.
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: False)
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
|
|
|
user = UserProxyAgent("user", human_input_mode="NEVER")
|
|
agent = ConversableAgent("test_agent", llm_config=False, default_auto_reply=auto_reply)
|
|
image_gen_capability.add_to_agent(agent)
|
|
|
|
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
|
last_message = agent.last_message()
|
|
|
|
assert last_message
|
|
|
|
processed_message = code_utils.content_str(last_message["content"])
|
|
|
|
assert "<image>" not in processed_message
|
|
assert auto_reply == processed_message
|
|
|
|
|
|
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
|
def test_image_generation_capability_cache(monkeypatch):
|
|
"""Tests ImageGeneration capability to cache the generated images."""
|
|
test_image_size = (256, 256)
|
|
|
|
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
|
|
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
cache = Cache.disk(cache_path_root=temp_dir)
|
|
|
|
user = UserProxyAgent("user", human_input_mode="NEVER")
|
|
agent = create_test_agent()
|
|
|
|
test_image_generator = _TestImageGenerator(Image.new("RGB", test_image_size))
|
|
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
|
|
image_gen_capability.add_to_agent(agent)
|
|
|
|
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
|
|
|
# Checking if the image has been cached by creating a new agent with a different image generator.
|
|
agent = create_test_agent(name="test_agent_2")
|
|
test_image_generator = _TestImageGenerator(Image.new("RGB", (512, 512)))
|
|
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
|
|
image_gen_capability.add_to_agent(agent)
|
|
|
|
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
|
|
|
last_message = agent.last_message()
|
|
|
|
assert last_message
|
|
|
|
image_dict = [image for image in last_message["content"] if image["type"] == "image_url"]
|
|
image = get_pil_image(image_dict[0]["image_url"]["url"])
|
|
|
|
assert image.size == test_image_size
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_dalle_image_generator(
|
|
dalle_config={"config_list": openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")}
|
|
)
|