diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index 9e254932f..18e4d9e4e 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -2,27 +2,33 @@ from typing import List, Union, Dict import logging import json import tiktoken +import re logger = logging.getLogger(__name__) def get_max_token_limit(model="gpt-3.5-turbo-0613"): + # Handle common azure model names/aliases + model = re.sub(r"^gpt\-?35", "gpt-3.5", model) + model = re.sub(r"^gpt4", "gpt-4", model) + max_token_limit = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-instruct": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-35-turbo": 4096, - "gpt-35-turbo-16k": 16384, - "gpt-35-turbo-instruct": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-16k-0613": 16385, + "gpt-3.5-turbo-1106": 16385, "gpt-4": 8192, "gpt-4-32k": 32768, "gpt-4-32k-0314": 32768, # deprecate in Sep "gpt-4-0314": 8192, # deprecate in Sep "gpt-4-0613": 8192, "gpt-4-32k-0613": 32768, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, } return max_token_limit[model] diff --git a/test/test_token_count.py b/test/test_token_count.py index 1da4ccabe..8fe6e3666 100644 --- a/test/test_token_count.py +++ b/test/test_token_count.py @@ -1,4 +1,10 @@ -from autogen.token_count_utils import count_token, num_tokens_from_functions, token_left, percentile_used +from autogen.token_count_utils import ( + count_token, + num_tokens_from_functions, + token_left, + percentile_used, + get_max_token_limit, +) import pytest func1 = { @@ -67,6 +73,14 @@ def test_count_token(): assert percentile_used(text) == 10 / 4096 +def test_model_aliases(): + assert get_max_token_limit("gpt35-turbo") == get_max_token_limit("gpt-3.5-turbo") + assert get_max_token_limit("gpt-35-turbo") == get_max_token_limit("gpt-3.5-turbo") + assert get_max_token_limit("gpt4") == get_max_token_limit("gpt-4") + assert get_max_token_limit("gpt4-32k") == get_max_token_limit("gpt-4-32k") + + if __name__ == "__main__": - test_num_tokens_from_functions() - test_count_token() + # test_num_tokens_from_functions() + # test_count_token() + test_model_aliases()