193 lines
6.1 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import asyncio
import re
from math import ceil
from typing import List
from azure.storage.blob import ContainerClient
from fastapi import (
APIRouter,
2025-01-28 00:34:04 -05:00
Depends,
2024-06-26 15:45:06 -04:00
HTTPException,
UploadFile,
)
2025-01-25 04:07:53 -05:00
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import (
BaseResponse,
StorageNameList,
)
2025-01-25 04:07:53 -05:00
from graphrag_app.utils.common import (
2024-06-26 15:45:06 -04:00
delete_blob_container,
2024-12-30 01:59:08 -05:00
delete_cosmos_container_item,
2025-01-28 00:34:04 -05:00
desanitize_name,
get_blob_container_client,
get_cosmos_container_store_client,
2024-06-26 15:45:06 -04:00
sanitize_name,
)
data_route = APIRouter(
prefix="/data",
tags=["Data Management"],
)
@data_route.get(
"",
2025-01-28 00:34:04 -05:00
summary="Get list of data containers.",
2024-06-26 15:45:06 -04:00
response_model=StorageNameList,
responses={200: {"model": StorageNameList}},
)
2025-01-28 00:34:04 -05:00
async def get_all_data_containers():
2024-06-26 15:45:06 -04:00
"""
2025-01-28 00:34:04 -05:00
Retrieve a list of all data containers.
2024-06-26 15:45:06 -04:00
"""
items = []
try:
2025-01-28 00:34:04 -05:00
container_store_client = get_cosmos_container_store_client()
2024-06-26 15:45:06 -04:00
for item in container_store_client.read_all_items():
if item["type"] == "data":
items.append(item["human_readable_name"])
except Exception:
2025-01-21 18:43:55 -05:00
reporter = load_pipeline_logger()
reporter.error("Error getting list of blob containers.")
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500, detail="Error getting list of blob containers."
)
return StorageNameList(storage_name=items)
async def upload_file_async(
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
) -> None:
"""
Asynchronously upload a file to the specified blob container.
Silently ignore errors that occur when overwrite=False.
"""
blob_client = container_client.get_blob_client(upload_file.filename)
with upload_file.file as file_stream:
try:
await blob_client.upload_blob(file_stream, overwrite=overwrite)
except Exception:
pass
class Cleaner:
def __init__(self, file):
self.file = file
self.name = file.name
self.changes = 0
def clean(self, val, replacement=""):
# fmt: off
_illegal_xml_chars_RE = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
# fmt: on
self.changes += len(_illegal_xml_chars_RE.findall(val))
return _illegal_xml_chars_RE.sub(replacement, val)
def read(self, n):
return self.clean(self.file.read(n).decode()).encode(
encoding="utf-8", errors="strict"
)
def name(self):
return self.file.name
def __enter__(self):
return self
def __exit__(self, *args):
self.file.close()
@data_route.post(
"",
summary="Upload data to a data storage container",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
)
async def upload_files(
2025-01-28 00:34:04 -05:00
files: List[UploadFile],
sanitized_container_name: str = Depends(sanitize_name),
overwrite: bool = True,
2024-06-26 15:45:06 -04:00
):
"""
Create a data storage container in Azure and upload files to it.
Args:
files (List[UploadFile]): A list of files to be uploaded.
storage_name (str): The name of the Azure Blob Storage container to which files will be uploaded.
overwrite (bool): Whether to overwrite existing files with the same name. Defaults to True. If False, files that already exist will be skipped.
Returns:
BaseResponse: An instance of the BaseResponse model with a status message indicating the result of the upload.
Raises:
HTTPException: If the container name is invalid or if any error occurs during the upload process.
"""
2025-01-28 00:34:04 -05:00
original_container_name = desanitize_name(sanitized_container_name)
2024-06-26 15:45:06 -04:00
try:
# clean files - remove illegal XML characters
files = [UploadFile(Cleaner(f.file), filename=f.filename) for f in files]
# upload files in batches of 1000 to avoid exceeding Azure Storage API limits
2025-01-28 00:34:04 -05:00
blob_container_client = get_blob_container_client(sanitized_container_name)
2024-06-26 15:45:06 -04:00
batch_size = 1000
batches = ceil(len(files) / batch_size)
for i in range(batches):
batch_files = files[i * batch_size : (i + 1) * batch_size]
tasks = [
2025-01-28 00:34:04 -05:00
upload_file_async(file, blob_container_client, overwrite)
2024-06-26 15:45:06 -04:00
for file in batch_files
]
await asyncio.gather(*tasks)
2025-01-28 00:34:04 -05:00
# update container-store entry in cosmosDB once upload process is successful
cosmos_container_store_client = get_cosmos_container_store_client()
cosmos_container_store_client.upsert_item({
"id": sanitized_container_name,
"human_readable_name": original_container_name,
2024-09-12 21:41:46 -04:00
"type": "data",
})
2024-06-26 15:45:06 -04:00
return BaseResponse(status="File upload successful.")
except Exception:
2025-01-21 18:43:55 -05:00
logger = load_pipeline_logger()
logger.error("Error uploading files.", details={"files": files})
2024-06-26 15:45:06 -04:00
raise HTTPException(
status_code=500,
2025-01-28 00:34:04 -05:00
detail=f"Error uploading files to container '{original_container_name}'.",
2024-06-26 15:45:06 -04:00
)
@data_route.delete(
"/{storage_name}",
summary="Delete a data storage container",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
)
2025-01-28 00:34:04 -05:00
async def delete_files(container_name: str):
2024-06-26 15:45:06 -04:00
"""
Delete a specified data storage container.
"""
2025-01-28 00:34:04 -05:00
sanitized_container_name = sanitize_name(container_name)
2024-06-26 15:45:06 -04:00
try:
# delete container in Azure Storage
2025-01-28 00:34:04 -05:00
delete_blob_container(sanitized_container_name)
2024-12-30 01:59:08 -05:00
# delete entry from container-store in cosmosDB
2025-01-28 00:34:04 -05:00
delete_cosmos_container_item("container-store", sanitized_container_name)
except Exception:
2025-01-21 18:43:55 -05:00
logger = load_pipeline_logger()
logger.error(
2025-01-28 00:34:04 -05:00
f"Error deleting container {container_name}.",
details={"Container": container_name},
2024-06-26 15:45:06 -04:00
)
raise HTTPException(
2025-01-28 00:34:04 -05:00
status_code=500,
detail=f"Error deleting container '{container_name}'.",
2024-06-26 15:45:06 -04:00
)
return BaseResponse(status="Success")