mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-03 05:17:07 +00:00
Update OpenAI Assistante Agent on_reset (#4423)
* update reset and imports * add missing imports * update lock file --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
7eb8b4645b
commit
1f90dc5ea9
@ -20,13 +20,10 @@ dependencies = [
|
|||||||
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
langchain-tools = ["langchain_core~= 0.3.3"]
|
|
||||||
azure-code-executor = ["azure-core"]
|
|
||||||
docker-code-executor = ["docker~=7.0"]
|
|
||||||
langchain = ["langchain_core~= 0.3.3"]
|
langchain = ["langchain_core~= 0.3.3"]
|
||||||
azure = ["azure-core", "azure-identity"]
|
azure = ["azure-core", "azure-identity"]
|
||||||
docker = ["docker~=7.0"]
|
docker = ["docker~=7.0"]
|
||||||
openai = ["openai>=1.3"]
|
openai = ["openai>=1.3", "aiofiles"]
|
||||||
web-surfer = [
|
web-surfer = [
|
||||||
"playwright>=1.48.0",
|
"playwright>=1.48.0",
|
||||||
"pillow>=11.0.0",
|
"pillow>=11.0.0",
|
||||||
|
@ -3,6 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@ -13,11 +14,11 @@ from typing import (
|
|||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Set,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.agents import BaseChatAgent
|
from autogen_agentchat.agents import BaseChatAgent
|
||||||
from autogen_agentchat.base import Response
|
from autogen_agentchat.base import Response
|
||||||
@ -35,6 +36,21 @@ from autogen_core.base import CancellationToken
|
|||||||
from autogen_core.components import FunctionCall
|
from autogen_core.components import FunctionCall
|
||||||
from autogen_core.components.models._types import FunctionExecutionResult
|
from autogen_core.components.models._types import FunctionExecutionResult
|
||||||
from autogen_core.components.tools import FunctionTool, Tool
|
from autogen_core.components.tools import FunctionTool, Tool
|
||||||
|
|
||||||
|
_has_openai_dependencies: bool = True
|
||||||
|
try:
|
||||||
|
import aiofiles
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||||
|
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||||
|
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||||
|
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||||
|
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||||
|
except ImportError:
|
||||||
|
_has_openai_dependencies = False
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import aiofiles
|
||||||
from openai import NOT_GIVEN, AsyncClient, NotGiven
|
from openai import NOT_GIVEN, AsyncClient, NotGiven
|
||||||
from openai.pagination import AsyncCursorPage
|
from openai.pagination import AsyncCursorPage
|
||||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||||
@ -54,8 +70,13 @@ from openai.types.shared_params.function_definition import FunctionDefinition
|
|||||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam:
|
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
|
||||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||||
|
if not _has_openai_dependencies:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
|
||||||
|
)
|
||||||
|
|
||||||
schema = tool.schema
|
schema = tool.schema
|
||||||
parameters: Dict[str, object] = {}
|
parameters: Dict[str, object] = {}
|
||||||
if "parameters" in schema:
|
if "parameters" in schema:
|
||||||
@ -160,7 +181,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
description: str,
|
description: str,
|
||||||
client: AsyncClient,
|
client: "AsyncClient",
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str,
|
instructions: str,
|
||||||
tools: Optional[
|
tools: Optional[
|
||||||
@ -174,18 +195,23 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
assistant_id: Optional[str] = None,
|
assistant_id: Optional[str] = None,
|
||||||
thread_id: Optional[str] = None,
|
thread_id: Optional[str] = None,
|
||||||
metadata: Optional[object] = None,
|
metadata: Optional[object] = None,
|
||||||
response_format: Optional[AssistantResponseFormatOptionParam] = None,
|
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
tool_resources: Optional[ToolResources] = None,
|
tool_resources: Optional["ToolResources"] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not _has_openai_dependencies:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(name, description)
|
super().__init__(name, description)
|
||||||
if tools is None:
|
if tools is None:
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# Store original tools and converted tools separately
|
# Store original tools and converted tools separately
|
||||||
self._original_tools: List[Tool] = []
|
self._original_tools: List[Tool] = []
|
||||||
converted_tools: List[AssistantToolParam] = []
|
converted_tools: List["AssistantToolParam"] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if isinstance(tool, str):
|
if isinstance(tool, str):
|
||||||
if tool == "code_interpreter":
|
if tool == "code_interpreter":
|
||||||
@ -207,8 +233,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||||
|
|
||||||
self._client = client
|
self._client = client
|
||||||
self._assistant: Optional[Assistant] = None
|
self._assistant: Optional["Assistant"] = None
|
||||||
self._thread: Optional[Thread] = None
|
self._thread: Optional["Thread"] = None
|
||||||
self._init_thread_id = thread_id
|
self._init_thread_id = thread_id
|
||||||
self._model = model
|
self._model = model
|
||||||
self._instructions = instructions
|
self._instructions = instructions
|
||||||
@ -222,6 +248,10 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
self._vector_store_id: Optional[str] = None
|
self._vector_store_id: Optional[str] = None
|
||||||
self._uploaded_file_ids: List[str] = []
|
self._uploaded_file_ids: List[str] = []
|
||||||
|
|
||||||
|
# Variables to track initial state
|
||||||
|
self._initial_message_ids: Set[str] = set()
|
||||||
|
self._initial_state_retrieved: bool = False
|
||||||
|
|
||||||
async def _ensure_initialized(self) -> None:
|
async def _ensure_initialized(self) -> None:
|
||||||
"""Ensure assistant and thread are created."""
|
"""Ensure assistant and thread are created."""
|
||||||
if self._assistant is None:
|
if self._assistant is None:
|
||||||
@ -246,6 +276,27 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
else:
|
else:
|
||||||
self._thread = await self._client.beta.threads.create()
|
self._thread = await self._client.beta.threads.create()
|
||||||
|
|
||||||
|
# Retrieve initial state only once
|
||||||
|
if not self._initial_state_retrieved:
|
||||||
|
await self._retrieve_initial_state()
|
||||||
|
self._initial_state_retrieved = True
|
||||||
|
|
||||||
|
async def _retrieve_initial_state(self) -> None:
|
||||||
|
"""Retrieve and store the initial state of messages and runs."""
|
||||||
|
# Retrieve all initial message IDs
|
||||||
|
initial_message_ids: Set[str] = set()
|
||||||
|
after: str | NotGiven = NOT_GIVEN
|
||||||
|
while True:
|
||||||
|
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list(
|
||||||
|
self._thread_id, after=after, order="asc", limit=100
|
||||||
|
)
|
||||||
|
for msg in msgs.data:
|
||||||
|
initial_message_ids.add(msg.id)
|
||||||
|
if not msgs.has_next_page():
|
||||||
|
break
|
||||||
|
after = msgs.data[-1].id
|
||||||
|
self._initial_message_ids = initial_message_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
"""The types of messages that the assistant agent produces."""
|
"""The types of messages that the assistant agent produces."""
|
||||||
@ -291,6 +342,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
"""Handle incoming messages and return a response."""
|
"""Handle incoming messages and return a response."""
|
||||||
|
|
||||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||||
if isinstance(message, Response):
|
if isinstance(message, Response):
|
||||||
return message
|
return message
|
||||||
@ -421,22 +473,27 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
"""Handle reset command by deleting all messages in the thread."""
|
"""Handle reset command by deleting new messages and runs since initialization."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
|
||||||
# Retrieve all message IDs in the thread
|
# Retrieve all message IDs in the thread
|
||||||
all_msgs: List[str] = []
|
new_message_ids: List[str] = []
|
||||||
after: str | NotGiven = NOT_GIVEN
|
after: str | NotGiven = NOT_GIVEN
|
||||||
while True:
|
while True:
|
||||||
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||||
asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after))
|
asyncio.ensure_future(
|
||||||
|
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for msg in msgs.data:
|
for msg in msgs.data:
|
||||||
all_msgs.append(msg.id)
|
if msg.id not in self._initial_message_ids:
|
||||||
after = msg.id
|
new_message_ids.append(msg.id)
|
||||||
if not msgs.has_next_page():
|
if not msgs.has_next_page():
|
||||||
break
|
break
|
||||||
|
after = msgs.data[-1].id
|
||||||
|
|
||||||
# Delete all messages
|
# Delete new messages
|
||||||
for msg_id in all_msgs:
|
for msg_id in new_message_ids:
|
||||||
status: MessageDeleted = await cancellation_token.link_future(
|
status: MessageDeleted = await cancellation_token.link_future(
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||||
|
@ -7,6 +7,7 @@ from autogen_agentchat.messages import TextMessage
|
|||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.tools._base import BaseTool, Tool
|
from autogen_core.components.tools._base import BaseTool, Tool
|
||||||
from autogen_ext.agents import OpenAIAssistantAgent
|
from autogen_ext.agents import OpenAIAssistantAgent
|
||||||
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||||
from openai import AsyncAzureOpenAI
|
from openai import AsyncAzureOpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -62,11 +63,22 @@ def client() -> AsyncAzureOpenAI:
|
|||||||
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
|
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
|
||||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||||
|
|
||||||
if not all([azure_endpoint, api_key]):
|
if not azure_endpoint:
|
||||||
pytest.skip("Azure OpenAI credentials not found in environment variables")
|
pytest.skip("Azure OpenAI endpoint not found in environment variables")
|
||||||
|
|
||||||
assert azure_endpoint is not None
|
# Try Azure CLI credentials if API key not provided
|
||||||
assert api_key is not None
|
if not api_key:
|
||||||
|
try:
|
||||||
|
token_provider = get_bearer_token_provider(
|
||||||
|
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
||||||
|
)
|
||||||
|
return AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pytest.skip("Failed to get Azure CLI credentials and no API key provided")
|
||||||
|
|
||||||
|
# Fall back to API key auth if provided
|
||||||
return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key)
|
return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
@ -138,3 +150,41 @@ async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: Ca
|
|||||||
assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content"))
|
assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content"))
|
||||||
|
|
||||||
await agent.delete_assistant(cancellation_token)
|
await agent.delete_assistant(cancellation_token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: CancellationToken) -> None:
|
||||||
|
# Create thread with initial message
|
||||||
|
thread = await client.beta.threads.create()
|
||||||
|
await client.beta.threads.messages.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
content="Hi, my name is John and I'm a software engineer. Use this information to help me.",
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agent with existing thread
|
||||||
|
agent = OpenAIAssistantAgent(
|
||||||
|
name="assistant",
|
||||||
|
instructions="Help the user with their task.",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
description="OpenAI Assistant Agent",
|
||||||
|
client=client,
|
||||||
|
thread_id=thread.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test before reset
|
||||||
|
message1 = TextMessage(source="user", content="What is my name?")
|
||||||
|
response1 = await agent.on_messages([message1], cancellation_token)
|
||||||
|
assert isinstance(response1.chat_message.content, str)
|
||||||
|
assert "john" in response1.chat_message.content.lower()
|
||||||
|
|
||||||
|
# Reset agent state
|
||||||
|
await agent.on_reset(cancellation_token)
|
||||||
|
|
||||||
|
# Test after reset
|
||||||
|
message2 = TextMessage(source="user", content="What is my name?")
|
||||||
|
response2 = await agent.on_messages([message2], cancellation_token)
|
||||||
|
assert isinstance(response2.chat_message.content, str)
|
||||||
|
assert "john" in response2.chat_message.content.lower()
|
||||||
|
|
||||||
|
await agent.delete_assistant(cancellation_token)
|
||||||
|
15
python/uv.lock
generated
15
python/uv.lock
generated
@ -31,7 +31,6 @@ resolution-markers = [
|
|||||||
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'",
|
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux'",
|
||||||
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
|
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
|
||||||
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
|
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux')",
|
||||||
"python_version < '0'",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[manifest]
|
[manifest]
|
||||||
@ -487,22 +486,14 @@ azure = [
|
|||||||
{ name = "azure-core" },
|
{ name = "azure-core" },
|
||||||
{ name = "azure-identity" },
|
{ name = "azure-identity" },
|
||||||
]
|
]
|
||||||
azure-code-executor = [
|
|
||||||
{ name = "azure-core" },
|
|
||||||
]
|
|
||||||
docker = [
|
docker = [
|
||||||
{ name = "docker" },
|
{ name = "docker" },
|
||||||
]
|
]
|
||||||
docker-code-executor = [
|
|
||||||
{ name = "docker" },
|
|
||||||
]
|
|
||||||
langchain = [
|
langchain = [
|
||||||
{ name = "langchain-core" },
|
{ name = "langchain-core" },
|
||||||
]
|
]
|
||||||
langchain-tools = [
|
|
||||||
{ name = "langchain-core" },
|
|
||||||
]
|
|
||||||
openai = [
|
openai = [
|
||||||
|
{ name = "aiofiles" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
]
|
]
|
||||||
video-surfer = [
|
video-surfer = [
|
||||||
@ -518,16 +509,14 @@ web-surfer = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "aiofiles", marker = "extra == 'openai'" },
|
||||||
{ name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" },
|
{ name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" },
|
||||||
{ name = "autogen-core", editable = "packages/autogen-core" },
|
{ name = "autogen-core", editable = "packages/autogen-core" },
|
||||||
{ name = "azure-core", marker = "extra == 'azure'" },
|
{ name = "azure-core", marker = "extra == 'azure'" },
|
||||||
{ name = "azure-core", marker = "extra == 'azure-code-executor'" },
|
|
||||||
{ name = "azure-identity", marker = "extra == 'azure'" },
|
{ name = "azure-identity", marker = "extra == 'azure'" },
|
||||||
{ name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" },
|
{ name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" },
|
||||||
{ name = "docker", marker = "extra == 'docker-code-executor'", specifier = "~=7.0" },
|
|
||||||
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
|
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
|
||||||
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
||||||
{ name = "langchain-core", marker = "extra == 'langchain-tools'", specifier = "~=0.3.3" },
|
|
||||||
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" },
|
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" },
|
||||||
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
|
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
|
||||||
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
|
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
|
||||||
|
Loading…
x
Reference in New Issue
Block a user