mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
Bing grounding citations (#6370)
Adding support for Bing grounding citations to the AzureAIAgent. <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [X] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [X] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [X] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Dheeraj Bandaru <BandaruDheeraj@users.noreply.github.com>
This commit is contained in:
parent
998840f7e0
commit
881cd6a75c
@ -12,9 +12,10 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
cast,
|
||||
)
|
||||
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat import TRACE_LOGGER_NAME
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import (
|
||||
@ -38,7 +39,7 @@ from azure.ai.projects.aio import AIProjectClient
|
||||
|
||||
from ._types import AzureAIAgentState, ListToolType
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
|
||||
class AzureAIAgent(BaseChatAgent):
|
||||
@ -111,8 +112,16 @@ class AzureAIAgent(BaseChatAgent):
|
||||
metadata={"source": "AzureAIAgent"},
|
||||
)
|
||||
|
||||
# For the bing grounding tool to return the citations, the message must contain an instruction for the model to do return them.
|
||||
# For example: "Please provide citations for the answers"
|
||||
|
||||
result = await agent_with_bing_grounding.on_messages(
|
||||
messages=[TextMessage(content="What is Microsoft's annual leave policy?", source="user")],
|
||||
messages=[
|
||||
TextMessage(
|
||||
content="What is Microsoft's annual leave policy? Provide citations for your answers.",
|
||||
source="user",
|
||||
)
|
||||
],
|
||||
cancellation_token=CancellationToken(),
|
||||
message_limit=5,
|
||||
)
|
||||
@ -575,7 +584,7 @@ class AzureAIAgent(BaseChatAgent):
|
||||
if file.status != models.FileState.PROCESSED:
|
||||
raise ValueError(f"File upload failed with status {file.status}")
|
||||
|
||||
event_logger.debug(f"File uploaded successfully: {file.id}, {file_name}")
|
||||
trace_logger.debug(f"File uploaded successfully: {file.id}, {file_name}")
|
||||
|
||||
file_ids.append(file.id)
|
||||
self._uploaded_file_ids.append(file.id)
|
||||
@ -644,11 +653,11 @@ class AzureAIAgent(BaseChatAgent):
|
||||
Raises:
|
||||
ValueError: If the run fails or no message is received from the assistant
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Process all messages in sequence
|
||||
for message in messages:
|
||||
if isinstance(message, (TextMessage, MultiModalMessage)):
|
||||
@ -701,7 +710,7 @@ class AzureAIAgent(BaseChatAgent):
|
||||
# Add tool call message to inner messages
|
||||
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
|
||||
inner_messages.append(tool_call_msg)
|
||||
event_logger.debug(tool_call_msg)
|
||||
trace_logger.debug(tool_call_msg)
|
||||
yield tool_call_msg
|
||||
|
||||
# Execute tool calls and get results
|
||||
@ -725,7 +734,7 @@ class AzureAIAgent(BaseChatAgent):
|
||||
# Add tool result message to inner messages
|
||||
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
|
||||
inner_messages.append(tool_result_msg)
|
||||
event_logger.debug(tool_result_msg)
|
||||
trace_logger.debug(tool_result_msg)
|
||||
yield tool_result_msg
|
||||
|
||||
# Submit tool outputs back to the run
|
||||
@ -748,15 +757,8 @@ class AzureAIAgent(BaseChatAgent):
|
||||
# TODO support for parameter to control polling interval
|
||||
await asyncio.sleep(sleep_interval)
|
||||
|
||||
# run_steps: models.OpenAIPageableListOfRunStep = await cancellation_token.link_future(
|
||||
# asyncio.ensure_future(
|
||||
# self._project_client.agents.list_run_steps(
|
||||
# thread_id=self.thread_id,
|
||||
# run_id=run.id,
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
# Get messages after run completion
|
||||
# After run is completed, get the messages
|
||||
trace_logger.debug("Retrieving messages from thread")
|
||||
agent_messages: models.OpenAIPageableListOfThreadMessage = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._project_client.agents.list_messages(
|
||||
@ -768,18 +770,67 @@ class AzureAIAgent(BaseChatAgent):
|
||||
if not agent_messages.data:
|
||||
raise ValueError("No messages received from assistant")
|
||||
|
||||
# Get the last message's content
|
||||
last_message = agent_messages.data[0]
|
||||
# Get the last message from the agent
|
||||
last_message: Optional[models.ThreadMessage] = agent_messages.get_last_message_by_role(models.MessageRole.AGENT)
|
||||
|
||||
if not last_message:
|
||||
trace_logger.debug("No message with AGENT role found, falling back to first message")
|
||||
last_message = agent_messages.data[0] # Fallback to first message
|
||||
|
||||
if not last_message.content:
|
||||
raise ValueError(f"No content in the last message: {last_message}")
|
||||
raise ValueError("No content in the last message")
|
||||
|
||||
# Extract text content
|
||||
text_content = agent_messages.text_messages
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message.content}")
|
||||
message_text = ""
|
||||
for text_message in last_message.text_messages:
|
||||
message_text += text_message.text.value
|
||||
|
||||
# Extract citations
|
||||
citations: list[Any] = []
|
||||
|
||||
# Try accessing annotations directly
|
||||
|
||||
annotations = getattr(last_message, "annotations", [])
|
||||
|
||||
if isinstance(annotations, list) and annotations:
|
||||
annotations = cast(List[models.MessageTextUrlCitationAnnotation], annotations)
|
||||
|
||||
trace_logger.debug(f"Found {len(annotations)} annotations")
|
||||
for annotation in annotations:
|
||||
if hasattr(annotation, "url_citation"): # type: ignore
|
||||
trace_logger.debug(f"Citation found: {annotation.url_citation.url}")
|
||||
citations.append(
|
||||
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
|
||||
)
|
||||
# For backwards compatibility
|
||||
elif hasattr(last_message, "url_citation_annotations") and last_message.url_citation_annotations:
|
||||
url_annotations = cast(List[Any], last_message.url_citation_annotations)
|
||||
|
||||
trace_logger.debug(f"Found {len(url_annotations)} URL citations")
|
||||
|
||||
for annotation in url_annotations:
|
||||
citations.append(
|
||||
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
|
||||
)
|
||||
|
||||
elif hasattr(last_message, "file_citation_annotations") and last_message.file_citation_annotations:
|
||||
file_annotations = cast(List[Any], last_message.file_citation_annotations)
|
||||
|
||||
trace_logger.debug(f"Found {len(file_annotations)} URL citations")
|
||||
|
||||
for annotation in file_annotations:
|
||||
citations.append(
|
||||
{"file_id": annotation.file_citation.file_id, "title": None, "text": annotation.file_citation.quote} # type: ignore
|
||||
)
|
||||
|
||||
trace_logger.debug(f"Total citations extracted: {len(citations)}")
|
||||
|
||||
# Create the response message with citations as JSON string
|
||||
chat_message = TextMessage(
|
||||
source=self.name, content=message_text, metadata={"citations": json.dumps(citations)} if citations else {}
|
||||
)
|
||||
|
||||
# Return the assistant's response as a Response with inner messages
|
||||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
yield Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_text_message(self, content: str, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
|
||||
@ -1,13 +1,4 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
|
||||
|
||||
from autogen_core.tools import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
@ -59,3 +50,7 @@ class AzureAIAgentState(BaseModel):
|
||||
initial_message_ids: List[str] = Field(default_factory=list)
|
||||
vector_store_id: Optional[str] = None
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
def has_annotations(obj: Any) -> TypeGuard[list[models.MessageTextUrlCitationAnnotation]]:
|
||||
return obj is not None and isinstance(obj, list)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from asyncio import CancelledError
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import azure.ai.projects.models as models
|
||||
@ -37,15 +38,126 @@ class FakeMessage:
|
||||
def text_messages(self) -> List[FakeTextContent]:
|
||||
"""Returns all text message contents in the messages.
|
||||
|
||||
:rtype: List[MessageTextContent]
|
||||
:rtype: List[FakeTextContent]
|
||||
"""
|
||||
if not self.content:
|
||||
return []
|
||||
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
||||
|
||||
|
||||
class FakeMessageUrlCitationDetails:
|
||||
def __init__(self, url: str, title: str) -> None:
|
||||
self.url = url
|
||||
self.title = title
|
||||
|
||||
|
||||
class FakeTextUrlCitationAnnotation:
|
||||
def __init__(self, citation_details: FakeMessageUrlCitationDetails, text: str) -> None:
|
||||
self.type = "url_citation"
|
||||
self.url_citation = citation_details
|
||||
self.text = text
|
||||
|
||||
|
||||
class FakeTextFileCitationDetails:
|
||||
def __init__(self, file_id: str, quote: str) -> None:
|
||||
self.file_id = file_id
|
||||
self.quote = quote
|
||||
|
||||
|
||||
class FakeTextFileCitationAnnotation:
|
||||
def __init__(self, citation_details: FakeTextFileCitationDetails) -> None:
|
||||
self.type = "file_citation"
|
||||
self.file_citation = citation_details
|
||||
|
||||
|
||||
class FakeMessageWithUrlCitationAnnotation:
|
||||
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
|
||||
self.id = id
|
||||
# The agent expects content to be a list of objects with a "type" attribute.
|
||||
self.content = [FakeTextContent(text)]
|
||||
self.role = "user"
|
||||
self._annotations = annotations
|
||||
|
||||
@property
|
||||
def text_messages(self) -> List[FakeTextContent]:
|
||||
"""Returns all text message contents in the messages.
|
||||
|
||||
:rtype: List[FakeTextContent]
|
||||
"""
|
||||
if not self.content:
|
||||
return []
|
||||
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
||||
|
||||
@property
|
||||
def url_citation_annotations(self) -> List[FakeTextUrlCitationAnnotation]:
|
||||
"""Returns all URL citation annotations from text message annotations in the messages.
|
||||
|
||||
:rtype: List[FakeTextUrlCitationAnnotation]
|
||||
"""
|
||||
return self._annotations
|
||||
|
||||
|
||||
class FakeMessageWithFileCitationAnnotation:
|
||||
def __init__(self, id: str, text: str, annotations: list[FakeTextFileCitationAnnotation]) -> None:
|
||||
self.id = id
|
||||
# The agent expects content to be a list of objects with a "type" attribute.
|
||||
self.content = [FakeTextContent(text)]
|
||||
self.role = "user"
|
||||
self._annotations = annotations
|
||||
|
||||
@property
|
||||
def text_messages(self) -> List[FakeTextContent]:
|
||||
"""Returns all text message contents in the messages.
|
||||
|
||||
:rtype: List[FakeTextContent]
|
||||
"""
|
||||
if not self.content:
|
||||
return []
|
||||
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
||||
|
||||
@property
|
||||
def file_citation_annotations(self) -> List[FakeTextFileCitationAnnotation]:
|
||||
"""Returns all URL citation annotations from text message annotations in the messages.
|
||||
|
||||
:rtype: List[FakeTextFileCitationAnnotation]
|
||||
"""
|
||||
return self._annotations
|
||||
|
||||
|
||||
class FakeMessageWithAnnotation:
|
||||
def __init__(self, id: str, text: str, annotations: list[FakeTextUrlCitationAnnotation]) -> None:
|
||||
self.id = id
|
||||
# The agent expects content to be a list of objects with a "type" attribute.
|
||||
self.content = [FakeTextContent(text)]
|
||||
self.role = "user"
|
||||
self.annotations = annotations
|
||||
|
||||
@property
|
||||
def text_messages(self) -> List[FakeTextContent]:
|
||||
"""Returns all text message contents in the messages.
|
||||
|
||||
:rtype: List[FakeTextContent]
|
||||
"""
|
||||
if not self.content:
|
||||
return []
|
||||
return [content for content in self.content if isinstance(content, FakeTextContent)]
|
||||
|
||||
|
||||
FakeMessageType = Union[
|
||||
ThreadMessage
|
||||
| FakeMessage
|
||||
| FakeMessageWithAnnotation
|
||||
| FakeMessageWithUrlCitationAnnotation
|
||||
| FakeMessageWithFileCitationAnnotation
|
||||
]
|
||||
|
||||
|
||||
class FakeOpenAIPageableListOfThreadMessage:
|
||||
def __init__(self, data: List[ThreadMessage | FakeMessage], has_more: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data: List[FakeMessageType],
|
||||
has_more: bool = False,
|
||||
) -> None:
|
||||
self.data = data
|
||||
self._has_more = has_more
|
||||
|
||||
@ -62,9 +174,31 @@ class FakeOpenAIPageableListOfThreadMessage:
|
||||
texts = [content for msg in self.data for content in msg.text_messages] # type: ignore
|
||||
return texts # type: ignore
|
||||
|
||||
def get_last_message_by_role(
|
||||
self, role: models.MessageRole
|
||||
) -> Optional[ThreadMessage | FakeMessage | FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation]:
|
||||
"""Returns the last message from a sender in the specified role.
|
||||
|
||||
def mock_list() -> FakeOpenAIPageableListOfThreadMessage:
|
||||
return FakeOpenAIPageableListOfThreadMessage([FakeMessage("msg-mock", "response")])
|
||||
:param role: The role of the sender.
|
||||
:type role: MessageRole
|
||||
|
||||
:return: The last message from a sender in the specified role.
|
||||
:rtype: ~azure.ai.projects.models.ThreadMessage
|
||||
"""
|
||||
for msg in self.data:
|
||||
if msg.role == role:
|
||||
return msg # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def mock_list(
|
||||
data: Optional[List[FakeMessageType]] = None,
|
||||
has_more: bool = False,
|
||||
) -> FakeOpenAIPageableListOfThreadMessage:
|
||||
if data is None or len(data) == 0:
|
||||
data = [FakeMessage("msg-mock", "response")]
|
||||
|
||||
return FakeOpenAIPageableListOfThreadMessage(data, has_more=has_more)
|
||||
|
||||
|
||||
def create_agent(
|
||||
@ -207,47 +341,6 @@ async def test_on_upload_for_file_search(mock_project_client: MagicMock) -> None
|
||||
mock_project_client.agents.create_vector_store_file_batch_and_poll.assert_called_once()
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_add_tools(mock_project_client: MagicMock) -> None:
|
||||
# agent = create_agent(mock_project_client)
|
||||
|
||||
# tools: Optional[ListToolType] = ["file_search", "code_interpreter"]
|
||||
# converted_tools: List[ToolDefinition] = []
|
||||
# agent._add_tools(tools, converted_tools)
|
||||
|
||||
# assert len(converted_tools) == 2
|
||||
# assert isinstance(converted_tools[0], models.FileSearchToolDefinition)
|
||||
# assert isinstance(converted_tools[1], models.CodeInterpreterToolDefinition)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_ensure_initialized(mock_project_client: MagicMock) -> None:
|
||||
# agent = create_agent(mock_project_client)
|
||||
|
||||
# await agent._ensure_initialized(create_new_agent=True, create_new_thread=True)
|
||||
|
||||
# assert agent._agent is not None
|
||||
# assert agent._thread is not None
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_execute_tool_call(mock_project_client: MagicMock) -> None:
|
||||
# mock_tool = MagicMock()
|
||||
# mock_tool.name = "test_tool"
|
||||
# mock_tool.run_json = AsyncMock(return_value={"result": "success"})
|
||||
# mock_tool.return_value_as_string = MagicMock(return_value="success")
|
||||
|
||||
# agent = create_agent(mock_project_client)
|
||||
|
||||
# agent._original_tools = [mock_tool]
|
||||
|
||||
# tool_call = FunctionCall(id="test_tool", name="test_tool", arguments="{}")
|
||||
# result = await agent._execute_tool_call(tool_call, CancellationToken())
|
||||
|
||||
# assert result == "success"
|
||||
# mock_tool.run_json.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_files(mock_project_client: MagicMock) -> None:
|
||||
mock_project_client.agents.create_vector_store_file_batch_and_poll = AsyncMock()
|
||||
@ -590,3 +683,96 @@ async def test_uploading_multiple_files(
|
||||
mock_project_client.agents.upload_file_and_poll.assert_has_calls(
|
||||
[call(file_path=file_path, purpose=models.FilePurpose.AGENTS, sleep_interval=0.1) for file_path in file_paths]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"fake_message, url, title",
|
||||
[
|
||||
(
|
||||
FakeMessageWithAnnotation(
|
||||
"msg-mock-1",
|
||||
"response-1",
|
||||
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url1", "title1"), "text")],
|
||||
),
|
||||
"url1",
|
||||
"title1",
|
||||
),
|
||||
(
|
||||
FakeMessageWithUrlCitationAnnotation(
|
||||
"msg-mock-2",
|
||||
"response-2",
|
||||
[FakeTextUrlCitationAnnotation(FakeMessageUrlCitationDetails("url2", "title2"), "text")],
|
||||
),
|
||||
"url2",
|
||||
"title2",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_on_message_stream_mapping_url_citation(
|
||||
mock_project_client: MagicMock,
|
||||
fake_message: FakeMessageWithAnnotation | FakeMessageWithUrlCitationAnnotation,
|
||||
url: str,
|
||||
title: str,
|
||||
) -> None:
|
||||
mock_project_client.agents.create_run = AsyncMock(
|
||||
return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED)
|
||||
)
|
||||
|
||||
list = mock_list([fake_message], has_more=False)
|
||||
|
||||
mock_project_client.agents.list_messages = AsyncMock(return_value=list)
|
||||
|
||||
agent = create_agent(mock_project_client)
|
||||
|
||||
messages = [TextMessage(content="Hello", source="user")]
|
||||
|
||||
async for response in agent.on_messages_stream(messages):
|
||||
assert isinstance(response, Response)
|
||||
assert response.chat_message is not None
|
||||
assert response.chat_message.metadata is not None
|
||||
|
||||
citations = json.loads(response.chat_message.metadata["citations"])
|
||||
assert citations is not None
|
||||
|
||||
assert len(citations) == 1
|
||||
|
||||
assert citations[0]["url"] == url
|
||||
assert citations[0]["title"] == title
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_stream_mapping_file_citation(mock_project_client: MagicMock) -> None:
|
||||
mock_project_client.agents.create_run = AsyncMock(
|
||||
return_value=MagicMock(id="run-id", status=models.RunStatus.COMPLETED)
|
||||
)
|
||||
|
||||
expected_file_id = "file_id_1"
|
||||
expected_quote = "this part of a file"
|
||||
|
||||
fake_message = FakeMessageWithFileCitationAnnotation(
|
||||
"msg-mock-1",
|
||||
"response-1",
|
||||
[FakeTextFileCitationAnnotation(FakeTextFileCitationDetails(expected_file_id, expected_quote))],
|
||||
)
|
||||
|
||||
list = mock_list([fake_message], has_more=False)
|
||||
|
||||
mock_project_client.agents.list_messages = AsyncMock(return_value=list)
|
||||
|
||||
agent = create_agent(mock_project_client)
|
||||
|
||||
messages = [TextMessage(content="Hello", source="user")]
|
||||
|
||||
async for response in agent.on_messages_stream(messages):
|
||||
assert isinstance(response, Response)
|
||||
assert response.chat_message is not None
|
||||
assert response.chat_message.metadata is not None
|
||||
|
||||
citations = json.loads(response.chat_message.metadata["citations"])
|
||||
assert citations is not None
|
||||
|
||||
assert len(citations) == 1
|
||||
|
||||
assert citations[0]["file_id"] == expected_file_id
|
||||
assert citations[0]["text"] == expected_quote
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user