final version of the upgraded graphrag accelerator

This commit is contained in:
Gabriel Nieves 2025-04-04 15:52:21 +00:00
parent 226e7e95dd
commit b6ef05335b
7 changed files with 172 additions and 75 deletions

View File

@ -13,6 +13,7 @@ from fastapi import (
HTTPException,
UploadFile,
)
from graphrag.config.enums import IndexingMethod
from kubernetes import (
client as kubernetes_client,
)
@ -56,8 +57,12 @@ async def schedule_index_job(
index_container_name: str,
entity_extraction_prompt: UploadFile | None = None,
entity_summarization_prompt: UploadFile | None = None,
community_summarization_prompt: UploadFile | None = None,
community_summarization_graph_prompt: UploadFile | None = None,
community_summarization_text_prompt: UploadFile | None = None,
indexing_method: IndexingMethod = IndexingMethod.Standard.value,
):
indexing_method = IndexingMethod(indexing_method).value
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
pipelinejob = PipelineJob()
@ -86,9 +91,14 @@ async def schedule_index_job(
if entity_summarization_prompt
else None
)
community_summarization_prompt_content = (
community_summarization_prompt.file.read().decode("utf-8")
if community_summarization_prompt
community_summarization_graph_content = (
community_summarization_graph_prompt.file.read().decode("utf-8")
if community_summarization_graph_prompt
else None
)
community_summarization_text_content = (
community_summarization_text_prompt.file.read().decode("utf-8")
if community_summarization_text_prompt
else None
)
@ -119,9 +129,14 @@ async def schedule_index_job(
) = []
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
existing_job._entity_summarization_prompt = entity_summarization_prompt_content
existing_job._community_summarization_prompt = (
community_summarization_prompt_content
existing_job.community_summarization_graph_prompt = (
community_summarization_graph_content
)
existing_job.community_summarization_text_prompt = (
community_summarization_text_content
)
existing_job._indexing_method = indexing_method
existing_job._epoch_request_time = int(time())
existing_job.update_db()
else:
@ -131,7 +146,9 @@ async def schedule_index_job(
human_readable_storage_name=storage_container_name,
entity_extraction_prompt=entity_extraction_prompt_content,
entity_summarization_prompt=entity_summarization_prompt_content,
community_summarization_prompt=community_summarization_prompt_content,
community_summarization_graph_prompt=community_summarization_graph_content,
community_summarization_text_prompt=community_summarization_text_content,
indexing_method=indexing_method,
status=PipelineJobState.SCHEDULED,
)

View File

@ -12,8 +12,8 @@ from fastapi import (
Depends,
HTTPException,
)
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager
from graphrag_app.utils.common import sanitize_name, subscription_key_check
@ -46,12 +46,15 @@ async def generate_prompts(
detail=f"Storage container '{container_name}' does not exist.",
)
# load pipeline configuration file (settings.yaml) for input data and other settings
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
with (ROOT_DIR / "scripts/settings.yaml").open("r") as f:
data = yaml.safe_load(f)
data["input"]["container_name"] = sanitized_container_name
graphrag_config = create_graphrag_config(values=data, root_dir=".")
# load custom pipeline settings
ROOT_DIR = Path(__file__).resolve().parent.parent.parent / "scripts/settings.yaml"
# layer the custom settings on top of the default configuration settings of graphrag
graphrag_config: GraphRagConfig = load_config(
root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR
)
graphrag_config.input.container_name = sanitized_container_name
# generate prompts
try:

View File

@ -5,9 +5,8 @@ import hashlib
import os
import sys
import traceback
from typing import Annotated
from pathlib import Path
from typing import Dict, List
from typing import Annotated, Dict, List
import pandas as pd
from azure.core.exceptions import ResourceNotFoundError
@ -17,7 +16,6 @@ from azure.storage.blob.aio import ContainerClient
from fastapi import Header, HTTPException
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import QueryData
@ -287,6 +285,10 @@ def get_data_tables(
root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR
)
# update the config to use the correct blob storage containers
config.cache.container_name = index_names["sanitized_name"]
config.reporting.container_name = index_names["sanitized_name"]
config.output.container_name = index_names["sanitized_name"]
# dynamically assign the sanitized index name
config.vector_store["default_vector_store"].container_name = sanitized_name

View File

@ -7,6 +7,7 @@ from typing import (
)
from azure.cosmos.exceptions import CosmosHttpResponseError
from graphrag.config.enums import IndexingMethod
from graphrag_app.typing.pipeline import PipelineJobState
from graphrag_app.utils.azure_clients import AzureClientManager
@ -39,7 +40,9 @@ class PipelineJob:
_entity_extraction_prompt: str = field(default=None, init=False)
_entity_summarization_prompt: str = field(default=None, init=False)
_community_summarization_prompt: str = field(default=None, init=False)
_community_summarization_graph_prompt: str = field(default=None, init=False)
_community_summarization_text_prompt: str = field(default=None, init=False)
_indexing_method: str = field(default=IndexingMethod.Standard.value, init=False)
@staticmethod
def _jobs_container():
@ -56,7 +59,9 @@ class PipelineJob:
human_readable_storage_name: str,
entity_extraction_prompt: str | None = None,
entity_summarization_prompt: str | None = None,
community_summarization_prompt: str | None = None,
community_summarization_graph_prompt: str | None = None,
community_summarization_text_prompt: str | None = None,
indexing_method: str = IndexingMethod.Standard.value,
**kwargs,
) -> "PipelineJob":
"""
@ -112,7 +117,10 @@ class PipelineJob:
instance._entity_extraction_prompt = entity_extraction_prompt
instance._entity_summarization_prompt = entity_summarization_prompt
instance._community_summarization_prompt = community_summarization_prompt
instance._community_summarization_graph_prompt = community_summarization_graph_prompt
instance._community_summarization_text_prompt = community_summarization_text_prompt
instance._indexing_method = IndexingMethod(indexing_method).value
# Create the item in the database
instance.update_db()
@ -160,9 +168,15 @@ class PipelineJob:
instance._entity_summarization_prompt = db_item.get(
"entity_summarization_prompt"
)
instance._community_summarization_prompt = db_item.get(
"community_summarization_prompt"
instance._community_summarization_graph_prompt = db_item.get(
"community_summarization_graph_prompt"
)
instance._community_summarization_text_prompt = db_item.get(
"community_summarization_text_prompt"
)
instance._indexing_method = db_item.get("indexing_method")
return instance
@staticmethod
@ -200,14 +214,19 @@ class PipelineJob:
"status": self._status.value,
"percent_complete": self._percent_complete,
"progress": self._progress,
"indexing_method": self._indexing_method,
}
if self._entity_extraction_prompt:
model["entity_extraction_prompt"] = self._entity_extraction_prompt
if self._entity_summarization_prompt:
model["entity_summarization_prompt"] = self._entity_summarization_prompt
if self._community_summarization_prompt:
model["community_summarization_prompt"] = (
self._community_summarization_prompt
if self._community_summarization_graph_prompt:
model["community_summarization_graph_prompt"] = (
self._community_summarization_graph_prompt
)
if self._community_summarization_text_prompt:
model["community_summarization_text_prompt"] = (
self._community_summarization_text_prompt
)
return model
@ -291,14 +310,34 @@ class PipelineJob:
self.update_db()
@property
def community_summarization_prompt(self) -> str:
return self._community_summarization_prompt
def community_summarization_graph_prompt(self) -> str:
return self._community_summarization_graph_prompt
@community_summarization_prompt.setter
def community_summarization_prompt(
self, community_summarization_prompt: str
@community_summarization_graph_prompt.setter
def community_summarization_graph_prompt(
self, community_summarization_graph_prompt: str
) -> None:
self._community_summarization_prompt = community_summarization_prompt
self._community_summarization_graph_prompt = community_summarization_graph_prompt
self.update_db()
@property
def community_summarization_text_prompt(self) -> str:
return self._community_summarization_text_prompt
@community_summarization_text_prompt.setter
def community_summarization_text_prompt(
self, community_summarization_text_prompt: str
) -> None:
self._community_summarization_text_prompt = community_summarization_text_prompt
self.update_db()
@property
def indexing_method(self) -> str:
return self._indexing_method
@indexing_method.setter
def indexing_method(self, indexing_method: str) -> None:
self._indexing_method = IndexingMethod(indexing_method).value
self.update_db()
@property

View File

@ -7,11 +7,15 @@ import traceback
from pathlib import Path
import graphrag.api as api
import yaml
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.typing import PipelineRunResult
# from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.config.enums import IndexingMethod
from graphrag.config.load_config import load_config
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.workflows.factory import PipelineFactory
from graphrag_app.logger import (
PipelineJobUpdater,
@ -48,54 +52,76 @@ def start_indexing_job(index_name: str):
storage_name = pipeline_job.human_readable_index_name
# load custom pipeline settings
SCRIPT_DIR = Path(__file__).resolve().parent
with (SCRIPT_DIR / "settings.yaml").open("r") as f:
data = yaml.safe_load(f)
# dynamically set some values
data["input"]["container_name"] = sanitized_storage_name
data["storage"]["container_name"] = sanitized_index_name
data["reporting"]["container_name"] = sanitized_index_name
data["cache"]["container_name"] = sanitized_index_name
if "vector_store" in data["embeddings"]:
data["embeddings"]["vector_store"]["collection_name"] = (
f"{sanitized_index_name}_description_embedding"
ROOT_DIR = Path(__file__).resolve().parent / "settings.yaml"
config: GraphRagConfig = load_config(
root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR
)
# dynamically assign the sanitized index name
config.vector_store["default_vector_store"].container_name = sanitized_index_name
# set prompt for entity extraction
# dynamically set indexing storage values
config.input.container_name = sanitized_storage_name
config.output.container_name = sanitized_index_name
config.reporting.container_name = sanitized_index_name
config.cache.container_name = sanitized_index_name
# update extraction prompts
PROMPT_DIR = Path(__file__).resolve().parent
# set prompt for entity extraction / graph construction
if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_extraction_prompt)
data["entity_extraction"]["prompt"] = fname
# use the default prompt
config.extract_graph.prompt = None
else:
data.pop("entity_extraction")
# try to load the custom prompt
fname = "extract_graph.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.entity_extraction_prompt)
config.extract_graph.prompt = fname
# set prompt for entity summarization
if pipeline_job.entity_summarization_prompt:
fname = "entity-summarization-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_summarization_prompt)
data["summarize_descriptions"]["prompt"] = fname
if pipeline_job.entity_summarization_prompt is None:
# use the default prompt
config.summarize_descriptions.prompt = None
else:
data.pop("summarize_descriptions")
# try to load the custom prompt
fname = "summarize_descriptions.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.entity_summarization_prompt)
config.summarize_descriptions.prompt = fname
# set prompt for community summarization
if pipeline_job.community_summarization_prompt:
fname = "community-summarization-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.community_summarization_prompt)
data["community_reports"]["prompt"] = fname
# set prompt for community graph summarization
if pipeline_job.community_summarization_graph_prompt is None:
# use the default prompt
config.community_reports.graph_prompt = None
else:
data.pop("community_reports")
# try to load the custom prompt
fname = "community_report_graph.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.community_summarization_graph_prompt)
pipeline_job.community_summarization_graph_prompt = fname
# generate default graphrag config parameters and override with custom settings
parameters = create_graphrag_config(data, ".")
# set prompt for community text summarization
if pipeline_job.community_summarization_text_prompt is None:
# use the default prompt
config.community_reports.text_prompt = None
else:
fname = "community_report_text.txt"
# try to load the custom prompt
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.community_summarization_text_prompt)
config.community_reports.text_prompt = fname
# set the extraction strategy
indexing_method = IndexingMethod(pipeline_job.indexing_method)
pipeline_workflows = PipelineFactory.create_pipeline(config, indexing_method)
# reset pipeline job details
pipeline_job.status = PipelineJobState.RUNNING
pipeline_config = create_pipeline_config(parameters)
pipeline_job.all_workflows = [
workflow.name for workflow in pipeline_config.workflows
workflow for workflow in pipeline_workflows.names()
]
pipeline_job.completed_workflows = []
pipeline_job.failed_workflows = []
@ -117,7 +143,7 @@ def start_indexing_job(index_name: str):
print("Building index...")
pipeline_results: list[PipelineRunResult] = asyncio.run(
api.build_index(
config=parameters,
config=config,
callbacks=[logger, pipeline_job_updater],
)
)

View File

@ -68,7 +68,7 @@ models:
vector_store:
default_vector_store:
type: cosmosdb # or [lancedb, azure_ai_search, cosmosdb]
url: ${COSMOS_URL}
url: ${COSMOS_URI_ENDPOINT}
container_name: PLACEHOLDER
overwrite: True
database_name: vectordb
@ -85,7 +85,7 @@ extract_graph:
extract_graph_nlp:
text_analyzer:
extractor_type: syntactic_parser # [regex_english, syntactic_parser, cfg]
extractor_type: regex_english # [regex_english, syntactic_parser, cfg]
summarize_descriptions:
model_id: default_chat_model

View File

@ -29,6 +29,8 @@ GRAPHRAG_EMBEDDING_MODEL="text-embedding-ada-002"
GRAPHRAG_EMBEDDING_MODEL_VERSION="2"
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME="text-embedding-ada-002"
GRAPHRAG_EMBEDDING_MODEL_QUOTA="300"
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST="15"
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST="15"
requiredParams=(
LOCATION
@ -56,6 +58,8 @@ optionalParams=(
GRAPHRAG_EMBEDDING_MODEL_QUOTA
GRAPHRAG_EMBEDDING_MODEL_VERSION
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST
)
errorBanner () {
@ -527,6 +531,9 @@ installGraphRAGHelmChart () {
exitIfValueEmpty "$graphragEmbeddingModelDeployment" "Unable to parse embedding model deployment name from deployment outputs, exiting..."
fi
graphragLlmModelConcurrentRequest="$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST"
graphragEmbeddingModelConcurrentRequest="$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
reset_x=true
if ! [ -o xtrace ]; then
set -x
@ -552,7 +559,10 @@ installGraphRAGHelmChart () {
--set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$graphragLlmModelDeployment" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$graphragEmbeddingModel" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$graphragEmbeddingModelDeployment" \
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl"
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" \
--set "graphragConfig.GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
local helmResult
helmResult=$?