mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
178 lines
6.4 KiB
Python
178 lines
6.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import os
|
|
import traceback
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from azure.cosmos import PartitionKey, ThroughputProperties
|
|
from fastapi import (
|
|
Depends,
|
|
FastAPI,
|
|
Request,
|
|
status,
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import Response
|
|
from kubernetes import (
|
|
client,
|
|
config,
|
|
)
|
|
|
|
from graphrag_app.api.data import data_route
|
|
from graphrag_app.api.graph import graph_route
|
|
from graphrag_app.api.index import index_route
|
|
from graphrag_app.api.prompt_tuning import prompt_tuning_route
|
|
from graphrag_app.api.query import query_route
|
|
from graphrag_app.api.source import source_route
|
|
from graphrag_app.logger.load_logger import load_pipeline_logger
|
|
from graphrag_app.utils.azure_clients import AzureClientManager
|
|
from graphrag_app.utils.common import subscription_key_check
|
|
|
|
|
|
async def catch_all_exceptions_middleware(request: Request, call_next):
|
|
"""A global function to catch all exceptions and produce a standard error message"""
|
|
try:
|
|
return await call_next(request)
|
|
except Exception as e:
|
|
reporter = load_pipeline_logger()
|
|
stack = traceback.format_exc()
|
|
reporter.error(
|
|
message="Unexpected internal server error",
|
|
cause=e,
|
|
stack=stack,
|
|
)
|
|
return Response(
|
|
"Unexpected internal server error.",
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
|
|
# NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved
|
|
def intialize_cosmosdb_setup():
|
|
"""Initialise database setup (if necessary) and configure CosmosDB containers that are expected at startup time if they do not exist."""
|
|
azure_client_manager = AzureClientManager()
|
|
client = azure_client_manager.get_cosmos_client()
|
|
throughput = ThroughputProperties(
|
|
auto_scale_max_throughput=1000, auto_scale_increment_percent=1
|
|
)
|
|
db_client = client.create_database_if_not_exists(
|
|
"graphrag", offer_throughput=throughput
|
|
)
|
|
# create containers with default settings
|
|
db_client.create_container_if_not_exists(
|
|
id="jobs", partition_key=PartitionKey(path="/id")
|
|
)
|
|
db_client.create_container_if_not_exists(
|
|
id="container-store",
|
|
partition_key=PartitionKey(path="/id"),
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Deploy a cronjob to manage indexing jobs.
|
|
|
|
This function is called when the FastAPI application first starts up.
|
|
To manage multiple graphrag indexing jobs, we deploy a k8s cronjob.
|
|
This cronjob will act as a job manager that creates/manages the execution of graphrag indexing jobs as they are requested.
|
|
"""
|
|
# if running in a TESTING environment, exit early to avoid creating k8s resources
|
|
if os.getenv("TESTING"):
|
|
yield
|
|
return
|
|
|
|
# TODO: must identify proper CosmosDB RBAC roles before databases and containers can be created by this web app
|
|
# intialize_cosmosdb_setup()
|
|
|
|
try:
|
|
# Check if the cronjob exists and create it if it does not exist
|
|
config.load_incluster_config()
|
|
# retrieve the running pod spec
|
|
core_v1 = client.CoreV1Api()
|
|
pod_name = os.environ["HOSTNAME"]
|
|
pod = core_v1.read_namespaced_pod(
|
|
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
|
|
)
|
|
# load the k8s cronjob template and update PLACEHOLDER values with correct values based on the running pod spec
|
|
ROOT_DIR = Path(__file__).resolve().parent.parent
|
|
with (ROOT_DIR / "manifests/cronjob.yaml").open("r") as f:
|
|
manifest = yaml.safe_load(f)
|
|
# set docker image name
|
|
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"]["containers"][0][
|
|
"image"
|
|
] = pod.spec.containers[0].image
|
|
# set service account name
|
|
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"][
|
|
"serviceAccountName"
|
|
] = pod.spec.service_account_name
|
|
# set image pull secrets only if they were provided as part of the deployment.
|
|
if hasattr(pod.spec, "image_pull_secrets"):
|
|
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"][
|
|
"imagePullSecrets"
|
|
] = pod.spec.image_pull_secrets
|
|
# retrieve list of existing cronjobs
|
|
batch_v1 = client.BatchV1Api()
|
|
namespace_cronjobs = batch_v1.list_namespaced_cron_job(
|
|
namespace=os.environ["AKS_NAMESPACE"]
|
|
)
|
|
cronjob_names = [cronjob.metadata.name for cronjob in namespace_cronjobs.items]
|
|
# create cronjob if it does not exist
|
|
if manifest["metadata"]["name"] not in cronjob_names:
|
|
batch_v1.create_namespaced_cron_job(
|
|
namespace=os.environ["AKS_NAMESPACE"], body=manifest
|
|
)
|
|
except Exception as e:
|
|
print("Failed to create graphrag cronjob.")
|
|
logger = load_pipeline_logger()
|
|
logger.error(
|
|
message="Failed to create graphrag cronjob",
|
|
cause=str(e),
|
|
stack=traceback.format_exc(),
|
|
)
|
|
yield # This is where the application starts up.
|
|
# shutdown/garbage collection code goes here
|
|
|
|
|
|
app = FastAPI(
|
|
docs_url="/manpage/docs",
|
|
openapi_url="/manpage/openapi.json",
|
|
root_path=os.getenv("API_ROOT_PATH", ""),
|
|
title="GraphRAG",
|
|
version=os.getenv("GRAPHRAG_VERSION", "undefined_version"),
|
|
lifespan=lifespan
|
|
if os.getenv("KUBERNETES_SERVICE_HOST")
|
|
else None, # only set lifespan if running in AKS (by checking for a default k8s environment variable)
|
|
)
|
|
|
|
app.middleware("http")(catch_all_exceptions_middleware)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
app.include_router(data_route)
|
|
app.include_router(index_route)
|
|
app.include_router(query_route)
|
|
# app.include_router(query_streaming_route) # temporarily disable streaming endpoints
|
|
app.include_router(prompt_tuning_route)
|
|
app.include_router(source_route)
|
|
app.include_router(graph_route)
|
|
|
|
|
|
# health check endpoint
|
|
@app.get(
|
|
"/health",
|
|
summary="API health check",
|
|
dependencies=[Depends(subscription_key_check)]
|
|
if os.getenv("KUBERNETES_SERVICE_HOST")
|
|
else None,
|
|
)
|
|
def health_check():
|
|
"""Returns a 200 response to indicate the API is healthy."""
|
|
return Response(status_code=status.HTTP_200_OK)
|