Bring back OpenAIAssistantAgent (#6867)

This commit is contained in:
Eric Zhu 2025-07-28 01:29:06 -07:00 committed by GitHub
parent d6ec7b85e3
commit 7d627f45ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 106 deletions

View File

@ -30,7 +30,7 @@ azure = [
]
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]
openai = ["openai>=1.66.5", "tiktoken>=0.8.0", "aiofiles"]
openai = ["openai>=1.93", "tiktoken>=0.8.0", "aiofiles"]
file-surfer = [
"autogen-agentchat==0.7.0",
"magika>=0.6.1rc2",
@ -177,7 +177,7 @@ exclude = ["src/autogen_ext/agents/web_surfer/*.js", "src/autogen_ext/runtimes/g
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests"]
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos", "src/autogen_ext/agents/openai/_openai_assistant_agent.py", "tests/test_openai_assistant_agent.py"]
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
[tool.pytest.ini_options]
minversion = "6.0"
@ -197,7 +197,7 @@ test.sequence = [
test.default_item_type = "cmd"
test-grpc = "pytest -n 1 --cov=src --cov-report=term-missing --cov-report=xml --grpc"
test-windows = "pytest -n 1 --cov=src --cov-report=term-missing --cov-report=xml -m 'windows'"
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos --exclude src/autogen_ext/agents/openai/_openai_assistant_agent.py --exclude tests/test_openai_assistant_agent.py --ignore-missing-imports src tests"
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos --ignore-missing-imports src tests"
[tool.mypy]
[[tool.mypy.overrides]]

View File

@ -1,39 +1,7 @@
try:
from ._openai_agent import OpenAIAgent
from ._openai_agent import OpenAIAgent
from ._openai_assistant_agent import OpenAIAssistantAgent
# Check OpenAI version to conditionally import OpenAIAssistantAgent
try:
from openai import __version__ as openai_version
def _parse_openai_version(version_str: str) -> tuple[int, int, int]:
"""Parse a semantic version string into a tuple of integers."""
try:
parts = version_str.split(".")
major = int(parts[0])
minor = int(parts[1]) if len(parts) > 1 else 0
patch = int(parts[2].split("-")[0]) if len(parts) > 2 else 0 # Handle pre-release versions
return (major, minor, patch)
except (ValueError, IndexError):
# If version parsing fails, assume it's a newer version
return (999, 999, 999)
_current_version = _parse_openai_version(openai_version)
_target_version = (1, 83, 0)
# Only import OpenAIAssistantAgent if OpenAI version is less than 1.83
if _current_version < _target_version:
from ._openai_assistant_agent import OpenAIAssistantAgent # type: ignore[import]
__all__ = ["OpenAIAssistantAgent", "OpenAIAgent"]
else:
__all__ = ["OpenAIAgent"]
except ImportError:
# If OpenAI is not available, skip OpenAIAssistantAgent import
__all__ = ["OpenAIAgent"]
except ImportError as e:
raise ImportError(
"Dependencies for OpenAI agents not found. "
'Please install autogen-ext with the "openai" extra: '
'pip install "autogen-ext[openai]"'
) from e
__all__ = [
"OpenAIAgent",
"OpenAIAssistantAgent",
]

View File

@ -1,17 +1,7 @@
"""
OpenAI Assistant Agent implementation.
This module is deprecated starting v0.7.0 and will be removed in a future version.
"""
# pyright: ignore
# mypy: ignore-errors
import asyncio
import json
import logging
import os
import warnings
from functools import wraps
from typing import (
Any,
AsyncGenerator,
@ -67,33 +57,9 @@ from openai.types.beta.threads.text_content_block_param import TextContentBlockP
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.vector_store import VectorStore
# Deprecation warning
warnings.warn(
"The OpenAIAssistantAgent module is deprecated and will be removed in a future version, use OpenAIAgent instead.",
DeprecationWarning,
stacklevel=2,
)
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def deprecated_class(reason: str) -> Callable[[type], type]:
"""Decorator to mark a class as deprecated."""
def decorator(cls: type) -> type:
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs) -> None:
warnings.warn(f"{cls.__name__} is deprecated: {reason}", DeprecationWarning, stacklevel=2)
original_init(self, *args, **kwargs)
cls.__init__ = new_init
return cls
return decorator
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
@ -124,9 +90,6 @@ class OpenAIAssistantAgentState(BaseModel):
uploaded_file_ids: List[str] = Field(default_factory=list)
@deprecated_class(
"This class is deprecated starting v0.7.0 and will be removed in a future version. Use OpenAIAgent instead."
)
class OpenAIAssistantAgent(BaseChatAgent):
"""An agent implementation that uses the Assistant API to generate responses.
@ -353,9 +316,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
"""Ensure assistant and thread are created."""
if self._assistant is None:
if self._assistant_id:
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id)
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
else:
self._assistant = await self._client.beta.assistants.create(
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
model=self._model,
description=self.description,
instructions=self._instructions,
@ -369,9 +332,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
if self._thread is None:
if self._init_thread_id:
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id)
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
else:
self._thread = await self._client.beta.threads.create()
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
# Retrieve initial state only once
if not self._initial_state_retrieved:
@ -384,7 +347,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
initial_message_ids: Set[str] = set()
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list(
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
self._thread_id, after=after, order="asc", limit=100
)
for msg in msgs.data:
@ -458,7 +421,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
# Create and start a run
run: Run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.create(
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
assistant_id=self._get_assistant_id,
)
@ -469,7 +432,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
while True:
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.retrieve(
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
)
@ -522,7 +485,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
# Submit tool outputs back to the run
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.submit_tool_outputs(
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
@ -539,7 +502,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
# Get messages after run completion
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1)
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
)
)
@ -577,7 +540,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.create(
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
content=content,
role="user",
@ -595,7 +558,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
while True:
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100)
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
)
)
for msg in msgs.data:
@ -609,7 +572,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
for msg_id in new_message_ids:
status: MessageDeleted = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
)
)
assert status.deleted is True
@ -645,7 +608,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
# Update thread with the new files
thread = await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id))
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
)
tool_resources: ToolResources = thread.tool_resources or ToolResources()
code_interpreter: ToolResourcesCodeInterpreter = (
@ -657,7 +620,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.update(
self._client.beta.threads.update( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
)
@ -720,7 +683,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
if self._assistant is not None and not self._assistant_id:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id))
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
)
self._assistant = None
except Exception as e:

View File

@ -11,20 +11,11 @@ import pytest
from autogen_agentchat.messages import BaseChatMessage, TextMessage, ToolCallRequestEvent
from autogen_core import CancellationToken
from autogen_core.tools._base import BaseTool, Tool
from autogen_ext.agents.openai import OpenAIAssistantAgent
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai import __version__ as openai_version
from pydantic import BaseModel
# Only import the OpenAIAssistantAgent if the version is compatible.
try:
from autogen_ext.agents.openai import OpenAIAssistantAgent
except ImportError:
pytest.skip(
"OpenAIAssistantAgent not available. Skipping all tests in this module.",
allow_module_level=True,
)
class QuestionType(str, Enum):
MULTIPLE_CHOICE = "MULTIPLE_CHOICE"
@ -345,8 +336,8 @@ async def test_quiz_creation(
@pytest.mark.asyncio
async def test_on_reset_behavior(client: AsyncOpenAI, cancellation_token: CancellationToken) -> None:
# Arrange: Use the default behavior for reset.
thread = await client.beta.threads.create()
await client.beta.threads.messages.create(
thread = await client.beta.threads.create() # type: ignore[reportDeprecated]
await client.beta.threads.messages.create( # type: ignore[reportDeprecated]
thread_id=thread.id,
content="Hi, my name is John and I'm a software engineer. Use this information to help me.",
role="user",

2
python/uv.lock generated
View File

@ -777,7 +777,7 @@ requires-dist = [
{ name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" },
{ name = "neo4j", marker = "extra == 'mem0-local'", specifier = ">=5.25.0" },
{ name = "ollama", marker = "extra == 'ollama'", specifier = ">=0.4.7" },
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.66.5" },
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.93" },
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
{ name = "pillow", marker = "extra == 'magentic-one'", specifier = ">=11.0.0" },