mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
Multi-filetype Support via Markitdown (#269)
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
parent
5d2ab180c7
commit
004fc65cdb
@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
@ -14,7 +15,9 @@ from fastapi import (
|
||||
Depends,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from markitdown import MarkItDown, StreamInfo
|
||||
|
||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||
from graphrag_app.typing.models import (
|
||||
@ -22,12 +25,15 @@ from graphrag_app.typing.models import (
|
||||
StorageNameList,
|
||||
)
|
||||
from graphrag_app.utils.common import (
|
||||
check_cache,
|
||||
create_cache,
|
||||
delete_cosmos_container_item_if_exist,
|
||||
delete_storage_container_if_exist,
|
||||
get_blob_container_client,
|
||||
get_cosmos_container_store_client,
|
||||
sanitize_name,
|
||||
subscription_key_check,
|
||||
update_cache,
|
||||
)
|
||||
|
||||
data_route = APIRouter(
|
||||
@ -42,7 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
"",
|
||||
summary="Get list of data containers.",
|
||||
response_model=StorageNameList,
|
||||
responses={200: {"model": StorageNameList}},
|
||||
responses={status.HTTP_200_OK: {"model": StorageNameList}},
|
||||
)
|
||||
async def get_all_data_containers():
|
||||
"""
|
||||
@ -67,56 +73,66 @@ async def get_all_data_containers():
|
||||
return StorageNameList(storage_name=items)
|
||||
|
||||
|
||||
async def upload_file_async(
|
||||
async def upload_file(
|
||||
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Asynchronously upload a file to the specified blob container.
|
||||
Silently ignore errors that occur when overwrite=False.
|
||||
Convert and upload a file to a specified blob container.
|
||||
|
||||
Returns a list of objects where each object will have one of the following types:
|
||||
* Tuple[str, str] - a tuple of (filename, file_hash) for successful uploads
|
||||
* Tuple[str, None] - a tuple of (filename, None) for failed uploads or
|
||||
* None for skipped files
|
||||
"""
|
||||
blob_client = container_client.get_blob_client(upload_file.filename)
|
||||
filename = upload_file.filename
|
||||
extension = os.path.splitext(filename)[1]
|
||||
converted_filename = filename + ".txt"
|
||||
converted_blob_client = container_client.get_blob_client(converted_filename)
|
||||
|
||||
with upload_file.file as file_stream:
|
||||
try:
|
||||
await blob_client.upload_blob(file_stream, overwrite=overwrite)
|
||||
file_hash = hashlib.sha256(file_stream.read()).hexdigest()
|
||||
if not await check_cache(file_hash, container_client):
|
||||
# extract text from file using MarkItDown
|
||||
md = MarkItDown()
|
||||
stream_info = StreamInfo(
|
||||
extension=extension,
|
||||
)
|
||||
file_stream._file.seek(0)
|
||||
file_stream = file_stream._file
|
||||
result = md.convert_stream(
|
||||
stream=file_stream,
|
||||
stream_info=stream_info,
|
||||
)
|
||||
|
||||
# remove illegal unicode characters and upload to blob storage
|
||||
cleaned_result = _clean_output(result.text_content)
|
||||
await converted_blob_client.upload_blob(
|
||||
cleaned_result, overwrite=overwrite
|
||||
)
|
||||
|
||||
# return tuple of (filename, file_hash) to indicate success
|
||||
return (filename, file_hash)
|
||||
except Exception:
|
||||
pass
|
||||
# if any exception occurs, return a tuple of (filename, None) to indicate conversion/upload failure
|
||||
return (upload_file.filename, None)
|
||||
|
||||
|
||||
class Cleaner:
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
self.name = file.name
|
||||
self.changes = 0
|
||||
|
||||
def clean(self, val, replacement=""):
|
||||
# fmt: off
|
||||
_illegal_xml_chars_RE = re.compile(
|
||||
def _clean_output(val: str, replacement: str = ""):
|
||||
"""Removes unicode characters that are invalid XML characters (not valid for graphml files at least)."""
|
||||
# fmt: off
|
||||
_illegal_xml_chars_RE = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
|
||||
)
|
||||
# fmt: on
|
||||
self.changes += len(_illegal_xml_chars_RE.findall(val))
|
||||
return _illegal_xml_chars_RE.sub(replacement, val)
|
||||
|
||||
def read(self, n):
|
||||
return self.clean(self.file.read(n).decode()).encode(
|
||||
encoding="utf-8", errors="strict"
|
||||
)
|
||||
|
||||
def name(self):
|
||||
return self.file.name
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.file.close()
|
||||
# fmt: on
|
||||
return _illegal_xml_chars_RE.sub(replacement, val)
|
||||
|
||||
|
||||
@data_route.post(
|
||||
"",
|
||||
summary="Upload data to a data storage container",
|
||||
response_model=BaseResponse,
|
||||
responses={200: {"model": BaseResponse}},
|
||||
responses={status.HTTP_201_CREATED: {"model": BaseResponse}},
|
||||
)
|
||||
async def upload_files(
|
||||
files: List[UploadFile],
|
||||
@ -125,36 +141,33 @@ async def upload_files(
|
||||
overwrite: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a Azure Storage container and upload files to it.
|
||||
|
||||
Args:
|
||||
files (List[UploadFile]): A list of files to be uploaded.
|
||||
storage_name (str): The name of the Azure Blob Storage container to which files will be uploaded.
|
||||
overwrite (bool): Whether to overwrite existing files with the same name. Defaults to True. If False, files that already exist will be skipped.
|
||||
|
||||
Returns:
|
||||
BaseResponse: An instance of the BaseResponse model with a status message indicating the result of the upload.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the container name is invalid or if any error occurs during the upload process.
|
||||
Create a Azure Storage container (if needed) and upload files. Multiple file types are supported, including pdf, powerpoint, word, excel, html, csv, json, xml, etc.
|
||||
The complete set of supported file types can be found in the MarkItDown (https://github.com/microsoft/markitdown) library.
|
||||
"""
|
||||
try:
|
||||
# clean files - remove illegal XML characters
|
||||
files = [UploadFile(Cleaner(f.file), filename=f.filename) for f in files]
|
||||
|
||||
# upload files in batches of 1000 to avoid exceeding Azure Storage API limits
|
||||
# create the initial cache if it doesn't exist
|
||||
blob_container_client = await get_blob_container_client(
|
||||
sanitized_container_name
|
||||
)
|
||||
batch_size = 1000
|
||||
await create_cache(blob_container_client)
|
||||
|
||||
# process file uploads in batches to avoid exceeding Azure Storage API limits
|
||||
processing_errors = []
|
||||
batch_size = 100
|
||||
num_batches = ceil(len(files) / batch_size)
|
||||
for i in range(num_batches):
|
||||
batch_files = files[i * batch_size : (i + 1) * batch_size]
|
||||
tasks = [
|
||||
upload_file_async(file, blob_container_client, overwrite)
|
||||
upload_file(file, blob_container_client, overwrite)
|
||||
for file in batch_files
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
upload_results = await asyncio.gather(*tasks)
|
||||
successful_uploads = [r for r in upload_results if r and r[1] is not None]
|
||||
# update the file cache with successful uploads
|
||||
await update_cache(successful_uploads, blob_container_client)
|
||||
# collect failed uploads
|
||||
failed_uploads = [r[0] for r in upload_results if r and r[1] is None]
|
||||
processing_errors.extend(failed_uploads)
|
||||
|
||||
# update container-store entry in cosmosDB once upload process is successful
|
||||
cosmos_container_store_client = get_cosmos_container_store_client()
|
||||
@ -163,17 +176,23 @@ async def upload_files(
|
||||
"human_readable_name": container_name,
|
||||
"type": "data",
|
||||
})
|
||||
return BaseResponse(status="File upload successful.")
|
||||
|
||||
if len(processing_errors) > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Error uploading files: {processing_errors}.",
|
||||
)
|
||||
return BaseResponse(status="Success.")
|
||||
except Exception as e:
|
||||
logger = load_pipeline_logger()
|
||||
logger.error(
|
||||
message="Error uploading files.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
details={"files": [f.filename for f in files]},
|
||||
details={"files": processing_errors},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error uploading files to container '{container_name}'.",
|
||||
)
|
||||
|
||||
@ -182,7 +201,7 @@ async def upload_files(
|
||||
"/{container_name}",
|
||||
summary="Delete a data storage container",
|
||||
response_model=BaseResponse,
|
||||
responses={200: {"model": BaseResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": BaseResponse}},
|
||||
)
|
||||
async def delete_files(
|
||||
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
||||
|
@ -8,6 +8,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
@ -31,6 +32,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
"/graphml/{container_name}",
|
||||
summary="Retrieve a GraphML file of the knowledge graph",
|
||||
response_description="GraphML file successfully downloaded",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def get_graphml_file(
|
||||
container_name, sanitized_container_name: str = Depends(sanitize_name)
|
||||
|
@ -12,6 +12,7 @@ from fastapi import (
|
||||
Depends,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from kubernetes import (
|
||||
client as kubernetes_client,
|
||||
@ -49,7 +50,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
"",
|
||||
summary="Build an index",
|
||||
response_model=BaseResponse,
|
||||
responses={200: {"model": BaseResponse}},
|
||||
responses={status.HTTP_202_ACCEPTED: {"model": BaseResponse}},
|
||||
)
|
||||
async def schedule_index_job(
|
||||
storage_container_name: str,
|
||||
@ -71,7 +72,7 @@ async def schedule_index_job(
|
||||
sanitized_storage_container_name
|
||||
).exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
status_code=status.HTTP_412_PRECONDITION_FAILED,
|
||||
detail=f"Storage container '{storage_container_name}' does not exist",
|
||||
)
|
||||
|
||||
@ -101,7 +102,7 @@ async def schedule_index_job(
|
||||
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=202, # request has been accepted for processing but is not complete.
|
||||
status_code=status.HTTP_425_TOO_EARLY, # request has been accepted for processing but is not complete.
|
||||
detail=f"Index '{index_container_name}' already exists and has not finished building.",
|
||||
)
|
||||
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
||||
@ -142,7 +143,7 @@ async def schedule_index_job(
|
||||
"",
|
||||
summary="Get all index names",
|
||||
response_model=IndexNameList,
|
||||
responses={200: {"model": IndexNameList}},
|
||||
responses={status.HTTP_200_OK: {"model": IndexNameList}},
|
||||
)
|
||||
async def get_all_index_names(
|
||||
container_store_client=Depends(get_cosmos_container_store_client),
|
||||
@ -218,7 +219,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
|
||||
"/{container_name}",
|
||||
summary="Delete a specified index",
|
||||
response_model=BaseResponse,
|
||||
responses={200: {"model": BaseResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": BaseResponse}},
|
||||
)
|
||||
async def delete_index(
|
||||
container_name: str,
|
||||
@ -257,7 +258,8 @@ async def delete_index(
|
||||
details={"container": container_name},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error deleting '{container_name}'."
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error deleting '{container_name}'.",
|
||||
)
|
||||
|
||||
return BaseResponse(status="Success")
|
||||
@ -267,6 +269,7 @@ async def delete_index(
|
||||
"/status/{container_name}",
|
||||
summary="Track the status of an indexing job",
|
||||
response_model=IndexStatusResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def get_index_status(
|
||||
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
||||
@ -275,7 +278,7 @@ async def get_index_status(
|
||||
if pipelinejob.item_exist(sanitized_container_name):
|
||||
pipeline_job = pipelinejob.load_item(sanitized_container_name)
|
||||
return IndexStatusResponse(
|
||||
status_code=200,
|
||||
status_code=status.HTTP_200_OK,
|
||||
index_name=pipeline_job.human_readable_index_name,
|
||||
storage_name=pipeline_job.human_readable_storage_name,
|
||||
status=pipeline_job.status.value,
|
||||
@ -284,5 +287,6 @@ async def get_index_status(
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"'{container_name}' does not exist."
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"'{container_name}' does not exist.",
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
|
||||
@ -27,6 +28,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
"/prompts",
|
||||
summary="Generate custom graphrag prompts based on user-provided data.",
|
||||
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def generate_prompts(
|
||||
container_name: str,
|
||||
|
@ -10,6 +10,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
from graphrag.api.query import global_search, local_search
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
@ -42,7 +43,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
summary="Perform a global search across the knowledge graph index",
|
||||
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
||||
response_model=GraphResponse,
|
||||
responses={200: {"model": GraphResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": GraphResponse}},
|
||||
)
|
||||
async def global_query(request: GraphRequest):
|
||||
# this is a slightly modified version of the graphrag.query.cli.run_global_search method
|
||||
@ -51,7 +52,7 @@ async def global_query(request: GraphRequest):
|
||||
|
||||
if not _is_index_complete(sanitized_index_name):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
status_code=status.HTTP_425_TOO_EARLY,
|
||||
detail=f"{index_name} not ready for querying.",
|
||||
)
|
||||
|
||||
@ -122,7 +123,7 @@ async def global_query(request: GraphRequest):
|
||||
summary="Perform a local search across the knowledge graph index.",
|
||||
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
||||
response_model=GraphResponse,
|
||||
responses={200: {"model": GraphResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": GraphResponse}},
|
||||
)
|
||||
async def local_query(request: GraphRequest):
|
||||
index_name = request.index_name
|
||||
@ -130,7 +131,7 @@ async def local_query(request: GraphRequest):
|
||||
|
||||
if not _is_index_complete(sanitized_index_name):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
status_code=status.HTTP_425_TOO_EARLY,
|
||||
detail=f"{index_name} not ready for querying.",
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from graphrag.api.query import (
|
||||
@ -47,6 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
||||
"/global",
|
||||
summary="Stream a response back after performing a global search",
|
||||
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def global_search_streaming(request: GraphRequest):
|
||||
# this is a slightly modified version of graphrag_app.api.query.global_query() method
|
||||
@ -204,6 +206,7 @@ async def global_search_streaming(request: GraphRequest):
|
||||
"/local",
|
||||
summary="Stream a response back after performing a local search",
|
||||
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def local_search_streaming(request: GraphRequest):
|
||||
# this is a slightly modified version of graphrag_app.api.query.local_query() method
|
||||
|
@ -5,7 +5,12 @@ import os
|
||||
import traceback
|
||||
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
|
||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||
from graphrag_app.typing.models import (
|
||||
@ -43,7 +48,7 @@ DOCUMENTS_TABLE = "output/create_final_documents.parquet"
|
||||
"/report/{container_name}/{report_id}",
|
||||
summary="Return a single community report.",
|
||||
response_model=ReportResponse,
|
||||
responses={200: {"model": ReportResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": ReportResponse}},
|
||||
)
|
||||
async def get_report_info(
|
||||
report_id: int,
|
||||
@ -88,7 +93,7 @@ async def get_report_info(
|
||||
"/text/{container_name}/{text_unit_id}",
|
||||
summary="Return a single base text unit.",
|
||||
response_model=TextUnitResponse,
|
||||
responses={200: {"model": TextUnitResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": TextUnitResponse}},
|
||||
)
|
||||
async def get_chunk_info(
|
||||
text_unit_id: str,
|
||||
@ -148,7 +153,7 @@ async def get_chunk_info(
|
||||
"/entity/{container_name}/{entity_id}",
|
||||
summary="Return a single entity.",
|
||||
response_model=EntityResponse,
|
||||
responses={200: {"model": EntityResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": EntityResponse}},
|
||||
)
|
||||
async def get_entity_info(
|
||||
entity_id: int,
|
||||
@ -190,7 +195,7 @@ async def get_entity_info(
|
||||
"/claim/{container_name}/{claim_id}",
|
||||
summary="Return a single claim.",
|
||||
response_model=ClaimResponse,
|
||||
responses={200: {"model": ClaimResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": ClaimResponse}},
|
||||
)
|
||||
async def get_claim_info(
|
||||
claim_id: int,
|
||||
@ -240,7 +245,7 @@ async def get_claim_info(
|
||||
"/relationship/{container_name}/{relationship_id}",
|
||||
summary="Return a single relationship.",
|
||||
response_model=RelationshipResponse,
|
||||
responses={200: {"model": RelationshipResponse}},
|
||||
responses={status.HTTP_200_OK: {"model": RelationshipResponse}},
|
||||
)
|
||||
async def get_relationship_info(
|
||||
relationship_id: int,
|
||||
|
@ -33,7 +33,7 @@ from graphrag_app.utils.common import subscription_key_check
|
||||
|
||||
|
||||
async def catch_all_exceptions_middleware(request: Request, call_next):
|
||||
"""a function to globally catch all exceptions and return a 500 response with the exception message"""
|
||||
"""A global function to catch all exceptions and produce a standard error message"""
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
@ -44,7 +44,10 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
|
||||
cause=e,
|
||||
stack=stack,
|
||||
)
|
||||
return Response("Unexpected internal server error.", status_code=500)
|
||||
return Response(
|
||||
"Unexpected internal server error.",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved
|
||||
|
@ -1,27 +1,35 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Common utility functions used by the API endpoints.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Annotated
|
||||
from io import StringIO
|
||||
from typing import Annotated, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.cosmos import ContainerProxy, exceptions
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob.aio import ContainerClient
|
||||
from fastapi import Header, HTTPException
|
||||
from fastapi import Header, HTTPException, status
|
||||
|
||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||
from graphrag_app.utils.azure_clients import AzureClientManager
|
||||
|
||||
FILE_UPLOAD_CACHE = "cache/uploaded_files.csv"
|
||||
|
||||
|
||||
def get_df(
|
||||
table_path: str,
|
||||
filepath: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Read a parquet file from Azure Storage and return it as a pandas DataFrame."""
|
||||
df = pd.read_parquet(
|
||||
table_path,
|
||||
filepath,
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
return df
|
||||
@ -123,7 +131,10 @@ def get_cosmos_container_store_client() -> ContainerProxy:
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Error fetching cosmosdb client.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error fetching cosmosdb client.",
|
||||
)
|
||||
|
||||
|
||||
async def get_blob_container_client(name: str) -> ContainerClient:
|
||||
@ -141,7 +152,10 @@ async def get_blob_container_client(name: str) -> ContainerClient:
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Error fetching storage client.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error fetching storage client.",
|
||||
)
|
||||
|
||||
|
||||
def sanitize_name(container_name: str) -> str:
|
||||
@ -188,7 +202,8 @@ def desanitize_name(sanitized_container_name: str) -> str | None:
|
||||
return None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error retrieving original container name."
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving original container name.",
|
||||
)
|
||||
|
||||
|
||||
@ -196,13 +211,102 @@ async def subscription_key_check(
|
||||
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
|
||||
):
|
||||
"""
|
||||
Verifies if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
|
||||
If it is not present, an HTTPException with a 400 status code is raised.
|
||||
Note: this check is unnecessary (APIM validates subscription keys automatically), but this will add the key
|
||||
Verify if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
|
||||
Note: this check is unnecessary (APIM validates subscription keys automatically), but it effectively adds the key
|
||||
as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature.
|
||||
"""
|
||||
if not Ocp_Apim_Subscription_Key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ocp-Apim-Subscription-Key required"
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Ocp-Apim-Subscription-Key required",
|
||||
)
|
||||
return Ocp_Apim_Subscription_Key
|
||||
|
||||
|
||||
async def create_cache(container_client: ContainerClient) -> None:
|
||||
"""
|
||||
Create a file cache (csv).
|
||||
"""
|
||||
try:
|
||||
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||
if not await cache_blob_client.exists():
|
||||
# create the empty file cache csv
|
||||
headers = [["Filename", "Hash"]]
|
||||
tmp_cache_file = "uploaded_files_cache.csv"
|
||||
with open(tmp_cache_file, "w", newline="") as f:
|
||||
writer = csv.writer(f, delimiter=",")
|
||||
writer.writerows(headers)
|
||||
|
||||
# upload to Azure Blob Storage and remove the temporary file
|
||||
with open(tmp_cache_file, "rb") as f:
|
||||
await cache_blob_client.upload_blob(f, overwrite=True)
|
||||
if os.path.exists(tmp_cache_file):
|
||||
os.remove(tmp_cache_file)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating file cache in Azure Blob Storage.",
|
||||
)
|
||||
|
||||
|
||||
async def check_cache(file_hash: str, container_client: ContainerClient) -> bool:
|
||||
"""
|
||||
Check a file cache (csv) to determine if a file hash has previously been uploaded.
|
||||
|
||||
Note: This function creates/checks a CSV file in azure storage to act as a cache of previously uploaded files.
|
||||
"""
|
||||
try:
|
||||
# load the file cache
|
||||
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||
cache_download_stream = await cache_blob_client.download_blob()
|
||||
cache_bytes = await cache_download_stream.readall()
|
||||
cache_content = StringIO(cache_bytes.decode("utf-8"))
|
||||
|
||||
# comupte the sha256 hash of the file and check if it exists in the cache
|
||||
cache_reader = csv.reader(cache_content, delimiter=",")
|
||||
for row in cache_reader:
|
||||
if file_hash in row:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error checking file cache in Azure Blob Storage.",
|
||||
)
|
||||
|
||||
|
||||
async def update_cache(
|
||||
new_files: Tuple[str, str], container_client: ContainerClient
|
||||
) -> None:
|
||||
"""
|
||||
Update an existing file cache (csv) with new files.
|
||||
"""
|
||||
try:
|
||||
# Load the existing cache
|
||||
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||
cache_download_stream = await cache_blob_client.download_blob()
|
||||
cache_bytes = await cache_download_stream.readall()
|
||||
cache_content = StringIO(cache_bytes.decode("utf-8"))
|
||||
cache_reader = csv.reader(cache_content, delimiter=",")
|
||||
|
||||
# append new data
|
||||
existing_rows = list(cache_reader)
|
||||
for filename, file_hash in new_files:
|
||||
row = [filename, file_hash]
|
||||
existing_rows.append(row)
|
||||
|
||||
# Write the updated content back to the StringIO object
|
||||
updated_cache_content = StringIO()
|
||||
cache_writer = csv.writer(updated_cache_content, delimiter=",")
|
||||
cache_writer.writerows(existing_rows)
|
||||
|
||||
# Upload the updated cache to Azure Blob Storage
|
||||
updated_cache_content.seek(0)
|
||||
await cache_blob_client.upload_blob(
|
||||
updated_cache_content.getvalue().encode("utf-8"), overwrite=True
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating file cache in Azure Blob Storage.",
|
||||
)
|
||||
|
1756
backend/poetry.lock
generated
1756
backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -53,6 +53,7 @@ fsspec = ">=2024.2.0"
|
||||
graphrag = "==1.2.0"
|
||||
httpx = ">=0.25.2"
|
||||
kubernetes = ">=29.0.0"
|
||||
markitdown = {extras = ["all"], version = "^0.1.1"}
|
||||
networkx = ">=3.2.1"
|
||||
nltk = "*"
|
||||
pandas = ">=2.2.1"
|
||||
|
@ -22,7 +22,7 @@ def test_schedule_index_without_data(client, cosmos_client: CosmosClient):
|
||||
"storage_container_name": "nonexistent-data-container",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.status_code == 412
|
||||
|
||||
|
||||
# def test_schedule_index_with_data(client, cosmos_client, blob_with_data_container_name):
|
||||
|
@ -21,7 +21,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install devtools python-magic requests tqdm"
|
||||
"! pip install devtools requests tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -35,7 +35,6 @@
|
||||
"import time\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import magic\n",
|
||||
"import requests\n",
|
||||
"from devtools import pprint\n",
|
||||
"from tqdm import tqdm"
|
||||
@ -127,7 +126,7 @@
|
||||
"source": [
|
||||
"## Upload Files\n",
|
||||
"\n",
|
||||
"For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag."
|
||||
"For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.**"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -177,20 +176,14 @@
|
||||
" return response\n",
|
||||
"\n",
|
||||
" batch_files = []\n",
|
||||
" accepted_file_types = [\"text/plain\"]\n",
|
||||
" filepaths = list(Path(file_directory).iterdir())\n",
|
||||
" for file in tqdm(filepaths):\n",
|
||||
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
||||
" if (\n",
|
||||
" not file.is_file()\n",
|
||||
" or file.suffix != \".txt\"\n",
|
||||
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
|
||||
" ):\n",
|
||||
" if (not file.is_file()):\n",
|
||||
" print(f\"Skipping invalid file: {file}\")\n",
|
||||
" continue\n",
|
||||
" # open and decode file as utf-8, ignore bad characters\n",
|
||||
" batch_files.append(\n",
|
||||
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n",
|
||||
" (\"files\", open(file=file, mode=\"rb\"))\n",
|
||||
" )\n",
|
||||
" # upload batch of files\n",
|
||||
" if len(batch_files) == batch_size:\n",
|
||||
@ -199,7 +192,7 @@
|
||||
" if not response.ok:\n",
|
||||
" return response\n",
|
||||
" batch_files.clear()\n",
|
||||
" # upload remaining files\n",
|
||||
" # upload last batch of remaining files\n",
|
||||
" if len(batch_files) > 0:\n",
|
||||
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
||||
" return response\n",
|
||||
@ -330,9 +323,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def global_search(\n",
|
||||
" index_name: str | list[str], query: str, community_level: int\n",
|
||||
") -> requests.Response:\n",
|
||||
@ -372,9 +362,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def local_search(\n",
|
||||
" index_name: str | list[str], query: str, community_level: int\n",
|
||||
") -> requests.Response:\n",
|
||||
@ -402,7 +389,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag-venv",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -27,17 +27,14 @@
|
||||
"| DELETE | /index/{index_name}\n",
|
||||
"| GET | /index/status/{index_name}\n",
|
||||
"| POST | /query/global\n",
|
||||
"| POST | /query/streaming/global\n",
|
||||
"| POST | /query/local\n",
|
||||
"| POST | /query/streaming/local\n",
|
||||
"| GET | /index/config/prompts\n",
|
||||
"| GET | /source/report/{index_name}/{report_id}\n",
|
||||
"| GET | /source/text/{index_name}/{text_unit_id}\n",
|
||||
"| GET | /source/entity/{index_name}/{entity_id}\n",
|
||||
"| GET | /source/claim/{index_name}/{claim_id}\n",
|
||||
"| GET | /source/relationship/{index_name}/{relationship_id}\n",
|
||||
"| GET | /graph/graphml/{index_name}\n",
|
||||
"| GET | /graph/stats/{index_name}"
|
||||
"| GET | /graph/graphml/{index_name}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -56,7 +53,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install devtools pandas python-magic requests tqdm"
|
||||
"! pip install devtools pandas requests tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -72,7 +69,6 @@
|
||||
"import time\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import magic\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"from devtools import pprint\n",
|
||||
@ -229,20 +225,14 @@
|
||||
" return response\n",
|
||||
"\n",
|
||||
" batch_files = []\n",
|
||||
" accepted_file_types = [\"text/plain\"]\n",
|
||||
" filepaths = list(Path(file_directory).iterdir())\n",
|
||||
" for file in tqdm(filepaths):\n",
|
||||
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
||||
" if (\n",
|
||||
" not file.is_file()\n",
|
||||
" or file.suffix != \".txt\"\n",
|
||||
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
|
||||
" ):\n",
|
||||
" if (not file.is_file()):\n",
|
||||
" print(f\"Skipping invalid file: {file}\")\n",
|
||||
" continue\n",
|
||||
" # open and decode file as utf-8, ignore bad characters\n",
|
||||
" batch_files.append(\n",
|
||||
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n",
|
||||
" (\"files\", open(file=file, mode=\"rb\"))\n",
|
||||
" )\n",
|
||||
" # upload batch of files\n",
|
||||
" if len(batch_files) == batch_size:\n",
|
||||
@ -251,7 +241,7 @@
|
||||
" if not response.ok:\n",
|
||||
" return response\n",
|
||||
" batch_files.clear()\n",
|
||||
" # upload remaining files\n",
|
||||
" # upload last batch of remaining files\n",
|
||||
" if len(batch_files) > 0:\n",
|
||||
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
||||
" return response\n",
|
||||
@ -290,7 +280,10 @@
|
||||
" return requests.post(\n",
|
||||
" url,\n",
|
||||
" files=prompts if len(prompts) > 0 else None,\n",
|
||||
" params={\"index_container_name\": index_name, \"storage_container_name\": storage_name},\n",
|
||||
" params={\n",
|
||||
" \"index_container_name\": index_name,\n",
|
||||
" \"storage_container_name\": storage_name,\n",
|
||||
" },\n",
|
||||
" headers=headers,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -486,7 +479,7 @@
|
||||
"source": [
|
||||
"## Upload files\n",
|
||||
"\n",
|
||||
"Use the API to upload a collection of local files. The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems."
|
||||
"Use the API to upload a collection of files. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.** The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -619,7 +612,9 @@
|
||||
" community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
|
||||
" summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
|
||||
"else:\n",
|
||||
" entity_extraction_prompt = community_summarization_prompt = summarize_description_prompt = None\n",
|
||||
" entity_extraction_prompt = community_summarization_prompt = (\n",
|
||||
" summarize_description_prompt\n",
|
||||
" ) = None\n",
|
||||
"\n",
|
||||
"response = build_index(\n",
|
||||
" storage_name=storage_name,\n",
|
||||
@ -745,8 +740,8 @@
|
||||
"# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n",
|
||||
"global_response = global_search(\n",
|
||||
" index_name=index_name,\n",
|
||||
" query=\"Summarize the main topics found in this data\",\n",
|
||||
" community_level=1,\n",
|
||||
" query=\"Summarize the qualifications to being a delivery data scientist\",\n",
|
||||
" community_level=2,\n",
|
||||
")\n",
|
||||
"# print the result and save context data in a variable\n",
|
||||
"global_response_data = parse_query_response(global_response, return_context_data=True)\n",
|
||||
@ -956,7 +951,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag-venv",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user