mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-24 21:49:42 +00:00
fix: Add model info validation and improve error messaging (#5556)
Introduce validation for the ModelInfo dictionary to ensure required fields are present. Resolves #5501
This commit is contained in:
parent
36da8f2af7
commit
69c0b2b5ef
@ -1,4 +1,10 @@
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore
|
||||
from ._model_client import (
|
||||
ChatCompletionClient,
|
||||
ModelCapabilities, # type: ignore
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
validate_model_info,
|
||||
)
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
ChatCompletionTokenLogprob,
|
||||
@ -29,4 +35,5 @@ __all__ = [
|
||||
"ChatCompletionTokenLogprob",
|
||||
"ModelFamily",
|
||||
"ModelInfo",
|
||||
"validate_model_info",
|
||||
]
|
||||
|
||||
@ -93,6 +93,12 @@ class ModelCapabilities(TypedDict, total=False):
|
||||
|
||||
|
||||
class ModelInfo(TypedDict, total=False):
|
||||
"""ModelInfo is a dictionary that contains information about a model's properties.
|
||||
It is expected to be used in the model_info property of a model client.
|
||||
|
||||
We are expecting this to grow over time as we add more features.
|
||||
"""
|
||||
|
||||
vision: Required[bool]
|
||||
"""True if the model supports vision, aka image input, otherwise False."""
|
||||
function_calling: Required[bool]
|
||||
@ -103,6 +109,21 @@ class ModelInfo(TypedDict, total=False):
|
||||
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
|
||||
|
||||
|
||||
def validate_model_info(model_info: ModelInfo) -> None:
|
||||
"""Validates the model info dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model info dictionary is missing required fields.
|
||||
"""
|
||||
required_fields = ["vision", "function_calling", "json_output", "family"]
|
||||
for field in required_fields:
|
||||
if field not in model_info:
|
||||
raise ValueError(
|
||||
f"Missing required field '{field}' in ModelInfo. "
|
||||
"Starting in v0.4.7, the required fields are enforced."
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
@abstractmethod
|
||||
|
||||
22
python/packages/autogen-core/tests/test_models.py
Normal file
22
python/packages/autogen-core/tests/test_models.py
Normal file
@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
from autogen_core.models import ModelInfo, validate_model_info
|
||||
|
||||
|
||||
def test_model_info() -> None:
|
||||
# Valid model info.
|
||||
info: ModelInfo = {
|
||||
"family": "gpt-4o",
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
}
|
||||
validate_model_info(info)
|
||||
|
||||
# Invalid model info.
|
||||
info = {
|
||||
"family": "gpt-4o",
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
} # type: ignore
|
||||
with pytest.raises(ValueError):
|
||||
validate_model_info(info)
|
||||
@ -17,6 +17,7 @@ from autogen_core.models import (
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from azure.ai.inference.aio import ChatCompletionsClient
|
||||
@ -246,10 +247,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
||||
if "model_info" not in config:
|
||||
raise ValueError("model_info is required for AzureAIChatCompletionClient")
|
||||
if "family" not in config["model_info"]:
|
||||
raise ValueError(
|
||||
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
|
||||
)
|
||||
validate_model_info(config["model_info"])
|
||||
if _is_github_model(config["endpoint"]) and "model" not in config:
|
||||
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
||||
return cast(AzureAIChatCompletionClientConfig, config)
|
||||
|
||||
@ -47,6 +47,7 @@ from autogen_core.models import (
|
||||
SystemMessage,
|
||||
TopLogprob,
|
||||
UserMessage,
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
@ -385,6 +386,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
elif model_capabilities is None and model_info is not None:
|
||||
self._model_info = model_info
|
||||
|
||||
# Validate model_info, check if all required fields are present
|
||||
validate_model_info(self._model_info)
|
||||
|
||||
self._resolved_model: Optional[str] = None
|
||||
if "model" in create_args:
|
||||
self._resolved_model = _model_info.resolve_model(create_args["model"])
|
||||
|
||||
@ -11,6 +11,7 @@ from autogen_core.models import (
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import BaseTool, Tool, ToolSchema
|
||||
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
|
||||
@ -261,6 +262,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
self._model_info = model_info or ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
)
|
||||
validate_model_info(self._model_info)
|
||||
self._total_prompt_tokens = 0
|
||||
self._total_completion_tokens = 0
|
||||
self._tools_plugin: Optional[KernelPlugin] = None
|
||||
|
||||
@ -150,7 +150,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="family is required"):
|
||||
with pytest.raises(ValueError, match="Missing required field 'family'"):
|
||||
AzureAIChatCompletionClient(
|
||||
model="model",
|
||||
endpoint="endpoint",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user