From 2041905acb361d564a7a4dcec83b032ecd81ec0d Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 19 Jul 2024 18:44:22 -0700 Subject: [PATCH] Add token counting to chat completion client #220 (#239) * Add token counting to chat completion client * fix mypy * ignore pyright for object type * format --- python/pyproject.toml | 4 +- .../agnext/components/models/_model_client.py | 4 + .../agnext/components/models/_model_info.py | 20 +++++ .../components/models/_openai_client.py | 76 ++++++++++++++++++- python/tests/test_model_client.py | 66 ++++++++++++++++ 5 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 python/tests/test_model_client.py diff --git a/python/pyproject.toml b/python/pyproject.toml index e98106492..63a7880b7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "pydantic>=1.10,<3", "types-aiofiles", "grpcio", - "protobuf" + "protobuf", + "tiktoken" ] [tool.hatch.envs.default] @@ -30,6 +31,7 @@ dependencies = [ "pyright==1.1.368", "mypy==1.10.0", "ruff==0.4.8", + "tiktoken", "types-Pillow", "polars", "chess", diff --git a/python/src/agnext/components/models/_model_client.py b/python/src/agnext/components/models/_model_client.py index b87e50b91..06cb39de9 100644 --- a/python/src/agnext/components/models/_model_client.py +++ b/python/src/agnext/components/models/_model_client.py @@ -48,5 +48,9 @@ class ChatCompletionClient(Protocol): def total_usage(self) -> RequestUsage: ... + def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ... + + def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ... + @property def capabilities(self) -> ModelCapabilities: ... diff --git a/python/src/agnext/components/models/_model_info.py b/python/src/agnext/components/models/_model_info.py index 9c919c0f7..06a28d6ad 100644 --- a/python/src/agnext/components/models/_model_info.py +++ b/python/src/agnext/components/models/_model_info.py @@ -77,6 +77,21 @@ _MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = { }, } +_MODEL_TOKEN_LIMITS: Dict[str, int] = { + "gpt-4o-2024-05-13": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-0125-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-1106-vision-preview": 128000, + "gpt-4-0613": 8192, + "gpt-4-32k-0613": 32768, + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k-0613": 16385, +} + def resolve_model(model: str) -> str: if model in _MODEL_POINTERS: @@ -87,3 +102,8 @@ def resolve_model(model: str) -> str: def get_capabilties(model: str) -> ModelCapabilities: resolved_model = resolve_model(model) return _MODEL_CAPABILITIES[resolved_model] + + +def get_token_limit(model: str) -> int: + resolved_model = resolve_model(model) + return _MODEL_TOKEN_LIMITS[resolved_model] diff --git a/python/src/agnext/components/models/_openai_client.py b/python/src/agnext/components/models/_openai_client.py index 68ef9acb9..f6f2597d4 100644 --- a/python/src/agnext/components/models/_openai_client.py +++ b/python/src/agnext/components/models/_openai_client.py @@ -1,4 +1,5 @@ import inspect +import json import logging import re import warnings @@ -15,6 +16,7 @@ from typing import ( cast, ) +import tiktoken from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ( ChatCompletionAssistantMessageParam, @@ -32,7 +34,7 @@ from openai.types.chat import ( from openai.types.shared_params import FunctionDefinition, FunctionParameters from typing_extensions import Unpack -from ...application.logging import EVENT_LOGGER_NAME +from ...application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME from ...application.logging.events import LLMCallEvent from .. import ( FunctionCall, @@ -53,6 +55,7 @@ from ._types import ( from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration logger = logging.getLogger(EVENT_LOGGER_NAME) +trace_logger = logging.getLogger(TRACE_LOGGER_NAME) openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs) aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs) @@ -518,6 +521,77 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): def total_usage(self) -> RequestUsage: return self._total_usage + def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + model = self._create_args["model"] + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + tokens_per_message = 3 + tokens_per_name = 1 + num_tokens = 0 + + # Message tokens. + for message in messages: + num_tokens += tokens_per_message + oai_message = to_oai_type(message) + for oai_message_part in oai_message: + for key, value in oai_message_part.items(): + if value is None: + continue + if not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: + trace_logger.warning(f"Could not convert {value} to string, skipping.") + continue + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + + # Tool tokens. + oai_tools = convert_tools(tools) + for tool in oai_tools: + function = tool["function"] + tool_tokens = len(encoding.encode(function["name"])) + if "description" in function: + tool_tokens += len(encoding.encode(function["description"])) + tool_tokens -= 2 + if "parameters" in function: + parameters = function["parameters"] + if "properties" in parameters: + assert isinstance(parameters["properties"], dict) + for propertiesKey in parameters["properties"]: # pyright: ignore + assert isinstance(propertiesKey, str) + tool_tokens += len(encoding.encode(propertiesKey)) + v = parameters["properties"][propertiesKey] # pyright: ignore + for field in v: # pyright: ignore + if field == "type": + tool_tokens += 2 + tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore + elif field == "description": + tool_tokens += 2 + tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore + elif field == "enum": + tool_tokens -= 3 + for o in v["enum"]: # pyright: ignore + tool_tokens += 3 + tool_tokens += len(encoding.encode(o)) # pyright: ignore + else: + trace_logger.warning(f"Not supported field {field}") + tool_tokens += 11 + if len(parameters["properties"]) == 0: # pyright: ignore + tool_tokens -= 2 + num_tokens += tool_tokens + num_tokens += 12 + return num_tokens + + def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + token_limit = _model_info.get_token_limit(self._create_args["model"]) + return token_limit - self.count_tokens(messages, tools) + @property def capabilities(self) -> ModelCapabilities: return self._model_capabilities diff --git a/python/tests/test_model_client.py b/python/tests/test_model_client.py new file mode 100644 index 000000000..18112d788 --- /dev/null +++ b/python/tests/test_model_client.py @@ -0,0 +1,66 @@ +from typing import List + +import pytest +from agnext.components import Image +from agnext.components.models import ( + AssistantMessage, + AzureOpenAIChatCompletionClient, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + OpenAIChatCompletionClient, + SystemMessage, + UserMessage, +) +from agnext.components.tools import FunctionTool + + +@pytest.mark.asyncio +async def test_openai_chat_completion_client() -> None: + client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key") + assert client + + +@pytest.mark.asyncio +async def test_azure_openai_chat_completion_client() -> None: + client = AzureOpenAIChatCompletionClient( + model="gpt-4o", + api_key="api_key", + api_version="2020-08-04", + azure_endpoint="https://dummy.com", + model_capabilities={"vision": True, "function_calling": True, "json_output": True}, + ) + assert client + + +@pytest.mark.asyncio +async def test_openai_chat_completion_client_count_tokens() -> None: + client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key") + messages : List[LLMMessage] = [ + SystemMessage(content="Hello"), + UserMessage(content="Hello", source="user"), + AssistantMessage(content="Hello", source="assistant"), + UserMessage( + content=[ + "str1", + Image.from_base64( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC" + ), + ], + source="user", + ), + FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1")]), + ] + + def tool1(test: str, test2: str) -> str: + return test + test2 + + def tool2(test1: int, test2: List[int]) -> str: + return str(test1) + str(test2) + + tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")] + num_tokens = client.count_tokens(messages, tools=tools) + assert num_tokens + + remaining_tokens = client.remaining_tokens(messages, tools=tools) + assert remaining_tokens