mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-12 08:04:11 +00:00
feat: enhance AzureAIChatCompletionClient validation and add unit tests (#5417)
Resolves #5414
This commit is contained in:
parent
af5dcc7fdf
commit
901ab1276d
@ -246,6 +246,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
|||||||
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
||||||
if "model_info" not in config:
|
if "model_info" not in config:
|
||||||
raise ValueError("model_info is required for AzureAIChatCompletionClient")
|
raise ValueError("model_info is required for AzureAIChatCompletionClient")
|
||||||
|
if "family" not in config["model_info"]:
|
||||||
|
raise ValueError(
|
||||||
|
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
|
||||||
|
)
|
||||||
if _is_github_model(config["endpoint"]) and "model" not in config:
|
if _is_github_model(config["endpoint"]) and "model" not in config:
|
||||||
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
||||||
return cast(AzureAIChatCompletionClientConfig, config)
|
return cast(AzureAIChatCompletionClientConfig, config)
|
||||||
@ -512,6 +516,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
|||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
# TODO: This is a hack to close the open client
|
# TODO: This is a hack to close the open client
|
||||||
|
if hasattr(self, "_client"):
|
||||||
try:
|
try:
|
||||||
asyncio.get_running_loop().create_task(self._client.close())
|
asyncio.get_running_loop().create_task(self._client.close())
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
from autogen_core import CancellationToken, FunctionCall, Image
|
from autogen_core import CancellationToken, FunctionCall, Image
|
||||||
from autogen_core.models import CreateResult, ModelFamily, UserMessage
|
from autogen_core.models import CreateResult, ModelFamily, UserMessage
|
||||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||||
|
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
|
||||||
from azure.ai.inference.aio import (
|
from azure.ai.inference.aio import (
|
||||||
ChatCompletionsClient,
|
ChatCompletionsClient,
|
||||||
)
|
)
|
||||||
@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||||
|
with pytest.raises(ValueError, match="endpoint is required"):
|
||||||
|
AzureAIChatCompletionClient(
|
||||||
|
model="model",
|
||||||
|
credential=AzureKeyCredential("api_key"),
|
||||||
|
model_info={
|
||||||
|
"json_output": False,
|
||||||
|
"function_calling": False,
|
||||||
|
"vision": False,
|
||||||
|
"family": "unknown",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="credential is required"):
|
||||||
|
AzureAIChatCompletionClient(
|
||||||
|
model="model",
|
||||||
|
endpoint="endpoint",
|
||||||
|
model_info={
|
||||||
|
"json_output": False,
|
||||||
|
"function_calling": False,
|
||||||
|
"vision": False,
|
||||||
|
"family": "unknown",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="model is required"):
|
||||||
|
AzureAIChatCompletionClient(
|
||||||
|
endpoint=GITHUB_MODELS_ENDPOINT,
|
||||||
|
credential=AzureKeyCredential("api_key"),
|
||||||
|
model_info={
|
||||||
|
"json_output": False,
|
||||||
|
"function_calling": False,
|
||||||
|
"vision": False,
|
||||||
|
"family": "unknown",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="model_info is required"):
|
||||||
|
AzureAIChatCompletionClient(
|
||||||
|
model="model",
|
||||||
|
endpoint="endpoint",
|
||||||
|
credential=AzureKeyCredential("api_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="family is required"):
|
||||||
|
AzureAIChatCompletionClient(
|
||||||
|
model="model",
|
||||||
|
endpoint="endpoint",
|
||||||
|
credential=AzureKeyCredential("api_key"),
|
||||||
|
model_info={
|
||||||
|
"json_output": False,
|
||||||
|
"function_calling": False,
|
||||||
|
"vision": False,
|
||||||
|
# Missing family.
|
||||||
|
}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
|
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
|
||||||
assert azure_client
|
assert azure_client
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user