Update OpenAI Assistante Agent on_reset (#4423)

* update reset and imports

* add missing imports

* update lock file

---------

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
Leonardo Pinheiro 2024-12-02 15:48:18 +10:00 committed by GitHub
parent 7eb8b4645b
commit 1f90dc5ea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 144 additions and 51 deletions

View File

@ -20,13 +20,10 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
langchain-tools = ["langchain_core~= 0.3.3"]
azure-code-executor = ["azure-core"]
docker-code-executor = ["docker~=7.0"]
langchain = ["langchain_core~= 0.3.3"] langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"] azure = ["azure-core", "azure-identity"]
docker = ["docker~=7.0"] docker = ["docker~=7.0"]
openai = ["openai>=1.3"] openai = ["openai>=1.3", "aiofiles"]
web-surfer = [ web-surfer = [
"playwright>=1.48.0", "playwright>=1.48.0",
"pillow>=11.0.0", "pillow>=11.0.0",

View File

@ -3,6 +3,7 @@ import json
import logging import logging
import os import os
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AsyncGenerator, AsyncGenerator,
Awaitable, Awaitable,
@ -13,11 +14,11 @@ from typing import (
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Set,
Union, Union,
cast, cast,
) )
import aiofiles
from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response from autogen_agentchat.base import Response
@ -35,6 +36,21 @@ from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall from autogen_core.components import FunctionCall
from autogen_core.components.models._types import FunctionExecutionResult from autogen_core.components.models._types import FunctionExecutionResult
from autogen_core.components.tools import FunctionTool, Tool from autogen_core.components.tools import FunctionTool, Tool
_has_openai_dependencies: bool = True
try:
import aiofiles
from openai import NOT_GIVEN
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
except ImportError:
_has_openai_dependencies = False
if TYPE_CHECKING:
import aiofiles
from openai import NOT_GIVEN, AsyncClient, NotGiven from openai import NOT_GIVEN, AsyncClient, NotGiven
from openai.pagination import AsyncCursorPage from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
@ -54,8 +70,13 @@ from openai.types.shared_params.function_definition import FunctionDefinition
event_logger = logging.getLogger(EVENT_LOGGER_NAME) event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" """Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
if not _has_openai_dependencies:
raise RuntimeError(
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
)
schema = tool.schema schema = tool.schema
parameters: Dict[str, object] = {} parameters: Dict[str, object] = {}
if "parameters" in schema: if "parameters" in schema:
@ -160,7 +181,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
self, self,
name: str, name: str,
description: str, description: str,
client: AsyncClient, client: "AsyncClient",
model: str, model: str,
instructions: str, instructions: str,
tools: Optional[ tools: Optional[
@ -174,18 +195,23 @@ class OpenAIAssistantAgent(BaseChatAgent):
assistant_id: Optional[str] = None, assistant_id: Optional[str] = None,
thread_id: Optional[str] = None, thread_id: Optional[str] = None,
metadata: Optional[object] = None, metadata: Optional[object] = None,
response_format: Optional[AssistantResponseFormatOptionParam] = None, response_format: Optional["AssistantResponseFormatOptionParam"] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
tool_resources: Optional[ToolResources] = None, tool_resources: Optional["ToolResources"] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
) -> None: ) -> None:
if not _has_openai_dependencies:
raise RuntimeError(
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
)
super().__init__(name, description) super().__init__(name, description)
if tools is None: if tools is None:
tools = [] tools = []
# Store original tools and converted tools separately # Store original tools and converted tools separately
self._original_tools: List[Tool] = [] self._original_tools: List[Tool] = []
converted_tools: List[AssistantToolParam] = [] converted_tools: List["AssistantToolParam"] = []
for tool in tools: for tool in tools:
if isinstance(tool, str): if isinstance(tool, str):
if tool == "code_interpreter": if tool == "code_interpreter":
@ -207,8 +233,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
raise ValueError(f"Unsupported tool type: {type(tool)}") raise ValueError(f"Unsupported tool type: {type(tool)}")
self._client = client self._client = client
self._assistant: Optional[Assistant] = None self._assistant: Optional["Assistant"] = None
self._thread: Optional[Thread] = None self._thread: Optional["Thread"] = None
self._init_thread_id = thread_id self._init_thread_id = thread_id
self._model = model self._model = model
self._instructions = instructions self._instructions = instructions
@ -222,6 +248,10 @@ class OpenAIAssistantAgent(BaseChatAgent):
self._vector_store_id: Optional[str] = None self._vector_store_id: Optional[str] = None
self._uploaded_file_ids: List[str] = [] self._uploaded_file_ids: List[str] = []
# Variables to track initial state
self._initial_message_ids: Set[str] = set()
self._initial_state_retrieved: bool = False
async def _ensure_initialized(self) -> None: async def _ensure_initialized(self) -> None:
"""Ensure assistant and thread are created.""" """Ensure assistant and thread are created."""
if self._assistant is None: if self._assistant is None:
@ -246,6 +276,27 @@ class OpenAIAssistantAgent(BaseChatAgent):
else: else:
self._thread = await self._client.beta.threads.create() self._thread = await self._client.beta.threads.create()
# Retrieve initial state only once
if not self._initial_state_retrieved:
await self._retrieve_initial_state()
self._initial_state_retrieved = True
async def _retrieve_initial_state(self) -> None:
"""Retrieve and store the initial state of messages and runs."""
# Retrieve all initial message IDs
initial_message_ids: Set[str] = set()
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list(
self._thread_id, after=after, order="asc", limit=100
)
for msg in msgs.data:
initial_message_ids.add(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
self._initial_message_ids = initial_message_ids
@property @property
def produced_message_types(self) -> List[type[ChatMessage]]: def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces.""" """The types of messages that the assistant agent produces."""
@ -291,6 +342,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response.""" """Handle incoming messages and return a response."""
async for message in self.on_messages_stream(messages, cancellation_token): async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response): if isinstance(message, Response):
return message return message
@ -421,22 +473,27 @@ class OpenAIAssistantAgent(BaseChatAgent):
) )
async def on_reset(self, cancellation_token: CancellationToken) -> None: async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Handle reset command by deleting all messages in the thread.""" """Handle reset command by deleting new messages and runs since initialization."""
await self._ensure_initialized()
# Retrieve all message IDs in the thread # Retrieve all message IDs in the thread
all_msgs: List[str] = [] new_message_ids: List[str] = []
after: str | NotGiven = NOT_GIVEN after: str | NotGiven = NOT_GIVEN
while True: while True:
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future( msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after)) asyncio.ensure_future(
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100)
)
) )
for msg in msgs.data: for msg in msgs.data:
all_msgs.append(msg.id) if msg.id not in self._initial_message_ids:
after = msg.id new_message_ids.append(msg.id)
if not msgs.has_next_page(): if not msgs.has_next_page():
break break
after = msgs.data[-1].id
# Delete all messages # Delete new messages
for msg_id in all_msgs: for msg_id in new_message_ids:
status: MessageDeleted = await cancellation_token.link_future( status: MessageDeleted = await cancellation_token.link_future(
asyncio.ensure_future( asyncio.ensure_future(
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)

View File

@ -7,6 +7,7 @@ from autogen_agentchat.messages import TextMessage
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from autogen_core.components.tools._base import BaseTool, Tool from autogen_core.components.tools._base import BaseTool, Tool
from autogen_ext.agents import OpenAIAssistantAgent from autogen_ext.agents import OpenAIAssistantAgent
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
from pydantic import BaseModel from pydantic import BaseModel
@ -62,11 +63,22 @@ def client() -> AsyncAzureOpenAI:
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
api_key = os.getenv("AZURE_OPENAI_API_KEY") api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not all([azure_endpoint, api_key]): if not azure_endpoint:
pytest.skip("Azure OpenAI credentials not found in environment variables") pytest.skip("Azure OpenAI endpoint not found in environment variables")
assert azure_endpoint is not None # Try Azure CLI credentials if API key not provided
assert api_key is not None if not api_key:
try:
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
return AsyncAzureOpenAI(
azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider
)
except Exception:
pytest.skip("Failed to get Azure CLI credentials and no API key provided")
# Fall back to API key auth if provided
return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key) return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key)
@ -138,3 +150,41 @@ async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: Ca
assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content")) assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content"))
await agent.delete_assistant(cancellation_token) await agent.delete_assistant(cancellation_token)
@pytest.mark.asyncio
async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: CancellationToken) -> None:
# Create thread with initial message
thread = await client.beta.threads.create()
await client.beta.threads.messages.create(
thread_id=thread.id,
content="Hi, my name is John and I'm a software engineer. Use this information to help me.",
role="user",
)
# Create agent with existing thread
agent = OpenAIAssistantAgent(
name="assistant",
instructions="Help the user with their task.",
model="gpt-4o-mini",
description="OpenAI Assistant Agent",
client=client,
thread_id=thread.id,
)
# Test before reset
message1 = TextMessage(source="user", content="What is my name?")
response1 = await agent.on_messages([message1], cancellation_token)
assert isinstance(response1.chat_message.content, str)
assert "john" in response1.chat_message.content.lower()
# Reset agent state
await agent.on_reset(cancellation_token)
# Test after reset
message2 = TextMessage(source="user", content="What is my name?")
response2 = await agent.on_messages([message2], cancellation_token)
assert isinstance(response2.chat_message.content, str)
assert "john" in response2.chat_message.content.lower()
await agent.delete_assistant(cancellation_token)

15
python/uv.lock generated
View File

@ -31,7 +31,6 @@ resolution-markers = [
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'", "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'",
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')", "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')", "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_version < '0'",
] ]
[manifest] [manifest]
@ -487,22 +486,14 @@ azure = [
{ name = "azure-core" }, { name = "azure-core" },
{ name = "azure-identity" }, { name = "azure-identity" },
] ]
azure-code-executor = [
{ name = "azure-core" },
]
docker = [ docker = [
{ name = "docker" }, { name = "docker" },
] ]
docker-code-executor = [
{ name = "docker" },
]
langchain = [ langchain = [
{ name = "langchain-core" }, { name = "langchain-core" },
] ]
langchain-tools = [
{ name = "langchain-core" },
]
openai = [ openai = [
{ name = "aiofiles" },
{ name = "openai" }, { name = "openai" },
] ]
video-surfer = [ video-surfer = [
@ -518,16 +509,14 @@ web-surfer = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiofiles", marker = "extra == 'openai'" },
{ name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" }, { name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" },
{ name = "autogen-core", editable = "packages/autogen-core" }, { name = "autogen-core", editable = "packages/autogen-core" },
{ name = "azure-core", marker = "extra == 'azure'" }, { name = "azure-core", marker = "extra == 'azure'" },
{ name = "azure-core", marker = "extra == 'azure-code-executor'" },
{ name = "azure-identity", marker = "extra == 'azure'" }, { name = "azure-identity", marker = "extra == 'azure'" },
{ name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" }, { name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" },
{ name = "docker", marker = "extra == 'docker-code-executor'", specifier = "~=7.0" },
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" }, { name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
{ name = "langchain-core", marker = "extra == 'langchain-tools'", specifier = "~=0.3.3" },
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" },
{ name = "openai-whisper", marker = "extra == 'video-surfer'" }, { name = "openai-whisper", marker = "extra == 'video-surfer'" },
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" }, { name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },