mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 14:09:00 +00:00
feat: Add thought process handling in tool calls and expose ThoughtEvent through stream in AgentChat (#5500)
Resolves #5192 Test ```python import asyncio import os from random import randint from typing import List from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.ui import Console async def get_current_time(city: str) -> str: return f"The current time in {city} is {randint(0, 23)}:{randint(0, 59)}." tools: List[BaseTool] = [ FunctionTool( get_current_time, name="get_current_time", description="Get current time for a city.", ), ] model_client = OpenAIChatCompletionClient( model="anthropic/claude-3.5-haiku-20241022", base_url="https://openrouter.ai/api/v1", api_key=os.environ["OPENROUTER_API_KEY"], model_info={ "family": "claude-3.5-haiku", "function_calling": True, "vision": False, "json_output": False, } ) agent = AssistantAgent( name="Agent", model_client=model_client, tools=tools, system_message= "You are an assistant with some tools that can be used to answer some questions", ) async def main() -> None: await Console(agent.run_stream(task="What is current time of Paris and Toronto?")) asyncio.run(main()) ``` ``` ---------- user ---------- What is current time of Paris and Toronto? ---------- Agent ---------- I'll help you find the current time for Paris and Toronto by using the get_current_time function for each city. ---------- Agent ---------- [FunctionCall(id='toolu_01NwP3fNAwcYKn1x656Dq9xW', arguments='{"city": "Paris"}', name='get_current_time'), FunctionCall(id='toolu_018d4cWSy3TxXhjgmLYFrfRt', arguments='{"city": "Toronto"}', name='get_current_time')] ---------- Agent ---------- [FunctionExecutionResult(content='The current time in Paris is 1:10.', call_id='toolu_01NwP3fNAwcYKn1x656Dq9xW', is_error=False), FunctionExecutionResult(content='The current time in Toronto is 7:28.', call_id='toolu_018d4cWSy3TxXhjgmLYFrfRt', is_error=False)] ---------- Agent ---------- The current time in Paris is 1:10. The current time in Toronto is 7:28. ``` --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
45c6d133c2
commit
7784f44ea6
@ -44,6 +44,7 @@ from ..messages import (
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
@ -418,7 +419,15 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
)
|
||||
|
||||
# Add the response to the model context.
|
||||
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
||||
await self._model_context.add_message(
|
||||
AssistantMessage(content=model_result.content, source=self.name, thought=model_result.thought)
|
||||
)
|
||||
|
||||
# Add thought to the inner messages.
|
||||
if model_result.thought:
|
||||
thought_event = ThoughtEvent(content=model_result.thought, source=self.name)
|
||||
inner_messages.append(thought_event)
|
||||
yield thought_event
|
||||
|
||||
# Check if the response is a string and return it.
|
||||
if isinstance(model_result.content, str):
|
||||
@ -479,7 +488,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
# Current context for handoff.
|
||||
handoff_context: List[LLMMessage] = []
|
||||
if len(tool_calls) > 0:
|
||||
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
|
||||
handoff_context.append(
|
||||
AssistantMessage(content=tool_calls, source=self.name, thought=model_result.thought)
|
||||
)
|
||||
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
|
||||
# Return the output messages to signal the handoff.
|
||||
yield Response(
|
||||
@ -515,7 +526,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
assert isinstance(reflection_model_result.content, str)
|
||||
# Add the response to the model context.
|
||||
await self._model_context.add_message(
|
||||
AssistantMessage(content=reflection_model_result.content, source=self.name)
|
||||
AssistantMessage(
|
||||
content=reflection_model_result.content, source=self.name, thought=reflection_model_result.thought
|
||||
)
|
||||
)
|
||||
# Yield the response.
|
||||
yield Response(
|
||||
|
||||
@ -137,6 +137,17 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||
|
||||
|
||||
class ThoughtEvent(BaseAgentEvent):
|
||||
"""An event signaling the thought process of an agent.
|
||||
It is used to communicate the reasoning tokens generated by a reasoning model,
|
||||
or the extra text content generated by a function call."""
|
||||
|
||||
content: str
|
||||
"""The thought process."""
|
||||
|
||||
type: Literal["ThoughtEvent"] = "ThoughtEvent"
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
@ -148,7 +159,8 @@ AgentEvent = Annotated[
|
||||
| ToolCallExecutionEvent
|
||||
| MemoryQueryEvent
|
||||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent,
|
||||
| ModelClientStreamingChunkEvent
|
||||
| ThoughtEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
||||
@ -168,4 +180,5 @@ __all__ = [
|
||||
"MemoryQueryEvent",
|
||||
"UserInputRequestedEvent",
|
||||
"ModelClientStreamingChunkEvent",
|
||||
"ThoughtEvent",
|
||||
]
|
||||
|
||||
@ -17,6 +17,7 @@ from autogen_agentchat.messages import (
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
ThoughtEvent,
|
||||
)
|
||||
from autogen_core import ComponentModel, FunctionCall, Image
|
||||
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
|
||||
@ -89,7 +90,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
content="Calling pass function",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
@ -151,18 +152,20 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
result = await agent.run(task="task")
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert len(result.messages) == 5
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].models_usage is None
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert result.messages[1].models_usage is not None
|
||||
assert result.messages[1].models_usage.completion_tokens == 5
|
||||
assert result.messages[1].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert result.messages[2].models_usage is None
|
||||
assert isinstance(result.messages[3], ToolCallSummaryMessage)
|
||||
assert result.messages[3].content == "pass"
|
||||
assert isinstance(result.messages[1], ThoughtEvent)
|
||||
assert result.messages[1].content == "Calling pass function"
|
||||
assert isinstance(result.messages[2], ToolCallRequestEvent)
|
||||
assert result.messages[2].models_usage is not None
|
||||
assert result.messages[2].models_usage.completion_tokens == 5
|
||||
assert result.messages[2].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[3], ToolCallExecutionEvent)
|
||||
assert result.messages[3].models_usage is None
|
||||
assert isinstance(result.messages[4], ToolCallSummaryMessage)
|
||||
assert result.messages[4].content == "pass"
|
||||
assert result.messages[4].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
@ -302,7 +305,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
content="Calling pass and echo functions",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
@ -380,30 +383,32 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
result = await agent.run(task="task")
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert len(result.messages) == 5
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].models_usage is None
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert result.messages[1].content == [
|
||||
assert isinstance(result.messages[1], ThoughtEvent)
|
||||
assert result.messages[1].content == "Calling pass and echo functions"
|
||||
assert isinstance(result.messages[2], ToolCallRequestEvent)
|
||||
assert result.messages[2].content == [
|
||||
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
|
||||
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
|
||||
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
|
||||
]
|
||||
assert result.messages[1].models_usage is not None
|
||||
assert result.messages[1].models_usage.completion_tokens == 5
|
||||
assert result.messages[1].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert result.messages[2].models_usage is not None
|
||||
assert result.messages[2].models_usage.completion_tokens == 5
|
||||
assert result.messages[2].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[3], ToolCallExecutionEvent)
|
||||
expected_content = [
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="2", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="3", content="task3", is_error=False),
|
||||
]
|
||||
for expected in expected_content:
|
||||
assert expected in result.messages[2].content
|
||||
assert result.messages[2].models_usage is None
|
||||
assert isinstance(result.messages[3], ToolCallSummaryMessage)
|
||||
assert result.messages[3].content == "pass\npass\ntask3"
|
||||
assert expected in result.messages[3].content
|
||||
assert result.messages[3].models_usage is None
|
||||
assert isinstance(result.messages[4], ToolCallSummaryMessage)
|
||||
assert result.messages[4].content == "pass\npass\ntask3"
|
||||
assert result.messages[4].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
|
||||
@ -44,6 +44,9 @@ class AssistantMessage(BaseModel):
|
||||
content: Union[str, List[FunctionCall]]
|
||||
"""The content of the message."""
|
||||
|
||||
thought: str | None = None
|
||||
"""The reasoning text for the completion if available. Used for reasoning model and additional text content besides function calls."""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
||||
|
||||
@ -208,11 +208,19 @@ def assistant_message_to_oai(
|
||||
) -> ChatCompletionAssistantMessageParam:
|
||||
assert_valid_name(message.source)
|
||||
if isinstance(message.content, list):
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
if message.thought is not None:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.thought,
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.content,
|
||||
@ -572,6 +580,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
# Detect whether it is a function call or not.
|
||||
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
|
||||
content: Union[str, List[FunctionCall]]
|
||||
thought: str | None = None
|
||||
if choice.message.function_call is not None:
|
||||
raise ValueError("function_call is deprecated and is not supported by this model client.")
|
||||
elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0:
|
||||
@ -583,11 +592,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
stacklevel=2,
|
||||
)
|
||||
if choice.message.content is not None and choice.message.content != "":
|
||||
warnings.warn(
|
||||
"Both tool_calls and content are present in the message. "
|
||||
"This is unexpected. content will be ignored, tool_calls will be used.",
|
||||
stacklevel=2,
|
||||
)
|
||||
# Put the content in the thought field.
|
||||
thought = choice.message.content
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
@ -626,8 +632,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
response = CreateResult(
|
||||
finish_reason=normalize_stop_reason(finish_reason),
|
||||
@ -788,6 +792,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
content_deltas.append(choice.delta.content)
|
||||
if len(choice.delta.content) > 0:
|
||||
yield choice.delta.content
|
||||
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
|
||||
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
|
||||
continue
|
||||
|
||||
# Otherwise, get tool calls
|
||||
@ -832,20 +838,24 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if len(content_deltas) > 1:
|
||||
thought: str | None = None
|
||||
if full_tool_calls:
|
||||
# This is a tool call.
|
||||
content = list(full_tool_calls.values())
|
||||
if len(content_deltas) > 1:
|
||||
# Put additional text content in the thought field.
|
||||
thought = "".join(content_deltas)
|
||||
elif len(content_deltas) > 0:
|
||||
# This is a text-only content.
|
||||
content = "".join(content_deltas)
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
else:
|
||||
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
|
||||
content = ""
|
||||
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
||||
# for tool_call in full_tool_calls.values():
|
||||
# # value = json.dumps(tool_call)
|
||||
# # completion_tokens += count_token(value, model=model)
|
||||
# completion_tokens += 0
|
||||
content = list(full_tool_calls.values())
|
||||
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -854,8 +864,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
result = CreateResult(
|
||||
finish_reason=normalize_stop_reason(stop_reason),
|
||||
|
||||
@ -26,7 +26,12 @@ from autogen_ext.models.openai._openai_client import calculate_vision_tokens, co
|
||||
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
@ -734,7 +739,7 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
# Warning completion when content is not None.
|
||||
# Thought field is populated when content is not None.
|
||||
ChatCompletion(
|
||||
id="id4",
|
||||
choices=[
|
||||
@ -850,13 +855,11 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
# Warning completion when content is not None.
|
||||
with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
||||
)
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
# Thought field is populated when content is not None.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.thought == "I should make a tool call."
|
||||
|
||||
# Should not be returning tool calls when the tool_calls are empty
|
||||
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
||||
@ -872,6 +875,85 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calling_with_stream(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
mock_chunks = [
|
||||
# generate the list of mock chunk content
|
||||
MockChunkDefinition(
|
||||
chunk_choice=ChunkChoice(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=mock_chunk_content,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
usage=None,
|
||||
)
|
||||
for mock_chunk_content in mock_chunks_content
|
||||
] + [
|
||||
# generate the function call chunk
|
||||
MockChunkDefinition(
|
||||
chunk_choice=ChunkChoice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="1",
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
usage=None,
|
||||
)
|
||||
]
|
||||
for mock_chunk in mock_chunks:
|
||||
await asyncio.sleep(0.1)
|
||||
yield ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[mock_chunk.chunk_choice],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
usage=mock_chunk.usage,
|
||||
)
|
||||
|
||||
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
stream = kwargs.get("stream", False)
|
||||
if not stream:
|
||||
raise ValueError("Stream is not False")
|
||||
else:
|
||||
return _mock_create_stream(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
|
||||
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
||||
stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert chunks[0] == "Hello"
|
||||
assert chunks[1] == " Another Hello"
|
||||
assert chunks[2] == " Yet Another Hello"
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert chunks[-1].finish_reason == "function_calls"
|
||||
assert chunks[-1].thought == "Hello Another Hello Yet Another Hello"
|
||||
|
||||
|
||||
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test basic completion
|
||||
create_result = await model_client.create(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user