mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-10-11 16:58:54 +00:00
final version of the upgraded graphrag accelerator
This commit is contained in:
parent
226e7e95dd
commit
b6ef05335b
@ -13,6 +13,7 @@ from fastapi import (
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
)
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from kubernetes import (
|
||||
client as kubernetes_client,
|
||||
)
|
||||
@ -56,8 +57,12 @@ async def schedule_index_job(
|
||||
index_container_name: str,
|
||||
entity_extraction_prompt: UploadFile | None = None,
|
||||
entity_summarization_prompt: UploadFile | None = None,
|
||||
community_summarization_prompt: UploadFile | None = None,
|
||||
community_summarization_graph_prompt: UploadFile | None = None,
|
||||
community_summarization_text_prompt: UploadFile | None = None,
|
||||
indexing_method: IndexingMethod = IndexingMethod.Standard.value,
|
||||
):
|
||||
indexing_method = IndexingMethod(indexing_method).value
|
||||
|
||||
azure_client_manager = AzureClientManager()
|
||||
blob_service_client = azure_client_manager.get_blob_service_client()
|
||||
pipelinejob = PipelineJob()
|
||||
@ -86,9 +91,14 @@ async def schedule_index_job(
|
||||
if entity_summarization_prompt
|
||||
else None
|
||||
)
|
||||
community_summarization_prompt_content = (
|
||||
community_summarization_prompt.file.read().decode("utf-8")
|
||||
if community_summarization_prompt
|
||||
community_summarization_graph_content = (
|
||||
community_summarization_graph_prompt.file.read().decode("utf-8")
|
||||
if community_summarization_graph_prompt
|
||||
else None
|
||||
)
|
||||
community_summarization_text_content = (
|
||||
community_summarization_text_prompt.file.read().decode("utf-8")
|
||||
if community_summarization_text_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
@ -119,9 +129,14 @@ async def schedule_index_job(
|
||||
) = []
|
||||
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
|
||||
existing_job._entity_summarization_prompt = entity_summarization_prompt_content
|
||||
existing_job._community_summarization_prompt = (
|
||||
community_summarization_prompt_content
|
||||
existing_job.community_summarization_graph_prompt = (
|
||||
community_summarization_graph_content
|
||||
)
|
||||
existing_job.community_summarization_text_prompt = (
|
||||
community_summarization_text_content
|
||||
)
|
||||
existing_job._indexing_method = indexing_method
|
||||
|
||||
existing_job._epoch_request_time = int(time())
|
||||
existing_job.update_db()
|
||||
else:
|
||||
@ -131,7 +146,9 @@ async def schedule_index_job(
|
||||
human_readable_storage_name=storage_container_name,
|
||||
entity_extraction_prompt=entity_extraction_prompt_content,
|
||||
entity_summarization_prompt=entity_summarization_prompt_content,
|
||||
community_summarization_prompt=community_summarization_prompt_content,
|
||||
community_summarization_graph_prompt=community_summarization_graph_content,
|
||||
community_summarization_text_prompt=community_summarization_text_content,
|
||||
indexing_method=indexing_method,
|
||||
status=PipelineJobState.SCHEDULED,
|
||||
)
|
||||
|
||||
|
@ -12,8 +12,8 @@ from fastapi import (
|
||||
Depends,
|
||||
HTTPException,
|
||||
)
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||
from graphrag_app.utils.azure_clients import AzureClientManager
|
||||
from graphrag_app.utils.common import sanitize_name, subscription_key_check
|
||||
@ -46,12 +46,15 @@ async def generate_prompts(
|
||||
detail=f"Storage container '{container_name}' does not exist.",
|
||||
)
|
||||
|
||||
# load pipeline configuration file (settings.yaml) for input data and other settings
|
||||
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
with (ROOT_DIR / "scripts/settings.yaml").open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
data["input"]["container_name"] = sanitized_container_name
|
||||
graphrag_config = create_graphrag_config(values=data, root_dir=".")
|
||||
# load custom pipeline settings
|
||||
ROOT_DIR = Path(__file__).resolve().parent.parent.parent / "scripts/settings.yaml"
|
||||
|
||||
# layer the custom settings on top of the default configuration settings of graphrag
|
||||
graphrag_config: GraphRagConfig = load_config(
|
||||
root_dir=ROOT_DIR.parent,
|
||||
config_filepath=ROOT_DIR
|
||||
)
|
||||
graphrag_config.input.container_name = sanitized_container_name
|
||||
|
||||
# generate prompts
|
||||
try:
|
||||
|
@ -5,9 +5,8 @@ import hashlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Annotated
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Annotated, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
@ -17,7 +16,6 @@ from azure.storage.blob.aio import ContainerClient
|
||||
from fastapi import Header, HTTPException
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
|
||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||
from graphrag_app.typing.models import QueryData
|
||||
@ -287,6 +285,10 @@ def get_data_tables(
|
||||
root_dir=ROOT_DIR.parent,
|
||||
config_filepath=ROOT_DIR
|
||||
)
|
||||
# update the config to use the correct blob storage containers
|
||||
config.cache.container_name = index_names["sanitized_name"]
|
||||
config.reporting.container_name = index_names["sanitized_name"]
|
||||
config.output.container_name = index_names["sanitized_name"]
|
||||
# dynamically assign the sanitized index name
|
||||
config.vector_store["default_vector_store"].container_name = sanitized_name
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from azure.cosmos.exceptions import CosmosHttpResponseError
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
|
||||
from graphrag_app.typing.pipeline import PipelineJobState
|
||||
from graphrag_app.utils.azure_clients import AzureClientManager
|
||||
@ -39,7 +40,9 @@ class PipelineJob:
|
||||
|
||||
_entity_extraction_prompt: str = field(default=None, init=False)
|
||||
_entity_summarization_prompt: str = field(default=None, init=False)
|
||||
_community_summarization_prompt: str = field(default=None, init=False)
|
||||
_community_summarization_graph_prompt: str = field(default=None, init=False)
|
||||
_community_summarization_text_prompt: str = field(default=None, init=False)
|
||||
_indexing_method: str = field(default=IndexingMethod.Standard.value, init=False)
|
||||
|
||||
@staticmethod
|
||||
def _jobs_container():
|
||||
@ -56,7 +59,9 @@ class PipelineJob:
|
||||
human_readable_storage_name: str,
|
||||
entity_extraction_prompt: str | None = None,
|
||||
entity_summarization_prompt: str | None = None,
|
||||
community_summarization_prompt: str | None = None,
|
||||
community_summarization_graph_prompt: str | None = None,
|
||||
community_summarization_text_prompt: str | None = None,
|
||||
indexing_method: str = IndexingMethod.Standard.value,
|
||||
**kwargs,
|
||||
) -> "PipelineJob":
|
||||
"""
|
||||
@ -112,7 +117,10 @@ class PipelineJob:
|
||||
|
||||
instance._entity_extraction_prompt = entity_extraction_prompt
|
||||
instance._entity_summarization_prompt = entity_summarization_prompt
|
||||
instance._community_summarization_prompt = community_summarization_prompt
|
||||
instance._community_summarization_graph_prompt = community_summarization_graph_prompt
|
||||
instance._community_summarization_text_prompt = community_summarization_text_prompt
|
||||
|
||||
instance._indexing_method = IndexingMethod(indexing_method).value
|
||||
|
||||
# Create the item in the database
|
||||
instance.update_db()
|
||||
@ -160,9 +168,15 @@ class PipelineJob:
|
||||
instance._entity_summarization_prompt = db_item.get(
|
||||
"entity_summarization_prompt"
|
||||
)
|
||||
instance._community_summarization_prompt = db_item.get(
|
||||
"community_summarization_prompt"
|
||||
instance._community_summarization_graph_prompt = db_item.get(
|
||||
"community_summarization_graph_prompt"
|
||||
)
|
||||
instance._community_summarization_text_prompt = db_item.get(
|
||||
"community_summarization_text_prompt"
|
||||
)
|
||||
|
||||
instance._indexing_method = db_item.get("indexing_method")
|
||||
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
@ -200,14 +214,19 @@ class PipelineJob:
|
||||
"status": self._status.value,
|
||||
"percent_complete": self._percent_complete,
|
||||
"progress": self._progress,
|
||||
"indexing_method": self._indexing_method,
|
||||
}
|
||||
if self._entity_extraction_prompt:
|
||||
model["entity_extraction_prompt"] = self._entity_extraction_prompt
|
||||
if self._entity_summarization_prompt:
|
||||
model["entity_summarization_prompt"] = self._entity_summarization_prompt
|
||||
if self._community_summarization_prompt:
|
||||
model["community_summarization_prompt"] = (
|
||||
self._community_summarization_prompt
|
||||
if self._community_summarization_graph_prompt:
|
||||
model["community_summarization_graph_prompt"] = (
|
||||
self._community_summarization_graph_prompt
|
||||
)
|
||||
if self._community_summarization_text_prompt:
|
||||
model["community_summarization_text_prompt"] = (
|
||||
self._community_summarization_text_prompt
|
||||
)
|
||||
return model
|
||||
|
||||
@ -291,14 +310,34 @@ class PipelineJob:
|
||||
self.update_db()
|
||||
|
||||
@property
|
||||
def community_summarization_prompt(self) -> str:
|
||||
return self._community_summarization_prompt
|
||||
def community_summarization_graph_prompt(self) -> str:
|
||||
return self._community_summarization_graph_prompt
|
||||
|
||||
@community_summarization_prompt.setter
|
||||
def community_summarization_prompt(
|
||||
self, community_summarization_prompt: str
|
||||
@community_summarization_graph_prompt.setter
|
||||
def community_summarization_graph_prompt(
|
||||
self, community_summarization_graph_prompt: str
|
||||
) -> None:
|
||||
self._community_summarization_prompt = community_summarization_prompt
|
||||
self._community_summarization_graph_prompt = community_summarization_graph_prompt
|
||||
self.update_db()
|
||||
|
||||
@property
|
||||
def community_summarization_text_prompt(self) -> str:
|
||||
return self._community_summarization_text_prompt
|
||||
|
||||
@community_summarization_text_prompt.setter
|
||||
def community_summarization_text_prompt(
|
||||
self, community_summarization_text_prompt: str
|
||||
) -> None:
|
||||
self._community_summarization_text_prompt = community_summarization_text_prompt
|
||||
self.update_db()
|
||||
|
||||
@property
|
||||
def indexing_method(self) -> str:
|
||||
return self._indexing_method
|
||||
|
||||
@indexing_method.setter
|
||||
def indexing_method(self, indexing_method: str) -> None:
|
||||
self._indexing_method = IndexingMethod(indexing_method).value
|
||||
self.update_db()
|
||||
|
||||
@property
|
||||
|
@ -7,11 +7,15 @@ import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
import yaml
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from graphrag.index.typing import PipelineRunResult
|
||||
|
||||
# from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
|
||||
from graphrag_app.logger import (
|
||||
PipelineJobUpdater,
|
||||
@ -48,54 +52,76 @@ def start_indexing_job(index_name: str):
|
||||
storage_name = pipeline_job.human_readable_index_name
|
||||
|
||||
# load custom pipeline settings
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
with (SCRIPT_DIR / "settings.yaml").open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
# dynamically set some values
|
||||
data["input"]["container_name"] = sanitized_storage_name
|
||||
data["storage"]["container_name"] = sanitized_index_name
|
||||
data["reporting"]["container_name"] = sanitized_index_name
|
||||
data["cache"]["container_name"] = sanitized_index_name
|
||||
if "vector_store" in data["embeddings"]:
|
||||
data["embeddings"]["vector_store"]["collection_name"] = (
|
||||
f"{sanitized_index_name}_description_embedding"
|
||||
ROOT_DIR = Path(__file__).resolve().parent / "settings.yaml"
|
||||
config: GraphRagConfig = load_config(
|
||||
root_dir=ROOT_DIR.parent,
|
||||
config_filepath=ROOT_DIR
|
||||
)
|
||||
# dynamically assign the sanitized index name
|
||||
config.vector_store["default_vector_store"].container_name = sanitized_index_name
|
||||
|
||||
# set prompt for entity extraction
|
||||
# dynamically set indexing storage values
|
||||
config.input.container_name = sanitized_storage_name
|
||||
config.output.container_name = sanitized_index_name
|
||||
config.reporting.container_name = sanitized_index_name
|
||||
config.cache.container_name = sanitized_index_name
|
||||
|
||||
# update extraction prompts
|
||||
PROMPT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# set prompt for entity extraction / graph construction
|
||||
if pipeline_job.entity_extraction_prompt:
|
||||
fname = "entity-extraction-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.entity_extraction_prompt)
|
||||
data["entity_extraction"]["prompt"] = fname
|
||||
# use the default prompt
|
||||
config.extract_graph.prompt = None
|
||||
else:
|
||||
data.pop("entity_extraction")
|
||||
# try to load the custom prompt
|
||||
fname = "extract_graph.txt"
|
||||
with open(PROMPT_DIR / fname, "w") as file:
|
||||
file.write(pipeline_job.entity_extraction_prompt)
|
||||
config.extract_graph.prompt = fname
|
||||
|
||||
# set prompt for entity summarization
|
||||
if pipeline_job.entity_summarization_prompt:
|
||||
fname = "entity-summarization-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.entity_summarization_prompt)
|
||||
data["summarize_descriptions"]["prompt"] = fname
|
||||
if pipeline_job.entity_summarization_prompt is None:
|
||||
# use the default prompt
|
||||
config.summarize_descriptions.prompt = None
|
||||
else:
|
||||
data.pop("summarize_descriptions")
|
||||
# try to load the custom prompt
|
||||
fname = "summarize_descriptions.txt"
|
||||
with open(PROMPT_DIR / fname, "w") as file:
|
||||
file.write(pipeline_job.entity_summarization_prompt)
|
||||
config.summarize_descriptions.prompt = fname
|
||||
|
||||
# set prompt for community summarization
|
||||
if pipeline_job.community_summarization_prompt:
|
||||
fname = "community-summarization-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.community_summarization_prompt)
|
||||
data["community_reports"]["prompt"] = fname
|
||||
# set prompt for community graph summarization
|
||||
if pipeline_job.community_summarization_graph_prompt is None:
|
||||
# use the default prompt
|
||||
config.community_reports.graph_prompt = None
|
||||
else:
|
||||
data.pop("community_reports")
|
||||
# try to load the custom prompt
|
||||
fname = "community_report_graph.txt"
|
||||
with open(PROMPT_DIR / fname, "w") as file:
|
||||
file.write(pipeline_job.community_summarization_graph_prompt)
|
||||
pipeline_job.community_summarization_graph_prompt = fname
|
||||
|
||||
# generate default graphrag config parameters and override with custom settings
|
||||
parameters = create_graphrag_config(data, ".")
|
||||
# set prompt for community text summarization
|
||||
if pipeline_job.community_summarization_text_prompt is None:
|
||||
# use the default prompt
|
||||
config.community_reports.text_prompt = None
|
||||
else:
|
||||
fname = "community_report_text.txt"
|
||||
# try to load the custom prompt
|
||||
with open(PROMPT_DIR / fname, "w") as file:
|
||||
file.write(pipeline_job.community_summarization_text_prompt)
|
||||
config.community_reports.text_prompt = fname
|
||||
|
||||
# set the extraction strategy
|
||||
indexing_method = IndexingMethod(pipeline_job.indexing_method)
|
||||
pipeline_workflows = PipelineFactory.create_pipeline(config, indexing_method)
|
||||
|
||||
# reset pipeline job details
|
||||
pipeline_job.status = PipelineJobState.RUNNING
|
||||
pipeline_config = create_pipeline_config(parameters)
|
||||
|
||||
pipeline_job.all_workflows = [
|
||||
workflow.name for workflow in pipeline_config.workflows
|
||||
workflow for workflow in pipeline_workflows.names()
|
||||
]
|
||||
pipeline_job.completed_workflows = []
|
||||
pipeline_job.failed_workflows = []
|
||||
@ -117,7 +143,7 @@ def start_indexing_job(index_name: str):
|
||||
print("Building index...")
|
||||
pipeline_results: list[PipelineRunResult] = asyncio.run(
|
||||
api.build_index(
|
||||
config=parameters,
|
||||
config=config,
|
||||
callbacks=[logger, pipeline_job_updater],
|
||||
)
|
||||
)
|
||||
|
@ -68,7 +68,7 @@ models:
|
||||
vector_store:
|
||||
default_vector_store:
|
||||
type: cosmosdb # or [lancedb, azure_ai_search, cosmosdb]
|
||||
url: ${COSMOS_URL}
|
||||
url: ${COSMOS_URI_ENDPOINT}
|
||||
container_name: PLACEHOLDER
|
||||
overwrite: True
|
||||
database_name: vectordb
|
||||
@ -85,7 +85,7 @@ extract_graph:
|
||||
|
||||
extract_graph_nlp:
|
||||
text_analyzer:
|
||||
extractor_type: syntactic_parser # [regex_english, syntactic_parser, cfg]
|
||||
extractor_type: regex_english # [regex_english, syntactic_parser, cfg]
|
||||
|
||||
summarize_descriptions:
|
||||
model_id: default_chat_model
|
||||
|
@ -29,6 +29,8 @@ GRAPHRAG_EMBEDDING_MODEL="text-embedding-ada-002"
|
||||
GRAPHRAG_EMBEDDING_MODEL_VERSION="2"
|
||||
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME="text-embedding-ada-002"
|
||||
GRAPHRAG_EMBEDDING_MODEL_QUOTA="300"
|
||||
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST="15"
|
||||
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST="15"
|
||||
|
||||
requiredParams=(
|
||||
LOCATION
|
||||
@ -56,6 +58,8 @@ optionalParams=(
|
||||
GRAPHRAG_EMBEDDING_MODEL_QUOTA
|
||||
GRAPHRAG_EMBEDDING_MODEL_VERSION
|
||||
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
|
||||
GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST
|
||||
GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST
|
||||
)
|
||||
|
||||
errorBanner () {
|
||||
@ -527,6 +531,9 @@ installGraphRAGHelmChart () {
|
||||
exitIfValueEmpty "$graphragEmbeddingModelDeployment" "Unable to parse embedding model deployment name from deployment outputs, exiting..."
|
||||
fi
|
||||
|
||||
graphragLlmModelConcurrentRequest="$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST"
|
||||
graphragEmbeddingModelConcurrentRequest="$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
|
||||
|
||||
reset_x=true
|
||||
if ! [ -o xtrace ]; then
|
||||
set -x
|
||||
@ -552,7 +559,10 @@ installGraphRAGHelmChart () {
|
||||
--set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$graphragLlmModelDeployment" \
|
||||
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$graphragEmbeddingModel" \
|
||||
--set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$graphragEmbeddingModelDeployment" \
|
||||
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl"
|
||||
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" \
|
||||
--set "graphragConfig.GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST" \
|
||||
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST"
|
||||
|
||||
|
||||
local helmResult
|
||||
helmResult=$?
|
||||
|
Loading…
x
Reference in New Issue
Block a user