From 1f90dc5ea91a1eae2fce087b4e6feaab87d3dc92 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 2 Dec 2024 15:48:18 +1000 Subject: [PATCH] Update OpenAI Assistante Agent on_reset (#4423) * update reset and imports * add missing imports * update lock file --------- Co-authored-by: Leonardo Pinheiro --- python/packages/autogen-ext/pyproject.toml | 5 +- .../agents/_openai_assistant_agent.py | 117 +++++++++++++----- .../tests/test_openai_assistant_agent.py | 58 ++++++++- python/uv.lock | 15 +-- 4 files changed, 144 insertions(+), 51 deletions(-) diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 9da90d577..3a25e1bf6 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -20,13 +20,10 @@ dependencies = [ [project.optional-dependencies] -langchain-tools = ["langchain_core~= 0.3.3"] -azure-code-executor = ["azure-core"] -docker-code-executor = ["docker~=7.0"] langchain = ["langchain_core~= 0.3.3"] azure = ["azure-core", "azure-identity"] docker = ["docker~=7.0"] -openai = ["openai>=1.3"] +openai = ["openai>=1.3", "aiofiles"] web-surfer = [ "playwright>=1.48.0", "pillow>=11.0.0", diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/_openai_assistant_agent.py b/python/packages/autogen-ext/src/autogen_ext/agents/_openai_assistant_agent.py index 7e1124728..c72f051ff 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/_openai_assistant_agent.py @@ -3,6 +3,7 @@ import json import logging import os from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, Awaitable, @@ -13,11 +14,11 @@ from typing import ( Literal, Optional, Sequence, + Set, Union, cast, ) -import aiofiles from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import BaseChatAgent from autogen_agentchat.base import Response @@ -35,27 +36,47 @@ from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall from autogen_core.components.models._types import FunctionExecutionResult from autogen_core.components.tools import FunctionTool, Tool -from openai import NOT_GIVEN, AsyncClient, NotGiven -from openai.pagination import AsyncCursorPage -from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads -from openai.types import FileObject -from openai.types.beta import thread_update_params -from openai.types.beta.assistant import Assistant -from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam -from openai.types.beta.assistant_tool_param import AssistantToolParam -from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam -from openai.types.beta.file_search_tool_param import FileSearchToolParam -from openai.types.beta.function_tool_param import FunctionToolParam -from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter -from openai.types.beta.threads import Message, MessageDeleted, Run -from openai.types.beta.vector_store import VectorStore -from openai.types.shared_params.function_definition import FunctionDefinition + +_has_openai_dependencies: bool = True +try: + import aiofiles + from openai import NOT_GIVEN + from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads + from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam + from openai.types.beta.file_search_tool_param import FileSearchToolParam + from openai.types.beta.function_tool_param import FunctionToolParam + from openai.types.shared_params.function_definition import FunctionDefinition +except ImportError: + _has_openai_dependencies = False + +if TYPE_CHECKING: + import aiofiles + from openai import NOT_GIVEN, AsyncClient, NotGiven + from openai.pagination import AsyncCursorPage + from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads + from openai.types import FileObject + from openai.types.beta import thread_update_params + from openai.types.beta.assistant import Assistant + from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam + from openai.types.beta.assistant_tool_param import AssistantToolParam + from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam + from openai.types.beta.file_search_tool_param import FileSearchToolParam + from openai.types.beta.function_tool_param import FunctionToolParam + from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter + from openai.types.beta.threads import Message, MessageDeleted, Run + from openai.types.beta.vector_store import VectorStore + from openai.types.shared_params.function_definition import FunctionDefinition event_logger = logging.getLogger(EVENT_LOGGER_NAME) -def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: +def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam": """Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" + if not _has_openai_dependencies: + raise RuntimeError( + "Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra." + ) + schema = tool.schema parameters: Dict[str, object] = {} if "parameters" in schema: @@ -160,7 +181,7 @@ class OpenAIAssistantAgent(BaseChatAgent): self, name: str, description: str, - client: AsyncClient, + client: "AsyncClient", model: str, instructions: str, tools: Optional[ @@ -174,18 +195,23 @@ class OpenAIAssistantAgent(BaseChatAgent): assistant_id: Optional[str] = None, thread_id: Optional[str] = None, metadata: Optional[object] = None, - response_format: Optional[AssistantResponseFormatOptionParam] = None, + response_format: Optional["AssistantResponseFormatOptionParam"] = None, temperature: Optional[float] = None, - tool_resources: Optional[ToolResources] = None, + tool_resources: Optional["ToolResources"] = None, top_p: Optional[float] = None, ) -> None: + if not _has_openai_dependencies: + raise RuntimeError( + "Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra." + ) + super().__init__(name, description) if tools is None: tools = [] # Store original tools and converted tools separately self._original_tools: List[Tool] = [] - converted_tools: List[AssistantToolParam] = [] + converted_tools: List["AssistantToolParam"] = [] for tool in tools: if isinstance(tool, str): if tool == "code_interpreter": @@ -207,8 +233,8 @@ class OpenAIAssistantAgent(BaseChatAgent): raise ValueError(f"Unsupported tool type: {type(tool)}") self._client = client - self._assistant: Optional[Assistant] = None - self._thread: Optional[Thread] = None + self._assistant: Optional["Assistant"] = None + self._thread: Optional["Thread"] = None self._init_thread_id = thread_id self._model = model self._instructions = instructions @@ -222,6 +248,10 @@ class OpenAIAssistantAgent(BaseChatAgent): self._vector_store_id: Optional[str] = None self._uploaded_file_ids: List[str] = [] + # Variables to track initial state + self._initial_message_ids: Set[str] = set() + self._initial_state_retrieved: bool = False + async def _ensure_initialized(self) -> None: """Ensure assistant and thread are created.""" if self._assistant is None: @@ -246,6 +276,27 @@ class OpenAIAssistantAgent(BaseChatAgent): else: self._thread = await self._client.beta.threads.create() + # Retrieve initial state only once + if not self._initial_state_retrieved: + await self._retrieve_initial_state() + self._initial_state_retrieved = True + + async def _retrieve_initial_state(self) -> None: + """Retrieve and store the initial state of messages and runs.""" + # Retrieve all initial message IDs + initial_message_ids: Set[str] = set() + after: str | NotGiven = NOT_GIVEN + while True: + msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( + self._thread_id, after=after, order="asc", limit=100 + ) + for msg in msgs.data: + initial_message_ids.add(msg.id) + if not msgs.has_next_page(): + break + after = msgs.data[-1].id + self._initial_message_ids = initial_message_ids + @property def produced_message_types(self) -> List[type[ChatMessage]]: """The types of messages that the assistant agent produces.""" @@ -291,6 +342,7 @@ class OpenAIAssistantAgent(BaseChatAgent): async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handle incoming messages and return a response.""" + async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message @@ -421,22 +473,27 @@ class OpenAIAssistantAgent(BaseChatAgent): ) async def on_reset(self, cancellation_token: CancellationToken) -> None: - """Handle reset command by deleting all messages in the thread.""" + """Handle reset command by deleting new messages and runs since initialization.""" + await self._ensure_initialized() + # Retrieve all message IDs in the thread - all_msgs: List[str] = [] + new_message_ids: List[str] = [] after: str | NotGiven = NOT_GIVEN while True: msgs: AsyncCursorPage[Message] = await cancellation_token.link_future( - asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after)) + asyncio.ensure_future( + self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) + ) ) for msg in msgs.data: - all_msgs.append(msg.id) - after = msg.id + if msg.id not in self._initial_message_ids: + new_message_ids.append(msg.id) if not msgs.has_next_page(): break + after = msgs.data[-1].id - # Delete all messages - for msg_id in all_msgs: + # Delete new messages + for msg_id in new_message_ids: status: MessageDeleted = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) diff --git a/python/packages/autogen-ext/tests/test_openai_assistant_agent.py b/python/packages/autogen-ext/tests/test_openai_assistant_agent.py index 3cd3e87da..efc01dead 100644 --- a/python/packages/autogen-ext/tests/test_openai_assistant_agent.py +++ b/python/packages/autogen-ext/tests/test_openai_assistant_agent.py @@ -7,6 +7,7 @@ from autogen_agentchat.messages import TextMessage from autogen_core.base import CancellationToken from autogen_core.components.tools._base import BaseTool, Tool from autogen_ext.agents import OpenAIAssistantAgent +from azure.identity import DefaultAzureCredential, get_bearer_token_provider from openai import AsyncAzureOpenAI from pydantic import BaseModel @@ -62,11 +63,22 @@ def client() -> AsyncAzureOpenAI: api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") api_key = os.getenv("AZURE_OPENAI_API_KEY") - if not all([azure_endpoint, api_key]): - pytest.skip("Azure OpenAI credentials not found in environment variables") + if not azure_endpoint: + pytest.skip("Azure OpenAI endpoint not found in environment variables") - assert azure_endpoint is not None - assert api_key is not None + # Try Azure CLI credentials if API key not provided + if not api_key: + try: + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + return AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider + ) + except Exception: + pytest.skip("Failed to get Azure CLI credentials and no API key provided") + + # Fall back to API key auth if provided return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key) @@ -138,3 +150,41 @@ async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: Ca assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content")) await agent.delete_assistant(cancellation_token) + + +@pytest.mark.asyncio +async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: CancellationToken) -> None: + # Create thread with initial message + thread = await client.beta.threads.create() + await client.beta.threads.messages.create( + thread_id=thread.id, + content="Hi, my name is John and I'm a software engineer. Use this information to help me.", + role="user", + ) + + # Create agent with existing thread + agent = OpenAIAssistantAgent( + name="assistant", + instructions="Help the user with their task.", + model="gpt-4o-mini", + description="OpenAI Assistant Agent", + client=client, + thread_id=thread.id, + ) + + # Test before reset + message1 = TextMessage(source="user", content="What is my name?") + response1 = await agent.on_messages([message1], cancellation_token) + assert isinstance(response1.chat_message.content, str) + assert "john" in response1.chat_message.content.lower() + + # Reset agent state + await agent.on_reset(cancellation_token) + + # Test after reset + message2 = TextMessage(source="user", content="What is my name?") + response2 = await agent.on_messages([message2], cancellation_token) + assert isinstance(response2.chat_message.content, str) + assert "john" in response2.chat_message.content.lower() + + await agent.delete_assistant(cancellation_token) diff --git a/python/uv.lock b/python/uv.lock index 3662a7900..7fb6a6c7a 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -31,7 +31,6 @@ resolution-markers = [ "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'", "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')", "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')", - "python_version < '0'", ] [manifest] @@ -487,22 +486,14 @@ azure = [ { name = "azure-core" }, { name = "azure-identity" }, ] -azure-code-executor = [ - { name = "azure-core" }, -] docker = [ { name = "docker" }, ] -docker-code-executor = [ - { name = "docker" }, -] langchain = [ { name = "langchain-core" }, ] -langchain-tools = [ - { name = "langchain-core" }, -] openai = [ + { name = "aiofiles" }, { name = "openai" }, ] video-surfer = [ @@ -518,16 +509,14 @@ web-surfer = [ [package.metadata] requires-dist = [ + { name = "aiofiles", marker = "extra == 'openai'" }, { name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" }, { name = "autogen-core", editable = "packages/autogen-core" }, { name = "azure-core", marker = "extra == 'azure'" }, - { name = "azure-core", marker = "extra == 'azure-code-executor'" }, { name = "azure-identity", marker = "extra == 'azure'" }, { name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" }, - { name = "docker", marker = "extra == 'docker-code-executor'", specifier = "~=7.0" }, { name = "ffmpeg-python", marker = "extra == 'video-surfer'" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, - { name = "langchain-core", marker = "extra == 'langchain-tools'", specifier = "~=0.3.3" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" }, { name = "openai-whisper", marker = "extra == 'video-surfer'" }, { name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },