Bing grounding citations (#6370)

Adding support for Bing grounding citations to the AzureAIAgent.

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [X] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [X] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [X] I've made sure all auto checks have passed.

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Dheeraj Bandaru <BandaruDheeraj@users.noreply.github.com>
This commit is contained in:
Abdo Talema 2025-04-29 06:09:13 +10:00 committed by GitHub
parent 998840f7e0
commit 881cd6a75c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 312 additions and 80 deletions

View File

@ -12,9 +12,10 @@ from typing import (
Optional,
Sequence,
Set,
cast,
)
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat import TRACE_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
@ -38,7 +39,7 @@ from azure.ai.projects.aio import AIProjectClient
from ._types import AzureAIAgentState, ListToolType
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
class AzureAIAgent(BaseChatAgent):
@ -111,8 +112,16 @@ class AzureAIAgent(BaseChatAgent):
metadata={"source": "AzureAIAgent"},
)
# For the bing grounding tool to return the citations, the message must contain an instruction for the model to do return them.
# For example: "Please provide citations for the answers"
result = await agent_with_bing_grounding.on_messages(
messages=[TextMessage(content="What is Microsoft's annual leave policy?", source="user")],
messages=[
TextMessage(
content="What is Microsoft's annual leave policy? Provide citations for your answers.",
source="user",
)
],
cancellation_token=CancellationToken(),
message_limit=5,
)
@ -575,7 +584,7 @@ class AzureAIAgent(BaseChatAgent):
if file.status != models.FileState.PROCESSED:
raise ValueError(f"File upload failed with status {file.status}")
event_logger.debug(f"File uploaded successfully: {file.id}, {file_name}")
trace_logger.debug(f"File uploaded successfully: {file.id}, {file_name}")
file_ids.append(file.id)
self._uploaded_file_ids.append(file.id)
@ -644,11 +653,11 @@ class AzureAIAgent(BaseChatAgent):
Raises:
ValueError: If the run fails or no message is received from the assistant
"""
await self._ensure_initialized()
if cancellation_token is None:
cancellation_token = CancellationToken()
await self._ensure_initialized()
# Process all messages in sequence
for message in messages:
if isinstance(message, (TextMessage, MultiModalMessage)):
@ -701,7 +710,7 @@ class AzureAIAgent(BaseChatAgent):
# Add tool call message to inner messages
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
inner_messages.append(tool_call_msg)
event_logger.debug(tool_call_msg)
trace_logger.debug(tool_call_msg)
yield tool_call_msg
# Execute tool calls and get results
@ -725,7 +734,7 @@ class AzureAIAgent(BaseChatAgent):
# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
inner_messages.append(tool_result_msg)
event_logger.debug(tool_result_msg)
trace_logger.debug(tool_result_msg)
yield tool_result_msg
# Submit tool outputs back to the run
@ -748,15 +757,8 @@ class AzureAIAgent(BaseChatAgent):
# TODO support for parameter to control polling interval
await asyncio.sleep(sleep_interval)
# run_steps: models.OpenAIPageableListOfRunStep = await cancellation_token.link_future(
# asyncio.ensure_future(
# self._project_client.agents.list_run_steps(
# thread_id=self.thread_id,
# run_id=run.id,
# )
# )
# )
# Get messages after run completion
# After run is completed, get the messages
trace_logger.debug("Retrieving messages from thread")
agent_messages: models.OpenAIPageableListOfThreadMessage = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.list_messages(
@ -768,18 +770,67 @@ class AzureAIAgent(BaseChatAgent):
if not agent_messages.data:
raise ValueError("No messages received from assistant")
# Get the last message's content
last_message = agent_messages.data[0]
# Get the last message from the agent
last_message: Optional[models.ThreadMessage] = agent_messages.get_last_message_by_role(models.MessageRole.AGENT)
if not last_message:
trace_logger.debug("No message with AGENT role found, falling back to first message")
last_message = agent_messages.data[0] # Fallback to first message
if not last_message.content:
raise ValueError(f"No content in the last message: {last_message}")
raise ValueError("No content in the last message")
# Extract text content
text_content = agent_messages.text_messages
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message.content}")
message_text = ""
for text_message in last_message.text_messages:
message_text += text_message.text.value
# Extract citations
citations: list[Any] = []
# Try accessing annotations directly
annotations = getattr(last_message, "annotations", [])
if isinstance(annotations, list) and annotations:
annotations = cast(List[models.MessageTextUrlCitationAnnotation], annotations)
trace_logger.debug(f"Found {len(annotations)} annotations")
for annotation in annotations:
if hasattr(annotation, "url_citation"): # type: ignore
trace_logger.debug(f"Citation found: {annotation.url_citation.url}")
citations.append(
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
)
# For backwards compatibility
elif hasattr(last_message, "url_citation_annotations") and last_message.url_citation_annotations:
url_annotations = cast(List[Any], last_message.url_citation_annotations)
trace_logger.debug(f"Found {len(url_annotations)} URL citations")
for annotation in url_annotations:
citations.append(
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
)
elif hasattr(last_message, "file_citation_annotations") and last_message.file_citation_annotations:
file_annotations = cast(List[Any], last_message.file_citation_annotations)
trace_logger.debug(f"Found {len(file_annotations)} URL citations")
for annotation in file_annotations:
citations.append(
{"file_id": annotation.file_citation.file_id, "title": None, "text": annotation.file_citation.quote} # type: ignore
)
trace_logger.debug(f"Total citations extracted: {len(citations)}")
# Create the response message with citations as JSON string
chat_message = TextMessage(
source=self.name, content=message_text, metadata={"citations": json.dumps(citations)} if citations else {}
)
# Return the assistant's response as a Response with inner messages
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
yield Response(chat_message=chat_message, inner_messages=inner_messages)
async def handle_text_message(self, content: str, cancellation_token: Optional[CancellationToken] = None) -> None:

View File

@ -1,13 +1,4 @@
from typing import (
Any,
Awaitable,
Callable,
Iterable,
List,
Literal,
Optional,
Union,
)
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
from autogen_core.tools import Tool
from pydantic import BaseModel, Field
@ -59,3 +50,7 @@ class AzureAIAgentState(BaseModel):
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
def has_annotations(obj: Any) -> TypeGuard[list[models.MessageTextUrlCitationAnnotation]]:
return obj is not None and isinstance(obj, list)

View File

@ -1,6 +1,7 @@
import json
from asyncio import CancelledError
from types import SimpleNamespace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
from unittest.mock import AsyncMock, MagicMock, call
import azure.ai.projects.models as models
@ -37,15 +38,126 @@ class FakeMessage:
def text_messages(self) -> List[FakeTextContent]:
"""Returns all text message contents in the messages.
:rtype: List[MessageTextContent]
:rtype: List[FakeTextContent]
"""
if not self.content:
return []
return [content for content in self.content if isinstance(content, FakeTextContent)]
class FakeMessageUrlCitationDetails:
def __init__(self, url: str, title: str) -> None:
self.url = url
self.title = title
class FakeTextUrlCitationAnnotation:
def __init__(self, citation_details: FakeMessageUrlCitationDetails, text: str) -> None:
self.type = "url_citation"
self.url_citation = citation_details
self.text = text
class FakeTextFileCitationDetails:
def __init__(self, file_id: str, quote: str) -> None:
self.file_id = file_id
self.quote = quote
class FakeTextFileCitationAnnotation:
def __init__(self, citation_details: FakeTextFileCitationDetails) -> None:
self.type = "file_citation"
self.file_citation = citation_details
class FakeMessageWithUrlCitationAnnotation:
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
self.id = id
# The agent expects content to be a list of objects with a "type" attribute.
self.content = [FakeTextContent(text)]
self.role = "user"
self._annotations = annotations
@property
def text_messages(self) -> List[FakeTextContent]:
"""Returns all text message contents in the messages.
:rtype: List[FakeTextContent]
"""
if not self.content:
return []
return [content for content in self.content if isinstance(content, FakeTextContent)]
@property
def url_citation_annotations(self) -> List[FakeTextUrlCitationAnnotation]:
"""Returns all URL citation annotations from text message annotations in the messages.
:rtype: List[FakeTextUrlCitationAnnotation]
"""
return self._annotations
class FakeMessageWithFileCitationAnnotation:
def __init__(self, id: str, text: str, annotations: list[FakeTextFileCitationAnnotation]) -> None:
self.id = id
# The agent expects content to be a list of objects with a "type" attribute.
self.content = [FakeTextContent(text)]
self.role = "user"
self._annotations = annotations
@property
def text_messages(self) -> List[FakeTextContent]:
"""Returns all text message contents in the messages.
:rtype: List[FakeTextContent]
"""
if not self.content:
return []
return [content for content in self.content if isinstance(content, FakeTextContent)]
@property
def file_citation_annotations(self) -> List[FakeTextFileCitationAnnotation]:
"""Returns all URL citation annotations from text message annotations in the messages.
:rtype: List[FakeTextFileCitationAnnotation]
"""
return self._annotations
class FakeMessageWithAnnotation:
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
self.id = id
# The agent expects content to be a list of objects with a "type" attribute.
self.content = [FakeTextContent(text)]
self.role = "user"
self.annotations = annotations
@property
def text_messages(self) -> List[FakeTextContent]:
"""Returns all text message contents in the messages.
:rtype: List[FakeTextContent]
"""
if not self.content:
return []
return [content for content in self.content if isinstance(content, FakeTextContent)]
FakeMessageType = Union[
ThreadMessage
| FakeMessage
| FakeMessageWithAnnotation
| FakeMessageWithUrlCitationAnnotation
| FakeMessageWithFileCitationAnnotation
]
class FakeOpenAIPageableListOfThreadMessage:
def __init__(self, data: List[ThreadMessage | FakeMessage], has_more: bool = False) -> None:
def __init__(
self,
data: List[FakeMessageType],
has_more: bool = False,
) -> None:
self.data = data
self._has_more = has_more
@ -62,9 +174,31 @@ class FakeOpenAIPageableListOfThreadMessage:
texts = [content for msg in self.data for content in msg.text_messages] # type: ignore
return texts # type: ignore
def get_last_message_by_role(
self, role: models.MessageRole
) -> Optional[ThreadMessage | FakeMessage | FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation]:
"""Returns the last message from a sender in the specified role.
def mock_list() -> FakeOpenAIPageableListOfThreadMessage:
return FakeOpenAIPageableListOfThreadMessage([FakeMessage("msg-mock", "response")])
:param role: The role of the sender.
:type role: MessageRole
:return: The last message from a sender in the specified role.
:rtype: ~azure.ai.projects.models.ThreadMessage
"""
for msg in self.data:
if msg.role == role:
return msg # type: ignore
return None
def mock_list(
data: Optional[List[FakeMessageType]] = None,
has_more: bool = False,
) -> FakeOpenAIPageableListOfThreadMessage:
if data is None or len(data) == 0:
data = [FakeMessage("msg-mock", "response")]
return FakeOpenAIPageableListOfThreadMessage(data, has_more=has_more)
def create_agent(
@ -207,47 +341,6 @@ async def test_on_upload_for_file_search(mock_project_client: MagicMock) -> None
mock_project_client.agents.create_vector_store_file_batch_and_poll.assert_called_once()
# @pytest.mark.asyncio
# async def test_add_tools(mock_project_client: MagicMock) -> None:
# agent = create_agent(mock_project_client)
# tools: Optional[ListToolType] = ["file_search", "code_interpreter"]
# converted_tools: List[ToolDefinition] = []
# agent._add_tools(tools, converted_tools)
# assert len(converted_tools) == 2
# assert isinstance(converted_tools[0], models.FileSearchToolDefinition)
# assert isinstance(converted_tools[1], models.CodeInterpreterToolDefinition)
# @pytest.mark.asyncio
# async def test_ensure_initialized(mock_project_client: MagicMock) -> None:
# agent = create_agent(mock_project_client)
# await agent._ensure_initialized(create_new_agent=True, create_new_thread=True)
# assert agent._agent is not None
# assert agent._thread is not None
# @pytest.mark.asyncio
# async def test_execute_tool_call(mock_project_client: MagicMock) -> None:
# mock_tool = MagicMock()
# mock_tool.name = "test_tool"
# mock_tool.run_json = AsyncMock(return_value={"result": "success"})
# mock_tool.return_value_as_string = MagicMock(return_value="success")
# agent = create_agent(mock_project_client)
# agent._original_tools = [mock_tool]
# tool_call = FunctionCall(id="test_tool", name="test_tool", arguments="{}")
# result = await agent._execute_tool_call(tool_call, CancellationToken())
# assert result == "success"
# mock_tool.run_json.assert_called_once()
@pytest.mark.asyncio
async def test_upload_files(mock_project_client: MagicMock) -> None:
mock_project_client.agents.create_vector_store_file_batch_and_poll = AsyncMock()
@ -590,3 +683,96 @@ async def test_uploading_multiple_files(
mock_project_client.agents.upload_file_and_poll.assert_has_calls(
[call(file_path=file_path, purpose=models.FilePurpose.AGENTS, sleep_interval=0.1) for file_path in file_paths]
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"fake_message, url, title",
[
(
FakeMessageWithAnnotation(
"msg-mock-1",
"response-1",
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url1", "title1"), "text")],
),
"url1",
"title1",
),
(
FakeMessageWithUrlCitationAnnotation(
"msg-mock-2",
"response-2",
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url2", "title2"), "text")],
),
"url2",
"title2",
),
],
)
async def test_on_message_stream_mapping_url_citation(
mock_project_client: MagicMock,
fake_message: FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation,
url: str,
title: str,
) -> None:
mock_project_client.agents.create_run = AsyncMock(
return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED)
)
list = mock_list([fake_message], has_more=False)
mock_project_client.agents.list_messages = AsyncMock(return_value=list)
agent = create_agent(mock_project_client)
messages = [TextMessage(content="Hello", source="user")]
async for response in agent.on_messages_stream(messages):
assert isinstance(response, Response)
assert response.chat_message is not None
assert response.chat_message.metadata is not None
citations = json.loads(response.chat_message.metadata["citations"])
assert citations is not None
assert len(citations) == 1
assert citations[0]["url"] == url
assert citations[0]["title"] == title
@pytest.mark.asyncio
async def test_on_message_stream_mapping_file_citation(mock_project_client: MagicMock) -> None:
mock_project_client.agents.create_run = AsyncMock(
return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED)
)
expected_file_id = "file_id_1"
expected_quote = "this part of a file"
fake_message = FakeMessageWithFileCitationAnnotation(
"msg-mock-1",
"response-1",
[FakeTextFileCitationAnnotation(FakeTextFileCitationDetails(expected_file_id, expected_quote))],
)
list = mock_list([fake_message], has_more=False)
mock_project_client.agents.list_messages = AsyncMock(return_value=list)
agent = create_agent(mock_project_client)
messages = [TextMessage(content="Hello", source="user")]
async for response in agent.on_messages_stream(messages):
assert isinstance(response, Response)
assert response.chat_message is not None
assert response.chat_message.metadata is not None
citations = json.loads(response.chat_message.metadata["citations"])
assert citations is not None
assert len(citations) == 1
assert citations[0]["file_id"] == expected_file_id
assert citations[0]["text"] == expected_quote