mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
Multi-filetype Support via Markitdown (#269)
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
parent
5d2ab180c7
commit
004fc65cdb
@ -2,6 +2,7 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
@ -14,7 +15,9 @@ from fastapi import (
|
|||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
|
from markitdown import MarkItDown, StreamInfo
|
||||||
|
|
||||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||||
from graphrag_app.typing.models import (
|
from graphrag_app.typing.models import (
|
||||||
@ -22,12 +25,15 @@ from graphrag_app.typing.models import (
|
|||||||
StorageNameList,
|
StorageNameList,
|
||||||
)
|
)
|
||||||
from graphrag_app.utils.common import (
|
from graphrag_app.utils.common import (
|
||||||
|
check_cache,
|
||||||
|
create_cache,
|
||||||
delete_cosmos_container_item_if_exist,
|
delete_cosmos_container_item_if_exist,
|
||||||
delete_storage_container_if_exist,
|
delete_storage_container_if_exist,
|
||||||
get_blob_container_client,
|
get_blob_container_client,
|
||||||
get_cosmos_container_store_client,
|
get_cosmos_container_store_client,
|
||||||
sanitize_name,
|
sanitize_name,
|
||||||
subscription_key_check,
|
subscription_key_check,
|
||||||
|
update_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
data_route = APIRouter(
|
data_route = APIRouter(
|
||||||
@ -42,7 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
"",
|
"",
|
||||||
summary="Get list of data containers.",
|
summary="Get list of data containers.",
|
||||||
response_model=StorageNameList,
|
response_model=StorageNameList,
|
||||||
responses={200: {"model": StorageNameList}},
|
responses={status.HTTP_200_OK: {"model": StorageNameList}},
|
||||||
)
|
)
|
||||||
async def get_all_data_containers():
|
async def get_all_data_containers():
|
||||||
"""
|
"""
|
||||||
@ -67,56 +73,66 @@ async def get_all_data_containers():
|
|||||||
return StorageNameList(storage_name=items)
|
return StorageNameList(storage_name=items)
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_async(
|
async def upload_file(
|
||||||
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
|
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Asynchronously upload a file to the specified blob container.
|
Convert and upload a file to a specified blob container.
|
||||||
Silently ignore errors that occur when overwrite=False.
|
|
||||||
|
Returns a list of objects where each object will have one of the following types:
|
||||||
|
* Tuple[str, str] - a tuple of (filename, file_hash) for successful uploads
|
||||||
|
* Tuple[str, None] - a tuple of (filename, None) for failed uploads or
|
||||||
|
* None for skipped files
|
||||||
"""
|
"""
|
||||||
blob_client = container_client.get_blob_client(upload_file.filename)
|
filename = upload_file.filename
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
converted_filename = filename + ".txt"
|
||||||
|
converted_blob_client = container_client.get_blob_client(converted_filename)
|
||||||
|
|
||||||
with upload_file.file as file_stream:
|
with upload_file.file as file_stream:
|
||||||
try:
|
try:
|
||||||
await blob_client.upload_blob(file_stream, overwrite=overwrite)
|
file_hash = hashlib.sha256(file_stream.read()).hexdigest()
|
||||||
|
if not await check_cache(file_hash, container_client):
|
||||||
|
# extract text from file using MarkItDown
|
||||||
|
md = MarkItDown()
|
||||||
|
stream_info = StreamInfo(
|
||||||
|
extension=extension,
|
||||||
|
)
|
||||||
|
file_stream._file.seek(0)
|
||||||
|
file_stream = file_stream._file
|
||||||
|
result = md.convert_stream(
|
||||||
|
stream=file_stream,
|
||||||
|
stream_info=stream_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove illegal unicode characters and upload to blob storage
|
||||||
|
cleaned_result = _clean_output(result.text_content)
|
||||||
|
await converted_blob_client.upload_blob(
|
||||||
|
cleaned_result, overwrite=overwrite
|
||||||
|
)
|
||||||
|
|
||||||
|
# return tuple of (filename, file_hash) to indicate success
|
||||||
|
return (filename, file_hash)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
# if any exception occurs, return a tuple of (filename, None) to indicate conversion/upload failure
|
||||||
|
return (upload_file.filename, None)
|
||||||
|
|
||||||
|
|
||||||
class Cleaner:
|
def _clean_output(val: str, replacement: str = ""):
|
||||||
def __init__(self, file):
|
"""Removes unicode characters that are invalid XML characters (not valid for graphml files at least)."""
|
||||||
self.file = file
|
# fmt: off
|
||||||
self.name = file.name
|
_illegal_xml_chars_RE = re.compile(
|
||||||
self.changes = 0
|
|
||||||
|
|
||||||
def clean(self, val, replacement=""):
|
|
||||||
# fmt: off
|
|
||||||
_illegal_xml_chars_RE = re.compile(
|
|
||||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
|
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.changes += len(_illegal_xml_chars_RE.findall(val))
|
return _illegal_xml_chars_RE.sub(replacement, val)
|
||||||
return _illegal_xml_chars_RE.sub(replacement, val)
|
|
||||||
|
|
||||||
def read(self, n):
|
|
||||||
return self.clean(self.file.read(n).decode()).encode(
|
|
||||||
encoding="utf-8", errors="strict"
|
|
||||||
)
|
|
||||||
|
|
||||||
def name(self):
|
|
||||||
return self.file.name
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.file.close()
|
|
||||||
|
|
||||||
|
|
||||||
@data_route.post(
|
@data_route.post(
|
||||||
"",
|
"",
|
||||||
summary="Upload data to a data storage container",
|
summary="Upload data to a data storage container",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
responses={200: {"model": BaseResponse}},
|
responses={status.HTTP_201_CREATED: {"model": BaseResponse}},
|
||||||
)
|
)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
files: List[UploadFile],
|
files: List[UploadFile],
|
||||||
@ -125,36 +141,33 @@ async def upload_files(
|
|||||||
overwrite: bool = True,
|
overwrite: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a Azure Storage container and upload files to it.
|
Create a Azure Storage container (if needed) and upload files. Multiple file types are supported, including pdf, powerpoint, word, excel, html, csv, json, xml, etc.
|
||||||
|
The complete set of supported file types can be found in the MarkItDown (https://github.com/microsoft/markitdown) library.
|
||||||
Args:
|
|
||||||
files (List[UploadFile]): A list of files to be uploaded.
|
|
||||||
storage_name (str): The name of the Azure Blob Storage container to which files will be uploaded.
|
|
||||||
overwrite (bool): Whether to overwrite existing files with the same name. Defaults to True. If False, files that already exist will be skipped.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseResponse: An instance of the BaseResponse model with a status message indicating the result of the upload.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If the container name is invalid or if any error occurs during the upload process.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# clean files - remove illegal XML characters
|
# create the initial cache if it doesn't exist
|
||||||
files = [UploadFile(Cleaner(f.file), filename=f.filename) for f in files]
|
|
||||||
|
|
||||||
# upload files in batches of 1000 to avoid exceeding Azure Storage API limits
|
|
||||||
blob_container_client = await get_blob_container_client(
|
blob_container_client = await get_blob_container_client(
|
||||||
sanitized_container_name
|
sanitized_container_name
|
||||||
)
|
)
|
||||||
batch_size = 1000
|
await create_cache(blob_container_client)
|
||||||
|
|
||||||
|
# process file uploads in batches to avoid exceeding Azure Storage API limits
|
||||||
|
processing_errors = []
|
||||||
|
batch_size = 100
|
||||||
num_batches = ceil(len(files) / batch_size)
|
num_batches = ceil(len(files) / batch_size)
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
batch_files = files[i * batch_size : (i + 1) * batch_size]
|
batch_files = files[i * batch_size : (i + 1) * batch_size]
|
||||||
tasks = [
|
tasks = [
|
||||||
upload_file_async(file, blob_container_client, overwrite)
|
upload_file(file, blob_container_client, overwrite)
|
||||||
for file in batch_files
|
for file in batch_files
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
upload_results = await asyncio.gather(*tasks)
|
||||||
|
successful_uploads = [r for r in upload_results if r and r[1] is not None]
|
||||||
|
# update the file cache with successful uploads
|
||||||
|
await update_cache(successful_uploads, blob_container_client)
|
||||||
|
# collect failed uploads
|
||||||
|
failed_uploads = [r[0] for r in upload_results if r and r[1] is None]
|
||||||
|
processing_errors.extend(failed_uploads)
|
||||||
|
|
||||||
# update container-store entry in cosmosDB once upload process is successful
|
# update container-store entry in cosmosDB once upload process is successful
|
||||||
cosmos_container_store_client = get_cosmos_container_store_client()
|
cosmos_container_store_client = get_cosmos_container_store_client()
|
||||||
@ -163,17 +176,23 @@ async def upload_files(
|
|||||||
"human_readable_name": container_name,
|
"human_readable_name": container_name,
|
||||||
"type": "data",
|
"type": "data",
|
||||||
})
|
})
|
||||||
return BaseResponse(status="File upload successful.")
|
|
||||||
|
if len(processing_errors) > 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"Error uploading files: {processing_errors}.",
|
||||||
|
)
|
||||||
|
return BaseResponse(status="Success.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = load_pipeline_logger()
|
logger = load_pipeline_logger()
|
||||||
logger.error(
|
logger.error(
|
||||||
message="Error uploading files.",
|
message="Error uploading files.",
|
||||||
cause=e,
|
cause=e,
|
||||||
stack=traceback.format_exc(),
|
stack=traceback.format_exc(),
|
||||||
details={"files": [f.filename for f in files]},
|
details={"files": processing_errors},
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error uploading files to container '{container_name}'.",
|
detail=f"Error uploading files to container '{container_name}'.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -182,7 +201,7 @@ async def upload_files(
|
|||||||
"/{container_name}",
|
"/{container_name}",
|
||||||
summary="Delete a data storage container",
|
summary="Delete a data storage container",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
responses={200: {"model": BaseResponse}},
|
responses={status.HTTP_200_OK: {"model": BaseResponse}},
|
||||||
)
|
)
|
||||||
async def delete_files(
|
async def delete_files(
|
||||||
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
||||||
|
@ -8,6 +8,7 @@ from fastapi import (
|
|||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
@ -31,6 +32,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
"/graphml/{container_name}",
|
"/graphml/{container_name}",
|
||||||
summary="Retrieve a GraphML file of the knowledge graph",
|
summary="Retrieve a GraphML file of the knowledge graph",
|
||||||
response_description="GraphML file successfully downloaded",
|
response_description="GraphML file successfully downloaded",
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
)
|
)
|
||||||
async def get_graphml_file(
|
async def get_graphml_file(
|
||||||
container_name, sanitized_container_name: str = Depends(sanitize_name)
|
container_name, sanitized_container_name: str = Depends(sanitize_name)
|
||||||
|
@ -12,6 +12,7 @@ from fastapi import (
|
|||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from kubernetes import (
|
from kubernetes import (
|
||||||
client as kubernetes_client,
|
client as kubernetes_client,
|
||||||
@ -49,7 +50,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
"",
|
"",
|
||||||
summary="Build an index",
|
summary="Build an index",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
responses={200: {"model": BaseResponse}},
|
responses={status.HTTP_202_ACCEPTED: {"model": BaseResponse}},
|
||||||
)
|
)
|
||||||
async def schedule_index_job(
|
async def schedule_index_job(
|
||||||
storage_container_name: str,
|
storage_container_name: str,
|
||||||
@ -71,7 +72,7 @@ async def schedule_index_job(
|
|||||||
sanitized_storage_container_name
|
sanitized_storage_container_name
|
||||||
).exists():
|
).exists():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=status.HTTP_412_PRECONDITION_FAILED,
|
||||||
detail=f"Storage container '{storage_container_name}' does not exist",
|
detail=f"Storage container '{storage_container_name}' does not exist",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,7 +102,7 @@ async def schedule_index_job(
|
|||||||
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
|
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=202, # request has been accepted for processing but is not complete.
|
status_code=status.HTTP_425_TOO_EARLY, # request has been accepted for processing but is not complete.
|
||||||
detail=f"Index '{index_container_name}' already exists and has not finished building.",
|
detail=f"Index '{index_container_name}' already exists and has not finished building.",
|
||||||
)
|
)
|
||||||
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
||||||
@ -142,7 +143,7 @@ async def schedule_index_job(
|
|||||||
"",
|
"",
|
||||||
summary="Get all index names",
|
summary="Get all index names",
|
||||||
response_model=IndexNameList,
|
response_model=IndexNameList,
|
||||||
responses={200: {"model": IndexNameList}},
|
responses={status.HTTP_200_OK: {"model": IndexNameList}},
|
||||||
)
|
)
|
||||||
async def get_all_index_names(
|
async def get_all_index_names(
|
||||||
container_store_client=Depends(get_cosmos_container_store_client),
|
container_store_client=Depends(get_cosmos_container_store_client),
|
||||||
@ -218,7 +219,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
|
|||||||
"/{container_name}",
|
"/{container_name}",
|
||||||
summary="Delete a specified index",
|
summary="Delete a specified index",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
responses={200: {"model": BaseResponse}},
|
responses={status.HTTP_200_OK: {"model": BaseResponse}},
|
||||||
)
|
)
|
||||||
async def delete_index(
|
async def delete_index(
|
||||||
container_name: str,
|
container_name: str,
|
||||||
@ -257,7 +258,8 @@ async def delete_index(
|
|||||||
details={"container": container_name},
|
details={"container": container_name},
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Error deleting '{container_name}'."
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error deleting '{container_name}'.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return BaseResponse(status="Success")
|
return BaseResponse(status="Success")
|
||||||
@ -267,6 +269,7 @@ async def delete_index(
|
|||||||
"/status/{container_name}",
|
"/status/{container_name}",
|
||||||
summary="Track the status of an indexing job",
|
summary="Track the status of an indexing job",
|
||||||
response_model=IndexStatusResponse,
|
response_model=IndexStatusResponse,
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
)
|
)
|
||||||
async def get_index_status(
|
async def get_index_status(
|
||||||
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
|
||||||
@ -275,7 +278,7 @@ async def get_index_status(
|
|||||||
if pipelinejob.item_exist(sanitized_container_name):
|
if pipelinejob.item_exist(sanitized_container_name):
|
||||||
pipeline_job = pipelinejob.load_item(sanitized_container_name)
|
pipeline_job = pipelinejob.load_item(sanitized_container_name)
|
||||||
return IndexStatusResponse(
|
return IndexStatusResponse(
|
||||||
status_code=200,
|
status_code=status.HTTP_200_OK,
|
||||||
index_name=pipeline_job.human_readable_index_name,
|
index_name=pipeline_job.human_readable_index_name,
|
||||||
storage_name=pipeline_job.human_readable_storage_name,
|
storage_name=pipeline_job.human_readable_storage_name,
|
||||||
status=pipeline_job.status.value,
|
status=pipeline_job.status.value,
|
||||||
@ -284,5 +287,6 @@ async def get_index_status(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"'{container_name}' does not exist."
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"'{container_name}' does not exist.",
|
||||||
)
|
)
|
||||||
|
@ -11,6 +11,7 @@ from fastapi import (
|
|||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
"/prompts",
|
"/prompts",
|
||||||
summary="Generate custom graphrag prompts based on user-provided data.",
|
summary="Generate custom graphrag prompts based on user-provided data.",
|
||||||
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
)
|
)
|
||||||
async def generate_prompts(
|
async def generate_prompts(
|
||||||
container_name: str,
|
container_name: str,
|
||||||
|
@ -10,6 +10,7 @@ from fastapi import (
|
|||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from graphrag.api.query import global_search, local_search
|
from graphrag.api.query import global_search, local_search
|
||||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||||
@ -42,7 +43,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
summary="Perform a global search across the knowledge graph index",
|
summary="Perform a global search across the knowledge graph index",
|
||||||
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
||||||
response_model=GraphResponse,
|
response_model=GraphResponse,
|
||||||
responses={200: {"model": GraphResponse}},
|
responses={status.HTTP_200_OK: {"model": GraphResponse}},
|
||||||
)
|
)
|
||||||
async def global_query(request: GraphRequest):
|
async def global_query(request: GraphRequest):
|
||||||
# this is a slightly modified version of the graphrag.query.cli.run_global_search method
|
# this is a slightly modified version of the graphrag.query.cli.run_global_search method
|
||||||
@ -51,7 +52,7 @@ async def global_query(request: GraphRequest):
|
|||||||
|
|
||||||
if not _is_index_complete(sanitized_index_name):
|
if not _is_index_complete(sanitized_index_name):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=status.HTTP_425_TOO_EARLY,
|
||||||
detail=f"{index_name} not ready for querying.",
|
detail=f"{index_name} not ready for querying.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,7 +123,7 @@ async def global_query(request: GraphRequest):
|
|||||||
summary="Perform a local search across the knowledge graph index.",
|
summary="Perform a local search across the knowledge graph index.",
|
||||||
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
||||||
response_model=GraphResponse,
|
response_model=GraphResponse,
|
||||||
responses={200: {"model": GraphResponse}},
|
responses={status.HTTP_200_OK: {"model": GraphResponse}},
|
||||||
)
|
)
|
||||||
async def local_query(request: GraphRequest):
|
async def local_query(request: GraphRequest):
|
||||||
index_name = request.index_name
|
index_name = request.index_name
|
||||||
@ -130,7 +131,7 @@ async def local_query(request: GraphRequest):
|
|||||||
|
|
||||||
if not _is_index_complete(sanitized_index_name):
|
if not _is_index_complete(sanitized_index_name):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=status.HTTP_425_TOO_EARLY,
|
||||||
detail=f"{index_name} not ready for querying.",
|
detail=f"{index_name} not ready for querying.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from fastapi import (
|
|||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from graphrag.api.query import (
|
from graphrag.api.query import (
|
||||||
@ -47,6 +48,7 @@ if os.getenv("KUBERNETES_SERVICE_HOST"):
|
|||||||
"/global",
|
"/global",
|
||||||
summary="Stream a response back after performing a global search",
|
summary="Stream a response back after performing a global search",
|
||||||
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
)
|
)
|
||||||
async def global_search_streaming(request: GraphRequest):
|
async def global_search_streaming(request: GraphRequest):
|
||||||
# this is a slightly modified version of graphrag_app.api.query.global_query() method
|
# this is a slightly modified version of graphrag_app.api.query.global_query() method
|
||||||
@ -204,6 +206,7 @@ async def global_search_streaming(request: GraphRequest):
|
|||||||
"/local",
|
"/local",
|
||||||
summary="Stream a response back after performing a local search",
|
summary="Stream a response back after performing a local search",
|
||||||
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
)
|
)
|
||||||
async def local_search_streaming(request: GraphRequest):
|
async def local_search_streaming(request: GraphRequest):
|
||||||
# this is a slightly modified version of graphrag_app.api.query.local_query() method
|
# this is a slightly modified version of graphrag_app.api.query.local_query() method
|
||||||
|
@ -5,7 +5,12 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
|
||||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||||
from graphrag_app.typing.models import (
|
from graphrag_app.typing.models import (
|
||||||
@ -43,7 +48,7 @@ DOCUMENTS_TABLE = "output/create_final_documents.parquet"
|
|||||||
"/report/{container_name}/{report_id}",
|
"/report/{container_name}/{report_id}",
|
||||||
summary="Return a single community report.",
|
summary="Return a single community report.",
|
||||||
response_model=ReportResponse,
|
response_model=ReportResponse,
|
||||||
responses={200: {"model": ReportResponse}},
|
responses={status.HTTP_200_OK: {"model": ReportResponse}},
|
||||||
)
|
)
|
||||||
async def get_report_info(
|
async def get_report_info(
|
||||||
report_id: int,
|
report_id: int,
|
||||||
@ -88,7 +93,7 @@ async def get_report_info(
|
|||||||
"/text/{container_name}/{text_unit_id}",
|
"/text/{container_name}/{text_unit_id}",
|
||||||
summary="Return a single base text unit.",
|
summary="Return a single base text unit.",
|
||||||
response_model=TextUnitResponse,
|
response_model=TextUnitResponse,
|
||||||
responses={200: {"model": TextUnitResponse}},
|
responses={status.HTTP_200_OK: {"model": TextUnitResponse}},
|
||||||
)
|
)
|
||||||
async def get_chunk_info(
|
async def get_chunk_info(
|
||||||
text_unit_id: str,
|
text_unit_id: str,
|
||||||
@ -148,7 +153,7 @@ async def get_chunk_info(
|
|||||||
"/entity/{container_name}/{entity_id}",
|
"/entity/{container_name}/{entity_id}",
|
||||||
summary="Return a single entity.",
|
summary="Return a single entity.",
|
||||||
response_model=EntityResponse,
|
response_model=EntityResponse,
|
||||||
responses={200: {"model": EntityResponse}},
|
responses={status.HTTP_200_OK: {"model": EntityResponse}},
|
||||||
)
|
)
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
entity_id: int,
|
entity_id: int,
|
||||||
@ -190,7 +195,7 @@ async def get_entity_info(
|
|||||||
"/claim/{container_name}/{claim_id}",
|
"/claim/{container_name}/{claim_id}",
|
||||||
summary="Return a single claim.",
|
summary="Return a single claim.",
|
||||||
response_model=ClaimResponse,
|
response_model=ClaimResponse,
|
||||||
responses={200: {"model": ClaimResponse}},
|
responses={status.HTTP_200_OK: {"model": ClaimResponse}},
|
||||||
)
|
)
|
||||||
async def get_claim_info(
|
async def get_claim_info(
|
||||||
claim_id: int,
|
claim_id: int,
|
||||||
@ -240,7 +245,7 @@ async def get_claim_info(
|
|||||||
"/relationship/{container_name}/{relationship_id}",
|
"/relationship/{container_name}/{relationship_id}",
|
||||||
summary="Return a single relationship.",
|
summary="Return a single relationship.",
|
||||||
response_model=RelationshipResponse,
|
response_model=RelationshipResponse,
|
||||||
responses={200: {"model": RelationshipResponse}},
|
responses={status.HTTP_200_OK: {"model": RelationshipResponse}},
|
||||||
)
|
)
|
||||||
async def get_relationship_info(
|
async def get_relationship_info(
|
||||||
relationship_id: int,
|
relationship_id: int,
|
||||||
|
@ -33,7 +33,7 @@ from graphrag_app.utils.common import subscription_key_check
|
|||||||
|
|
||||||
|
|
||||||
async def catch_all_exceptions_middleware(request: Request, call_next):
|
async def catch_all_exceptions_middleware(request: Request, call_next):
|
||||||
"""a function to globally catch all exceptions and return a 500 response with the exception message"""
|
"""A global function to catch all exceptions and produce a standard error message"""
|
||||||
try:
|
try:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -44,7 +44,10 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
|
|||||||
cause=e,
|
cause=e,
|
||||||
stack=stack,
|
stack=stack,
|
||||||
)
|
)
|
||||||
return Response("Unexpected internal server error.", status_code=500)
|
return Response(
|
||||||
|
"Unexpected internal server error.",
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved
|
# NOTE: this function is not currently used, but it is a placeholder for future use once RBAC issues have been resolved
|
||||||
|
@ -1,27 +1,35 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
"""
|
||||||
|
Common utility functions used by the API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Annotated
|
from io import StringIO
|
||||||
|
from typing import Annotated, Tuple
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
from azure.cosmos import ContainerProxy, exceptions
|
from azure.cosmos import ContainerProxy, exceptions
|
||||||
from azure.identity import DefaultAzureCredential
|
from azure.identity import DefaultAzureCredential
|
||||||
from azure.storage.blob.aio import ContainerClient
|
from azure.storage.blob.aio import ContainerClient
|
||||||
from fastapi import Header, HTTPException
|
from fastapi import Header, HTTPException, status
|
||||||
|
|
||||||
from graphrag_app.logger.load_logger import load_pipeline_logger
|
from graphrag_app.logger.load_logger import load_pipeline_logger
|
||||||
from graphrag_app.utils.azure_clients import AzureClientManager
|
from graphrag_app.utils.azure_clients import AzureClientManager
|
||||||
|
|
||||||
|
FILE_UPLOAD_CACHE = "cache/uploaded_files.csv"
|
||||||
|
|
||||||
|
|
||||||
def get_df(
|
def get_df(
|
||||||
table_path: str,
|
filepath: str,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
|
"""Read a parquet file from Azure Storage and return it as a pandas DataFrame."""
|
||||||
df = pd.read_parquet(
|
df = pd.read_parquet(
|
||||||
table_path,
|
filepath,
|
||||||
storage_options=pandas_storage_options(),
|
storage_options=pandas_storage_options(),
|
||||||
)
|
)
|
||||||
return df
|
return df
|
||||||
@ -123,7 +131,10 @@ def get_cosmos_container_store_client() -> ContainerProxy:
|
|||||||
cause=e,
|
cause=e,
|
||||||
stack=traceback.format_exc(),
|
stack=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail="Error fetching cosmosdb client.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error fetching cosmosdb client.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_blob_container_client(name: str) -> ContainerClient:
|
async def get_blob_container_client(name: str) -> ContainerClient:
|
||||||
@ -141,7 +152,10 @@ async def get_blob_container_client(name: str) -> ContainerClient:
|
|||||||
cause=e,
|
cause=e,
|
||||||
stack=traceback.format_exc(),
|
stack=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail="Error fetching storage client.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error fetching storage client.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(container_name: str) -> str:
|
def sanitize_name(container_name: str) -> str:
|
||||||
@ -188,7 +202,8 @@ def desanitize_name(sanitized_container_name: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="Error retrieving original container name."
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving original container name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -196,13 +211,102 @@ async def subscription_key_check(
|
|||||||
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
|
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Verifies if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
|
Verify if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
|
||||||
If it is not present, an HTTPException with a 400 status code is raised.
|
Note: this check is unnecessary (APIM validates subscription keys automatically), but it effectively adds the key
|
||||||
Note: this check is unnecessary (APIM validates subscription keys automatically), but this will add the key
|
|
||||||
as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature.
|
as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature.
|
||||||
"""
|
"""
|
||||||
if not Ocp_Apim_Subscription_Key:
|
if not Ocp_Apim_Subscription_Key:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Ocp-Apim-Subscription-Key required"
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Ocp-Apim-Subscription-Key required",
|
||||||
)
|
)
|
||||||
return Ocp_Apim_Subscription_Key
|
return Ocp_Apim_Subscription_Key
|
||||||
|
|
||||||
|
|
||||||
|
async def create_cache(container_client: ContainerClient) -> None:
|
||||||
|
"""
|
||||||
|
Create a file cache (csv).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||||
|
if not await cache_blob_client.exists():
|
||||||
|
# create the empty file cache csv
|
||||||
|
headers = [["Filename", "Hash"]]
|
||||||
|
tmp_cache_file = "uploaded_files_cache.csv"
|
||||||
|
with open(tmp_cache_file, "w", newline="") as f:
|
||||||
|
writer = csv.writer(f, delimiter=",")
|
||||||
|
writer.writerows(headers)
|
||||||
|
|
||||||
|
# upload to Azure Blob Storage and remove the temporary file
|
||||||
|
with open(tmp_cache_file, "rb") as f:
|
||||||
|
await cache_blob_client.upload_blob(f, overwrite=True)
|
||||||
|
if os.path.exists(tmp_cache_file):
|
||||||
|
os.remove(tmp_cache_file)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error creating file cache in Azure Blob Storage.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_cache(file_hash: str, container_client: ContainerClient) -> bool:
|
||||||
|
"""
|
||||||
|
Check a file cache (csv) to determine if a file hash has previously been uploaded.
|
||||||
|
|
||||||
|
Note: This function creates/checks a CSV file in azure storage to act as a cache of previously uploaded files.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# load the file cache
|
||||||
|
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||||
|
cache_download_stream = await cache_blob_client.download_blob()
|
||||||
|
cache_bytes = await cache_download_stream.readall()
|
||||||
|
cache_content = StringIO(cache_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
# comupte the sha256 hash of the file and check if it exists in the cache
|
||||||
|
cache_reader = csv.reader(cache_content, delimiter=",")
|
||||||
|
for row in cache_reader:
|
||||||
|
if file_hash in row:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error checking file cache in Azure Blob Storage.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_cache(
|
||||||
|
new_files: Tuple[str, str], container_client: ContainerClient
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing file cache (csv) with new files.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load the existing cache
|
||||||
|
cache_blob_client = container_client.get_blob_client(FILE_UPLOAD_CACHE)
|
||||||
|
cache_download_stream = await cache_blob_client.download_blob()
|
||||||
|
cache_bytes = await cache_download_stream.readall()
|
||||||
|
cache_content = StringIO(cache_bytes.decode("utf-8"))
|
||||||
|
cache_reader = csv.reader(cache_content, delimiter=",")
|
||||||
|
|
||||||
|
# append new data
|
||||||
|
existing_rows = list(cache_reader)
|
||||||
|
for filename, file_hash in new_files:
|
||||||
|
row = [filename, file_hash]
|
||||||
|
existing_rows.append(row)
|
||||||
|
|
||||||
|
# Write the updated content back to the StringIO object
|
||||||
|
updated_cache_content = StringIO()
|
||||||
|
cache_writer = csv.writer(updated_cache_content, delimiter=",")
|
||||||
|
cache_writer.writerows(existing_rows)
|
||||||
|
|
||||||
|
# Upload the updated cache to Azure Blob Storage
|
||||||
|
updated_cache_content.seek(0)
|
||||||
|
await cache_blob_client.upload_blob(
|
||||||
|
updated_cache_content.getvalue().encode("utf-8"), overwrite=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error updating file cache in Azure Blob Storage.",
|
||||||
|
)
|
||||||
|
1756
backend/poetry.lock
generated
1756
backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -53,6 +53,7 @@ fsspec = ">=2024.2.0"
|
|||||||
graphrag = "==1.2.0"
|
graphrag = "==1.2.0"
|
||||||
httpx = ">=0.25.2"
|
httpx = ">=0.25.2"
|
||||||
kubernetes = ">=29.0.0"
|
kubernetes = ">=29.0.0"
|
||||||
|
markitdown = {extras = ["all"], version = "^0.1.1"}
|
||||||
networkx = ">=3.2.1"
|
networkx = ">=3.2.1"
|
||||||
nltk = "*"
|
nltk = "*"
|
||||||
pandas = ">=2.2.1"
|
pandas = ">=2.2.1"
|
||||||
|
@ -22,7 +22,7 @@ def test_schedule_index_without_data(client, cosmos_client: CosmosClient):
|
|||||||
"storage_container_name": "nonexistent-data-container",
|
"storage_container_name": "nonexistent-data-container",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 500
|
assert response.status_code == 412
|
||||||
|
|
||||||
|
|
||||||
# def test_schedule_index_with_data(client, cosmos_client, blob_with_data_container_name):
|
# def test_schedule_index_with_data(client, cosmos_client, blob_with_data_container_name):
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"! pip install devtools python-magic requests tqdm"
|
"! pip install devtools requests tqdm"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -35,7 +35,6 @@
|
|||||||
"import time\n",
|
"import time\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import magic\n",
|
|
||||||
"import requests\n",
|
"import requests\n",
|
||||||
"from devtools import pprint\n",
|
"from devtools import pprint\n",
|
||||||
"from tqdm import tqdm"
|
"from tqdm import tqdm"
|
||||||
@ -127,7 +126,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Upload Files\n",
|
"## Upload Files\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag."
|
"For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.**"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -177,20 +176,14 @@
|
|||||||
" return response\n",
|
" return response\n",
|
||||||
"\n",
|
"\n",
|
||||||
" batch_files = []\n",
|
" batch_files = []\n",
|
||||||
" accepted_file_types = [\"text/plain\"]\n",
|
|
||||||
" filepaths = list(Path(file_directory).iterdir())\n",
|
" filepaths = list(Path(file_directory).iterdir())\n",
|
||||||
" for file in tqdm(filepaths):\n",
|
" for file in tqdm(filepaths):\n",
|
||||||
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
||||||
" if (\n",
|
" if (not file.is_file()):\n",
|
||||||
" not file.is_file()\n",
|
|
||||||
" or file.suffix != \".txt\"\n",
|
|
||||||
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
|
|
||||||
" ):\n",
|
|
||||||
" print(f\"Skipping invalid file: {file}\")\n",
|
" print(f\"Skipping invalid file: {file}\")\n",
|
||||||
" continue\n",
|
" continue\n",
|
||||||
" # open and decode file as utf-8, ignore bad characters\n",
|
|
||||||
" batch_files.append(\n",
|
" batch_files.append(\n",
|
||||||
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n",
|
" (\"files\", open(file=file, mode=\"rb\"))\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" # upload batch of files\n",
|
" # upload batch of files\n",
|
||||||
" if len(batch_files) == batch_size:\n",
|
" if len(batch_files) == batch_size:\n",
|
||||||
@ -199,7 +192,7 @@
|
|||||||
" if not response.ok:\n",
|
" if not response.ok:\n",
|
||||||
" return response\n",
|
" return response\n",
|
||||||
" batch_files.clear()\n",
|
" batch_files.clear()\n",
|
||||||
" # upload remaining files\n",
|
" # upload last batch of remaining files\n",
|
||||||
" if len(batch_files) > 0:\n",
|
" if len(batch_files) > 0:\n",
|
||||||
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
||||||
" return response\n",
|
" return response\n",
|
||||||
@ -330,9 +323,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%%time\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def global_search(\n",
|
"def global_search(\n",
|
||||||
" index_name: str | list[str], query: str, community_level: int\n",
|
" index_name: str | list[str], query: str, community_level: int\n",
|
||||||
") -> requests.Response:\n",
|
") -> requests.Response:\n",
|
||||||
@ -372,9 +362,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%%time\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def local_search(\n",
|
"def local_search(\n",
|
||||||
" index_name: str | list[str], query: str, community_level: int\n",
|
" index_name: str | list[str], query: str, community_level: int\n",
|
||||||
") -> requests.Response:\n",
|
") -> requests.Response:\n",
|
||||||
@ -402,7 +389,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "graphrag-venv",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -27,17 +27,14 @@
|
|||||||
"| DELETE | /index/{index_name}\n",
|
"| DELETE | /index/{index_name}\n",
|
||||||
"| GET | /index/status/{index_name}\n",
|
"| GET | /index/status/{index_name}\n",
|
||||||
"| POST | /query/global\n",
|
"| POST | /query/global\n",
|
||||||
"| POST | /query/streaming/global\n",
|
|
||||||
"| POST | /query/local\n",
|
"| POST | /query/local\n",
|
||||||
"| POST | /query/streaming/local\n",
|
|
||||||
"| GET | /index/config/prompts\n",
|
"| GET | /index/config/prompts\n",
|
||||||
"| GET | /source/report/{index_name}/{report_id}\n",
|
"| GET | /source/report/{index_name}/{report_id}\n",
|
||||||
"| GET | /source/text/{index_name}/{text_unit_id}\n",
|
"| GET | /source/text/{index_name}/{text_unit_id}\n",
|
||||||
"| GET | /source/entity/{index_name}/{entity_id}\n",
|
"| GET | /source/entity/{index_name}/{entity_id}\n",
|
||||||
"| GET | /source/claim/{index_name}/{claim_id}\n",
|
"| GET | /source/claim/{index_name}/{claim_id}\n",
|
||||||
"| GET | /source/relationship/{index_name}/{relationship_id}\n",
|
"| GET | /source/relationship/{index_name}/{relationship_id}\n",
|
||||||
"| GET | /graph/graphml/{index_name}\n",
|
"| GET | /graph/graphml/{index_name}"
|
||||||
"| GET | /graph/stats/{index_name}"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -56,7 +53,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"! pip install devtools pandas python-magic requests tqdm"
|
"! pip install devtools pandas requests tqdm"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -72,7 +69,6 @@
|
|||||||
"import time\n",
|
"import time\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import magic\n",
|
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import requests\n",
|
"import requests\n",
|
||||||
"from devtools import pprint\n",
|
"from devtools import pprint\n",
|
||||||
@ -229,20 +225,14 @@
|
|||||||
" return response\n",
|
" return response\n",
|
||||||
"\n",
|
"\n",
|
||||||
" batch_files = []\n",
|
" batch_files = []\n",
|
||||||
" accepted_file_types = [\"text/plain\"]\n",
|
|
||||||
" filepaths = list(Path(file_directory).iterdir())\n",
|
" filepaths = list(Path(file_directory).iterdir())\n",
|
||||||
" for file in tqdm(filepaths):\n",
|
" for file in tqdm(filepaths):\n",
|
||||||
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
" # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
|
||||||
" if (\n",
|
" if (not file.is_file()):\n",
|
||||||
" not file.is_file()\n",
|
|
||||||
" or file.suffix != \".txt\"\n",
|
|
||||||
" or magic.from_file(str(file), mime=True) not in accepted_file_types\n",
|
|
||||||
" ):\n",
|
|
||||||
" print(f\"Skipping invalid file: {file}\")\n",
|
" print(f\"Skipping invalid file: {file}\")\n",
|
||||||
" continue\n",
|
" continue\n",
|
||||||
" # open and decode file as utf-8, ignore bad characters\n",
|
|
||||||
" batch_files.append(\n",
|
" batch_files.append(\n",
|
||||||
" (\"files\", open(file=file, mode=\"r\", encoding=\"utf-8\", errors=\"ignore\"))\n",
|
" (\"files\", open(file=file, mode=\"rb\"))\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" # upload batch of files\n",
|
" # upload batch of files\n",
|
||||||
" if len(batch_files) == batch_size:\n",
|
" if len(batch_files) == batch_size:\n",
|
||||||
@ -251,7 +241,7 @@
|
|||||||
" if not response.ok:\n",
|
" if not response.ok:\n",
|
||||||
" return response\n",
|
" return response\n",
|
||||||
" batch_files.clear()\n",
|
" batch_files.clear()\n",
|
||||||
" # upload remaining files\n",
|
" # upload last batch of remaining files\n",
|
||||||
" if len(batch_files) > 0:\n",
|
" if len(batch_files) > 0:\n",
|
||||||
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
" response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
|
||||||
" return response\n",
|
" return response\n",
|
||||||
@ -290,7 +280,10 @@
|
|||||||
" return requests.post(\n",
|
" return requests.post(\n",
|
||||||
" url,\n",
|
" url,\n",
|
||||||
" files=prompts if len(prompts) > 0 else None,\n",
|
" files=prompts if len(prompts) > 0 else None,\n",
|
||||||
" params={\"index_container_name\": index_name, \"storage_container_name\": storage_name},\n",
|
" params={\n",
|
||||||
|
" \"index_container_name\": index_name,\n",
|
||||||
|
" \"storage_container_name\": storage_name,\n",
|
||||||
|
" },\n",
|
||||||
" headers=headers,\n",
|
" headers=headers,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -486,7 +479,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Upload files\n",
|
"## Upload files\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Use the API to upload a collection of local files. The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems."
|
"Use the API to upload a collection of files. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.** The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -619,7 +612,9 @@
|
|||||||
" community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
|
" community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
|
||||||
" summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
|
" summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" entity_extraction_prompt = community_summarization_prompt = summarize_description_prompt = None\n",
|
" entity_extraction_prompt = community_summarization_prompt = (\n",
|
||||||
|
" summarize_description_prompt\n",
|
||||||
|
" ) = None\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = build_index(\n",
|
"response = build_index(\n",
|
||||||
" storage_name=storage_name,\n",
|
" storage_name=storage_name,\n",
|
||||||
@ -745,8 +740,8 @@
|
|||||||
"# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n",
|
"# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n",
|
||||||
"global_response = global_search(\n",
|
"global_response = global_search(\n",
|
||||||
" index_name=index_name,\n",
|
" index_name=index_name,\n",
|
||||||
" query=\"Summarize the main topics found in this data\",\n",
|
" query=\"Summarize the qualifications to being a delivery data scientist\",\n",
|
||||||
" community_level=1,\n",
|
" community_level=2,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"# print the result and save context data in a variable\n",
|
"# print the result and save context data in a variable\n",
|
||||||
"global_response_data = parse_query_response(global_response, return_context_data=True)\n",
|
"global_response_data = parse_query_response(global_response, return_context_data=True)\n",
|
||||||
@ -956,7 +951,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "graphrag-venv",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user