Add OAI image token count logic (#518)

This commit is contained in:
Leonardo Pinheiro 2024-09-18 21:21:09 +10:00 committed by GitHub
parent 53aabb88cb
commit c60a60e3c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 117 additions and 11 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import inspect
import json
import logging
import math
import re
import warnings
from typing import (
@ -206,6 +207,55 @@ def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
return tool_message_to_oai(message)
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
MAX_LONG_EDGE = 2048
BASE_TOKEN_COUNT = 85
TOKENS_PER_TILE = 170
MAX_SHORT_EDGE = 768
TILE_SIZE = 512
if detail == "low":
return BASE_TOKEN_COUNT
width, height = image.image.size
# Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary
if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE:
aspect_ratio = width / height
if aspect_ratio > 1:
# Width is greater than height
width = MAX_LONG_EDGE
height = int(MAX_LONG_EDGE / aspect_ratio)
else:
# Height is greater than or equal to width
height = MAX_LONG_EDGE
width = int(MAX_LONG_EDGE * aspect_ratio)
# Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE
aspect_ratio = width / height
if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE:
if aspect_ratio > 1:
# Width is greater than height
height = MAX_SHORT_EDGE
width = int(MAX_SHORT_EDGE * aspect_ratio)
else:
# Height is greater than or equal to width
width = MAX_SHORT_EDGE
height = int(MAX_SHORT_EDGE / aspect_ratio)
# Calculate the number of tiles based on TILE_SIZE
tiles_width = math.ceil(width / TILE_SIZE)
tiles_height = math.ceil(height / TILE_SIZE)
total_tiles = tiles_width * tiles_height
# Calculate the total tokens based on the number of tiles and the base token count
total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles
return total_tokens
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
return RequestUsage(
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
@ -604,6 +654,28 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
for key, value in oai_message_part.items():
if value is None:
continue
if isinstance(message, UserMessage) and isinstance(value, list):
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
assert len(typed_message_value) == len(
message.content
), "Mismatch in message content and typed message value"
# We need image properties that are only in the original message
for part, content_part in zip(typed_message_value, message.content, strict=False):
if isinstance(content_part, Image):
# TODO: add detail parameter
num_tokens += calculate_vision_tokens(content_part)
elif isinstance(part, str):
num_tokens += len(encoding.encode(part))
else:
try:
serialized_part = json.dumps(part)
num_tokens += len(encoding.encode(serialized_part))
except TypeError:
trace_logger.warning(f"Could not convert {part} to string, skipping.")
else:
if not isinstance(value, str):
try:
value = json.dumps(value)

View File

@ -1,6 +1,8 @@
import asyncio
from typing import Any, AsyncGenerator, List
from typing import Any, AsyncGenerator, List, Tuple
from unittest.mock import MagicMock, patch
from autogen_core.components.models._openai_client import calculate_vision_tokens
import pytest
from autogen_core.base import CancellationToken
from autogen_core.components import Image
@ -136,7 +138,7 @@ async def test_openai_chat_completion_client_create_stream_cancel(monkeypatch: p
@pytest.mark.asyncio
async def test_openai_chat_completion_client_count_tokens() -> None:
async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
messages: List[LLMMessage] = [
SystemMessage(content="Hello"),
@ -161,8 +163,40 @@ async def test_openai_chat_completion_client_count_tokens() -> None:
return str(test1) + str(test2)
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
mockcalculate_vision_tokens = MagicMock()
monkeypatch.setattr(
"autogen_core.components.models._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens
)
num_tokens = client.count_tokens(messages, tools=tools)
assert num_tokens
# Check that calculate_vision_tokens was called
mockcalculate_vision_tokens.assert_called_once()
remaining_tokens = client.remaining_tokens(messages, tools=tools)
assert remaining_tokens
@pytest.mark.parametrize(
"mock_size, expected_num_tokens",
[
((1, 1), 255),
((512, 512), 255),
((2048, 512), 765),
((2048, 2048), 765),
((512, 1024), 425),
],
)
def test_openai_count_image_tokens(mock_size: Tuple[int, int], expected_num_tokens: int) -> None:
# Step 1: Mock the Image class with only the 'image' attribute
mock_image_attr = MagicMock()
mock_image_attr.size = mock_size
mock_image = MagicMock()
mock_image.image = mock_image_attr
# Directly call calculate_vision_tokens and check the result
calculated_tokens = calculate_vision_tokens(mock_image, detail="auto")
assert calculated_tokens == expected_num_tokens