mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Support openai assistant v2 API (#2466)
* adapted to openai assistant v2 api * fix comments * format code * fix ci * Update autogen/agentchat/contrib/gpt_assistant_agent.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
2daae42708
commit
a41182a93f
@ -10,7 +10,7 @@ import openai
|
||||
from autogen import OpenAIWrapper
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
|
||||
from autogen.oai.openai_utils import retrieve_assistants_by_name
|
||||
from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,7 +50,8 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
- check_every_ms: check thread run status interval
|
||||
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
|
||||
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
|
||||
- file_ids: files used by retrieval in run
|
||||
- file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
|
||||
- tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
|
||||
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
|
||||
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
|
||||
kwargs (dict): Additional configuration options for the agent.
|
||||
@ -90,7 +91,6 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
candidate_assistants,
|
||||
instructions,
|
||||
openai_assistant_cfg.get("tools", []),
|
||||
openai_assistant_cfg.get("file_ids", []),
|
||||
)
|
||||
|
||||
if len(candidate_assistants) == 0:
|
||||
@ -101,12 +101,12 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
|
||||
)
|
||||
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
|
||||
self._openai_assistant = self._openai_client.beta.assistants.create(
|
||||
self._openai_assistant = create_gpt_assistant(
|
||||
self._openai_client,
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=openai_assistant_cfg.get("tools", []),
|
||||
model=model_name,
|
||||
file_ids=openai_assistant_cfg.get("file_ids", []),
|
||||
assistant_config=openai_assistant_cfg,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@ -127,9 +127,12 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
logger.warning(
|
||||
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
|
||||
)
|
||||
self._openai_assistant = self._openai_client.beta.assistants.update(
|
||||
self._openai_assistant = update_gpt_assistant(
|
||||
self._openai_client,
|
||||
assistant_id=openai_assistant_id,
|
||||
instructions=instructions,
|
||||
assistant_config={
|
||||
"instructions": instructions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@ -154,9 +157,13 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
logger.warning(
|
||||
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
|
||||
)
|
||||
self._openai_assistant = self._openai_client.beta.assistants.update(
|
||||
self._openai_assistant = update_gpt_assistant(
|
||||
self._openai_client,
|
||||
assistant_id=openai_assistant_id,
|
||||
tools=openai_assistant_cfg.get("tools", []),
|
||||
assistant_config={
|
||||
"tools": specified_tools,
|
||||
"tool_resources": openai_assistant_cfg.get("tool_resources", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
|
||||
@ -198,6 +205,8 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
assistant_thread = self._openai_threads[sender]
|
||||
# Process each unread message
|
||||
for message in pending_messages:
|
||||
if message["content"].strip() == "":
|
||||
continue
|
||||
self._openai_client.beta.threads.messages.create(
|
||||
thread_id=assistant_thread.id,
|
||||
content=message["content"],
|
||||
@ -426,22 +435,23 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
logger.warning("Permanently deleting assistant...")
|
||||
self._openai_client.beta.assistants.delete(self.assistant_id)
|
||||
|
||||
def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
|
||||
def find_matching_assistant(self, candidate_assistants, instructions, tools):
|
||||
"""
|
||||
Find the matching assistant from a list of candidate assistants.
|
||||
Filter out candidates with the same name but different instructions, file IDs, and function names.
|
||||
TODO: implement accurate match based on assistant metadata fields.
|
||||
Filter out candidates with the same name but different instructions, and function names.
|
||||
"""
|
||||
matching_assistants = []
|
||||
|
||||
# Preprocess the required tools for faster comparison
|
||||
required_tool_types = set(tool.get("type") for tool in tools)
|
||||
required_tool_types = set(
|
||||
"file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
|
||||
)
|
||||
|
||||
required_function_names = set(
|
||||
tool.get("function", {}).get("name")
|
||||
for tool in tools
|
||||
if tool.get("type") not in ["code_interpreter", "retrieval"]
|
||||
if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
|
||||
)
|
||||
required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
|
||||
|
||||
for assistant in candidate_assistants:
|
||||
# Check if instructions are similar
|
||||
@ -454,11 +464,12 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
continue
|
||||
|
||||
# Preprocess the assistant's tools
|
||||
assistant_tool_types = set(tool.type for tool in assistant.tools)
|
||||
assistant_tool_types = set(
|
||||
"file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
|
||||
)
|
||||
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
|
||||
assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
|
||||
|
||||
# Check if the tool types, function names, and file IDs match
|
||||
# Check if the tool types, function names match
|
||||
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
|
||||
logger.warning(
|
||||
"tools not match, skip assistant(%s): tools %s, functions %s",
|
||||
@ -467,9 +478,6 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
assistant_function_names,
|
||||
)
|
||||
continue
|
||||
if required_file_ids != assistant_file_ids:
|
||||
logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
|
||||
continue
|
||||
|
||||
# Append assistant to matching list if all conditions are met
|
||||
matching_assistants.append(assistant)
|
||||
@ -496,7 +504,7 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
|
||||
# Move the assistant related configurations to assistant_config
|
||||
# It's important to keep forward compatibility
|
||||
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
|
||||
assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
|
||||
for item in assistant_config_items:
|
||||
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
|
||||
openai_assistant_cfg[item] = openai_client_cfg[item]
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
import importlib.metadata
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from packaging.version import parse
|
||||
|
||||
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
|
||||
@ -675,3 +678,103 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
|
||||
if assistant.name == name:
|
||||
candidate_assistants.append(assistant)
|
||||
return candidate_assistants
|
||||
|
||||
|
||||
def detect_gpt_assistant_api_version() -> str:
|
||||
"""Detect the openai assistant API version"""
|
||||
oai_version = importlib.metadata.version("openai")
|
||||
if parse(oai_version) < parse("1.21"):
|
||||
return "v1"
|
||||
else:
|
||||
return "v2"
|
||||
|
||||
|
||||
def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any:
|
||||
"""Create a openai vector store for gpt assistant"""
|
||||
|
||||
vector_store = client.beta.vector_stores.create(name=name)
|
||||
# poll the status of the file batch for completion.
|
||||
batch = client.beta.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
|
||||
|
||||
if batch.status == "in_progress":
|
||||
time.sleep(1)
|
||||
logging.debug(f"file batch status: {batch.file_counts}")
|
||||
batch = client.beta.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
|
||||
|
||||
if batch.status == "completed":
|
||||
return vector_store
|
||||
|
||||
raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
|
||||
|
||||
|
||||
def create_gpt_assistant(
|
||||
client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any]
|
||||
) -> Assistant:
|
||||
"""Create a openai gpt assistant"""
|
||||
|
||||
assistant_create_kwargs = {}
|
||||
gpt_assistant_api_version = detect_gpt_assistant_api_version()
|
||||
tools = assistant_config.get("tools", [])
|
||||
|
||||
if gpt_assistant_api_version == "v2":
|
||||
tool_resources = assistant_config.get("tool_resources", {})
|
||||
file_ids = assistant_config.get("file_ids")
|
||||
if tool_resources.get("file_search") is not None and file_ids is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
|
||||
)
|
||||
|
||||
# Designed for backwards compatibility for the V1 API
|
||||
# Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
|
||||
for tool in tools:
|
||||
if tool["type"] == "retrieval":
|
||||
tool["type"] = "file_search"
|
||||
if file_ids is not None:
|
||||
# create a vector store for the file search tool
|
||||
vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
|
||||
tool_resources["file_search"] = {
|
||||
"vector_store_ids": [vs.id],
|
||||
}
|
||||
elif tool["type"] == "code_interpreter" and file_ids is not None:
|
||||
tool_resources["code_interpreter"] = {
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
|
||||
assistant_create_kwargs["tools"] = tools
|
||||
if len(tool_resources) > 0:
|
||||
assistant_create_kwargs["tool_resources"] = tool_resources
|
||||
else:
|
||||
# not support forwards compatibility
|
||||
if "tool_resources" in assistant_config:
|
||||
raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
|
||||
if any(tool["type"] == "file_search" for tool in tools):
|
||||
raise ValueError(
|
||||
"`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
|
||||
)
|
||||
assistant_create_kwargs["tools"] = tools
|
||||
assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
|
||||
|
||||
logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
|
||||
return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
|
||||
|
||||
|
||||
def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant:
|
||||
"""Update openai gpt assistant"""
|
||||
|
||||
gpt_assistant_api_version = detect_gpt_assistant_api_version()
|
||||
assistant_update_kwargs = {}
|
||||
|
||||
if assistant_config.get("tools") is not None:
|
||||
assistant_update_kwargs["tools"] = assistant_config["tools"]
|
||||
|
||||
if assistant_config.get("instructions") is not None:
|
||||
assistant_update_kwargs["instructions"] = assistant_config["instructions"]
|
||||
|
||||
if gpt_assistant_api_version == "v2":
|
||||
if assistant_config.get("tool_resources") is not None:
|
||||
assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
|
||||
else:
|
||||
if assistant_config.get("file_ids") is not None:
|
||||
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
|
||||
|
||||
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
|
||||
|
||||
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ with open(os.path.join(here, "autogen/version.py")) as fp:
|
||||
__version__ = version["__version__"]
|
||||
|
||||
install_requires = [
|
||||
"openai>=1.3,<1.21",
|
||||
"openai>=1.3",
|
||||
"diskcache",
|
||||
"termcolor",
|
||||
"flaml",
|
||||
|
||||
@ -11,7 +11,7 @@ import pytest
|
||||
import autogen
|
||||
from autogen import OpenAIWrapper, UserProxyAgent
|
||||
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
|
||||
from autogen.oai.openai_utils import retrieve_assistants_by_name
|
||||
from autogen.oai.openai_utils import detect_gpt_assistant_api_version, retrieve_assistants_by_name
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
from conftest import reason, skip_openai # noqa: E402
|
||||
@ -264,6 +264,7 @@ def test_get_assistant_files() -> None:
|
||||
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
|
||||
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
|
||||
name = f"For test_get_assistant_files {uuid.uuid4()}"
|
||||
gpt_assistant_api_version = detect_gpt_assistant_api_version()
|
||||
|
||||
# keep it to test older version of assistant config
|
||||
assistant = GPTAssistantAgent(
|
||||
@ -277,10 +278,17 @@ def test_get_assistant_files() -> None:
|
||||
)
|
||||
|
||||
try:
|
||||
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
|
||||
retrieved_file_ids = [fild.id for fild in files]
|
||||
if gpt_assistant_api_version == "v1":
|
||||
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
|
||||
retrieved_file_ids = [fild.id for fild in files]
|
||||
elif gpt_assistant_api_version == "v2":
|
||||
oas_assistant = assistant.openai_client.beta.assistants.retrieve(assistant_id=assistant.assistant_id)
|
||||
vectorstore_ids = oas_assistant.tool_resources.file_search.vector_store_ids
|
||||
retrieved_file_ids = []
|
||||
for vectorstore_id in vectorstore_ids:
|
||||
files = assistant.openai_client.beta.vector_stores.files.list(vector_store_id=vectorstore_id)
|
||||
retrieved_file_ids.extend([fild.id for fild in files])
|
||||
expected_file_id = file.id
|
||||
|
||||
finally:
|
||||
assistant.delete_assistant()
|
||||
openai_client.files.delete(file.id)
|
||||
@ -401,7 +409,7 @@ def test_assistant_mismatch_retrieval() -> None:
|
||||
"tools": [
|
||||
{"type": "function", "function": function_1_schema},
|
||||
{"type": "function", "function": function_2_schema},
|
||||
{"type": "retrieval"},
|
||||
{"type": "file_search"},
|
||||
{"type": "code_interpreter"},
|
||||
],
|
||||
"file_ids": [file_1.id, file_2.id],
|
||||
@ -411,7 +419,6 @@ def test_assistant_mismatch_retrieval() -> None:
|
||||
name = f"For test_assistant_retrieval {uuid.uuid4()}"
|
||||
|
||||
assistant_first, assistant_instructions_mistaching = None, None
|
||||
assistant_file_ids_mismatch, assistant_tools_mistaching = None, None
|
||||
try:
|
||||
assistant_first = GPTAssistantAgent(
|
||||
name,
|
||||
@ -432,30 +439,11 @@ def test_assistant_mismatch_retrieval() -> None:
|
||||
)
|
||||
assert len(candidate_instructions_mistaching) == 2
|
||||
|
||||
# test mismatch fild ids
|
||||
file_ids_mismatch_llm_config = {
|
||||
"tools": [
|
||||
{"type": "code_interpreter"},
|
||||
{"type": "retrieval"},
|
||||
{"type": "function", "function": function_2_schema},
|
||||
{"type": "function", "function": function_1_schema},
|
||||
],
|
||||
"file_ids": [file_2.id],
|
||||
"config_list": openai_config_list,
|
||||
}
|
||||
assistant_file_ids_mismatch = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config=file_ids_mismatch_llm_config,
|
||||
)
|
||||
candidate_file_ids_mismatch = retrieve_assistants_by_name(assistant_file_ids_mismatch.openai_client, name)
|
||||
assert len(candidate_file_ids_mismatch) == 3
|
||||
|
||||
# test tools mismatch
|
||||
tools_mismatch_llm_config = {
|
||||
"tools": [
|
||||
{"type": "code_interpreter"},
|
||||
{"type": "retrieval"},
|
||||
{"type": "file_search"},
|
||||
{"type": "function", "function": function_3_schema},
|
||||
],
|
||||
"file_ids": [file_2.id, file_1.id],
|
||||
@ -467,15 +455,13 @@ def test_assistant_mismatch_retrieval() -> None:
|
||||
llm_config=tools_mismatch_llm_config,
|
||||
)
|
||||
candidate_tools_mismatch = retrieve_assistants_by_name(assistant_tools_mistaching.openai_client, name)
|
||||
assert len(candidate_tools_mismatch) == 4
|
||||
assert len(candidate_tools_mismatch) == 3
|
||||
|
||||
finally:
|
||||
if assistant_first:
|
||||
assistant_first.delete_assistant()
|
||||
if assistant_instructions_mistaching:
|
||||
assistant_instructions_mistaching.delete_assistant()
|
||||
if assistant_file_ids_mismatch:
|
||||
assistant_file_ids_mismatch.delete_assistant()
|
||||
if assistant_tools_mistaching:
|
||||
assistant_tools_mistaching.delete_assistant()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user