Added Support for Github Models

This commit is contained in:
Rohan Thacker 2024-12-16 11:36:29 +05:30 committed by Rohan Thacker
parent 09c071ea45
commit a24901cca0
2 changed files with 19 additions and 28 deletions

View File

@ -40,12 +40,14 @@ from autogen_core.models import (
FunctionExecutionResultMessage,
)
from autogen_core.tools import Tool, ToolSchema
from autogen_ext.models.azure.config import AzureAIConfig
from autogen_ext.models.azure.config import AzureAIChatCompletionClientConfig, GITHUB_MODELS_ENDPOINT
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
# create_args
def _is_github_model(endpoint: str) -> bool:
return endpoint == GITHUB_MODELS_ENDPOINT
def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
result: List[ChatCompletionsToolDefinition] = []
@ -127,15 +129,18 @@ def to_azure_message(message: LLMMessage):
else:
return _tool_message_to_azure(message)
# TODO: Add Support for Github Models
class AzureAIChatCompletionClient(ChatCompletionClient):
def __init__(self, **kwargs: Unpack[AzureAIConfig]):
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
if "endpoint" not in kwargs:
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
if "credential" not in kwargs:
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_capabilities" not in kwargs:
raise ValueError("model_capabilities is required for AzureAIChatCompletionClient")
if _is_github_model(kwargs['endpoint']) and "model" not in kwargs:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
# TODO: Change
_endpoint = kwargs.pop("endpoint")
@ -189,11 +194,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
if len(tools) > 0:
converted_tools = convert_tools(tools)
task = asyncio.create_task(
self._client.complete(
messages=azure_messages,
tools=converted_tools,
**create_args
)
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args)
)
else:
task = asyncio.create_task(
@ -254,11 +255,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
create_args.update(extra_create_args)
if self.capabilities["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.capabilities["json_output"] is False and json_output is True:
@ -281,21 +281,11 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
if len(tools) > 0:
converted_tools = convert_tools(tools)
task = asyncio.create_task(
self._client.complete(
messages=azure_messages,
tools=converted_tools,
stream=True,
**create_args
)
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
)
else:
task = asyncio.create_task(
self._client.complete(
messages=azure_messages,
max_tokens=20,
stream=True,
**create_args
)
self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args)
)
if cancellation_token is not None:

View File

@ -8,9 +8,10 @@ from azure.ai.inference.models import (
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from autogen_core.models import ModelCapabilities
GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"
class AzureAIClientArguments(TypedDict, total=False):
endpoint: str
@ -18,7 +19,7 @@ class AzureAIClientArguments(TypedDict, total=False):
model_capabilities: ModelCapabilities
class AzureAIRequestArguments(TypedDict, total=False):
class AzureAICreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
presence_penalty: Optional[float]
temperature: Optional[float]
@ -33,5 +34,5 @@ class AzureAIRequestArguments(TypedDict, total=False):
model_extras: Optional[Dict[str, Any]]
class AzureAIConfig(AzureAIClientArguments, AzureAIRequestArguments):
class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
pass