From c60a60e3c9ecd9d9f13304a289aece0cb15e44c2 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Wed, 18 Sep 2024 21:21:09 +1000 Subject: [PATCH] Add OAI image token count logic (#518) --- .../components/models/_openai_client.py | 90 +++++++++++++++++-- .../autogen-core/tests/test_model_client.py | 38 +++++++- 2 files changed, 117 insertions(+), 11 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py b/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py index 7e81a4ff4..b2232756b 100644 --- a/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py +++ b/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py @@ -2,6 +2,7 @@ import asyncio import inspect import json import logging +import math import re import warnings from typing import ( @@ -206,6 +207,55 @@ def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]: return tool_message_to_oai(message) +def calculate_vision_tokens(image: Image, detail: str = "auto") -> int: + MAX_LONG_EDGE = 2048 + BASE_TOKEN_COUNT = 85 + TOKENS_PER_TILE = 170 + MAX_SHORT_EDGE = 768 + TILE_SIZE = 512 + + if detail == "low": + return BASE_TOKEN_COUNT + + width, height = image.image.size + + # Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary + + if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE: + aspect_ratio = width / height + if aspect_ratio > 1: + # Width is greater than height + width = MAX_LONG_EDGE + height = int(MAX_LONG_EDGE / aspect_ratio) + else: + # Height is greater than or equal to width + height = MAX_LONG_EDGE + width = int(MAX_LONG_EDGE * aspect_ratio) + + # Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE + aspect_ratio = width / height + if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE: + if aspect_ratio > 1: + # Width is greater than height + height = MAX_SHORT_EDGE + width = int(MAX_SHORT_EDGE * aspect_ratio) + else: + # Height is greater than or equal to width + width = MAX_SHORT_EDGE + height = int(MAX_SHORT_EDGE / aspect_ratio) + + # Calculate the number of tiles based on TILE_SIZE + + tiles_width = math.ceil(width / TILE_SIZE) + tiles_height = math.ceil(height / TILE_SIZE) + total_tiles = tiles_width * tiles_height + # Calculate the total tokens based on the number of tiles and the base token count + + total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles + + return total_tokens + + def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage: return RequestUsage( prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens, @@ -604,15 +654,37 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): for key, value in oai_message_part.items(): if value is None: continue - if not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: - trace_logger.warning(f"Could not convert {value} to string, skipping.") - continue - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name + + if isinstance(message, UserMessage) and isinstance(value, list): + typed_message_value = cast(List[ChatCompletionContentPartParam], value) + + assert len(typed_message_value) == len( + message.content + ), "Mismatch in message content and typed message value" + + # We need image properties that are only in the original message + for part, content_part in zip(typed_message_value, message.content, strict=False): + if isinstance(content_part, Image): + # TODO: add detail parameter + num_tokens += calculate_vision_tokens(content_part) + elif isinstance(part, str): + num_tokens += len(encoding.encode(part)) + else: + try: + serialized_part = json.dumps(part) + num_tokens += len(encoding.encode(serialized_part)) + except TypeError: + trace_logger.warning(f"Could not convert {part} to string, skipping.") + else: + if not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: + trace_logger.warning(f"Could not convert {value} to string, skipping.") + continue + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> # Tool tokens. diff --git a/python/packages/autogen-core/tests/test_model_client.py b/python/packages/autogen-core/tests/test_model_client.py index 988a6fe6a..8fa75b154 100644 --- a/python/packages/autogen-core/tests/test_model_client.py +++ b/python/packages/autogen-core/tests/test_model_client.py @@ -1,6 +1,8 @@ import asyncio -from typing import Any, AsyncGenerator, List +from typing import Any, AsyncGenerator, List, Tuple +from unittest.mock import MagicMock, patch +from autogen_core.components.models._openai_client import calculate_vision_tokens import pytest from autogen_core.base import CancellationToken from autogen_core.components import Image @@ -136,7 +138,7 @@ async def test_openai_chat_completion_client_create_stream_cancel(monkeypatch: p @pytest.mark.asyncio -async def test_openai_chat_completion_client_count_tokens() -> None: +async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.MonkeyPatch) -> None: client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key") messages: List[LLMMessage] = [ SystemMessage(content="Hello"), @@ -161,8 +163,40 @@ async def test_openai_chat_completion_client_count_tokens() -> None: return str(test1) + str(test2) tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")] + + mockcalculate_vision_tokens = MagicMock() + monkeypatch.setattr( + "autogen_core.components.models._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens + ) + num_tokens = client.count_tokens(messages, tools=tools) assert num_tokens + # Check that calculate_vision_tokens was called + mockcalculate_vision_tokens.assert_called_once() + remaining_tokens = client.remaining_tokens(messages, tools=tools) assert remaining_tokens + + +@pytest.mark.parametrize( + "mock_size, expected_num_tokens", + [ + ((1, 1), 255), + ((512, 512), 255), + ((2048, 512), 765), + ((2048, 2048), 765), + ((512, 1024), 425), + ], +) +def test_openai_count_image_tokens(mock_size: Tuple[int, int], expected_num_tokens: int) -> None: + # Step 1: Mock the Image class with only the 'image' attribute + mock_image_attr = MagicMock() + mock_image_attr.size = mock_size + + mock_image = MagicMock() + mock_image.image = mock_image_attr + + # Directly call calculate_vision_tokens and check the result + calculated_tokens = calculate_vision_tokens(mock_image, detail="auto") + assert calculated_tokens == expected_num_tokens