final version of the upgraded graphrag accelerator

This commit is contained in:
Gabriel Nieves 2025-04-04 15:52:21 +00:00
parent 226e7e95dd
commit b6ef05335b
7 changed files with 172 additions and 75 deletions

View File

@ -13,6 +13,7 @@ from fastapi import (
HTTPException, HTTPException,
UploadFile, UploadFile,
) )
from graphrag.config.enums import IndexingMethod
from kubernetes import ( from kubernetes import (
client as kubernetes_client, client as kubernetes_client,
) )
@ -56,8 +57,12 @@ async def schedule_index_job(
index_container_name: str, index_container_name: str,
entity_extraction_prompt: UploadFile | None = None, entity_extraction_prompt: UploadFile | None = None,
entity_summarization_prompt: UploadFile | None = None, entity_summarization_prompt: UploadFile | None = None,
community_summarization_prompt: UploadFile | None = None, community_summarization_graph_prompt: UploadFile | None = None,
community_summarization_text_prompt: UploadFile | None = None,
indexing_method: IndexingMethod = IndexingMethod.Standard.value,
): ):
indexing_method = IndexingMethod(indexing_method).value
azure_client_manager = AzureClientManager() azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client() blob_service_client = azure_client_manager.get_blob_service_client()
pipelinejob = PipelineJob() pipelinejob = PipelineJob()
@ -86,9 +91,14 @@ async def schedule_index_job(
if entity_summarization_prompt if entity_summarization_prompt
else None else None
) )
community_summarization_prompt_content = ( community_summarization_graph_content = (
community_summarization_prompt.file.read().decode("utf-8") community_summarization_graph_prompt.file.read().decode("utf-8")
if community_summarization_prompt if community_summarization_graph_prompt
else None
)
community_summarization_text_content = (
community_summarization_text_prompt.file.read().decode("utf-8")
if community_summarization_text_prompt
else None else None
) )
@ -119,9 +129,14 @@ async def schedule_index_job(
) = [] ) = []
existing_job._entity_extraction_prompt = entity_extraction_prompt_content existing_job._entity_extraction_prompt = entity_extraction_prompt_content
existing_job._entity_summarization_prompt = entity_summarization_prompt_content existing_job._entity_summarization_prompt = entity_summarization_prompt_content
existing_job._community_summarization_prompt = ( existing_job.community_summarization_graph_prompt = (
community_summarization_prompt_content community_summarization_graph_content
) )
existing_job.community_summarization_text_prompt = (
community_summarization_text_content
)
existing_job._indexing_method = indexing_method
existing_job._epoch_request_time = int(time()) existing_job._epoch_request_time = int(time())
existing_job.update_db() existing_job.update_db()
else: else:
@ -131,7 +146,9 @@ async def schedule_index_job(
human_readable_storage_name=storage_container_name, human_readable_storage_name=storage_container_name,
entity_extraction_prompt=entity_extraction_prompt_content, entity_extraction_prompt=entity_extraction_prompt_content,
entity_summarization_prompt=entity_summarization_prompt_content, entity_summarization_prompt=entity_summarization_prompt_content,
community_summarization_prompt=community_summarization_prompt_content, community_summarization_graph_prompt=community_summarization_graph_content,
community_summarization_text_prompt=community_summarization_text_content,
indexing_method=indexing_method,
status=PipelineJobState.SCHEDULED, status=PipelineJobState.SCHEDULED,
) )

View File

@ -12,8 +12,8 @@ from fastapi import (
Depends, Depends,
HTTPException, HTTPException,
) )
from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager from graphrag_app.utils.azure_clients import AzureClientManager
from graphrag_app.utils.common import sanitize_name, subscription_key_check from graphrag_app.utils.common import sanitize_name, subscription_key_check
@ -46,12 +46,15 @@ async def generate_prompts(
detail=f"Storage container '{container_name}' does not exist.", detail=f"Storage container '{container_name}' does not exist.",
) )
# load pipeline configuration file (settings.yaml) for input data and other settings # load custom pipeline settings
ROOT_DIR = Path(__file__).resolve().parent.parent.parent ROOT_DIR = Path(__file__).resolve().parent.parent.parent / "scripts/settings.yaml"
with (ROOT_DIR / "scripts/settings.yaml").open("r") as f:
data = yaml.safe_load(f) # layer the custom settings on top of the default configuration settings of graphrag
data["input"]["container_name"] = sanitized_container_name graphrag_config: GraphRagConfig = load_config(
graphrag_config = create_graphrag_config(values=data, root_dir=".") root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR
)
graphrag_config.input.container_name = sanitized_container_name
# generate prompts # generate prompts
try: try:

View File

