mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-09 17:22:50 +00:00
67 lines
2.8 KiB
Python
67 lines
2.8 KiB
Python
![]() |
import json
|
||
|
import os
|
||
|
from typing import Any, Dict
|
||
|
|
||
|
from agnext.components.models import (
|
||
|
AzureOpenAIChatCompletionClient,
|
||
|
ChatCompletionClient,
|
||
|
ModelCapabilities,
|
||
|
OpenAIChatCompletionClient,
|
||
|
)
|
||
|
|
||
|
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
|
||
|
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
|
||
|
|
||
|
# The singleton _default_azure_ad_token_provider, which will be created if needed
|
||
|
_default_azure_ad_token_provider = None
|
||
|
|
||
|
|
||
|
# Create a model client based on information provided in environment variables.
|
||
|
def create_completion_client_from_env(env: Dict[str, str] | None = None, **kwargs: Any) -> ChatCompletionClient:
|
||
|
global _default_azure_ad_token_provider
|
||
|
|
||
|
"""
|
||
|
Create a model client based on information provided in environment variables.
|
||
|
env (Optional): When provied, read from this dictionary rather than os.environ
|
||
|
kwargs**: ChatClient arguments to override (e.g., model)
|
||
|
|
||
|
NOTE: If 'azure_ad_token_provider' is included, and euquals the string 'DEFAULT' then replace it with
|
||
|
azure.identity.get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
|
||
|
"""
|
||
|
|
||
|
# If a dictionary was not provided, load it from the environment
|
||
|
if env is None:
|
||
|
env = dict()
|
||
|
env.update(os.environ)
|
||
|
|
||
|
# Load the kwargs, and override with provided kwargs
|
||
|
_kwargs = json.loads(env.get(ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON, "{}"))
|
||
|
_kwargs.update(kwargs)
|
||
|
|
||
|
# If model capabilities were provided, deserialize them as well
|
||
|
if "model_capabilities" in _kwargs:
|
||
|
_kwargs["model_capabilities"] = ModelCapabilities(
|
||
|
vision=_kwargs["model_capabilities"].get("vision"),
|
||
|
function_calling=_kwargs["model_capabilities"].get("function_calling"),
|
||
|
json_output=_kwargs["model_capabilities"].get("json_output"),
|
||
|
)
|
||
|
|
||
|
# Figure out what provider we are using. Default to OpenAI
|
||
|
_provider = env.get(ENVIRON_KEY_CHAT_COMPLETION_PROVIDER, "openai").lower().strip()
|
||
|
|
||
|
# Instantiate the correct client
|
||
|
if _provider == "openai":
|
||
|
return OpenAIChatCompletionClient(**_kwargs)
|
||
|
elif _provider == "azure":
|
||
|
if _kwargs.get("azure_ad_token_provider", "").lower() == "default":
|
||
|
if _default_azure_ad_token_provider is None:
|
||
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||
|
|
||
|
_default_azure_ad_token_provider = get_bearer_token_provider(
|
||
|
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
||
|
)
|
||
|
_kwargs["azure_ad_token_provider"] = _default_azure_ad_token_provider
|
||
|
return AzureOpenAIChatCompletionClient(**_kwargs)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown OAI provider '{_provider}'")
|