2024-06-26 15:45:06 -04:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
from fastapi import (
|
|
|
|
APIRouter,
|
|
|
|
Depends,
|
|
|
|
HTTPException,
|
|
|
|
)
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from graphrag.prompt_tune.cli import fine_tune as generate_fine_tune_prompts
|
|
|
|
|
|
|
|
from src.api.azure_clients import (
|
|
|
|
AzureStorageClientManager,
|
|
|
|
BlobServiceClientSingleton,
|
|
|
|
)
|
|
|
|
from src.api.common import (
|
|
|
|
sanitize_name,
|
|
|
|
verify_subscription_key_exist,
|
|
|
|
)
|
|
|
|
from src.models import (
|
|
|
|
BaseResponse,
|
|
|
|
EntityConfiguration,
|
|
|
|
EntityNameList,
|
|
|
|
)
|
|
|
|
from src.reporting import ReporterSingleton
|
|
|
|
|
|
|
|
azure_storage_client_manager = AzureStorageClientManager()
|
|
|
|
index_configuration_route = APIRouter(
|
|
|
|
prefix="/index/config", tags=["Index Configuration"]
|
|
|
|
)
|
|
|
|
|
|
|
|
if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|
|
|
index_configuration_route.dependencies.append(
|
|
|
|
Depends(verify_subscription_key_exist)
|
|
|
|
)
|
|
|
|
|
|
|
|
# NOTE: currently disable all /entity endpoints - to be replaced by the auto-generation of prompts
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.get(
|
|
|
|
"/entity",
|
|
|
|
summary="Get all entity configurations",
|
|
|
|
response_model=EntityNameList,
|
|
|
|
responses={200: {"model": EntityNameList}, 400: {"model": EntityNameList}},
|
|
|
|
include_in_schema=False,
|
|
|
|
)
|
|
|
|
async def get_all_entitys():
|
|
|
|
"""
|
|
|
|
Retrieve a list of all entity configuration names.
|
|
|
|
"""
|
|
|
|
items = []
|
|
|
|
try:
|
|
|
|
entity_container = azure_storage_client_manager.get_cosmos_container_client(
|
|
|
|
database_name="graphrag", container_name="entities"
|
|
|
|
)
|
|
|
|
for item in entity_container.read_all_items():
|
|
|
|
items.append(item["human_readable_name"])
|
2024-06-27 16:05:12 -04:00
|
|
|
except Exception:
|
2024-06-26 15:45:06 -04:00
|
|
|
reporter = ReporterSingleton.get_instance()
|
2024-06-27 16:05:12 -04:00
|
|
|
reporter.on_error("Error getting all entity configurations")
|
2024-06-26 15:45:06 -04:00
|
|
|
return EntityNameList(entity_configuration_name=items)
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.post(
|
|
|
|
"/entity",
|
|
|
|
summary="Create an entity configuration",
|
|
|
|
response_model=BaseResponse,
|
|
|
|
responses={200: {"model": BaseResponse}},
|
|
|
|
include_in_schema=False,
|
|
|
|
)
|
|
|
|
async def create_entity(request: EntityConfiguration):
|
|
|
|
# check for entity configuration existence
|
|
|
|
entity_container = azure_storage_client_manager.get_cosmos_container_client(
|
|
|
|
database_name="graphrag", container_name="entities"
|
|
|
|
)
|
|
|
|
sanitized_entity_config_name = sanitize_name(request.entity_configuration_name)
|
|
|
|
try:
|
|
|
|
# throw error if entity configuration already exists
|
|
|
|
entity_container.read_item(
|
|
|
|
item=sanitized_entity_config_name,
|
|
|
|
partition_key=sanitized_entity_config_name,
|
|
|
|
)
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Entity configuration name '{request.entity_configuration_name}' already exists.",
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# create entity configuration and add to database
|
|
|
|
all_examples = ""
|
|
|
|
entity_examples = []
|
|
|
|
for example in request.entity_examples:
|
|
|
|
if (
|
|
|
|
len(example.entity_types) == 0
|
|
|
|
or len(example.text) == 0
|
|
|
|
or len(example.output) == 0
|
|
|
|
):
|
|
|
|
return BaseResponse(status="Example contains empty parameters")
|
|
|
|
all_examples += example.entity_types
|
|
|
|
entity_examples.append(
|
|
|
|
{
|
|
|
|
"entity_types": example.entity_types,
|
|
|
|
"text": example.text,
|
|
|
|
"output": example.output,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
for entity in request.entity_types:
|
|
|
|
if entity not in all_examples:
|
|
|
|
return BaseResponse(
|
|
|
|
status=f"Entity '{entity}' does not have an associated example."
|
|
|
|
)
|
|
|
|
entity_container.create_item(
|
|
|
|
{
|
|
|
|
"id": sanitized_entity_config_name,
|
|
|
|
"human_readable_name": request.entity_configuration_name,
|
|
|
|
"entity_types": request.entity_types,
|
|
|
|
"entity_examples": entity_examples,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
return BaseResponse(status="Success")
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.put(
|
|
|
|
"/entity",
|
|
|
|
summary="Update an existing entity configuration",
|
|
|
|
response_model=BaseResponse,
|
|
|
|
responses={200: {"model": BaseResponse}},
|
|
|
|
include_in_schema=False,
|
|
|
|
)
|
|
|
|
async def update_entity(request: EntityConfiguration):
|
|
|
|
# check for entity configuration existence
|
|
|
|
reporter = ReporterSingleton.get_instance()
|
|
|
|
existing_item = None
|
|
|
|
try:
|
|
|
|
entity_container = azure_storage_client_manager.get_cosmos_container_client(
|
|
|
|
database_name="graphrag", container_name="entities"
|
|
|
|
)
|
|
|
|
sanitized_config_name = sanitize_name(request.entity_configuration_name)
|
|
|
|
existing_item = entity_container.read_item(
|
|
|
|
item=sanitized_config_name,
|
|
|
|
partition_key=sanitized_config_name,
|
|
|
|
)
|
2024-06-27 16:05:12 -04:00
|
|
|
except Exception:
|
|
|
|
reporter.on_error("Error getting entity type")
|
2024-06-26 15:45:06 -04:00
|
|
|
reporter.on_error(
|
|
|
|
f"Item with entity configuration name '{request.entity_configuration_name}' not found."
|
|
|
|
)
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Entity configuration '{request.entity_configuration_name}' not found.",
|
|
|
|
)
|
|
|
|
# update entity configuration and add back to database
|
|
|
|
try:
|
|
|
|
all_examples = ""
|
|
|
|
for example in request.entity_examples:
|
|
|
|
if (
|
|
|
|
len(example.entity_types) == 0
|
|
|
|
or len(example.text) == 0
|
|
|
|
or len(example.output) == 0
|
|
|
|
):
|
|
|
|
return BaseResponse(status="Example contains empty parameters")
|
|
|
|
all_examples += example.entity_types
|
|
|
|
for entity in request.entity_types:
|
|
|
|
if entity not in all_examples:
|
|
|
|
return BaseResponse(
|
|
|
|
status=f"Entity '{entity}' does not have an example associated."
|
|
|
|
)
|
|
|
|
# Update the existing item with the new information if it is different
|
|
|
|
if existing_item["entity_types"] != request.entity_types:
|
|
|
|
existing_item["entity_types"] = request.entity_types
|
|
|
|
if existing_item["entity_examples"] != request.entity_examples:
|
|
|
|
existing_item["entity_examples"] = [
|
|
|
|
{"entity_types": i.entity_types, "text": i.text, "output": i.output}
|
|
|
|
for i in request.entity_examples
|
|
|
|
]
|
|
|
|
entity_container.replace_item(sanitized_config_name, existing_item)
|
2024-06-27 16:05:12 -04:00
|
|
|
except Exception:
|
|
|
|
reporter.on_error("Error updating entity type")
|
2024-06-26 15:45:06 -04:00
|
|
|
return BaseResponse(status="Success.")
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.get(
|
|
|
|
"/entity/{entity_configuration_name}",
|
|
|
|
summary="Get a specified entity configuration",
|
|
|
|
response_model=Union[EntityConfiguration, BaseResponse],
|
|
|
|
responses={200: {"model": EntityConfiguration}, 400: {"model": BaseResponse}},
|
|
|
|
include_in_schema=False,
|
|
|
|
)
|
|
|
|
async def get_entity(entity_configuration_name: str):
|
|
|
|
reporter = ReporterSingleton.get_instance()
|
|
|
|
try:
|
|
|
|
existing_item = None
|
|
|
|
entity_container = azure_storage_client_manager.get_cosmos_container_client(
|
|
|
|
database_name="graphrag", container_name="entities"
|
|
|
|
)
|
|
|
|
sanitized_config_name = sanitize_name(entity_configuration_name)
|
|
|
|
existing_item = entity_container.read_item(
|
|
|
|
item=sanitized_config_name,
|
|
|
|
partition_key=sanitized_config_name,
|
|
|
|
)
|
|
|
|
return EntityConfiguration(
|
|
|
|
entity_configuration_name=existing_item["human_readable_name"],
|
|
|
|
entity_types=existing_item["entity_types"],
|
|
|
|
entity_examples=existing_item["entity_examples"],
|
|
|
|
)
|
2024-06-27 16:05:12 -04:00
|
|
|
except Exception:
|
|
|
|
reporter.on_error("Error getting entity type")
|
2024-06-26 15:45:06 -04:00
|
|
|
reporter.on_error(
|
|
|
|
f"Item with entity configuration name '{entity_configuration_name}' not found."
|
|
|
|
)
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Entity configuration '{entity_configuration_name}' not found.",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.delete(
|
|
|
|
"/entity/{entity_configuration_name}",
|
|
|
|
summary="Delete a specified entity configuration",
|
|
|
|
response_model=BaseResponse,
|
|
|
|
responses={200: {"model": BaseResponse}},
|
|
|
|
include_in_schema=False,
|
|
|
|
)
|
|
|
|
async def delete_entity(entity_configuration_name: str):
|
|
|
|
reporter = ReporterSingleton.get_instance()
|
|
|
|
try:
|
|
|
|
entity_container = azure_storage_client_manager.get_cosmos_container_client(
|
|
|
|
database_name="graphrag", container_name="entities"
|
|
|
|
)
|
|
|
|
sanitized_entity_config_name = sanitize_name(entity_configuration_name)
|
|
|
|
entity_container.delete_item(
|
|
|
|
item=sanitized_entity_config_name,
|
|
|
|
partition_key=sanitized_entity_config_name,
|
|
|
|
)
|
|
|
|
return BaseResponse(status="Success")
|
2024-06-27 16:05:12 -04:00
|
|
|
except Exception:
|
|
|
|
reporter.on_error("Error deleting entity")
|
2024-06-26 15:45:06 -04:00
|
|
|
reporter.on_error(
|
|
|
|
f"Item with entity configuration name '{entity_configuration_name}' not found."
|
|
|
|
)
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Entity configuration '{entity_configuration_name}' not found.",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@index_configuration_route.get(
|
|
|
|
"/prompts",
|
|
|
|
summary="Generate graphrag prompts from user-provided data.",
|
|
|
|
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
|
|
|
)
|
|
|
|
async def generate_prompts(storage_name: str, limit: int = 5):
|
|
|
|
"""
|
|
|
|
Automatically generate custom prompts for entity entraction,
|
|
|
|
community reports, and summarize descriptions based on a sample of provided data.
|
|
|
|
"""
|
|
|
|
# check for storage container existence
|
|
|
|
blob_service_client = BlobServiceClientSingleton().get_instance()
|
|
|
|
sanitized_storage_name = sanitize_name(storage_name)
|
|
|
|
if not blob_service_client.get_container_client(sanitized_storage_name).exists():
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Data container '{storage_name}' does not exist.",
|
|
|
|
)
|
|
|
|
this_directory = os.path.dirname(
|
|
|
|
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
|
|
|
)
|
|
|
|
|
|
|
|
# write custom settings.yaml to a file and store in a temporary directory
|
|
|
|
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
|
|
|
data["input"]["container_name"] = sanitized_storage_name
|
|
|
|
temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning"
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
|
|
with open(f"{temp_dir}/settings.yaml", "w") as f:
|
|
|
|
yaml.dump(data, f, default_flow_style=False)
|
|
|
|
|
|
|
|
# generate prompts
|
|
|
|
try:
|
|
|
|
await generate_fine_tune_prompts(
|
|
|
|
root=temp_dir,
|
|
|
|
domain="",
|
|
|
|
select="random",
|
|
|
|
limit=limit,
|
|
|
|
skip_entity_types=True,
|
|
|
|
output="prompts",
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=500,
|
|
|
|
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
|
|
|
|
)
|
|
|
|
|
|
|
|
# zip up the generated prompt files and return the zip file
|
|
|
|
temp_archive = (
|
|
|
|
f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip
|
|
|
|
)
|
|
|
|
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")
|
|
|
|
|
|
|
|
def iterfile(file_path: str):
|
|
|
|
with open(file_path, mode="rb") as file_like:
|
|
|
|
yield from file_like
|
|
|
|
|
|
|
|
return StreamingResponse(iterfile(f"{temp_archive}.zip"))
|