@ -5,9 +5,8 @@ import hashlib
import os import os
import sys import sys
import traceback import traceback
from typing import Annotated
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Annotated, Dict, List
import pandas as pd import pandas as pd
from azure.core.exceptions import ResourceNotFoundError from azure.core.exceptions import ResourceNotFoundError
@ -17,7 +16,6 @@ from azure.storage.blob.aio import ContainerClient
from fastapi import Header, HTTPException from fastapi import Header, HTTPException
from graphrag.config.load_config import load_config from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import QueryData from graphrag_app.typing.models import QueryData
@ -287,6 +285,10 @@ def get_data_tables(
root_dir=ROOT_DIR.parent, root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR config_filepath=ROOT_DIR
) )
# update the config to use the correct blob storage containers
config.cache.container_name = index_names["sanitized_name"]
config.reporting.container_name = index_names["sanitized_name"]
config.output.container_name = index_names["sanitized_name"]
# dynamically assign the sanitized index name # dynamically assign the sanitized index name
config.vector_store["default_vector_store"].container_name = sanitized_name config.vector_store["default_vector_store"].container_name = sanitized_name

View File

@ -7,6 +7,7 @@ from typing import (
) )
from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.exceptions import CosmosHttpResponseError
from graphrag.config.enums import IndexingMethod
from graphrag_app.typing.pipeline import PipelineJobState from graphrag_app.typing.pipeline import PipelineJobState
from graphrag_app.utils.azure_clients import AzureClientManager from graphrag_app.utils.azure_clients import AzureClientManager
@ -39,7 +40,9 @@ class PipelineJob:
_entity_extraction_prompt: str = field(default=None, init=False) _entity_extraction_prompt: str = field(default=None, init=False)
_entity_summarization_prompt: str = field(default=None, init=False) _entity_summarization_prompt: str = field(default=None, init=False)
_community_summarization_prompt: str = field(default=None, init=False) _community_summarization_graph_prompt: str = field(default=None, init=False)
_community_summarization_text_prompt: str = field(default=None, init=False)
_indexing_method: str = field(default=IndexingMethod.Standard.value, init=False)
@staticmethod @staticmethod
def _jobs_container(): def _jobs_container():
@ -56,7 +59,9 @@ class PipelineJob:
human_readable_storage_name: str, human_readable_storage_name: str,
entity_extraction_prompt: str | None = None, entity_extraction_prompt: str | None = None,
entity_summarization_prompt: str | None = None, entity_summarization_prompt: str | None = None,
community_summarization_prompt: str | None = None, community_summarization_graph_prompt: str | None = None,
community_summarization_text_prompt: str | None = None,
indexing_method: str = IndexingMethod.Standard.value,
**kwargs, **kwargs,
) -> "PipelineJob": ) -> "PipelineJob":
""" """
@ -112,7 +117,10 @@ class PipelineJob:
instance._entity_extraction_prompt = entity_extraction_prompt instance._entity_extraction_prompt = entity_extraction_prompt
instance._entity_summarization_prompt = entity_summarization_prompt instance._entity_summarization_prompt = entity_summarization_prompt
instance._community_summarization_prompt = community_summarization_prompt instance._community_summarization_graph_prompt = community_summarization_graph_prompt
instance._community_summarization_text_prompt = community_summarization_text_prompt
instance._indexing_method = IndexingMethod(indexing_method).value
# Create the item in the database # Create the item in the database
instance.update_db() instance.update_db()
@ -160,9 +168,15 @@ class PipelineJob:
instance._entity_summarization_prompt = db_item.get( instance._entity_summarization_prompt = db_item.get(
"entity_summarization_prompt" "entity_summarization_prompt"
) )
instance._community_summarization_prompt = db_item.get( instance._community_summarization_graph_prompt = db_item.get(
"community_summarization_prompt" "community_summarization_graph_prompt"
) )
instance._community_summarization_text_prompt = db_item.get(
"community_summarization_text_prompt"
)
instance._indexing_method = db_item.get("indexing_method")
return instance return instance
@staticmethod @staticmethod
@ -200,14 +214,19 @@ class PipelineJob:
"status": self._status.value, "status": self._status.value,
"percent_complete": self._percent_complete, "percent_complete": self._percent_complete,
"progress": self._progress, "progress": self._progress,
"indexing_method": self._indexing_method,
} }
if self._entity_extraction_prompt: if self._entity_extraction_prompt:
model["entity_extraction_prompt"] = self._entity_extraction_prompt model["entity_extraction_prompt"] = self._entity_extraction_prompt
if self._entity_summarization_prompt: if self._entity_summarization_prompt:
model["entity_summarization_prompt"] = self._entity_summarization_prompt model["entity_summarization_prompt"] = self._entity_summarization_prompt
if self._community_summarization_prompt: if self._community_summarization_graph_prompt:
model["community_summarization_prompt"] = ( model["community_summarization_graph_prompt"] = (
self._community_summarization_prompt self._community_summarization_graph_prompt
)
if self._community_summarization_text_prompt:
model["community_summarization_text_prompt"] = (
self._community_summarization_text_prompt
) )
return model return model
@ -291,14 +310,34 @@ class PipelineJob:
self.update_db() self.update_db()
@property @property
def community_summarization_prompt(self) -> str: def community_summarization_graph_prompt(self) -> str:
return self._community_summarization_prompt return self._community_summarization_graph_prompt
@community_summarization_prompt.setter @community_summarization_graph_prompt.setter
def community_summarization_prompt( def community_summarization_graph_prompt(
self, community_summarization_prompt: str self, community_summarization_graph_prompt: str
) -> None: ) -> None:
self._community_summarization_prompt = community_summarization_prompt self._community_summarization_graph_prompt = community_summarization_graph_prompt
self.update_db()
@property
def community_summarization_text_prompt(self) -> str:
return self._community_summarization_text_prompt
@community_summarization_text_prompt.setter
def community_summarization_text_prompt(
self, community_summarization_text_prompt: str
) -> None:
self._community_summarization_text_prompt = community_summarization_text_prompt
self.update_db()
@property
def indexing_method(self) -> str:
return self._indexing_method
@indexing_method.setter
def indexing_method(self, indexing_method: str) -> None:
self._indexing_method = IndexingMethod(indexing_method).value
self.update_db() self.update_db()
@property @property

