Multi-filetype Support via Markitdown (#269)

Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
KennyZhang1 2025-04-08 11:00:28 -04:00 committed by GitHub
parent 5d2ab180c7
commit 004fc65cdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1422 additions and 720 deletions

View File

@ -2,6 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import asyncio import asyncio
import hashlib
import os import os
import re import re
import traceback import traceback
@ -14,7 +15,9 @@ from fastapi import (
Depends, Depends,
HTTPException, HTTPException,
UploadFile, UploadFile,
status,
) )
from markitdown import MarkItDown, StreamInfo
from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import ( from graphrag_app.typing.models import (
@ -22,12 +25,15 @@ from graphrag_app.typing.models import (
StorageNameList, StorageNameList,
) )
from graphrag_app.utils.common import ( from graphrag_app.utils.common import (
check_cache,
create_cache,
delete_cosmos_container_item_if_exist, delete_cosmos_container_item_if_exist,
delete_storage_container_if_exist, delete_storage_container_if_exist,
get_blob_container_client, get_blob_container_client,
get_cosmos_container_store_client, get_cosmos_container_store_client,
sanitize_name, sanitize_name,
subscription_key_check, subscription_key_check,
update_cache,
) )
data_route = APIRouter( data_route = APIRouter(
@ -42,7 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
"", "",
summary="Get list of data containers.", summary="Get list of data containers.",
response_model=StorageNameList, response_model=StorageNameList,
responses={200: {"model": StorageNameList}}, responses={status.HTTP_200_OK: {"model": StorageNameList}},
) )
async def get_all_data_containers(): async def get_all_data_containers():
""" """
@ -67,56 +73,66 @@ async def get_all_data_containers():
return StorageNameList(storage_name=items) return StorageNameList(storage_name=items)
async def upload_file_async( async def upload_file(
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
) -> None: ):
""" """
Asynchronously upload a file to the specified blob container. Convert and upload a file to a specified blob container.
Silently ignore errors that occur when overwrite=False.
Returns a list of objects where each object will have one of the following types:
* Tuple[str, str] - a tuple of (filename, file_hash) for successful uploads
* Tuple[str, None] - a tuple of (filename, None) for failed uploads or
* None for skipped files
""" """
blob_client = container_client.get_blob_client(upload_file.filename) filename = upload_file.filename
extension = os.path.splitext(filename)[1]
converted_filename = filename + ".txt"
converted_blob_client = container_client.get_blob_client(converted_filename)
with upload_file.file as file_stream: with upload_file.file as file_stream:
try: try:
await blob_client.upload_blob(file_stream, overwrite=overwrite) file_hash = hashlib.sha256(file_stream.read()).hexdigest()
if not await check_cache(file_hash, container_client):
# extract text from file using MarkItDown
md = MarkItDown()
stream_info = StreamInfo(
extension=extension,
)
file_stream._file.seek(0)
file_stream = file_stream._file
result = md.convert_stream(
stream=file_stream,
stream_info=stream_info,
)
# remove illegal unicode characters and upload to blob storage
cleaned_result = _clean_output(result.text_content)
await converted_blob_client.upload_blob(
cleaned_result, overwrite=overwrite
)
# return tuple of (filename, file_hash) to indicate success
return (filename, file_hash)
except Exception: except Exception:
pass # if any exception occurs, return a tuple of (filename, None) to indicate conversion/upload failure
return (upload_file.filename, None)
class Cleaner: def _clean_output(val: str, replacement: str = ""):
def __init__(self, file): """Removes unicode characters that are invalid XML characters (not valid for graphml files at least)."""
self.file = file # fmt: off
self.name = file.name _illegal_xml_chars_RE = re.compile(
self.changes = 0
def clean(self, val, replacement=""):
# fmt: off
_illegal_xml_chars_RE = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]" "[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
) )
# fmt: on # fmt: on
self.changes += len(_illegal_xml_chars_RE.findall(val)) return _illegal_xml_chars_RE.sub(replacement, val)
return _illegal_xml_chars_RE.sub(replacement, val)
def read(self, n):
return self.clean(self.file.read(n).decode()).encode(
encoding="utf-8", errors="strict"
)
def name(self):
return self.file.name
def __enter__(self):
return self
def __exit__(self, *args):
self.file.close()
@data_route.post( @data_route.post(
"", "",
summary="Upload data to a data storage container", summary="Upload data to a data storage container",
response_model=BaseResponse, response_model=BaseResponse,
responses={200: {"model": BaseResponse}}, responses={status.HTTP_201_CREATED: {"model": BaseResponse}},
) )
async def upload_files( async def upload_files(
files: List[UploadFile], files: List[UploadFile],
@ -125,36 +141,33 @@ async def upload_files(
overwrite: bool = True, overwrite: bool = True,
): ):
""" """
Create a Azure Storage container and upload files to it. Create a Azure Storage container (if needed) and upload files. Multiple file types are supported, including pdf, powerpoint, word, excel, html, csv, json, xml, etc.
The complete set of supported file types can be found in the MarkItDown (https://github.com/microsoft/markitdown) library.
Args:
files (List[UploadFile]): A list of files to be uploaded.
storage_name (str): The name of the Azure Blob Storage container to which files will be uploaded.
overwrite (bool): Whether to overwrite existing files with the same name. Defaults to True. If False, files that already exist will be skipped.
Returns:
BaseResponse: An instance of the BaseResponse model with a status message indicating the result of the upload.
Raises:
HTTPException: If the container name is invalid or if any error occurs during the upload process.
""" """
try: try:
# clean files - remove illegal XML characters # create the initial cache if it doesn't exist
files = [UploadFile(Cleaner(f.file), filename=f.filename) for f in files]
# upload files in batches of 1000 to avoid exceeding Azure Storage API limits
blob_container_client = await get_blob_container_client( blob_container_client = await get_blob_container_client(
sanitized_container_name sanitized_container_name
) )
batch_size = 1000 await create_cache(blob_container_client)
# process file uploads in batches to avoid exceeding Azure Storage API limits
processing_errors = []
batch_size = 100
num_batches = ceil(len(files) / batch_size) num_batches = ceil(len(files) / batch_size)
for i in range(num_batches): for i in range(num_batches):
batch_files = files[i * batch_size : (i + 1) * batch_size] batch_files = files[i * batch_size : (i + 1) * batch_size]
tasks = [ tasks = [
upload_file_async(file, blob_container_client, overwrite) upload_file(file, blob_container_client, overwrite)
for file in batch_files for file in batch_files
] ]
await asyncio.gather(*tasks) upload_results = await asyncio.gather(*tasks)
successful_uploads = [r for r in upload_results if r and r[1] is not None]
# update the file cache with successful uploads
await update_cache(successful_uploads, blob_container_client)
# collect failed uploads
failed_uploads = [r[0] for r in upload_results if r and r[1] is None]
processing_errors.extend(failed_uploads)
# update container-store entry in cosmosDB once upload process is successful # update container-store entry in cosmosDB once upload process is successful
cosmos_container_store_client = get_cosmos_container_store_client() cosmos_container_store_client = get_cosmos_container_store_client()
@ -163,17 +176,23 @@ async def upload_files(
"human_readable_name": container_name, "human_readable_name": container_name,
"type": "data", "type": "data",
}) })
return BaseResponse(status="File upload successful.")
if len(processing_errors) > 0:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Error uploading files: {processing_errors}.",
)
return BaseResponse(status="Success.")
except Exception as e: except Exception as e:
logger = load_pipeline_logger() logger = load_pipeline_logger()
logger.error( logger.error(
message="Error uploading files.", message="Error uploading files.",
cause=e, cause=e,
stack=traceback.format_exc(), stack=traceback.format_exc(),
details={"files": [f.filename for f in files]}, details={"files": processing_errors},
) )
raise HTTPException( raise HTTPException(
status_code=500, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading files to container '{container_name}'.", detail=f"Error uploading files to container '{container_name}'.",
) )
@ -182,7 +201,7 @@ async def upload_files(
"/{container_name}", "/{container_name}",
summary="Delete a data storage container", summary="Delete a data storage container",
response_model=BaseResponse, response_model=BaseResponse,
responses={200: {"model": BaseResponse}}, responses={status.HTTP_200_OK: {"model": BaseResponse}},
) )
async def delete_files( async def delete_files(
container_name: str, sanitized_container_name: str = Depends(sanitize_name) container_name: str, sanitized_container_name: str = Depends(sanitize_name)

View File

@ -8,6 +8,7 @@ from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
status,
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -31,6 +32,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
"/graphml/{container_name}", "/graphml/{container_name}",
summary="Retrieve a GraphML file of the knowledge graph", summary="Retrieve a GraphML file of the knowledge graph",
response_description="GraphML file successfully downloaded", response_description="GraphML file successfully downloaded",
status_code=status.HTTP_200_OK,
) )
async def get_graphml_file( async def get_graphml_file(
container_name, sanitized_container_name: str = Depends(sanitize_name) container_name, sanitized_container_name: str = Depends(sanitize_name)

View File

@ -12,6 +12,7 @@ from fastapi import (
Depends, Depends,
HTTPException, HTTPException,
UploadFile, UploadFile,
status,
) )
from kubernetes import ( from kubernetes import (
client as kubernetes_client, client as kubernetes_client,
@ -49,7 +50,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
"", "",
summary="Build an index", summary="Build an index",
response_model=BaseResponse, response_model=BaseResponse,
responses={200: {"model": BaseResponse}}, responses={status.HTTP_202_ACCEPTED: {"model": BaseResponse}},
) )
async def schedule_index_job( async def schedule_index_job(
storage_container_name: str, storage_container_name: str,
@ -71,7 +72,7 @@ async def schedule_index_job(
sanitized_storage_container_name sanitized_storage_container_name
).exists(): ).exists():
raise HTTPException( raise HTTPException(
status_code=500, status_code=status.HTTP_412_PRECONDITION_FAILED,
detail=f"Storage container '{storage_container_name}' does not exist", detail=f"Storage container '{storage_container_name}' does not exist",
) )
@ -101,7 +102,7 @@ async def schedule_index_job(
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
): ):
raise HTTPException( raise HTTPException(
status_code=202, # request has been accepted for processing but is not complete. status_code=status.HTTP_425_TOO_EARLY, # request has been accepted for processing but is not complete.
detail=f"Index '{index_container_name}' already exists and has not finished building.", detail=f"Index '{index_container_name}' already exists and has not finished building.",
) )
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled # if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
@ -142,7 +143,7 @@ async def schedule_index_job(
"", "",
summary="Get all index names", summary="Get all index names",
response_model=IndexNameList, response_model=IndexNameList,
responses={200: {"model": IndexNameList}}, responses={status.HTTP_200_OK: {"model": IndexNameList}},
) )
async def get_all_index_names( async def get_all_index_names(
container_store_client=Depends(get_cosmos_container_store_client), container_store_client=Depends(get_cosmos_container_store_client),
@ -218,7 +219,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
"/{container_name}", "/{container_name}",
summary="Delete a specified index", summary="Delete a specified index",
response_model=BaseResponse, response_model=BaseResponse,
responses={200: {"model": BaseResponse}}, responses={status.HTTP_200_OK: {"model": BaseResponse}},
) )
async def delete_index( async def delete_index(
container_name: str, container_name: str,
@ -257,7 +258,8 @@ async def delete_index(
details={"container": container_name}, details={"container": container_name},
) )
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error deleting '{container_name}'." status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error deleting '{container_name}'.",
) )
return BaseResponse(status="Success") return BaseResponse(status="Success")
@ -267,6 +269,7 @@ async def delete_index(
"/status/{container_name}", "/status/{container_name}",
summary="Track the status of an indexing job", summary="Track the status of an indexing job",
response_model=IndexStatusResponse, response_model=IndexStatusResponse,
status_code=status.HTTP_200_OK,
) )
async def get_index_status( async def get_index_status(
container_name: str, sanitized_container_name: str = Depends(sanitize_name) container_name: str, sanitized_container_name: str = Depends(sanitize_name)
@ -275,7 +278,7 @@ async def get_index_status(
if pipelinejob.item_exist(sanitized_container_name): if pipelinejob.item_exist(sanitized_container_name):
pipeline_job = pipelinejob.load_item(sanitized_container_name) pipeline_job = pipelinejob.load_item(sanitized_container_name)
return IndexStatusResponse( return IndexStatusResponse(
status_code=200, status_code=status.HTTP_200_OK,
index_name=pipeline_job.human_readable_index_name, index_name=pipeline_job.human_readable_index_name,
storage_name=pipeline_job.human_readable_storage_name, storage_name=pipeline_job.human_readable_storage_name,
status=pipeline_job.status.value, status=pipeline_job.status.value,
@ -284,5 +287,6 @@ async def get_index_status(
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"'{container_name}' does not exist." status_code=status.HTTP_404_NOT_FOUND,
detail=f"'{container_name}' does not exist.",
) )

View File

@ -11,6 +11,7 @@ from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
status,
) )
from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.create_graphrag_config import create_graphrag_config
@ -27,6 +28,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
"/prompts", "/prompts",
summary="Generate custom graphrag prompts based on user-provided data.", summary="Generate custom graphrag prompts based on user-provided data.",
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.", description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
status_code=status.HTTP_200_OK,
) )
async def generate_prompts( async def generate_prompts(
container_name: str, container_name: str,

View File

@ -10,6 +10,7 @@ from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
status,
) )
from graphrag.api.query import global_search, local_search from graphrag.api.query import global_search, local_search
from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.create_graphrag_config import create_graphrag_config
@ -42,7 +43,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
summary="Perform a global search across the knowledge graph index", summary="Perform a global search across the knowledge graph index",
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.", description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
response_model=GraphResponse, response_model=GraphResponse,
responses={200: {"model": GraphResponse}}, responses={status.HTTP_200_OK: {"model": GraphResponse}},
) )
async def global_query(request: GraphRequest): async def global_query(request: GraphRequest):
# this is a slightly modified version of the graphrag.query.cli.run_global_search method # this is a slightly modified version of the graphrag.query.cli.run_global_search method
@ -51,7 +52,7 @@ async def global_query(request: GraphRequest):
if not _is_index_complete(sanitized_index_name): if not _is_index_complete(sanitized_index_name):
raise HTTPException( raise HTTPException(
status_code=500, status_code=status.HTTP_425_TOO_EARLY,
detail=f"{index_name} not ready for querying.", detail=f"{index_name} not ready for querying.",
) )
@ -122,7 +123,7 @@ async def global_query(request: GraphRequest):
summary="Perform a local search across the knowledge graph index.", summary="Perform a local search across the knowledge graph index.",
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).", description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
response_model=GraphResponse, response_model=GraphResponse,
responses={200: {"model": GraphResponse}}, responses={status.HTTP_200_OK: {"model": GraphResponse}},
) )
async def local_query(request: GraphRequest): async def local_query(request: GraphRequest):
index_name = request.index_name index_name = request.index_name
@ -130,7 +131,7 @@ async def local_query(request: GraphRequest):
if not _is_index_complete(sanitized_index_name): if not _is_index_complete(sanitized_index_name):
raise HTTPException( raise HTTPException(
status_code=500, status_code=status.HTTP_425_TOO_EARLY,
detail=f"{index_name} not ready for querying.", detail=f"{index_name} not ready for querying.",
) )

