mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-12 19:40:40 +00:00
Add OAI image token count logic (#518)
This commit is contained in:
parent
53aabb88cb
commit
c60a60e3c9
@ -2,6 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
@ -206,6 +207,55 @@ def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
|
||||
return tool_message_to_oai(message)
|
||||
|
||||
|
||||
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
|
||||
MAX_LONG_EDGE = 2048
|
||||
BASE_TOKEN_COUNT = 85
|
||||
TOKENS_PER_TILE = 170
|
||||
MAX_SHORT_EDGE = 768
|
||||
TILE_SIZE = 512
|
||||
|
||||
if detail == "low":
|
||||
return BASE_TOKEN_COUNT
|
||||
|
||||
width, height = image.image.size
|
||||
|
||||
# Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary
|
||||
|
||||
if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE:
|
||||
aspect_ratio = width / height
|
||||
if aspect_ratio > 1:
|
||||
# Width is greater than height
|
||||
width = MAX_LONG_EDGE
|
||||
height = int(MAX_LONG_EDGE / aspect_ratio)
|
||||
else:
|
||||
# Height is greater than or equal to width
|
||||
height = MAX_LONG_EDGE
|
||||
width = int(MAX_LONG_EDGE * aspect_ratio)
|
||||
|
||||
# Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE
|
||||
aspect_ratio = width / height
|
||||
if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE:
|
||||
if aspect_ratio > 1:
|
||||
# Width is greater than height
|
||||
height = MAX_SHORT_EDGE
|
||||
width = int(MAX_SHORT_EDGE * aspect_ratio)
|
||||
else:
|
||||
# Height is greater than or equal to width
|
||||
width = MAX_SHORT_EDGE
|
||||
height = int(MAX_SHORT_EDGE / aspect_ratio)
|
||||
|
||||
# Calculate the number of tiles based on TILE_SIZE
|
||||
|
||||
tiles_width = math.ceil(width / TILE_SIZE)
|
||||
tiles_height = math.ceil(height / TILE_SIZE)
|
||||
total_tiles = tiles_width * tiles_height
|
||||
# Calculate the total tokens based on the number of tiles and the base token count
|
||||
|
||||
total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
||||
@ -604,15 +654,37 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||
continue
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
if isinstance(message, UserMessage) and isinstance(value, list):
|
||||
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
||||
|
||||
assert len(typed_message_value) == len(
|
||||
message.content
|
||||
), "Mismatch in message content and typed message value"
|
||||
|
||||
# We need image properties that are only in the original message
|
||||
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
||||
if isinstance(content_part, Image):
|
||||
# TODO: add detail parameter
|
||||
num_tokens += calculate_vision_tokens(content_part)
|
||||
elif isinstance(part, str):
|
||||
num_tokens += len(encoding.encode(part))
|
||||
else:
|
||||
try:
|
||||
serialized_part = json.dumps(part)
|
||||
num_tokens += len(encoding.encode(serialized_part))
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||
continue
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
|
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, List
|
||||
from typing import Any, AsyncGenerator, List, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from autogen_core.components.models._openai_client import calculate_vision_tokens
|
||||
import pytest
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import Image
|
||||
@ -136,7 +138,7 @@ async def test_openai_chat_completion_client_create_stream_cancel(monkeypatch: p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_count_tokens() -> None:
|
||||
async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Hello"),
|
||||
@ -161,8 +163,40 @@ async def test_openai_chat_completion_client_count_tokens() -> None:
|
||||
return str(test1) + str(test2)
|
||||
|
||||
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
||||
|
||||
mockcalculate_vision_tokens = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"autogen_core.components.models._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens
|
||||
)
|
||||
|
||||
num_tokens = client.count_tokens(messages, tools=tools)
|
||||
assert num_tokens
|
||||
|
||||
# Check that calculate_vision_tokens was called
|
||||
mockcalculate_vision_tokens.assert_called_once()
|
||||
|
||||
remaining_tokens = client.remaining_tokens(messages, tools=tools)
|
||||
assert remaining_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mock_size, expected_num_tokens",
|
||||
[
|
||||
((1, 1), 255),
|
||||
((512, 512), 255),
|
||||
((2048, 512), 765),
|
||||
((2048, 2048), 765),
|
||||
((512, 1024), 425),
|
||||
],
|
||||
)
|
||||
def test_openai_count_image_tokens(mock_size: Tuple[int, int], expected_num_tokens: int) -> None:
|
||||
# Step 1: Mock the Image class with only the 'image' attribute
|
||||
mock_image_attr = MagicMock()
|
||||
mock_image_attr.size = mock_size
|
||||
|
||||
mock_image = MagicMock()
|
||||
mock_image.image = mock_image_attr
|
||||
|
||||
# Directly call calculate_vision_tokens and check the result
|
||||
calculated_tokens = calculate_vision_tokens(mock_image, detail="auto")
|
||||
assert calculated_tokens == expected_num_tokens
|
||||
|
Loading…
x
Reference in New Issue
Block a user