mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-17 04:03:49 +00:00
Validate the OpenAI API key format (#1635)
* Validate the OpenAI API key format Increase the amount of internal validation for OpenAI API keys. The intent is to shorten the debugging loop in case of typos. The changes do *not* add validation for Azure OpenAI API keys. * Add the validation in `__init__` of `OpenAIClient`. * Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing. * Add unit test coverage for the `is_valid_api_key` function. * Validate the OpenAI API key format Increase the amount of internal validation for OpenAI API keys. The intent is to shorten the debugging loop in case of typos. The changes do *not* add validation for Azure OpenAI API keys. * Add the validation in `__init__` of `OpenAIClient`. * Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing. *Add unit test coverage for the `is_valid_api_key` function. * Log a warning when register a default client fails. * Validate the OpenAI API key format Increase the amount of internal validation for OpenAI API keys. The intent is to shorten the debugging loop in case of typos. The changes do *not* add validation for Azure OpenAI API keys. * Add the validation in `__init__` of `OpenAIClient`. We'll log a warning when the OpenAI API key isn't valid. * Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing. * Add unit test coverage for the `is_valid_api_key` function. * Check for OpenAI base_url before API key validation --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
a62b5c3b2d
commit
f68c09b035
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from autogen.cache.cache import Cache
|
from autogen.cache.cache import Cache
|
||||||
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
|
from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
|
||||||
from autogen.token_count_utils import count_token
|
from autogen.token_count_utils import count_token
|
||||||
|
|
||||||
TOOL_ENABLED = False
|
TOOL_ENABLED = False
|
||||||
@ -47,6 +47,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
LEGACY_DEFAULT_CACHE_SEED = 41
|
LEGACY_DEFAULT_CACHE_SEED = 41
|
||||||
LEGACY_CACHE_DIR = ".cache"
|
LEGACY_CACHE_DIR = ".cache"
|
||||||
|
OPEN_API_BASE_URL_PREFIX = "https://api.openai.com"
|
||||||
|
|
||||||
|
|
||||||
class ModelClient(Protocol):
|
class ModelClient(Protocol):
|
||||||
@ -111,6 +112,14 @@ class OpenAIClient:
|
|||||||
|
|
||||||
def __init__(self, client: Union[OpenAI, AzureOpenAI]):
|
def __init__(self, client: Union[OpenAI, AzureOpenAI]):
|
||||||
self._oai_client = client
|
self._oai_client = client
|
||||||
|
if (
|
||||||
|
not isinstance(client, openai.AzureOpenAI)
|
||||||
|
and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
|
||||||
|
and not is_valid_api_key(self._oai_client.api_key)
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
|
||||||
|
)
|
||||||
|
|
||||||
def message_retrieval(
|
def message_retrieval(
|
||||||
self, response: Union[ChatCompletion, Completion]
|
self, response: Union[ChatCompletion, Completion]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -74,6 +75,19 @@ def get_key(config: Dict[str, Any]) -> str:
|
|||||||
return json.dumps(config, sort_keys=True)
|
return json.dumps(config, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_api_key(api_key: str):
|
||||||
|
"""Determine if input is valid OpenAI API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): An input string to be validated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: A boolean that indicates if input is valid OpenAI API key.
|
||||||
|
"""
|
||||||
|
api_key_re = re.compile(r"^sk-[A-Za-z0-9]{32,}$")
|
||||||
|
return bool(re.fullmatch(api_key_re, api_key))
|
||||||
|
|
||||||
|
|
||||||
def get_config_list(
|
def get_config_list(
|
||||||
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
|
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
|
@ -5,6 +5,8 @@ import pytest
|
|||||||
|
|
||||||
import autogen
|
import autogen
|
||||||
|
|
||||||
|
from conftest import MOCK_OPEN_AI_API_KEY
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from autogen.agentchat.contrib.llava_agent import (
|
from autogen.agentchat.contrib.llava_agent import (
|
||||||
LLaVAAgent,
|
LLaVAAgent,
|
||||||
@ -26,7 +28,7 @@ class TestLLaVAAgent(unittest.TestCase):
|
|||||||
llm_config={
|
llm_config={
|
||||||
"timeout": 600,
|
"timeout": 600,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}],
|
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": MOCK_OPEN_AI_API_KEY}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ class TestLLavaCall(unittest.TestCase):
|
|||||||
|
|
||||||
# Set up the llm_config dictionary
|
# Set up the llm_config dictionary
|
||||||
llm_config = {
|
llm_config = {
|
||||||
"config_list": [{"api_key": "value", "base_url": "localhost:8000"}],
|
"config_list": [{"api_key": MOCK_OPEN_AI_API_KEY, "base_url": "localhost:8000"}],
|
||||||
"max_new_tokens": 2000,
|
"max_new_tokens": 2000,
|
||||||
"temperature": 0.5,
|
"temperature": 0.5,
|
||||||
"seed": 1,
|
"seed": 1,
|
||||||
|
@ -6,6 +6,8 @@ import pytest
|
|||||||
import autogen
|
import autogen
|
||||||
from autogen.agentchat.conversable_agent import ConversableAgent
|
from autogen.agentchat.conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
from conftest import MOCK_OPEN_AI_API_KEY
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -28,7 +30,7 @@ class TestMultimodalConversableAgent(unittest.TestCase):
|
|||||||
llm_config={
|
llm_config={
|
||||||
"timeout": 600,
|
"timeout": 600,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}],
|
"config_list": [{"model": "gpt-4-vision-preview", "api_key": MOCK_OPEN_AI_API_KEY}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from autogen.oai.openai_utils import filter_config
|
|||||||
from autogen.cache import Cache
|
from autogen.cache import Cache
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
from conftest import skip_openai # noqa: E402
|
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||||
@ -48,7 +48,7 @@ if not skip_oai:
|
|||||||
def test_web_surfer() -> None:
|
def test_web_surfer() -> None:
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
# we mock the API key so we can register functions (llm_config must be present for this to work)
|
# we mock the API key so we can register functions (llm_config must be present for this to work)
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
page_size = 4096
|
page_size = 4096
|
||||||
web_surfer = WebSurferAgent(
|
web_surfer = WebSurferAgent(
|
||||||
"web_surfer", llm_config={"config_list": []}, browser_config={"viewport_size": page_size}
|
"web_surfer", llm_config={"config_list": []}, browser_config={"viewport_size": page_size}
|
||||||
|
@ -15,7 +15,7 @@ import autogen
|
|||||||
from autogen.agentchat import ConversableAgent, UserProxyAgent
|
from autogen.agentchat import ConversableAgent, UserProxyAgent
|
||||||
from autogen.agentchat.conversable_agent import register_function
|
from autogen.agentchat.conversable_agent import register_function
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||||
from conftest import skip_openai
|
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
@ -473,7 +473,7 @@ async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_a
|
|||||||
|
|
||||||
def test_update_function_signature_and_register_functions() -> None:
|
def test_update_function_signature_and_register_functions() -> None:
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent = ConversableAgent(name="agent", llm_config={})
|
agent = ConversableAgent(name="agent", llm_config={})
|
||||||
|
|
||||||
def exec_python(cell: str) -> None:
|
def exec_python(cell: str) -> None:
|
||||||
@ -617,7 +617,7 @@ def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]
|
|||||||
|
|
||||||
def test_register_for_llm():
|
def test_register_for_llm():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
|
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
|
||||||
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
|
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
|
||||||
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
|
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
|
||||||
@ -690,7 +690,7 @@ def test_register_for_llm():
|
|||||||
|
|
||||||
def test_register_for_llm_api_style_function():
|
def test_register_for_llm_api_style_function():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
|
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
|
||||||
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
|
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
|
||||||
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
|
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
|
||||||
@ -761,7 +761,7 @@ def test_register_for_llm_api_style_function():
|
|||||||
|
|
||||||
def test_register_for_llm_without_description():
|
def test_register_for_llm_without_description():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent = ConversableAgent(name="agent", llm_config={})
|
agent = ConversableAgent(name="agent", llm_config={})
|
||||||
|
|
||||||
with pytest.raises(ValueError) as e:
|
with pytest.raises(ValueError) as e:
|
||||||
@ -775,7 +775,7 @@ def test_register_for_llm_without_description():
|
|||||||
|
|
||||||
def test_register_for_llm_without_LLM():
|
def test_register_for_llm_without_LLM():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent = ConversableAgent(name="agent", llm_config=None)
|
agent = ConversableAgent(name="agent", llm_config=None)
|
||||||
agent.llm_config = None
|
agent.llm_config = None
|
||||||
assert agent.llm_config is None
|
assert agent.llm_config is None
|
||||||
@ -791,7 +791,7 @@ def test_register_for_llm_without_LLM():
|
|||||||
|
|
||||||
def test_register_for_execution():
|
def test_register_for_execution():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
||||||
user_proxy_1 = UserProxyAgent(name="user_proxy_1")
|
user_proxy_1 = UserProxyAgent(name="user_proxy_1")
|
||||||
user_proxy_2 = UserProxyAgent(name="user_proxy_2")
|
user_proxy_2 = UserProxyAgent(name="user_proxy_2")
|
||||||
@ -826,7 +826,7 @@ def test_register_for_execution():
|
|||||||
|
|
||||||
def test_register_functions():
|
def test_register_functions():
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
||||||
user_proxy = UserProxyAgent(name="user_proxy")
|
user_proxy = UserProxyAgent(name="user_proxy")
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from autogen.coding.factory import CodeExecutorFactory
|
|||||||
from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor
|
from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor
|
||||||
from autogen.oai.openai_utils import config_list_from_json
|
from autogen.oai.openai_utils import config_list_from_json
|
||||||
|
|
||||||
from conftest import skip_openai
|
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai
|
||||||
|
|
||||||
|
|
||||||
def test_create() -> None:
|
def test_create() -> None:
|
||||||
@ -152,7 +152,7 @@ def test_local_commandline_executor_conversable_agent_code_execution() -> None:
|
|||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
|
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
_test_conversable_agent_code_execution(executor)
|
_test_conversable_agent_code_execution(executor)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from autogen.agentchat.conversable_agent import ConversableAgent
|
|||||||
from autogen.coding.base import CodeBlock, CodeExecutor
|
from autogen.coding.base import CodeBlock, CodeExecutor
|
||||||
from autogen.coding.factory import CodeExecutorFactory
|
from autogen.coding.factory import CodeExecutorFactory
|
||||||
from autogen.oai.openai_utils import config_list_from_json
|
from autogen.oai.openai_utils import config_list_from_json
|
||||||
from conftest import skip_openai # noqa: E402
|
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
|
from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
|
||||||
@ -211,7 +211,7 @@ print(test_function(123, 4))
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("OPENAI_API_KEY", "mock")
|
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
|
||||||
reply = agent.generate_reply(
|
reply = agent.generate_reply(
|
||||||
[{"role": "user", "content": msg}],
|
[{"role": "user", "content": msg}],
|
||||||
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
|
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
|
||||||
|
@ -4,6 +4,8 @@ skip_openai = False
|
|||||||
skip_redis = False
|
skip_redis = False
|
||||||
skip_docker = False
|
skip_docker = False
|
||||||
|
|
||||||
|
MOCK_OPEN_AI_API_KEY = "sk-mockopenaiAPIkeyinexpectedformatfortestingonly"
|
||||||
|
|
||||||
|
|
||||||
# Registers command-line options like '--skip-openai' and '--skip-redis' via pytest hook.
|
# Registers command-line options like '--skip-openai' and '--skip-redis' via pytest hook.
|
||||||
# When these flags are set, it indicates that tests requiring OpenAI or Redis (respectively) should be skipped.
|
# When these flags are set, it indicates that tests requiring OpenAI or Redis (respectively) should be skipped.
|
||||||
|
@ -8,7 +8,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import autogen # noqa: E402
|
import autogen # noqa: E402
|
||||||
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config
|
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config, is_valid_api_key
|
||||||
|
|
||||||
|
from conftest import MOCK_OPEN_AI_API_KEY
|
||||||
|
|
||||||
# Example environment variables
|
# Example environment variables
|
||||||
ENV_VARS = {
|
ENV_VARS = {
|
||||||
@ -370,5 +372,17 @@ def test_tags():
|
|||||||
assert len(list_5) == 0
|
assert len(list_5) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_valid_api_key():
|
||||||
|
assert not is_valid_api_key("")
|
||||||
|
assert not is_valid_api_key("sk-")
|
||||||
|
assert not is_valid_api_key("SK-")
|
||||||
|
assert not is_valid_api_key("sk-asajsdjsd2")
|
||||||
|
assert not is_valid_api_key("FooBar")
|
||||||
|
assert not is_valid_api_key("sk-asajsdjsd22372%23kjdfdfdf2329ffUUDSDS")
|
||||||
|
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS")
|
||||||
|
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS1212121221212sssXX")
|
||||||
|
assert is_valid_api_key(MOCK_OPEN_AI_API_KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main()
|
pytest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user