Implement prompt tuning API (#855)

* initial setup commit

* cleanup API and CLI interfaces

* move datatype definition to types.py

* code cleanup

* add semversioner file

* remove unused import

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Josh Bradley 2024-08-12 17:09:00 -04:00 committed by GitHub
parent 4bcbfd10eb
commit 238f1c2adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 308 additions and 278 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Implement auto templating API."
}

View File

@ -1,37 +1,32 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Prompt auto templating package root."""
"""The auto templating package root."""
import argparse
import asyncio
from enum import Enum
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
from .api import DocSelectionType
from .cli import prompt_tune
class DocSelectionType(Enum):
"""The type of document selection to use."""
ALL = "all"
RANDOM = "random"
TOP = "top"
AUTO = "auto"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value
from .generator import MAX_TOKEN_COUNT
from .loader import MIN_CHUNK_SIZE
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
prog="python -m graphrag.prompt_tune",
description="The graphrag auto templating module.",
)
parser.add_argument(
"--config",
help="Configuration yaml file to use when generating prompts",
required=True,
type=str,
)
parser.add_argument(
"--root",
help="The data project root. Including the config yml, json or .env",
help="Data project root. Default: current directory",
required=False,
type=str,
default=".",
@ -39,15 +34,15 @@ if __name__ == "__main__":
parser.add_argument(
"--domain",
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.",
help="Domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, the domain will be inferred from the input data.",
required=False,
default="",
type=str,
)
parser.add_argument(
"--method",
help="The method to select documents, one of: all, random, top or auto",
"--selection-method",
help=f"Chunk selection method. Default: {DocSelectionType.RANDOM}",
required=False,
type=DocSelectionType,
choices=list(DocSelectionType),
@ -56,7 +51,7 @@ if __name__ == "__main__":
parser.add_argument(
"--n_subset_max",
help="The number of text chunks to embed when using auto selection method",
help="Number of text chunks to embed when using auto selection method. Default: 300",
required=False,
type=int,
default=300,
@ -64,7 +59,7 @@ if __name__ == "__main__":
parser.add_argument(
"--k",
help="The maximum number of documents to select from each centroid when using auto selection method",
help="Maximum number of documents to select from each centroid when using auto selection method. Default: 15",
required=False,
type=int,
default=15,
@ -72,7 +67,7 @@ if __name__ == "__main__":
parser.add_argument(
"--limit",
help="The limit of files to load when doing random or top selection",
help="Number of documents to load when doing random or top selection. Default: 15",
type=int,
required=False,
default=15,
@ -80,7 +75,7 @@ if __name__ == "__main__":
parser.add_argument(
"--max-tokens",
help="Max token count for prompt generation",
help=f"Max token count for prompt generation. Default: {MAX_TOKEN_COUNT}",
type=int,
required=False,
default=MAX_TOKEN_COUNT,
@ -88,7 +83,7 @@ if __name__ == "__main__":
parser.add_argument(
"--min-examples-required",
help="The minimum number of examples required in entity extraction prompt",
help="Minimum number of examples required in the entity extraction prompt. Default: 2",
type=int,
required=False,
default=2,
@ -96,7 +91,7 @@ if __name__ == "__main__":
parser.add_argument(
"--chunk-size",
help="Max token count for prompt generation",
help=f"Max token count for prompt generation. Default: {MIN_CHUNK_SIZE}",
type=int,
required=False,
default=MIN_CHUNK_SIZE,
@ -120,7 +115,7 @@ if __name__ == "__main__":
parser.add_argument(
"--output",
help="Folder to save the generated prompts to",
help="Directory to save generated prompts to. Default: 'prompts'",
type=str,
required=False,
default="prompts",
@ -132,17 +127,18 @@ if __name__ == "__main__":
loop.run_until_complete(
prompt_tune(
args.root,
args.domain,
str(args.method),
args.limit,
args.max_tokens,
args.chunk_size,
args.language,
args.no_entity_types,
args.output,
args.n_subset_max,
args.k,
args.min_examples_required,
config=args.config,
root=args.root,
domain=args.domain,
selection_method=args.selection_method,
limit=args.limit,
max_tokens=args.max_tokens,
chunk_size=args.chunk_size,
language=args.language,
skip_entity_types=args.no_entity_types,
output=args.output,
n_subset_max=args.n_subset_max,
k=args.k,
min_examples_required=args.min_examples_required,
)
)

173
graphrag/prompt_tune/api.py Normal file
View File

@ -0,0 +1,173 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Auto Templating API.
This API provides access to the auto templating feature of graphrag, allowing external applications
to hook into graphrag and generate prompts from private data.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
from datashaper import NoopVerbCallbacks
from pydantic import PositiveInt, validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from .cli import DocSelectionType
from .generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
create_entity_summarization_prompt,
detect_language,
generate_community_report_rating,
generate_community_reporter_role,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)
from .loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
)
@validate_call
async def generate_indexing_prompts(
config: GraphRagConfig,
root: str,
chunk_size: PositiveInt = MIN_CHUNK_SIZE,
limit: PositiveInt = 15,
selection_method: DocSelectionType = DocSelectionType.RANDOM,
domain: str | None = None,
language: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
skip_entity_types: bool = False,
min_examples_required: PositiveInt = 2,
n_subset_max: PositiveInt = 300,
k: PositiveInt = 15,
) -> tuple[str, str, str]:
"""Generate indexing prompts.
Parameters
----------
- config: The GraphRag configuration.
- output_path: The path to store the prompts.
- chunk_size: The chunk token size to use for input text units.
- limit: The limit of chunks to load.
- selection_method: The chunk selection method.
- domain: The domain to map the input documents to.
- language: The language to use for the prompts.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
- skip_entity_types: Skip generating entity types.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
Returns
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
reporter = PrintProgressReporter("")
# Retrieve documents
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=selection_method,
reporter=reporter,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
)
# Create LLM from config
llm = load_llm(
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
)
if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")
if not language:
reporter.info("Detecting language...")
language = await detect_language(llm, doc_list)
reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)
entity_types = None
if not skip_entity_types:
reporter.info("Generating entity types...")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
)
reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
reporter.info("Generating entity extraction prompt...")
entity_extraction_prompt = create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json by the index engine
encoding_model=config.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)
reporter.info("Generating entity summarization prompt...")
entity_summarization_prompt = create_entity_summarization_prompt(
persona=persona,
language=language,
)
reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info("Generating community summarization prompt...")
community_summarization_prompt = create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,
report_rating_description=community_report_ranking,
language=language,
)
return (
entity_extraction_prompt,
entity_summarization_prompt,
community_summarization_prompt,
)

