2024-07-09 10:46:55 -07:00
|
|
|
import json
|
2024-07-09 13:51:05 -07:00
|
|
|
import logging
|
2024-07-09 10:46:55 -07:00
|
|
|
import os
|
2024-07-09 13:51:05 -07:00
|
|
|
from datetime import datetime
|
2024-07-10 00:01:13 -07:00
|
|
|
from typing import Any, Dict, List
|
2024-07-09 10:46:55 -07:00
|
|
|
|
|
|
|
from agnext.components.models import (
|
|
|
|
AzureOpenAIChatCompletionClient,
|
|
|
|
ChatCompletionClient,
|
|
|
|
ModelCapabilities,
|
|
|
|
OpenAIChatCompletionClient,
|
|
|
|
)
|
|
|
|
|
2024-07-10 00:01:13 -07:00
|
|
|
from .messages import AssistantContent, FunctionExecutionContent, OrchestrationEvent, SystemContent, UserContent
|
2024-07-09 13:51:05 -07:00
|
|
|
|
2024-07-09 10:46:55 -07:00
|
|
|
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
|
|
|
|
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
|
|
|
|
|
|
|
|
# The singleton _default_azure_ad_token_provider, which will be created if needed
|
|
|
|
_default_azure_ad_token_provider = None
|
|
|
|
|
|
|
|
|
|
|
|
# Create a model client based on information provided in environment variables.
|
|
|
|
def create_completion_client_from_env(env: Dict[str, str] | None = None, **kwargs: Any) -> ChatCompletionClient:
|
|
|
|
global _default_azure_ad_token_provider
|
|
|
|
|
|
|
|
"""
|
|
|
|
Create a model client based on information provided in environment variables.
|
|
|
|
env (Optional): When provied, read from this dictionary rather than os.environ
|
|
|
|
kwargs**: ChatClient arguments to override (e.g., model)
|
|
|
|
|
|
|
|
NOTE: If 'azure_ad_token_provider' is included, and euquals the string 'DEFAULT' then replace it with
|
|
|
|
azure.identity.get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
|
|
|
|
"""
|
|
|
|
|
|
|
|
# If a dictionary was not provided, load it from the environment
|
|
|
|
if env is None:
|
|
|
|
env = dict()
|
|
|
|
env.update(os.environ)
|
|
|
|
|
|
|
|
# Load the kwargs, and override with provided kwargs
|
|
|
|
_kwargs = json.loads(env.get(ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON, "{}"))
|
|
|
|
_kwargs.update(kwargs)
|
|
|
|
|
|
|
|
# If model capabilities were provided, deserialize them as well
|
|
|
|
if "model_capabilities" in _kwargs:
|
|
|
|
_kwargs["model_capabilities"] = ModelCapabilities(
|
|
|
|
vision=_kwargs["model_capabilities"].get("vision"),
|
|
|
|
function_calling=_kwargs["model_capabilities"].get("function_calling"),
|
|
|
|
json_output=_kwargs["model_capabilities"].get("json_output"),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Figure out what provider we are using. Default to OpenAI
|
|
|
|
_provider = env.get(ENVIRON_KEY_CHAT_COMPLETION_PROVIDER, "openai").lower().strip()
|
|
|
|
|
|
|
|
# Instantiate the correct client
|
|
|
|
if _provider == "openai":
|
|
|
|
return OpenAIChatCompletionClient(**_kwargs)
|
|
|
|
elif _provider == "azure":
|
|
|
|
if _kwargs.get("azure_ad_token_provider", "").lower() == "default":
|
|
|
|
if _default_azure_ad_token_provider is None:
|
|
|
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
|
|
|
|
|
|
_default_azure_ad_token_provider = get_bearer_token_provider(
|
|
|
|
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
|
|
|
)
|
|
|
|
_kwargs["azure_ad_token_provider"] = _default_azure_ad_token_provider
|
|
|
|
return AzureOpenAIChatCompletionClient(**_kwargs)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown OAI provider '{_provider}'")
|
2024-07-09 13:51:05 -07:00
|
|
|
|
|
|
|
|
2024-07-10 00:01:13 -07:00
|
|
|
# Convert UserContent to a string
|
|
|
|
def message_content_to_str(
|
|
|
|
message_content: UserContent | AssistantContent | SystemContent | FunctionExecutionContent,
|
|
|
|
) -> str:
|
|
|
|
if isinstance(message_content, str):
|
|
|
|
return message_content
|
|
|
|
elif isinstance(message_content, List):
|
|
|
|
converted: List[str] = list()
|
|
|
|
for item in message_content:
|
|
|
|
if isinstance(item, str):
|
|
|
|
converted.append(item.rstrip())
|
|
|
|
else:
|
|
|
|
converted.append(str(item).rstrip())
|
|
|
|
return "\n".join(converted)
|
|
|
|
else:
|
|
|
|
raise AssertionError("Unexpected response type.")
|
|
|
|
|
|
|
|
|
2024-07-09 13:51:05 -07:00
|
|
|
# TeamOne log event handler
|
|
|
|
class LogHandler(logging.Handler):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def emit(self, record: logging.LogRecord) -> None:
|
|
|
|
try:
|
|
|
|
if isinstance(record.msg, OrchestrationEvent):
|
|
|
|
ts = datetime.fromtimestamp(record.created).isoformat()
|
|
|
|
print(
|
|
|
|
f"""
|
|
|
|
---------------------------------------------------------------------------
|
|
|
|
\033[91m[{ts}], {record.msg.source}:\033[0m
|
|
|
|
|
|
|
|
{record.msg.message}""",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
self.handleError(record)
|