feat: enhance Gemini model support in OpenAI client and tests (#5461)

This commit is contained in:
Eric Zhu 2025-02-09 10:12:59 -08:00 committed by GitHub
parent 5308b76d5f
commit 9a028acf9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 116 additions and 112 deletions

View File

@ -145,13 +145,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
roles=roles, participants=str(participants), history=history
)
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if self._model_client.model_info["family"] in [
ModelFamily.GPT_4,
ModelFamily.GPT_4O,
ModelFamily.GPT_35,
ModelFamily.O1,
ModelFamily.O3,
]:
if ModelFamily.is_openai(self._model_client.model_info["family"]):
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
else:
# Many other models need a UserMessage to respond to

View File

@ -4,7 +4,7 @@ import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ChatCompletionClient, ModelFamily
from autogen_core.models import ChatCompletionClient
from autogen_ext.models.openai import OpenAIChatCompletionClient
@ -36,14 +36,7 @@ async def test_selector_group_chat_gemini() -> None:
model_client = OpenAIChatCompletionClient(
model="gemini-1.5-flash",
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key=api_key,
model_info={
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
)
await _test_selector_group_chat(model_client)

View File

@ -12,11 +12,7 @@
"metadata": {},
"source": [
"Via AgentChat, you can build applications quickly using preset agents.\n",
"To illustrate this, we will begin with creating a team of a single tool-use\n",
"agent that you can chat with.\n",
"\n",
"The following code uses the OpenAI model. If you haven't already, you need to\n",
"install the following package and extension:"
"To illustrate this, we will begin with creating a single tool-use agent."
]
},
{
@ -29,7 +25,7 @@
},
"outputs": [],
"source": [
"pip install -U \"autogen-agentchat\" \"autogen-ext[openai,azure]\""
"pip install -U \"autogen-ext[openai,azure]\""
]
},
{
@ -37,12 +33,14 @@
"metadata": {},
"source": [
"To use Azure OpenAI models and AAD authentication,\n",
"you can follow the instructions [here](./tutorial/models.ipynb#azure-openai)."
"you can follow the instructions [here](./tutorial/models.ipynb#azure-openai).\n",
"\n",
"To use other models, see [Models](./tutorial/models.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 15,
"metadata": {},
"outputs": [
{
@ -50,94 +48,53 @@
"output_type": "stream",
"text": [
"---------- user ----------\n",
"What is the weather in NYC?\n",
"What is the weather in New York?\n",
"---------- weather_agent ----------\n",
"[FunctionCall(id='call_vN04UiNJgqSz6g3MHt7Renig', arguments='{\"city\":\"New York City\"}', name='get_weather')]\n",
"[Prompt tokens: 75, Completion tokens: 16]\n",
"[FunctionCall(id='call_ciy1Ecys9LH201cyim10xlnQ', arguments='{\"city\":\"New York\"}', name='get_weather')]\n",
"---------- weather_agent ----------\n",
"[FunctionExecutionResult(content='The weather in New York City is 73 degrees and Sunny.', call_id='call_vN04UiNJgqSz6g3MHt7Renig')]\n",
"[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_ciy1Ecys9LH201cyim10xlnQ')]\n",
"---------- weather_agent ----------\n",
"The weather in New York City is 73 degrees and Sunny.\n",
"---------- Summary ----------\n",
"Number of messages: 4\n",
"Finish reason: Maximum number of turns 1 reached.\n",
"Total prompt tokens: 75\n",
"Total completion tokens: 16\n",
"Duration: 1.15 seconds\n",
"---------- user ----------\n",
"What is the weather in Seattle?\n",
"---------- weather_agent ----------\n",
"[FunctionCall(id='call_BesYutZXJIMfu2TlDZgodIEj', arguments='{\"city\":\"Seattle\"}', name='get_weather')]\n",
"[Prompt tokens: 127, Completion tokens: 14]\n",
"---------- weather_agent ----------\n",
"[FunctionExecutionResult(content='The weather in Seattle is 73 degrees and Sunny.', call_id='call_BesYutZXJIMfu2TlDZgodIEj')]\n",
"---------- weather_agent ----------\n",
"The weather in Seattle is 73 degrees and Sunny.\n",
"---------- Summary ----------\n",
"Number of messages: 4\n",
"Finish reason: Maximum number of turns 1 reached.\n",
"Total prompt tokens: 127\n",
"Total completion tokens: 14\n",
"Duration: 2.38 seconds\n"
"The weather in New York is currently 73 degrees and sunny.\n"
]
}
],
"source": [
"from autogen_agentchat.agents import AssistantAgent\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"\n",
"# Define a tool\n",
"async def get_weather(city: str) -> str:\n",
" \"\"\"Get the weather for a given city.\"\"\"\n",
" return f\"The weather in {city} is 73 degrees and Sunny.\"\n",
"\n",
"\n",
"async def main() -> None:\n",
" # Define an agent\n",
" weather_agent = AssistantAgent(\n",
" agent = AssistantAgent(\n",
" name=\"weather_agent\",\n",
" model_client=OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-2024-08-06\",\n",
" model=\"gpt-4o\",\n",
" # api_key=\"YOUR_API_KEY\",\n",
" ),\n",
" tools=[get_weather],\n",
" system_message=\"You are a helpful assistant.\",\n",
" reflect_on_tool_use=True,\n",
" )\n",
"\n",
" # Define a team with a single agent and maximum auto-gen turns of 1.\n",
" agent_team = RoundRobinGroupChat([weather_agent], max_turns=1)\n",
"\n",
" while True:\n",
" # Get user input from the console.\n",
" user_input = input(\"Enter a message (type 'exit' to leave): \")\n",
" if user_input.strip().lower() == \"exit\":\n",
" break\n",
" # Run the team and stream messages to the console.\n",
" stream = agent_team.run_stream(task=user_input)\n",
" await Console(stream)\n",
" await Console(agent.run_stream(task=\"What is the weather in New York?\"))\n",
"\n",
"\n",
"# NOTE: if running this inside a Python script you'll need to use asyncio.run(main()).\n",
"await main()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code snippet above introduces two high level concepts in AgentChat: *Agent* and *Team*. An Agent helps us define what actions are taken when a message is received. Specifically, we use the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset - an agent that can be given access to a model (e.g., LLM) and tools (functions) that it can then use to address tasks. A Team helps us define the rules for how agents interact with each other. In the {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team, agents respond in a sequential round-robin fashion.\n",
"In this case, we have a single agent, so the same agent is used for each round."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What's Next?\n",
"\n",
"Now that you have a basic understanding of how to define an agent and a team, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat.\n",
"\n"
"Now that you have a basic understanding of how to define an agent, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat."
]
}
],
@ -157,7 +114,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@ -295,33 +295,24 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n"
"finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n"
]
}
],
"source": [
"import os\n",
"\n",
"from autogen_core.models import ModelFamily, UserMessage\n",
"from autogen_core.models import UserMessage\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"model_client = OpenAIChatCompletionClient(\n",
" model=\"gemini-1.5-flash\",\n",
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n",
" api_key=os.environ[\"GEMINI_API_KEY\"],\n",
" model_info={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
" },\n",
" model=\"gemini-1.5-flash-8b\",\n",
" # api_key=\"GEMINI_API_KEY\",\n",
")\n",
"\n",
"response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n",

View File

@ -303,33 +303,24 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n"
"finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n"
]
}
],
"source": [
"import os\n",
"\n",
"from autogen_core.models import ModelFamily, UserMessage\n",
"from autogen_core.models import UserMessage\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"model_client = OpenAIChatCompletionClient(\n",
" model=\"gemini-1.5-flash\",\n",
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n",
" api_key=os.environ[\"GEMINI_API_KEY\"],\n",
" model_info={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
" },\n",
" model=\"gemini-1.5-flash-8b\",\n",
" # api_key=\"GEMINI_API_KEY\",\n",
")\n",
"\n",
"response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n",

View File

@ -55,6 +55,35 @@ class ModelFamily:
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
@staticmethod
def is_claude(family: str) -> bool:
return family in (
ModelFamily.CLAUDE_3_HAIKU,
ModelFamily.CLAUDE_3_SONNET,
ModelFamily.CLAUDE_3_OPUS,
ModelFamily.CLAUDE_3_5_HAIKU,
ModelFamily.CLAUDE_3_5_SONNET,
)
@staticmethod
def is_gemini(family: str) -> bool:
return family in (
ModelFamily.GEMINI_1_5_FLASH,
ModelFamily.GEMINI_1_5_PRO,
ModelFamily.GEMINI_2_0_FLASH,
)
@staticmethod
def is_openai(family: str) -> bool:
return family in (
ModelFamily.GPT_4O,
ModelFamily.O1,
ModelFamily.O3,
ModelFamily.GPT_4,
ModelFamily.GPT_35,
ModelFamily.R1,
)
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
class ModelCapabilities(TypedDict, total=False):

View File

@ -134,6 +134,36 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"json_output": True,
"family": ModelFamily.GPT_35,
},
"gemini-1.5-flash": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
"gemini-1.5-flash-8b": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
"gemini-1.5-pro": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_PRO,
},
"gemini-2.0-flash": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_2_0_FLASH,
},
"gemini-2.0-flash-lite-preview-02-05": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_2_0_FLASH,
},
}
_MODEL_TOKEN_LIMITS: Dict[str, int] = {
@ -156,8 +186,14 @@ _MODEL_TOKEN_LIMITS: Dict[str, int] = {
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16385,
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.5-pro": 2097152,
"gemini-2.0-flash": 1048576,
}
GEMINI_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
def resolve_model(model: str) -> str:
if model in _MODEL_POINTERS:

View File

@ -3,6 +3,7 @@ import inspect
import json
import logging
import math
import os
import re
import warnings
from asyncio import Task
@ -1093,6 +1094,14 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]
# Special handling for Gemini model.
assert "model" in copied_args and isinstance(copied_args["model"], str)
if copied_args["model"].startswith("gemini-"):
if "base_url" not in copied_args:
copied_args["base_url"] = _model_info.GEMINI_OPENAI_BASE_URL
if "api_key" not in copied_args and "GEMINI_API_KEY" in os.environ:
copied_args["api_key"] = os.environ["GEMINI_API_KEY"]
client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)

View File

@ -197,6 +197,18 @@ async def test_openai_chat_completion_client() -> None:
assert client
@pytest.mark.asyncio
async def test_openai_chat_completion_client_with_gemini_model() -> None:
client = OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key="api_key")
assert client
@pytest.mark.asyncio
async def test_openai_chat_completion_client_raise_on_unknown_model() -> None:
with pytest.raises(ValueError, match="model_info is required"):
_ = OpenAIChatCompletionClient(model="unknown", api_key="api_key")
@pytest.mark.asyncio
async def test_custom_model_with_capabilities() -> None:
with pytest.raises(ValueError, match="model_info is required"):
@ -952,14 +964,6 @@ async def test_gemini() -> None:
model_client = OpenAIChatCompletionClient(
model="gemini-1.5-flash",
api_key=api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
model_info={
"function_calling": True,
"json_output": True,
"vision": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
)
await _test_model_client_basic_completion(model_client)
await _test_model_client_with_function_calling(model_client)