View File

@ -5,37 +5,25 @@
from pathlib import Path
from datashaper import NoopVerbCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import CompletionLLM
from graphrag.prompt_tune.generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
create_entity_summarization_prompt,
detect_language,
generate_community_report_rating,
generate_community_reporter_role,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
read_config_parameters,
)
from . import api
from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME
from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME
from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME
from .types import DocSelectionType
async def prompt_tune(
config: str,
root: str,
domain: str,
select: str = "random",
selection_method: DocSelectionType = DocSelectionType.RANDOM,
limit: int = 15,
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
@ -50,223 +38,51 @@ async def prompt_tune(
Parameters
----------
- config: The configuration file.
- root: The root directory.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- selection_method: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use.
- language: The language to use for the prompts.
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
reporter = PrintProgressReporter("")
config = read_config_parameters(root, reporter)
graph_config = read_config_parameters(root, reporter, config)
await prompt_tune_with_config(
root,
config,
domain,
select,
limit,
max_tokens,
chunk_size,
language,
skip_entity_types,
output,
reporter,
n_subset_max,
k,
min_examples_required,
)
async def prompt_tune_with_config(
root: str,
config: GraphRagConfig,
domain: str,
select: str = "random",
limit: int = 15,
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
language: str | None = None,
skip_entity_types: bool = False,
output: str = "prompts",
reporter: ProgressReporter | None = None,
n_subset_max: int = 300,
k: int = 15,
min_examples_required: int = 2,
):
"""Prompt tune the model with a configuration.
Parameters
----------
- root: The root directory.
- config: The GraphRag configuration.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use for input text units.
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- reporter: The progress reporter.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
Returns
-------
- None
"""
if not reporter:
reporter = PrintProgressReporter("")
output_path = Path(config.root_dir) / output
doc_list = await load_docs_in_chunks(
prompts = await api.generate_indexing_prompts(
config=graph_config,
root=root,
config=config,
limit=limit,
select_method=select,
reporter=reporter,
chunk_size=chunk_size,
limit=limit,
selection_method=selection_method,
domain=domain,
language=language,
max_tokens=max_tokens,
skip_entity_types=skip_entity_types,
min_examples_required=min_examples_required,
n_subset_max=n_subset_max,
k=k,
)
# Create LLM from config
llm = load_llm(
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
)
await generate_indexing_prompts(
llm,
config,
doc_list,
output_path,
reporter,
domain,
language,
max_tokens,
skip_entity_types,
min_examples_required,
)
async def generate_indexing_prompts(
llm: CompletionLLM,
config: GraphRagConfig,
doc_list: list[str],
output_path: Path,
reporter: ProgressReporter,
domain: str | None = None,
language: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
skip_entity_types: bool = False,
min_examples_required: int = 2,
):
"""Generate indexing prompts.
Parameters
----------
- llm: The LLM model to use.
- config: The GraphRag configuration.
- doc_list: The list of documents to use.
- output_path: The path to store the prompts.
- reporter: The progress reporter.
- domain: The domain to map the input documents to.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
- skip_entity_types: Skip generating entity types.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")
if not language:
reporter.info("Detecting language...")
language = await detect_language(llm, doc_list)
reporter.info(f"Detected language: {language}")
reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info(f"Generated persona: {persona}")
reporter.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(
f"Generated community report ranking description: {community_report_ranking}"
)
entity_types = None
if not skip_entity_types:
reporter.info("Generating entity types")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
output_path = Path(output)
if output_path:
reporter.info(f"Writing prompts to {output_path}")
output_path.mkdir(parents=True, exist_ok=True)
entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
community_summarization_prompt_path = (
output_path / COMMUNITY_SUMMARIZATION_FILENAME
)
reporter.info(f"Generated entity types: {entity_types}")
reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
reporter.info("Done generating entity relationship examples")
reporter.info("Generating entity extraction prompt...")
create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
output_path=output_path,
encoding_model=config.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
reporter.info("Generating entity summarization prompt...")
create_entity_summarization_prompt(
persona=persona,
language=language,
output_path=output_path,
)
reporter.info(
f"Generated entity summarization prompt, stored in folder {output_path}"
)
reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(f"Generated community reporter role: {community_reporter_role}")
reporter.info("Generating community summarization prompt...")
create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,
report_rating_description=community_report_ranking,
language=language,
output_path=output_path,
)
reporter.info(
f"Generated community summarization prompt, stored in folder {output_path}"
)
# Write files to output path
with entity_extraction_prompt_path.open("wb") as file:
file.write(prompts[0].encode(encoding="utf-8", errors="strict"))
with entity_summarization_prompt_path.open("wb") as file:
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
with community_summarization_prompt_path.open("wb") as file:
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))

View File

@ -41,7 +41,7 @@ def create_entity_extraction_prompt(
- encoding_model (str): The name of the model to use for token counting
- max_token_count (int): The maximum number of tokens to use for the prompt
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
- output_path (Path | None): The path to write the prompt to. Default is None.
- min_examples_required (int): The minimum number of examples required. Default is 2.
Returns
@ -58,8 +58,8 @@ def create_entity_extraction_prompt(
tokens_left = (
max_token_count
- num_tokens_from_string(prompt, model=encoding_model)
- num_tokens_from_string(entity_types, model=encoding_model)
- num_tokens_from_string(prompt, encoding_name=encoding_model)
- num_tokens_from_string(entity_types, encoding_name=encoding_model)
if entity_types
else 0
)
@ -79,7 +79,9 @@ def create_entity_extraction_prompt(
)
)
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
example_tokens = num_tokens_from_string(
example_formatted, encoding_name=encoding_model
)
# Ensure at least three examples are included
if i >= min_examples_required and example_tokens > tokens_left:

View File

@ -15,13 +15,14 @@ def create_entity_summarization_prompt(
language: str,
output_path: Path | None = None,
) -> str:
"""Create a prompt for entity summarization. If output_path is provided, write the prompt to a file.
"""
Create a prompt for entity summarization.
Parameters
----------
- persona (str): The persona to use for the entity summarization prompt
- language (str): The language to use for the entity summarization prompt
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
- output_path (Path | None): The path to write the prompt to. Default is None.
"""
prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language)

View File

@ -9,20 +9,38 @@ from graphrag.config import create_graphrag_config
from graphrag.index.progress.types import ProgressReporter
def read_config_parameters(root: str, reporter: ProgressReporter):
def read_config_parameters(
root: str, reporter: ProgressReporter, config: str | None = None
):
"""Read the configuration parameters from the settings file or environment variables.
Parameters
----------
- root: The root directory where the parameters are.
- reporter: The progress reporter.
- config: The path to the settings file.
"""
_root = Path(root)
settings_yaml = _root / "settings.yaml"
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
settings_json = _root / "settings.json"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:

View File

@ -16,6 +16,7 @@ from graphrag.index.llm import load_llm_embeddings
from graphrag.index.progress.types import ProgressReporter
from graphrag.index.verbs import chunk
from graphrag.llm.types.llm_types import EmbeddingLLM
from graphrag.prompt_tune.types import DocSelectionType
MIN_CHUNK_OVERLAP = 0
MIN_CHUNK_SIZE = 200
@ -50,7 +51,7 @@ def _sample_chunks_from_embeddings(
async def load_docs_in_chunks(
root: str,
config: GraphRagConfig,
select_method: str,
select_method: DocSelectionType,
limit: int,
reporter: ProgressReporter,
chunk_size: int = MIN_CHUNK_SIZE,
@ -85,11 +86,11 @@ async def load_docs_in_chunks(
if limit <= 0 or limit > len(chunks_df):
limit = len(chunks_df)
if select_method == "top":
if select_method == DocSelectionType.TOP:
chunks_df = chunks_df[:limit]
elif select_method == "random":
elif select_method == DocSelectionType.RANDOM:
chunks_df = chunks_df.sample(n=limit)
elif select_method == "auto":
elif select_method == DocSelectionType.AUTO:
if k is None or k <= 0:
msg = "k must be an integer > 0"
raise ValueError(msg)

View File

@ -0,0 +1,19 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Types for prompt tuning."""
from enum import Enum
class DocSelectionType(Enum):
"""The type of document selection to use."""
ALL = "all"
RANDOM = "random"
TOP = "top"
AUTO = "auto"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value