feat: Add thought process handling in tool calls and expose ThoughtEvent through stream in AgentChat (#5500)

Resolves #5192

Test

```python
import asyncio
import os
from random import randint
from typing import List
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console

async def get_current_time(city: str) -> str:
    return f"The current time in {city} is {randint(0, 23)}:{randint(0, 59)}."

tools: List[BaseTool] = [
    FunctionTool(
        get_current_time,
        name="get_current_time",
        description="Get current time for a city.",
    ),
]

model_client = OpenAIChatCompletionClient(
    model="anthropic/claude-3.5-haiku-20241022",
    base_url="https://openrouter.ai/api/v1",
    api_key=os.environ["OPENROUTER_API_KEY"],
    model_info={
        "family": "claude-3.5-haiku",
        "function_calling": True,
        "vision": False,
        "json_output": False,
    }
)

agent = AssistantAgent(
    name="Agent",
    model_client=model_client,
    tools=tools,
    system_message= "You are an assistant with some tools that can be used to answer some questions",
)

async def main() -> None:
    await Console(agent.run_stream(task="What is current time of Paris and Toronto?"))

asyncio.run(main())
```

```
---------- user ----------
What is current time of Paris and Toronto?
---------- Agent ----------
I'll help you find the current time for Paris and Toronto by using the get_current_time function for each city.
---------- Agent ----------
[FunctionCall(id='toolu_01NwP3fNAwcYKn1x656Dq9xW', arguments='{"city": "Paris"}', name='get_current_time'), FunctionCall(id='toolu_018d4cWSy3TxXhjgmLYFrfRt', arguments='{"city": "Toronto"}', name='get_current_time')]
---------- Agent ----------
[FunctionExecutionResult(content='The current time in Paris is 1:10.', call_id='toolu_01NwP3fNAwcYKn1x656Dq9xW', is_error=False), FunctionExecutionResult(content='The current time in Toronto is 7:28.', call_id='toolu_018d4cWSy3TxXhjgmLYFrfRt', is_error=False)]
---------- Agent ----------
The current time in Paris is 1:10.
The current time in Toronto is 7:28.
```

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-02-21 14:58:32 -07:00 committed by GitHub
parent 45c6d133c2
commit 7784f44ea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 60 deletions

View File

@ -44,6 +44,7 @@ from ..messages import (
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
@ -418,7 +419,15 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
await self._model_context.add_message(
AssistantMessage(content=model_result.content, source=self.name, thought=model_result.thought)
)
# Add thought to the inner messages.
if model_result.thought:
thought_event = ThoughtEvent(content=model_result.thought, source=self.name)
inner_messages.append(thought_event)
yield thought_event
# Check if the response is a string and return it.
if isinstance(model_result.content, str):
@ -479,7 +488,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(
AssistantMessage(content=tool_calls, source=self.name, thought=model_result.thought)
)
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
@ -515,7 +526,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
assert isinstance(reflection_model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(
AssistantMessage(content=reflection_model_result.content, source=self.name)
AssistantMessage(
content=reflection_model_result.content, source=self.name, thought=reflection_model_result.thought
)
)
# Yield the response.
yield Response(

View File

@ -137,6 +137,17 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
class ThoughtEvent(BaseAgentEvent):
"""An event signaling the thought process of an agent.
It is used to communicate the reasoning tokens generated by a reasoning model,
or the extra text content generated by a function call."""
content: str
"""The thought process."""
type: Literal["ThoughtEvent"] = "ThoughtEvent"
ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
@ -148,7 +159,8 @@ AgentEvent = Annotated[
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent,
| ModelClientStreamingChunkEvent
| ThoughtEvent,
Field(discriminator="type"),
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
@ -168,4 +180,5 @@ __all__ = [
"MemoryQueryEvent",
"UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
"ThoughtEvent",
]

View File

@ -17,6 +17,7 @@ from autogen_agentchat.messages import (
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
ThoughtEvent,
)
from autogen_core import ComponentModel, FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
@ -89,7 +90,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
content="Calling pass function",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
@ -151,18 +152,20 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await agent.run(task="task")
assert len(result.messages) == 4
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass function"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass"
assert result.messages[4].models_usage is None
# Test streaming.
mock.curr_index = 0 # Reset the mock
@ -302,7 +305,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
content="Calling pass and echo functions",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
@ -380,30 +383,32 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await agent.run(task="task")
assert len(result.messages) == 4
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass and echo functions"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
FunctionExecutionResult(call_id="2", content="pass", is_error=False),
FunctionExecutionResult(call_id="3", content="task3", is_error=False),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert expected in result.messages[3].content
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass\npass\ntask3"
assert result.messages[4].models_usage is None
# Test streaming.
mock.curr_index = 0 # Reset the mock

View File

@ -44,6 +44,9 @@ class AssistantMessage(BaseModel):
content: Union[str, List[FunctionCall]]
"""The content of the message."""
thought: str | None = None
"""The reasoning text for the completion if available. Used for reasoning model and additional text content besides function calls."""
source: str
"""The name of the agent that sent this message."""

View File

@ -208,11 +208,19 @@ def assistant_message_to_oai(
) -> ChatCompletionAssistantMessageParam:
assert_valid_name(message.source)
if isinstance(message.content, list):
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
if message.thought is not None:
return ChatCompletionAssistantMessageParam(
content=message.thought,
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
content=message.content,
@ -572,6 +580,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
thought: str | None = None
if choice.message.function_call is not None:
raise ValueError("function_call is deprecated and is not supported by this model client.")
elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0:
@ -583,11 +592,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
stacklevel=2,
)
if choice.message.content is not None and choice.message.content != "":
warnings.warn(
"Both tool_calls and content are present in the message. "
"This is unexpected. content will be ignored, tool_calls will be used.",
stacklevel=2,
)
# Put the content in the thought field.
thought = choice.message.content
# NOTE: If OAI response type changes, this will need to be updated
content = []
for tool_call in choice.message.tool_calls:
@ -626,8 +632,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
response = CreateResult(
finish_reason=normalize_stop_reason(finish_reason),
@ -788,6 +792,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
continue
# Otherwise, get tool calls
@ -832,20 +838,24 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]]
if len(content_deltas) > 1:
thought: str | None = None
if full_tool_calls:
# This is a tool call.
content = list(full_tool_calls.values())
if len(content_deltas) > 1:
# Put additional text content in the thought field.
thought = "".join(content_deltas)
elif len(content_deltas) > 0:
# This is a text-only content.
content = "".join(content_deltas)
if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
else:
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
content = ""
if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = list(full_tool_calls.values())
usage = RequestUsage(
prompt_tokens=prompt_tokens,
@ -854,8 +864,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),

View File

@ -26,7 +26,12 @@ from autogen_ext.models.openai._openai_client import calculate_vision_tokens, co
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
@ -734,7 +739,7 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Warning completion when content is not None.
# Thought field is populated when content is not None.
ChatCompletion(
id="id4",
choices=[
@ -850,13 +855,11 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"
# Warning completion when content is not None.
with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"):
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
)
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"
# Thought field is populated when content is not None.
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"
assert create_result.thought == "I should make a tool call."
# Should not be returning tool calls when the tool_calls are empty
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
@ -872,6 +875,85 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.finish_reason == "function_calls"
@pytest.mark.asyncio
async def test_tool_calling_with_stream(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
model = resolve_model(kwargs.get("model", "gpt-4o"))
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
mock_chunks = [
# generate the list of mock chunk content
MockChunkDefinition(
chunk_choice=ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=mock_chunk_content,
role="assistant",
),
),
usage=None,
)
for mock_chunk_content in mock_chunks_content
] + [
# generate the function call chunk
MockChunkDefinition(
chunk_choice=ChunkChoice(
finish_reason="tool_calls",
index=0,
delta=ChoiceDelta(
content=None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="1",
type="function",
function=ChoiceDeltaToolCallFunction(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
),
),
usage=None,
)
]
for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield ChatCompletionChunk(
id="id",
choices=[mock_chunk.chunk_choice],
created=0,
model=model,
object="chat.completion.chunk",
usage=mock_chunk.usage,
)
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
stream = kwargs.get("stream", False)
if not stream:
raise ValueError("Stream is not False")
else:
return _mock_create_stream(*args, **kwargs)
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
pass_tool = FunctionTool(_pass_function, description="pass tool.")
stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert chunks[-1].finish_reason == "function_calls"
assert chunks[-1].thought == "Hello Another Hello Yet Another Hello"
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
# Test basic completion
create_result = await model_client.create(