475 lines
17 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import os
2024-07-15 16:42:22 -07:00
import traceback
2024-08-09 22:22:49 -04:00
from time import time
2024-06-26 15:45:06 -04:00
2025-01-17 00:41:37 -05:00
import graphrag.api as api
2024-06-26 15:45:06 -04:00
import yaml
from azure.identity import DefaultAzureCredential
from azure.search.documents.indexes import SearchIndexClient
from fastapi import (
APIRouter,
HTTPException,
UploadFile,
)
2025-01-03 00:55:38 -05:00
from graphrag.config.create_graphrag_config import create_graphrag_config
2024-06-26 15:45:06 -04:00
from graphrag.index.bootstrap import bootstrap
2025-01-03 00:55:38 -05:00
from graphrag.index.create_pipeline_config import create_pipeline_config
2024-06-26 15:45:06 -04:00
from kubernetes import (
2024-12-30 01:59:08 -05:00
client as kubernetes_client,
2024-06-26 15:45:06 -04:00
)
2024-12-30 01:59:08 -05:00
from kubernetes import (
config as kubernetes_config,
2024-06-26 15:45:06 -04:00
)
2024-12-30 01:59:08 -05:00
from src.api.azure_clients import AzureClientManager
2024-06-26 15:45:06 -04:00
from src.api.common import (
delete_blob_container,
sanitize_name,
validate_blob_container_name,
)
2024-12-30 01:59:08 -05:00
from src.logger import (
LoggerSingleton,
PipelineJobWorkflowCallbacks,
Reporters,
load_pipeline_logger,
)
2024-06-26 15:45:06 -04:00
from src.models import (
BaseResponse,
IndexNameList,
IndexStatusResponse,
)
2024-12-30 01:59:08 -05:00
from src.typing.pipeline import PipelineJobState
from src.utils.pipeline import PipelineJob
2024-06-26 15:45:06 -04:00
index_route = APIRouter(
prefix="/index",
tags=["Index Operations"],
)
@index_route.post(
"",
summary="Build an index",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
)
async def setup_indexing_pipeline(
storage_name: str,
index_name: str,
entity_extraction_prompt: UploadFile | None = None,
community_report_prompt: UploadFile | None = None,
summarize_descriptions_prompt: UploadFile | None = None,
):
2024-12-30 01:59:08 -05:00
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
2024-06-26 15:45:06 -04:00
pipelinejob = PipelineJob()
# validate index name against blob container naming rules
sanitized_index_name = sanitize_name(index_name)
try:
validate_blob_container_name(sanitized_index_name)
except ValueError:
raise HTTPException(
status_code=500,
detail=f"Invalid index name: {index_name}",
)
# check for data container existence
sanitized_storage_name = sanitize_name(storage_name)
2024-12-30 01:59:08 -05:00
if not blob_service_client.get_container_client(sanitized_storage_name).exists():
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
2024-07-15 16:42:22 -07:00
detail=f"Storage blob container {storage_name} does not exist",
2024-06-26 15:45:06 -04:00
)
# check for prompts
entity_extraction_prompt_content = (
entity_extraction_prompt.file.read().decode("utf-8")
if entity_extraction_prompt
else None
)
community_report_prompt_content = (
community_report_prompt.file.read().decode("utf-8")
if community_report_prompt
else None
)
summarize_descriptions_prompt_content = (
summarize_descriptions_prompt.file.read().decode("utf-8")
if summarize_descriptions_prompt
else None
)
2024-06-26 15:45:06 -04:00
# check for existing index job
# it is okay if job doesn't exist, but if it does,
# it must not be scheduled or running
if pipelinejob.item_exist(sanitized_index_name):
existing_job = pipelinejob.load_item(sanitized_index_name)
if (PipelineJobState(existing_job.status) == PipelineJobState.SCHEDULED) or (
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
):
raise HTTPException(
status_code=202, # request has been accepted for processing but is not complete.
2024-08-09 22:22:49 -04:00
detail=f"Index '{index_name}' already exists and has not finished building.",
2024-06-26 15:45:06 -04:00
)
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
_delete_k8s_job(
f"indexing-job-{sanitized_index_name}", os.environ["AKS_NAMESPACE"]
)
# reset the pipeline job details
existing_job._status = PipelineJobState.SCHEDULED
existing_job._percent_complete = 0
existing_job._progress = ""
existing_job._all_workflows = existing_job._completed_workflows = (
existing_job._failed_workflows
2024-06-26 15:45:06 -04:00
) = []
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
existing_job._community_report_prompt = community_report_prompt_content
existing_job._summarize_descriptions_prompt = (
summarize_descriptions_prompt_content
)
2024-08-09 22:22:49 -04:00
existing_job._epoch_request_time = int(time())
existing_job.update_db()
else:
pipelinejob.create_item(
id=sanitized_index_name,
2024-08-09 22:22:49 -04:00
human_readable_index_name=index_name,
human_readable_storage_name=storage_name,
entity_extraction_prompt=entity_extraction_prompt_content,
community_report_prompt=community_report_prompt_content,
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
status=PipelineJobState.SCHEDULED,
)
2024-06-26 15:45:06 -04:00
2024-08-09 22:22:49 -04:00
return BaseResponse(status="Indexing job scheduled")
async def _start_indexing_pipeline(index_name: str):
# get sanitized name
sanitized_index_name = sanitize_name(index_name)
2024-06-26 15:45:06 -04:00
# update or create new item in container-store in cosmosDB
2024-12-30 01:59:08 -05:00
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
if not blob_service_client.get_container_client(sanitized_index_name).exists():
blob_service_client.create_container(sanitized_index_name)
cosmos_container_client = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
2024-06-26 15:45:06 -04:00
)
2024-12-30 01:59:08 -05:00
cosmos_container_client.upsert_item({
2024-09-12 21:41:46 -04:00
"id": sanitized_index_name,
"human_readable_name": index_name,
"type": "index",
})
2024-06-26 15:45:06 -04:00
2024-12-30 01:59:08 -05:00
logger = LoggerSingleton().get_instance()
2024-06-26 15:45:06 -04:00
pipelinejob = PipelineJob()
pipeline_job = pipelinejob.load_item(sanitized_index_name)
2024-08-09 22:22:49 -04:00
sanitized_storage_name = pipeline_job.sanitized_storage_name
storage_name = pipeline_job.human_readable_index_name
2024-06-26 15:45:06 -04:00
# download nltk dependencies
bootstrap()
# load custom pipeline settings
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
# dynamically set some values
data["input"]["container_name"] = sanitized_storage_name
data["storage"]["container_name"] = sanitized_index_name
data["reporting"]["container_name"] = sanitized_index_name
data["cache"]["container_name"] = sanitized_index_name
if "vector_store" in data["embeddings"]:
data["embeddings"]["vector_store"]["collection_name"] = (
f"{sanitized_index_name}_description_embedding"
)
2025-01-17 00:41:37 -05:00
# set prompt for entity extraction
2024-06-26 15:45:06 -04:00
if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_extraction_prompt)
data["entity_extraction"]["prompt"] = fname
else:
data.pop("entity_extraction")
2025-01-17 00:41:37 -05:00
# set prompt for summarize descriptions
2024-06-26 15:45:06 -04:00
if pipeline_job.summarize_descriptions_prompt:
fname = "summarize-descriptions-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.summarize_descriptions_prompt)
data["summarize_descriptions"]["prompt"] = fname
else:
data.pop("summarize_descriptions")
2024-06-26 15:45:06 -04:00
2025-01-17 00:41:37 -05:00
# set prompt for community report
if pipeline_job.community_report_prompt:
fname = "community-report-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.community_report_prompt)
data["community_reports"]["prompt"] = fname
else:
data.pop("community_reports")
# generate a default GraphRagConfig and override with custom settings
2024-06-26 15:45:06 -04:00
parameters = create_graphrag_config(data, ".")
# reset pipeline job details
pipeline_job.status = PipelineJobState.RUNNING
pipeline_job.all_workflows = []
pipeline_job.completed_workflows = []
pipeline_job.failed_workflows = []
2025-01-17 00:41:37 -05:00
pipeline_config = create_pipeline_config(parameters)
2024-06-26 15:45:06 -04:00
for workflow in pipeline_config.workflows:
pipeline_job.all_workflows.append(workflow.name)
2024-12-30 01:59:08 -05:00
# create new loggers/callbacks just for this job
loggers = []
logger_names = os.getenv("REPORTERS", Reporters.CONSOLE.name.upper()).split(",")
for logger_name in logger_names:
2024-07-15 16:42:22 -07:00
try:
2024-12-30 01:59:08 -05:00
loggers.append(Reporters[logger_name.upper()])
2024-07-15 16:42:22 -07:00
except KeyError:
2024-12-30 01:59:08 -05:00
raise ValueError(f"Unknown logger type: {logger_name}")
workflow_callbacks = load_pipeline_logger(
2024-07-15 16:42:22 -07:00
index_name=index_name,
num_workflow_steps=len(pipeline_job.all_workflows),
reporting_dir=sanitized_index_name,
2024-12-30 01:59:08 -05:00
reporters=loggers,
2024-07-15 16:42:22 -07:00
)
2025-01-17 00:41:37 -05:00
# add pipeline job callback to monitor job progress
pipeline_job_callback = PipelineJobWorkflowCallbacks(pipeline_job)
2024-06-26 15:45:06 -04:00
# run the pipeline
try:
2025-01-17 00:41:37 -05:00
await api.build_index(
config=parameters,
callbacks=[workflow_callbacks, pipeline_job_callback],
)
2024-06-26 15:45:06 -04:00
# if job is done, check if any workflow steps failed
if len(pipeline_job.failed_workflows) > 0:
pipeline_job.status = PipelineJobState.FAILED
2025-01-17 00:41:37 -05:00
workflow_callbacks.on_log(
message=f"Indexing pipeline encountered error for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline encountered error",
},
)
2024-06-26 15:45:06 -04:00
else:
# record the workflow completion
pipeline_job.status = PipelineJobState.COMPLETE
pipeline_job.percent_complete = 100
2025-01-17 00:41:37 -05:00
workflow_callbacks.on_log(
message=f"Indexing pipeline complete for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline complete",
},
)
2024-06-26 15:45:06 -04:00
pipeline_job.progress = (
f"{len(pipeline_job.completed_workflows)} out of "
f"{len(pipeline_job.all_workflows)} workflows completed successfully."
)
del workflow_callbacks # garbage collect
if pipeline_job.status == PipelineJobState.FAILED:
exit(1) # signal to AKS that indexing job failed
2024-07-15 16:42:22 -07:00
except Exception as e:
2024-06-26 15:45:06 -04:00
pipeline_job.status = PipelineJobState.FAILED
# update failed state in cosmos db
error_details = {
2024-07-15 16:42:22 -07:00
"index": index_name,
"storage_name": storage_name,
2024-06-26 15:45:06 -04:00
}
# log error in local index directory logs
workflow_callbacks.on_error(
2024-08-09 22:22:49 -04:00
message=f"Indexing pipeline failed for index '{index_name}'.",
2024-07-15 16:42:22 -07:00
cause=e,
stack=traceback.format_exc(),
2024-06-26 15:45:06 -04:00
details=error_details,
)
# log error in global index directory logs
2024-12-30 01:59:08 -05:00
logger.on_error(
2024-08-09 22:22:49 -04:00
message=f"Indexing pipeline failed for index '{index_name}'.",
2024-07-15 16:42:22 -07:00
cause=e,
stack=traceback.format_exc(),
2024-06-26 15:45:06 -04:00
details=error_details,
)
raise HTTPException(
status_code=500,
2024-08-09 22:22:49 -04:00
detail=f"Error encountered during indexing job for index '{index_name}'.",
2024-06-26 15:45:06 -04:00
)
@index_route.get(
"",
summary="Get all indexes",
response_model=IndexNameList,
responses={200: {"model": IndexNameList}},
)
async def get_all_indexes():
"""
Retrieve a list of all index names.
"""
items = []
try:
2024-12-30 01:59:08 -05:00
azure_client_manager = AzureClientManager()
container_store_client = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
2024-06-26 15:45:06 -04:00
)
for item in container_store_client.read_all_items():
if item["type"] == "index":
items.append(item["human_readable_name"])
except Exception:
2024-12-30 01:59:08 -05:00
logger = LoggerSingleton().get_instance()
logger.on_error("Error retrieving index names")
2024-06-26 15:45:06 -04:00
return IndexNameList(index_name=items)
def _get_pod_name(job_name: str, namespace: str) -> str | None:
"""Retrieve the name of a kubernetes pod associated with a given job name."""
# function should work only when running in AKS
if not os.getenv("KUBERNETES_SERVICE_HOST"):
return None
2024-12-30 01:59:08 -05:00
kubernetes_config.load_incluster_config()
v1 = kubernetes_client.CoreV1Api()
2024-06-26 15:45:06 -04:00
ret = v1.list_namespaced_pod(namespace=namespace)
for i in ret.items:
if job_name in i.metadata.name:
return i.metadata.name
return None
def _delete_k8s_job(job_name: str, namespace: str) -> None:
"""Delete a kubernetes job.
Must delete K8s job first and then any pods associated with it
"""
# function should only work when running in AKS
if not os.getenv("KUBERNETES_SERVICE_HOST"):
return None
2024-12-30 01:59:08 -05:00
logger = LoggerSingleton().get_instance()
kubernetes_config.load_incluster_config()
2024-06-26 15:45:06 -04:00
try:
2024-12-30 01:59:08 -05:00
batch_v1 = kubernetes_client.BatchV1Api()
2024-06-26 15:45:06 -04:00
batch_v1.delete_namespaced_job(name=job_name, namespace=namespace)
except Exception:
2024-12-30 01:59:08 -05:00
logger.on_error(
2024-07-15 16:42:22 -07:00
message=f"Error deleting k8s job {job_name}.",
details={"container": job_name},
2024-06-26 15:45:06 -04:00
)
pass
try:
2024-12-30 01:59:08 -05:00
core_v1 = kubernetes_client.CoreV1Api()
2024-06-26 15:45:06 -04:00
job_pod = _get_pod_name(job_name, os.environ["AKS_NAMESPACE"])
if job_pod:
core_v1.delete_namespaced_pod(job_pod, namespace=namespace)
except Exception:
2024-12-30 01:59:08 -05:00
logger.on_error(
2024-07-15 16:42:22 -07:00
message=f"Error deleting k8s pod for job {job_name}.",
details={"container": job_name},
2024-06-26 15:45:06 -04:00
)
pass
@index_route.delete(
"/{index_name}",
summary="Delete a specified index",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
)
async def delete_index(index_name: str):
"""
Delete a specified index.
"""
sanitized_index_name = sanitize_name(index_name)
2024-12-30 01:59:08 -05:00
azure_client_manager = AzureClientManager()
2024-06-26 15:45:06 -04:00
try:
# kill indexing job if it is running
if os.getenv("KUBERNETES_SERVICE_HOST"): # only found if in AKS
_delete_k8s_job(f"indexing-job-{sanitized_index_name}", "graphrag")
# remove blob container and all associated entries in cosmos db
try:
delete_blob_container(sanitized_index_name)
except Exception:
pass
# update container-store in cosmosDB
try:
2024-12-30 01:59:08 -05:00
container_store_client = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
2024-06-26 15:45:06 -04:00
)
container_store_client.delete_item(
item=sanitized_index_name, partition_key=sanitized_index_name
)
except Exception:
pass
# update jobs database in cosmosDB
try:
2024-12-30 01:59:08 -05:00
jobs_container = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="jobs"
2024-06-26 15:45:06 -04:00
)
jobs_container.delete_item(
item=sanitized_index_name, partition_key=sanitized_index_name
)
except Exception:
pass
index_client = SearchIndexClient(
2024-12-30 01:59:08 -05:00
endpoint=os.environ["AI_SEARCH_URL"],
2024-06-26 15:45:06 -04:00
credential=DefaultAzureCredential(),
2024-12-30 01:59:08 -05:00
audience=os.environ["AI_SEARCH_AUDIENCE"],
2024-06-26 15:45:06 -04:00
)
ai_search_index_name = f"{sanitized_index_name}_description_embedding"
if ai_search_index_name in index_client.list_index_names():
index_client.delete_index(ai_search_index_name)
except Exception:
2024-12-30 01:59:08 -05:00
logger = LoggerSingleton().get_instance()
logger.on_error(
2024-06-26 15:45:06 -04:00
message=f"Error encountered while deleting all data for index {index_name}.",
stack=None,
details={"container": index_name},
2024-06-26 15:45:06 -04:00
)
raise HTTPException(
status_code=500, detail=f"Error deleting index '{index_name}'."
)
return BaseResponse(status="Success")
@index_route.get(
"/status/{index_name}",
summary="Track the status of an indexing job",
response_model=IndexStatusResponse,
)
async def get_index_job_status(index_name: str):
pipelinejob = PipelineJob() # TODO: fix class so initiliazation is not required
sanitized_index_name = sanitize_name(index_name)
if pipelinejob.item_exist(sanitized_index_name):
pipeline_job = pipelinejob.load_item(sanitized_index_name)
return IndexStatusResponse(
status_code=200,
2024-08-09 22:22:49 -04:00
index_name=pipeline_job.human_readable_index_name,
storage_name=pipeline_job.human_readable_storage_name,
2024-06-26 15:45:06 -04:00
status=pipeline_job.status.value,
percent_complete=pipeline_job.percent_complete,
progress=pipeline_job.progress,
)
raise HTTPException(status_code=404, detail=f"Index '{index_name}' does not exist.")