mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 07:29:54 +00:00
fix: update SK model adapter constructor (#5150)
* update constructor * fix typing error * revert/fix doc changes * add unsaved changes --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
5e9b24c3d9
commit
3fe106621e
@ -44,6 +44,15 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Args:
|
||||
sk_client (ChatCompletionClientBase):
|
||||
The Semantic Kernel client to wrap (e.g., AzureChatCompletion, GoogleAIChatCompletion, OllamaChatCompletion).
|
||||
kernel (Optional[Kernel]):
|
||||
The Semantic Kernel instance to use for executing requests. If not provided, one must be passed
|
||||
in the extra_create_args for each request.
|
||||
prompt_settings (Optional[PromptExecutionSettings]):
|
||||
Default prompt execution settings to use. Can be overridden per request.
|
||||
model_info (Optional[ModelInfo]):
|
||||
Information about the model's capabilities.
|
||||
service_id (Optional[str]):
|
||||
Optional service identifier.
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -100,8 +109,8 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
api_key = "<AZURE_OPENAI_API_KEY>"
|
||||
|
||||
azure_client = AzureChatCompletion(deployment_name=deployment_name, endpoint=endpoint, api_key=api_key)
|
||||
azure_request_settings = AzureChatPromptExecutionSettings(temperature=0.8)
|
||||
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client, default_prompt_settings=azure_request_settings)
|
||||
azure_settings = AzureChatPromptExecutionSettings(temperature=0.8)
|
||||
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client, kernel=kernel, prompt_settings=azure_settings)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Example B: Google Gemini
|
||||
@ -127,7 +136,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
"temperature": 0.8,
|
||||
},
|
||||
)
|
||||
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client, default_prompt_settings=request_settings)
|
||||
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client, prompt_settings=request_settings)
|
||||
|
||||
# 3) Create a tool and register it with the kernel
|
||||
calc_tool = CalculatorTool()
|
||||
@ -143,7 +152,6 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
azure_result = await azure_adapter.create(
|
||||
messages=messages,
|
||||
tools=[calc_tool],
|
||||
extra_create_args={"kernel": kernel, "prompt_execution_settings": azure_request_settings},
|
||||
)
|
||||
print("Azure result:", azure_result.content)
|
||||
|
||||
@ -151,7 +159,6 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
google_result = await google_adapter.create(
|
||||
messages=messages,
|
||||
tools=[calc_tool],
|
||||
extra_create_args={"kernel": kernel},
|
||||
)
|
||||
print("Google result:", google_result.content)
|
||||
|
||||
@ -159,7 +166,6 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
ollama_result = await ollama_adapter.create(
|
||||
messages=messages,
|
||||
tools=[calc_tool],
|
||||
extra_create_args={"kernel": kernel, "prompt_execution_settings": request_settings},
|
||||
)
|
||||
print("Ollama result:", ollama_result.content)
|
||||
|
||||
@ -171,12 +177,14 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
sk_client: ChatCompletionClientBase,
|
||||
kernel: Optional[Kernel] = None,
|
||||
prompt_settings: Optional[PromptExecutionSettings] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
service_id: Optional[str] = None,
|
||||
default_prompt_settings: Optional[PromptExecutionSettings] = None,
|
||||
):
|
||||
self._service_id = service_id
|
||||
self._default_prompt_settings = default_prompt_settings
|
||||
self._kernel = kernel
|
||||
self._prompt_settings = prompt_settings
|
||||
self._sk_client = sk_client
|
||||
self._model_info = model_info or ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
@ -287,6 +295,17 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
function_calls.append(FunctionCall(id=item.id, name=full_name, arguments=arguments))
|
||||
return function_calls
|
||||
|
||||
def _get_kernel(self, extra_create_args: Mapping[str, Any]) -> Kernel:
|
||||
kernel = extra_create_args.get("kernel", self._kernel)
|
||||
if not kernel:
|
||||
raise ValueError("kernel must be provided either in constructor or extra_create_args")
|
||||
if not isinstance(kernel, Kernel):
|
||||
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")
|
||||
return kernel
|
||||
|
||||
def _get_prompt_settings(self, extra_create_args: Mapping[str, Any]) -> Optional[PromptExecutionSettings]:
|
||||
return extra_create_args.get("prompt_execution_settings", None) or self._prompt_settings
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
@ -300,9 +319,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
|
||||
The `extra_create_args` dictionary can include two special keys:
|
||||
|
||||
1) `"kernel"` (required):
|
||||
1) `"kernel"` (optional):
|
||||
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
|
||||
If not provided, a ValueError is raised.
|
||||
If not provided either in constructor or extra_create_args, a ValueError is raised.
|
||||
|
||||
2) `"prompt_execution_settings"` (optional):
|
||||
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
|
||||
@ -320,19 +339,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Returns:
|
||||
CreateResult: The result of the chat completion.
|
||||
"""
|
||||
if "kernel" not in extra_create_args:
|
||||
raise ValueError("kernel is required in extra_create_args")
|
||||
|
||||
kernel = extra_create_args["kernel"]
|
||||
if not isinstance(kernel, Kernel):
|
||||
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
|
||||
# Build execution settings from extra args and tools
|
||||
user_settings = extra_create_args.get("prompt_execution_settings", None)
|
||||
if user_settings is None:
|
||||
user_settings = self._default_prompt_settings
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
settings = self._build_execution_settings(user_settings, tools)
|
||||
|
||||
# Sync tools with kernel
|
||||
@ -380,9 +389,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
|
||||
The `extra_create_args` dictionary can include two special keys:
|
||||
|
||||
1) `"kernel"` (required):
|
||||
1) `"kernel"` (optional):
|
||||
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
|
||||
If not provided, a ValueError is raised.
|
||||
If not provided either in constructor or extra_create_args, a ValueError is raised.
|
||||
|
||||
2) `"prompt_execution_settings"` (optional):
|
||||
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
|
||||
@ -400,17 +409,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Yields:
|
||||
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
|
||||
"""
|
||||
if "kernel" not in extra_create_args:
|
||||
raise ValueError("kernel is required in extra_create_args")
|
||||
|
||||
kernel = extra_create_args["kernel"]
|
||||
if not isinstance(kernel, Kernel):
|
||||
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = extra_create_args.get("prompt_execution_settings", None)
|
||||
if user_settings is None:
|
||||
user_settings = self._default_prompt_settings
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
settings = self._build_execution_settings(user_settings, tools)
|
||||
self._sync_tools_with_kernel(kernel, tools)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user