Gunnar Kudrjavets f68c09b035
Validate the OpenAI API key format (#1635)
* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

*  Add unit test coverage for the `is_valid_api_key` function.

* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

*Add unit test coverage for the `is_valid_api_key` function.

* Log a warning when register a default client fails.

* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`. We'll log a
  warning when the OpenAI API key isn't valid.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

* Add unit test coverage for the `is_valid_api_key` function.

* Check for OpenAI base_url before API key validation

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
2024-02-14 18:51:38 +00:00

134 lines
5.1 KiB
Python

import unittest
from unittest.mock import MagicMock
import pytest
import autogen
from autogen.agentchat.conversable_agent import ConversableAgent
from conftest import MOCK_OPEN_AI_API_KEY
try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
except ImportError:
skip = True
else:
skip = False
base64_encoded_image = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestMultimodalConversableAgent(unittest.TestCase):
def setUp(self):
self.agent = MultimodalConversableAgent(
name="TestAgent",
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": MOCK_OPEN_AI_API_KEY}],
},
)
def test_system_message(self):
# Test default system message
self.assertEqual(
self.agent.system_message,
[
{
"type": "text",
"text": "You are a helpful AI assistant.",
}
],
)
# Test updating system message
new_message = f"We will discuss <img {base64_encoded_image}> in this conversation."
self.agent.update_system_message(new_message)
self.assertEqual(
self.agent.system_message,
[
{"type": "text", "text": "We will discuss "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "text", "text": " in this conversation."},
],
)
def test_message_to_dict(self):
# Test string message
message_str = "Hello"
expected_dict = {"content": [{"type": "text", "text": "Hello"}]}
self.assertDictEqual(self.agent._message_to_dict(message_str), expected_dict)
# Test list message
message_list = [{"type": "text", "text": "Hello"}]
expected_dict = {"content": message_list}
self.assertDictEqual(self.agent._message_to_dict(message_list), expected_dict)
# Test dictionary message
message_dict = {"content": [{"type": "text", "text": "Hello"}]}
self.assertDictEqual(self.agent._message_to_dict(message_dict), message_dict)
def test_print_received_message(self):
sender = ConversableAgent(name="SenderAgent", llm_config=False, code_execution_config=False)
message_str = "Hello"
self.agent._print_received_message = MagicMock() # Mocking print method to avoid actual print
self.agent._print_received_message(message_str, sender)
self.agent._print_received_message.assert_called_with(message_str, sender)
@pytest.mark.skipif(skip, reason="Dependency not installed")
def test_group_chat_with_lmm():
"""
Tests the group chat functionality with two MultimodalConversable Agents.
Verifies that the chat is correctly limited by the max_round parameter.
Each agent is set to describe an image in a unique style, but the chat should not exceed the specified max_rounds.
"""
# Configuration parameters
max_round = 5
max_consecutive_auto_reply = 10
llm_config = False
# Creating two MultimodalConversable Agents with different descriptive styles
agent1 = MultimodalConversableAgent(
name="image-explainer-1",
max_consecutive_auto_reply=max_consecutive_auto_reply,
llm_config=llm_config,
system_message="Your image description is poetic and engaging.",
)
agent2 = MultimodalConversableAgent(
name="image-explainer-2",
max_consecutive_auto_reply=max_consecutive_auto_reply,
llm_config=llm_config,
system_message="Your image description is factual and to the point.",
)
# Creating a user proxy agent for initiating the group chat
user_proxy = autogen.UserProxyAgent(
name="User_proxy",
system_message="Ask both image explainer 1 and 2 for their description.",
human_input_mode="NEVER", # Options: 'ALWAYS' or 'NEVER'
max_consecutive_auto_reply=max_consecutive_auto_reply,
)
# Setting up the group chat
groupchat = autogen.GroupChat(agents=[agent1, agent2, user_proxy], messages=[], max_round=max_round)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
# Initiating the group chat and observing the number of rounds
user_proxy.initiate_chat(group_chat_manager, message=f"What do you see? <img {base64_encoded_image}>")
# Assertions to check if the number of rounds does not exceed max_round
assert all(len(arr) <= max_round for arr in agent1._oai_messages.values()), "Agent 1 exceeded max rounds"
assert all(len(arr) <= max_round for arr in agent2._oai_messages.values()), "Agent 2 exceeded max rounds"
assert all(len(arr) <= max_round for arr in user_proxy._oai_messages.values()), "User proxy exceeded max rounds"
if __name__ == "__main__":
unittest.main()