graphrag/graphrag/index/reporting/blob_workflow_callbacks.py
Alonso Guevara 073f650ba9
Fix/json dumps ascii (#873)
* Ensure ascii false in json dumps, support for non ASCII chars

* Format

* Semver
2024-08-09 17:05:48 -06:00

109 lines
3.8 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A reporter that writes to a blob storage."""
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from datashaper import NoopWorkflowCallbacks
class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a blob storage."""
_blob_service_client: BlobServiceClient
_container_name: str
_max_block_count: int = 25000 # 25k blocks per blob
def __init__(
self,
connection_string: str | None,
container_name: str,
blob_name: str = "",
base_dir: str | None = None,
storage_account_blob_url: str | None = None,
): # type: ignore
"""Create a new instance of the BlobStorageReporter class."""
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)
self._connection_string = connection_string
self._storage_account_blob_url = storage_account_blob_url
if self._connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
self._connection_string
)
else:
if storage_account_blob_url is None:
msg = "Either connection_string or storage_account_blob_url must be provided."
raise ValueError(msg)
self._blob_service_client = BlobServiceClient(
storage_account_blob_url,
credential=DefaultAzureCredential(),
)
if blob_name == "":
blob_name = f"report/{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d-%H:%M:%S:%f')}.logs.json"
self._blob_name = str(Path(base_dir or "") / blob_name)
self._container_name = container_name
self._blob_client = self._blob_service_client.get_blob_client(
self._container_name, self._blob_name
)
if not self._blob_client.exists():
self._blob_client.create_append_blob()
self._num_blocks = 0 # refresh block counter
def _write_log(self, log: dict[str, Any]):
# create a new file when block count hits close 25k
if (
self._num_blocks >= self._max_block_count
): # Check if block count exceeds 25k
self.__init__(
self._connection_string,
self._container_name,
storage_account_blob_url=self._storage_account_blob_url,
)
blob_client = self._blob_service_client.get_blob_client(
self._container_name, self._blob_name
)
blob_client.append_block(json.dumps(log, ensure_ascii=False) + "\n")
# update the blob's block count
self._num_blocks += 1
def on_error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
):
"""Report an error."""
self._write_log({
"type": "error",
"data": message,
"cause": str(cause),
"stack": stack,
"details": details,
})
def on_warning(self, message: str, details: dict | None = None):
"""Report a warning."""
self._write_log({"type": "warning", "data": message, "details": details})
def on_log(self, message: str, details: dict | None = None):
"""Report a generic log message."""
self._write_log({"type": "log", "data": message, "details": details})