mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00

<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> There have been updates to the azure ai agent foundry sdk (azure-ai-project). This PR updates the autogen `AzureAIAgent` which wraps the azure ai agent. A list of some changes - Update docstring samples to use `endpoint` (instead of connection string previously) - Update imports and arguments e.g, from `azure.ai.agents` etc - Add a guide in ext docs showing Bing Search Grounding tool example. <img width="1423" alt="image" src="https://github.com/user-attachments/assets/0b7c8fa6-8aa5-4c20-831b-b525ac8243b7" /> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number Closes #6601 <!-- For example: "Closes #1234" --> ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed.
797 lines
28 KiB
Python
797 lines
28 KiB
Python
import json
|
|
from asyncio import CancelledError
|
|
from types import SimpleNamespace
|
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
|
from unittest.mock import AsyncMock, MagicMock, call
|
|
|
|
import pytest
|
|
from autogen_agentchat.base._chat_agent import Response
|
|
from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent
|
|
from autogen_core._cancellation_token import CancellationToken
|
|
from autogen_core.tools._function_tool import FunctionTool
|
|
from autogen_ext.agents.azure._azure_ai_agent import AzureAIAgent
|
|
from autogen_ext.agents.azure._types import ListToolType
|
|
from azure.ai.agents.models import (
|
|
AzureAISearchToolDefinition,
|
|
AzureFunctionToolDefinition,
|
|
BingGroundingToolDefinition,
|
|
CodeInterpreterToolDefinition,
|
|
FilePurpose,
|
|
FileSearchToolDefinition,
|
|
FileState,
|
|
RequiredAction,
|
|
RunStatus,
|
|
SubmitToolOutputsAction,
|
|
ThreadMessage,
|
|
)
|
|
from azure.ai.projects.aio import AIProjectClient
|
|
|
|
|
|
class FakeText:
|
|
def __init__(self, value: str) -> None:
|
|
self.value = value
|
|
|
|
|
|
class FakeTextContent:
|
|
def __init__(self, text: str) -> None:
|
|
self.type = "text"
|
|
self.text = FakeText(text)
|
|
|
|
|
|
class FakeMessage:
|
|
def __init__(self, id: str, text: str) -> None:
|
|
self.id = id
|
|
# The agent expects content to be a list of objects with a "type" attribute.
|
|
self.content = [FakeTextContent(text)]
|
|
self.role = "user"
|
|
|
|
@property
|
|
def text_messages(self) -> List[FakeTextContent]:
|
|
"""Returns all text message contents in the messages.
|
|
|
|
:rtype: List[FakeTextContent]
|
|
"""
|
|
if not self.content:
|
|
return []
|
|
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
|
|
|
|
|
class FakeMessageUrlCitationDetails:
|
|
def __init__(self, url: str, title: str) -> None:
|
|
self.url = url
|
|
self.title = title
|
|
|
|
|
|
class FakeTextUrlCitationAnnotation:
|
|
def __init__(self, citation_details: FakeMessageUrlCitationDetails, text: str) -> None:
|
|
self.type = "url_citation"
|
|
self.url_citation = citation_details
|
|
self.text = text
|
|
|
|
|
|
class FakeTextFileCitationDetails:
|
|
def __init__(self, file_id: str, quote: str) -> None:
|
|
self.file_id = file_id
|
|
self.quote = quote
|
|
|
|
|
|
class FakeTextFileCitationAnnotation:
|
|
def __init__(self, citation_details: FakeTextFileCitationDetails) -> None:
|
|
self.type = "file_citation"
|
|
self.file_citation = citation_details
|
|
|
|
|
|
class FakeMessageWithUrlCitationAnnotation:
|
|
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
|
|
self.id = id
|
|
# The agent expects content to be a list of objects with a "type" attribute.
|
|
self.content = [FakeTextContent(text)]
|
|
self.role = "user"
|
|
self._annotations = annotations
|
|
|
|
@property
|
|
def text_messages(self) -> List[FakeTextContent]:
|
|
"""Returns all text message contents in the messages.
|
|
|
|
:rtype: List[FakeTextContent]
|
|
"""
|
|
if not self.content:
|
|
return []
|
|
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
|
|
|
@property
|
|
def url_citation_annotations(self) -> List[FakeTextUrlCitationAnnotation]:
|
|
"""Returns all URL citation annotations from text message annotations in the messages.
|
|
|
|
:rtype: List[FakeTextUrlCitationAnnotation]
|
|
"""
|
|
return self._annotations
|
|
|
|
|
|
class FakeMessageWithFileCitationAnnotation:
|
|
def __init__(self, id: str, text: str, annotations: list[FakeTextFileCitationAnnotation]) -> None:
|
|
self.id = id
|
|
# The agent expects content to be a list of objects with a "type" attribute.
|
|
self.content = [FakeTextContent(text)]
|
|
self.role = "user"
|
|
self._annotations = annotations
|
|
|
|
@property
|
|
def text_messages(self) -> List[FakeTextContent]:
|
|
"""Returns all text message contents in the messages.
|
|
|
|
:rtype: List[FakeTextContent]
|
|
"""
|
|
if not self.content:
|
|
return []
|
|
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
|
|
|
@property
|
|
def file_citation_annotations(self) -> List[FakeTextFileCitationAnnotation]:
|
|
"""Returns all URL citation annotations from text message annotations in the messages.
|
|
|
|
:rtype: List[FakeTextFileCitationAnnotation]
|
|
"""
|
|
return self._annotations
|
|
|
|
|
|
class FakeMessageWithAnnotation:
|
|
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
|
|
self.id = id
|
|
# The agent expects content to be a list of objects with a "type" attribute.
|
|
self.content = [FakeTextContent(text)]
|
|
self.role = "user"
|
|
self.annotations = annotations
|
|
|
|
@property
|
|
def text_messages(self) -> List[FakeTextContent]:
|
|
"""Returns all text message contents in the messages.
|
|
|
|
:rtype: List[FakeTextContent]
|
|
"""
|
|
if not self.content:
|
|
return []
|
|
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
|
|
|
|
|
FakeMessageType = Union[
|
|
ThreadMessage
|
|
| FakeMessage
|
|
| FakeMessageWithAnnotation
|
|
| FakeMessageWithUrlCitationAnnotation
|
|
| FakeMessageWithFileCitationAnnotation
|
|
]
|
|
|
|
|
|
async def mock_messages_list(**kwargs: Any) -> AsyncGenerator[FakeMessage, None]:
|
|
"""Mock async generator for messages.list()"""
|
|
messages = [FakeMessage("msg-mock", "response")]
|
|
for message in messages:
|
|
yield message
|
|
|
|
|
|
async def mock_messages_list_empty(**kwargs: Any) -> AsyncGenerator[FakeMessage, None]:
|
|
"""Mock async generator that yields no messages"""
|
|
# This generator yields nothing, simulating an empty message list
|
|
return
|
|
yield # This line is never reached but makes this a generator
|
|
|
|
|
|
async def mock_messages_list_multiple(**kwargs: Any) -> AsyncGenerator[FakeMessage, None]:
|
|
"""Mock async generator for multiple messages (pagination test)"""
|
|
messages = [
|
|
FakeMessage("msg-mock-1", "response-1"),
|
|
FakeMessage("msg-mock-2", "response-2"),
|
|
]
|
|
for message in messages:
|
|
yield message
|
|
|
|
|
|
def create_agent(
|
|
mock_project_client: MagicMock,
|
|
tools: Optional[ListToolType] = None,
|
|
agent_name: str = "test_agent",
|
|
description: str = "Test Azure AI Agent",
|
|
instructions: str = "Test instructions",
|
|
agent_id: Optional[str] = None,
|
|
thread_id: Optional[str] = None,
|
|
) -> AzureAIAgent:
|
|
return AzureAIAgent(
|
|
name=agent_name,
|
|
description=description,
|
|
project_client=mock_project_client,
|
|
deployment_name="test_model",
|
|
tools=tools,
|
|
instructions=instructions,
|
|
agent_id=agent_id,
|
|
thread_id=thread_id,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_project_client() -> MagicMock:
|
|
client = MagicMock(spec=AIProjectClient)
|
|
|
|
# Create separate operation groups to match the actual SDK structure
|
|
client.agents = MagicMock()
|
|
client.runs = MagicMock()
|
|
client.messages = MagicMock()
|
|
client.threads = MagicMock()
|
|
client.files = MagicMock()
|
|
client.vector_stores = MagicMock()
|
|
client.vector_store_files = MagicMock()
|
|
client.vector_store_file_batches = MagicMock()
|
|
|
|
# Agent operations
|
|
client.agents.create_agent = AsyncMock(return_value=MagicMock(id="assistant-mock"))
|
|
client.agents.get_agent = AsyncMock(return_value=MagicMock(id="assistant-mock"))
|
|
client.agents.update_agent = AsyncMock()
|
|
client.agents.delete_agent = AsyncMock()
|
|
|
|
agent_run = MagicMock()
|
|
agent_run.id = "run-mock"
|
|
agent_run.status = RunStatus.COMPLETED
|
|
|
|
client.agents.runs = MagicMock()
|
|
client.agents.runs.create = AsyncMock(return_value=agent_run)
|
|
client.agents.runs.get = AsyncMock(return_value=agent_run)
|
|
client.agents.runs.submit_tool_outputs = AsyncMock(return_value=agent_run)
|
|
|
|
client.agents.messages = MagicMock()
|
|
client.agents.messages.list = mock_messages_list
|
|
client.agents.messages.create = AsyncMock()
|
|
|
|
client.agents.threads = MagicMock()
|
|
client.agents.threads.get = AsyncMock(return_value=MagicMock(id="thread-mock"))
|
|
client.agents.threads.create = AsyncMock(return_value=MagicMock(id="thread-mock"))
|
|
client.agents.threads.update = AsyncMock()
|
|
|
|
client.agents.files = MagicMock()
|
|
client.agents.files.upload_and_poll = AsyncMock(return_value=MagicMock(id="file-mock", status=FileState.PROCESSED))
|
|
|
|
client.agents.vector_stores = MagicMock()
|
|
client.agents.vector_stores.create_and_poll = AsyncMock(return_value=MagicMock(id="vector_store_id"))
|
|
client.agents.vector_store_file_batches = MagicMock()
|
|
client.agents.vector_store_file_batches.create_and_poll = AsyncMock()
|
|
|
|
return client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_azure_ai_agent_initialization(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client, ["file_search"])
|
|
|
|
assert agent.name == "test_agent"
|
|
assert agent.description == "Test Azure AI Agent"
|
|
assert agent.deployment_name == "test_model"
|
|
assert agent.instructions == "Test instructions"
|
|
assert len(agent.tools) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_messages(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
response = await agent.on_messages(messages)
|
|
|
|
assert response is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_reset(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
await agent.on_reset(CancellationToken())
|
|
|
|
# The agent might call create_thread multiple times during initialization, so check if it was called at least once
|
|
assert mock_project_client.agents.threads.create.call_count > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_save_and_load_state(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client, agent_id="agent-mock", thread_id="thread-mock")
|
|
|
|
state = await agent.save_state()
|
|
assert state is not None
|
|
|
|
await agent.load_state(state)
|
|
|
|
assert agent.agent_id == state["agent_id"]
|
|
# assert agent._init_thread_id == state["thread_id"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_upload_for_code_interpreter(mock_project_client: MagicMock) -> None:
|
|
file_mock = MagicMock()
|
|
file_mock.id = "file-mock"
|
|
file_mock.status = FileState.PROCESSED
|
|
|
|
thread_mock = MagicMock()
|
|
thread_mock.tool_resources = MagicMock()
|
|
thread_mock.tool_resources.code_interpreter = MagicMock()
|
|
thread_mock.tool_resources.code_interpreter.file_ids = [] # Set as a valid list
|
|
|
|
mock_project_client.agents.files.upload_and_poll = AsyncMock(return_value=file_mock)
|
|
mock_project_client.agents.threads.get = AsyncMock(return_value=thread_mock)
|
|
mock_project_client.agents.threads.update = AsyncMock()
|
|
|
|
agent = create_agent(
|
|
mock_project_client,
|
|
)
|
|
|
|
file_paths = ["test_file_1.txt", "test_file_2.txt"]
|
|
await agent.on_upload_for_code_interpreter(file_paths)
|
|
|
|
mock_project_client.agents.files.upload_and_poll.assert_called()
|
|
mock_project_client.agents.threads.get.assert_called_once()
|
|
mock_project_client.agents.threads.update.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_upload_for_file_search(mock_project_client: MagicMock) -> None:
|
|
file_mock = MagicMock()
|
|
file_mock.id = "file-mock"
|
|
file_mock.status = FileState.PROCESSED # Set a valid status
|
|
|
|
mock_project_client.agents.files.upload_and_poll = AsyncMock(return_value=file_mock)
|
|
mock_project_client.agents.vector_stores.create_and_poll = AsyncMock(return_value=MagicMock(id="vector_store_id"))
|
|
mock_project_client.agents.update_agent = AsyncMock()
|
|
mock_project_client.agents.vector_store_file_batches.create_and_poll = AsyncMock()
|
|
|
|
agent = create_agent(mock_project_client, tools=["file_search"])
|
|
|
|
file_paths = ["test_file_1.txt", "test_file_2.txt"]
|
|
await agent.on_upload_for_file_search(file_paths, cancellation_token=CancellationToken())
|
|
|
|
mock_project_client.agents.files.upload_and_poll.assert_called()
|
|
mock_project_client.agents.vector_stores.create_and_poll.assert_called_once()
|
|
mock_project_client.agents.update_agent.assert_called_once()
|
|
mock_project_client.agents.vector_store_file_batches.create_and_poll.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_files(mock_project_client: MagicMock) -> None:
|
|
mock_project_client.agents.vector_store_file_batches.create_and_poll = AsyncMock()
|
|
|
|
mock_project_client.agents.update_agent = AsyncMock()
|
|
mock_project_client.agents.vector_stores.create_and_poll = AsyncMock(return_value=MagicMock(id="vector_store_id"))
|
|
|
|
mock_project_client.agents.files.upload_and_poll = AsyncMock(
|
|
return_value=MagicMock(id="file-id", status=FileState.PROCESSED)
|
|
)
|
|
|
|
agent = create_agent(mock_project_client, tools=["file_search"])
|
|
|
|
await agent.on_upload_for_file_search(["test_file.txt"], cancellation_token=CancellationToken())
|
|
|
|
mock_project_client.agents.files.upload_and_poll.assert_any_await(
|
|
file_path="test_file.txt", purpose=FilePurpose.AGENTS, polling_interval=0.5
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_messages_stream(mock_project_client: MagicMock) -> None:
|
|
mock_project_client.agents.runs.create = AsyncMock( # Corrected path
|
|
return_value=MagicMock(id="run-id", status=RunStatus.COMPLETED)
|
|
)
|
|
mock_project_client.agents.messages.list = mock_messages_list # Corrected path
|
|
|
|
agent = create_agent(mock_project_client)
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
async for response in agent.on_messages_stream(messages):
|
|
assert isinstance(response, Response)
|
|
assert response.chat_message.to_model_message().content == "response"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_messages_stream_with_tool(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client, tools=["file_search"])
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
async for response in agent.on_messages_stream(messages):
|
|
assert isinstance(response, Response)
|
|
assert response.chat_message.to_model_message().content == "response"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thread_id_validation(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
with pytest.raises(ValueError, match="Thread not"):
|
|
_ = agent.thread_id # Using _ for intentionally unused variable
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_agent_id_validation(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
with pytest.raises(ValueError, match="Agent not"):
|
|
_ = agent.agent_id # Using _ for intentionally unused variable
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"tool_name, should_raise_error",
|
|
[
|
|
("file_search", False),
|
|
("code_interpreter", False),
|
|
("bing_grounding", False),
|
|
("azure_function", False),
|
|
("azure_ai_search", False),
|
|
# ("sharepoint_grounding", False),
|
|
("unknown_tool", True),
|
|
],
|
|
)
|
|
async def test_adding_tools_as_literals(
|
|
mock_project_client: MagicMock, tool_name: Any, should_raise_error: bool
|
|
) -> None:
|
|
if should_raise_error:
|
|
with pytest.raises(ValueError, match=tool_name):
|
|
agent = create_agent(mock_project_client, tools=[tool_name]) # mypy ignore
|
|
else:
|
|
agent = create_agent(mock_project_client, tools=[tool_name])
|
|
assert agent.tools[0].type == tool_name
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"tool_definition",
|
|
[
|
|
FileSearchToolDefinition(),
|
|
CodeInterpreterToolDefinition(),
|
|
BingGroundingToolDefinition(), # type: ignore
|
|
AzureFunctionToolDefinition(), # type: ignore
|
|
AzureAISearchToolDefinition(),
|
|
# SharepointToolDefinition(), # type: ignore
|
|
],
|
|
)
|
|
async def test_adding_tools_as_typed_definition(mock_project_client: MagicMock, tool_definition: Any) -> None:
|
|
agent = create_agent(mock_project_client, tools=[tool_definition])
|
|
|
|
assert len(agent.tools) == 1
|
|
assert agent.tools[0].type == tool_definition.type
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adding_callable_func_as_tool(mock_project_client: MagicMock) -> None:
|
|
def mock_tool_func() -> None:
|
|
"""Mock tool function."""
|
|
pass
|
|
|
|
agent = create_agent(mock_project_client, tools=[mock_tool_func])
|
|
assert len(agent.tools) == 1
|
|
|
|
assert agent.tools[0].type == "function"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adding_core_autogen_tool(mock_project_client: MagicMock) -> None:
|
|
def mock_tool_func() -> None:
|
|
"""Mock tool function."""
|
|
pass
|
|
|
|
tool = FunctionTool(
|
|
func=mock_tool_func,
|
|
name="mock_tool",
|
|
description="Mock tool function",
|
|
)
|
|
|
|
agent = create_agent(mock_project_client, tools=[tool])
|
|
|
|
assert len(agent.tools) == 1
|
|
assert agent.tools[0].type == "function"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adding_core_autogen_tool_without_doc_string(mock_project_client: MagicMock) -> None:
|
|
def mock_tool_func() -> None:
|
|
pass
|
|
|
|
agent = create_agent(mock_project_client, tools=[mock_tool_func])
|
|
|
|
assert len(agent.tools) == 1
|
|
assert agent.tools[0].type == "function"
|
|
assert agent.tools[0].function.description == "" # type: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adding_unsupported_tool(mock_project_client: MagicMock) -> None:
|
|
tool_name: Any = 5
|
|
|
|
with pytest.raises(ValueError, match="class 'int'"):
|
|
create_agent(mock_project_client, tools=[tool_name])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization_with_no_agent_id(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
await agent.on_messages([TextMessage(content="Hello", source="user")])
|
|
|
|
mock_project_client.agents.create_agent.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization_with_agent_id(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client, agent_id="agent-mock")
|
|
|
|
await agent.on_messages([TextMessage(content="Hello", source="user")])
|
|
|
|
mock_project_client.agents.get_agent.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization_with_no_thread_id(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
await agent.on_messages([TextMessage(content="Hello", source="user")])
|
|
|
|
mock_project_client.agents.threads.create.assert_awaited_once() # Corrected path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization_with_thread_id(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client, thread_id="thread-mock")
|
|
|
|
await agent.on_messages([TextMessage(content="Hello", source="user")])
|
|
|
|
mock_project_client.agents.threads.get.assert_awaited_once() # Corrected path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization_fetching_multiple_pages_of_thread_messages(mock_project_client: MagicMock) -> None:
|
|
mock_project_client.agents.threads.get = AsyncMock(return_value=MagicMock(id="thread-id")) # Corrected path
|
|
# Mock the list_messages method to return multiple messages
|
|
mock_project_client.agents.messages.list = mock_messages_list_multiple # Corrected path
|
|
|
|
agent = create_agent(mock_project_client, thread_id="thread-id")
|
|
|
|
def assert_messages(actual: list[str], expected: List[str]) -> None:
|
|
assert len(actual) == len(expected)
|
|
for i in range(len(actual)):
|
|
assert actual[i] in expected
|
|
|
|
try:
|
|
await agent.on_messages([TextMessage(content="Hello", source="user")])
|
|
|
|
state = await agent.save_state()
|
|
assert state is not None
|
|
assert len(state["initial_message_ids"]) == 2
|
|
assert_messages(state["initial_message_ids"], ["msg-mock-1", "msg-mock-2"])
|
|
except StopAsyncIteration:
|
|
# Handle the StopAsyncIteration exception to allow the test to continue
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_messages_with_cancellation(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
# Create a cancellation token that's already cancelled
|
|
token = CancellationToken()
|
|
token.cancel()
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
|
|
with pytest.raises(CancelledError):
|
|
await agent.on_messages(messages, token)
|
|
|
|
|
|
def mock_run(action: str, run_id: str, required_action: Optional[RequiredAction] = None) -> MagicMock:
|
|
run = MagicMock()
|
|
run.id = run_id
|
|
run.status = action
|
|
run.required_action = required_action
|
|
return run
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"tool_name, registered_tools, error",
|
|
[
|
|
(
|
|
"function",
|
|
[
|
|
FunctionTool(
|
|
func=lambda: None,
|
|
name="mock_tool",
|
|
description="Mock tool function",
|
|
)
|
|
],
|
|
"is not available",
|
|
),
|
|
("function", None, "No tools"),
|
|
],
|
|
)
|
|
async def test_on_messages_return_required_action_with_no_tool_raise_error(
|
|
mock_project_client: MagicMock, tool_name: str, registered_tools: ListToolType, error: str
|
|
) -> None:
|
|
agent = create_agent(mock_project_client, tools=registered_tools)
|
|
|
|
complete_run = mock_run(RunStatus.COMPLETED, "run-mock")
|
|
mock_project_client.agents.runs.submit_tool_outputs = AsyncMock(return_value=complete_run) # Corrected path
|
|
|
|
required_action = SubmitToolOutputsAction(
|
|
submit_tool_outputs=SimpleNamespace( # type: ignore
|
|
tool_calls=[
|
|
SimpleNamespace(
|
|
type="function",
|
|
id="tool-mock",
|
|
name=tool_name,
|
|
function=SimpleNamespace(arguments={}, name="function"),
|
|
)
|
|
]
|
|
)
|
|
)
|
|
|
|
required_action.submit_tool_outputs = SimpleNamespace( # type: ignore
|
|
tool_calls=[
|
|
SimpleNamespace(
|
|
type="function", id="tool-mock", name=tool_name, function=SimpleNamespace(arguments={}, name="function")
|
|
)
|
|
]
|
|
) # mypy ignore
|
|
|
|
requires_action_run = mock_run(RunStatus.REQUIRES_ACTION, "run-mock", required_action)
|
|
mock_project_client.agents.runs.get = AsyncMock(side_effect=[requires_action_run, complete_run]) # Corrected path
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
|
|
response: Response = await agent.on_messages(messages)
|
|
|
|
# check why there are 2 inner messages
|
|
tool_call_events = [event for event in response.inner_messages if isinstance(event, ToolCallExecutionEvent)] # type: ignore
|
|
assert len(tool_call_events) == 1
|
|
|
|
event: ToolCallExecutionEvent = tool_call_events[0]
|
|
assert event.content[0].is_error is True
|
|
assert event.content[0].content.find(error) != -1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_message_raise_error_when_stream_return_nothing(mock_project_client: MagicMock) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
agent.on_messages_stream = MagicMock(name="on_messages_stream") # type: ignore
|
|
agent.on_messages_stream.__aiter__.return_value = []
|
|
|
|
with pytest.raises(AssertionError, match="have returned the final result"):
|
|
await agent.on_messages(messages)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"file_paths, file_status, should_raise_error",
|
|
[
|
|
(["file1.txt", "file2.txt"], FileState.PROCESSED, False),
|
|
(["file3.txt"], FileState.ERROR, True),
|
|
],
|
|
)
|
|
async def test_uploading_multiple_files(
|
|
mock_project_client: MagicMock, file_paths: list[str], file_status: FileState, should_raise_error: bool
|
|
) -> None:
|
|
agent = create_agent(mock_project_client)
|
|
|
|
file_mock = MagicMock(id="file-id", status=file_status)
|
|
mock_project_client.agents.threads.update = AsyncMock()
|
|
mock_project_client.agents.files.upload_and_poll = AsyncMock(return_value=file_mock)
|
|
|
|
async def upload_files() -> None:
|
|
await agent.on_upload_for_code_interpreter(
|
|
file_paths,
|
|
cancellation_token=CancellationToken(),
|
|
polling_interval=0.1,
|
|
)
|
|
|
|
if should_raise_error:
|
|
with pytest.raises(ValueError, match="upload failed with status"): # Changed from Exception to ValueError
|
|
await upload_files()
|
|
else:
|
|
await upload_files()
|
|
|
|
mock_project_client.agents.files.upload_and_poll.assert_has_calls(
|
|
[call(file_path=file_path, purpose=FilePurpose.AGENTS, polling_interval=0.1) for file_path in file_paths]
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"fake_message, url, title",
|
|
[
|
|
(
|
|
FakeMessageWithAnnotation(
|
|
"msg-mock-1",
|
|
"response-1",
|
|
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url1", "title1"), "text")],
|
|
),
|
|
"url1",
|
|
"title1",
|
|
),
|
|
(
|
|
FakeMessageWithUrlCitationAnnotation(
|
|
"msg-mock-2",
|
|
"response-2",
|
|
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url2", "title2"), "text")],
|
|
),
|
|
"url2",
|
|
"title2",
|
|
),
|
|
],
|
|
)
|
|
async def test_on_message_stream_mapping_url_citation(
|
|
mock_project_client: MagicMock,
|
|
fake_message: FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation,
|
|
url: str,
|
|
title: str,
|
|
) -> None:
|
|
mock_project_client.agents.runs.create = AsyncMock( # Corrected path and method name
|
|
return_value=MagicMock(id="run-id", status=RunStatus.COMPLETED)
|
|
)
|
|
|
|
async def mock_messages_list_with_citation(
|
|
**kwargs: Any,
|
|
) -> AsyncGenerator[FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation, None]:
|
|
"""Mock async generator for messages with citation"""
|
|
yield fake_message
|
|
|
|
mock_project_client.agents.messages.list = mock_messages_list_with_citation
|
|
|
|
agent = create_agent(mock_project_client)
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
|
|
async for response in agent.on_messages_stream(messages):
|
|
assert isinstance(response, Response)
|
|
assert response.chat_message is not None
|
|
assert response.chat_message.metadata is not None
|
|
|
|
citations = json.loads(response.chat_message.metadata["citations"])
|
|
assert citations is not None
|
|
|
|
assert len(citations) == 1
|
|
|
|
assert citations[0]["url"] == url
|
|
assert citations[0]["title"] == title
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_message_stream_mapping_file_citation(mock_project_client: MagicMock) -> None:
|
|
mock_project_client.agents.create_run = AsyncMock(return_value=MagicMock(id="run-id", status=RunStatus.COMPLETED))
|
|
|
|
expected_file_id = "file_id_1"
|
|
expected_quote = "this part of a file"
|
|
|
|
fake_message = FakeMessageWithFileCitationAnnotation(
|
|
"msg-mock-1",
|
|
"response-1",
|
|
[FakeTextFileCitationAnnotation(FakeTextFileCitationDetails(expected_file_id, expected_quote))],
|
|
)
|
|
|
|
async def mock_messages_list_with_file_citation(
|
|
**kwargs: Any,
|
|
) -> AsyncGenerator[FakeMessageWithFileCitationAnnotation, None]:
|
|
"""Mock async generator for messages with file citation"""
|
|
yield fake_message
|
|
|
|
mock_project_client.agents.messages.list = mock_messages_list_with_file_citation
|
|
|
|
agent = create_agent(mock_project_client)
|
|
|
|
messages = [TextMessage(content="Hello", source="user")]
|
|
|
|
async for response in agent.on_messages_stream(messages):
|
|
assert isinstance(response, Response)
|
|
assert response.chat_message is not None
|
|
assert response.chat_message.metadata is not None
|
|
|
|
citations = json.loads(response.chat_message.metadata["citations"])
|
|
assert citations is not None
|
|
|
|
assert len(citations) == 1
|
|
|
|
assert citations[0]["file_id"] == expected_file_id
|
|
assert citations[0]["text"] == expected_quote
|