From a0caadb320c5db4e7b8e83625f00c19be893170b Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Thu, 11 Jul 2024 16:12:25 -0600 Subject: [PATCH] Fix encoding model parameter in prompt auto templating (#500) * Fix encoding model parameter in prompt tune * Format changes --- .../next-release/patch-20240711004716103302.json | 4 ++++ graphrag/index/utils/tokens.py | 7 ++++--- graphrag/prompt_tune/cli.py | 2 +- .../prompt_tune/generator/entity_extraction_prompt.py | 11 ++++++----- 4 files changed, 15 insertions(+), 9 deletions(-) create mode 100644 .semversioner/next-release/patch-20240711004716103302.json diff --git a/.semversioner/next-release/patch-20240711004716103302.json b/.semversioner/next-release/patch-20240711004716103302.json new file mode 100644 index 00000000..d912ee41 --- /dev/null +++ b/.semversioner/next-release/patch-20240711004716103302.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix encoding model parameter on prompt tune" +} diff --git a/graphrag/index/utils/tokens.py b/graphrag/index/utils/tokens.py index b0fa435b..4a189b9b 100644 --- a/graphrag/index/utils/tokens.py +++ b/graphrag/index/utils/tokens.py @@ -2,6 +2,7 @@ # Licensed under the MIT License """Utilities for working with tokens.""" + import logging import tiktoken @@ -17,9 +18,9 @@ def num_tokens_from_string( if model is not None: try: encoding = tiktoken.encoding_for_model(model) - except KeyError as e: - log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, " - f"fall back to default encoding {DEFAULT_ENCODING_NAME}") + except KeyError: + msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}" + log.warning(msg) encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME) else: encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME) diff --git a/graphrag/prompt_tune/cli.py b/graphrag/prompt_tune/cli.py index d3fd08d4..f26b1c09 100644 --- a/graphrag/prompt_tune/cli.py +++ b/graphrag/prompt_tune/cli.py @@ -218,8 +218,8 @@ async def generate_indexing_prompts( examples=examples, language=language, json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine - model_name=config.llm.model, output_path=output_path, + encoding_model=config.encoding_model, max_token_count=max_tokens, ) reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") diff --git a/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/graphrag/prompt_tune/generator/entity_extraction_prompt.py index d56e894d..faac8da0 100644 --- a/graphrag/prompt_tune/generator/entity_extraction_prompt.py +++ b/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -5,6 +5,7 @@ from pathlib import Path +import graphrag.config.defaults as defs from graphrag.index.utils.tokens import num_tokens_from_string from graphrag.prompt_tune.template import ( EXAMPLE_EXTRACTION_TEMPLATE, @@ -22,8 +23,8 @@ def create_entity_extraction_prompt( docs: list[str], examples: list[str], language: str, - model_name: str, max_token_count: int, + encoding_model: str = defs.ENCODING_MODEL, json_mode: bool = False, output_path: Path | None = None, ) -> str: @@ -36,7 +37,7 @@ def create_entity_extraction_prompt( - docs (list[str]): The list of documents to extract entities from - examples (list[str]): The list of examples to use for entity extraction - language (str): The language of the inputs and outputs - - model_name (str): The name of the model to use for token counting + - encoding_model (str): The name of the model to use for token counting - max_token_count (int): The maximum number of tokens to use for the prompt - json_mode (bool): Whether to use JSON mode for the prompt. Default is False - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. @@ -55,8 +56,8 @@ def create_entity_extraction_prompt( tokens_left = ( max_token_count - - num_tokens_from_string(prompt, model=model_name) - - num_tokens_from_string(entity_types, model=model_name) + - num_tokens_from_string(prompt, model=encoding_model) + - num_tokens_from_string(entity_types, model=encoding_model) if entity_types else 0 ) @@ -76,7 +77,7 @@ def create_entity_extraction_prompt( ) ) - example_tokens = num_tokens_from_string(example_formatted, model=model_name) + example_tokens = num_tokens_from_string(example_formatted, model=encoding_model) # Squeeze in at least one example if i > 0 and example_tokens > tokens_left: