mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-01 18:29:49 +00:00
feat: enhance AzureAIChatCompletionClient validation and add unit tests (#5417)
Resolves #5414
This commit is contained in:
parent
af5dcc7fdf
commit
901ab1276d
@ -246,6 +246,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
||||
if "model_info" not in config:
|
||||
raise ValueError("model_info is required for AzureAIChatCompletionClient")
|
||||
if "family" not in config["model_info"]:
|
||||
raise ValueError(
|
||||
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
|
||||
)
|
||||
if _is_github_model(config["endpoint"]) and "model" not in config:
|
||||
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
||||
return cast(AzureAIChatCompletionClientConfig, config)
|
||||
@ -512,7 +516,8 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
def __del__(self) -> None:
|
||||
# TODO: This is a hack to close the open client
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self._client.close())
|
||||
except RuntimeError:
|
||||
asyncio.run(self._client.close())
|
||||
if hasattr(self, "_client"):
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self._client.close())
|
||||
except RuntimeError:
|
||||
asyncio.run(self._client.close())
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import CreateResult, ModelFamily, UserMessage
|
||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient,
|
||||
)
|
||||
@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||
with pytest.raises(ValueError, match="endpoint is required"):
|
||||
AzureAIChatCompletionClient(
|
||||
model="model",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="credential is required"):
|
||||
AzureAIChatCompletionClient(
|
||||
model="model",
|
||||
endpoint="endpoint",
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="model is required"):
|
||||
AzureAIChatCompletionClient(
|
||||
endpoint=GITHUB_MODELS_ENDPOINT,
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="model_info is required"):
|
||||
AzureAIChatCompletionClient(
|
||||
model="model",
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="family is required"):
|
||||
AzureAIChatCompletionClient(
|
||||
model="model",
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
# Missing family.
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
assert azure_client
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user