326 lines
12 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import traceback
2024-06-26 15:45:06 -04:00
import pandas as pd
from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
)
2024-06-26 15:45:06 -04:00
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import (
2024-06-26 15:45:06 -04:00
ClaimResponse,
EntityResponse,
RelationshipResponse,
ReportResponse,
TextUnitResponse,
)
from graphrag_app.utils.common import (
pandas_storage_options,
sanitize_name,
subscription_key_check,
validate_index_file_exist,
)
2024-06-26 15:45:06 -04:00
source_route = APIRouter(
prefix="/source",
tags=["Sources"],
)
if os.getenv("KUBERNETES_SERVICE_HOST"):
source_route.dependencies.append(Depends(subscription_key_check))
2025-04-03 21:05:42 +00:00
COMMUNITY_TABLE = "output/communities.parquet"
COMMUNITY_REPORT_TABLE = "output/community_reports.parquet"
COVARIATES_TABLE = "output/covariates.parquet"
ENTITIES_TABLE = "output/entities.parquet"
RELATIONSHIPS_TABLE = "output/relationships.parquet"
TEXT_UNITS_TABLE = "output/text_units.parquet"
DOCUMENTS_TABLE = "output/documents.parquet"
2024-06-26 15:45:06 -04:00
@source_route.get(
"/report/{container_name}/{report_id}",
2024-06-26 15:45:06 -04:00
summary="Return a single community report.",
response_model=ReportResponse,
responses={status.HTTP_200_OK: {"model": ReportResponse}},
2024-06-26 15:45:06 -04:00
)
async def get_report_info(
report_id: int,
container_name: str,
sanitized_container_name: str = Depends(sanitize_name),
):
2024-06-26 15:45:06 -04:00
# check for existence of file the query relies on to validate the index is complete
validate_index_file_exist(sanitized_container_name, COMMUNITY_REPORT_TABLE)
2024-06-26 15:45:06 -04:00
try:
report_table = pd.read_parquet(
f"abfs://{sanitized_container_name}/{COMMUNITY_REPORT_TABLE}",
storage_options=pandas_storage_options(),
2024-12-30 01:59:08 -05:00
)
# check if report_id exists in the index
if not report_table["human_readable_id"].isin([report_id]).any():
raise ValueError(
f"Report '{report_id}' not found in index '{container_name}'."
)
2024-12-30 01:59:08 -05:00
# check if multiple reports with the same id exist (should not happen)
if len(report_table.loc[report_table["human_readable_id"] == report_id]) > 1:
2024-12-30 01:59:08 -05:00
raise ValueError(
f"Multiple reports with id '{report_id}' found in index '{container_name}'."
2024-12-30 01:59:08 -05:00
)
report_content = report_table.loc[
report_table["human_readable_id"] == report_id, "full_content_json"
2024-12-30 01:59:08 -05:00
].to_numpy()[0]
return ReportResponse(text=report_content)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not get report.",
cause=e,
stack=traceback.format_exc(),
)
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving report '{report_id}' from index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
@source_route.get(
"/text/{container_name}/{text_unit_id}",
2024-06-26 15:45:06 -04:00
summary="Return a single base text unit.",
response_model=TextUnitResponse,
responses={status.HTTP_200_OK: {"model": TextUnitResponse}},
2024-06-26 15:45:06 -04:00
)
async def get_chunk_info(
text_unit_id: int,
container_name: str,
sanitized_container_name: str = Depends(sanitize_name),
):
2024-06-26 15:45:06 -04:00
# check for existence of file the query relies on to validate the index is complete
validate_index_file_exist(sanitized_container_name, TEXT_UNITS_TABLE)
validate_index_file_exist(sanitized_container_name, DOCUMENTS_TABLE)
2024-06-26 15:45:06 -04:00
try:
2024-12-30 01:59:08 -05:00
text_units = pd.read_parquet(
f"abfs://{sanitized_container_name}/{TEXT_UNITS_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
text_units_filter = text_units["human_readable_id"].isin([text_unit_id])
2025-04-03 21:05:42 +00:00
# verify that text_unit_id exists in the index
if not text_units_filter.any():
raise ValueError(
f"Text unit '{text_unit_id}' not found in index '{container_name}'."
)
# explode the 'document_ids' column so the format matches with 'document_id'
text_units = text_units[text_units_filter].explode("document_ids")
2024-06-26 15:45:06 -04:00
docs = pd.read_parquet(
f"abfs://{sanitized_container_name}/{DOCUMENTS_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
2024-12-30 01:59:08 -05:00
# rename columns for easy joining
docs = docs[
[
"id", "title", "human_readable_id"
]
].rename(
columns={
"id": "document_id",
"title": "source_document",
"human_readable_id": "document_human_readable_id"
}
2024-06-26 15:45:06 -04:00
)
2024-12-30 01:59:08 -05:00
# combine tables to create a (chunk_id -> source_document) mapping
merged_table = text_units.merge(
docs, left_on="document_ids", right_on="document_id", how="left"
)
row = merged_table.loc[
merged_table["human_readable_id"] == text_unit_id,
[
"text",
"source_document",
"human_readable_id",
"document_human_readable_id"
]
2024-06-26 15:45:06 -04:00
]
return TextUnitResponse(
text_unit_id=row["human_readable_id"].to_numpy()[0],
source_document_id=row["document_human_readable_id"].to_numpy()[0],
2025-04-03 21:05:42 +00:00
text=row["text"].to_numpy()[0],
2024-12-30 01:59:08 -05:00
source_document=row["source_document"].to_numpy()[0],
2024-06-26 15:45:06 -04:00
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not get text chunk.",
cause=e,
stack=traceback.format_exc(),
)
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving text chunk '{text_unit_id}' from index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
@source_route.get(
"/entity/{container_name}/{entity_id}",
2024-06-26 15:45:06 -04:00
summary="Return a single entity.",
response_model=EntityResponse,
responses={status.HTTP_200_OK: {"model": EntityResponse}},
2024-06-26 15:45:06 -04:00
)
async def get_entity_info(
entity_id: int,
container_name: str,
sanitized_container_name: str = Depends(sanitize_name),
):
2024-06-26 15:45:06 -04:00
# check for existence of file the query relies on to validate the index is complete
2025-04-03 21:05:42 +00:00
validate_index_file_exist(sanitized_container_name, ENTITIES_TABLE)
2024-06-26 15:45:06 -04:00
try:
entity_table = pd.read_parquet(
2025-04-03 21:05:42 +00:00
f"abfs://{sanitized_container_name}/{ENTITIES_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
text_units = pd.read_parquet(
f"abfs://{sanitized_container_name}/{TEXT_UNITS_TABLE}",
storage_options=pandas_storage_options(),
)
2024-12-30 01:59:08 -05:00
# check if entity_id exists in the index
if not entity_table["human_readable_id"].isin([entity_id]).any():
raise ValueError(
f"Entity '{entity_id}' not found in index '{container_name}'."
)
2024-12-30 01:59:08 -05:00
row = entity_table[entity_table["human_readable_id"] == entity_id]
text_unit_human_readable_ids = text_units[
text_units["id"].isin(row["text_unit_ids"].to_numpy()[0].tolist())
]["human_readable_id"].to_list()
2024-06-26 15:45:06 -04:00
return EntityResponse(
name=row["title"].to_numpy()[0],
2025-04-03 21:05:42 +00:00
type=row["type"].to_numpy()[0],
2024-12-30 01:59:08 -05:00
description=row["description"].to_numpy()[0],
text_units=text_unit_human_readable_ids,
2024-06-26 15:45:06 -04:00
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not get entity",
cause=e,
stack=traceback.format_exc(),
)
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving entity '{entity_id}' from index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
@source_route.get(
"/claim/{container_name}/{claim_id}",
2024-06-26 15:45:06 -04:00
summary="Return a single claim.",
response_model=ClaimResponse,
responses={status.HTTP_200_OK: {"model": ClaimResponse}},
2024-06-26 15:45:06 -04:00
)
async def get_claim_info(
claim_id: int,
container_name: str,
sanitized_container_name: str = Depends(sanitize_name),
):
2024-06-26 15:45:06 -04:00
# check for existence of file the query relies on to validate the index is complete
# claims is optional in graphrag
try:
validate_index_file_exist(sanitized_container_name, COVARIATES_TABLE)
2024-06-26 15:45:06 -04:00
except ValueError:
raise HTTPException(
status_code=500,
detail=f"Claim data unavailable for index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
try:
claims_table = pd.read_parquet(
f"abfs://{sanitized_container_name}/{COVARIATES_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
claims_table.human_readable_id = claims_table.human_readable_id.astype(
float
).astype(int)
row = claims_table[claims_table.human_readable_id == claim_id]
return ClaimResponse(
covariate_type=row["covariate_type"].values[0],
type=row["type"].values[0],
description=row["description"].values[0],
subject_id=row["subject_id"].values[0],
object_id=row["object_id"].values[0],
source_text=row["source_text"].values[0],
text_unit_id=row["text_unit_id"].values[0],
document_ids=row["document_ids"].values[0].tolist(),
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not get claim.", cause=e, stack=traceback.format_exc()
)
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving claim '{claim_id}' for index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
@source_route.get(
"/relationship/{container_name}/{relationship_id}",
2024-06-26 15:45:06 -04:00
summary="Return a single relationship.",
response_model=RelationshipResponse,
responses={status.HTTP_200_OK: {"model": RelationshipResponse}},
2024-06-26 15:45:06 -04:00
)
async def get_relationship_info(
relationship_id: int,
container_name: str,
sanitized_container_name: str = Depends(sanitize_name),
):
2024-06-26 15:45:06 -04:00
# check for existence of file the query relies on to validate the index is complete
validate_index_file_exist(sanitized_container_name, RELATIONSHIPS_TABLE)
2025-04-03 21:05:42 +00:00
validate_index_file_exist(sanitized_container_name, ENTITIES_TABLE)
2024-06-26 15:45:06 -04:00
try:
relationship_table = pd.read_parquet(
f"abfs://{sanitized_container_name}/{RELATIONSHIPS_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
relationship_table_row = relationship_table[
relationship_table.human_readable_id == relationship_id
]
2024-06-26 15:45:06 -04:00
entity_table = pd.read_parquet(
2025-04-03 21:05:42 +00:00
f"abfs://{sanitized_container_name}/{ENTITIES_TABLE}",
storage_options=pandas_storage_options(),
2024-06-26 15:45:06 -04:00
)
text_units = pd.read_parquet(
f"abfs://{sanitized_container_name}/{TEXT_UNITS_TABLE}",
storage_options=pandas_storage_options(),
)
text_unit_ids = text_units[text_units["id"].isin(
relationship_table_row["text_unit_ids"].values[0]
)]["human_readable_id"]
2024-06-26 15:45:06 -04:00
return RelationshipResponse(
source=relationship_table_row["source"].values[0],
2024-06-26 15:45:06 -04:00
source_id=entity_table[
entity_table.title == relationship_table_row["source"].values[0]
2024-06-26 15:45:06 -04:00
].human_readable_id.values[0],
target=relationship_table_row["target"].values[0],
2024-06-26 15:45:06 -04:00
target_id=entity_table[
entity_table.title == relationship_table_row["target"].values[0]
2024-06-26 15:45:06 -04:00
].human_readable_id.values[0],
description=relationship_table_row["description"].values[0],
text_units=text_unit_ids.to_list(), # extract text_unit_ids from a list of panda series
2024-06-26 15:45:06 -04:00
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not get relationship.", cause=e, stack=traceback.format_exc()
)
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Error retrieving relationship '{relationship_id}' from index '{container_name}'.",
2024-06-26 15:45:06 -04:00
)