graphrag-accelerator/backend/src/api/azure_clients.py

178 lines
5.6 KiB
Python
Raw Normal View History

2024-06-26 15:45:06 -04:00
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from azure.cosmos import (
ContainerProxy,
CosmosClient,
DatabaseProxy,
)
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient as BlobServiceClientAsync
from environs import Env
ENDPOINT_ERROR_MSG = "Could not find connection string in environment variables"
from dotenv import load_dotenv
load_dotenv()
class CosmosClientSingleton:
_instance = None
_env = Env()
@classmethod
def get_instance(cls):
if cls._instance is None:
endpoint = os.environ["COSMOS_URI_ENDPOINT"]
credential = DefaultAzureCredential()
cls._instance = CosmosClient(endpoint, credential)
return cls._instance
class BlobServiceClientSingleton:
_instance = None
_env = Env()
@classmethod
def get_instance(cls):
if cls._instance is None:
account_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
credential = DefaultAzureCredential()
cls._instance = BlobServiceClient(account_url, credential=credential)
return cls._instance
@classmethod
def get_storage_account_name(cls):
account_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
return account_url.split("//")[1].split(".")[0]
class BlobServiceClientSingletonAsync:
_instance = None
_env = Env()
@classmethod
def get_instance(cls):
if cls._instance is None:
account_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
credential = DefaultAzureCredential()
cls._instance = BlobServiceClientAsync(account_url, credential=credential)
return cls._instance
@classmethod
def get_storage_account_name(cls):
account_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
return account_url.split("//")[1].split(".")[0]
def get_database_client(database_name: str) -> DatabaseProxy:
client = CosmosClientSingleton.get_instance()
return client.get_database_client(database_name)
def get_database_container_client(
database_name: str, container_name: str
) -> ContainerProxy:
db_client = get_database_client(database_name)
return db_client.get_container_client(container_name)
class AzureStorageClientManager:
"""
Manages the Azure storage clients for blob storage and Cosmos DB.
Attributes:
azure_storage_blob_url (str): The blob endpoint for azure storage.
cosmos_uri_endpoint (str): The uri endpoint for the Cosmos DB.
_blob_service_client (BlobServiceClient): The blob service client.
_blob_service_client_async (BlobServiceClientAsync): The asynchronous blob service client.
_cosmos_client (CosmosClient): The Cosmos DB client.
_cosmos_database_client (DatabaseProxy): The Cosmos DB database client.
_cosmos_container_client (ContainerProxy): The Cosmos DB container client.
"""
def __init__(self) -> None:
self._env = Env()
self.azure_storage_blob_url = self._env.str(
"STORAGE_ACCOUNT_BLOB_URL", ENDPOINT_ERROR_MSG
)
self.cosmos_uri_endpoint = self._env.str(
"COSMOS_URI_ENDPOINT", ENDPOINT_ERROR_MSG
)
credential = DefaultAzureCredential()
self._blob_service_client = BlobServiceClient(
account_url=os.environ["STORAGE_ACCOUNT_BLOB_URL"], credential=credential
)
self._blob_service_client_async = BlobServiceClientAsync(
account_url=os.environ["STORAGE_ACCOUNT_BLOB_URL"], credential=credential
)
self._cosmos_client = CosmosClient(
url=os.environ["COSMOS_URI_ENDPOINT"], credential=credential
)
def get_blob_service_client(self) -> BlobServiceClient:
"""
Returns the blob service client.
Returns:
BlobServiceClient: The blob service client.
"""
return self._blob_service_client
def get_blob_service_client_async(self) -> BlobServiceClientAsync:
"""
Returns the asynchronous blob service client.
Returns:
BlobServiceClientAsync: The asynchronous blob service client.
"""
return self._blob_service_client_async
def get_cosmos_client(self) -> CosmosClient:
"""
Returns the Cosmos DB client.
Returns:
CosmosClient: The Cosmos DB client.
"""
return self._cosmos_client
def get_cosmos_database_client(self, database_name: str) -> DatabaseProxy:
"""
Returns the Cosmos DB database client.
Args:
database_name (str): The name of the database.
Returns:
DatabaseProxy: The Cosmos DB database client.
"""
if not hasattr(self, "_cosmos_database_client"):
self._cosmos_database_client = self._cosmos_client.get_database_client(
database=database_name
)
return self._cosmos_database_client
def get_cosmos_container_client(
self, database_name: str, container_name: str
) -> ContainerProxy:
"""
Returns the Cosmos DB container client.
Args:
database_name (str): The name of the database.
container_name (str): The name of the container.
Returns:
ContainerProxy: The Cosmos DB container client.
"""
if not hasattr(self, "_cosmos_container_client"):
self._cosmos_container_client = self.get_cosmos_database_client(
database_name=database_name
).get_container_client(container=container_name)
return self._cosmos_container_client