mirror of
https://github.com/microsoft/graphrag.git
synced 2025-07-03 07:04:19 +00:00
Fix encoding model parameter in prompt auto templating (#500)
* Fix encoding model parameter in prompt tune * Format changes
This commit is contained in:
parent
c7da7f1afb
commit
a0caadb320
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Fix encoding model parameter on prompt tune"
|
||||||
|
}
|
@ -2,6 +2,7 @@
|
|||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
"""Utilities for working with tokens."""
|
"""Utilities for working with tokens."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -17,9 +18,9 @@ def num_tokens_from_string(
|
|||||||
if model is not None:
|
if model is not None:
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
except KeyError as e:
|
except KeyError:
|
||||||
log.error(f"Failed to get encoding for {model} when getting num_tokens_from_string, "
|
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
|
||||||
f"fall back to default encoding {DEFAULT_ENCODING_NAME}")
|
log.warning(msg)
|
||||||
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
||||||
else:
|
else:
|
||||||
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
||||||
|
@ -218,8 +218,8 @@ async def generate_indexing_prompts(
|
|||||||
examples=examples,
|
examples=examples,
|
||||||
language=language,
|
language=language,
|
||||||
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
||||||
model_name=config.llm.model,
|
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
encoding_model=config.encoding_model,
|
||||||
max_token_count=max_tokens,
|
max_token_count=max_tokens,
|
||||||
)
|
)
|
||||||
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
|
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||||
from graphrag.prompt_tune.template import (
|
from graphrag.prompt_tune.template import (
|
||||||
EXAMPLE_EXTRACTION_TEMPLATE,
|
EXAMPLE_EXTRACTION_TEMPLATE,
|
||||||
@ -22,8 +23,8 @@ def create_entity_extraction_prompt(
|
|||||||
docs: list[str],
|
docs: list[str],
|
||||||
examples: list[str],
|
examples: list[str],
|
||||||
language: str,
|
language: str,
|
||||||
model_name: str,
|
|
||||||
max_token_count: int,
|
max_token_count: int,
|
||||||
|
encoding_model: str = defs.ENCODING_MODEL,
|
||||||
json_mode: bool = False,
|
json_mode: bool = False,
|
||||||
output_path: Path | None = None,
|
output_path: Path | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -36,7 +37,7 @@ def create_entity_extraction_prompt(
|
|||||||
- docs (list[str]): The list of documents to extract entities from
|
- docs (list[str]): The list of documents to extract entities from
|
||||||
- examples (list[str]): The list of examples to use for entity extraction
|
- examples (list[str]): The list of examples to use for entity extraction
|
||||||
- language (str): The language of the inputs and outputs
|
- language (str): The language of the inputs and outputs
|
||||||
- model_name (str): The name of the model to use for token counting
|
- encoding_model (str): The name of the model to use for token counting
|
||||||
- max_token_count (int): The maximum number of tokens to use for the prompt
|
- max_token_count (int): The maximum number of tokens to use for the prompt
|
||||||
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
||||||
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
||||||
@ -55,8 +56,8 @@ def create_entity_extraction_prompt(
|
|||||||
|
|
||||||
tokens_left = (
|
tokens_left = (
|
||||||
max_token_count
|
max_token_count
|
||||||
- num_tokens_from_string(prompt, model=model_name)
|
- num_tokens_from_string(prompt, model=encoding_model)
|
||||||
- num_tokens_from_string(entity_types, model=model_name)
|
- num_tokens_from_string(entity_types, model=encoding_model)
|
||||||
if entity_types
|
if entity_types
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
@ -76,7 +77,7 @@ def create_entity_extraction_prompt(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
example_tokens = num_tokens_from_string(example_formatted, model=model_name)
|
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
|
||||||
|
|
||||||
# Squeeze in at least one example
|
# Squeeze in at least one example
|
||||||
if i > 0 and example_tokens > tokens_left:
|
if i > 0 and example_tokens > tokens_left:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user