235 lines
8.9 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pandas as pd
from azure.identity import DefaultAzureCredential
from fastapi import APIRouter, Depends, HTTPException
from src.api.common import (
sanitize_name,
validate_index_file_exist,
verify_subscription_key_exist,
)
from src.models import (
ClaimResponse,
EntityResponse,
RelationshipResponse,
ReportResponse,
TextUnitResponse,
)
from src.reporting import ReporterSingleton
source_route = APIRouter(
prefix="/source",
tags=["Sources"],
)
if os.getenv("KUBERNETES_SERVICE_HOST"):
source_route.dependencies.append(Depends(verify_subscription_key_exist))
COMMUNITY_REPORT_TABLE = "output/create_final_community_reports.parquet"
COVARIATES_TABLE = "output/create_final_covariates.parquet"
ENTITY_EMBEDDING_TABLE = "output/create_final_entities.parquet"
RELATIONSHIPS_TABLE = "output/create_final_relationships.parquet"
TEXT_UNITS_TABLE = "output/create_base_text_units.parquet"
DOCUMENTS_TABLE = "output/create_base_documents.parquet"
storage_account_blob_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
storage_account_name = storage_account_blob_url.split("//")[1].split(".")[0]
storage_account_host = storage_account_blob_url.split("//")[1]
storage_options = {
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
}
@source_route.get(
"/report/{index_name}/{report_id}",
summary="Return a single community report.",
response_model=ReportResponse,
responses={200: {"model": ReportResponse}},
)
async def get_report_info(index_name: str, report_id: str):
# check for existence of file the query relies on to validate the index is complete
sanitized_index_name = sanitize_name(index_name)
validate_index_file_exist(sanitized_index_name, COMMUNITY_REPORT_TABLE)
try:
report_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}",
storage_options=storage_options,
)
row = report_table[report_table.community == report_id]
return ReportResponse(text=row["full_content"].values[0])
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not get report.")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving report '{report_id}' from index '{index_name}'.",
)
@source_route.get(
"/text/{index_name}/{text_unit_id}",
summary="Return a single base text unit.",
response_model=TextUnitResponse,
responses={200: {"model": TextUnitResponse}},
)
async def get_chunk_info(index_name: str, text_unit_id: str):
# check for existence of file the query relies on to validate the index is complete
sanitized_index_name = sanitize_name(index_name)
validate_index_file_exist(sanitized_index_name, TEXT_UNITS_TABLE)
validate_index_file_exist(sanitized_index_name, DOCUMENTS_TABLE)
try:
text_unit_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{TEXT_UNITS_TABLE}",
storage_options=storage_options,
)
docs = pd.read_parquet(
f"abfs://{sanitized_index_name}/{DOCUMENTS_TABLE}",
storage_options=storage_options,
)
links = {
el["id"]: el["title"]
for el in docs[["id", "title"]].to_dict(orient="records")
}
text_unit_table["source_doc"] = text_unit_table["document_ids"].apply(
lambda x: links[x[0]]
)
row = text_unit_table[text_unit_table.chunk_id == text_unit_id][
["chunk", "source_doc"]
]
return TextUnitResponse(
text=row["chunk"].values[0], source_document=row["source_doc"].values[0]
)
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not get text chunk.")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving text chunk '{text_unit_id}' from index '{index_name}'.",
)
@source_route.get(
"/entity/{index_name}/{entity_id}",
summary="Return a single entity.",
response_model=EntityResponse,
responses={200: {"model": EntityResponse}},
)
async def get_entity_info(index_name: str, entity_id: int):
# check for existence of file the query relies on to validate the index is complete
sanitized_index_name = sanitize_name(index_name)
validate_index_file_exist(sanitized_index_name, ENTITY_EMBEDDING_TABLE)
try:
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options=storage_options,
)
row = entity_table[entity_table.human_readable_id == entity_id]
return EntityResponse(
name=row["name"].values[0],
description=row["description"].values[0],
text_units=row["text_unit_ids"].values[0].tolist(),
)
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not get entity")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving entity '{entity_id}' from index '{index_name}'.",
)
@source_route.get(
"/claim/{index_name}/{claim_id}",
summary="Return a single claim.",
response_model=ClaimResponse,
responses={200: {"model": ClaimResponse}},
)
async def get_claim_info(index_name: str, claim_id: int):
# check for existence of file the query relies on to validate the index is complete
# claims is optional in graphrag
sanitized_index_name = sanitize_name(index_name)
try:
validate_index_file_exist(sanitized_index_name, COVARIATES_TABLE)
except ValueError:
raise HTTPException(
status_code=500,
detail=f"Claim data unavailable for index '{index_name}'.",
)
try:
claims_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COVARIATES_TABLE}",
storage_options=storage_options,
)
claims_table.human_readable_id = claims_table.human_readable_id.astype(
float
).astype(int)
row = claims_table[claims_table.human_readable_id == claim_id]
return ClaimResponse(
covariate_type=row["covariate_type"].values[0],
type=row["type"].values[0],
description=row["description"].values[0],
subject_id=row["subject_id"].values[0],
object_id=row["object_id"].values[0],
source_text=row["source_text"].values[0],
text_unit_id=row["text_unit_id"].values[0],
document_ids=row["document_ids"].values[0].tolist(),
)
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not get claim.")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving claim '{claim_id}' from index '{index_name}'.",
)
@source_route.get(
"/relationship/{index_name}/{relationship_id}",
summary="Return a single relationship.",
response_model=RelationshipResponse,
responses={200: {"model": RelationshipResponse}},
)
async def get_relationship_info(index_name: str, relationship_id: int):
# check for existence of file the query relies on to validate the index is complete
sanitized_index_name = sanitize_name(index_name)
validate_index_file_exist(sanitized_index_name, RELATIONSHIPS_TABLE)
validate_index_file_exist(sanitized_index_name, ENTITY_EMBEDDING_TABLE)
try:
relationship_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{RELATIONSHIPS_TABLE}",
storage_options=storage_options,
)
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options=storage_options,
)
row = relationship_table[
relationship_table.human_readable_id == str(relationship_id)
]
return RelationshipResponse(
source=row["source"].values[0],
source_id=entity_table[
entity_table.name == row["source"].values[0]
].human_readable_id.values[0],
target=row["target"].values[0],
target_id=entity_table[
entity_table.name == row["target"].values[0]
].human_readable_id.values[0],
description=row["description"].values[0],
text_units=[
x[0] for x in row["text_unit_ids"].to_list()
], # extract text_unit_ids from a list of panda series
)
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not get relationship.")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving relationship '{relationship_id}' from index '{index_name}'.",
)