mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-10-03 04:46:27 +00:00
Fixes 14370: Add Azure Client, support Default Creds (#15554)
* Add Azure Client, support Default Creds
This commit is contained in:
parent
cb411d0aa2
commit
8b880bbf91
85
ingestion/src/metadata/clients/azure_client.py
Normal file
85
ingestion/src/metadata/clients/azure_client.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Module containing Azure Client
|
||||
"""
|
||||
|
||||
from metadata.generated.schema.security.credentials.azureCredentials import (
|
||||
AzureCredentials,
|
||||
)
|
||||
from metadata.utils.logger import utils_logger
|
||||
|
||||
logger = utils_logger()
|
||||
|
||||
|
||||
class AzureClient:
|
||||
"""
|
||||
AzureClient based on AzureCredentials.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials: "AzureCredentials"):
|
||||
self.credentials = credentials
|
||||
if not isinstance(credentials, AzureCredentials):
|
||||
self.credentials = AzureCredentials.parse_obj(credentials)
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
):
|
||||
from azure.identity import ClientSecretCredential, DefaultAzureCredential
|
||||
|
||||
try:
|
||||
if (
|
||||
getattr(self.credentials, "tenantId", None)
|
||||
and getattr(self.credentials, "clientId", None)
|
||||
and getattr(self.credentials, "clientSecret", None)
|
||||
):
|
||||
logger.info("Using Client Secret Credentials")
|
||||
return ClientSecretCredential(
|
||||
tenant_id=self.credentials.tenantId,
|
||||
client_id=self.credentials.clientId,
|
||||
client_secret=self.credentials.clientSecret.get_secret_value(),
|
||||
)
|
||||
else:
|
||||
logger.info("Using Default Azure Credentials")
|
||||
return DefaultAzureCredential()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Azure Client: {e}")
|
||||
raise e
|
||||
|
||||
def create_blob_client(self):
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
|
||||
try:
|
||||
logger.info("Creating Blob Service Client")
|
||||
if self.credentials.accountName:
|
||||
return BlobServiceClient(
|
||||
account_url=f"https://{self.credentials.accountName}.blob.core.windows.net/",
|
||||
credential=self.create_client(),
|
||||
)
|
||||
raise ValueError("Account Name is required to create Blob Service Client")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Blob Service Client: {e}")
|
||||
raise e
|
||||
|
||||
def create_secret_client(self):
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
try:
|
||||
if self.credentials.vaultName:
|
||||
logger.info("Creating Secret Client")
|
||||
return SecretClient(
|
||||
vault_url=f"https://{self.credentials.vaultName}.vault.azure.net/",
|
||||
credential=self.create_client(),
|
||||
)
|
||||
raise ValueError("Vault Name is required to create a Secret Client")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Secret Client: {e}")
|
||||
raise e
|
@ -0,0 +1,29 @@
|
||||
source:
|
||||
type: datalake
|
||||
serviceName: local_datalake4
|
||||
serviceConnection:
|
||||
config:
|
||||
type: Datalake
|
||||
configSource:
|
||||
securityConfig:
|
||||
clientId: clientId
|
||||
accountName: accountName
|
||||
bucketName: bucket name
|
||||
prefix: prefix
|
||||
sourceConfig:
|
||||
config:
|
||||
type: DatabaseMetadata
|
||||
tableFilterPattern:
|
||||
includes:
|
||||
- ''
|
||||
sink:
|
||||
type: metadata-rest
|
||||
config: {}
|
||||
workflowConfig:
|
||||
# loggerLevel: INFO # DEBUG, INFO, WARN or ERROR
|
||||
openMetadataServerConfig:
|
||||
hostPort: http://localhost:8585/api
|
||||
authProvider: openmetadata
|
||||
securityConfig:
|
||||
jwtToken: "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
|
||||
|
@ -19,6 +19,9 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import msal
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.dashboard.powerBIConnection import (
|
||||
PowerBIConnection,
|
||||
)
|
||||
from metadata.ingestion.api.steps import InvalidSourceException
|
||||
from metadata.ingestion.ometa.client import REST, ClientConfig
|
||||
from metadata.ingestion.source.dashboard.powerbi.models import (
|
||||
@ -52,7 +55,7 @@ class PowerBiApiClient:
|
||||
|
||||
client: REST
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PowerBIConnection):
|
||||
self.config = config
|
||||
self.msal_client = msal.ConfidentialClientApplication(
|
||||
client_id=self.config.clientId,
|
||||
|
@ -15,12 +15,13 @@ Source connection handler
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.azureSQLConnection import (
|
||||
Authentication,
|
||||
AzureSQLConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.mssqlConnection import (
|
||||
@ -40,13 +41,29 @@ def get_connection_url(connection: Union[AzureSQLConnection, MssqlConnection]) -
|
||||
Build the connection URL
|
||||
"""
|
||||
|
||||
if connection.authenticationMode:
|
||||
connection_string = f"Driver={connection.driver};Server={connection.hostPort};Database={connection.database};"
|
||||
connection_string += f"Uid={connection.username};"
|
||||
if (
|
||||
connection.authenticationMode.authentication
|
||||
== Authentication.ActiveDirectoryPassword
|
||||
):
|
||||
connection_string += f"Pwd={connection.password.get_secret_value()};"
|
||||
|
||||
connection_string += f"Encrypt={'yes' if connection.authenticationMode.encrypt else 'no'};TrustServerCertificate={'yes' if connection.authenticationMode.trustServerCertificate else 'no'};"
|
||||
connection_string += f"Connection Timeout={connection.authenticationMode.connectionTimeout or 30};Authentication={connection.authenticationMode.authentication.value};"
|
||||
|
||||
connection_url = URL.create(
|
||||
"mssql+pyodbc", query={"odbc_connect": connection_string}
|
||||
)
|
||||
return connection_url
|
||||
url = f"{connection.scheme.value}://"
|
||||
|
||||
if connection.username:
|
||||
url += f"{quote_plus(connection.username)}"
|
||||
url += (
|
||||
f":{quote_plus(connection.password.get_secret_value())}"
|
||||
if connection
|
||||
if connection.password
|
||||
else ""
|
||||
)
|
||||
url += "@"
|
||||
@ -54,12 +71,13 @@ def get_connection_url(connection: Union[AzureSQLConnection, MssqlConnection]) -
|
||||
url += f"{connection.hostPort}"
|
||||
url += f"/{quote_plus(connection.database)}" if connection.database else ""
|
||||
url += f"?driver={quote_plus(connection.driver)}"
|
||||
|
||||
options = get_connection_options_dict(connection)
|
||||
if options:
|
||||
if not connection.database:
|
||||
url += "/"
|
||||
params = "&".join(
|
||||
f"{key}={quote_plus(value)}" for (key, value) in options.items() if value
|
||||
f"{key}={quote_plus(value)}" for key, value in options.items() if value
|
||||
)
|
||||
url = f"{url}?{params}"
|
||||
|
||||
|
@ -20,6 +20,7 @@ from typing import Optional
|
||||
|
||||
from google.cloud import storage
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
@ -88,22 +89,9 @@ def _(config: GCSConfig):
|
||||
|
||||
@get_datalake_client.register
|
||||
def _(config: AzureConfig):
|
||||
from azure.identity import ClientSecretCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
|
||||
try:
|
||||
credentials = ClientSecretCredential(
|
||||
config.securityConfig.tenantId,
|
||||
config.securityConfig.clientId,
|
||||
config.securityConfig.clientSecret.get_secret_value(),
|
||||
)
|
||||
|
||||
azure_client = BlobServiceClient(
|
||||
f"https://{config.securityConfig.accountName}.blob.core.windows.net/",
|
||||
credential=credentials,
|
||||
)
|
||||
return azure_client
|
||||
|
||||
return AzureClient(config.securityConfig).create_blob_client()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Unknown error connecting with {config.securityConfig}: {exc}."
|
||||
|
@ -20,6 +20,7 @@ from typing import Dict, Iterable, List, Optional, Tuple
|
||||
import requests
|
||||
|
||||
from metadata.clients.aws_client import AWSClient
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtAzureConfig import (
|
||||
DbtAzureConfig,
|
||||
)
|
||||
@ -357,21 +358,8 @@ def _(config: DbtGcsConfig):
|
||||
def _(config: DbtAzureConfig):
|
||||
try:
|
||||
bucket_name, prefix = get_dbt_prefix_config(config)
|
||||
from azure.identity import ( # pylint: disable=import-outside-toplevel
|
||||
ClientSecretCredential,
|
||||
)
|
||||
from azure.storage.blob import ( # pylint: disable=import-outside-toplevel
|
||||
BlobServiceClient,
|
||||
)
|
||||
|
||||
client = BlobServiceClient(
|
||||
f"https://{config.dbtSecurityConfig.accountName}.blob.core.windows.net/",
|
||||
credential=ClientSecretCredential(
|
||||
config.dbtSecurityConfig.tenantId,
|
||||
config.dbtSecurityConfig.clientId,
|
||||
config.dbtSecurityConfig.clientSecret.get_secret_value(),
|
||||
),
|
||||
)
|
||||
client = AzureClient(config.dbtSecurityConfig).create_blob_client()
|
||||
|
||||
if not bucket_name:
|
||||
container_dicts = client.list_containers()
|
||||
|
@ -16,9 +16,13 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
||||
BasicAuth,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
|
||||
MysqlConnection,
|
||||
)
|
||||
@ -38,6 +42,16 @@ def get_connection(connection: MysqlConnection) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
if hasattr(connection.authType, "azureConfig"):
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
connection.authType = BasicAuth(password=access_token_obj.token)
|
||||
if connection.sslCA or connection.sslCert or connection.sslKey:
|
||||
if not connection.connectionOptions:
|
||||
connection.connectionOptions = init_empty_connection_options()
|
||||
|
@ -17,9 +17,13 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
||||
BasicAuth,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
||||
PostgresConnection,
|
||||
SslMode,
|
||||
@ -46,6 +50,17 @@ def get_connection(connection: PostgresConnection) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
|
||||
if hasattr(connection.authType, "azureConfig"):
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
connection.authType = BasicAuth(password=access_token_obj.token)
|
||||
if connection.sslMode:
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
|
@ -15,7 +15,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from google import auth
|
||||
@ -25,6 +25,9 @@ from metadata.generated.schema.security.credentials.gcpCredentials import (
|
||||
GCPCredentials,
|
||||
GcpCredentialsPath,
|
||||
)
|
||||
from metadata.generated.schema.security.credentials.gcpExternalAccount import (
|
||||
GcpCredentialsValuesExternalAccount,
|
||||
)
|
||||
from metadata.generated.schema.security.credentials.gcpValues import (
|
||||
GcpCredentialsValues,
|
||||
)
|
||||
@ -85,7 +88,9 @@ def create_credential_tmp_file(credentials: dict) -> str:
|
||||
return temp_file_path
|
||||
|
||||
|
||||
def build_google_credentials_dict(gcp_values: GcpCredentialsValues) -> Dict[str, str]:
|
||||
def build_google_credentials_dict(
|
||||
gcp_values: Union[GcpCredentialsValues, GcpCredentialsValuesExternalAccount]
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Given GcPCredentialsValues, build a dictionary as the JSON file
|
||||
downloaded from GCP with the service_account
|
||||
|
@ -17,9 +17,9 @@ import traceback
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from azure.identity import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.keyvault.secrets import KeyVaultSecret, SecretClient
|
||||
from azure.keyvault.secrets import KeyVaultSecret
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.security.secrets.secretsManagerClientLoader import (
|
||||
SecretsManagerClientLoader,
|
||||
)
|
||||
@ -105,23 +105,7 @@ class AzureKVSecretsManager(ExternalSecretsManager, ABC):
|
||||
):
|
||||
super().__init__(provider=SecretsManagerProvider.azure_kv, loader=loader)
|
||||
|
||||
if (
|
||||
self.credentials.tenantId
|
||||
and self.credentials.clientId
|
||||
and self.credentials.clientSecret
|
||||
):
|
||||
azure_identity = ClientSecretCredential(
|
||||
tenant_id=self.credentials.tenantId,
|
||||
client_id=self.credentials.clientId,
|
||||
client_secret=self.credentials.clientSecret.get_secret_value(),
|
||||
)
|
||||
else:
|
||||
azure_identity = DefaultAzureCredential()
|
||||
|
||||
self.client = SecretClient(
|
||||
vault_url=f"https://{self.credentials.vaultName}.vault.azure.net/",
|
||||
credential=azure_identity,
|
||||
)
|
||||
self.client = AzureClient(self.credentials).create_secret_client()
|
||||
|
||||
def get_string_value(self, secret_id: str) -> str:
|
||||
"""
|
||||
|
@ -17,6 +17,7 @@ from functools import singledispatch
|
||||
|
||||
import requests
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import (
|
||||
AzureConfig,
|
||||
)
|
||||
@ -153,21 +154,7 @@ def _(config: StorageMetadataAdlsConfig) -> ManifestMetadataConfig:
|
||||
else STORAGE_METADATA_MANIFEST_FILE_NAME
|
||||
)
|
||||
|
||||
from azure.identity import ( # pylint: disable=import-outside-toplevel
|
||||
ClientSecretCredential,
|
||||
)
|
||||
from azure.storage.blob import ( # pylint: disable=import-outside-toplevel
|
||||
BlobServiceClient,
|
||||
)
|
||||
|
||||
blob_client = BlobServiceClient(
|
||||
account_url=f"https://{config.securityConfig.accountName}.blob.core.windows.net/",
|
||||
credential=ClientSecretCredential(
|
||||
config.securityConfig.tenantId,
|
||||
config.securityConfig.clientId,
|
||||
config.securityConfig.clientSecret.get_secret_value(),
|
||||
),
|
||||
)
|
||||
blob_client = AzureClient(config.securityConfig).create_blob_client()
|
||||
|
||||
reader = get_reader(
|
||||
config_source=AzureConfig(securityConfig=config.securityConfig),
|
||||
|
@ -17,17 +17,6 @@ import time
|
||||
from datetime import datetime
|
||||
from unittest import TestCase
|
||||
|
||||
from ingestion.tests.integration.integration_base import (
|
||||
generate_name,
|
||||
get_create_entity,
|
||||
get_create_service,
|
||||
get_create_team_entity,
|
||||
get_create_test_case,
|
||||
get_create_test_definition,
|
||||
get_create_test_suite,
|
||||
get_create_user_entity,
|
||||
int_admin_ometa,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
from metadata.generated.schema.entity.data.table import Column, DataType, Table
|
||||
@ -54,6 +43,18 @@ from metadata.ingestion.models.patch_request import (
|
||||
from metadata.ingestion.models.table_metadata import ColumnTag
|
||||
from metadata.utils.helpers import find_column_in_table
|
||||
|
||||
from ..integration_base import (
|
||||
generate_name,
|
||||
get_create_entity,
|
||||
get_create_service,
|
||||
get_create_team_entity,
|
||||
get_create_test_case,
|
||||
get_create_test_definition,
|
||||
get_create_test_suite,
|
||||
get_create_user_entity,
|
||||
int_admin_ometa,
|
||||
)
|
||||
|
||||
PII_TAG_LABEL = TagLabel(
|
||||
tagFQN="PII.Sensitive",
|
||||
labelType=LabelType.Automated,
|
||||
|
@ -33,7 +33,6 @@ from unittest import TestCase
|
||||
from pymongo import MongoClient, database
|
||||
from testcontainers.mongodb import MongoDbContainer
|
||||
|
||||
from ingestion.tests.integration.integration_base import int_admin_ometa
|
||||
from metadata.generated.schema.entity.data.table import ColumnProfile, Table
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
@ -46,6 +45,8 @@ from metadata.workflow.metadata import MetadataWorkflow
|
||||
from metadata.workflow.profiler import ProfilerWorkflow
|
||||
from metadata.workflow.workflow_output_handler import print_status
|
||||
|
||||
from ..integration_base import int_admin_ometa
|
||||
|
||||
SERVICE_NAME = Path(__file__).stem
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
This file has been generated from dag_runner.j2
|
||||
"""
|
||||
from openmetadata.workflows import workflow_factory
|
||||
from openmetadata_managed_apis.workflows import workflow_factory
|
||||
|
||||
workflow = workflow_factory.WorkflowFactory.create(
|
||||
"/airflow/dag_generated_configs/local_redshift_profiler_e9AziRXs.json"
|
||||
|
@ -2,7 +2,7 @@
|
||||
This file has been generated from dag_runner.j2
|
||||
"""
|
||||
from airflow import DAG
|
||||
from openmetadata.workflows import workflow_factory
|
||||
from openmetadata_managed_apis.workflows import workflow_factory
|
||||
|
||||
workflow = workflow_factory.WorkflowFactory.create("/airflow/dag_generated_configs/local_redshift_profiler_e9AziRXs.json")
|
||||
workflow.generate_dag(globals())
|
63
ingestion/tests/unit/test_azure_credentials.py
Normal file
63
ingestion/tests/unit/test_azure_credentials.py
Normal file
@ -0,0 +1,63 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.security.credentials.azureCredentials import (
|
||||
AzureCredentials,
|
||||
)
|
||||
|
||||
|
||||
class TestAzureClient(unittest.TestCase):
|
||||
@patch("azure.identity.ClientSecretCredential")
|
||||
@patch("azure.identity.DefaultAzureCredential")
|
||||
def test_create_client(
|
||||
self,
|
||||
mock_default_credential,
|
||||
mock_client_secret_credential,
|
||||
):
|
||||
# Test with ClientSecretCredential
|
||||
credentials = AzureCredentials(
|
||||
clientId="clientId", clientSecret="clientSecret", tenantId="tenantId"
|
||||
)
|
||||
instance = AzureClient(credentials)
|
||||
instance.create_client()
|
||||
|
||||
mock_client_secret_credential.assert_called_once()
|
||||
mock_client_secret_credential.reset_mock()
|
||||
|
||||
credentials = AzureCredentials(
|
||||
clientId="clientId",
|
||||
)
|
||||
instance = AzureClient(credentials)
|
||||
|
||||
instance.create_client()
|
||||
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
@patch("azure.storage.blob.BlobServiceClient")
|
||||
def test_create_blob_client(self, mock_blob_service_client):
|
||||
credentials = AzureCredentials(
|
||||
clientId="clientId", clientSecret="clientSecret", tenantId="tenantId"
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
AzureClient(credentials=credentials).create_blob_client()
|
||||
|
||||
credentials.accountName = "accountName"
|
||||
AzureClient(credentials=credentials).create_blob_client()
|
||||
mock_blob_service_client.assert_called_once()
|
||||
|
||||
@patch("azure.keyvault.secrets.SecretClient")
|
||||
def test_create_secret_client(self, mock_secret_client):
|
||||
credentials = AzureCredentials(
|
||||
clientId="clientId", clientSecret="clientSecret", tenantId="tenantId"
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
AzureClient(credentials=credentials).create_secret_client()
|
||||
|
||||
credentials.vaultName = "vaultName"
|
||||
AzureClient(credentials=credentials).create_secret_client()
|
||||
mock_secret_client.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
138
ingestion/tests/unit/test_build_connection_url.py
Normal file
138
ingestion/tests/unit/test_build_connection_url.py
Normal file
@ -0,0 +1,138 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.identity import ClientSecretCredential
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.azureSQLConnection import (
|
||||
Authentication,
|
||||
AuthenticationMode,
|
||||
AzureSQLConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
|
||||
AzureConfigurationSource,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
||||
BasicAuth,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
|
||||
MysqlConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
||||
PostgresConnection,
|
||||
)
|
||||
from metadata.generated.schema.security.credentials.azureCredentials import (
|
||||
AzureCredentials,
|
||||
)
|
||||
from metadata.ingestion.source.database.azuresql.connection import get_connection_url
|
||||
from metadata.ingestion.source.database.mysql.connection import (
|
||||
get_connection as mysql_get_connection,
|
||||
)
|
||||
from metadata.ingestion.source.database.postgres.connection import (
|
||||
get_connection as postgres_get_connection,
|
||||
)
|
||||
|
||||
|
||||
class TestGetConnectionURL(unittest.TestCase):
|
||||
def test_get_connection_url_wo_active_directory_password(self):
|
||||
connection = AzureSQLConnection(
|
||||
driver="SQL Server",
|
||||
hostPort="myserver.database.windows.net",
|
||||
database="mydb",
|
||||
username="myuser",
|
||||
password="mypassword",
|
||||
authenticationMode=AuthenticationMode(
|
||||
authentication=Authentication.ActiveDirectoryPassword,
|
||||
encrypt=True,
|
||||
trustServerCertificate=False,
|
||||
connectionTimeout=45,
|
||||
),
|
||||
)
|
||||
expected_url = "mssql+pyodbc://?odbc_connect=Driver%3DSQL+Server%3BServer%3Dmyserver.database.windows.net%3BDatabase%3Dmydb%3BUid%3Dmyuser%3BPwd%3Dmypassword%3BEncrypt%3Dyes%3BTrustServerCertificate%3Dno%3BConnection+Timeout%3D45%3BAuthentication%3DActiveDirectoryPassword%3B"
|
||||
self.assertEqual(str(get_connection_url(connection)), expected_url)
|
||||
|
||||
connection = AzureSQLConnection(
|
||||
driver="SQL Server",
|
||||
hostPort="myserver.database.windows.net",
|
||||
database="mydb",
|
||||
username="myuser",
|
||||
password="mypassword",
|
||||
authenticationMode=AuthenticationMode(
|
||||
authentication=Authentication.ActiveDirectoryPassword,
|
||||
),
|
||||
)
|
||||
|
||||
expected_url = "mssql+pyodbc://?odbc_connect=Driver%3DSQL+Server%3BServer%3Dmyserver.database.windows.net%3BDatabase%3Dmydb%3BUid%3Dmyuser%3BPwd%3Dmypassword%3BEncrypt%3Dno%3BTrustServerCertificate%3Dno%3BConnection+Timeout%3D30%3BAuthentication%3DActiveDirectoryPassword%3B"
|
||||
self.assertEqual(str(get_connection_url(connection)), expected_url)
|
||||
|
||||
def test_get_connection_url_mysql(self):
|
||||
connection = MysqlConnection(
|
||||
username="openmetadata_user",
|
||||
authType=BasicAuth(password="openmetadata_password"),
|
||||
hostPort="localhost:3306",
|
||||
databaseSchema="openmetadata_db",
|
||||
)
|
||||
engine_connection = mysql_get_connection(connection)
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
||||
)
|
||||
connection = MysqlConnection(
|
||||
username="openmetadata_user",
|
||||
authType=AzureConfigurationSource(
|
||||
azureConfig=AzureCredentials(
|
||||
clientId="clientid",
|
||||
tenantId="tenantid",
|
||||
clientSecret="clientsecret",
|
||||
scopes="scope1,scope2",
|
||||
)
|
||||
),
|
||||
hostPort="localhost:3306",
|
||||
databaseSchema="openmetadata_db",
|
||||
)
|
||||
with patch.object(
|
||||
ClientSecretCredential,
|
||||
"get_token",
|
||||
return_value=AccessToken(token="mocked_token", expires_on=100),
|
||||
):
|
||||
engine_connection = mysql_get_connection(connection)
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",
|
||||
)
|
||||
|
||||
def test_get_connection_url_postgres(self):
|
||||
connection = PostgresConnection(
|
||||
username="openmetadata_user",
|
||||
authType=BasicAuth(password="openmetadata_password"),
|
||||
hostPort="localhost:3306",
|
||||
database="openmetadata_db",
|
||||
)
|
||||
engine_connection = postgres_get_connection(connection)
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
||||
)
|
||||
connection = PostgresConnection(
|
||||
username="openmetadata_user",
|
||||
authType=AzureConfigurationSource(
|
||||
azureConfig=AzureCredentials(
|
||||
clientId="clientid",
|
||||
tenantId="tenantid",
|
||||
clientSecret="clientsecret",
|
||||
scopes="scope1,scope2",
|
||||
)
|
||||
),
|
||||
hostPort="localhost:3306",
|
||||
database="openmetadata_db",
|
||||
)
|
||||
with patch.object(
|
||||
ClientSecretCredential,
|
||||
"get_token",
|
||||
return_value=AccessToken(token="mocked_token", expires_on=100),
|
||||
):
|
||||
engine_connection = postgres_get_connection(connection)
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",
|
||||
)
|
@ -9,13 +9,17 @@
|
||||
"azureSQLType": {
|
||||
"description": "Service type.",
|
||||
"type": "string",
|
||||
"enum": ["AzureSQL"],
|
||||
"enum": [
|
||||
"AzureSQL"
|
||||
],
|
||||
"default": "AzureSQL"
|
||||
},
|
||||
"azureSQLScheme": {
|
||||
"description": "SQLAlchemy driver scheme options.",
|
||||
"type": "string",
|
||||
"enum": ["mssql+pyodbc"],
|
||||
"enum": [
|
||||
"mssql+pyodbc"
|
||||
],
|
||||
"default": "mssql+pyodbc"
|
||||
}
|
||||
},
|
||||
@ -59,6 +63,37 @@
|
||||
"type": "string",
|
||||
"default": "ODBC Driver 18 for SQL Server"
|
||||
},
|
||||
"authenticationMode": {
|
||||
"title": "Authentication Mode",
|
||||
"description": "This parameter determines the mode of authentication for connecting to AzureSQL using ODBC. If 'Active Directory Password' is selected, you need to provide the password. If 'Active Directory Integrated' is selected, password is not required as it uses the logged-in user's credentials. This mode is useful for establishing secure and seamless connections with AzureSQL.",
|
||||
"properties": {
|
||||
"authentication": {
|
||||
"title": "Authentication",
|
||||
"description": "Authentication from Connection String for AzureSQL.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"ActiveDirectoryIntegrated",
|
||||
"ActiveDirectoryPassword"
|
||||
]
|
||||
},
|
||||
"encrypt": {
|
||||
"title": "Encrypt",
|
||||
"description": "Encrypt from Connection String for AzureSQL.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"trustServerCertificate": {
|
||||
"title": "Trust Server Certificate",
|
||||
"description": "Trust Server Certificate from Connection String for AzureSQL.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"connectionTimeout": {
|
||||
"title": "Connection Timeout",
|
||||
"description": "Connection Timeout from Connection String for AzureSQL.",
|
||||
"type": "integer",
|
||||
"default": 30
|
||||
}
|
||||
}
|
||||
},
|
||||
"ingestAllDatabases": {
|
||||
"title": "Ingest All Databases",
|
||||
"description": "Ingest data from all databases in Azuresql. You can use databaseFilterPattern on top of this.",
|
||||
@ -102,5 +137,9 @@
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["hostPort", "username", "database"]
|
||||
}
|
||||
"required": [
|
||||
"hostPort",
|
||||
"database",
|
||||
"username"
|
||||
]
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
{
|
||||
"$id": "https://open-metadata.org/schema/entity/services/connections/database/common/azureConfig.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "Azure Configuration Source",
|
||||
"description": "Azure Database Connection Config",
|
||||
"type": "object",
|
||||
"javaType": "org.openmetadata.schema.services.connections.database.common.AzureConfig",
|
||||
"properties": {
|
||||
"azureConfig": {
|
||||
"title": "Azure Credentials Configuration",
|
||||
"$ref": "../../../../../security/credentials/azureCredentials.json"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
@ -9,13 +9,17 @@
|
||||
"mySQLType": {
|
||||
"description": "Service type.",
|
||||
"type": "string",
|
||||
"enum": ["Mysql"],
|
||||
"enum": [
|
||||
"Mysql"
|
||||
],
|
||||
"default": "Mysql"
|
||||
},
|
||||
"mySQLScheme": {
|
||||
"description": "SQLAlchemy driver scheme options.",
|
||||
"type": "string",
|
||||
"enum": ["mysql+pymysql"],
|
||||
"enum": [
|
||||
"mysql+pymysql"
|
||||
],
|
||||
"default": "mysql+pymysql"
|
||||
}
|
||||
},
|
||||
@ -46,6 +50,9 @@
|
||||
},
|
||||
{
|
||||
"$ref": "./common/iamAuthConfig.json"
|
||||
},
|
||||
{
|
||||
"$ref": "./common/azureConfig.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -108,5 +115,8 @@
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["hostPort", "username"]
|
||||
}
|
||||
"required": [
|
||||
"hostPort",
|
||||
"username"
|
||||
]
|
||||
}
|
@ -50,6 +50,9 @@
|
||||
},
|
||||
{
|
||||
"$ref": "./common/iamAuthConfig.json"
|
||||
},
|
||||
{
|
||||
"$ref": "./common/azureConfig.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -31,7 +31,12 @@
|
||||
"title": "Key Vault Name",
|
||||
"description": "Key Vault Name",
|
||||
"type": "string"
|
||||
},
|
||||
"scopes": {
|
||||
"title": "Scopes",
|
||||
"description": "Scopes to get access token, for e.g. api://6dfX33ab-XXXX-49df-XXXX-3459eX817d3e/.default",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user