84 lines
3.0 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from io import BytesIO
import networkx as nx
from fastapi import (
APIRouter,
HTTPException,
)
from fastapi.responses import StreamingResponse
from src.api.azure_clients import BlobServiceClientSingleton
from src.api.common import (
sanitize_name,
validate_index_file_exist,
)
from src.models import GraphDataResponse
from src.reporting import ReporterSingleton
blob_service_client = BlobServiceClientSingleton.get_instance()
graph_route = APIRouter(
prefix="/graph",
tags=["Graph Operations"],
)
@graph_route.get(
"/graphml/{index_name}",
summary="Retrieve a GraphML file of the knowledge graph",
response_description="GraphML file successfully downloaded",
)
async def retrieve_graphml_file(index_name: str):
# validate index_name and graphml file existence
sanitized_index_name = sanitize_name(index_name)
graphml_filename = "summarized_graph.graphml"
blob_filepath = f"output/{graphml_filename}" # expected file location of the graph based on the workflow
validate_index_file_exist(sanitized_index_name, blob_filepath)
try:
blob_client = blob_service_client.get_blob_client(
container=sanitized_index_name, blob=blob_filepath
)
blob_stream = blob_client.download_blob().chunks()
return StreamingResponse(
blob_stream,
media_type="application/octet-stream",
headers={"Content-Disposition": f"attachment; filename={graphml_filename}"},
)
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not retrieve graphml file")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graphml file for index '{index_name}'.",
)
@graph_route.get(
"/stats/{index_name}",
summary="Retrieve basic graph statistics, number of nodes and edges",
response_model=GraphDataResponse,
responses={200: {"model": GraphDataResponse}},
)
async def retrieve_graph_stats(index_name: str):
# validate index_name and knowledge graph file existence
sanitized_index_name = sanitize_name(index_name)
graph_file = "output/summarized_graph.graphml" # expected filename of the graph based on the indexing workflow
validate_index_file_exist(sanitized_index_name, graph_file)
try:
storage_client = blob_service_client.get_container_client(sanitized_index_name)
blob_data = storage_client.download_blob(graph_file).readall()
bytes_io = BytesIO(blob_data)
g = nx.read_graphml(bytes_io)
return GraphDataResponse(nodes=len(g.nodes), edges=len(g.edges))
except Exception:
2024-06-26 15:45:06 -04:00
reporter = ReporterSingleton().get_instance()
reporter.on_error("Could not retrieve graph data file")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graph statistics for index '{index_name}'.",
)