mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 23:49:13 +00:00
Added Support for Github Models
This commit is contained in:
parent
09c071ea45
commit
a24901cca0
@ -40,12 +40,14 @@ from autogen_core.models import (
|
||||
FunctionExecutionResultMessage,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from autogen_ext.models.azure.config import AzureAIConfig
|
||||
from autogen_ext.models.azure.config import AzureAIChatCompletionClientConfig, GITHUB_MODELS_ENDPOINT
|
||||
|
||||
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
|
||||
|
||||
|
||||
# create_args
|
||||
def _is_github_model(endpoint: str) -> bool:
|
||||
return endpoint == GITHUB_MODELS_ENDPOINT
|
||||
|
||||
|
||||
def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
|
||||
result: List[ChatCompletionsToolDefinition] = []
|
||||
@ -127,15 +129,18 @@ def to_azure_message(message: LLMMessage):
|
||||
else:
|
||||
return _tool_message_to_azure(message)
|
||||
|
||||
|
||||
# TODO: Add Support for Github Models
|
||||
class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
def __init__(self, **kwargs: Unpack[AzureAIConfig]):
|
||||
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
|
||||
if "endpoint" not in kwargs:
|
||||
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
|
||||
if "credential" not in kwargs:
|
||||
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
||||
if "model_capabilities" not in kwargs:
|
||||
raise ValueError("model_capabilities is required for AzureAIChatCompletionClient")
|
||||
if _is_github_model(kwargs['endpoint']) and "model" not in kwargs:
|
||||
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
||||
|
||||
# TODO: Change
|
||||
_endpoint = kwargs.pop("endpoint")
|
||||
@ -189,11 +194,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(
|
||||
messages=azure_messages,
|
||||
tools=converted_tools,
|
||||
**create_args
|
||||
)
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args)
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task(
|
||||
@ -254,11 +255,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
if self.capabilities["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
@ -281,21 +281,11 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(
|
||||
messages=azure_messages,
|
||||
tools=converted_tools,
|
||||
stream=True,
|
||||
**create_args
|
||||
)
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(
|
||||
messages=azure_messages,
|
||||
max_tokens=20,
|
||||
stream=True,
|
||||
**create_args
|
||||
)
|
||||
self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args)
|
||||
)
|
||||
|
||||
if cancellation_token is not None:
|
||||
|
||||
@ -8,9 +8,10 @@ from azure.ai.inference.models import (
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
from autogen_core.models import ModelCapabilities
|
||||
|
||||
GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"
|
||||
|
||||
|
||||
class AzureAIClientArguments(TypedDict, total=False):
|
||||
endpoint: str
|
||||
@ -18,7 +19,7 @@ class AzureAIClientArguments(TypedDict, total=False):
|
||||
model_capabilities: ModelCapabilities
|
||||
|
||||
|
||||
class AzureAIRequestArguments(TypedDict, total=False):
|
||||
class AzureAICreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
presence_penalty: Optional[float]
|
||||
temperature: Optional[float]
|
||||
@ -33,5 +34,5 @@ class AzureAIRequestArguments(TypedDict, total=False):
|
||||
model_extras: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class AzureAIConfig(AzureAIClientArguments, AzureAIRequestArguments):
|
||||
class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
|
||||
pass
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user