import json from asyncio import CancelledError from types import SimpleNamespace from typing import Any, List, Optional, Union from unittest.mock import AsyncMock, MagicMock, call import azure.ai.projects.models as models import pytest from autogen_agentchat.base._chat_agent import Response from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent from autogen_core._cancellation_token import CancellationToken from autogen_core.tools._function_tool import FunctionTool from autogen_ext.agents.azure._azure_ai_agent import AzureAIAgent from autogen_ext.agents.azure._types import ListToolType from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ThreadMessage class FakeText: def __init__(self, value: str) -> None: self.value = value class FakeTextContent: def __init__(self, text: str) -> None: self.type = "text" self.text = FakeText(text) class FakeMessage: def __init__(self, id: str, text: str) -> None: self.id = id # The agent expects content to be a list of objects with a "type" attribute. self.content = [FakeTextContent(text)] self.role = "user" @property def text_messages(self) -> List[FakeTextContent]: """Returns all text message contents in the messages. :rtype: List[FakeTextContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, FakeTextContent)] class FakeMessageUrlCitationDetails: def __init__(self, url: str, title: str) -> None: self.url = url self.title = title class FakeTextUrlCitationAnnotation: def __init__(self, citation_details: FakeMessageUrlCitationDetails, text: str) -> None: self.type = "url_citation" self.url_citation = citation_details self.text = text class FakeTextFileCitationDetails: def __init__(self, file_id: str, quote: str) -> None: self.file_id = file_id self.quote = quote class FakeTextFileCitationAnnotation: def __init__(self, citation_details: FakeTextFileCitationDetails) -> None: self.type = "file_citation" self.file_citation = citation_details class FakeMessageWithUrlCitationAnnotation: def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None: self.id = id # The agent expects content to be a list of objects with a "type" attribute. self.content = [FakeTextContent(text)] self.role = "user" self._annotations = annotations @property def text_messages(self) -> List[FakeTextContent]: """Returns all text message contents in the messages. :rtype: List[FakeTextContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, FakeTextContent)] @property def url_citation_annotations(self) -> List[FakeTextUrlCitationAnnotation]: """Returns all URL citation annotations from text message annotations in the messages. :rtype: List[FakeTextUrlCitationAnnotation] """ return self._annotations class FakeMessageWithFileCitationAnnotation: def __init__(self, id: str, text: str, annotations: list[FakeTextFileCitationAnnotation]) -> None: self.id = id # The agent expects content to be a list of objects with a "type" attribute. self.content = [FakeTextContent(text)] self.role = "user" self._annotations = annotations @property def text_messages(self) -> List[FakeTextContent]: """Returns all text message contents in the messages. :rtype: List[FakeTextContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, FakeTextContent)] @property def file_citation_annotations(self) -> List[FakeTextFileCitationAnnotation]: """Returns all URL citation annotations from text message annotations in the messages. :rtype: List[FakeTextFileCitationAnnotation] """ return self._annotations class FakeMessageWithAnnotation: def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None: self.id = id # The agent expects content to be a list of objects with a "type" attribute. self.content = [FakeTextContent(text)] self.role = "user" self.annotations = annotations @property def text_messages(self) -> List[FakeTextContent]: """Returns all text message contents in the messages. :rtype: List[FakeTextContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, FakeTextContent)] FakeMessageType = Union[ ThreadMessage | FakeMessage | FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation | FakeMessageWithFileCitationAnnotation ] class FakeOpenAIPageableListOfThreadMessage: def __init__( self, data: List[FakeMessageType], has_more: bool = False, ) -> None: self.data = data self._has_more = has_more @property def has_more(self) -> bool: return self._has_more @property def text_messages(self) -> List[ThreadMessage | FakeTextContent]: """Returns all text message contents in the messages. :rtype: List[FakeMessage] """ texts = [content for msg in self.data for content in msg.text_messages] # type: ignore return texts # type: ignore def get_last_message_by_role( self, role: models.MessageRole ) -> Optional[ThreadMessage | FakeMessage | FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation]: """Returns the last message from a sender in the specified role. :param role: The role of the sender. :type role: MessageRole :return: The last message from a sender in the specified role. :rtype: ~azure.ai.projects.models.ThreadMessage """ for msg in self.data: if msg.role == role: return msg # type: ignore return None def mock_list( data: Optional[List[FakeMessageType]] = None, has_more: bool = False, ) -> FakeOpenAIPageableListOfThreadMessage: if data is None or len(data) == 0: data = [FakeMessage("msg-mock", "response")] return FakeOpenAIPageableListOfThreadMessage(data, has_more=has_more) def create_agent( mock_project_client: MagicMock, tools: Optional[ListToolType] = None, agent_name: str = "test_agent", description: str = "Test Azure AI Agent", instructions: str = "Test instructions", agent_id: Optional[str] = None, thread_id: Optional[str] = None, ) -> AzureAIAgent: return AzureAIAgent( name=agent_name, description=description, project_client=mock_project_client, deployment_name="test_model", tools=tools, instructions=instructions, agent_id=agent_id, thread_id=thread_id, ) @pytest.fixture def mock_project_client() -> MagicMock: client = MagicMock(spec=AIProjectClient) agents = MagicMock() client.agents = agents client.agents.create_agent = AsyncMock(return_value=MagicMock(id="assistant-mock")) client.agents.get_agent = AsyncMock(return_value=MagicMock(id="assistant-mock")) agent_run = MagicMock() agent_run.id = "run-mock" agent_run.status = "completed" client.agents.create_run = AsyncMock(return_value=agent_run) client.agents.get_run = AsyncMock(return_value=agent_run) client.agents.list_messages = AsyncMock(return_value=mock_list()) client.agents.create_message = AsyncMock() client.agents.get_thread = AsyncMock(id="thread-mock", return_value=MagicMock(id="thread-mock")) client.agents.create_thread = AsyncMock(return_value=MagicMock(id="thread-mock")) return client @pytest.mark.asyncio async def test_azure_ai_agent_initialization(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client, ["file_search"]) assert agent.name == "test_agent" assert agent.description == "Test Azure AI Agent" assert agent.deployment_name == "test_model" assert agent.instructions == "Test instructions" assert len(agent.tools) == 1 @pytest.mark.asyncio async def test_on_messages(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) messages = [TextMessage(content="Hello", source="user")] response = await agent.on_messages(messages) assert response is not None @pytest.mark.asyncio async def test_on_reset(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) await agent.on_reset(CancellationToken()) mock_project_client.agents.create_thread.assert_called_once() @pytest.mark.asyncio async def test_save_and_load_state(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client, agent_id="agent-mock", thread_id="thread-mock") state = await agent.save_state() assert state is not None await agent.load_state(state) assert agent.agent_id == state["agent_id"] # assert agent._init_thread_id == state["thread_id"] @pytest.mark.asyncio async def test_on_upload_for_code_interpreter(mock_project_client: MagicMock) -> None: file_mock = AsyncMock() file_mock.id = "file-mock" file_mock.status = "processed" thread_mock = AsyncMock() thread_mock.tool_resources = AsyncMock() thread_mock.tool_resources.code_interpreter = AsyncMock() thread_mock.tool_resources.code_interpreter.file_ids = [] # Set as a valid list mock_project_client.agents.upload_file_and_poll = AsyncMock(return_value=file_mock) mock_project_client.agents.get_thread = AsyncMock(return_value=thread_mock) mock_project_client.agents.update_thread = AsyncMock() agent = create_agent( mock_project_client, ) file_paths = ["test_file_1.txt", "test_file_2.txt"] await agent.on_upload_for_code_interpreter(file_paths) mock_project_client.agents.upload_file_and_poll.assert_called() mock_project_client.agents.get_thread.assert_called_once() mock_project_client.agents.update_thread.assert_called_once() @pytest.mark.asyncio async def test_on_upload_for_file_search(mock_project_client: MagicMock) -> None: file_mock = AsyncMock() file_mock.id = "file-mock" file_mock.status = "processed" # Set a valid status mock_project_client.agents.upload_file_and_poll = AsyncMock(return_value=file_mock) mock_project_client.agents.create_vector_store_and_poll = AsyncMock(return_value=AsyncMock(id="vector_store_id")) mock_project_client.agents.update_agent = AsyncMock() mock_project_client.agents.create_vector_store_file_batch_and_poll = AsyncMock() agent = create_agent(mock_project_client, tools=["file_search"]) file_paths = ["test_file_1.txt", "test_file_2.txt"] await agent.on_upload_for_file_search(file_paths, cancellation_token=CancellationToken()) mock_project_client.agents.upload_file_and_poll.assert_called() mock_project_client.agents.create_vector_store_and_poll.assert_called_once() mock_project_client.agents.update_agent.assert_called_once() mock_project_client.agents.create_vector_store_file_batch_and_poll.assert_called_once() @pytest.mark.asyncio async def test_upload_files(mock_project_client: MagicMock) -> None: mock_project_client.agents.create_vector_store_file_batch_and_poll = AsyncMock() mock_project_client.agents.update_agent = AsyncMock() mock_project_client.agents.create_vector_store_and_poll = AsyncMock(return_value=AsyncMock(id="vector_store_id")) mock_project_client.agents.upload_file_and_poll = AsyncMock( return_value=AsyncMock(id="file-id", status=models.FileState.PROCESSED) ) agent = create_agent(mock_project_client, tools=["file_search"]) await agent.on_upload_for_file_search(["test_file.txt"], cancellation_token=CancellationToken()) mock_project_client.agents.upload_file_and_poll.assert_any_await( file_path="test_file.txt", purpose=models.FilePurpose.AGENTS, sleep_interval=0.5 ) @pytest.mark.asyncio async def test_on_messages_stream(mock_project_client: MagicMock) -> None: mock_project_client.agents.create_run = AsyncMock( return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED) ) mock_project_client.agents.list_messages = AsyncMock(return_value=mock_list()) agent = create_agent(mock_project_client) messages = [TextMessage(content="Hello", source="user")] async for response in agent.on_messages_stream(messages): assert isinstance(response, Response) assert response.chat_message.to_model_message().content == "response" @pytest.mark.asyncio async def test_on_messages_stream_with_tool(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client, tools=["file_search"]) messages = [TextMessage(content="Hello", source="user")] async for response in agent.on_messages_stream(messages): assert isinstance(response, Response) assert response.chat_message.to_model_message().content == "response" @pytest.mark.asyncio async def test_thread_id_validation(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) with pytest.raises(ValueError, match="Thread not"): _ = agent.thread_id # Using _ for intentionally unused variable @pytest.mark.asyncio async def test_get_agent_id_validation(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) with pytest.raises(ValueError, match="Agent not"): _ = agent.agent_id # Using _ for intentionally unused variable @pytest.mark.asyncio @pytest.mark.parametrize( "tool_name, should_raise_error", [ ("file_search", False), ("code_interpreter", False), ("bing_grounding", False), ("azure_function", False), ("azure_ai_search", False), ("sharepoint_grounding", False), ("unknown_tool", True), ], ) async def test_adding_tools_as_literals( mock_project_client: MagicMock, tool_name: Any, should_raise_error: bool ) -> None: if should_raise_error: with pytest.raises(ValueError, match=tool_name): agent = create_agent(mock_project_client, tools=[tool_name]) # mypy ignore else: agent = agent = create_agent(mock_project_client, tools=[tool_name]) assert agent.tools[0].type == tool_name @pytest.mark.asyncio @pytest.mark.parametrize( "tool_definition", [ models.FileSearchToolDefinition(), models.CodeInterpreterToolDefinition(), models.BingGroundingToolDefinition(), # type: ignore models.AzureFunctionToolDefinition(), # type: ignore models.AzureAISearchToolDefinition(), models.SharepointToolDefinition(), # type: ignore ], ) async def test_adding_tools_as_typed_definition(mock_project_client: MagicMock, tool_definition: Any) -> None: agent = agent = create_agent(mock_project_client, tools=[tool_definition]) assert len(agent.tools) == 1 assert agent.tools[0].type == tool_definition.type @pytest.mark.asyncio async def test_adding_callable_func_as_tool(mock_project_client: MagicMock) -> None: def mock_tool_func() -> None: """Mock tool function.""" pass agent = create_agent(mock_project_client, tools=[mock_tool_func]) assert len(agent.tools) == 1 assert agent.tools[0].type == "function" @pytest.mark.asyncio async def test_adding_core_autogen_tool(mock_project_client: MagicMock) -> None: def mock_tool_func() -> None: """Mock tool function.""" pass tool = FunctionTool( func=mock_tool_func, name="mock_tool", description="Mock tool function", ) agent = create_agent(mock_project_client, tools=[tool]) assert len(agent.tools) == 1 assert agent.tools[0].type == "function" @pytest.mark.asyncio async def test_adding_core_autogen_tool_without_doc_string(mock_project_client: MagicMock) -> None: def mock_tool_func() -> None: pass agent = create_agent(mock_project_client, tools=[mock_tool_func]) assert len(agent.tools) == 1 assert agent.tools[0].type == "function" assert agent.tools[0].function.description == "" # type: ignore @pytest.mark.asyncio async def test_adding_unsupported_tool(mock_project_client: MagicMock) -> None: tool_name: Any = 5 with pytest.raises(ValueError, match="class 'int'"): create_agent(mock_project_client, tools=[tool_name]) @pytest.mark.asyncio async def test_agent_initialization_with_no_agent_id(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) await agent.on_messages([TextMessage(content="Hello", source="user")]) mock_project_client.agents.create_agent.assert_awaited_once() @pytest.mark.asyncio async def test_agent_initialization_with_agent_id(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client, agent_id="agent-mock") await agent.on_messages([TextMessage(content="Hello", source="user")]) mock_project_client.agents.get_agent.assert_awaited_once() @pytest.mark.asyncio async def test_agent_initialization_with_no_thread_id(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) await agent.on_messages([TextMessage(content="Hello", source="user")]) mock_project_client.agents.create_thread.assert_awaited_once() @pytest.mark.asyncio async def test_agent_initialization_with_thread_id(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client, thread_id="thread-mock") await agent.on_messages([TextMessage(content="Hello", source="user")]) mock_project_client.agents.get_thread.assert_awaited_once() @pytest.mark.asyncio async def test_agent_initialization_fetching_multiple_pages_of_thread_messages(mock_project_client: MagicMock) -> None: list_messages = [ FakeOpenAIPageableListOfThreadMessage([FakeMessage("msg-mock-1", "response-1")], has_more=True), FakeOpenAIPageableListOfThreadMessage([FakeMessage("msg-mock-2", "response-2")]), FakeOpenAIPageableListOfThreadMessage( [FakeMessage("msg-mock-1", "response-1"), FakeMessage("msg-mock-2", "response-2")] ), ] mock_project_client.agents.get_thread = AsyncMock(id="thread-id", return_value=MagicMock(id="thread-id")) # Mock the list_messages method to return multiple pages of messages mock_project_client.agents.list_messages = AsyncMock(side_effect=list_messages) agent = create_agent(mock_project_client, thread_id="thread-id") def assert_messages(actual: list[str], expected: List[str]) -> None: assert len(actual) == len(expected) for i in range(len(actual)): assert actual[i] in expected try: await agent.on_messages([TextMessage(content="Hello", source="user")]) state = await agent.save_state() assert state is not None assert len(state["initial_message_ids"]) == 2 assert_messages(state["initial_message_ids"], ["msg-mock-1", "msg-mock-2"]) except StopAsyncIteration: # Handle the StopAsyncIteration exception to allow the test to continue pass @pytest.mark.asyncio async def test_on_messages_with_cancellation(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) # Create a cancellation token that's already cancelled token = CancellationToken() token.cancel() messages = [TextMessage(content="Hello", source="user")] with pytest.raises(CancelledError): await agent.on_messages(messages, token) def mock_run(action: str, run_id: str, required_action: Optional[models.RequiredAction] = None) -> MagicMock: run = MagicMock() run.id = run_id run.status = action run.required_action = required_action return run @pytest.mark.asyncio @pytest.mark.parametrize( "tool_name, registered_tools, error", [ ( "function", [ FunctionTool( func=lambda: None, name="mock_tool", description="Mock tool function", ) ], "is not available", ), ("function", None, "No tools"), ], ) async def test_on_messages_return_required_action_with_no_tool_raise_error( mock_project_client: MagicMock, tool_name: str, registered_tools: ListToolType, error: str ) -> None: agent = create_agent(mock_project_client, tools=registered_tools) complete_run = mock_run("completed", "run-mock") mock_project_client.agents.submit_tool_outputs_to_run = AsyncMock(return_value=complete_run) required_action = models.SubmitToolOutputsAction() # type: ignore required_action.submit_tool_outputs = SimpleNamespace( # type: ignore tool_calls=[ SimpleNamespace( type="function", id="tool-mock", name=tool_name, function=SimpleNamespace(arguments={}, name="function") ) ] ) # mypy ignore requires_action_run = mock_run("requires_action", "run-mock", required_action) mock_project_client.agents.get_run = AsyncMock(side_effect=[requires_action_run, complete_run]) messages = [TextMessage(content="Hello", source="user")] response: Response = await agent.on_messages(messages) # check why there are 2 inner messages tool_call_events = [event for event in response.inner_messages if isinstance(event, ToolCallExecutionEvent)] # type: ignore assert len(tool_call_events) == 1 event: ToolCallExecutionEvent = tool_call_events[0] assert event.content[0].is_error is True assert event.content[0].content.find(error) != -1 @pytest.mark.asyncio async def test_on_message_raise_error_when_stream_return_nothing(mock_project_client: MagicMock) -> None: agent = create_agent(mock_project_client) messages = [TextMessage(content="Hello", source="user")] agent.on_messages_stream = MagicMock(name="on_messages_stream") # type: ignore agent.on_messages_stream.__aiter__.return_value = [] with pytest.raises(AssertionError, match="have returned the final result"): await agent.on_messages(messages) @pytest.mark.asyncio @pytest.mark.parametrize( "file_paths, file_status, should_raise_error", [ (["file1.txt", "file2.txt"], models.FileState.PROCESSED, False), (["file3.txt"], models.FileState.ERROR, True), ], ) async def test_uploading_multiple_files( mock_project_client: MagicMock, file_paths: list[str], file_status: models.FileState, should_raise_error: bool ) -> None: agent = create_agent(mock_project_client) file_mock = AsyncMock(id="file-id", status=file_status) mock_project_client.agents.update_thread = AsyncMock() mock_project_client.agents.upload_file_and_poll = AsyncMock(return_value=file_mock) async def upload_files() -> None: await agent.on_upload_for_code_interpreter( file_paths, cancellation_token=CancellationToken(), sleep_interval=0.1, ) if should_raise_error: with pytest.raises(Exception, match="upload failed with status"): await upload_files() else: await upload_files() mock_project_client.agents.upload_file_and_poll.assert_has_calls( [call(file_path=file_path, purpose=models.FilePurpose.AGENTS, sleep_interval=0.1) for file_path in file_paths] ) @pytest.mark.asyncio @pytest.mark.parametrize( "fake_message, url, title", [ ( FakeMessageWithAnnotation( "msg-mock-1", "response-1", [FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url1", "title1"), "text")], ), "url1", "title1", ), ( FakeMessageWithUrlCitationAnnotation( "msg-mock-2", "response-2", [FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url2", "title2"), "text")], ), "url2", "title2", ), ], ) async def test_on_message_stream_mapping_url_citation( mock_project_client: MagicMock, fake_message: FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation, url: str, title: str, ) -> None: mock_project_client.agents.create_run = AsyncMock( return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED) ) list = mock_list([fake_message], has_more=False) mock_project_client.agents.list_messages = AsyncMock(return_value=list) agent = create_agent(mock_project_client) messages = [TextMessage(content="Hello", source="user")] async for response in agent.on_messages_stream(messages): assert isinstance(response, Response) assert response.chat_message is not None assert response.chat_message.metadata is not None citations = json.loads(response.chat_message.metadata["citations"]) assert citations is not None assert len(citations) == 1 assert citations[0]["url"] == url assert citations[0]["title"] == title @pytest.mark.asyncio async def test_on_message_stream_mapping_file_citation(mock_project_client: MagicMock) -> None: mock_project_client.agents.create_run = AsyncMock( return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED) ) expected_file_id = "file_id_1" expected_quote = "this part of a file" fake_message = FakeMessageWithFileCitationAnnotation( "msg-mock-1", "response-1", [FakeTextFileCitationAnnotation(FakeTextFileCitationDetails(expected_file_id, expected_quote))], ) list = mock_list([fake_message], has_more=False) mock_project_client.agents.list_messages = AsyncMock(return_value=list) agent = create_agent(mock_project_client) messages = [TextMessage(content="Hello", source="user")] async for response in agent.on_messages_stream(messages): assert isinstance(response, Response) assert response.chat_message is not None assert response.chat_message.metadata is not None citations = json.loads(response.chat_message.metadata["citations"]) assert citations is not None assert len(citations) == 1 assert citations[0]["file_id"] == expected_file_id assert citations[0]["text"] == expected_quote