mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Fix encoding model parameter in prompt auto templating (#500)
* Fix encoding model parameter in prompt tune * Format changes
This commit is contained in:
parent
c7da7f1afb
commit
a0caadb320
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix encoding model parameter on prompt tune"
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utilities for working with tokens."""
|
||||
|
||||
import logging
|
||||
|
||||
import tiktoken
|
||||
@ -17,9 +18,9 @@ def num_tokens_from_string(
|
||||
if model is not None:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError as e:
|
||||
log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, "
|
||||
f"fall back to default encoding {DEFAULT_ENCODING_NAME}")
|
||||
except KeyError:
|
||||
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
|
||||
log.warning(msg)
|
||||
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
||||
|
@ -218,8 +218,8 @@ async def generate_indexing_prompts(
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
||||
model_name=config.llm.model,
|
||||
output_path=output_path,
|
||||
encoding_model=config.encoding_model,
|
||||
max_token_count=max_tokens,
|
||||
)
|
||||
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.prompt_tune.template import (
|
||||
EXAMPLE_EXTRACTION_TEMPLATE,
|
||||
@ -22,8 +23,8 @@ def create_entity_extraction_prompt(
|
||||
docs: list[str],
|
||||
examples: list[str],
|
||||
language: str,
|
||||
model_name: str,
|
||||
max_token_count: int,
|
||||
encoding_model: str = defs.ENCODING_MODEL,
|
||||
json_mode: bool = False,
|
||||
output_path: Path | None = None,
|
||||
) -> str:
|
||||
@ -36,7 +37,7 @@ def create_entity_extraction_prompt(
|
||||
- docs (list[str]): The list of documents to extract entities from
|
||||
- examples (list[str]): The list of examples to use for entity extraction
|
||||
- language (str): The language of the inputs and outputs
|
||||
- model_name (str): The name of the model to use for token counting
|
||||
- encoding_model (str): The name of the model to use for token counting
|
||||
- max_token_count (int): The maximum number of tokens to use for the prompt
|
||||
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
||||
@ -55,8 +56,8 @@ def create_entity_extraction_prompt(
|
||||
|
||||
tokens_left = (
|
||||
max_token_count
|
||||
- num_tokens_from_string(prompt, model=model_name)
|
||||
- num_tokens_from_string(entity_types, model=model_name)
|
||||
- num_tokens_from_string(prompt, model=encoding_model)
|
||||
- num_tokens_from_string(entity_types, model=encoding_model)
|
||||
if entity_types
|
||||
else 0
|
||||
)
|
||||
@ -76,7 +77,7 @@ def create_entity_extraction_prompt(
|
||||
)
|
||||
)
|
||||
|
||||
example_tokens = num_tokens_from_string(example_formatted, model=model_name)
|
||||
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
|
||||
|
||||
# Squeeze in at least one example
|
||||
if i > 0 and example_tokens > tokens_left:
|
||||
|
Loading…
x
Reference in New Issue
Block a user