View File

@ -12,6 +12,7 @@ from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
status,
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from graphrag.api.query import ( from graphrag.api.query import (
@ -47,6 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
"/global", "/global",
summary="Stream a response back after performing a global search", summary="Stream a response back after performing a global search",
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.", description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
status_code=status.HTTP_200_OK,
) )
async def global_search_streaming(request: GraphRequest): async def global_search_streaming(request: GraphRequest):
# this is a slightly modified version of graphrag_app.api.query.global_query() method # this is a slightly modified version of graphrag_app.api.query.global_query() method
@ -204,6 +206,7 @@ async def global_search_streaming(request: GraphRequest):
"/local", "/local",
summary="Stream a response back after performing a local search", summary="Stream a response back after performing a local search",
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).", description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
status_code=status.HTTP_200_OK,
) )
async def local_search_streaming(request: GraphRequest): async def local_search_streaming(request: GraphRequest):
# this is a slightly modified version of graphrag_app.api.query.local_query() method # this is a slightly modified version of graphrag_app.api.query.local_query() method

View File

@ -5,7 +5,12 @@ import os
import traceback import traceback
import pandas as pd import pandas as pd
from fastapi import APIRouter, Depends, HTTPException from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
)
from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import ( from graphrag_app.typing.models import (
@ -43,7 +48,7 @@ DOCUMENTS_TABLE = "output/create_final_documents.parquet"
"/report/{container_name}/{report_id}", "/report/{container_name}/{report_id}",
summary="Return a single community report.", summary="Return a single community report.",
response_model=ReportResponse, response_model=ReportResponse,
responses={200: {"model": ReportResponse}}, responses={status.HTTP_200_OK: {"model": ReportResponse}},
) )
async def get_report_info( async def get_report_info(
report_id: int, report_id: int,
@ -88,7 +93,7 @@ async def get_report_info(
"/text/{container_name}/{text_unit_id}", "/text/{container_name}/{text_unit_id}",
summary="Return a single base text unit.", summary="Return a single base text unit.",
response_model=TextUnitResponse, response_model=TextUnitResponse,
responses={200: {"model": TextUnitResponse}}, responses={status.HTTP_200_OK: {"model": TextUnitResponse}},
) )
async def get_chunk_info( async def get_chunk_info(
text_unit_id: str, text_unit_id: str,
@ -148,7 +153,7 @@ async def get_chunk_info(
"/entity/{container_name}/{entity_id}", "/entity/{container_name}/{entity_id}",
summary="Return a single entity.", summary="Return a single entity.",
response_model=EntityResponse, response_model=EntityResponse,
responses={200: {"model": EntityResponse}}, responses={status.HTTP_200_OK: {"model": EntityResponse}},
) )
async def get_entity_info( async def get_entity_info(
entity_id: int, entity_id: int,
@ -190,7 +195,7 @@ async def get_entity_info(
"/claim/{container_name}/{claim_id}", "/claim/{container_name}/{claim_id}",
summary="Return a single claim.", summary="Return a single claim.",
response_model=ClaimResponse, response_model=ClaimResponse,
responses={200: {"model": ClaimResponse}}, responses={status.HTTP_200_OK: {"model": ClaimResponse}},
) )
async def get_claim_info( async def get_claim_info(
claim_id: int, claim_id: int,
@ -240,7 +245,7 @@ async def get_claim_info(
"/relationship/{container_name}/{relationship_id}", "/relationship/{container_name}/{relationship_id}",
summary="Return a single relationship.", summary="Return a single relationship.",
response_model=RelationshipResponse, response_model=RelationshipResponse,
responses={200: {"model": RelationshipResponse}}, responses={status.HTTP_200_OK: {"model": RelationshipResponse}},
) )
async def get_relationship_info( async def get_relationship_info(
relationship_id: int, relationship_id: int,

View File

@ -33,7 +33,7 @@ from graphrag_app.utils.common import subscription_key_check
async def catch_all_exceptions_middleware(request: Request, call_next): async def catch_all_exceptions_middleware(request: Request, call_next):
"""a function to globally catch all exceptions and return a 500 response with the exception message""" """A global function to catch all exceptions and produce a standard error message"""
try: try:
return await call_next(request) return await call_next(request)
except Exception as e: except Exception as e:
@ -44,7 +44,10 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
cause=e, cause=e,
stack=stack, stack=stack,
) )
return Response("Unexpected internal server error.", status_code=500) return Response(
"Unexpected internal server error.",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved # NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved

View File

@ -1,27 +1,35 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
"""
Common utility functions used by the API endpoints.
"""
import csv
import hashlib import hashlib
import os import os
import traceback import traceback
from typing import Annotated from io import StringIO
from typing import Annotated, Tuple
import pandas as pd import pandas as pd
from azure.core.exceptions import ResourceNotFoundError from azure.core.exceptions import ResourceNotFoundError
from azure.cosmos import ContainerProxy, exceptions from azure.cosmos import ContainerProxy, exceptions
from azure.identity import DefaultAzureCredential from azure.identity import DefaultAzureCredential
from azure.storage.blob.aio import ContainerClient from azure.storage.blob.aio import ContainerClient
from fastapi import Header, HTTPException from fastapi import Header, HTTPException, status
from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager from graphrag_app.utils.azure_clients import AzureClientManager
FILE_UPLOAD_CACHE = "cache/uploaded_files.csv"
def get_df( def get_df(
table_path: str, filepath: str,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Read a parquet file from Azure Storage and return it as a pandas DataFrame."""
df = pd.read_parquet( df = pd.read_parquet(
table_path, filepath,
storage_options=pandas_storage_options(), storage_options=pandas_storage_options(),
) )
return df return df
@ -123,7 +131,10 @@ def get_cosmos_container_store_client() -> ContainerProxy:
cause=e, cause=e,
stack=traceback.format_exc(), stack=traceback.format_exc(),
) )
raise HTTPException(status_code=500, detail="Error fetching cosmosdb client.") raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error fetching cosmosdb client.",
)
async def get_blob_container_client(name: str) -> ContainerClient: async def get_blob_container_client(name: str) -> ContainerClient:
@ -141,7 +152,10 @@ async def get_blob_container_client(name: str) -> ContainerClient:
cause=e, cause=e,
stack=traceback.format_exc(), stack=traceback.format_exc(),
) )
raise HTTPException(status_code=500, detail="Error fetching storage client.") raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error fetching storage client.",
)
def sanitize_name(container_name: str) -> str: def sanitize_name(container_name: str) -> str:
@ -188,7 +202,8 @@ def desanitize_name(sanitized_container_name: str) -> str | None:
return None return None
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=500, detail="Error retrieving original container name." status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving original container name.",
) )
@ -196,13 +211,102 @@ async def subscription_key_check(
Ocp_Apim_Subscription_Key: Annotated[str, Header()], Ocp_Apim_Subscription_Key: Annotated[str, Header()],
): ):
""" """
Verifies if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header. Verify if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
If it is not present, an HTTPException with a 400 status code is raised. Note: this check is unnecessary (APIM validates subscription keys automatically), but it effectively adds the key
Note: this check is unnecessary (APIM validates subscription keys automatically), but this will add the key
as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature. as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature.
""" """
if not Ocp_Apim_Subscription_Key: if not Ocp_Apim_Subscription_Key:
raise HTTPException( raise HTTPException(
status_code=400, detail="Ocp-Apim-Subscription-Key required" status_code=status.HTTP_400_BAD_REQUEST,
detail="Ocp-Apim-Subscription-Key required",
) )
return Ocp_Apim_Subscription_Key return Ocp_Apim_Subscription_Key
async def create_cache(container_client: ContainerClient) -> None:
"""
Create a file cache (csv).
"""
try:
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
if not await cache_blob_client.exists():
# create the empty file cache csv
headers = [["Filename", "Hash"]]
tmp_cache_file = "uploaded_files_cache.csv"
with open(tmp_cache_file, "w", newline="") as f:
writer = csv.writer(f, delimiter=",")
writer.writerows(headers)
# upload to Azure Blob Storage and remove the temporary file
with open(tmp_cache_file, "rb") as f:
await cache_blob_client.upload_blob(f, overwrite=True)
if os.path.exists(tmp_cache_file):
os.remove(tmp_cache_file)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating file cache in Azure Blob Storage.",
)
async def check_cache(file_hash: str, container_client: ContainerClient) -> bool:
"""
Check a file cache (csv) to determine if a file hash has previously been uploaded.
Note: This function creates/checks a CSV file in azure storage to act as a cache of previously uploaded files.
"""
try:
# load the file cache
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
cache_download_stream = await cache_blob_client.download_blob()
cache_bytes = await cache_download_stream.readall()
cache_content = StringIO(cache_bytes.decode("utf-8"))
# comupte the sha256 hash of the file and check if it exists in the cache
cache_reader = csv.reader(cache_content, delimiter=",")
for row in cache_reader:
if file_hash in row:
return True
return False
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error checking file cache in Azure Blob Storage.",
)
async def update_cache(
new_files: Tuple[str, str], container_client: ContainerClient
) -> None:
"""
Update an existing file cache (csv) with new files.
"""
try:
# Load the existing cache
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
cache_download_stream = await cache_blob_client.download_blob()
cache_bytes = await cache_download_stream.readall()
cache_content = StringIO(cache_bytes.decode("utf-8"))
cache_reader = csv.reader(cache_content, delimiter=",")
# append new data
existing_rows = list(cache_reader)
for filename, file_hash in new_files:
row = [filename, file_hash]
existing_rows.append(row)
# Write the updated content back to the StringIO object
updated_cache_content = StringIO()
cache_writer = csv.writer(updated_cache_content, delimiter=",")
cache_writer.writerows(existing_rows)
# Upload the updated cache to Azure Blob Storage
updated_cache_content.seek(0)
await cache_blob_client.upload_blob(
updated_cache_content.getvalue().encode("utf-8"), overwrite=True
)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating file cache in Azure Blob Storage.",
)

1756
backend/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -53,6 +53,7 @@ fsspec = ">=2024.2.0"
graphrag = "==1.2.0" graphrag = "==1.2.0"
httpx = ">=0.25.2" httpx = ">=0.25.2"
kubernetes = ">=29.0.0" kubernetes = ">=29.0.0"
markitdown = {extras = ["all"], version = "^0.1.1"}
networkx = ">=3.2.1" networkx = ">=3.2.1"
nltk = "*" nltk = "*"
pandas = ">=2.2.1" pandas = ">=2.2.1"

View File

@ -22,7 +22,7 @@ def test_schedule_index_without_data(client, cosmos_client: CosmosClient):
"storage_container_name": "nonexistent-data-container", "storage_container_name": "nonexistent-data-container",
}, },
) )
assert response.status_code == 500 assert response.status_code == 412
# def test_schedule_index_with_data(client, cosmos_client, blob_with_data_container_name): # def test_schedule_index_with_data(client, cosmos_client, blob_with_data_container_name):

View File

@ -21,7 +21,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"! pip install devtools python-magic requests tqdm" "! pip install devtools requests tqdm"
] ]
}, },
{ {
@ -35,7 +35,6 @@
"import time\n", "import time\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n", "\n",
"import magic\n",
"import requests\n", "import requests\n",
"from devtools import pprint\n", "from devtools import pprint\n",
"from tqdm import tqdm" "from tqdm import tqdm"
@ -127,7 +126,7 @@
"source": [ "source": [
"## Upload Files\n", "## Upload Files\n",
"\n", "\n",
"For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag." "For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.**"
] ]
}, },
{ {
@ -177,20 +176,14 @@
" return response\n", " return response\n",
"\n", "\n",
" batch_files = []\n", " batch_files = []\n",
" accepted_file_types = [\"text/plain\"]\n",
" filepaths = list(Path(file_directory).iterdir())\n", " filepaths = list(Path(file_directory).iterdir())\n",
" for file in tqdm(filepaths):\n", " for file in tqdm(filepaths):\n",
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n", " # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
" if (\n", " if (not file.is_file()):\n",
" not file.is_file()\n",
" or file.suffix != \".txt\"\n",
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
" ):\n",
" print(f\"Skipping invalid file: {file}\")\n", " print(f\"Skipping invalid file: {file}\")\n",
" continue\n", " continue\n",
" # open and decode file as utf-8, ignore bad characters\n",
" batch_files.append(\n", " batch_files.append(\n",
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n", " (\"files\", open(file=file, mode=\"rb\"))\n",
" )\n", " )\n",
" # upload batch of files\n", " # upload batch of files\n",
" if len(batch_files) == batch_size:\n", " if len(batch_files) == batch_size:\n",
@ -199,7 +192,7 @@
" if not response.ok:\n", " if not response.ok:\n",
" return response\n", " return response\n",
" batch_files.clear()\n", " batch_files.clear()\n",
" # upload remaining files\n", " # upload last batch of remaining files\n",
" if len(batch_files) > 0:\n", " if len(batch_files) > 0:\n",
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n", " response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
" return response\n", " return response\n",
@ -330,9 +323,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%%time\n",
"\n",
"\n",
"def global_search(\n", "def global_search(\n",
" index_name: str | list[str], query: str, community_level: int\n", " index_name: str | list[str], query: str, community_level: int\n",
") -> requests.Response:\n", ") -> requests.Response:\n",
@ -372,9 +362,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%%time\n",
"\n",
"\n",
"def local_search(\n", "def local_search(\n",
" index_name: str | list[str], query: str, community_level: int\n", " index_name: str | list[str], query: str, community_level: int\n",
") -> requests.Response:\n", ") -> requests.Response:\n",
@ -402,7 +389,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "graphrag-venv", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -27,17 +27,14 @@
"| DELETE | /index/{index_name}\n", "| DELETE | /index/{index_name}\n",
"| GET | /index/status/{index_name}\n", "| GET | /index/status/{index_name}\n",
"| POST | /query/global\n", "| POST | /query/global\n",
"| POST | /query/streaming/global\n",
"| POST | /query/local\n", "| POST | /query/local\n",
"| POST | /query/streaming/local\n",
"| GET | /index/config/prompts\n", "| GET | /index/config/prompts\n",
"| GET | /source/report/{index_name}/{report_id}\n", "| GET | /source/report/{index_name}/{report_id}\n",
"| GET | /source/text/{index_name}/{text_unit_id}\n", "| GET | /source/text/{index_name}/{text_unit_id}\n",
"| GET | /source/entity/{index_name}/{entity_id}\n", "| GET | /source/entity/{index_name}/{entity_id}\n",
"| GET | /source/claim/{index_name}/{claim_id}\n", "| GET | /source/claim/{index_name}/{claim_id}\n",
"| GET | /source/relationship/{index_name}/{relationship_id}\n", "| GET | /source/relationship/{index_name}/{relationship_id}\n",
"| GET | /graph/graphml/{index_name}\n", "| GET | /graph/graphml/{index_name}"
"| GET | /graph/stats/{index_name}"
] ]
}, },
{ {
@ -56,7 +53,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"! pip install devtools pandas python-magic requests tqdm" "! pip install devtools pandas requests tqdm"
] ]
}, },
{ {
@ -72,7 +69,6 @@
"import time\n", "import time\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n", "\n",
"import magic\n",
"import pandas as pd\n", "import pandas as pd\n",
"import requests\n", "import requests\n",
"from devtools import pprint\n", "from devtools import pprint\n",
@ -229,20 +225,14 @@
" return response\n", " return response\n",
"\n", "\n",
" batch_files = []\n", " batch_files = []\n",
" accepted_file_types = [\"text/plain\"]\n",
" filepaths = list(Path(file_directory).iterdir())\n", " filepaths = list(Path(file_directory).iterdir())\n",
" for file in tqdm(filepaths):\n", " for file in tqdm(filepaths):\n",
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n", " # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
" if (\n", " if (not file.is_file()):\n",
" not file.is_file()\n",
" or file.suffix != \".txt\"\n",
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
" ):\n",
" print(f\"Skipping invalid file: {file}\")\n", " print(f\"Skipping invalid file: {file}\")\n",
" continue\n", " continue\n",
" # open and decode file as utf-8, ignore bad characters\n",
" batch_files.append(\n", " batch_files.append(\n",
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n", " (\"files\", open(file=file, mode=\"rb\"))\n",
" )\n", " )\n",
" # upload batch of files\n", " # upload batch of files\n",
" if len(batch_files) == batch_size:\n", " if len(batch_files) == batch_size:\n",
@ -251,7 +241,7 @@
" if not response.ok:\n", " if not response.ok:\n",
" return response\n", " return response\n",
" batch_files.clear()\n", " batch_files.clear()\n",
" # upload remaining files\n", " # upload last batch of remaining files\n",
" if len(batch_files) > 0:\n", " if len(batch_files) > 0:\n",
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n", " response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
" return response\n", " return response\n",
@ -290,7 +280,10 @@
" return requests.post(\n", " return requests.post(\n",
" url,\n", " url,\n",
" files=prompts if len(prompts) > 0 else None,\n", " files=prompts if len(prompts) > 0 else None,\n",
" params={\"index_container_name\": index_name, \"storage_container_name\": storage_name},\n", " params={\n",
" \"index_container_name\": index_name,\n",
" \"storage_container_name\": storage_name,\n",
" },\n",
" headers=headers,\n", " headers=headers,\n",
" )\n", " )\n",
"\n", "\n",
@ -486,7 +479,7 @@
"source": [ "source": [
"## Upload files\n", "## Upload files\n",
"\n", "\n",
"Use the API to upload a collection of local files. The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems." "Use the API to upload a collection of files. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.** The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems."
] ]
}, },
{ {
@ -619,7 +612,9 @@
" community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n", " community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
" summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n", " summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
"else:\n", "else:\n",
" entity_extraction_prompt = community_summarization_prompt = summarize_description_prompt = None\n", " entity_extraction_prompt = community_summarization_prompt = (\n",
" summarize_description_prompt\n",
" ) = None\n",
"\n", "\n",
"response = build_index(\n", "response = build_index(\n",
" storage_name=storage_name,\n", " storage_name=storage_name,\n",
@ -745,8 +740,8 @@
"# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n", "# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n",
"global_response = global_search(\n", "global_response = global_search(\n",
" index_name=index_name,\n", " index_name=index_name,\n",
" query=\"Summarize the main topics found in this data\",\n", " query=\"Summarize the qualifications to being a delivery data scientist\",\n",
" community_level=1,\n", " community_level=2,\n",
")\n", ")\n",
"# print the result and save context data in a variable\n", "# print the result and save context data in a variable\n",
"global_response_data = parse_query_response(global_response, return_context_data=True)\n", "global_response_data = parse_query_response(global_response, return_context_data=True)\n",
@ -956,7 +951,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "graphrag-venv", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },