mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 22:48:40 +00:00
feat: enhance Gemini model support in OpenAI client and tests (#5461)
This commit is contained in:
parent
5308b76d5f
commit
9a028acf9f
@ -145,13 +145,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
roles=roles, participants=str(participants), history=history
|
||||
)
|
||||
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
|
||||
if self._model_client.model_info["family"] in [
|
||||
ModelFamily.GPT_4,
|
||||
ModelFamily.GPT_4O,
|
||||
ModelFamily.GPT_35,
|
||||
ModelFamily.O1,
|
||||
ModelFamily.O3,
|
||||
]:
|
||||
if ModelFamily.is_openai(self._model_client.model_info["family"]):
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
else:
|
||||
# Many other models need a UserMessage to respond to
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core.models import ChatCompletionClient, ModelFamily
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
@ -36,14 +36,7 @@ async def test_selector_group_chat_gemini() -> None:
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gemini-1.5-flash",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
api_key=api_key,
|
||||
model_info={
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
},
|
||||
)
|
||||
await _test_selector_group_chat(model_client)
|
||||
|
||||
|
||||
@ -12,11 +12,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Via AgentChat, you can build applications quickly using preset agents.\n",
|
||||
"To illustrate this, we will begin with creating a team of a single tool-use\n",
|
||||
"agent that you can chat with.\n",
|
||||
"\n",
|
||||
"The following code uses the OpenAI model. If you haven't already, you need to\n",
|
||||
"install the following package and extension:"
|
||||
"To illustrate this, we will begin with creating a single tool-use agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -29,7 +25,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install -U \"autogen-agentchat\" \"autogen-ext[openai,azure]\""
|
||||
"pip install -U \"autogen-ext[openai,azure]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -37,12 +33,14 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To use Azure OpenAI models and AAD authentication,\n",
|
||||
"you can follow the instructions [here](./tutorial/models.ipynb#azure-openai)."
|
||||
"you can follow the instructions [here](./tutorial/models.ipynb#azure-openai).\n",
|
||||
"\n",
|
||||
"To use other models, see [Models](./tutorial/models.ipynb)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -50,94 +48,53 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What is the weather in NYC?\n",
|
||||
"What is the weather in New York?\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"[FunctionCall(id='call_vN04UiNJgqSz6g3MHt7Renig', arguments='{\"city\":\"New York City\"}', name='get_weather')]\n",
|
||||
"[Prompt tokens: 75, Completion tokens: 16]\n",
|
||||
"[FunctionCall(id='call_ciy1Ecys9LH201cyim10xlnQ', arguments='{\"city\":\"New York\"}', name='get_weather')]\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"[FunctionExecutionResult(content='The weather in New York City is 73 degrees and Sunny.', call_id='call_vN04UiNJgqSz6g3MHt7Renig')]\n",
|
||||
"[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_ciy1Ecys9LH201cyim10xlnQ')]\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"The weather in New York City is 73 degrees and Sunny.\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 4\n",
|
||||
"Finish reason: Maximum number of turns 1 reached.\n",
|
||||
"Total prompt tokens: 75\n",
|
||||
"Total completion tokens: 16\n",
|
||||
"Duration: 1.15 seconds\n",
|
||||
"---------- user ----------\n",
|
||||
"What is the weather in Seattle?\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"[FunctionCall(id='call_BesYutZXJIMfu2TlDZgodIEj', arguments='{\"city\":\"Seattle\"}', name='get_weather')]\n",
|
||||
"[Prompt tokens: 127, Completion tokens: 14]\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"[FunctionExecutionResult(content='The weather in Seattle is 73 degrees and Sunny.', call_id='call_BesYutZXJIMfu2TlDZgodIEj')]\n",
|
||||
"---------- weather_agent ----------\n",
|
||||
"The weather in Seattle is 73 degrees and Sunny.\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 4\n",
|
||||
"Finish reason: Maximum number of turns 1 reached.\n",
|
||||
"Total prompt tokens: 127\n",
|
||||
"Total completion tokens: 14\n",
|
||||
"Duration: 2.38 seconds\n"
|
||||
"The weather in New York is currently 73 degrees and sunny.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define a tool\n",
|
||||
"async def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get the weather for a given city.\"\"\"\n",
|
||||
" return f\"The weather in {city} is 73 degrees and Sunny.\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def main() -> None:\n",
|
||||
" # Define an agent\n",
|
||||
" weather_agent = AssistantAgent(\n",
|
||||
" agent = AssistantAgent(\n",
|
||||
" name=\"weather_agent\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" model=\"gpt-4o\",\n",
|
||||
" # api_key=\"YOUR_API_KEY\",\n",
|
||||
" ),\n",
|
||||
" tools=[get_weather],\n",
|
||||
" system_message=\"You are a helpful assistant.\",\n",
|
||||
" reflect_on_tool_use=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Define a team with a single agent and maximum auto-gen turns of 1.\n",
|
||||
" agent_team = RoundRobinGroupChat([weather_agent], max_turns=1)\n",
|
||||
"\n",
|
||||
" while True:\n",
|
||||
" # Get user input from the console.\n",
|
||||
" user_input = input(\"Enter a message (type 'exit' to leave): \")\n",
|
||||
" if user_input.strip().lower() == \"exit\":\n",
|
||||
" break\n",
|
||||
" # Run the team and stream messages to the console.\n",
|
||||
" stream = agent_team.run_stream(task=user_input)\n",
|
||||
" await Console(stream)\n",
|
||||
" await Console(agent.run_stream(task=\"What is the weather in New York?\"))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# NOTE: if running this inside a Python script you'll need to use asyncio.run(main()).\n",
|
||||
"await main()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code snippet above introduces two high level concepts in AgentChat: *Agent* and *Team*. An Agent helps us define what actions are taken when a message is received. Specifically, we use the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset - an agent that can be given access to a model (e.g., LLM) and tools (functions) that it can then use to address tasks. A Team helps us define the rules for how agents interact with each other. In the {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team, agents respond in a sequential round-robin fashion.\n",
|
||||
"In this case, we have a single agent, so the same agent is used for each round."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## What's Next?\n",
|
||||
"\n",
|
||||
"Now that you have a basic understanding of how to define an agent and a team, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat.\n",
|
||||
"\n"
|
||||
"Now that you have a basic understanding of how to define an agent, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat."
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -157,7 +114,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -295,33 +295,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n"
|
||||
"finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from autogen_core.models import ModelFamily, UserMessage\n",
|
||||
"from autogen_core.models import UserMessage\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"model_client = OpenAIChatCompletionClient(\n",
|
||||
" model=\"gemini-1.5-flash\",\n",
|
||||
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n",
|
||||
" api_key=os.environ[\"GEMINI_API_KEY\"],\n",
|
||||
" model_info={\n",
|
||||
" \"vision\": True,\n",
|
||||
" \"function_calling\": True,\n",
|
||||
" \"json_output\": True,\n",
|
||||
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
|
||||
" },\n",
|
||||
" model=\"gemini-1.5-flash-8b\",\n",
|
||||
" # api_key=\"GEMINI_API_KEY\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n",
|
||||
|
||||
@ -303,33 +303,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n"
|
||||
"finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from autogen_core.models import ModelFamily, UserMessage\n",
|
||||
"from autogen_core.models import UserMessage\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"model_client = OpenAIChatCompletionClient(\n",
|
||||
" model=\"gemini-1.5-flash\",\n",
|
||||
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n",
|
||||
" api_key=os.environ[\"GEMINI_API_KEY\"],\n",
|
||||
" model_info={\n",
|
||||
" \"vision\": True,\n",
|
||||
" \"function_calling\": True,\n",
|
||||
" \"json_output\": True,\n",
|
||||
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
|
||||
" },\n",
|
||||
" model=\"gemini-1.5-flash-8b\",\n",
|
||||
" # api_key=\"GEMINI_API_KEY\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n",
|
||||
|
||||
@ -55,6 +55,35 @@ class ModelFamily:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
|
||||
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
|
||||
|
||||
@staticmethod
|
||||
def is_claude(family: str) -> bool:
|
||||
return family in (
|
||||
ModelFamily.CLAUDE_3_HAIKU,
|
||||
ModelFamily.CLAUDE_3_SONNET,
|
||||
ModelFamily.CLAUDE_3_OPUS,
|
||||
ModelFamily.CLAUDE_3_5_HAIKU,
|
||||
ModelFamily.CLAUDE_3_5_SONNET,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_gemini(family: str) -> bool:
|
||||
return family in (
|
||||
ModelFamily.GEMINI_1_5_FLASH,
|
||||
ModelFamily.GEMINI_1_5_PRO,
|
||||
ModelFamily.GEMINI_2_0_FLASH,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_openai(family: str) -> bool:
|
||||
return family in (
|
||||
ModelFamily.GPT_4O,
|
||||
ModelFamily.O1,
|
||||
ModelFamily.O3,
|
||||
ModelFamily.GPT_4,
|
||||
ModelFamily.GPT_35,
|
||||
ModelFamily.R1,
|
||||
)
|
||||
|
||||
|
||||
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
|
||||
@ -134,6 +134,36 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
},
|
||||
"gemini-1.5-flash-8b": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
},
|
||||
"gemini-1.5-pro": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_PRO,
|
||||
},
|
||||
"gemini-2.0-flash": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_2_0_FLASH,
|
||||
},
|
||||
"gemini-2.0-flash-lite-preview-02-05": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_2_0_FLASH,
|
||||
},
|
||||
}
|
||||
|
||||
_MODEL_TOKEN_LIMITS: Dict[str, int] = {
|
||||
@ -156,8 +186,14 @@ _MODEL_TOKEN_LIMITS: Dict[str, int] = {
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k-0613": 16385,
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.5-pro": 2097152,
|
||||
"gemini-2.0-flash": 1048576,
|
||||
}
|
||||
|
||||
GEMINI_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
|
||||
def resolve_model(model: str) -> str:
|
||||
if model in _MODEL_POINTERS:
|
||||
|
||||
@ -3,6 +3,7 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from asyncio import Task
|
||||
@ -1093,6 +1094,14 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
if "add_name_prefixes" in kwargs:
|
||||
add_name_prefixes = kwargs["add_name_prefixes"]
|
||||
|
||||
# Special handling for Gemini model.
|
||||
assert "model" in copied_args and isinstance(copied_args["model"], str)
|
||||
if copied_args["model"].startswith("gemini-"):
|
||||
if "base_url" not in copied_args:
|
||||
copied_args["base_url"] = _model_info.GEMINI_OPENAI_BASE_URL
|
||||
if "api_key" not in copied_args and "GEMINI_API_KEY" in os.environ:
|
||||
copied_args["api_key"] = os.environ["GEMINI_API_KEY"]
|
||||
|
||||
client = _openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
|
||||
|
||||
@ -197,6 +197,18 @@ async def test_openai_chat_completion_client() -> None:
|
||||
assert client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_with_gemini_model() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key="api_key")
|
||||
assert client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_raise_on_unknown_model() -> None:
|
||||
with pytest.raises(ValueError, match="model_info is required"):
|
||||
_ = OpenAIChatCompletionClient(model="unknown", api_key="api_key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_model_with_capabilities() -> None:
|
||||
with pytest.raises(ValueError, match="model_info is required"):
|
||||
@ -952,14 +964,6 @@ async def test_gemini() -> None:
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gemini-1.5-flash",
|
||||
api_key=api_key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model_info={
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"vision": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
},
|
||||
)
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
await _test_model_client_with_function_calling(model_client)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user