mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
Add indexing job manager (#133)
This commit is contained in:
parent
a0b0629e4c
commit
8118923232
39
backend/indexing-job-manager-template.yaml
Normal file
39
backend/indexing-job-manager-template.yaml
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
# NOTE: the location of this file is important as it gets referenced by the src/main.py script
|
||||||
|
# and depends on the relative path to this file when uvicorn is run
|
||||||
|
|
||||||
|
apiVersion: batch/v1
|
||||||
|
kind: CronJob
|
||||||
|
metadata:
|
||||||
|
name: graphrag-index-manager
|
||||||
|
spec:
|
||||||
|
schedule: "*/5 * * * *"
|
||||||
|
jobTemplate:
|
||||||
|
spec:
|
||||||
|
ttlSecondsAfterFinished: 30
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
azure.workload.identity/use: "true"
|
||||||
|
spec:
|
||||||
|
serviceAccountName: PLACEHOLDER
|
||||||
|
restartPolicy: OnFailure
|
||||||
|
containers:
|
||||||
|
- name: index-job-manager
|
||||||
|
image: PLACEHOLDER
|
||||||
|
imagePullPolicy: Always
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "0.5"
|
||||||
|
memory: "0.5Gi"
|
||||||
|
limits:
|
||||||
|
cpu: "1"
|
||||||
|
memory: "1Gi"
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: graphrag
|
||||||
|
command:
|
||||||
|
- python
|
||||||
|
- "manage-indexing-jobs.py"
|
@ -1,15 +1,16 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
# NOTE: the location of this file is important, as it is referenced by the api/index.py script and depends on the relative path to this file when uvicorn is run
|
# NOTE: the location of this file is important as it gets referenced by the manage-indexing-jobs.py script
|
||||||
# To account for periods of time where an AOAI endpoint may have be getting hammered with too much work and rate-limiting will cause indexing jobs to fail, we set the backoffLimit to a high number (meaning the job will be retried 30 times before it is considered a failure) with exponential backoff
|
# and depends on the relative path to this file when uvicorn is run
|
||||||
|
|
||||||
apiVersion: batch/v1
|
apiVersion: batch/v1
|
||||||
kind: Job
|
kind: Job
|
||||||
metadata:
|
metadata:
|
||||||
name: PLACEHOLDER
|
name: PLACEHOLDER
|
||||||
spec:
|
spec:
|
||||||
ttlSecondsAfterFinished: 0
|
ttlSecondsAfterFinished: 30
|
||||||
backoffLimit: 6
|
backoffLimit: 3
|
||||||
template:
|
template:
|
||||||
metadata:
|
metadata:
|
||||||
labels:
|
labels:
|
||||||
@ -23,10 +24,10 @@ spec:
|
|||||||
imagePullPolicy: Always
|
imagePullPolicy: Always
|
||||||
resources:
|
resources:
|
||||||
requests:
|
requests:
|
||||||
cpu: "4"
|
cpu: "6"
|
||||||
memory: "24Gi"
|
memory: "24Gi"
|
||||||
limits:
|
limits:
|
||||||
cpu: "8"
|
cpu: "10"
|
||||||
memory: "32Gi"
|
memory: "32Gi"
|
||||||
envFrom:
|
envFrom:
|
||||||
- configMapRef:
|
- configMapRef:
|
120
backend/manage-indexing-jobs.py
Normal file
120
backend/manage-indexing-jobs.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
A naive implementation of a job manager that leverages k8s CronJob and CosmosDB
|
||||||
|
to schedule graphrag indexing jobs in a first-come-first-serve manner (based on epoch time).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from kubernetes import (
|
||||||
|
client,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
from src.api.azure_clients import AzureStorageClientManager
|
||||||
|
from src.api.common import sanitize_name
|
||||||
|
from src.models import PipelineJob
|
||||||
|
from src.reporting.reporter_singleton import ReporterSingleton
|
||||||
|
from src.typing.pipeline import PipelineJobState
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_indexing_job(index_name: str):
|
||||||
|
"""
|
||||||
|
Schedule a k8s job to run graphrag indexing for a given index name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config.load_incluster_config()
|
||||||
|
# get container image name
|
||||||
|
core_v1 = client.CoreV1Api()
|
||||||
|
pod_name = os.environ["HOSTNAME"]
|
||||||
|
pod = core_v1.read_namespaced_pod(
|
||||||
|
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
|
||||||
|
)
|
||||||
|
# retrieve job manifest template and replace necessary values
|
||||||
|
job_manifest = _generate_aks_job_manifest(
|
||||||
|
docker_image_name=pod.spec.containers[0].image,
|
||||||
|
index_name=index_name,
|
||||||
|
service_account_name=pod.spec.service_account_name,
|
||||||
|
)
|
||||||
|
batch_v1 = client.BatchV1Api()
|
||||||
|
batch_v1.create_namespaced_job(
|
||||||
|
body=job_manifest, namespace=os.environ["AKS_NAMESPACE"]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
reporter = ReporterSingleton().get_instance()
|
||||||
|
reporter.on_error(
|
||||||
|
"Index job manager encountered error scheduling indexing job",
|
||||||
|
)
|
||||||
|
# In the event of a catastrophic scheduling failure, something in k8s or the job manifest is likely broken.
|
||||||
|
# Set job status to failed to prevent an infinite loop of re-scheduling
|
||||||
|
pipelinejob = PipelineJob()
|
||||||
|
pipeline_job = pipelinejob.load_item(sanitize_name(index_name))
|
||||||
|
pipeline_job["status"] = PipelineJobState.FAILED
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_aks_job_manifest(
|
||||||
|
docker_image_name: str,
|
||||||
|
index_name: str,
|
||||||
|
service_account_name: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate an AKS Jobs manifest file with the specified parameters.
|
||||||
|
|
||||||
|
The manifest must be valid YAML with certain values replaced by the provided arguments.
|
||||||
|
"""
|
||||||
|
# NOTE: this file location is relative to the WORKDIR set in Dockerfile-backend
|
||||||
|
with open("indexing-job-template.yaml", "r") as f:
|
||||||
|
manifest = yaml.safe_load(f)
|
||||||
|
manifest["metadata"]["name"] = f"indexing-job-{sanitize_name(index_name)}"
|
||||||
|
manifest["spec"]["template"]["spec"]["serviceAccountName"] = service_account_name
|
||||||
|
manifest["spec"]["template"]["spec"]["containers"][0]["image"] = docker_image_name
|
||||||
|
manifest["spec"]["template"]["spec"]["containers"][0]["command"] = [
|
||||||
|
"python",
|
||||||
|
"run-indexing-job.py",
|
||||||
|
f"-i={index_name}",
|
||||||
|
]
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
azure_storage_client_manager = AzureStorageClientManager()
|
||||||
|
job_container_store_client = (
|
||||||
|
azure_storage_client_manager.get_cosmos_container_client(
|
||||||
|
database_name="graphrag", container_name="jobs"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# retrieve status for all jobs that are either scheduled or running
|
||||||
|
job_metadata = []
|
||||||
|
for item in job_container_store_client.read_all_items():
|
||||||
|
# exit if a job is running
|
||||||
|
if item["status"] == PipelineJobState.RUNNING.value:
|
||||||
|
print(
|
||||||
|
f"Indexing job for '{item['human_readable_index_name']}' already running. Will not schedule another. Exiting..."
|
||||||
|
)
|
||||||
|
exit()
|
||||||
|
if item["status"] == PipelineJobState.SCHEDULED.value:
|
||||||
|
job_metadata.append(
|
||||||
|
{
|
||||||
|
"human_readable_index_name": item["human_readable_index_name"],
|
||||||
|
"epoch_request_time": item["epoch_request_time"],
|
||||||
|
"status": item["status"],
|
||||||
|
"percent_complete": item["percent_complete"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# exit if no jobs found
|
||||||
|
if not job_metadata:
|
||||||
|
print("No jobs found")
|
||||||
|
exit()
|
||||||
|
# convert to dataframe for easy processing
|
||||||
|
df = pd.DataFrame(job_metadata)
|
||||||
|
# jobs are run in the order they were requested - sort by epoch_request_time
|
||||||
|
df.sort_values(by="epoch_request_time", ascending=True, inplace=True)
|
||||||
|
index_to_schedule = df.iloc[0]["human_readable_index_name"]
|
||||||
|
print(f"Scheduling job for index: {index_to_schedule}")
|
||||||
|
schedule_indexing_job(index_to_schedule)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from time import time
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -25,7 +26,6 @@ from kubernetes import (
|
|||||||
client,
|
client,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
from kubernetes.client.rest import ApiException
|
|
||||||
|
|
||||||
from src.api.azure_clients import (
|
from src.api.azure_clients import (
|
||||||
AzureStorageClientManager,
|
AzureStorageClientManager,
|
||||||
@ -34,7 +34,6 @@ from src.api.azure_clients import (
|
|||||||
)
|
)
|
||||||
from src.api.common import (
|
from src.api.common import (
|
||||||
delete_blob_container,
|
delete_blob_container,
|
||||||
retrieve_original_blob_container_name,
|
|
||||||
sanitize_name,
|
sanitize_name,
|
||||||
validate_blob_container_name,
|
validate_blob_container_name,
|
||||||
verify_subscription_key_exist,
|
verify_subscription_key_exist,
|
||||||
@ -129,7 +128,7 @@ async def setup_indexing_pipeline(
|
|||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=202, # request has been accepted for processing but is not complete.
|
status_code=202, # request has been accepted for processing but is not complete.
|
||||||
detail=f"an index with name {index_name} already exists and has not finished building.",
|
detail=f"Index '{index_name}' already exists and has not finished building.",
|
||||||
)
|
)
|
||||||
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
||||||
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
|
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
|
||||||
@ -146,25 +145,28 @@ async def setup_indexing_pipeline(
|
|||||||
existing_job._summarize_descriptions_prompt = (
|
existing_job._summarize_descriptions_prompt = (
|
||||||
summarize_descriptions_prompt_content
|
summarize_descriptions_prompt_content
|
||||||
)
|
)
|
||||||
|
existing_job._epoch_request_time = int(time())
|
||||||
existing_job.update_db()
|
existing_job.update_db()
|
||||||
else:
|
else:
|
||||||
pipelinejob.create_item(
|
pipelinejob.create_item(
|
||||||
id=sanitized_index_name,
|
id=sanitized_index_name,
|
||||||
index_name=sanitized_index_name,
|
human_readable_index_name=index_name,
|
||||||
storage_name=sanitized_storage_name,
|
human_readable_storage_name=storage_name,
|
||||||
entity_extraction_prompt=entity_extraction_prompt_content,
|
entity_extraction_prompt=entity_extraction_prompt_content,
|
||||||
community_report_prompt=community_report_prompt_content,
|
community_report_prompt=community_report_prompt_content,
|
||||||
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
|
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
|
||||||
status=PipelineJobState.SCHEDULED,
|
status=PipelineJobState.SCHEDULED,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
return BaseResponse(status="Indexing job scheduled")
|
||||||
At this point, we know:
|
|
||||||
1) the index name is valid
|
|
||||||
2) the data container exists
|
async def _start_indexing_pipeline(index_name: str):
|
||||||
3) there is no indexing job with this name currently running or a previous job has finished
|
# get sanitized name
|
||||||
"""
|
sanitized_index_name = sanitize_name(index_name)
|
||||||
|
|
||||||
# update or create new item in container-store in cosmosDB
|
# update or create new item in container-store in cosmosDB
|
||||||
|
_blob_service_client = BlobServiceClientSingleton().get_instance()
|
||||||
if not _blob_service_client.get_container_client(sanitized_index_name).exists():
|
if not _blob_service_client.get_container_client(sanitized_index_name).exists():
|
||||||
_blob_service_client.create_container(sanitized_index_name)
|
_blob_service_client.create_container(sanitized_index_name)
|
||||||
container_store_client = get_database_container_client(
|
container_store_client = get_database_container_client(
|
||||||
@ -178,57 +180,11 @@ async def setup_indexing_pipeline(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# schedule AKS job
|
|
||||||
try:
|
|
||||||
config.load_incluster_config()
|
|
||||||
# get container image name
|
|
||||||
core_v1 = client.CoreV1Api()
|
|
||||||
pod_name = os.environ["HOSTNAME"]
|
|
||||||
pod = core_v1.read_namespaced_pod(
|
|
||||||
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
|
|
||||||
)
|
|
||||||
# retrieve job manifest template and replace necessary values
|
|
||||||
job_manifest = _generate_aks_job_manifest(
|
|
||||||
docker_image_name=pod.spec.containers[0].image,
|
|
||||||
index_name=index_name,
|
|
||||||
service_account_name=pod.spec.service_account_name,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
batch_v1 = client.BatchV1Api()
|
|
||||||
batch_v1.create_namespaced_job(
|
|
||||||
body=job_manifest, namespace=os.environ["AKS_NAMESPACE"]
|
|
||||||
)
|
|
||||||
except ApiException:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="exception when calling BatchV1Api->create_namespaced_job",
|
|
||||||
)
|
|
||||||
return BaseResponse(status="Indexing operation scheduled")
|
|
||||||
except Exception:
|
|
||||||
reporter = ReporterSingleton().get_instance()
|
|
||||||
job_details = {
|
|
||||||
"storage_name": storage_name,
|
|
||||||
"index_name": index_name,
|
|
||||||
}
|
|
||||||
reporter.on_error(
|
|
||||||
"Error creating a new index",
|
|
||||||
details={"job_details": job_details},
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Error occurred during setup of indexing job for index {index_name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _start_indexing_pipeline(index_name: str):
|
|
||||||
# get sanitized name
|
|
||||||
sanitized_index_name = sanitize_name(index_name)
|
|
||||||
|
|
||||||
reporter = ReporterSingleton().get_instance()
|
reporter = ReporterSingleton().get_instance()
|
||||||
pipelinejob = PipelineJob()
|
pipelinejob = PipelineJob()
|
||||||
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
||||||
sanitized_storage_name = pipeline_job.storage_name
|
sanitized_storage_name = pipeline_job.sanitized_storage_name
|
||||||
storage_name = retrieve_original_blob_container_name(sanitized_storage_name)
|
storage_name = pipeline_job.human_readable_index_name
|
||||||
|
|
||||||
# download nltk dependencies
|
# download nltk dependencies
|
||||||
bootstrap()
|
bootstrap()
|
||||||
@ -331,7 +287,7 @@ async def _start_indexing_pipeline(index_name: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
workflow_callbacks.on_log(
|
workflow_callbacks.on_log(
|
||||||
message=f"Indexing pipeline complete for index {index_name}.",
|
message=f"Indexing pipeline complete for index'{index_name}'.",
|
||||||
details={
|
details={
|
||||||
"index": index_name,
|
"index": index_name,
|
||||||
"storage_name": storage_name,
|
"storage_name": storage_name,
|
||||||
@ -353,47 +309,24 @@ async def _start_indexing_pipeline(index_name: str):
|
|||||||
}
|
}
|
||||||
# log error in local index directory logs
|
# log error in local index directory logs
|
||||||
workflow_callbacks.on_error(
|
workflow_callbacks.on_error(
|
||||||
message=f"Indexing pipeline failed for index {index_name}.",
|
message=f"Indexing pipeline failed for index '{index_name}'.",
|
||||||
cause=e,
|
cause=e,
|
||||||
stack=traceback.format_exc(),
|
stack=traceback.format_exc(),
|
||||||
details=error_details,
|
details=error_details,
|
||||||
)
|
)
|
||||||
# log error in global index directory logs
|
# log error in global index directory logs
|
||||||
reporter.on_error(
|
reporter.on_error(
|
||||||
message=f"Indexing pipeline failed for index {index_name}.",
|
message=f"Indexing pipeline failed for index '{index_name}'.",
|
||||||
cause=e,
|
cause=e,
|
||||||
stack=traceback.format_exc(),
|
stack=traceback.format_exc(),
|
||||||
details=error_details,
|
details=error_details,
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Error encountered during indexing job for index {index_name}.",
|
detail=f"Error encountered during indexing job for index '{index_name}'.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_aks_job_manifest(
|
|
||||||
docker_image_name: str,
|
|
||||||
index_name: str,
|
|
||||||
service_account_name: str,
|
|
||||||
) -> dict:
|
|
||||||
"""Generate an AKS Jobs manifest file with the specified parameters.
|
|
||||||
|
|
||||||
The manifest file must be valid YAML with certain values replaced by the provided arguments.
|
|
||||||
"""
|
|
||||||
# NOTE: the relative file locations are based on the WORKDIR set in Dockerfile-indexing
|
|
||||||
with open("src/aks-batch-job-template.yaml", "r") as f:
|
|
||||||
manifest = yaml.safe_load(f)
|
|
||||||
manifest["metadata"]["name"] = f"indexing-job-{sanitize_name(index_name)}"
|
|
||||||
manifest["spec"]["template"]["spec"]["serviceAccountName"] = service_account_name
|
|
||||||
manifest["spec"]["template"]["spec"]["containers"][0]["image"] = docker_image_name
|
|
||||||
manifest["spec"]["template"]["spec"]["containers"][0]["command"] = [
|
|
||||||
"python",
|
|
||||||
"run-indexing-job.py",
|
|
||||||
f"-i={index_name}",
|
|
||||||
]
|
|
||||||
return manifest
|
|
||||||
|
|
||||||
|
|
||||||
@index_route.get(
|
@index_route.get(
|
||||||
"",
|
"",
|
||||||
summary="Get all indexes",
|
summary="Get all indexes",
|
||||||
@ -474,7 +407,6 @@ async def delete_index(index_name: str):
|
|||||||
Delete a specified index.
|
Delete a specified index.
|
||||||
"""
|
"""
|
||||||
sanitized_index_name = sanitize_name(index_name)
|
sanitized_index_name = sanitize_name(index_name)
|
||||||
reporter = ReporterSingleton().get_instance()
|
|
||||||
try:
|
try:
|
||||||
# kill indexing job if it is running
|
# kill indexing job if it is running
|
||||||
if os.getenv("KUBERNETES_SERVICE_HOST"): # only found if in AKS
|
if os.getenv("KUBERNETES_SERVICE_HOST"): # only found if in AKS
|
||||||
@ -518,6 +450,7 @@ async def delete_index(index_name: str):
|
|||||||
index_client.delete_index(ai_search_index_name)
|
index_client.delete_index(ai_search_index_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
reporter = ReporterSingleton().get_instance()
|
||||||
reporter.on_error(
|
reporter.on_error(
|
||||||
message=f"Error encountered while deleting all data for index {index_name}.",
|
message=f"Error encountered while deleting all data for index {index_name}.",
|
||||||
stack=None,
|
stack=None,
|
||||||
@ -542,10 +475,8 @@ async def get_index_job_status(index_name: str):
|
|||||||
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
||||||
return IndexStatusResponse(
|
return IndexStatusResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
index_name=retrieve_original_blob_container_name(pipeline_job.index_name),
|
index_name=pipeline_job.human_readable_index_name,
|
||||||
storage_name=retrieve_original_blob_container_name(
|
storage_name=pipeline_job.human_readable_storage_name,
|
||||||
pipeline_job.storage_name
|
|
||||||
),
|
|
||||||
status=pipeline_job.status.value,
|
status=pipeline_job.status.value,
|
||||||
percent_complete=pipeline_job.percent_complete,
|
percent_complete=pipeline_job.percent_complete,
|
||||||
progress=pipeline_job.progress,
|
progress=pipeline_job.progress,
|
||||||
|
@ -3,7 +3,9 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import yaml
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
FastAPI,
|
FastAPI,
|
||||||
@ -12,6 +14,10 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
from kubernetes import (
|
||||||
|
client,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
from src.api.common import verify_subscription_key_exist
|
from src.api.common import verify_subscription_key_exist
|
||||||
from src.api.data import data_route
|
from src.api.data import data_route
|
||||||
@ -38,13 +44,55 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
|
|||||||
return Response("Unexpected internal server error.", status_code=500)
|
return Response("Unexpected internal server error.", status_code=500)
|
||||||
|
|
||||||
|
|
||||||
version = os.getenv("GRAPHRAG_VERSION", "undefined_version")
|
# deploy a cronjob to manage indexing jobs
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# This function is called when the FastAPI application first starts up.
|
||||||
|
# To manage multiple graphrag indexing jobs, we deploy a k8s cronjob.
|
||||||
|
# This cronjob will act as a job manager that creates/manages the execution of graphrag indexing jobs as they are requested.
|
||||||
|
try:
|
||||||
|
# Check if the cronjob exists and create it if it does not exist
|
||||||
|
config.load_incluster_config()
|
||||||
|
# retrieve the running pod spec
|
||||||
|
core_v1 = client.CoreV1Api()
|
||||||
|
pod_name = os.environ["HOSTNAME"]
|
||||||
|
pod = core_v1.read_namespaced_pod(
|
||||||
|
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
|
||||||
|
)
|
||||||
|
# load the cronjob manifest template and update PLACEHOLDER values with correct values using the pod spec
|
||||||
|
with open("indexing-job-manager-template.yaml", "r") as f:
|
||||||
|
manifest = yaml.safe_load(f)
|
||||||
|
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"]["containers"][0][
|
||||||
|
"image"
|
||||||
|
] = pod.spec.containers[0].image
|
||||||
|
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"][
|
||||||
|
"serviceAccountName"
|
||||||
|
] = pod.spec.service_account_name
|
||||||
|
# retrieve list of existing cronjobs
|
||||||
|
batch_v1 = client.BatchV1Api()
|
||||||
|
namespace_cronjobs = batch_v1.list_namespaced_cron_job(namespace="graphrag")
|
||||||
|
cronjob_names = [cronjob.metadata.name for cronjob in namespace_cronjobs.items]
|
||||||
|
# create cronjob if it does not exist
|
||||||
|
if manifest["metadata"]["name"] not in cronjob_names:
|
||||||
|
batch_v1.create_namespaced_cron_job(namespace="graphrag", body=manifest)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to create graphrag cronjob.\n{e}")
|
||||||
|
reporter = ReporterSingleton().get_instance()
|
||||||
|
reporter.on_error(
|
||||||
|
message="Failed to create graphrag cronjob",
|
||||||
|
cause=str(e),
|
||||||
|
stack=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
yield # This is where the application starts up.
|
||||||
|
# shutdown/garbage collection code goes here
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
docs_url="/manpage/docs",
|
docs_url="/manpage/docs",
|
||||||
openapi_url="/manpage/openapi.json",
|
openapi_url="/manpage/openapi.json",
|
||||||
title="GraphRAG",
|
title="GraphRAG",
|
||||||
version=version,
|
version=os.getenv("GRAPHRAG_VERSION", "undefined_version"),
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
app.middleware("http")(catch_all_exceptions_middleware)
|
app.middleware("http")(catch_all_exceptions_middleware)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from time import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
List,
|
List,
|
||||||
@ -11,6 +12,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from src.api.azure_clients import AzureStorageClientManager
|
from src.api.azure_clients import AzureStorageClientManager
|
||||||
|
from src.api.common import sanitize_name
|
||||||
from src.typing import PipelineJobState
|
from src.typing import PipelineJobState
|
||||||
|
|
||||||
|
|
||||||
@ -104,8 +106,12 @@ class TextUnitResponse(BaseModel):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PipelineJob:
|
class PipelineJob:
|
||||||
_id: str = field(default=None, init=False)
|
_id: str = field(default=None, init=False)
|
||||||
|
_epoch_request_time: int = field(default=None, init=False)
|
||||||
_index_name: str = field(default=None, init=False)
|
_index_name: str = field(default=None, init=False)
|
||||||
_storage_name: str = field(default=None, init=False)
|
_human_readable_index_name: str = field(default=None, init=False)
|
||||||
|
_sanitized_index_name: str = field(default=None, init=False)
|
||||||
|
_human_readable_storage_name: str = field(default=None, init=False)
|
||||||
|
_sanitized_storage_name: str = field(default=None, init=False)
|
||||||
_entity_extraction_prompt: str = field(default=None, init=False)
|
_entity_extraction_prompt: str = field(default=None, init=False)
|
||||||
_community_report_prompt: str = field(default=None, init=False)
|
_community_report_prompt: str = field(default=None, init=False)
|
||||||
_summarize_descriptions_prompt: str = field(default=None, init=False)
|
_summarize_descriptions_prompt: str = field(default=None, init=False)
|
||||||
@ -127,8 +133,8 @@ class PipelineJob:
|
|||||||
def create_item(
|
def create_item(
|
||||||
cls,
|
cls,
|
||||||
id: str,
|
id: str,
|
||||||
index_name: str,
|
human_readable_index_name: str,
|
||||||
storage_name: str,
|
human_readable_storage_name: str,
|
||||||
entity_extraction_prompt: str | None = None,
|
entity_extraction_prompt: str | None = None,
|
||||||
community_report_prompt: str | None = None,
|
community_report_prompt: str | None = None,
|
||||||
summarize_descriptions_prompt: str | None = None,
|
summarize_descriptions_prompt: str | None = None,
|
||||||
@ -160,15 +166,20 @@ class PipelineJob:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert id is not None, "ID cannot be None."
|
assert id is not None, "ID cannot be None."
|
||||||
assert index_name is not None, "index_name cannot be None."
|
assert human_readable_index_name is not None, "index_name cannot be None."
|
||||||
assert len(index_name) > 0, "index_name cannot be empty."
|
assert len(human_readable_index_name) > 0, "index_name cannot be empty."
|
||||||
assert storage_name is not None, "storage_name cannot be None."
|
assert human_readable_storage_name is not None, "storage_name cannot be None."
|
||||||
assert len(storage_name) > 0, "storage_name cannot be empty."
|
assert len(human_readable_storage_name) > 0, "storage_name cannot be empty."
|
||||||
|
|
||||||
instance = cls.__new__(cls, id, index_name, storage_name, **kwargs)
|
instance = cls.__new__(
|
||||||
|
cls, id, human_readable_index_name, human_readable_storage_name, **kwargs
|
||||||
|
)
|
||||||
instance._id = id
|
instance._id = id
|
||||||
instance._index_name = index_name
|
instance._epoch_request_time = int(time())
|
||||||
instance._storage_name = storage_name
|
instance._human_readable_index_name = human_readable_index_name
|
||||||
|
instance._sanitized_index_name = sanitize_name(human_readable_index_name)
|
||||||
|
instance._human_readable_storage_name = human_readable_storage_name
|
||||||
|
instance._sanitized_storage_name = sanitize_name(human_readable_storage_name)
|
||||||
instance._entity_extraction_prompt = entity_extraction_prompt
|
instance._entity_extraction_prompt = entity_extraction_prompt
|
||||||
instance._community_report_prompt = community_report_prompt
|
instance._community_report_prompt = community_report_prompt
|
||||||
instance._summarize_descriptions_prompt = summarize_descriptions_prompt
|
instance._summarize_descriptions_prompt = summarize_descriptions_prompt
|
||||||
@ -206,8 +217,14 @@ class PipelineJob:
|
|||||||
)
|
)
|
||||||
instance = cls.__new__(cls, **db_item)
|
instance = cls.__new__(cls, **db_item)
|
||||||
instance._id = db_item.get("id")
|
instance._id = db_item.get("id")
|
||||||
|
instance._epoch_request_time = db_item.get("epoch_request_time")
|
||||||
instance._index_name = db_item.get("index_name")
|
instance._index_name = db_item.get("index_name")
|
||||||
instance._storage_name = db_item.get("storage_name")
|
instance._human_readable_index_name = db_item.get("human_readable_index_name")
|
||||||
|
instance._sanitized_index_name = db_item.get("sanitized_index_name")
|
||||||
|
instance._human_readable_storage_name = db_item.get(
|
||||||
|
"human_readable_storage_name"
|
||||||
|
)
|
||||||
|
instance._sanitized_storage_name = db_item.get("sanitized_storage_name")
|
||||||
instance._entity_extraction_prompt = db_item.get("entity_extraction_prompt")
|
instance._entity_extraction_prompt = db_item.get("entity_extraction_prompt")
|
||||||
instance._community_report_prompt = db_item.get("community_report_prompt")
|
instance._community_report_prompt = db_item.get("community_report_prompt")
|
||||||
instance._summarize_descriptions_prompt = db_item.get(
|
instance._summarize_descriptions_prompt = db_item.get(
|
||||||
@ -245,8 +262,11 @@ class PipelineJob:
|
|||||||
def dump_model(self) -> dict:
|
def dump_model(self) -> dict:
|
||||||
model = {
|
model = {
|
||||||
"id": self._id,
|
"id": self._id,
|
||||||
"index_name": self._index_name,
|
"epoch_request_time": self._epoch_request_time,
|
||||||
"storage_name": self._storage_name,
|
"human_readable_index_name": self._human_readable_index_name,
|
||||||
|
"sanitized_index_name": self._sanitized_index_name,
|
||||||
|
"human_readable_storage_name": self._human_readable_storage_name,
|
||||||
|
"sanitized_storage_name": self._sanitized_storage_name,
|
||||||
"all_workflows": self._all_workflows,
|
"all_workflows": self._all_workflows,
|
||||||
"completed_workflows": self._completed_workflows,
|
"completed_workflows": self._completed_workflows,
|
||||||
"failed_workflows": self._failed_workflows,
|
"failed_workflows": self._failed_workflows,
|
||||||
@ -277,21 +297,52 @@ class PipelineJob:
|
|||||||
raise ValueError("ID cannot be changed once set.")
|
raise ValueError("ID cannot be changed once set.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def index_name(self) -> str:
|
def epoch_request_time(self) -> int:
|
||||||
return self._index_name
|
return self._epoch_request_time
|
||||||
|
|
||||||
@index_name.setter
|
@epoch_request_time.setter
|
||||||
def index_name(self, index_name: str) -> None:
|
def epoch_request_time(self, epoch_request_time: int) -> None:
|
||||||
self._index_name = index_name
|
if self._epoch_request_time is not None:
|
||||||
|
self._epoch_request_time = epoch_request_time
|
||||||
|
else:
|
||||||
|
raise ValueError("ID cannot be changed once set.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def human_readable_index_name(self) -> str:
|
||||||
|
return self._human_readable_index_name
|
||||||
|
|
||||||
|
@human_readable_index_name.setter
|
||||||
|
def human_readable_index_name(self, human_readable_index_name: str) -> None:
|
||||||
|
self._human_readable_index_name = human_readable_index_name
|
||||||
self.update_db()
|
self.update_db()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def storage_name(self) -> str:
|
def sanitized_index_name(self) -> str:
|
||||||
return self._storage_name
|
return self._sanitized_index_name
|
||||||
|
|
||||||
@storage_name.setter
|
@sanitized_index_name.setter
|
||||||
def storage_name(self, storage_name: str) -> None:
|
def sanitized_index_name(self, sanitized_index_name: str) -> None:
|
||||||
self._storage_name = storage_name
|
self._sanitized_index_name = sanitized_index_name
|
||||||
|
self.update_db()
|
||||||
|
self._sanitized_storage_name = sanitized_storage_name
|
||||||
|
self.update_db()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def human_readable_storage_name(self) -> str:
|
||||||
|
return self._human_readable_storage_name
|
||||||
|
|
||||||
|
@human_readable_storage_name.setter
|
||||||
|
def human_readable_storage_name(self, human_readable_storage_name: str) -> None:
|
||||||
|
self._human_readable_storage_name = human_readable_storage_name
|
||||||
|
self.update_db()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sanitized_storage_name(self) -> str:
|
||||||
|
return self._sanitized_storage_name
|
||||||
|
|
||||||
|
@sanitized_storage_name.setter
|
||||||
|
def sanitized_storage_name(self, sanitized_storage_name: str) -> None:
|
||||||
|
self._sanitized_storage_name = sanitized_storage_name
|
||||||
self.update_db()
|
self.update_db()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -66,6 +66,8 @@ def load_pipeline_reporter(
|
|||||||
num_workflow_steps=num_workflow_steps,
|
num_workflow_steps=num_workflow_steps,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
case Reporters.CONSOLE:
|
||||||
|
pass
|
||||||
case _:
|
case _:
|
||||||
print(f"WARNING: unknown reporter type: {reporter}. Skipping.")
|
print(f"WARNING: unknown reporter type: {reporter}. Skipping.")
|
||||||
# always register the console reporter as a fallback
|
# always register the console reporter as a fallback
|
||||||
|
@ -58,7 +58,7 @@ param dnsLabelPrefix string = toLower('${publicIpName}-${uniqueString(resourceGr
|
|||||||
@description('The workspace id of the Log Analytics resource.')
|
@description('The workspace id of the Log Analytics resource.')
|
||||||
param logAnalyticsWorkspaceId string
|
param logAnalyticsWorkspaceId string
|
||||||
|
|
||||||
// var subnetRef = resourceId('Microsoft.Network/virtualNetworks/subnets', virtualNetworkName, subnetName)
|
param restoreAPIM bool = false
|
||||||
param subnetId string
|
param subnetId string
|
||||||
|
|
||||||
resource publicIp 'Microsoft.Network/publicIPAddresses@2024-01-01' = {
|
resource publicIp 'Microsoft.Network/publicIPAddresses@2024-01-01' = {
|
||||||
@ -85,6 +85,7 @@ resource apiManagementService 'Microsoft.ApiManagement/service@2023-09-01-previe
|
|||||||
}
|
}
|
||||||
zones: ((length(availabilityZones) == 0) ? null : availabilityZones)
|
zones: ((length(availabilityZones) == 0) ? null : availabilityZones)
|
||||||
properties: {
|
properties: {
|
||||||
|
restore: restoreAPIM
|
||||||
publisherEmail: publisherEmail
|
publisherEmail: publisherEmail
|
||||||
publisherName: publisherName
|
publisherName: publisherName
|
||||||
virtualNetworkType: 'External'
|
virtualNetworkType: 'External'
|
||||||
|
@ -46,7 +46,7 @@ resource cosmosDb 'Microsoft.DocumentDB/databaseAccounts@2022-11-15' = {
|
|||||||
databaseAccountOfferType: 'Standard'
|
databaseAccountOfferType: 'Standard'
|
||||||
defaultIdentity: 'FirstPartyIdentity'
|
defaultIdentity: 'FirstPartyIdentity'
|
||||||
networkAclBypass: 'None'
|
networkAclBypass: 'None'
|
||||||
disableLocalAuth: false
|
disableLocalAuth: true
|
||||||
enablePartitionMerge: false
|
enablePartitionMerge: false
|
||||||
minimalTlsVersion: 'Tls12'
|
minimalTlsVersion: 'Tls12'
|
||||||
consistencyPolicy: {
|
consistencyPolicy: {
|
||||||
|
@ -289,6 +289,27 @@ getAksCredentials () {
|
|||||||
printf "Done\n"
|
printf "Done\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkForApimSoftDelete () {
|
||||||
|
printf "Checking if APIM was soft-deleted... "
|
||||||
|
# This is an optional step to check if an APIM instance previously existed in the
|
||||||
|
# resource group and is in a soft-deleted state. If so, purge it before deploying
|
||||||
|
# a new APIM instance to prevent conflicts with the new deployment.
|
||||||
|
local RESULTS=$(az apim deletedservice list -o json --query "[?contains(serviceId, 'resourceGroups/$RESOURCE_GROUP/')].{name:name, location:location}")
|
||||||
|
exitIfCommandFailed $? "Error checking for soft-deleted APIM instances, exiting..."
|
||||||
|
local apimName=$(jq -r .[0].name <<< $RESULTS)
|
||||||
|
local location=$(jq -r .[0].location <<< $RESULTS)
|
||||||
|
# jq returns "null" if a value is not found
|
||||||
|
if [ -z "$apimName" ] || [[ "$apimName" == "null" ]] || [ -z "$location" ] || [[ "$location" == "null" ]]; then
|
||||||
|
printf "Done.\n"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
if [ ! -z "$apimName" ] && [ ! -z "$location" ]; then
|
||||||
|
printf "\nAPIM instance found in soft-deleted state. Purging...\n"
|
||||||
|
az apim deletedservice purge -n $apimName --location "$location" > /dev/null
|
||||||
|
fi
|
||||||
|
printf "Done.\n"
|
||||||
|
}
|
||||||
|
|
||||||
deployAzureResources () {
|
deployAzureResources () {
|
||||||
echo "Deploying Azure resources..."
|
echo "Deploying Azure resources..."
|
||||||
local SSH_PUBLICKEY=$(jq -r .publicKey <<< $SSHKEY_DETAILS)
|
local SSH_PUBLICKEY=$(jq -r .publicKey <<< $SSHKEY_DETAILS)
|
||||||
@ -606,6 +627,7 @@ createResourceGroupIfNotExists $LOCATION $RESOURCE_GROUP
|
|||||||
createSshkeyIfNotExists $RESOURCE_GROUP
|
createSshkeyIfNotExists $RESOURCE_GROUP
|
||||||
|
|
||||||
# Deploy Azure resources
|
# Deploy Azure resources
|
||||||
|
checkForApimSoftDelete
|
||||||
deployAzureResources
|
deployAzureResources
|
||||||
|
|
||||||
# Deploy the graphrag backend docker image to ACR
|
# Deploy the graphrag backend docker image to ACR
|
||||||
|
@ -9,5 +9,5 @@ rules:
|
|||||||
resources: ["pods"]
|
resources: ["pods"]
|
||||||
verbs: ["get", "list", "watch", "create", "update", "patch", "delete"]
|
verbs: ["get", "list", "watch", "create", "update", "patch", "delete"]
|
||||||
- apiGroups: ["batch", "extensions"]
|
- apiGroups: ["batch", "extensions"]
|
||||||
resources: ["jobs"]
|
resources: ["*"]
|
||||||
verbs: ["get", "list", "watch", "create", "update", "patch", "delete"]
|
verbs: ["*"]
|
||||||
|
@ -46,6 +46,9 @@ param aksSshRsaPublicKey string
|
|||||||
@description('Whether to enable private endpoints.')
|
@description('Whether to enable private endpoints.')
|
||||||
param enablePrivateEndpoints bool = true
|
param enablePrivateEndpoints bool = true
|
||||||
|
|
||||||
|
@description('Whether to restore the API Management instance.')
|
||||||
|
param restoreAPIM bool = false
|
||||||
|
|
||||||
param acrName string = ''
|
param acrName string = ''
|
||||||
param apimName string = ''
|
param apimName string = ''
|
||||||
param apimTier string = 'Developer'
|
param apimTier string = 'Developer'
|
||||||
@ -264,6 +267,7 @@ module apim 'core/apim/apim.bicep' = {
|
|||||||
name: 'apim'
|
name: 'apim'
|
||||||
params: {
|
params: {
|
||||||
apiManagementName: !empty(apimName) ? apimName : '${abbrs.apiManagementService}${resourceBaseNameFinal}'
|
apiManagementName: !empty(apimName) ? apimName : '${abbrs.apiManagementService}${resourceBaseNameFinal}'
|
||||||
|
restoreAPIM: restoreAPIM
|
||||||
appInsightsName: '${abbrs.insightsComponents}${resourceBaseNameFinal}'
|
appInsightsName: '${abbrs.insightsComponents}${resourceBaseNameFinal}'
|
||||||
appInsightsPublicNetworkAccessForIngestion: enablePrivateEndpoints ? 'Disabled' : 'Enabled'
|
appInsightsPublicNetworkAccessForIngestion: enablePrivateEndpoints ? 'Disabled' : 'Enabled'
|
||||||
publicIpName: '${abbrs.networkPublicIPAddresses}${resourceBaseNameFinal}'
|
publicIpName: '${abbrs.networkPublicIPAddresses}${resourceBaseNameFinal}'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user