82 lines
2.9 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import os
2024-09-12 21:41:46 -04:00
import traceback
2024-06-26 15:45:06 -04:00
import graphrag.api as api
2024-06-26 15:45:06 -04:00
import yaml
from fastapi import (
APIRouter,
HTTPException,
)
from graphrag.config.create_graphrag_config import create_graphrag_config
2024-06-26 15:45:06 -04:00
2025-01-25 04:07:53 -05:00
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager
from graphrag_app.utils.common import sanitize_name
2024-06-26 15:45:06 -04:00
prompt_tuning_route = APIRouter(prefix="/index/config", tags=["Index Configuration"])
2024-06-26 15:45:06 -04:00
@prompt_tuning_route.get(
2024-06-26 15:45:06 -04:00
"/prompts",
summary="Generate prompts from user-provided data.",
2024-06-26 15:45:06 -04:00
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
)
async def generate_prompts(storage_name: str, limit: int = 5):
"""
Automatically generate custom prompts for entity entraction,
community reports, and summarize descriptions based on a sample of provided data.
"""
# check for storage container existence
2024-12-30 01:59:08 -05:00
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
2024-06-26 15:45:06 -04:00
sanitized_storage_name = sanitize_name(storage_name)
if not blob_service_client.get_container_client(sanitized_storage_name).exists():
raise HTTPException(
status_code=500,
detail=f"Data container '{storage_name}' does not exist.",
)
# load pipeline configuration file (settings.yaml) for input data and other settings
2024-06-26 15:45:06 -04:00
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
2025-01-23 00:23:58 -05:00
data = yaml.safe_load(open(f"{this_directory}/../indexer/settings.yaml"))
2024-06-26 15:45:06 -04:00
data["input"]["container_name"] = sanitized_storage_name
graphrag_config = create_graphrag_config(values=data, root_dir=".")
2024-06-26 15:45:06 -04:00
# generate prompts
try:
# NOTE: we need to call api.generate_indexing_prompts
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
config=graphrag_config,
root=".",
2024-06-26 15:45:06 -04:00
limit=limit,
selection_method="random",
2024-09-12 21:41:46 -04:00
)
except Exception as e:
2025-01-21 18:43:55 -05:00
logger = load_pipeline_logger()
2024-09-12 21:41:46 -04:00
error_details = {
"storage_name": storage_name,
}
logger.error(
2024-09-12 21:41:46 -04:00
message="Auto-prompt generation failed.",
cause=e,
stack=traceback.format_exc(),
details=error_details,
2024-06-26 15:45:06 -04:00
)
raise HTTPException(
status_code=500,
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
)
content = {
"entity_extraction_prompt": prompts[0],
"entity_summarization_prompt": prompts[1],
"community_summarization_prompt": prompts[2],
}
return content # return a fastapi.responses.JSONResponse object