| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | # Copyright (c) Microsoft Corporation. | 
					
						
							|  |  |  | # Licensed under the MIT License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-07-09 22:58:25 -04:00
										 |  |  | import traceback | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  | from contextlib import asynccontextmanager | 
					
						
							| 
									
										
										
										
											2025-01-25 04:07:53 -05:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  | import yaml | 
					
						
							| 
									
										
										
										
											2025-01-23 00:23:58 -05:00
										 |  |  | from azure.cosmos import PartitionKey, ThroughputProperties | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | from fastapi import ( | 
					
						
							|  |  |  |     FastAPI, | 
					
						
							|  |  |  |     Request, | 
					
						
							|  |  |  |     status, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from fastapi.middleware.cors import CORSMiddleware | 
					
						
							|  |  |  | from fastapi.responses import Response | 
					
						
							| 
									
										
										
										
											2024-12-30 00:59:35 -05:00
										 |  |  | from fastapi_offline import FastAPIOffline | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  | from kubernetes import ( | 
					
						
							|  |  |  |     client, | 
					
						
							|  |  |  |     config, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-25 04:07:53 -05:00
										 |  |  | from graphrag_app.api.data import data_route | 
					
						
							|  |  |  | from graphrag_app.api.graph import graph_route | 
					
						
							|  |  |  | from graphrag_app.api.index import index_route | 
					
						
							|  |  |  | from graphrag_app.api.prompt_tuning import prompt_tuning_route | 
					
						
							|  |  |  | from graphrag_app.api.query import query_route | 
					
						
							|  |  |  | from graphrag_app.api.query_streaming import query_streaming_route | 
					
						
							|  |  |  | from graphrag_app.api.source import source_route | 
					
						
							|  |  |  | from graphrag_app.logger.load_logger import load_pipeline_logger | 
					
						
							|  |  |  | from graphrag_app.utils.azure_clients import AzureClientManager | 
					
						
							| 
									
										
										
										
											2024-07-15 16:42:22 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def catch_all_exceptions_middleware(request: Request, call_next): | 
					
						
							|  |  |  |     """a function to globally catch all exceptions and return a 500 response with the exception message""" | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await call_next(request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-01-21 18:43:55 -05:00
										 |  |  |         reporter = load_pipeline_logger() | 
					
						
							| 
									
										
										
										
											2024-12-30 01:59:08 -05:00
										 |  |  |         stack = traceback.format_exc() | 
					
						
							| 
									
										
										
										
											2025-01-21 00:29:48 -05:00
										 |  |  |         reporter.error( | 
					
						
							| 
									
										
										
										
											2024-07-15 16:42:22 -07:00
										 |  |  |             message="Unexpected internal server error", | 
					
						
							|  |  |  |             cause=e, | 
					
						
							| 
									
										
										
										
											2024-12-30 01:59:08 -05:00
										 |  |  |             stack=stack, | 
					
						
							| 
									
										
										
										
											2024-07-15 16:42:22 -07:00
										 |  |  |         ) | 
					
						
							|  |  |  |         return Response("Unexpected internal server error.", status_code=500) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-30 01:59:08 -05:00
										 |  |  | def intialize_cosmosdb_setup(): | 
					
						
							|  |  |  |     """Initialise CosmosDB (if necessary) by setting up a database and containers that are expected at startup time.""" | 
					
						
							|  |  |  |     azure_client_manager = AzureClientManager() | 
					
						
							|  |  |  |     client = azure_client_manager.get_cosmos_client() | 
					
						
							| 
									
										
										
										
											2025-01-23 00:23:58 -05:00
										 |  |  |     db_client = client.create_database_if_not_exists("graphrag") | 
					
						
							|  |  |  |     # create containers with default settings | 
					
						
							|  |  |  |     throughput = ThroughputProperties( | 
					
						
							|  |  |  |         auto_scale_max_throughput=1000, auto_scale_increment_percent=1 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     db_client.create_container_if_not_exists( | 
					
						
							|  |  |  |         id="jobs", partition_key=PartitionKey(path="/id"), offer_throughput=throughput | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     db_client.create_container_if_not_exists( | 
					
						
							|  |  |  |         id="container-store", | 
					
						
							|  |  |  |         partition_key=PartitionKey(path="/id"), | 
					
						
							|  |  |  |         offer_throughput=throughput, | 
					
						
							| 
									
										
										
										
											2024-12-30 01:59:08 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  | @asynccontextmanager | 
					
						
							|  |  |  | async def lifespan(app: FastAPI): | 
					
						
							| 
									
										
										
										
											2024-12-30 01:59:08 -05:00
										 |  |  |     """Deploy a cronjob to manage indexing jobs.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     This function is called when the FastAPI application first starts up. | 
					
						
							|  |  |  |     To manage multiple graphrag indexing jobs, we deploy a k8s cronjob. | 
					
						
							|  |  |  |     This cronjob will act as a job manager that creates/manages the execution of graphrag indexing jobs as they are requested. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     # if running in a TESTING environment, exit early to avoid creating k8s resources | 
					
						
							|  |  |  |     if os.getenv("TESTING"): | 
					
						
							|  |  |  |         yield | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Initialize CosmosDB setup | 
					
						
							|  |  |  |     intialize_cosmosdb_setup() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |     try: | 
					
						
							|  |  |  |         # Check if the cronjob exists and create it if it does not exist | 
					
						
							|  |  |  |         config.load_incluster_config() | 
					
						
							|  |  |  |         # retrieve the running pod spec | 
					
						
							|  |  |  |         core_v1 = client.CoreV1Api() | 
					
						
							|  |  |  |         pod_name = os.environ["HOSTNAME"] | 
					
						
							|  |  |  |         pod = core_v1.read_namespaced_pod( | 
					
						
							|  |  |  |             name=pod_name, namespace=os.environ["AKS_NAMESPACE"] | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-01-26 03:11:51 -05:00
										 |  |  |         # load the k8s cronjob template and update PLACEHOLDER values with correct values based on the running pod spec | 
					
						
							|  |  |  |         ROOT_DIR = Path(__file__).resolve().parent.parent | 
					
						
							|  |  |  |         with (ROOT_DIR / "manifests/cronjob.yaml").open("r") as f: | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |             manifest = yaml.safe_load(f) | 
					
						
							|  |  |  |         manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"]["containers"][0][ | 
					
						
							|  |  |  |             "image" | 
					
						
							|  |  |  |         ] = pod.spec.containers[0].image | 
					
						
							|  |  |  |         manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"][ | 
					
						
							|  |  |  |             "serviceAccountName" | 
					
						
							|  |  |  |         ] = pod.spec.service_account_name | 
					
						
							|  |  |  |         # retrieve list of existing cronjobs | 
					
						
							|  |  |  |         batch_v1 = client.BatchV1Api() | 
					
						
							| 
									
										
										
										
											2024-09-19 01:09:26 -04:00
										 |  |  |         namespace_cronjobs = batch_v1.list_namespaced_cron_job( | 
					
						
							|  |  |  |             namespace=os.environ["AKS_NAMESPACE"] | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |         cronjob_names = [cronjob.metadata.name for cronjob in namespace_cronjobs.items] | 
					
						
							|  |  |  |         # create cronjob if it does not exist | 
					
						
							|  |  |  |         if manifest["metadata"]["name"] not in cronjob_names: | 
					
						
							| 
									
										
										
										
											2024-09-19 01:09:26 -04:00
										 |  |  |             batch_v1.create_namespaced_cron_job( | 
					
						
							|  |  |  |                 namespace=os.environ["AKS_NAMESPACE"], body=manifest | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-09-12 21:41:46 -04:00
										 |  |  |         print("Failed to create graphrag cronjob.") | 
					
						
							| 
									
										
										
										
											2025-01-21 18:43:55 -05:00
										 |  |  |         logger = load_pipeline_logger() | 
					
						
							| 
									
										
										
										
											2025-01-21 00:29:48 -05:00
										 |  |  |         logger.error( | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |             message="Failed to create graphrag cronjob", | 
					
						
							|  |  |  |             cause=str(e), | 
					
						
							|  |  |  |             stack=traceback.format_exc(), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     yield  # This is where the application starts up. | 
					
						
							|  |  |  |     # shutdown/garbage collection code goes here | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-30 00:59:35 -05:00
										 |  |  | app = FastAPIOffline( | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  |     docs_url="/manpage/docs", | 
					
						
							|  |  |  |     openapi_url="/manpage/openapi.json", | 
					
						
							| 
									
										
										
										
											2024-12-30 00:59:35 -05:00
										 |  |  |     root_path=os.getenv("API_ROOT_PATH", ""), | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  |     title="GraphRAG", | 
					
						
							| 
									
										
										
										
											2024-08-09 22:22:49 -04:00
										 |  |  |     version=os.getenv("GRAPHRAG_VERSION", "undefined_version"), | 
					
						
							|  |  |  |     lifespan=lifespan, | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | ) | 
					
						
							|  |  |  | app.middleware("http")(catch_all_exceptions_middleware) | 
					
						
							|  |  |  | app.add_middleware( | 
					
						
							|  |  |  |     CORSMiddleware, | 
					
						
							|  |  |  |     allow_origins=["*"], | 
					
						
							|  |  |  |     allow_credentials=True, | 
					
						
							|  |  |  |     allow_methods=["*"], | 
					
						
							|  |  |  |     allow_headers=["*"], | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | app.include_router(data_route) | 
					
						
							|  |  |  | app.include_router(index_route) | 
					
						
							|  |  |  | app.include_router(query_route) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:41:46 -04:00
										 |  |  | app.include_router(query_streaming_route) | 
					
						
							| 
									
										
										
										
											2025-01-21 00:29:48 -05:00
										 |  |  | app.include_router(prompt_tuning_route) | 
					
						
							| 
									
										
										
										
											2024-06-26 15:45:06 -04:00
										 |  |  | app.include_router(source_route) | 
					
						
							|  |  |  | app.include_router(graph_route) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # health check endpoint | 
					
						
							|  |  |  | @app.get( | 
					
						
							|  |  |  |     "/health", | 
					
						
							|  |  |  |     summary="API health check", | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | def health_check(): | 
					
						
							|  |  |  |     """Returns a 200 response to indicate the API is healthy.""" | 
					
						
							|  |  |  |     return Response(status_code=status.HTTP_200_OK) |