fix: Add model info validation and improve error messaging (#5556)

Introduce validation for the ModelInfo dictionary to ensure required
fields are present.

Resolves #5501
This commit is contained in:
Eric Zhu 2025-02-14 18:09:33 -08:00 committed by GitHub
parent 36da8f2af7
commit 69c0b2b5ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 6 deletions

View File

@ -1,4 +1,10 @@
from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore
from ._model_client import (
ChatCompletionClient,
ModelCapabilities, # type: ignore
ModelFamily,
ModelInfo,
validate_model_info,
)
from ._types import (
AssistantMessage,
ChatCompletionTokenLogprob,
@ -29,4 +35,5 @@ __all__ = [
"ChatCompletionTokenLogprob",
"ModelFamily",
"ModelInfo",
"validate_model_info",
]

View File

@ -93,6 +93,12 @@ class ModelCapabilities(TypedDict, total=False):
class ModelInfo(TypedDict, total=False):
"""ModelInfo is a dictionary that contains information about a model's properties.
It is expected to be used in the model_info property of a model client.
We are expecting this to grow over time as we add more features.
"""
vision: Required[bool]
"""True if the model supports vision, aka image input, otherwise False."""
function_calling: Required[bool]
@ -103,6 +109,21 @@ class ModelInfo(TypedDict, total=False):
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
def validate_model_info(model_info: ModelInfo) -> None:
"""Validates the model info dictionary.
Raises:
ValueError: If the model info dictionary is missing required fields.
"""
required_fields = ["vision", "function_calling", "json_output", "family"]
for field in required_fields:
if field not in model_info:
raise ValueError(
f"Missing required field '{field}' in ModelInfo. "
"Starting in v0.4.7, the required fields are enforced."
)
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
@abstractmethod

View File

@ -0,0 +1,22 @@
import pytest
from autogen_core.models import ModelInfo, validate_model_info
def test_model_info() -> None:
# Valid model info.
info: ModelInfo = {
"family": "gpt-4o",
"vision": True,
"function_calling": True,
"json_output": True,
}
validate_model_info(info)
# Invalid model info.
info = {
"family": "gpt-4o",
"vision": True,
"function_calling": True,
} # type: ignore
with pytest.raises(ValueError):
validate_model_info(info)

View File

@ -17,6 +17,7 @@ from autogen_core.models import (
RequestUsage,
SystemMessage,
UserMessage,
validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from azure.ai.inference.aio import ChatCompletionsClient
@ -246,10 +247,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_info" not in config:
raise ValueError("model_info is required for AzureAIChatCompletionClient")
if "family" not in config["model_info"]:
raise ValueError(
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
)
validate_model_info(config["model_info"])
if _is_github_model(config["endpoint"]) and "model" not in config:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
return cast(AzureAIChatCompletionClientConfig, config)

View File

@ -47,6 +47,7 @@ from autogen_core.models import (
SystemMessage,
TopLogprob,
UserMessage,
validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from openai import AsyncAzureOpenAI, AsyncOpenAI
@ -385,6 +386,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
elif model_capabilities is None and model_info is not None:
self._model_info = model_info
# Validate model_info, check if all required fields are present
validate_model_info(self._model_info)
self._resolved_model: Optional[str] = None
if "model" in create_args:
self._resolved_model = _model_info.resolve_model(create_args["model"])

View File

@ -11,6 +11,7 @@ from autogen_core.models import (
ModelFamily,
ModelInfo,
RequestUsage,
validate_model_info,
)
from autogen_core.tools import BaseTool, Tool, ToolSchema
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
@ -261,6 +262,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
self._model_info = model_info or ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
)
validate_model_info(self._model_info)
self._total_prompt_tokens = 0
self._total_completion_tokens = 0
self._tools_plugin: Optional[KernelPlugin] = None

View File

@ -150,7 +150,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
credential=AzureKeyCredential("api_key"),
)
with pytest.raises(ValueError, match="family is required"):
with pytest.raises(ValueError, match="Missing required field 'family'"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",