Fix encoding model parameter in prompt auto templating (#500)

* Fix encoding model parameter in prompt tune

* Format changes
This commit is contained in:
Alonso Guevara 2024-07-11 16:12:25 -06:00 committed by GitHub
parent c7da7f1afb
commit a0caadb320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 9 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix encoding model parameter on prompt tune"
}

View File

@ -2,6 +2,7 @@
# Licensed under the MIT License
"""Utilities for working with tokens."""
import logging
import tiktoken
@ -17,9 +18,9 @@ def num_tokens_from_string(
if model is not None:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError as e:
log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, "
f"fall back to default encoding {DEFAULT_ENCODING_NAME}")
except KeyError:
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
log.warning(msg)
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
else:
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)

View File

@ -218,8 +218,8 @@ async def generate_indexing_prompts(
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
model_name=config.llm.model,
output_path=output_path,
encoding_model=config.encoding_model,
max_token_count=max_tokens,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")

View File

@ -5,6 +5,7 @@
from pathlib import Path
import graphrag.config.defaults as defs
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.prompt_tune.template import (
EXAMPLE_EXTRACTION_TEMPLATE,
@ -22,8 +23,8 @@ def create_entity_extraction_prompt(
docs: list[str],
examples: list[str],
language: str,
model_name: str,
max_token_count: int,
encoding_model: str = defs.ENCODING_MODEL,
json_mode: bool = False,
output_path: Path | None = None,
) -> str:
@ -36,7 +37,7 @@ def create_entity_extraction_prompt(
- docs (list[str]): The list of documents to extract entities from
- examples (list[str]): The list of examples to use for entity extraction
- language (str): The language of the inputs and outputs
- model_name (str): The name of the model to use for token counting
- encoding_model (str): The name of the model to use for token counting
- max_token_count (int): The maximum number of tokens to use for the prompt
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
@ -55,8 +56,8 @@ def create_entity_extraction_prompt(
tokens_left = (
max_token_count
- num_tokens_from_string(prompt, model=model_name)
- num_tokens_from_string(entity_types, model=model_name)
- num_tokens_from_string(prompt, model=encoding_model)
- num_tokens_from_string(entity_types, model=encoding_model)
if entity_types
else 0
)
@ -76,7 +77,7 @@ def create_entity_extraction_prompt(
)
)
example_tokens = num_tokens_from_string(example_formatted, model=model_name)
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
# Squeeze in at least one example
if i > 0 and example_tokens > tokens_left: