feat: enhance AzureAIChatCompletionClient validation and add unit tests (#5417)

Resolves #5414
This commit is contained in:
Eric Zhu 2025-02-07 10:32:14 -08:00 committed by GitHub
parent af5dcc7fdf
commit 901ab1276d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 4 deletions

View File

@ -246,6 +246,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_info" not in config:
raise ValueError("model_info is required for AzureAIChatCompletionClient")
if "family" not in config["model_info"]:
raise ValueError(
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
)
if _is_github_model(config["endpoint"]) and "model" not in config:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
return cast(AzureAIChatCompletionClientConfig, config)
@ -512,7 +516,8 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
def __del__(self) -> None:
# TODO: This is a hack to close the open client
try:
asyncio.get_running_loop().create_task(self._client.close())
except RuntimeError:
asyncio.run(self._client.close())
if hasattr(self, "_client"):
try:
asyncio.get_running_loop().create_task(self._client.close())
except RuntimeError:
asyncio.run(self._client.close())

View File

@ -7,6 +7,7 @@ import pytest
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
from azure.ai.inference.aio import (
ChatCompletionsClient,
)
@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
)
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_validation() -> None:
with pytest.raises(ValueError, match="endpoint is required"):
AzureAIChatCompletionClient(
model="model",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)
with pytest.raises(ValueError, match="credential is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)
with pytest.raises(ValueError, match="model is required"):
AzureAIChatCompletionClient(
endpoint=GITHUB_MODELS_ENDPOINT,
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)
with pytest.raises(ValueError, match="model_info is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
)
with pytest.raises(ValueError, match="family is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
# Missing family.
}, # type: ignore
)
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
assert azure_client