Rename model_usage to models_usage. (#4053)

This commit is contained in:
Eric Zhu 2024-11-04 09:25:53 -08:00 committed by GitHub
parent f46e52e6ff
commit 16e64c4c10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 27 deletions

View File

@ -266,8 +266,8 @@ class AssistantAgent(BaseChatAgent):
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage))
yield ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage)
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage))
yield ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
# Execute the tool calls.
results = await asyncio.gather(
@ -303,7 +303,7 @@ class AssistantAgent(BaseChatAgent):
assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, model_usage=result.usage),
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)

View File

@ -11,7 +11,7 @@ class BaseMessage(BaseModel):
source: str
"""The name of the agent that sent this message."""
model_usage: RequestUsage | None = None
models_usage: RequestUsage | None = None
"""The model client usage incurred when producing this message."""

View File

@ -131,10 +131,10 @@ class TokenUsageTermination(TerminationCondition):
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
if message.model_usage is not None:
self._prompt_token_count += message.model_usage.prompt_tokens
self._completion_token_count += message.model_usage.completion_tokens
self._total_token_count += message.model_usage.prompt_tokens + message.model_usage.completion_tokens
if message.models_usage is not None:
self._prompt_token_count += message.models_usage.prompt_tokens
self._completion_token_count += message.models_usage.completion_tokens
self._total_token_count += message.models_usage.prompt_tokens + message.models_usage.completion_tokens
if self.terminated:
content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}."
return StopMessage(content=content, source="TokenUsageTermination")

View File

@ -113,17 +113,17 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 5
assert result.messages[1].model_usage.prompt_tokens == 10
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].model_usage is not None
assert result.messages[3].model_usage.completion_tokens == 5
assert result.messages[3].model_usage.prompt_tokens == 10
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
mock._curr_index = 0 # pyright: ignore
@ -181,17 +181,17 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 43
assert result.messages[1].model_usage.prompt_tokens == 42
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 43
assert result.messages[1].models_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == handoff.message
assert result.messages[3].target == handoff.target
assert result.messages[3].model_usage is None
assert result.messages[3].models_usage is None
# Test streaming.
mock._curr_index = 0 # pyright: ignore

View File

@ -66,7 +66,7 @@ async def test_token_usage_termination() -> None:
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
)
]
)
@ -77,10 +77,10 @@ async def test_token_usage_termination() -> None:
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
TextMessage(
content="World", source="agent", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
content="World", source="agent", models_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
]
)
@ -91,10 +91,10 @@ async def test_token_usage_termination() -> None:
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
),
TextMessage(
content="stop", source="user", model_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
content="stop", source="user", models_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
),
]
)