From c7da7f1afb4ad1dbe5a2a230cefe0398954d38cd Mon Sep 17 00:00:00 2001 From: Kylin Date: Fri, 12 Jul 2024 02:03:30 +0800 Subject: [PATCH] [compatibility issue] Support open source LLM model to prompt-tune (#505) Compatibility update: support non-open ai model to prompt-tune Co-authored-by: Alonso Guevara --- .../next-release/patch-20240711092703710242.json | 4 ++++ graphrag/index/utils/tokens.py | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 .semversioner/next-release/patch-20240711092703710242.json diff --git a/.semversioner/next-release/patch-20240711092703710242.json b/.semversioner/next-release/patch-20240711092703710242.json new file mode 100644 index 00000000..7868b33b --- /dev/null +++ b/.semversioner/next-release/patch-20240711092703710242.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "support non-open ai model config to prompt tune" +} diff --git a/graphrag/index/utils/tokens.py b/graphrag/index/utils/tokens.py index e308c81f..b0fa435b 100644 --- a/graphrag/index/utils/tokens.py +++ b/graphrag/index/utils/tokens.py @@ -2,10 +2,12 @@ # Licensed under the MIT License """Utilities for working with tokens.""" +import logging import tiktoken DEFAULT_ENCODING_NAME = "cl100k_base" +log = logging.getLogger(__name__) def num_tokens_from_string( @@ -13,7 +15,12 @@ def num_tokens_from_string( ) -> int: """Return the number of tokens in a text string.""" if model is not None: - encoding = tiktoken.encoding_for_model(model) + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError as e: + log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, " + f"fall back to default encoding {DEFAULT_ENCODING_NAME}") + encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME) else: encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME) return len(encoding.encode(string))