# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import asyncio import inspect import os import traceback from typing import cast import yaml from azure.identity import DefaultAzureCredential from azure.search.documents.indexes import SearchIndexClient from datashaper import WorkflowCallbacksManager from fastapi import ( APIRouter, Depends, HTTPException, UploadFile, ) from graphrag.config import create_graphrag_config from graphrag.index import create_pipeline_config from graphrag.index.bootstrap import bootstrap from graphrag.index.run import run_pipeline_with_config from kubernetes import ( client, config, ) from kubernetes.client.rest import ApiException from src.api.azure_clients import ( AzureStorageClientManager, BlobServiceClientSingleton, get_database_container_client, ) from src.api.common import ( delete_blob_container, retrieve_original_blob_container_name, sanitize_name, validate_blob_container_name, verify_subscription_key_exist, ) from src.models import ( BaseResponse, IndexNameList, IndexStatusResponse, PipelineJob, ) from src.reporting import ReporterSingleton from src.reporting.load_reporter import load_pipeline_reporter from src.reporting.pipeline_job_workflow_callbacks import PipelineJobWorkflowCallbacks from src.reporting.typing import Reporters from src.typing import PipelineJobState blob_service_client = BlobServiceClientSingleton.get_instance() azure_storage_client_manager = ( AzureStorageClientManager() ) # TODO: update API to use the AzureStorageClientManager ai_search_url = os.environ["AI_SEARCH_URL"] ai_search_audience = os.environ["AI_SEARCH_AUDIENCE"] index_route = APIRouter( prefix="/index", tags=["Index Operations"], ) if os.getenv("KUBERNETES_SERVICE_HOST"): index_route.dependencies.append(Depends(verify_subscription_key_exist)) @index_route.post( "", summary="Build an index", response_model=BaseResponse, responses={200: {"model": BaseResponse}}, ) async def setup_indexing_pipeline( storage_name: str, index_name: str, entity_extraction_prompt: UploadFile | None = None, community_report_prompt: UploadFile | None = None, summarize_descriptions_prompt: UploadFile | None = None, ): _blob_service_client = BlobServiceClientSingleton().get_instance() pipelinejob = PipelineJob() # validate index name against blob container naming rules sanitized_index_name = sanitize_name(index_name) try: validate_blob_container_name(sanitized_index_name) except ValueError: raise HTTPException( status_code=500, detail=f"Invalid index name: {index_name}", ) # check for data container existence sanitized_storage_name = sanitize_name(storage_name) if not _blob_service_client.get_container_client(sanitized_storage_name).exists(): raise HTTPException( status_code=500, detail=f"Data container '{storage_name}' does not exist.", ) # check for existing index job # it is okay if job doesn't exist, but if it does, # it must not be scheduled or running if pipelinejob.item_exist(sanitized_index_name): existing_job = pipelinejob.load_item(sanitized_index_name) if (PipelineJobState(existing_job.status) == PipelineJobState.SCHEDULED) or ( PipelineJobState(existing_job.status) == PipelineJobState.RUNNING ): raise HTTPException( status_code=202, # request has been accepted for processing but is not complete. detail=f"an index with name {index_name} already exists and has not finished building.", ) # if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled if PipelineJobState(existing_job.status) == PipelineJobState.FAILED: _delete_k8s_job(f"indexing-job-{sanitized_index_name}", "graphrag") # reset the job to scheduled state existing_job.status = PipelineJobState.SCHEDULED existing_job.percent_complete = 0 existing_job.progress = "" existing_job.all_workflows = existing_job.completed_workflows = ( existing_job.failed_workflows ) = [] existing_job.entity_extraction_prompt = None existing_job.community_report_prompt = None existing_job.summarize_descriptions_prompt = None # create or update state in cosmos db entity_extraction_prompt_content = ( entity_extraction_prompt.file.read().decode("utf-8") if entity_extraction_prompt else None ) community_report_prompt_content = ( community_report_prompt.file.read().decode("utf-8") if community_report_prompt else None ) summarize_descriptions_prompt_content = ( summarize_descriptions_prompt.file.read().decode("utf-8") if summarize_descriptions_prompt else None ) print(f"ENTITY EXTRACTION PROMPT:\n{entity_extraction_prompt_content}") print(f"COMMUNITY REPORT PROMPT:\n{community_report_prompt_content}") print(f"SUMMARIZE DESCRIPTIONS PROMPT:\n{summarize_descriptions_prompt_content}") pipelinejob.create_item( id=sanitized_index_name, index_name=sanitized_index_name, storage_name=sanitized_storage_name, entity_extraction_prompt=entity_extraction_prompt_content, community_report_prompt=community_report_prompt_content, summarize_descriptions_prompt=summarize_descriptions_prompt_content, status=PipelineJobState.SCHEDULED, ) """ At this point, we know: 1) the index name is valid 2) the data container exists 3) there is no indexing job with this name currently running or a previous job has finished """ # update or create new item in container-store in cosmosDB if not _blob_service_client.get_container_client(sanitized_index_name).exists(): _blob_service_client.create_container(sanitized_index_name) container_store_client = get_database_container_client( database_name="graphrag", container_name="container-store" ) container_store_client.upsert_item( { "id": sanitized_index_name, "human_readable_name": index_name, "type": "index", } ) # schedule AKS job try: config.load_incluster_config() # get container image name core_v1 = client.CoreV1Api() pod_name = os.environ["HOSTNAME"] pod = core_v1.read_namespaced_pod( name=pod_name, namespace=os.environ["AKS_NAMESPACE"] ) # retrieve job manifest template and replace necessary values job_manifest = _generate_aks_job_manifest( docker_image_name=pod.spec.containers[0].image, index_name=index_name, service_account_name=pod.spec.service_account_name, ) try: batch_v1 = client.BatchV1Api() batch_v1.create_namespaced_job( body=job_manifest, namespace=os.environ["AKS_NAMESPACE"] ) except ApiException as e: raise HTTPException( status_code=500, detail=f"exception when calling BatchV1Api->create_namespaced_job: {str(e)}", ) return BaseResponse(status="indexing operation has been scheduled.") except Exception as e: reporter = ReporterSingleton().get_instance() job_details = { "storage_name": storage_name, "index_name": index_name, } reporter.on_error( "Error creating a new index", details={"error_details": str(e), "job_details": job_details}, ) raise HTTPException( status_code=500, detail=f"Error occurred during setup of indexing job for '{index_name}'.", ) async def _start_indexing_pipeline( index_name: str ): # get sanitized name sanitized_index_name = sanitize_name(index_name) reporter = ReporterSingleton().get_instance() pipelinejob = PipelineJob() pipeline_job = pipelinejob.load_item(sanitized_index_name) sanitized_storage_name = pipeline_job.storage_name storage_name = retrieve_original_blob_container_name(sanitized_storage_name) # download nltk dependencies bootstrap() # create new reporters/callbacks just for this job reporters = [] reporter_names = os.getenv("REPORTERS", Reporters.CONSOLE.name.upper()).split(",") for reporter_name in reporter_names: try: reporters.append(Reporters[reporter_name.upper()]) except KeyError: raise ValueError(f"Found unknown reporter: {reporter_name}") workflow_callbacks = load_pipeline_reporter( reporting_dir=sanitized_index_name, reporters=reporters ) # load custom pipeline settings this_directory = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe())) ) data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml")) # dynamically set some values data["input"]["container_name"] = sanitized_storage_name data["storage"]["container_name"] = sanitized_index_name data["reporting"]["container_name"] = sanitized_index_name data["cache"]["container_name"] = sanitized_index_name if "vector_store" in data["embeddings"]: data["embeddings"]["vector_store"]["collection_name"] = ( f"{sanitized_index_name}_description_embedding" ) # set prompts for entity extraction, community report, and summarize descriptions. # an environment variable is set to the file path of the prompt if pipeline_job.entity_extraction_prompt: fname = "entity-extraction-prompt.txt" with open(fname, "w") as outfile: outfile.write(pipeline_job.entity_extraction_prompt) os.environ["GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE"] = fname # data["entity_extraction"]["prompt"] = fname # else: # data["entity_extraction"]["prompt"] = None if pipeline_job.community_report_prompt: fname = "community-report-prompt.txt" with open(fname, "w") as outfile: outfile.write(pipeline_job.community_report_prompt) os.environ["GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE"] = fname # data["community_reports"]["prompt"] = fname # else: # data["community_reports"]["prompt"] = None if pipeline_job.summarize_descriptions_prompt: fname = "summarize-descriptions-prompt.txt" with open(fname, "w") as outfile: outfile.write(pipeline_job.summarize_descriptions_prompt) os.environ["GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE"] = fname # data["summarize_descriptions"]["prompt"] = fname # else: # data["summarize_descriptions"]["prompt"] = None # set placeholder values to None if they have not been set # if data["entity_extraction"]["prompt"] == "PLACEHOLDER": # data["entity_extraction"]["prompt"] = None # if data["community_reports"]["prompt"] == "PLACEHOLDER": # data["community_reports"]["prompt"] = None # if data["summarize_descriptions"]["prompt"] == "PLACEHOLDER": # data["summarize_descriptions"]["prompt"] = None # generate the default pipeline from default parameters and override with custom settings parameters = create_graphrag_config(data, ".") pipeline_config = create_pipeline_config(parameters, True) # reset pipeline job details pipeline_job.status = PipelineJobState.RUNNING pipeline_job.all_workflows = [] pipeline_job.completed_workflows = [] pipeline_job.failed_workflows = [] for workflow in pipeline_config.workflows: pipeline_job.all_workflows.append(workflow.name) # add pipeline_job callback to the callback manager cast(WorkflowCallbacksManager, workflow_callbacks).register( PipelineJobWorkflowCallbacks(pipeline_job) ) # print("#################### PIPELINE JOB:") # pprint(pipeline_job.dump_model()) print("#################### PIPELINE CONFIG:") print(pipeline_config) # run the pipeline try: async for workflow_result in run_pipeline_with_config( config_or_path=pipeline_config, callbacks=workflow_callbacks, progress_reporter=None, ): await asyncio.sleep(0) if len(workflow_result.errors or []) > 0: # if the workflow failed, record the failure pipeline_job.failed_workflows.append(workflow_result.workflow) pipeline_job.update_db() # if job is done, check if any workflow steps failed if len(pipeline_job.failed_workflows) > 0: pipeline_job.status = PipelineJobState.FAILED else: # record the workflow completion pipeline_job.status = PipelineJobState.COMPLETE pipeline_job.percent_complete = 100 pipeline_job.progress = ( f"{len(pipeline_job.completed_workflows)} out of " f"{len(pipeline_job.all_workflows)} workflows completed successfully." ) workflow_callbacks.on_log( f"Index Name: {index_name}, Container Name: {storage_name}\n", details={"status_message": "Indexing pipeline complete."}, ) del workflow_callbacks # garbage collect if pipeline_job.status == PipelineJobState.FAILED: exit(1) # signal to AKS that indexing job failed except Exception as e: pipeline_job.status = PipelineJobState.FAILED # update failed state in cosmos db error_details = { "error_details": str(e), "error_message": "Indexing pipeline failed.", } # log error in local index directory logs workflow_callbacks.on_error( message=f"Index Name: {index_name}, Container Name: {storage_name}\n", cause=e, stack=traceback.format_exc(), details=error_details, ) # log error in global index directory logs reporter.on_error( f"Index Name: {index_name}, Container Name: {storage_name}\n {str(e)} \n", cause=str(e), stack=traceback.format_exc(), details=error_details, ) raise HTTPException( status_code=500, detail=f"Error occurred during indexing job for index '{index_name}'.", ) def _generate_aks_job_manifest( docker_image_name: str, index_name: str, service_account_name: str, ) -> dict: """Generate an AKS Jobs manifest file with the specified parameters. The manifest file must be valid YAML with certain values replaced by the provided arguments. """ # NOTE: the relative file locations are based on the WORKDIR set in Dockerfile-indexing with open("src/aks-batch-job-template.yaml", "r") as f: manifest = yaml.safe_load(f) manifest["metadata"]["name"] = f"indexing-job-{sanitize_name(index_name)}" manifest["spec"]["template"]["spec"]["serviceAccountName"] = service_account_name manifest["spec"]["template"]["spec"]["containers"][0]["image"] = docker_image_name manifest["spec"]["template"]["spec"]["containers"][0]["command"] = [ "python", "run-indexing-job.py", f"-i={index_name}", ] return manifest @index_route.get( "", summary="Get all indexes", response_model=IndexNameList, responses={200: {"model": IndexNameList}}, ) async def get_all_indexes(): """ Retrieve a list of all index names. """ items = [] try: container_store_client = get_database_container_client( database_name="graphrag", container_name="container-store" ) for item in container_store_client.read_all_items(): if item["type"] == "index": items.append(item["human_readable_name"]) except Exception as e: reporter = ReporterSingleton().get_instance() reporter.on_error(f"Error retrieving index names: {str(e)}") return IndexNameList(index_name=items) def _get_pod_name(job_name: str, namespace: str) -> str | None: """Retrieve the name of a kubernetes pod associated with a given job name.""" # function should work only when running in AKS if not os.getenv("KUBERNETES_SERVICE_HOST"): return None config.load_incluster_config() v1 = client.CoreV1Api() ret = v1.list_namespaced_pod(namespace=namespace) for i in ret.items: if job_name in i.metadata.name: return i.metadata.name return None def _delete_k8s_job(job_name: str, namespace: str) -> None: """Delete a kubernetes job. Must delete K8s job first and then any pods associated with it """ # function should only work when running in AKS if not os.getenv("KUBERNETES_SERVICE_HOST"): return None reporter = ReporterSingleton().get_instance() config.load_incluster_config() try: batch_v1 = client.BatchV1Api() batch_v1.delete_namespaced_job(name=job_name, namespace=namespace) except Exception as e: reporter.on_error( f"Error deleting k8s job {job_name}.", details={"error_details": str(e), "Container": job_name}, ) pass try: core_v1 = client.CoreV1Api() job_pod = _get_pod_name(job_name, os.environ["AKS_NAMESPACE"]) if job_pod: core_v1.delete_namespaced_pod(job_pod, namespace=namespace) except Exception as e: reporter.on_error( f"Error deleting k8s pod for job {job_name}.", details={"error_details": str(e), "Container": job_name}, ) pass @index_route.delete( "/{index_name}", summary="Delete a specified index", response_model=BaseResponse, responses={200: {"model": BaseResponse}}, ) async def delete_index(index_name: str): """ Delete a specified index. """ sanitized_index_name = sanitize_name(index_name) reporter = ReporterSingleton().get_instance() try: # kill indexing job if it is running if os.getenv("KUBERNETES_SERVICE_HOST"): # only found if in AKS _delete_k8s_job(f"indexing-job-{sanitized_index_name}", "graphrag") # remove blob container and all associated entries in cosmos db try: delete_blob_container(sanitized_index_name) except Exception: pass # update container-store in cosmosDB try: container_store_client = get_database_container_client( database_name="graphrag", container_name="container-store" ) container_store_client.delete_item( item=sanitized_index_name, partition_key=sanitized_index_name ) except Exception: pass # update jobs database in cosmosDB try: jobs_container = get_database_container_client( database_name="graphrag", container_name="jobs" ) jobs_container.delete_item( item=sanitized_index_name, partition_key=sanitized_index_name ) except Exception: pass index_client = SearchIndexClient( endpoint=ai_search_url, credential=DefaultAzureCredential(), audience=ai_search_audience, ) ai_search_index_name = f"{sanitized_index_name}_description_embedding" if ai_search_index_name in index_client.list_index_names(): index_client.delete_index(ai_search_index_name) except Exception as e: reporter.on_error( message=f"Error encountered while deleting all data for index {index_name}.", stack=traceback.format_exc(), details={"error_details": str(e), "container": index_name}, ) raise HTTPException( status_code=500, detail=f"Error deleting index '{index_name}'." ) return BaseResponse(status="Success") @index_route.get( "/status/{index_name}", summary="Track the status of an indexing job", response_model=IndexStatusResponse, ) async def get_index_job_status(index_name: str): pipelinejob = PipelineJob() # TODO: fix class so initiliazation is not required sanitized_index_name = sanitize_name(index_name) if pipelinejob.item_exist(sanitized_index_name): pipeline_job = pipelinejob.load_item(sanitized_index_name) return IndexStatusResponse( status_code=200, index_name=retrieve_original_blob_container_name(pipeline_job.index_name), storage_name=retrieve_original_blob_container_name( pipeline_job.storage_name ), status=pipeline_job.status.value, percent_complete=pipeline_job.percent_complete, progress=pipeline_job.progress, ) raise HTTPException(status_code=404, detail=f"Index '{index_name}' does not exist.")