2024-06-26 15:45:06 -04:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import os
|
2024-09-12 21:41:46 -04:00
|
|
|
import traceback
|
2024-06-26 15:45:06 -04:00
|
|
|
|
2025-01-02 23:19:28 -05:00
|
|
|
import graphrag.api as api
|
2024-06-26 15:45:06 -04:00
|
|
|
import yaml
|
|
|
|
from fastapi import (
|
|
|
|
APIRouter,
|
|
|
|
HTTPException,
|
|
|
|
)
|
2025-01-02 23:19:28 -05:00
|
|
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
2024-06-26 15:45:06 -04:00
|
|
|
|
2024-12-30 01:59:08 -05:00
|
|
|
from src.api.azure_clients import AzureClientManager
|
|
|
|
from src.logger import LoggerSingleton
|
2025-01-21 00:29:48 -05:00
|
|
|
from src.utils.common import sanitize_name
|
2024-06-26 15:45:06 -04:00
|
|
|
|
2025-01-21 00:29:48 -05:00
|
|
|
prompt_tuning_route = APIRouter(prefix="/index/config", tags=["Index Configuration"])
|
2024-06-26 15:45:06 -04:00
|
|
|
|
|
|
|
|
2025-01-21 00:29:48 -05:00
|
|
|
@prompt_tuning_route.get(
|
2024-06-26 15:45:06 -04:00
|
|
|
"/prompts",
|
2025-01-02 23:19:28 -05:00
|
|
|
summary="Generate prompts from user-provided data.",
|
2024-06-26 15:45:06 -04:00
|
|
|
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
|
|
|
)
|
|
|
|
async def generate_prompts(storage_name: str, limit: int = 5):
|
|
|
|
"""
|
|
|
|
Automatically generate custom prompts for entity entraction,
|
|
|
|
community reports, and summarize descriptions based on a sample of provided data.
|
|
|
|
"""
|
|
|
|
# check for storage container existence
|
2024-12-30 01:59:08 -05:00
|
|
|
azure_client_manager = AzureClientManager()
|
|
|
|
blob_service_client = azure_client_manager.get_blob_service_client()
|
2024-06-26 15:45:06 -04:00
|
|
|
sanitized_storage_name = sanitize_name(storage_name)
|
|
|
|
if not blob_service_client.get_container_client(sanitized_storage_name).exists():
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Data container '{storage_name}' does not exist.",
|
|
|
|
)
|
2025-01-02 23:19:28 -05:00
|
|
|
|
|
|
|
# load pipeline configuration file (settings.yaml) for input data and other settings
|
2024-06-26 15:45:06 -04:00
|
|
|
this_directory = os.path.dirname(
|
|
|
|
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
|
|
|
)
|
|
|
|
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
|
|
|
data["input"]["container_name"] = sanitized_storage_name
|
2025-01-02 23:19:28 -05:00
|
|
|
graphrag_config = create_graphrag_config(values=data, root_dir=".")
|
2024-06-26 15:45:06 -04:00
|
|
|
|
|
|
|
# generate prompts
|
|
|
|
try:
|
2025-01-02 23:19:28 -05:00
|
|
|
# NOTE: we need to call api.generate_indexing_prompts
|
|
|
|
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
|
|
|
|
config=graphrag_config,
|
|
|
|
root=".",
|
2024-06-26 15:45:06 -04:00
|
|
|
limit=limit,
|
2025-01-02 23:19:28 -05:00
|
|
|
selection_method="random",
|
2024-09-12 21:41:46 -04:00
|
|
|
)
|
|
|
|
except Exception as e:
|
2024-12-30 01:59:08 -05:00
|
|
|
logger = LoggerSingleton().get_instance()
|
2024-09-12 21:41:46 -04:00
|
|
|
error_details = {
|
|
|
|
"storage_name": storage_name,
|
|
|
|
}
|
2025-01-21 00:29:48 -05:00
|
|
|
logger.error(
|
2024-09-12 21:41:46 -04:00
|
|
|
message="Auto-prompt generation failed.",
|
|
|
|
cause=e,
|
|
|
|
stack=traceback.format_exc(),
|
|
|
|
details=error_details,
|
2024-06-26 15:45:06 -04:00
|
|
|
)
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
|
|
|
|
)
|
|
|
|
|
2025-01-02 23:19:28 -05:00
|
|
|
content = {
|
|
|
|
"entity_extraction_prompt": prompts[0],
|
|
|
|
"entity_summarization_prompt": prompts[1],
|
|
|
|
"community_summarization_prompt": prompts[2],
|
|
|
|
}
|
|
|
|
return content # return a fastapi.responses.JSONResponse object
|