diff --git a/.semversioner/next-release/patch-20240711092703710242.json b/.semversioner/next-release/patch-20240711092703710242.json new file mode 100644 index 00000000..7868b33b --- /dev/null +++ b/.semversioner/next-release/patch-20240711092703710242.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "support non-open ai model config to prompt tune" +} diff --git a/graphrag/index/utils/tokens.py b/graphrag/index/utils/tokens.py index e308c81f..b0fa435b 100644 --- a/graphrag/index/utils/tokens.py +++ b/graphrag/index/utils/tokens.py @@ -2,10 +2,12 @@ # Licensed under the MIT License """Utilities for working with tokens.""" +import logging import tiktoken DEFAULT_ENCODING_NAME = "cl100k_base" +log = logging.getLogger(__name__) def num_tokens_from_string( @@ -13,7 +15,12 @@ def num_tokens_from_string( ) -> int: """Return the number of tokens in a text string.""" if model is not None: - encoding = tiktoken.encoding_for_model(model) + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError as e: + log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, " + f"fall back to default encoding {DEFAULT_ENCODING_NAME}") + encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME) else: encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME) return len(encoding.encode(string))