mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 23:41:28 +00:00
Bring back OpenAIAssistantAgent (#6867)
This commit is contained in:
parent
d6ec7b85e3
commit
7d627f45ca
@ -30,7 +30,7 @@ azure = [
|
||||
]
|
||||
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
|
||||
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]
|
||||
openai = ["openai>=1.66.5", "tiktoken>=0.8.0", "aiofiles"]
|
||||
openai = ["openai>=1.93", "tiktoken>=0.8.0", "aiofiles"]
|
||||
file-surfer = [
|
||||
"autogen-agentchat==0.7.0",
|
||||
"magika>=0.6.1rc2",
|
||||
@ -177,7 +177,7 @@ exclude = ["src/autogen_ext/agents/web_surfer/*.js", "src/autogen_ext/runtimes/g
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests"]
|
||||
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos", "src/autogen_ext/agents/openai/_openai_assistant_agent.py", "tests/test_openai_assistant_agent.py"]
|
||||
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
@ -197,7 +197,7 @@ test.sequence = [
|
||||
test.default_item_type = "cmd"
|
||||
test-grpc = "pytest -n 1 --cov=src --cov-report=term-missing --cov-report=xml --grpc"
|
||||
test-windows = "pytest -n 1 --cov=src --cov-report=term-missing --cov-report=xml -m 'windows'"
|
||||
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos --exclude src/autogen_ext/agents/openai/_openai_assistant_agent.py --exclude tests/test_openai_assistant_agent.py --ignore-missing-imports src tests"
|
||||
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos --ignore-missing-imports src tests"
|
||||
|
||||
[tool.mypy]
|
||||
[[tool.mypy.overrides]]
|
||||
|
||||
@ -1,39 +1,7 @@
|
||||
try:
|
||||
from ._openai_agent import OpenAIAgent
|
||||
from ._openai_agent import OpenAIAgent
|
||||
from ._openai_assistant_agent import OpenAIAssistantAgent
|
||||
|
||||
# Check OpenAI version to conditionally import OpenAIAssistantAgent
|
||||
try:
|
||||
from openai import __version__ as openai_version
|
||||
|
||||
def _parse_openai_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""Parse a semantic version string into a tuple of integers."""
|
||||
try:
|
||||
parts = version_str.split(".")
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1]) if len(parts) > 1 else 0
|
||||
patch = int(parts[2].split("-")[0]) if len(parts) > 2 else 0 # Handle pre-release versions
|
||||
return (major, minor, patch)
|
||||
except (ValueError, IndexError):
|
||||
# If version parsing fails, assume it's a newer version
|
||||
return (999, 999, 999)
|
||||
|
||||
_current_version = _parse_openai_version(openai_version)
|
||||
_target_version = (1, 83, 0)
|
||||
|
||||
# Only import OpenAIAssistantAgent if OpenAI version is less than 1.83
|
||||
if _current_version < _target_version:
|
||||
from ._openai_assistant_agent import OpenAIAssistantAgent # type: ignore[import]
|
||||
|
||||
__all__ = ["OpenAIAssistantAgent", "OpenAIAgent"]
|
||||
else:
|
||||
__all__ = ["OpenAIAgent"]
|
||||
except ImportError:
|
||||
# If OpenAI is not available, skip OpenAIAssistantAgent import
|
||||
__all__ = ["OpenAIAgent"]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for OpenAI agents not found. "
|
||||
'Please install autogen-ext with the "openai" extra: '
|
||||
'pip install "autogen-ext[openai]"'
|
||||
) from e
|
||||
__all__ = [
|
||||
"OpenAIAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
]
|
||||
|
||||
@ -1,17 +1,7 @@
|
||||
"""
|
||||
OpenAI Assistant Agent implementation.
|
||||
|
||||
This module is deprecated starting v0.7.0 and will be removed in a future version.
|
||||
"""
|
||||
# pyright: ignore
|
||||
# mypy: ignore-errors
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@ -67,33 +57,9 @@ from openai.types.beta.threads.text_content_block_param import TextContentBlockP
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai.types.vector_store import VectorStore
|
||||
|
||||
# Deprecation warning
|
||||
warnings.warn(
|
||||
"The OpenAIAssistantAgent module is deprecated and will be removed in a future version, use OpenAIAgent instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
def deprecated_class(reason: str) -> Callable[[type], type]:
|
||||
"""Decorator to mark a class as deprecated."""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs) -> None:
|
||||
warnings.warn(f"{cls.__name__} is deprecated: {reason}", DeprecationWarning, stacklevel=2)
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
|
||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||
|
||||
@ -124,9 +90,6 @@ class OpenAIAssistantAgentState(BaseModel):
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@deprecated_class(
|
||||
"This class is deprecated starting v0.7.0 and will be removed in a future version. Use OpenAIAgent instead."
|
||||
)
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
"""An agent implementation that uses the Assistant API to generate responses.
|
||||
|
||||
@ -353,9 +316,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
"""Ensure assistant and thread are created."""
|
||||
if self._assistant is None:
|
||||
if self._assistant_id:
|
||||
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id)
|
||||
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._assistant = await self._client.beta.assistants.create(
|
||||
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
|
||||
model=self._model,
|
||||
description=self.description,
|
||||
instructions=self._instructions,
|
||||
@ -369,9 +332,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
|
||||
if self._thread is None:
|
||||
if self._init_thread_id:
|
||||
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id)
|
||||
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._thread = await self._client.beta.threads.create()
|
||||
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
|
||||
|
||||
# Retrieve initial state only once
|
||||
if not self._initial_state_retrieved:
|
||||
@ -384,7 +347,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
initial_message_ids: Set[str] = set()
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list(
|
||||
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
|
||||
self._thread_id, after=after, order="asc", limit=100
|
||||
)
|
||||
for msg in msgs.data:
|
||||
@ -458,7 +421,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
# Create and start a run
|
||||
run: Run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.create(
|
||||
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._get_assistant_id,
|
||||
)
|
||||
@ -469,7 +432,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
while True:
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.retrieve(
|
||||
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
@ -522,7 +485,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
# Submit tool outputs back to the run
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.submit_tool_outputs(
|
||||
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
|
||||
@ -539,7 +502,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
# Get messages after run completion
|
||||
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1)
|
||||
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
|
||||
@ -577,7 +540,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.create(
|
||||
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
content=content,
|
||||
role="user",
|
||||
@ -595,7 +558,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100)
|
||||
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
for msg in msgs.data:
|
||||
@ -609,7 +572,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
for msg_id in new_message_ids:
|
||||
status: MessageDeleted = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
assert status.deleted is True
|
||||
@ -645,7 +608,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
|
||||
# Update thread with the new files
|
||||
thread = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id))
|
||||
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
tool_resources: ToolResources = thread.tool_resources or ToolResources()
|
||||
code_interpreter: ToolResourcesCodeInterpreter = (
|
||||
@ -657,7 +620,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.update(
|
||||
self._client.beta.threads.update( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
|
||||
)
|
||||
@ -720,7 +683,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
if self._assistant is not None and not self._assistant_id:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id))
|
||||
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
self._assistant = None
|
||||
except Exception as e:
|
||||
|
||||
@ -11,20 +11,11 @@ import pytest
|
||||
from autogen_agentchat.messages import BaseChatMessage, TextMessage, ToolCallRequestEvent
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.tools._base import BaseTool, Tool
|
||||
from autogen_ext.agents.openai import OpenAIAssistantAgent
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai import __version__ as openai_version
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Only import the OpenAIAssistantAgent if the version is compatible.
|
||||
try:
|
||||
from autogen_ext.agents.openai import OpenAIAssistantAgent
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
"OpenAIAssistantAgent not available. Skipping all tests in this module.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
class QuestionType(str, Enum):
|
||||
MULTIPLE_CHOICE = "MULTIPLE_CHOICE"
|
||||
@ -345,8 +336,8 @@ async def test_quiz_creation(
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_reset_behavior(client: AsyncOpenAI, cancellation_token: CancellationToken) -> None:
|
||||
# Arrange: Use the default behavior for reset.
|
||||
thread = await client.beta.threads.create()
|
||||
await client.beta.threads.messages.create(
|
||||
thread = await client.beta.threads.create() # type: ignore[reportDeprecated]
|
||||
await client.beta.threads.messages.create( # type: ignore[reportDeprecated]
|
||||
thread_id=thread.id,
|
||||
content="Hi, my name is John and I'm a software engineer. Use this information to help me.",
|
||||
role="user",
|
||||
|
||||
2
python/uv.lock
generated
2
python/uv.lock
generated
@ -777,7 +777,7 @@ requires-dist = [
|
||||
{ name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" },
|
||||
{ name = "neo4j", marker = "extra == 'mem0-local'", specifier = ">=5.25.0" },
|
||||
{ name = "ollama", marker = "extra == 'ollama'", specifier = ">=0.4.7" },
|
||||
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.66.5" },
|
||||
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.93" },
|
||||
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
|
||||
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
|
||||
{ name = "pillow", marker = "extra == 'magentic-one'", specifier = ">=11.0.0" },
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user