mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-12 19:40:40 +00:00
Add OAI image token count logic (#518)
This commit is contained in:
parent
53aabb88cb
commit
c60a60e3c9
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -206,6 +207,55 @@ def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
|
|||||||
return tool_message_to_oai(message)
|
return tool_message_to_oai(message)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
|
||||||
|
MAX_LONG_EDGE = 2048
|
||||||
|
BASE_TOKEN_COUNT = 85
|
||||||
|
TOKENS_PER_TILE = 170
|
||||||
|
MAX_SHORT_EDGE = 768
|
||||||
|
TILE_SIZE = 512
|
||||||
|
|
||||||
|
if detail == "low":
|
||||||
|
return BASE_TOKEN_COUNT
|
||||||
|
|
||||||
|
width, height = image.image.size
|
||||||
|
|
||||||
|
# Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary
|
||||||
|
|
||||||
|
if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE:
|
||||||
|
aspect_ratio = width / height
|
||||||
|
if aspect_ratio > 1:
|
||||||
|
# Width is greater than height
|
||||||
|
width = MAX_LONG_EDGE
|
||||||
|
height = int(MAX_LONG_EDGE / aspect_ratio)
|
||||||
|
else:
|
||||||
|
# Height is greater than or equal to width
|
||||||
|
height = MAX_LONG_EDGE
|
||||||
|
width = int(MAX_LONG_EDGE * aspect_ratio)
|
||||||
|
|
||||||
|
# Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE
|
||||||
|
aspect_ratio = width / height
|
||||||
|
if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE:
|
||||||
|
if aspect_ratio > 1:
|
||||||
|
# Width is greater than height
|
||||||
|
height = MAX_SHORT_EDGE
|
||||||
|
width = int(MAX_SHORT_EDGE * aspect_ratio)
|
||||||
|
else:
|
||||||
|
# Height is greater than or equal to width
|
||||||
|
width = MAX_SHORT_EDGE
|
||||||
|
height = int(MAX_SHORT_EDGE / aspect_ratio)
|
||||||
|
|
||||||
|
# Calculate the number of tiles based on TILE_SIZE
|
||||||
|
|
||||||
|
tiles_width = math.ceil(width / TILE_SIZE)
|
||||||
|
tiles_height = math.ceil(height / TILE_SIZE)
|
||||||
|
total_tiles = tiles_width * tiles_height
|
||||||
|
# Calculate the total tokens based on the number of tiles and the base token count
|
||||||
|
|
||||||
|
total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles
|
||||||
|
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
|
||||||
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||||
return RequestUsage(
|
return RequestUsage(
|
||||||
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
||||||
@ -604,15 +654,37 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
for key, value in oai_message_part.items():
|
for key, value in oai_message_part.items():
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
if not isinstance(value, str):
|
|
||||||
try:
|
if isinstance(message, UserMessage) and isinstance(value, list):
|
||||||
value = json.dumps(value)
|
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
||||||
except TypeError:
|
|
||||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
assert len(typed_message_value) == len(
|
||||||
continue
|
message.content
|
||||||
num_tokens += len(encoding.encode(value))
|
), "Mismatch in message content and typed message value"
|
||||||
if key == "name":
|
|
||||||
num_tokens += tokens_per_name
|
# We need image properties that are only in the original message
|
||||||
|
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
||||||
|
if isinstance(content_part, Image):
|
||||||
|
# TODO: add detail parameter
|
||||||
|
num_tokens += calculate_vision_tokens(content_part)
|
||||||
|
elif isinstance(part, str):
|
||||||
|
num_tokens += len(encoding.encode(part))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
serialized_part = json.dumps(part)
|
||||||
|
num_tokens += len(encoding.encode(serialized_part))
|
||||||
|
except TypeError:
|
||||||
|
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
||||||
|
else:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
try:
|
||||||
|
value = json.dumps(value)
|
||||||
|
except TypeError:
|
||||||
|
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||||
|
continue
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
|
|
||||||
# Tool tokens.
|
# Tool tokens.
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, AsyncGenerator, List
|
from typing import Any, AsyncGenerator, List, Tuple
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from autogen_core.components.models._openai_client import calculate_vision_tokens
|
||||||
import pytest
|
import pytest
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components import Image
|
from autogen_core.components import Image
|
||||||
@ -136,7 +138,7 @@ async def test_openai_chat_completion_client_create_stream_cancel(monkeypatch: p
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_chat_completion_client_count_tokens() -> None:
|
async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||||
messages: List[LLMMessage] = [
|
messages: List[LLMMessage] = [
|
||||||
SystemMessage(content="Hello"),
|
SystemMessage(content="Hello"),
|
||||||
@ -161,8 +163,40 @@ async def test_openai_chat_completion_client_count_tokens() -> None:
|
|||||||
return str(test1) + str(test2)
|
return str(test1) + str(test2)
|
||||||
|
|
||||||
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
||||||
|
|
||||||
|
mockcalculate_vision_tokens = MagicMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"autogen_core.components.models._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens
|
||||||
|
)
|
||||||
|
|
||||||
num_tokens = client.count_tokens(messages, tools=tools)
|
num_tokens = client.count_tokens(messages, tools=tools)
|
||||||
assert num_tokens
|
assert num_tokens
|
||||||
|
|
||||||
|
# Check that calculate_vision_tokens was called
|
||||||
|
mockcalculate_vision_tokens.assert_called_once()
|
||||||
|
|
||||||
remaining_tokens = client.remaining_tokens(messages, tools=tools)
|
remaining_tokens = client.remaining_tokens(messages, tools=tools)
|
||||||
assert remaining_tokens
|
assert remaining_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mock_size, expected_num_tokens",
|
||||||
|
[
|
||||||
|
((1, 1), 255),
|
||||||
|
((512, 512), 255),
|
||||||
|
((2048, 512), 765),
|
||||||
|
((2048, 2048), 765),
|
||||||
|
((512, 1024), 425),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_count_image_tokens(mock_size: Tuple[int, int], expected_num_tokens: int) -> None:
|
||||||
|
# Step 1: Mock the Image class with only the 'image' attribute
|
||||||
|
mock_image_attr = MagicMock()
|
||||||
|
mock_image_attr.size = mock_size
|
||||||
|
|
||||||
|
mock_image = MagicMock()
|
||||||
|
mock_image.image = mock_image_attr
|
||||||
|
|
||||||
|
# Directly call calculate_vision_tokens and check the result
|
||||||
|
calculated_tokens = calculate_vision_tokens(mock_image, detail="auto")
|
||||||
|
assert calculated_tokens == expected_num_tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user