diff --git a/backend/graphrag_app/api/index.py b/backend/graphrag_app/api/index.py index 6ec3d55..0c55da9 100644 --- a/backend/graphrag_app/api/index.py +++ b/backend/graphrag_app/api/index.py @@ -13,6 +13,7 @@ from fastapi import ( HTTPException, UploadFile, ) +from graphrag.config.enums import IndexingMethod from kubernetes import ( client as kubernetes_client, ) @@ -56,8 +57,12 @@ async def schedule_index_job( index_container_name: str, entity_extraction_prompt: UploadFile | None = None, entity_summarization_prompt: UploadFile | None = None, - community_summarization_prompt: UploadFile | None = None, + community_summarization_graph_prompt: UploadFile | None = None, + community_summarization_text_prompt: UploadFile | None = None, + indexing_method: IndexingMethod = IndexingMethod.Standard.value, ): + indexing_method = IndexingMethod(indexing_method).value + azure_client_manager = AzureClientManager() blob_service_client = azure_client_manager.get_blob_service_client() pipelinejob = PipelineJob() @@ -86,9 +91,14 @@ async def schedule_index_job( if entity_summarization_prompt else None ) - community_summarization_prompt_content = ( - community_summarization_prompt.file.read().decode("utf-8") - if community_summarization_prompt + community_summarization_graph_content = ( + community_summarization_graph_prompt.file.read().decode("utf-8") + if community_summarization_graph_prompt + else None + ) + community_summarization_text_content = ( + community_summarization_text_prompt.file.read().decode("utf-8") + if community_summarization_text_prompt else None ) @@ -119,9 +129,14 @@ async def schedule_index_job( ) = [] existing_job._entity_extraction_prompt = entity_extraction_prompt_content existing_job._entity_summarization_prompt = entity_summarization_prompt_content - existing_job._community_summarization_prompt = ( - community_summarization_prompt_content + existing_job.community_summarization_graph_prompt = ( + community_summarization_graph_content ) + existing_job.community_summarization_text_prompt = ( + community_summarization_text_content + ) + existing_job._indexing_method = indexing_method + existing_job._epoch_request_time = int(time()) existing_job.update_db() else: @@ -131,7 +146,9 @@ async def schedule_index_job( human_readable_storage_name=storage_container_name, entity_extraction_prompt=entity_extraction_prompt_content, entity_summarization_prompt=entity_summarization_prompt_content, - community_summarization_prompt=community_summarization_prompt_content, + community_summarization_graph_prompt=community_summarization_graph_content, + community_summarization_text_prompt=community_summarization_text_content, + indexing_method=indexing_method, status=PipelineJobState.SCHEDULED, ) diff --git a/backend/graphrag_app/api/prompt_tuning.py b/backend/graphrag_app/api/prompt_tuning.py index 1a6f078..65427b7 100644 --- a/backend/graphrag_app/api/prompt_tuning.py +++ b/backend/graphrag_app/api/prompt_tuning.py @@ -12,8 +12,8 @@ from fastapi import ( Depends, HTTPException, ) -from graphrag.config.create_graphrag_config import create_graphrag_config - +from graphrag.config.load_config import load_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.utils.azure_clients import AzureClientManager from graphrag_app.utils.common import sanitize_name, subscription_key_check @@ -46,12 +46,15 @@ async def generate_prompts( detail=f"Storage container '{container_name}' does not exist.", ) - # load pipeline configuration file (settings.yaml) for input data and other settings - ROOT_DIR = Path(__file__).resolve().parent.parent.parent - with (ROOT_DIR / "scripts/settings.yaml").open("r") as f: - data = yaml.safe_load(f) - data["input"]["container_name"] = sanitized_container_name - graphrag_config = create_graphrag_config(values=data, root_dir=".") + # load custom pipeline settings + ROOT_DIR = Path(__file__).resolve().parent.parent.parent / "scripts/settings.yaml" + + # layer the custom settings on top of the default configuration settings of graphrag + graphrag_config: GraphRagConfig = load_config( + root_dir=ROOT_DIR.parent, + config_filepath=ROOT_DIR + ) + graphrag_config.input.container_name = sanitized_container_name # generate prompts try: diff --git a/backend/graphrag_app/utils/common.py b/backend/graphrag_app/utils/common.py index 4827d28..16e4408 100644 --- a/backend/graphrag_app/utils/common.py +++ b/backend/graphrag_app/utils/common.py @@ -5,9 +5,8 @@ import hashlib import os import sys import traceback -from typing import Annotated from pathlib import Path -from typing import Dict, List +from typing import Annotated, Dict, List import pandas as pd from azure.core.exceptions import ResourceNotFoundError @@ -17,7 +16,6 @@ from azure.storage.blob.aio import ContainerClient from fastapi import Header, HTTPException from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.vector_store_config import VectorStoreConfig from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.typing.models import QueryData @@ -287,6 +285,10 @@ def get_data_tables( root_dir=ROOT_DIR.parent, config_filepath=ROOT_DIR ) + # update the config to use the correct blob storage containers + config.cache.container_name = index_names["sanitized_name"] + config.reporting.container_name = index_names["sanitized_name"] + config.output.container_name = index_names["sanitized_name"] # dynamically assign the sanitized index name config.vector_store["default_vector_store"].container_name = sanitized_name diff --git a/backend/graphrag_app/utils/pipeline.py b/backend/graphrag_app/utils/pipeline.py index 1e1b3ab..0236327 100644 --- a/backend/graphrag_app/utils/pipeline.py +++ b/backend/graphrag_app/utils/pipeline.py @@ -7,6 +7,7 @@ from typing import ( ) from azure.cosmos.exceptions import CosmosHttpResponseError +from graphrag.config.enums import IndexingMethod from graphrag_app.typing.pipeline import PipelineJobState from graphrag_app.utils.azure_clients import AzureClientManager @@ -39,7 +40,9 @@ class PipelineJob: _entity_extraction_prompt: str = field(default=None, init=False) _entity_summarization_prompt: str = field(default=None, init=False) - _community_summarization_prompt: str = field(default=None, init=False) + _community_summarization_graph_prompt: str = field(default=None, init=False) + _community_summarization_text_prompt: str = field(default=None, init=False) + _indexing_method: str = field(default=IndexingMethod.Standard.value, init=False) @staticmethod def _jobs_container(): @@ -56,7 +59,9 @@ class PipelineJob: human_readable_storage_name: str, entity_extraction_prompt: str | None = None, entity_summarization_prompt: str | None = None, - community_summarization_prompt: str | None = None, + community_summarization_graph_prompt: str | None = None, + community_summarization_text_prompt: str | None = None, + indexing_method: str = IndexingMethod.Standard.value, **kwargs, ) -> "PipelineJob": """ @@ -112,7 +117,10 @@ class PipelineJob: instance._entity_extraction_prompt = entity_extraction_prompt instance._entity_summarization_prompt = entity_summarization_prompt - instance._community_summarization_prompt = community_summarization_prompt + instance._community_summarization_graph_prompt = community_summarization_graph_prompt + instance._community_summarization_text_prompt = community_summarization_text_prompt + + instance._indexing_method = IndexingMethod(indexing_method).value # Create the item in the database instance.update_db() @@ -160,9 +168,15 @@ class PipelineJob: instance._entity_summarization_prompt = db_item.get( "entity_summarization_prompt" ) - instance._community_summarization_prompt = db_item.get( - "community_summarization_prompt" + instance._community_summarization_graph_prompt = db_item.get( + "community_summarization_graph_prompt" ) + instance._community_summarization_text_prompt = db_item.get( + "community_summarization_text_prompt" + ) + + instance._indexing_method = db_item.get("indexing_method") + return instance @staticmethod @@ -200,14 +214,19 @@ class PipelineJob: "status": self._status.value, "percent_complete": self._percent_complete, "progress": self._progress, + "indexing_method": self._indexing_method, } if self._entity_extraction_prompt: model["entity_extraction_prompt"] = self._entity_extraction_prompt if self._entity_summarization_prompt: model["entity_summarization_prompt"] = self._entity_summarization_prompt - if self._community_summarization_prompt: - model["community_summarization_prompt"] = ( - self._community_summarization_prompt + if self._community_summarization_graph_prompt: + model["community_summarization_graph_prompt"] = ( + self._community_summarization_graph_prompt + ) + if self._community_summarization_text_prompt: + model["community_summarization_text_prompt"] = ( + self._community_summarization_text_prompt ) return model @@ -291,14 +310,34 @@ class PipelineJob: self.update_db() @property - def community_summarization_prompt(self) -> str: - return self._community_summarization_prompt + def community_summarization_graph_prompt(self) -> str: + return self._community_summarization_graph_prompt - @community_summarization_prompt.setter - def community_summarization_prompt( - self, community_summarization_prompt: str + @community_summarization_graph_prompt.setter + def community_summarization_graph_prompt( + self, community_summarization_graph_prompt: str ) -> None: - self._community_summarization_prompt = community_summarization_prompt + self._community_summarization_graph_prompt = community_summarization_graph_prompt + self.update_db() + + @property + def community_summarization_text_prompt(self) -> str: + return self._community_summarization_text_prompt + + @community_summarization_text_prompt.setter + def community_summarization_text_prompt( + self, community_summarization_text_prompt: str + ) -> None: + self._community_summarization_text_prompt = community_summarization_text_prompt + self.update_db() + + @property + def indexing_method(self) -> str: + return self._indexing_method + + @indexing_method.setter + def indexing_method(self, indexing_method: str) -> None: + self._indexing_method = IndexingMethod(indexing_method).value self.update_db() @property diff --git a/backend/scripts/indexer.py b/backend/scripts/indexer.py index ed8dcfb..bdcd978 100644 --- a/backend/scripts/indexer.py +++ b/backend/scripts/indexer.py @@ -7,11 +7,15 @@ import traceback from pathlib import Path import graphrag.api as api -import yaml from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.index.typing import PipelineRunResult + +# from graphrag.index.create_pipeline_config import create_pipeline_config +from graphrag.config.enums import IndexingMethod +from graphrag.config.load_config import load_config +from graphrag.config.models.community_reports_config import CommunityReportsConfig +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.typing.pipeline_run_result import PipelineRunResult +from graphrag.index.workflows.factory import PipelineFactory from graphrag_app.logger import ( PipelineJobUpdater, @@ -48,54 +52,76 @@ def start_indexing_job(index_name: str): storage_name = pipeline_job.human_readable_index_name # load custom pipeline settings - SCRIPT_DIR = Path(__file__).resolve().parent - with (SCRIPT_DIR / "settings.yaml").open("r") as f: - data = yaml.safe_load(f) - # dynamically set some values - data["input"]["container_name"] = sanitized_storage_name - data["storage"]["container_name"] = sanitized_index_name - data["reporting"]["container_name"] = sanitized_index_name - data["cache"]["container_name"] = sanitized_index_name - if "vector_store" in data["embeddings"]: - data["embeddings"]["vector_store"]["collection_name"] = ( - f"{sanitized_index_name}_description_embedding" - ) + ROOT_DIR = Path(__file__).resolve().parent / "settings.yaml" + config: GraphRagConfig = load_config( + root_dir=ROOT_DIR.parent, + config_filepath=ROOT_DIR + ) + # dynamically assign the sanitized index name + config.vector_store["default_vector_store"].container_name = sanitized_index_name - # set prompt for entity extraction + # dynamically set indexing storage values + config.input.container_name = sanitized_storage_name + config.output.container_name = sanitized_index_name + config.reporting.container_name = sanitized_index_name + config.cache.container_name = sanitized_index_name + + # update extraction prompts + PROMPT_DIR = Path(__file__).resolve().parent + + # set prompt for entity extraction / graph construction if pipeline_job.entity_extraction_prompt: - fname = "entity-extraction-prompt.txt" - with open(fname, "w") as outfile: - outfile.write(pipeline_job.entity_extraction_prompt) - data["entity_extraction"]["prompt"] = fname + # use the default prompt + config.extract_graph.prompt = None else: - data.pop("entity_extraction") + # try to load the custom prompt + fname = "extract_graph.txt" + with open(PROMPT_DIR / fname, "w") as file: + file.write(pipeline_job.entity_extraction_prompt) + config.extract_graph.prompt = fname # set prompt for entity summarization - if pipeline_job.entity_summarization_prompt: - fname = "entity-summarization-prompt.txt" - with open(fname, "w") as outfile: - outfile.write(pipeline_job.entity_summarization_prompt) - data["summarize_descriptions"]["prompt"] = fname + if pipeline_job.entity_summarization_prompt is None: + # use the default prompt + config.summarize_descriptions.prompt = None else: - data.pop("summarize_descriptions") + # try to load the custom prompt + fname = "summarize_descriptions.txt" + with open(PROMPT_DIR / fname, "w") as file: + file.write(pipeline_job.entity_summarization_prompt) + config.summarize_descriptions.prompt = fname - # set prompt for community summarization - if pipeline_job.community_summarization_prompt: - fname = "community-summarization-prompt.txt" - with open(fname, "w") as outfile: - outfile.write(pipeline_job.community_summarization_prompt) - data["community_reports"]["prompt"] = fname + # set prompt for community graph summarization + if pipeline_job.community_summarization_graph_prompt is None: + # use the default prompt + config.community_reports.graph_prompt = None else: - data.pop("community_reports") + # try to load the custom prompt + fname = "community_report_graph.txt" + with open(PROMPT_DIR / fname, "w") as file: + file.write(pipeline_job.community_summarization_graph_prompt) + pipeline_job.community_summarization_graph_prompt = fname - # generate default graphrag config parameters and override with custom settings - parameters = create_graphrag_config(data, ".") + # set prompt for community text summarization + if pipeline_job.community_summarization_text_prompt is None: + # use the default prompt + config.community_reports.text_prompt = None + else: + fname = "community_report_text.txt" + # try to load the custom prompt + with open(PROMPT_DIR / fname, "w") as file: + file.write(pipeline_job.community_summarization_text_prompt) + config.community_reports.text_prompt = fname + + # set the extraction strategy + indexing_method = IndexingMethod(pipeline_job.indexing_method) + pipeline_workflows = PipelineFactory.create_pipeline(config, indexing_method) # reset pipeline job details pipeline_job.status = PipelineJobState.RUNNING - pipeline_config = create_pipeline_config(parameters) + pipeline_job.all_workflows = [ - workflow.name for workflow in pipeline_config.workflows + workflow for workflow in pipeline_workflows.names() ] pipeline_job.completed_workflows = [] pipeline_job.failed_workflows = [] @@ -117,7 +143,7 @@ def start_indexing_job(index_name: str): print("Building index...") pipeline_results: list[PipelineRunResult] = asyncio.run( api.build_index( - config=parameters, + config=config, callbacks=[logger, pipeline_job_updater], ) ) diff --git a/backend/scripts/settings.yaml b/backend/scripts/settings.yaml index 1733fc9..173465a 100644 --- a/backend/scripts/settings.yaml +++ b/backend/scripts/settings.yaml @@ -68,7 +68,7 @@ models: vector_store: default_vector_store: type: cosmosdb # or [lancedb, azure_ai_search, cosmosdb] - url: ${COSMOS_URL} + url: ${COSMOS_URI_ENDPOINT} container_name: PLACEHOLDER overwrite: True database_name: vectordb @@ -85,7 +85,7 @@ extract_graph: extract_graph_nlp: text_analyzer: - extractor_type: syntactic_parser # [regex_english, syntactic_parser, cfg] + extractor_type: regex_english # [regex_english, syntactic_parser, cfg] summarize_descriptions: model_id: default_chat_model diff --git a/infra/deploy.sh b/infra/deploy.sh index 324a029..079ab5c 100755 --- a/infra/deploy.sh +++ b/infra/deploy.sh @@ -29,6 +29,8 @@ GRAPHRAG_EMBEDDING_MODEL="text-embedding-ada-002" GRAPHRAG_EMBEDDING_MODEL_VERSION="2" GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME="text-embedding-ada-002" GRAPHRAG_EMBEDDING_MODEL_QUOTA="300" +GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST="15" +GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST="15" requiredParams=( LOCATION @@ -56,6 +58,8 @@ optionalParams=( GRAPHRAG_EMBEDDING_MODEL_QUOTA GRAPHRAG_EMBEDDING_MODEL_VERSION GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME + GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST + GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST ) errorBanner () { @@ -527,6 +531,9 @@ installGraphRAGHelmChart () { exitIfValueEmpty "$graphragEmbeddingModelDeployment" "Unable to parse embedding model deployment name from deployment outputs, exiting..." fi + graphragLlmModelConcurrentRequest="$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST" + graphragEmbeddingModelConcurrentRequest="$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST" + reset_x=true if ! [ -o xtrace ]; then set -x @@ -552,7 +559,10 @@ installGraphRAGHelmChart () { --set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$graphragLlmModelDeployment" \ --set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$graphragEmbeddingModel" \ --set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$graphragEmbeddingModelDeployment" \ - --set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" + --set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl" \ + --set "graphragConfig.GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_LLM_MODEL_CONCURRENT_REQUEST" \ + --set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST=$GRAPHRAG_EMBEDDING_MODEL_CONCURRENT_REQUEST" + local helmResult helmResult=$?