Validate the OpenAI API key format (#1635)

* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

*  Add unit test coverage for the `is_valid_api_key` function.

* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

*Add unit test coverage for the `is_valid_api_key` function.

* Log a warning when register a default client fails.

* Validate the OpenAI API key format

Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`. We'll log a
  warning when the OpenAI API key isn't valid.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

* Add unit test coverage for the `is_valid_api_key` function.

* Check for OpenAI base_url before API key validation

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Gunnar Kudrjavets 2024-02-14 10:51:38 -08:00 committed by GitHub
parent a62b5c3b2d
commit f68c09b035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 62 additions and 19 deletions

View File

@ -10,7 +10,7 @@ from pydantic import BaseModel
from typing import Protocol from typing import Protocol
from autogen.cache.cache import Cache from autogen.cache.cache import Cache
from autogen.oai.openai_utils import get_key, OAI_PRICE1K from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
from autogen.token_count_utils import count_token from autogen.token_count_utils import count_token
TOOL_ENABLED = False TOOL_ENABLED = False
@ -47,6 +47,7 @@ if not logger.handlers:
LEGACY_DEFAULT_CACHE_SEED = 41 LEGACY_DEFAULT_CACHE_SEED = 41
LEGACY_CACHE_DIR = ".cache" LEGACY_CACHE_DIR = ".cache"
OPEN_API_BASE_URL_PREFIX = "https://api.openai.com"
class ModelClient(Protocol): class ModelClient(Protocol):
@ -111,6 +112,14 @@ class OpenAIClient:
def __init__(self, client: Union[OpenAI, AzureOpenAI]): def __init__(self, client: Union[OpenAI, AzureOpenAI]):
self._oai_client = client self._oai_client = client
if (
not isinstance(client, openai.AzureOpenAI)
and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
and not is_valid_api_key(self._oai_client.api_key)
):
logger.warning(
"The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
)
def message_retrieval( def message_retrieval(
self, response: Union[ChatCompletion, Completion] self, response: Union[ChatCompletion, Completion]

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
import os import os
import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -74,6 +75,19 @@ def get_key(config: Dict[str, Any]) -> str:
return json.dumps(config, sort_keys=True) return json.dumps(config, sort_keys=True)
def is_valid_api_key(api_key: str):
"""Determine if input is valid OpenAI API key.
Args:
api_key (str): An input string to be validated.
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
api_key_re = re.compile(r"^sk-[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))
def get_config_list( def get_config_list(
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> List[Dict]: ) -> List[Dict]:

View File

@ -5,6 +5,8 @@ import pytest
import autogen import autogen
from conftest import MOCK_OPEN_AI_API_KEY
try: try:
from autogen.agentchat.contrib.llava_agent import ( from autogen.agentchat.contrib.llava_agent import (
LLaVAAgent, LLaVAAgent,
@ -26,7 +28,7 @@ class TestLLaVAAgent(unittest.TestCase):
llm_config={ llm_config={
"timeout": 600, "timeout": 600,
"seed": 42, "seed": 42,
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}], "config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": MOCK_OPEN_AI_API_KEY}],
}, },
) )
@ -103,7 +105,7 @@ class TestLLavaCall(unittest.TestCase):
# Set up the llm_config dictionary # Set up the llm_config dictionary
llm_config = { llm_config = {
"config_list": [{"api_key": "value", "base_url": "localhost:8000"}], "config_list": [{"api_key": MOCK_OPEN_AI_API_KEY, "base_url": "localhost:8000"}],
"max_new_tokens": 2000, "max_new_tokens": 2000,
"temperature": 0.5, "temperature": 0.5,
"seed": 1, "seed": 1,

View File

@ -6,6 +6,8 @@ import pytest
import autogen import autogen
from autogen.agentchat.conversable_agent import ConversableAgent from autogen.agentchat.conversable_agent import ConversableAgent
from conftest import MOCK_OPEN_AI_API_KEY
try: try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
except ImportError: except ImportError:
@ -28,7 +30,7 @@ class TestMultimodalConversableAgent(unittest.TestCase):
llm_config={ llm_config={
"timeout": 600, "timeout": 600,
"seed": 42, "seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}], "config_list": [{"model": "gpt-4-vision-preview", "api_key": MOCK_OPEN_AI_API_KEY}],
}, },
) )

View File

@ -7,7 +7,7 @@ from autogen.oai.openai_utils import filter_config
from autogen.cache import Cache from autogen.cache import Cache
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import skip_openai # noqa: E402 from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
@ -48,7 +48,7 @@ if not skip_oai:
def test_web_surfer() -> None: def test_web_surfer() -> None:
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
# we mock the API key so we can register functions (llm_config must be present for this to work) # we mock the API key so we can register functions (llm_config must be present for this to work)
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
page_size = 4096 page_size = 4096
web_surfer = WebSurferAgent( web_surfer = WebSurferAgent(
"web_surfer", llm_config={"config_list": []}, browser_config={"viewport_size": page_size} "web_surfer", llm_config={"config_list": []}, browser_config={"viewport_size": page_size}

View File

@ -15,7 +15,7 @@ import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent from autogen.agentchat import ConversableAgent, UserProxyAgent
from autogen.agentchat.conversable_agent import register_function from autogen.agentchat.conversable_agent import register_function
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai from conftest import MOCK_OPEN_AI_API_KEY, skip_openai
try: try:
import openai import openai
@ -473,7 +473,7 @@ async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_a
def test_update_function_signature_and_register_functions() -> None: def test_update_function_signature_and_register_functions() -> None:
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={}) agent = ConversableAgent(name="agent", llm_config={})
def exec_python(cell: str) -> None: def exec_python(cell: str) -> None:
@ -617,7 +617,7 @@ def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]
def test_register_for_llm(): def test_register_for_llm():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []}) agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []}) agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []}) agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
@ -690,7 +690,7 @@ def test_register_for_llm():
def test_register_for_llm_api_style_function(): def test_register_for_llm_api_style_function():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []}) agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []}) agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []}) agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
@ -761,7 +761,7 @@ def test_register_for_llm_api_style_function():
def test_register_for_llm_without_description(): def test_register_for_llm_without_description():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={}) agent = ConversableAgent(name="agent", llm_config={})
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
@ -775,7 +775,7 @@ def test_register_for_llm_without_description():
def test_register_for_llm_without_LLM(): def test_register_for_llm_without_LLM():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config=None) agent = ConversableAgent(name="agent", llm_config=None)
agent.llm_config = None agent.llm_config = None
assert agent.llm_config is None assert agent.llm_config is None
@ -791,7 +791,7 @@ def test_register_for_llm_without_LLM():
def test_register_for_execution(): def test_register_for_execution():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={"config_list": []}) agent = ConversableAgent(name="agent", llm_config={"config_list": []})
user_proxy_1 = UserProxyAgent(name="user_proxy_1") user_proxy_1 = UserProxyAgent(name="user_proxy_1")
user_proxy_2 = UserProxyAgent(name="user_proxy_2") user_proxy_2 = UserProxyAgent(name="user_proxy_2")
@ -826,7 +826,7 @@ def test_register_for_execution():
def test_register_functions(): def test_register_functions():
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={"config_list": []}) agent = ConversableAgent(name="agent", llm_config={"config_list": []})
user_proxy = UserProxyAgent(name="user_proxy") user_proxy = UserProxyAgent(name="user_proxy")

View File

@ -7,7 +7,7 @@ from autogen.coding.factory import CodeExecutorFactory
from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor
from autogen.oai.openai_utils import config_list_from_json from autogen.oai.openai_utils import config_list_from_json
from conftest import skip_openai from conftest import MOCK_OPEN_AI_API_KEY, skip_openai
def test_create() -> None: def test_create() -> None:
@ -152,7 +152,7 @@ def test_local_commandline_executor_conversable_agent_code_execution() -> None:
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir) executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
_test_conversable_agent_code_execution(executor) _test_conversable_agent_code_execution(executor)

View File

@ -7,7 +7,7 @@ from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.coding.base import CodeBlock, CodeExecutor from autogen.coding.base import CodeBlock, CodeExecutor
from autogen.coding.factory import CodeExecutorFactory from autogen.coding.factory import CodeExecutorFactory
from autogen.oai.openai_utils import config_list_from_json from autogen.oai.openai_utils import config_list_from_json
from conftest import skip_openai # noqa: E402 from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
try: try:
from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
@ -211,7 +211,7 @@ print(test_function(123, 4))
``` ```
""" """
with pytest.MonkeyPatch.context() as mp: with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock") mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
reply = agent.generate_reply( reply = agent.generate_reply(
[{"role": "user", "content": msg}], [{"role": "user", "content": msg}],
sender=ConversableAgent("user", llm_config=False, code_execution_config=False), sender=ConversableAgent("user", llm_config=False, code_execution_config=False),

View File

@ -4,6 +4,8 @@ skip_openai = False
skip_redis = False skip_redis = False
skip_docker = False skip_docker = False
MOCK_OPEN_AI_API_KEY = "sk-mockopenaiAPIkeyinexpectedformatfortestingonly"
# Registers command-line options like '--skip-openai' and '--skip-redis' via pytest hook. # Registers command-line options like '--skip-openai' and '--skip-redis' via pytest hook.
# When these flags are set, it indicates that tests requiring OpenAI or Redis (respectively) should be skipped. # When these flags are set, it indicates that tests requiring OpenAI or Redis (respectively) should be skipped.

View File

@ -8,7 +8,9 @@ from unittest.mock import patch
import pytest import pytest
import autogen # noqa: E402 import autogen # noqa: E402
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config, is_valid_api_key
from conftest import MOCK_OPEN_AI_API_KEY
# Example environment variables # Example environment variables
ENV_VARS = { ENV_VARS = {
@ -370,5 +372,17 @@ def test_tags():
assert len(list_5) == 0 assert len(list_5) == 0
def test_is_valid_api_key():
assert not is_valid_api_key("")
assert not is_valid_api_key("sk-")
assert not is_valid_api_key("SK-")
assert not is_valid_api_key("sk-asajsdjsd2")
assert not is_valid_api_key("FooBar")
assert not is_valid_api_key("sk-asajsdjsd22372%23kjdfdfdf2329ffUUDSDS")
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS")
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS1212121221212sssXX")
assert is_valid_api_key(MOCK_OPEN_AI_API_KEY)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main() pytest.main()