mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-06 12:11:01 +00:00
Implement prompt tuning API (#855)
* initial setup commit * cleanup API and CLI interfaces * move datatype definition to types.py * code cleanup * add semversioner file * remove unused import --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
4bcbfd10eb
commit
238f1c2adc
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Implement auto templating API."
|
||||
}
|
||||
@ -1,37 +1,32 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Prompt auto templating package root."""
|
||||
"""The auto templating package root."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
|
||||
|
||||
from .api import DocSelectionType
|
||||
from .cli import prompt_tune
|
||||
|
||||
|
||||
class DocSelectionType(Enum):
|
||||
"""The type of document selection to use."""
|
||||
|
||||
ALL = "all"
|
||||
RANDOM = "random"
|
||||
TOP = "top"
|
||||
AUTO = "auto"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string representation of the enum value."""
|
||||
return self.value
|
||||
|
||||
from .generator import MAX_TOKEN_COUNT
|
||||
from .loader import MIN_CHUNK_SIZE
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python -m graphrag.prompt_tune",
|
||||
description="The graphrag auto templating module.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="Configuration yaml file to use when generating prompts",
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
help="The data project root. Including the config yml, json or .env",
|
||||
help="Data project root. Default: current directory",
|
||||
required=False,
|
||||
type=str,
|
||||
default=".",
|
||||
@ -39,15 +34,15 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--domain",
|
||||
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.",
|
||||
help="Domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, the domain will be inferred from the input data.",
|
||||
required=False,
|
||||
default="",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
help="The method to select documents, one of: all, random, top or auto",
|
||||
"--selection-method",
|
||||
help=f"Chunk selection method. Default: {DocSelectionType.RANDOM}",
|
||||
required=False,
|
||||
type=DocSelectionType,
|
||||
choices=list(DocSelectionType),
|
||||
@ -56,7 +51,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--n_subset_max",
|
||||
help="The number of text chunks to embed when using auto selection method",
|
||||
help="Number of text chunks to embed when using auto selection method. Default: 300",
|
||||
required=False,
|
||||
type=int,
|
||||
default=300,
|
||||
@ -64,7 +59,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--k",
|
||||
help="The maximum number of documents to select from each centroid when using auto selection method",
|
||||
help="Maximum number of documents to select from each centroid when using auto selection method. Default: 15",
|
||||
required=False,
|
||||
type=int,
|
||||
default=15,
|
||||
@ -72,7 +67,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
help="The limit of files to load when doing random or top selection",
|
||||
help="Number of documents to load when doing random or top selection. Default: 15",
|
||||
type=int,
|
||||
required=False,
|
||||
default=15,
|
||||
@ -80,7 +75,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
help="Max token count for prompt generation",
|
||||
help=f"Max token count for prompt generation. Default: {MAX_TOKEN_COUNT}",
|
||||
type=int,
|
||||
required=False,
|
||||
default=MAX_TOKEN_COUNT,
|
||||
@ -88,7 +83,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--min-examples-required",
|
||||
help="The minimum number of examples required in entity extraction prompt",
|
||||
help="Minimum number of examples required in the entity extraction prompt. Default: 2",
|
||||
type=int,
|
||||
required=False,
|
||||
default=2,
|
||||
@ -96,7 +91,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
help="Max token count for prompt generation",
|
||||
help=f"Max token count for prompt generation. Default: {MIN_CHUNK_SIZE}",
|
||||
type=int,
|
||||
required=False,
|
||||
default=MIN_CHUNK_SIZE,
|
||||
@ -120,7 +115,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Folder to save the generated prompts to",
|
||||
help="Directory to save generated prompts to. Default: 'prompts'",
|
||||
type=str,
|
||||
required=False,
|
||||
default="prompts",
|
||||
@ -132,17 +127,18 @@ if __name__ == "__main__":
|
||||
|
||||
loop.run_until_complete(
|
||||
prompt_tune(
|
||||
args.root,
|
||||
args.domain,
|
||||
str(args.method),
|
||||
args.limit,
|
||||
args.max_tokens,
|
||||
args.chunk_size,
|
||||
args.language,
|
||||
args.no_entity_types,
|
||||
args.output,
|
||||
args.n_subset_max,
|
||||
args.k,
|
||||
args.min_examples_required,
|
||||
config=args.config,
|
||||
root=args.root,
|
||||
domain=args.domain,
|
||||
selection_method=args.selection_method,
|
||||
limit=args.limit,
|
||||
max_tokens=args.max_tokens,
|
||||
chunk_size=args.chunk_size,
|
||||
language=args.language,
|
||||
skip_entity_types=args.no_entity_types,
|
||||
output=args.output,
|
||||
n_subset_max=args.n_subset_max,
|
||||
k=args.k,
|
||||
min_examples_required=args.min_examples_required,
|
||||
)
|
||||
)
|
||||
|
||||
173
graphrag/prompt_tune/api.py
Normal file
173
graphrag/prompt_tune/api.py
Normal file
@ -0,0 +1,173 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""
|
||||
Auto Templating API.
|
||||
|
||||
This API provides access to the auto templating feature of graphrag, allowing external applications
|
||||
to hook into graphrag and generate prompts from private data.
|
||||
|
||||
WARNING: This API is under development and may undergo changes in future releases.
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from datashaper import NoopVerbCallbacks
|
||||
from pydantic import PositiveInt, validate_call
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm import load_llm
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
|
||||
from .cli import DocSelectionType
|
||||
from .generator import (
|
||||
MAX_TOKEN_COUNT,
|
||||
create_community_summarization_prompt,
|
||||
create_entity_extraction_prompt,
|
||||
create_entity_summarization_prompt,
|
||||
detect_language,
|
||||
generate_community_report_rating,
|
||||
generate_community_reporter_role,
|
||||
generate_domain,
|
||||
generate_entity_relationship_examples,
|
||||
generate_entity_types,
|
||||
generate_persona,
|
||||
)
|
||||
from .loader import (
|
||||
MIN_CHUNK_SIZE,
|
||||
load_docs_in_chunks,
|
||||
)
|
||||
|
||||
|
||||
@validate_call
|
||||
async def generate_indexing_prompts(
|
||||
config: GraphRagConfig,
|
||||
root: str,
|
||||
chunk_size: PositiveInt = MIN_CHUNK_SIZE,
|
||||
limit: PositiveInt = 15,
|
||||
selection_method: DocSelectionType = DocSelectionType.RANDOM,
|
||||
domain: str | None = None,
|
||||
language: str | None = None,
|
||||
max_tokens: int = MAX_TOKEN_COUNT,
|
||||
skip_entity_types: bool = False,
|
||||
min_examples_required: PositiveInt = 2,
|
||||
n_subset_max: PositiveInt = 300,
|
||||
k: PositiveInt = 15,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Generate indexing prompts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config: The GraphRag configuration.
|
||||
- output_path: The path to store the prompts.
|
||||
- chunk_size: The chunk token size to use for input text units.
|
||||
- limit: The limit of chunks to load.
|
||||
- selection_method: The chunk selection method.
|
||||
- domain: The domain to map the input documents to.
|
||||
- language: The language to use for the prompts.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
- n_subset_max: The number of text chunks to embed when using auto selection method.
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
# Retrieve documents
|
||||
doc_list = await load_docs_in_chunks(
|
||||
root=root,
|
||||
config=config,
|
||||
limit=limit,
|
||||
select_method=selection_method,
|
||||
reporter=reporter,
|
||||
chunk_size=chunk_size,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
)
|
||||
|
||||
# Create LLM from config
|
||||
llm = load_llm(
|
||||
"prompt_tuning",
|
||||
config.llm.type,
|
||||
NoopVerbCallbacks(),
|
||||
None,
|
||||
config.llm.model_dump(),
|
||||
)
|
||||
|
||||
if not domain:
|
||||
reporter.info("Generating domain...")
|
||||
domain = await generate_domain(llm, doc_list)
|
||||
reporter.info(f"Generated domain: {domain}")
|
||||
|
||||
if not language:
|
||||
reporter.info("Detecting language...")
|
||||
language = await detect_language(llm, doc_list)
|
||||
|
||||
reporter.info("Generating persona...")
|
||||
persona = await generate_persona(llm, domain)
|
||||
|
||||
reporter.info("Generating community report ranking description...")
|
||||
community_report_ranking = await generate_community_report_rating(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
|
||||
entity_types = None
|
||||
if not skip_entity_types:
|
||||
reporter.info("Generating entity types...")
|
||||
entity_types = await generate_entity_types(
|
||||
llm,
|
||||
domain=domain,
|
||||
persona=persona,
|
||||
docs=doc_list,
|
||||
json_mode=config.llm.model_supports_json or False,
|
||||
)
|
||||
|
||||
reporter.info("Generating entity relationship examples...")
|
||||
examples = await generate_entity_relationship_examples(
|
||||
llm,
|
||||
persona=persona,
|
||||
entity_types=entity_types,
|
||||
docs=doc_list,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
||||
)
|
||||
|
||||
reporter.info("Generating entity extraction prompt...")
|
||||
entity_extraction_prompt = create_entity_extraction_prompt(
|
||||
entity_types=entity_types,
|
||||
docs=doc_list,
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json by the index engine
|
||||
encoding_model=config.encoding_model,
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
|
||||
reporter.info("Generating entity summarization prompt...")
|
||||
entity_summarization_prompt = create_entity_summarization_prompt(
|
||||
persona=persona,
|
||||
language=language,
|
||||
)
|
||||
|
||||
reporter.info("Generating community reporter role...")
|
||||
community_reporter_role = await generate_community_reporter_role(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
|
||||
reporter.info("Generating community summarization prompt...")
|
||||
community_summarization_prompt = create_community_summarization_prompt(
|
||||
persona=persona,
|
||||
role=community_reporter_role,
|
||||
report_rating_description=community_report_ranking,
|
||||
language=language,
|
||||
)
|
||||
|
||||
return (
|
||||
entity_extraction_prompt,
|
||||
entity_summarization_prompt,
|
||||
community_summarization_prompt,
|
||||
)
|
||||
@ -5,37 +5,25 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from datashaper import NoopVerbCallbacks
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm import load_llm
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
from graphrag.index.progress.types import ProgressReporter
|
||||
from graphrag.llm.types.llm_types import CompletionLLM
|
||||
from graphrag.prompt_tune.generator import (
|
||||
MAX_TOKEN_COUNT,
|
||||
create_community_summarization_prompt,
|
||||
create_entity_extraction_prompt,
|
||||
create_entity_summarization_prompt,
|
||||
detect_language,
|
||||
generate_community_report_rating,
|
||||
generate_community_reporter_role,
|
||||
generate_domain,
|
||||
generate_entity_relationship_examples,
|
||||
generate_entity_types,
|
||||
generate_persona,
|
||||
)
|
||||
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.loader import (
|
||||
MIN_CHUNK_SIZE,
|
||||
load_docs_in_chunks,
|
||||
read_config_parameters,
|
||||
)
|
||||
|
||||
from . import api
|
||||
from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME
|
||||
from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME
|
||||
from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME
|
||||
from .types import DocSelectionType
|
||||
|
||||
|
||||
async def prompt_tune(
|
||||
config: str,
|
||||
root: str,
|
||||
domain: str,
|
||||
select: str = "random",
|
||||
selection_method: DocSelectionType = DocSelectionType.RANDOM,
|
||||
limit: int = 15,
|
||||
max_tokens: int = MAX_TOKEN_COUNT,
|
||||
chunk_size: int = MIN_CHUNK_SIZE,
|
||||
@ -50,223 +38,51 @@ async def prompt_tune(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config: The configuration file.
|
||||
- root: The root directory.
|
||||
- domain: The domain to map the input documents to.
|
||||
- select: The chunk selection method.
|
||||
- selection_method: The chunk selection method.
|
||||
- limit: The limit of chunks to load.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
|
||||
- chunk_size: The chunk token size to use.
|
||||
- language: The language to use for the prompts.
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- output: The output folder to store the prompts.
|
||||
- n_subset_max: The number of text chunks to embed when using auto selection method.
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
config = read_config_parameters(root, reporter)
|
||||
graph_config = read_config_parameters(root, reporter, config)
|
||||
|
||||
await prompt_tune_with_config(
|
||||
root,
|
||||
config,
|
||||
domain,
|
||||
select,
|
||||
limit,
|
||||
max_tokens,
|
||||
chunk_size,
|
||||
language,
|
||||
skip_entity_types,
|
||||
output,
|
||||
reporter,
|
||||
n_subset_max,
|
||||
k,
|
||||
min_examples_required,
|
||||
)
|
||||
|
||||
|
||||
async def prompt_tune_with_config(
|
||||
root: str,
|
||||
config: GraphRagConfig,
|
||||
domain: str,
|
||||
select: str = "random",
|
||||
limit: int = 15,
|
||||
max_tokens: int = MAX_TOKEN_COUNT,
|
||||
chunk_size: int = MIN_CHUNK_SIZE,
|
||||
language: str | None = None,
|
||||
skip_entity_types: bool = False,
|
||||
output: str = "prompts",
|
||||
reporter: ProgressReporter | None = None,
|
||||
n_subset_max: int = 300,
|
||||
k: int = 15,
|
||||
min_examples_required: int = 2,
|
||||
):
|
||||
"""Prompt tune the model with a configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- root: The root directory.
|
||||
- config: The GraphRag configuration.
|
||||
- domain: The domain to map the input documents to.
|
||||
- select: The chunk selection method.
|
||||
- limit: The limit of chunks to load.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
|
||||
- chunk_size: The chunk token size to use for input text units.
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- output: The output folder to store the prompts.
|
||||
- reporter: The progress reporter.
|
||||
- n_subset_max: The number of text chunks to embed when using auto selection method.
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- None
|
||||
"""
|
||||
if not reporter:
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
output_path = Path(config.root_dir) / output
|
||||
|
||||
doc_list = await load_docs_in_chunks(
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
root=root,
|
||||
config=config,
|
||||
limit=limit,
|
||||
select_method=select,
|
||||
reporter=reporter,
|
||||
chunk_size=chunk_size,
|
||||
limit=limit,
|
||||
selection_method=selection_method,
|
||||
domain=domain,
|
||||
language=language,
|
||||
max_tokens=max_tokens,
|
||||
skip_entity_types=skip_entity_types,
|
||||
min_examples_required=min_examples_required,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
)
|
||||
|
||||
# Create LLM from config
|
||||
llm = load_llm(
|
||||
"prompt_tuning",
|
||||
config.llm.type,
|
||||
NoopVerbCallbacks(),
|
||||
None,
|
||||
config.llm.model_dump(),
|
||||
)
|
||||
|
||||
await generate_indexing_prompts(
|
||||
llm,
|
||||
config,
|
||||
doc_list,
|
||||
output_path,
|
||||
reporter,
|
||||
domain,
|
||||
language,
|
||||
max_tokens,
|
||||
skip_entity_types,
|
||||
min_examples_required,
|
||||
)
|
||||
|
||||
|
||||
async def generate_indexing_prompts(
|
||||
llm: CompletionLLM,
|
||||
config: GraphRagConfig,
|
||||
doc_list: list[str],
|
||||
output_path: Path,
|
||||
reporter: ProgressReporter,
|
||||
domain: str | None = None,
|
||||
language: str | None = None,
|
||||
max_tokens: int = MAX_TOKEN_COUNT,
|
||||
skip_entity_types: bool = False,
|
||||
min_examples_required: int = 2,
|
||||
):
|
||||
"""Generate indexing prompts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- llm: The LLM model to use.
|
||||
- config: The GraphRag configuration.
|
||||
- doc_list: The list of documents to use.
|
||||
- output_path: The path to store the prompts.
|
||||
- reporter: The progress reporter.
|
||||
- domain: The domain to map the input documents to.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
"""
|
||||
if not domain:
|
||||
reporter.info("Generating domain...")
|
||||
domain = await generate_domain(llm, doc_list)
|
||||
reporter.info(f"Generated domain: {domain}")
|
||||
|
||||
if not language:
|
||||
reporter.info("Detecting language...")
|
||||
language = await detect_language(llm, doc_list)
|
||||
reporter.info(f"Detected language: {language}")
|
||||
|
||||
reporter.info("Generating persona...")
|
||||
persona = await generate_persona(llm, domain)
|
||||
reporter.info(f"Generated persona: {persona}")
|
||||
|
||||
reporter.info("Generating community report ranking description...")
|
||||
community_report_ranking = await generate_community_report_rating(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
reporter.info(
|
||||
f"Generated community report ranking description: {community_report_ranking}"
|
||||
)
|
||||
|
||||
entity_types = None
|
||||
if not skip_entity_types:
|
||||
reporter.info("Generating entity types")
|
||||
entity_types = await generate_entity_types(
|
||||
llm,
|
||||
domain=domain,
|
||||
persona=persona,
|
||||
docs=doc_list,
|
||||
json_mode=config.llm.model_supports_json or False,
|
||||
output_path = Path(output)
|
||||
if output_path:
|
||||
reporter.info(f"Writing prompts to {output_path}")
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME
|
||||
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
|
||||
community_summarization_prompt_path = (
|
||||
output_path / COMMUNITY_SUMMARIZATION_FILENAME
|
||||
)
|
||||
reporter.info(f"Generated entity types: {entity_types}")
|
||||
|
||||
reporter.info("Generating entity relationship examples...")
|
||||
examples = await generate_entity_relationship_examples(
|
||||
llm,
|
||||
persona=persona,
|
||||
entity_types=entity_types,
|
||||
docs=doc_list,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
||||
)
|
||||
reporter.info("Done generating entity relationship examples")
|
||||
|
||||
reporter.info("Generating entity extraction prompt...")
|
||||
create_entity_extraction_prompt(
|
||||
entity_types=entity_types,
|
||||
docs=doc_list,
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
|
||||
output_path=output_path,
|
||||
encoding_model=config.encoding_model,
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
|
||||
|
||||
reporter.info("Generating entity summarization prompt...")
|
||||
create_entity_summarization_prompt(
|
||||
persona=persona,
|
||||
language=language,
|
||||
output_path=output_path,
|
||||
)
|
||||
reporter.info(
|
||||
f"Generated entity summarization prompt, stored in folder {output_path}"
|
||||
)
|
||||
|
||||
reporter.info("Generating community reporter role...")
|
||||
community_reporter_role = await generate_community_reporter_role(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
reporter.info(f"Generated community reporter role: {community_reporter_role}")
|
||||
|
||||
reporter.info("Generating community summarization prompt...")
|
||||
create_community_summarization_prompt(
|
||||
persona=persona,
|
||||
role=community_reporter_role,
|
||||
report_rating_description=community_report_ranking,
|
||||
language=language,
|
||||
output_path=output_path,
|
||||
)
|
||||
reporter.info(
|
||||
f"Generated community summarization prompt, stored in folder {output_path}"
|
||||
)
|
||||
# Write files to output path
|
||||
with entity_extraction_prompt_path.open("wb") as file:
|
||||
file.write(prompts[0].encode(encoding="utf-8", errors="strict"))
|
||||
with entity_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
|
||||
with community_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
@ -41,7 +41,7 @@ def create_entity_extraction_prompt(
|
||||
- encoding_model (str): The name of the model to use for token counting
|
||||
- max_token_count (int): The maximum number of tokens to use for the prompt
|
||||
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None.
|
||||
- min_examples_required (int): The minimum number of examples required. Default is 2.
|
||||
|
||||
Returns
|
||||
@ -58,8 +58,8 @@ def create_entity_extraction_prompt(
|
||||
|
||||
tokens_left = (
|
||||
max_token_count
|
||||
- num_tokens_from_string(prompt, model=encoding_model)
|
||||
- num_tokens_from_string(entity_types, model=encoding_model)
|
||||
- num_tokens_from_string(prompt, encoding_name=encoding_model)
|
||||
- num_tokens_from_string(entity_types, encoding_name=encoding_model)
|
||||
if entity_types
|
||||
else 0
|
||||
)
|
||||
@ -79,7 +79,9 @@ def create_entity_extraction_prompt(
|
||||
)
|
||||
)
|
||||
|
||||
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
|
||||
example_tokens = num_tokens_from_string(
|
||||
example_formatted, encoding_name=encoding_model
|
||||
)
|
||||
|
||||
# Ensure at least three examples are included
|
||||
if i >= min_examples_required and example_tokens > tokens_left:
|
||||
|
||||
@ -15,13 +15,14 @@ def create_entity_summarization_prompt(
|
||||
language: str,
|
||||
output_path: Path | None = None,
|
||||
) -> str:
|
||||
"""Create a prompt for entity summarization. If output_path is provided, write the prompt to a file.
|
||||
"""
|
||||
Create a prompt for entity summarization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- persona (str): The persona to use for the entity summarization prompt
|
||||
- language (str): The language to use for the entity summarization prompt
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None.
|
||||
"""
|
||||
prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language)
|
||||
|
||||
|
||||
@ -9,20 +9,38 @@ from graphrag.config import create_graphrag_config
|
||||
from graphrag.index.progress.types import ProgressReporter
|
||||
|
||||
|
||||
def read_config_parameters(root: str, reporter: ProgressReporter):
|
||||
def read_config_parameters(
|
||||
root: str, reporter: ProgressReporter, config: str | None = None
|
||||
):
|
||||
"""Read the configuration parameters from the settings file or environment variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- root: The root directory where the parameters are.
|
||||
- reporter: The progress reporter.
|
||||
- config: The path to the settings file.
|
||||
"""
|
||||
_root = Path(root)
|
||||
settings_yaml = _root / "settings.yaml"
|
||||
settings_yaml = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
||||
else _root / "settings.yaml"
|
||||
)
|
||||
if not settings_yaml.exists():
|
||||
settings_yaml = _root / "settings.yml"
|
||||
settings_json = _root / "settings.json"
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
settings_json = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix == ".json"
|
||||
else _root / "settings.json"
|
||||
)
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("rb") as file:
|
||||
|
||||
@ -16,6 +16,7 @@ from graphrag.index.llm import load_llm_embeddings
|
||||
from graphrag.index.progress.types import ProgressReporter
|
||||
from graphrag.index.verbs import chunk
|
||||
from graphrag.llm.types.llm_types import EmbeddingLLM
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
|
||||
MIN_CHUNK_OVERLAP = 0
|
||||
MIN_CHUNK_SIZE = 200
|
||||
@ -50,7 +51,7 @@ def _sample_chunks_from_embeddings(
|
||||
async def load_docs_in_chunks(
|
||||
root: str,
|
||||
config: GraphRagConfig,
|
||||
select_method: str,
|
||||
select_method: DocSelectionType,
|
||||
limit: int,
|
||||
reporter: ProgressReporter,
|
||||
chunk_size: int = MIN_CHUNK_SIZE,
|
||||
@ -85,11 +86,11 @@ async def load_docs_in_chunks(
|
||||
if limit <= 0 or limit > len(chunks_df):
|
||||
limit = len(chunks_df)
|
||||
|
||||
if select_method == "top":
|
||||
if select_method == DocSelectionType.TOP:
|
||||
chunks_df = chunks_df[:limit]
|
||||
elif select_method == "random":
|
||||
elif select_method == DocSelectionType.RANDOM:
|
||||
chunks_df = chunks_df.sample(n=limit)
|
||||
elif select_method == "auto":
|
||||
elif select_method == DocSelectionType.AUTO:
|
||||
if k is None or k <= 0:
|
||||
msg = "k must be an integer > 0"
|
||||
raise ValueError(msg)
|
||||
|
||||
19
graphrag/prompt_tune/types.py
Normal file
19
graphrag/prompt_tune/types.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Types for prompt tuning."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DocSelectionType(Enum):
|
||||
"""The type of document selection to use."""
|
||||
|
||||
ALL = "all"
|
||||
RANDOM = "random"
|
||||
TOP = "top"
|
||||
AUTO = "auto"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string representation of the enum value."""
|
||||
return self.value
|
||||
Loading…
x
Reference in New Issue
Block a user