mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-31 12:00:11 +00:00
OpenAI Assistants Agent (#4131)
* initial assistant client draft * expose assistants client * initial openai assistant agentchat draft * update file search * add delete methods and fix typing * add tool execution * fix tool call and add docstring * abstract tools and support thread management * add tests * removed unused typevars * add unsaved test changes * test typing fixes --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
ad271d975c
commit
df32d5e1d1
@ -2,6 +2,7 @@ from ._assistant_agent import AssistantAgent, Handoff
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
from ._openai_assistant_agent import OpenAIAssistantAgent
|
||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
|
||||
@ -11,6 +12,7 @@ __all__ = [
|
||||
"Handoff",
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
"SocietyOfMindAgent",
|
||||
]
|
||||
|
@ -0,0 +1,538 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiofiles
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
from autogen_core.components.models._types import FunctionExecutionResult
|
||||
from autogen_core.components.tools import FunctionTool, Tool
|
||||
from openai import NOT_GIVEN, AsyncClient, NotGiven
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types import FileObject
|
||||
from openai.types.beta import thread_update_params
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.vector_store import VectorStore
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from autogen_agentchat.messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam:
|
||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||
schema = tool.schema
|
||||
parameters: Dict[str, object] = {}
|
||||
if "parameters" in schema:
|
||||
parameters = {
|
||||
"type": schema["parameters"]["type"],
|
||||
"properties": schema["parameters"]["properties"],
|
||||
}
|
||||
if "required" in schema["parameters"]:
|
||||
parameters["required"] = schema["parameters"]["required"]
|
||||
|
||||
function_def = FunctionDefinition(
|
||||
name=schema["name"],
|
||||
description=schema.get("description", ""),
|
||||
parameters=parameters,
|
||||
)
|
||||
return FunctionToolParam(type="function", function=function_def)
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
"""An agent implementation that uses the OpenAI Assistant API to generate responses.
|
||||
|
||||
This agent leverages the OpenAI Assistant API to create AI assistants with capabilities like:
|
||||
- Code interpretation and execution
|
||||
- File handling and search
|
||||
- Custom function calling
|
||||
- Multi-turn conversations
|
||||
|
||||
The agent maintains a thread of conversation and can use various tools including:
|
||||
- Code interpreter: For executing code and working with files
|
||||
- File search: For searching through uploaded documents
|
||||
- Custom functions: For extending capabilities with user-defined tools
|
||||
|
||||
Key Features:
|
||||
- Supports multiple file formats including code, documents, images
|
||||
- Can handle up to 128 tools per assistant
|
||||
- Maintains conversation context in threads
|
||||
- Supports file uploads for code interpreter and search
|
||||
- Vector store integration for efficient file search
|
||||
- Automatic file parsing and embedding
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncClient
|
||||
from autogen_agentchat.agents import OpenAIAssistantAgent
|
||||
|
||||
# Create an OpenAI client
|
||||
client = AsyncClient(api_key="your-api-key", base_url="your-base-url")
|
||||
|
||||
# Create an assistant with code interpreter
|
||||
assistant = OpenAIAssistantAgent(
|
||||
name="Python Helper",
|
||||
description="Helps with Python programming",
|
||||
client=client,
|
||||
model="gpt-4",
|
||||
instructions="You are a helpful Python programming assistant.",
|
||||
tools=["code_interpreter"],
|
||||
)
|
||||
|
||||
# Upload files for the assistant to use
|
||||
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
|
||||
|
||||
# Get response from the assistant
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
|
||||
)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_uploaded_files(cancellation_token)
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
||||
Args:
|
||||
name (str): Name of the assistant
|
||||
description (str): Description of the assistant's purpose
|
||||
client (AsyncClient): OpenAI API client instance
|
||||
model (str): Model to use (e.g. "gpt-4")
|
||||
instructions (str): System instructions for the assistant
|
||||
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
|
||||
assistant_id (Optional[str]): ID of existing assistant to use
|
||||
metadata (Optional[object]): Additional metadata for the assistant
|
||||
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
|
||||
temperature (Optional[float]): Temperature for response generation
|
||||
tool_resources (Optional[ToolResources]): Additional tool configuration
|
||||
top_p (Optional[float]): Top p sampling parameter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
client: AsyncClient,
|
||||
model: str,
|
||||
instructions: str,
|
||||
tools: Optional[
|
||||
Iterable[
|
||||
Union[
|
||||
Literal["code_interpreter", "file_search"],
|
||||
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
assistant_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
metadata: Optional[object] = None,
|
||||
response_format: Optional[AssistantResponseFormatOptionParam] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_resources: Optional[ToolResources] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(name, description)
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Store original tools and converted tools separately
|
||||
self._original_tools: List[Tool] = []
|
||||
converted_tools: List[AssistantToolParam] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
if tool == "code_interpreter":
|
||||
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
|
||||
elif tool == "file_search":
|
||||
converted_tools.append(FileSearchToolParam(type="file_search"))
|
||||
elif isinstance(tool, Tool):
|
||||
self._original_tools.append(tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(tool))
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
function_tool = FunctionTool(tool, description=description)
|
||||
self._original_tools.append(function_tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(function_tool))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
self._client = client
|
||||
self._assistant: Optional[Assistant] = None
|
||||
self._thread: Optional[Thread] = None
|
||||
self._init_thread_id = thread_id
|
||||
self._model = model
|
||||
self._instructions = instructions
|
||||
self._api_tools = converted_tools
|
||||
self._assistant_id = assistant_id
|
||||
self._metadata = metadata
|
||||
self._response_format = response_format
|
||||
self._temperature = temperature
|
||||
self._tool_resources = tool_resources
|
||||
self._top_p = top_p
|
||||
self._vector_store_id: Optional[str] = None
|
||||
self._uploaded_file_ids: List[str] = []
|
||||
|
||||
async def _ensure_initialized(self) -> None:
|
||||
"""Ensure assistant and thread are created."""
|
||||
if self._assistant is None:
|
||||
if self._assistant_id:
|
||||
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id)
|
||||
else:
|
||||
self._assistant = await self._client.beta.assistants.create(
|
||||
model=self._model,
|
||||
description=self.description,
|
||||
instructions=self._instructions,
|
||||
tools=self._api_tools,
|
||||
metadata=self._metadata,
|
||||
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
|
||||
temperature=self._temperature,
|
||||
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
|
||||
top_p=self._top_p,
|
||||
)
|
||||
|
||||
if self._thread is None:
|
||||
if self._init_thread_id:
|
||||
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id)
|
||||
else:
|
||||
self._thread = await self._client.beta.threads.create()
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
return [TextMessage]
|
||||
|
||||
@property
|
||||
def threads(self) -> AsyncThreads:
|
||||
return self._client.beta.threads
|
||||
|
||||
@property
|
||||
def runs(self) -> AsyncRuns:
|
||||
return self._client.beta.threads.runs
|
||||
|
||||
@property
|
||||
def messages(self) -> AsyncMessages:
|
||||
return self._client.beta.threads.messages
|
||||
|
||||
@property
|
||||
def _get_assistant_id(self) -> str:
|
||||
if self._assistant is None:
|
||||
raise ValueError("Assistant not initialized")
|
||||
return self._assistant.id
|
||||
|
||||
@property
|
||||
def _thread_id(self) -> str:
|
||||
if self._thread is None:
|
||||
raise ValueError("Thread not initialized")
|
||||
return self._thread.id
|
||||
|
||||
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
|
||||
"""Execute a tool call and return the result."""
|
||||
try:
|
||||
if not self._original_tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token)
|
||||
return tool.return_value_as_string(result)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handle incoming messages and return a response."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Process all messages in sequence
|
||||
for message in messages:
|
||||
if isinstance(message, (TextMessage, MultiModalMessage)):
|
||||
await self.handle_text_message(str(message.content), cancellation_token)
|
||||
elif isinstance(message, (StopMessage, HandoffMessage)):
|
||||
await self.handle_text_message(message.content, cancellation_token)
|
||||
|
||||
# Inner messages for tool calls
|
||||
inner_messages: List[AgentMessage] = []
|
||||
|
||||
# Create and start a run
|
||||
run: Run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.create(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._get_assistant_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for run completion by polling
|
||||
while True:
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.retrieve(
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if run.status == "failed":
|
||||
raise ValueError(f"Run failed: {run.last_error}")
|
||||
|
||||
# If the run requires action (function calls), execute tools and continue
|
||||
if run.status == "requires_action" and run.required_action is not None:
|
||||
tool_calls: List[FunctionCall] = []
|
||||
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
||||
if required_tool_call.type == "function":
|
||||
tool_calls.append(
|
||||
FunctionCall(
|
||||
id=required_tool_call.id,
|
||||
name=required_tool_call.function.name,
|
||||
arguments=required_tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool call message to inner messages
|
||||
tool_call_msg = ToolCallMessage(source=self.name, content=tool_calls)
|
||||
inner_messages.append(tool_call_msg)
|
||||
event_logger.debug(tool_call_msg)
|
||||
|
||||
# Execute tool calls and get results
|
||||
tool_outputs: List[FunctionExecutionResult] = []
|
||||
for tool_call in tool_calls:
|
||||
result = await self._execute_tool_call(tool_call, cancellation_token)
|
||||
tool_outputs.append(FunctionExecutionResult(content=result, call_id=tool_call.id))
|
||||
|
||||
# Add tool result message to inner messages
|
||||
tool_result_msg = ToolCallResultMessage(source=self.name, content=tool_outputs)
|
||||
inner_messages.append(tool_result_msg)
|
||||
event_logger.debug(tool_result_msg)
|
||||
|
||||
# Submit tool outputs back to the run
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.submit_tool_outputs(
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
|
||||
)
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if run.status == "completed":
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get messages after run completion
|
||||
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1)
|
||||
)
|
||||
)
|
||||
|
||||
if not assistant_messages.data:
|
||||
raise ValueError("No messages received from assistant")
|
||||
|
||||
# Get the last message's content
|
||||
last_message = assistant_messages.data[0]
|
||||
if not last_message.content:
|
||||
raise ValueError(f"No content in the last message: {last_message}")
|
||||
|
||||
# Extract text content
|
||||
text_content = [content for content in last_message.content if content.type == "text"]
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message.content}")
|
||||
|
||||
# Return the assistant's response as a Response with inner messages
|
||||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
return Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle regular text messages by adding them to the thread."""
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
content=content,
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle reset command by deleting all messages in the thread."""
|
||||
# Retrieve all message IDs in the thread
|
||||
all_msgs: List[str] = []
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after))
|
||||
)
|
||||
for msg in msgs.data:
|
||||
all_msgs.append(msg.id)
|
||||
after = msg.id
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
|
||||
# Delete all messages
|
||||
for msg_id in all_msgs:
|
||||
status: MessageDeleted = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||
)
|
||||
)
|
||||
assert status.deleted is True
|
||||
|
||||
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
|
||||
"""Upload files and return their IDs."""
|
||||
if isinstance(file_paths, str):
|
||||
file_paths = [file_paths]
|
||||
|
||||
file_ids: List[str] = []
|
||||
for file_path in file_paths:
|
||||
async with aiofiles.open(file_path, mode="rb") as f:
|
||||
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
file: FileObject = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
|
||||
)
|
||||
file_ids.append(file.id)
|
||||
self._uploaded_file_ids.append(file.id)
|
||||
|
||||
return file_ids
|
||||
|
||||
async def on_upload_for_code_interpreter(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for the code interpreter."""
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Update thread with the new files
|
||||
thread = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id))
|
||||
)
|
||||
tool_resources: ToolResources = thread.tool_resources or ToolResources()
|
||||
code_interpreter: ToolResourcesCodeInterpreter = (
|
||||
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
|
||||
)
|
||||
existing_file_ids: List[str] = code_interpreter.file_ids or []
|
||||
existing_file_ids.extend(file_ids)
|
||||
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
|
||||
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.update(
|
||||
thread_id=self._thread_id,
|
||||
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_upload_for_file_search(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for file search."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Check if file_search is enabled in tools
|
||||
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
|
||||
raise ValueError(
|
||||
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
|
||||
)
|
||||
|
||||
# Create vector store if not already created
|
||||
if self._vector_store_id is None:
|
||||
vector_store: VectorStore = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.vector_stores.create())
|
||||
)
|
||||
self._vector_store_id = vector_store.id
|
||||
|
||||
# Update assistant with vector store ID
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.assistants.update(
|
||||
assistant_id=self._get_assistant_id,
|
||||
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Create file batch with the file IDs
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.vector_stores.file_batches.create_and_poll(
|
||||
vector_store_id=self._vector_store_id, file_ids=file_ids
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete all files that were uploaded by this agent instance."""
|
||||
for file_id in self._uploaded_file_ids:
|
||||
try:
|
||||
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
|
||||
self._uploaded_file_ids = []
|
||||
|
||||
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the assistant if it was created by this instance."""
|
||||
if self._assistant is not None and not self._assistant_id:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id))
|
||||
)
|
||||
self._assistant = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete assistant: {str(e)}")
|
||||
|
||||
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the vector store if it was created by this instance."""
|
||||
if self._vector_store_id is not None:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.vector_stores.delete(vector_store_id=self._vector_store_id))
|
||||
)
|
||||
self._vector_store_id = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete vector store: {str(e)}")
|
@ -0,0 +1,140 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import OpenAIAssistantAgent
|
||||
from autogen_agentchat.messages import TextMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools._base import BaseTool, Tool
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QuestionType(str, Enum):
|
||||
MULTIPLE_CHOICE = "MULTIPLE_CHOICE"
|
||||
FREE_RESPONSE = "FREE_RESPONSE"
|
||||
|
||||
|
||||
class Question(BaseModel):
|
||||
question_text: str
|
||||
question_type: QuestionType
|
||||
choices: Optional[List[str]] = None
|
||||
|
||||
|
||||
class DisplayQuizArgs(BaseModel):
|
||||
title: str
|
||||
questions: List[Question]
|
||||
|
||||
|
||||
class QuizResponses(BaseModel):
|
||||
responses: List[str]
|
||||
|
||||
|
||||
class DisplayQuizTool(BaseTool[DisplayQuizArgs, QuizResponses]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
args_type=DisplayQuizArgs,
|
||||
return_type=QuizResponses,
|
||||
name="display_quiz",
|
||||
description=(
|
||||
"Displays a quiz to the student and returns the student's responses. "
|
||||
"A single quiz can have multiple questions."
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken) -> QuizResponses:
|
||||
responses: List[str] = []
|
||||
for q in args.questions:
|
||||
if q.question_type == QuestionType.MULTIPLE_CHOICE:
|
||||
response = q.choices[0] if q.choices else ""
|
||||
elif q.question_type == QuestionType.FREE_RESPONSE:
|
||||
response = "Sample free response"
|
||||
else:
|
||||
response = ""
|
||||
responses.append(response)
|
||||
return QuizResponses(responses=responses)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> AsyncAzureOpenAI:
|
||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
|
||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||
|
||||
if not all([azure_endpoint, api_key]):
|
||||
pytest.skip("Azure OpenAI credentials not found in environment variables")
|
||||
|
||||
assert azure_endpoint is not None
|
||||
assert api_key is not None
|
||||
return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(client: AsyncAzureOpenAI) -> OpenAIAssistantAgent:
|
||||
tools: List[Union[Literal["code_interpreter", "file_search"], Tool]] = [
|
||||
"code_interpreter",
|
||||
"file_search",
|
||||
DisplayQuizTool(),
|
||||
]
|
||||
|
||||
return OpenAIAssistantAgent(
|
||||
name="assistant",
|
||||
instructions="Help the user with their task.",
|
||||
model="gpt-4o-mini",
|
||||
description="OpenAI Assistant Agent",
|
||||
client=client,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cancellation_token() -> CancellationToken:
|
||||
return CancellationToken()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
|
||||
file_path = r"C:\Users\lpinheiro\Github\autogen-test\data\SampleBooks\jungle_book.txt"
|
||||
await agent.on_upload_for_file_search(file_path, cancellation_token)
|
||||
|
||||
message = TextMessage(source="user", content="What is the first sentence of the jungle scout book?")
|
||||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert len(response.chat_message.content) > 0
|
||||
|
||||
await agent.delete_uploaded_files(cancellation_token)
|
||||
await agent.delete_vector_store(cancellation_token)
|
||||
await agent.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
|
||||
message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?")
|
||||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert len(response.chat_message.content) > 0
|
||||
assert "x = 1" in response.chat_message.content.lower()
|
||||
|
||||
await agent.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
|
||||
message = TextMessage(
|
||||
source="user",
|
||||
content="Create a short quiz about basic math with one multiple choice question and one free response question.",
|
||||
)
|
||||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert len(response.chat_message.content) > 0
|
||||
assert isinstance(response.inner_messages, list)
|
||||
assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content"))
|
||||
|
||||
await agent.delete_assistant(cancellation_token)
|
Loading…
x
Reference in New Issue
Block a user