View File

@ -7,11 +7,15 @@ import traceback
from pathlib import Path from pathlib import Path
import graphrag.api as api import graphrag.api as api
import yaml
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.create_pipeline_config import create_pipeline_config # from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.typing import PipelineRunResult from graphrag.config.enums import IndexingMethod
from graphrag.config.load_config import load_config
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.workflows.factory import PipelineFactory
from graphrag_app.logger import ( from graphrag_app.logger import (
PipelineJobUpdater, PipelineJobUpdater,
@ -48,54 +52,76 @@ def start_indexing_job(index_name: str):
storage_name = pipeline_job.human_readable_index_name storage_name = pipeline_job.human_readable_index_name
# load custom pipeline settings # load custom pipeline settings
SCRIPT_DIR = Path(__file__).resolve().parent ROOT_DIR = Path(__file__).resolve().parent / "settings.yaml"
with (SCRIPT_DIR / "settings.yaml").open("r") as f: config: GraphRagConfig = load_config(
data = yaml.safe_load(f) root_dir=ROOT_DIR.parent,
# dynamically set some values config_filepath=ROOT_DIR
data["input"]["container_name"] = sanitized_storage_name )
data["storage"]["container_name"] = sanitized_index_name # dynamically assign the sanitized index name
data["reporting"]["container_name"] = sanitized_index_name config.vector_store["default_vector_store"].container_name = sanitized_index_name
data["cache"]["container_name"] = sanitized_index_name
if "vector_store" in data["embeddings"]:
data["embeddings"]["vector_store"]["collection_name"] = (
f"{sanitized_index_name}_description_embedding"
)
# set prompt for entity extraction # dynamically set indexing storage values
config.input.container_name = sanitized_storage_name
config.output.container_name = sanitized_index_name
config.reporting.container_name = sanitized_index_name
config.cache.container_name = sanitized_index_name
# update extraction prompts
PROMPT_DIR = Path(__file__).resolve().parent
# set prompt for entity extraction / graph construction
if pipeline_job.entity_extraction_prompt: if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt" # use the default prompt
with open(fname, "w") as outfile: config.extract_graph.prompt = None
outfile.write(pipeline_job.entity_extraction_prompt)
data["entity_extraction"]["prompt"] = fname
else: else:
data.pop("entity_extraction") # try to load the custom prompt
fname = "extract_graph.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.entity_extraction_prompt)
config.extract_graph.prompt = fname
# set prompt for entity summarization # set prompt for entity summarization
if pipeline_job.entity_summarization_prompt: if pipeline_job.entity_summarization_prompt is None:
fname = "entity-summarization-prompt.txt" # use the default prompt
with open(fname, "w") as outfile: config.summarize_descriptions.prompt = None
outfile.write(pipeline_job.entity_summarization_prompt)
data["summarize_descriptions"]["prompt"] = fname
else: else:
data.pop("summarize_descriptions") # try to load the custom prompt
fname = "summarize_descriptions.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.entity_summarization_prompt)
config.summarize_descriptions.prompt = fname
# set prompt for community summarization # set prompt for community graph summarization
if pipeline_job.community_summarization_prompt: if pipeline_job.community_summarization_graph_prompt is None:
fname = "community-summarization-prompt.txt" # use the default prompt
with open(fname, "w") as outfile: config.community_reports.graph_prompt = None
outfile.write(pipeline_job.community_summarization_prompt)
data["community_reports"]["prompt"] = fname
else: else:
data.pop("community_reports") # try to load the custom prompt
fname = "community_report_graph.txt"
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.community_summarization_graph_prompt)
pipeline_job.community_summarization_graph_prompt = fname
# generate default graphrag config parameters and override with custom settings # set prompt for community text summarization
parameters = create_graphrag_config(data, ".") if pipeline_job.community_summarization_text_prompt is None:
# use the default prompt
config.community_reports.text_prompt = None
else:
fname = "community_report_text.txt"
# try to load the custom prompt
with open(PROMPT_DIR / fname, "w") as file:
file.write(pipeline_job.community_summarization_text_prompt)
config.community_reports.text_prompt = fname
# set the extraction strategy
indexing_method = IndexingMethod(pipeline_job.indexing_method)
pipeline_workflows = PipelineFactory.create_pipeline(config, indexing_method)
# reset pipeline job details # reset pipeline job details
pipeline_job.status = PipelineJobState.RUNNING pipeline_job.status = PipelineJobState.RUNNING
pipeline_config = create_pipeline_config(parameters)
pipeline_job.all_workflows = [ pipeline_job.all_workflows = [
workflow.name for workflow in pipeline_config.workflows workflow for workflow in pipeline_workflows.names()
] ]
pipeline_job.completed_workflows = [] pipeline_job.completed_workflows = []
pipeline_job.failed_workflows = [] pipeline_job.failed_workflows = []
@ -117,7 +143,7 @@ def start_indexing_job(index_name: str):
print("Building index...") print("Building index...")
pipeline_results: list[PipelineRunResult] = asyncio.run( pipeline_results: list[PipelineRunResult] = asyncio.run(
api.build_index( api.build_index(
config=parameters, config=config,
callbacks=[logger, pipeline_job_updater], callbacks=[logger, pipeline_job_updater],
) )
) )

View File

@ -68,7 +68,7 @@ models:
vector_store: vector_store:
default_vector_store: default_vector_store:
type: cosmosdb # or [lancedb, azure_ai_search, cosmosdb] type: cosmosdb # or [lancedb, azure_ai_search, cosmosdb]
url: ${COSMOS_URL} url: ${COSMOS_URI_ENDPOINT}
container_name: PLACEHOLDER container_name: PLACEHOLDER
overwrite: True overwrite: True
database_name: vectordb database_name: vectordb
@ -85,7 +85,7 @@ extract_graph:
extract_graph_nlp: extract_graph_nlp:
text_analyzer: text_analyzer:
extractor_type: syntactic_parser # [regex_english, syntactic_parser, cfg] extractor_type: regex_english # [regex_english, syntactic_parser, cfg]
summarize_descriptions: summarize_descriptions:
model_id: default_chat_model model_id: default_chat_model

View File

@ -29,6 +29,8 @@ GRAPHRAG_EMBEDDING_MODEL="text-embedding-ada-002"
GRAPHRAG_EMBEDDING_MODEL_VERSION="2" GRAPHRAG_EMBEDDING_MODEL_VERSION="2"
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME="text-embedding-ada-002" GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME="text-embedding-ada-002"
GRAPHRAG_EMBEDDING_MODEL_QUOTA="300" GRAPHRAG_EMBEDDING_MODEL_QUOTA="300"
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST="15"
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST="15"
requiredParams=( requiredParams=(
LOCATION LOCATION
@ -56,6 +58,8 @@ optionalParams=(
GRAPHRAG_EMBEDDING_MODEL_QUOTA GRAPHRAG_EMBEDDING_MODEL_QUOTA
GRAPHRAG_EMBEDDING_MODEL_VERSION GRAPHRAG_EMBEDDING_MODEL_VERSION
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST
) )
errorBanner () { errorBanner () {
@ -527,6 +531,9 @@ installGraphRAGHelmChart () {
exitIfValueEmpty "$graphragEmbeddingModelDeployment" "Unable to parse embedding model deployment name from deployment outputs, exiting..." exitIfValueEmpty "$graphragEmbeddingModelDeployment" "Unable to parse embedding model deployment name from deployment outputs, exiting..."
fi fi
graphragLlmModelConcurrentRequest="$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST"
graphragEmbeddingModelConcurrentRequest="$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
reset_x=true reset_x=true
if ! [ -o xtrace ]; then if ! [ -o xtrace ]; then
set -x set -x
@ -552,7 +559,10 @@ installGraphRAGHelmChart () {
--set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$graphragLlmModelDeployment" \ --set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$graphragLlmModelDeployment" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$graphragEmbeddingModel" \ --set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$graphragEmbeddingModel" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$graphragEmbeddingModelDeployment" \ --set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$graphragEmbeddingModelDeployment" \
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" --set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" \
--set "graphragConfig.GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
local helmResult local helmResult
helmResult=$? helmResult=$?