mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-04 23:02:09 +00:00

* LMM Code added * LLaVA notebook update * Test cases and Notebook modified for OpenAI v1 * Move LMM into contrib To resolve test issues and deploy issues In the future, we can install pillow by default, and then move back LMM agents into agentchat * LMM test setup update * try...except... clause for LMM tests * disable patch for llava agent test To resolve dependencies issue for build * Add LMM Blog * Change docstring for LMM agents * Docstring update patch * llava: insert reply at position 1 now So, it can still handle human_input_mode and max_consecutive_reply * Resolve comments Fixing: typos, blogs, yml, and add OpenAIWrapper * Signature typo fix for LMM agent: system_message * Update LMM "content" from latest OpenAI release Reference https://platform.openai.com/docs/guides/vision * update LMM test according to latest OpenAI release * Fully support GPT-4V now 1. Add a notebook for GPT-4V. LLava notebook also updated. 2. img_utils updated 3. GPT-4V formatter now return base64 image with mime type 4. Infer mime type directly from b64 image content (while loading without suffix) 5. Test cases modified according to all the related changes. * GPT-4V link updated in blog --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import autogen
|
|
|
|
try:
|
|
from autogen.agentchat.contrib.llava_agent import (
|
|
LLaVAAgent,
|
|
_llava_call_binary_with_config,
|
|
llava_call,
|
|
llava_call_binary,
|
|
)
|
|
except ImportError:
|
|
skip = True
|
|
else:
|
|
skip = False
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestLLaVAAgent(unittest.TestCase):
|
|
def setUp(self):
|
|
self.agent = LLaVAAgent(
|
|
name="TestAgent",
|
|
llm_config={
|
|
"timeout": 600,
|
|
"seed": 42,
|
|
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}],
|
|
},
|
|
)
|
|
|
|
def test_init(self):
|
|
self.assertIsInstance(self.agent, LLaVAAgent)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestLLavaCallBinaryWithConfig(unittest.TestCase):
|
|
@patch("requests.post")
|
|
def test_local_mode(self, mock_post):
|
|
# Mocking the response of requests.post
|
|
mock_response = MagicMock()
|
|
mock_response.iter_lines.return_value = [b'{"text":"response text"}']
|
|
mock_post.return_value = mock_response
|
|
|
|
# Calling the function
|
|
output = _llava_call_binary_with_config(
|
|
prompt="Test Prompt",
|
|
images=[],
|
|
config={"base_url": "http://0.0.0.0/api", "model": "test-model"},
|
|
max_new_tokens=1000,
|
|
temperature=0.5,
|
|
seed=1,
|
|
)
|
|
|
|
# Verifying the results
|
|
self.assertEqual(output, "response text")
|
|
mock_post.assert_called_once_with(
|
|
"http://0.0.0.0/api/worker_generate_stream",
|
|
headers={"User-Agent": "LLaVA Client"},
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": "Test Prompt",
|
|
"max_new_tokens": 1000,
|
|
"temperature": 0.5,
|
|
"stop": "###",
|
|
"images": [],
|
|
},
|
|
stream=False,
|
|
)
|
|
|
|
@patch("replicate.run")
|
|
def test_remote_mode(self, mock_run):
|
|
# Mocking the response of replicate.run
|
|
mock_run.return_value = iter(["response ", "text"])
|
|
|
|
# Calling the function
|
|
output = _llava_call_binary_with_config(
|
|
prompt="Test Prompt",
|
|
images=["image_data"],
|
|
config={"base_url": "http://remote/api", "model": "test-model"},
|
|
max_new_tokens=1000,
|
|
temperature=0.5,
|
|
seed=1,
|
|
)
|
|
|
|
# Verifying the results
|
|
self.assertEqual(output, "response text")
|
|
mock_run.assert_called_once_with(
|
|
"http://remote/api",
|
|
input={"image": "data:image/jpeg;base64,image_data", "prompt": "Test Prompt", "seed": 1},
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestLLavaCall(unittest.TestCase):
|
|
@patch("autogen.agentchat.contrib.llava_agent.llava_formater")
|
|
@patch("autogen.agentchat.contrib.llava_agent.llava_call_binary")
|
|
def test_llava_call(self, mock_llava_call_binary, mock_llava_formater):
|
|
# Set up the mocks
|
|
mock_llava_formater.return_value = ("formatted prompt", ["image1", "image2"])
|
|
mock_llava_call_binary.return_value = "Generated Text"
|
|
|
|
# Set up the llm_config dictionary
|
|
llm_config = {
|
|
"config_list": [{"api_key": "value", "base_url": "localhost:8000"}],
|
|
"max_new_tokens": 2000,
|
|
"temperature": 0.5,
|
|
"seed": 1,
|
|
}
|
|
|
|
# Call the function
|
|
result = llava_call("Test Prompt", llm_config)
|
|
|
|
# Check the results
|
|
mock_llava_formater.assert_called_once_with("Test Prompt", order_image_tokens=False)
|
|
mock_llava_call_binary.assert_called_once_with(
|
|
"formatted prompt",
|
|
["image1", "image2"],
|
|
config_list=llm_config["config_list"],
|
|
max_new_tokens=2000,
|
|
temperature=0.5,
|
|
seed=1,
|
|
)
|
|
self.assertEqual(result, "Generated Text")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|