diff --git a/ingestion/src/metadata/clients/azure_client.py b/ingestion/src/metadata/clients/azure_client.py new file mode 100644 index 00000000000..f80cc0ad5e6 --- /dev/null +++ b/ingestion/src/metadata/clients/azure_client.py @@ -0,0 +1,85 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module containing Azure Client +""" + +from metadata.generated.schema.security.credentials.azureCredentials import ( + AzureCredentials, +) +from metadata.utils.logger import utils_logger + +logger = utils_logger() + + +class AzureClient: + """ + AzureClient based on AzureCredentials. + """ + + def __init__(self, credentials: "AzureCredentials"): + self.credentials = credentials + if not isinstance(credentials, AzureCredentials): + self.credentials = AzureCredentials.parse_obj(credentials) + + def create_client( + self, + ): + from azure.identity import ClientSecretCredential, DefaultAzureCredential + + try: + if ( + getattr(self.credentials, "tenantId", None) + and getattr(self.credentials, "clientId", None) + and getattr(self.credentials, "clientSecret", None) + ): + logger.info("Using Client Secret Credentials") + return ClientSecretCredential( + tenant_id=self.credentials.tenantId, + client_id=self.credentials.clientId, + client_secret=self.credentials.clientSecret.get_secret_value(), + ) + else: + logger.info("Using Default Azure Credentials") + return DefaultAzureCredential() + except Exception as e: + logger.error(f"Error creating Azure Client: {e}") + raise e + + def create_blob_client(self): + from azure.storage.blob import BlobServiceClient + + try: + logger.info("Creating Blob Service Client") + if self.credentials.accountName: + return BlobServiceClient( + account_url=f"https://{self.credentials.accountName}.blob.core.windows.net/", + credential=self.create_client(), + ) + raise ValueError("Account Name is required to create Blob Service Client") + except Exception as e: + logger.error(f"Error creating Blob Service Client: {e}") + raise e + + def create_secret_client(self): + from azure.keyvault.secrets import SecretClient + + try: + if self.credentials.vaultName: + logger.info("Creating Secret Client") + return SecretClient( + vault_url=f"https://{self.credentials.vaultName}.vault.azure.net/", + credential=self.create_client(), + ) + raise ValueError("Vault Name is required to create a Secret Client") + except Exception as e: + logger.error(f"Error creating Secret Client: {e}") + raise e diff --git a/ingestion/src/metadata/examples/workflows/datalake_azure.yaml b/ingestion/src/metadata/examples/workflows/datalake_azure_client_secret.yaml similarity index 100% rename from ingestion/src/metadata/examples/workflows/datalake_azure.yaml rename to ingestion/src/metadata/examples/workflows/datalake_azure_client_secret.yaml diff --git a/ingestion/src/metadata/examples/workflows/datalake_azure_default.yaml b/ingestion/src/metadata/examples/workflows/datalake_azure_default.yaml new file mode 100644 index 00000000000..2a4f248232e --- /dev/null +++ b/ingestion/src/metadata/examples/workflows/datalake_azure_default.yaml @@ -0,0 +1,29 @@ +source: + type: datalake + serviceName: local_datalake4 + serviceConnection: + config: + type: Datalake + configSource: + securityConfig: + clientId: clientId + accountName: accountName + bucketName: bucket name + prefix: prefix + sourceConfig: + config: + type: DatabaseMetadata + tableFilterPattern: + includes: + - '' +sink: + type: metadata-rest + config: {} +workflowConfig: +# loggerLevel: INFO # DEBUG, INFO, WARN or ERROR + openMetadataServerConfig: + hostPort: http://localhost:8585/api + authProvider: openmetadata + securityConfig: + jwtToken: "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" + \ No newline at end of file diff --git a/ingestion/src/metadata/ingestion/source/dashboard/powerbi/client.py b/ingestion/src/metadata/ingestion/source/dashboard/powerbi/client.py index 90014713592..e72cead1d79 100644 --- a/ingestion/src/metadata/ingestion/source/dashboard/powerbi/client.py +++ b/ingestion/src/metadata/ingestion/source/dashboard/powerbi/client.py @@ -19,6 +19,9 @@ from typing import List, Optional, Tuple import msal +from metadata.generated.schema.entity.services.connections.dashboard.powerBIConnection import ( + PowerBIConnection, +) from metadata.ingestion.api.steps import InvalidSourceException from metadata.ingestion.ometa.client import REST, ClientConfig from metadata.ingestion.source.dashboard.powerbi.models import ( @@ -52,7 +55,7 @@ class PowerBiApiClient: client: REST - def __init__(self, config): + def __init__(self, config: PowerBIConnection): self.config = config self.msal_client = msal.ConfidentialClientApplication( client_id=self.config.clientId, diff --git a/ingestion/src/metadata/ingestion/source/database/azuresql/connection.py b/ingestion/src/metadata/ingestion/source/database/azuresql/connection.py index 1cb201d7b09..9fd23b2fa19 100644 --- a/ingestion/src/metadata/ingestion/source/database/azuresql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/azuresql/connection.py @@ -15,12 +15,13 @@ Source connection handler from typing import Optional, Union from urllib.parse import quote_plus -from sqlalchemy.engine import Engine +from sqlalchemy.engine import URL, Engine from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) from metadata.generated.schema.entity.services.connections.database.azureSQLConnection import ( + Authentication, AzureSQLConnection, ) from metadata.generated.schema.entity.services.connections.database.mssqlConnection import ( @@ -40,13 +41,29 @@ def get_connection_url(connection: Union[AzureSQLConnection, MssqlConnection]) - Build the connection URL """ + if connection.authenticationMode: + connection_string = f"Driver={connection.driver};Server={connection.hostPort};Database={connection.database};" + connection_string += f"Uid={connection.username};" + if ( + connection.authenticationMode.authentication + == Authentication.ActiveDirectoryPassword + ): + connection_string += f"Pwd={connection.password.get_secret_value()};" + + connection_string += f"Encrypt={'yes' if connection.authenticationMode.encrypt else 'no'};TrustServerCertificate={'yes' if connection.authenticationMode.trustServerCertificate else 'no'};" + connection_string += f"Connection Timeout={connection.authenticationMode.connectionTimeout or 30};Authentication={connection.authenticationMode.authentication.value};" + + connection_url = URL.create( + "mssql+pyodbc", query={"odbc_connect": connection_string} + ) + return connection_url url = f"{connection.scheme.value}://" if connection.username: url += f"{quote_plus(connection.username)}" url += ( f":{quote_plus(connection.password.get_secret_value())}" - if connection + if connection.password else "" ) url += "@" @@ -54,12 +71,13 @@ def get_connection_url(connection: Union[AzureSQLConnection, MssqlConnection]) - url += f"{connection.hostPort}" url += f"/{quote_plus(connection.database)}" if connection.database else "" url += f"?driver={quote_plus(connection.driver)}" + options = get_connection_options_dict(connection) if options: if not connection.database: url += "/" params = "&".join( - f"{key}={quote_plus(value)}" for (key, value) in options.items() if value + f"{key}={quote_plus(value)}" for key, value in options.items() if value ) url = f"{url}?{params}" diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py index 56e5315da5b..c5be85eb5a1 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py @@ -20,6 +20,7 @@ from typing import Optional from google.cloud import storage +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) @@ -88,22 +89,9 @@ def _(config: GCSConfig): @get_datalake_client.register def _(config: AzureConfig): - from azure.identity import ClientSecretCredential - from azure.storage.blob import BlobServiceClient try: - credentials = ClientSecretCredential( - config.securityConfig.tenantId, - config.securityConfig.clientId, - config.securityConfig.clientSecret.get_secret_value(), - ) - - azure_client = BlobServiceClient( - f"https://{config.securityConfig.accountName}.blob.core.windows.net/", - credential=credentials, - ) - return azure_client - + return AzureClient(config.securityConfig).create_blob_client() except Exception as exc: raise RuntimeError( f"Unknown error connecting with {config.securityConfig}: {exc}." diff --git a/ingestion/src/metadata/ingestion/source/database/dbt/dbt_config.py b/ingestion/src/metadata/ingestion/source/database/dbt/dbt_config.py index c130494c73c..3dfb91c6588 100644 --- a/ingestion/src/metadata/ingestion/source/database/dbt/dbt_config.py +++ b/ingestion/src/metadata/ingestion/source/database/dbt/dbt_config.py @@ -20,6 +20,7 @@ from typing import Dict, Iterable, List, Optional, Tuple import requests from metadata.clients.aws_client import AWSClient +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.metadataIngestion.dbtconfig.dbtAzureConfig import ( DbtAzureConfig, ) @@ -357,21 +358,8 @@ def _(config: DbtGcsConfig): def _(config: DbtAzureConfig): try: bucket_name, prefix = get_dbt_prefix_config(config) - from azure.identity import ( # pylint: disable=import-outside-toplevel - ClientSecretCredential, - ) - from azure.storage.blob import ( # pylint: disable=import-outside-toplevel - BlobServiceClient, - ) - client = BlobServiceClient( - f"https://{config.dbtSecurityConfig.accountName}.blob.core.windows.net/", - credential=ClientSecretCredential( - config.dbtSecurityConfig.tenantId, - config.dbtSecurityConfig.clientId, - config.dbtSecurityConfig.clientSecret.get_secret_value(), - ), - ) + client = AzureClient(config.dbtSecurityConfig).create_blob_client() if not bucket_name: container_dicts = client.list_containers() diff --git a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py index dca28eefbae..f5e6e61d5f4 100644 --- a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py @@ -16,9 +16,13 @@ from typing import Optional from sqlalchemy.engine import Engine +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) +from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( + BasicAuth, +) from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( MysqlConnection, ) @@ -38,6 +42,16 @@ def get_connection(connection: MysqlConnection) -> Engine: """ Create connection """ + if hasattr(connection.authType, "azureConfig"): + azure_client = AzureClient(connection.authType.azureConfig).create_client() + if not connection.authType.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + access_token_obj = azure_client.get_token( + *connection.authType.azureConfig.scopes.split(",") + ) + connection.authType = BasicAuth(password=access_token_obj.token) if connection.sslCA or connection.sslCert or connection.sslKey: if not connection.connectionOptions: connection.connectionOptions = init_empty_connection_options() diff --git a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py index 2b34896cd0e..3427fdb0e5f 100644 --- a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py @@ -17,9 +17,13 @@ from typing import Optional from sqlalchemy.engine import Engine +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) +from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( + BasicAuth, +) from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( PostgresConnection, SslMode, @@ -46,6 +50,17 @@ def get_connection(connection: PostgresConnection) -> Engine: """ Create connection """ + + if hasattr(connection.authType, "azureConfig"): + azure_client = AzureClient(connection.authType.azureConfig).create_client() + if not connection.authType.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + access_token_obj = azure_client.get_token( + *connection.authType.azureConfig.scopes.split(",") + ) + connection.authType = BasicAuth(password=access_token_obj.token) if connection.sslMode: if not connection.connectionArguments: connection.connectionArguments = init_empty_connection_arguments() diff --git a/ingestion/src/metadata/utils/credentials.py b/ingestion/src/metadata/utils/credentials.py index dea1a6e7888..233a5a5445e 100644 --- a/ingestion/src/metadata/utils/credentials.py +++ b/ingestion/src/metadata/utils/credentials.py @@ -15,7 +15,7 @@ import base64 import json import os import tempfile -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from cryptography.hazmat.primitives import serialization from google import auth @@ -25,6 +25,9 @@ from metadata.generated.schema.security.credentials.gcpCredentials import ( GCPCredentials, GcpCredentialsPath, ) +from metadata.generated.schema.security.credentials.gcpExternalAccount import ( + GcpCredentialsValuesExternalAccount, +) from metadata.generated.schema.security.credentials.gcpValues import ( GcpCredentialsValues, ) @@ -85,7 +88,9 @@ def create_credential_tmp_file(credentials: dict) -> str: return temp_file_path -def build_google_credentials_dict(gcp_values: GcpCredentialsValues) -> Dict[str, str]: +def build_google_credentials_dict( + gcp_values: Union[GcpCredentialsValues, GcpCredentialsValuesExternalAccount] +) -> Dict[str, str]: """ Given GcPCredentialsValues, build a dictionary as the JSON file downloaded from GCP with the service_account diff --git a/ingestion/src/metadata/utils/secrets/azure_kv_secrets_manager.py b/ingestion/src/metadata/utils/secrets/azure_kv_secrets_manager.py index 4682fc23ace..566c9154850 100644 --- a/ingestion/src/metadata/utils/secrets/azure_kv_secrets_manager.py +++ b/ingestion/src/metadata/utils/secrets/azure_kv_secrets_manager.py @@ -17,9 +17,9 @@ import traceback from abc import ABC from typing import Optional -from azure.identity import ClientSecretCredential, DefaultAzureCredential -from azure.keyvault.secrets import KeyVaultSecret, SecretClient +from azure.keyvault.secrets import KeyVaultSecret +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.security.secrets.secretsManagerClientLoader import ( SecretsManagerClientLoader, ) @@ -105,23 +105,7 @@ class AzureKVSecretsManager(ExternalSecretsManager, ABC): ): super().__init__(provider=SecretsManagerProvider.azure_kv, loader=loader) - if ( - self.credentials.tenantId - and self.credentials.clientId - and self.credentials.clientSecret - ): - azure_identity = ClientSecretCredential( - tenant_id=self.credentials.tenantId, - client_id=self.credentials.clientId, - client_secret=self.credentials.clientSecret.get_secret_value(), - ) - else: - azure_identity = DefaultAzureCredential() - - self.client = SecretClient( - vault_url=f"https://{self.credentials.vaultName}.vault.azure.net/", - credential=azure_identity, - ) + self.client = AzureClient(self.credentials).create_secret_client() def get_string_value(self, secret_id: str) -> str: """ diff --git a/ingestion/src/metadata/utils/storage_metadata_config.py b/ingestion/src/metadata/utils/storage_metadata_config.py index 3eb12a1670c..7cfbdd8324e 100644 --- a/ingestion/src/metadata/utils/storage_metadata_config.py +++ b/ingestion/src/metadata/utils/storage_metadata_config.py @@ -17,6 +17,7 @@ from functools import singledispatch import requests +from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import ( AzureConfig, ) @@ -153,21 +154,7 @@ def _(config: StorageMetadataAdlsConfig) -> ManifestMetadataConfig: else STORAGE_METADATA_MANIFEST_FILE_NAME ) - from azure.identity import ( # pylint: disable=import-outside-toplevel - ClientSecretCredential, - ) - from azure.storage.blob import ( # pylint: disable=import-outside-toplevel - BlobServiceClient, - ) - - blob_client = BlobServiceClient( - account_url=f"https://{config.securityConfig.accountName}.blob.core.windows.net/", - credential=ClientSecretCredential( - config.securityConfig.tenantId, - config.securityConfig.clientId, - config.securityConfig.clientSecret.get_secret_value(), - ), - ) + blob_client = AzureClient(config.securityConfig).create_blob_client() reader = get_reader( config_source=AzureConfig(securityConfig=config.securityConfig), diff --git a/ingestion/tests/integration/ometa/test_ometa_patch.py b/ingestion/tests/integration/ometa/test_ometa_patch.py index a5eb78932c2..409536091f2 100644 --- a/ingestion/tests/integration/ometa/test_ometa_patch.py +++ b/ingestion/tests/integration/ometa/test_ometa_patch.py @@ -17,17 +17,6 @@ import time from datetime import datetime from unittest import TestCase -from ingestion.tests.integration.integration_base import ( - generate_name, - get_create_entity, - get_create_service, - get_create_team_entity, - get_create_test_case, - get_create_test_definition, - get_create_test_suite, - get_create_user_entity, - int_admin_ometa, -) from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.data.table import Column, DataType, Table @@ -54,6 +43,18 @@ from metadata.ingestion.models.patch_request import ( from metadata.ingestion.models.table_metadata import ColumnTag from metadata.utils.helpers import find_column_in_table +from ..integration_base import ( + generate_name, + get_create_entity, + get_create_service, + get_create_team_entity, + get_create_test_case, + get_create_test_definition, + get_create_test_suite, + get_create_user_entity, + int_admin_ometa, +) + PII_TAG_LABEL = TagLabel( tagFQN="PII.Sensitive", labelType=LabelType.Automated, diff --git a/ingestion/tests/integration/profiler/test_nosql_profiler.py b/ingestion/tests/integration/profiler/test_nosql_profiler.py index 693ad7ec7c1..2d00d6b3e4c 100644 --- a/ingestion/tests/integration/profiler/test_nosql_profiler.py +++ b/ingestion/tests/integration/profiler/test_nosql_profiler.py @@ -33,7 +33,6 @@ from unittest import TestCase from pymongo import MongoClient, database from testcontainers.mongodb import MongoDbContainer -from ingestion.tests.integration.integration_base import int_admin_ometa from metadata.generated.schema.entity.data.table import ColumnProfile, Table from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.ingestion.ometa.ometa_api import OpenMetadata @@ -46,6 +45,8 @@ from metadata.workflow.metadata import MetadataWorkflow from metadata.workflow.profiler import ProfilerWorkflow from metadata.workflow.workflow_output_handler import print_status +from ..integration_base import int_admin_ometa + SERVICE_NAME = Path(__file__).stem diff --git a/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.py b/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.py index 5a9573d5038..b9bfe8e7ed1 100644 --- a/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.py +++ b/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.py @@ -1,7 +1,7 @@ """ This file has been generated from dag_runner.j2 """ -from openmetadata.workflows import workflow_factory +from openmetadata_managed_apis.workflows import workflow_factory workflow = workflow_factory.WorkflowFactory.create( "/airflow/dag_generated_configs/local_redshift_profiler_e9AziRXs.json" diff --git a/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.txt b/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.txt index b3945ba7e25..bdb70bb1fd9 100644 --- a/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.txt +++ b/ingestion/tests/unit/metadata/cli/resources/profiler_workflow.txt @@ -2,7 +2,7 @@ This file has been generated from dag_runner.j2 """ from airflow import DAG -from openmetadata.workflows import workflow_factory +from openmetadata_managed_apis.workflows import workflow_factory workflow = workflow_factory.WorkflowFactory.create("/airflow/dag_generated_configs/local_redshift_profiler_e9AziRXs.json") workflow.generate_dag(globals()) \ No newline at end of file diff --git a/ingestion/tests/unit/test_azure_credentials.py b/ingestion/tests/unit/test_azure_credentials.py new file mode 100644 index 00000000000..bb1f03f96c5 --- /dev/null +++ b/ingestion/tests/unit/test_azure_credentials.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import patch + +from metadata.clients.azure_client import AzureClient +from metadata.generated.schema.security.credentials.azureCredentials import ( + AzureCredentials, +) + + +class TestAzureClient(unittest.TestCase): + @patch("azure.identity.ClientSecretCredential") + @patch("azure.identity.DefaultAzureCredential") + def test_create_client( + self, + mock_default_credential, + mock_client_secret_credential, + ): + # Test with ClientSecretCredential + credentials = AzureCredentials( + clientId="clientId", clientSecret="clientSecret", tenantId="tenantId" + ) + instance = AzureClient(credentials) + instance.create_client() + + mock_client_secret_credential.assert_called_once() + mock_client_secret_credential.reset_mock() + + credentials = AzureCredentials( + clientId="clientId", + ) + instance = AzureClient(credentials) + + instance.create_client() + + mock_default_credential.assert_called_once() + + @patch("azure.storage.blob.BlobServiceClient") + def test_create_blob_client(self, mock_blob_service_client): + credentials = AzureCredentials( + clientId="clientId", clientSecret="clientSecret", tenantId="tenantId" + ) + with self.assertRaises(ValueError): + AzureClient(credentials=credentials).create_blob_client() + + credentials.accountName = "accountName" + AzureClient(credentials=credentials).create_blob_client() + mock_blob_service_client.assert_called_once() + + @patch("azure.keyvault.secrets.SecretClient") + def test_create_secret_client(self, mock_secret_client): + credentials = AzureCredentials( + clientId="clientId", clientSecret="clientSecret", tenantId="tenantId" + ) + with self.assertRaises(ValueError): + AzureClient(credentials=credentials).create_secret_client() + + credentials.vaultName = "vaultName" + AzureClient(credentials=credentials).create_secret_client() + mock_secret_client.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/ingestion/tests/unit/test_build_connection_url.py b/ingestion/tests/unit/test_build_connection_url.py new file mode 100644 index 00000000000..8cf60dae677 --- /dev/null +++ b/ingestion/tests/unit/test_build_connection_url.py @@ -0,0 +1,138 @@ +import unittest +from unittest.mock import patch + +from azure.core.credentials import AccessToken +from azure.identity import ClientSecretCredential + +from metadata.generated.schema.entity.services.connections.database.azureSQLConnection import ( + Authentication, + AuthenticationMode, + AzureSQLConnection, +) +from metadata.generated.schema.entity.services.connections.database.common.azureConfig import ( + AzureConfigurationSource, +) +from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( + BasicAuth, +) +from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( + MysqlConnection, +) +from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( + PostgresConnection, +) +from metadata.generated.schema.security.credentials.azureCredentials import ( + AzureCredentials, +) +from metadata.ingestion.source.database.azuresql.connection import get_connection_url +from metadata.ingestion.source.database.mysql.connection import ( + get_connection as mysql_get_connection, +) +from metadata.ingestion.source.database.postgres.connection import ( + get_connection as postgres_get_connection, +) + + +class TestGetConnectionURL(unittest.TestCase): + def test_get_connection_url_wo_active_directory_password(self): + connection = AzureSQLConnection( + driver="SQL Server", + hostPort="myserver.database.windows.net", + database="mydb", + username="myuser", + password="mypassword", + authenticationMode=AuthenticationMode( + authentication=Authentication.ActiveDirectoryPassword, + encrypt=True, + trustServerCertificate=False, + connectionTimeout=45, + ), + ) + expected_url = "mssql+pyodbc://?odbc_connect=Driver%3DSQL+Server%3BServer%3Dmyserver.database.windows.net%3BDatabase%3Dmydb%3BUid%3Dmyuser%3BPwd%3Dmypassword%3BEncrypt%3Dyes%3BTrustServerCertificate%3Dno%3BConnection+Timeout%3D45%3BAuthentication%3DActiveDirectoryPassword%3B" + self.assertEqual(str(get_connection_url(connection)), expected_url) + + connection = AzureSQLConnection( + driver="SQL Server", + hostPort="myserver.database.windows.net", + database="mydb", + username="myuser", + password="mypassword", + authenticationMode=AuthenticationMode( + authentication=Authentication.ActiveDirectoryPassword, + ), + ) + + expected_url = "mssql+pyodbc://?odbc_connect=Driver%3DSQL+Server%3BServer%3Dmyserver.database.windows.net%3BDatabase%3Dmydb%3BUid%3Dmyuser%3BPwd%3Dmypassword%3BEncrypt%3Dno%3BTrustServerCertificate%3Dno%3BConnection+Timeout%3D30%3BAuthentication%3DActiveDirectoryPassword%3B" + self.assertEqual(str(get_connection_url(connection)), expected_url) + + def test_get_connection_url_mysql(self): + connection = MysqlConnection( + username="openmetadata_user", + authType=BasicAuth(password="openmetadata_password"), + hostPort="localhost:3306", + databaseSchema="openmetadata_db", + ) + engine_connection = mysql_get_connection(connection) + self.assertEqual( + str(engine_connection.url), + "mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", + ) + connection = MysqlConnection( + username="openmetadata_user", + authType=AzureConfigurationSource( + azureConfig=AzureCredentials( + clientId="clientid", + tenantId="tenantid", + clientSecret="clientsecret", + scopes="scope1,scope2", + ) + ), + hostPort="localhost:3306", + databaseSchema="openmetadata_db", + ) + with patch.object( + ClientSecretCredential, + "get_token", + return_value=AccessToken(token="mocked_token", expires_on=100), + ): + engine_connection = mysql_get_connection(connection) + self.assertEqual( + str(engine_connection.url), + "mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db", + ) + + def test_get_connection_url_postgres(self): + connection = PostgresConnection( + username="openmetadata_user", + authType=BasicAuth(password="openmetadata_password"), + hostPort="localhost:3306", + database="openmetadata_db", + ) + engine_connection = postgres_get_connection(connection) + self.assertEqual( + str(engine_connection.url), + "postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", + ) + connection = PostgresConnection( + username="openmetadata_user", + authType=AzureConfigurationSource( + azureConfig=AzureCredentials( + clientId="clientid", + tenantId="tenantid", + clientSecret="clientsecret", + scopes="scope1,scope2", + ) + ), + hostPort="localhost:3306", + database="openmetadata_db", + ) + with patch.object( + ClientSecretCredential, + "get_token", + return_value=AccessToken(token="mocked_token", expires_on=100), + ): + engine_connection = postgres_get_connection(connection) + self.assertEqual( + str(engine_connection.url), + "postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db", + ) diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/azureSQLConnection.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/azureSQLConnection.json index b2b0a2cdc88..1cb390f0024 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/azureSQLConnection.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/azureSQLConnection.json @@ -9,13 +9,17 @@ "azureSQLType": { "description": "Service type.", "type": "string", - "enum": ["AzureSQL"], + "enum": [ + "AzureSQL" + ], "default": "AzureSQL" }, "azureSQLScheme": { "description": "SQLAlchemy driver scheme options.", "type": "string", - "enum": ["mssql+pyodbc"], + "enum": [ + "mssql+pyodbc" + ], "default": "mssql+pyodbc" } }, @@ -59,6 +63,37 @@ "type": "string", "default": "ODBC Driver 18 for SQL Server" }, + "authenticationMode": { + "title": "Authentication Mode", + "description": "This parameter determines the mode of authentication for connecting to AzureSQL using ODBC. If 'Active Directory Password' is selected, you need to provide the password. If 'Active Directory Integrated' is selected, password is not required as it uses the logged-in user's credentials. This mode is useful for establishing secure and seamless connections with AzureSQL.", + "properties": { + "authentication": { + "title": "Authentication", + "description": "Authentication from Connection String for AzureSQL.", + "type": "string", + "enum": [ + "ActiveDirectoryIntegrated", + "ActiveDirectoryPassword" + ] + }, + "encrypt": { + "title": "Encrypt", + "description": "Encrypt from Connection String for AzureSQL.", + "type": "boolean" + }, + "trustServerCertificate": { + "title": "Trust Server Certificate", + "description": "Trust Server Certificate from Connection String for AzureSQL.", + "type": "boolean" + }, + "connectionTimeout": { + "title": "Connection Timeout", + "description": "Connection Timeout from Connection String for AzureSQL.", + "type": "integer", + "default": 30 + } + } + }, "ingestAllDatabases": { "title": "Ingest All Databases", "description": "Ingest data from all databases in Azuresql. You can use databaseFilterPattern on top of this.", @@ -102,5 +137,9 @@ } }, "additionalProperties": false, - "required": ["hostPort", "username", "database"] -} + "required": [ + "hostPort", + "database", + "username" + ] +} \ No newline at end of file diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json new file mode 100644 index 00000000000..364f69347c8 --- /dev/null +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json @@ -0,0 +1,15 @@ +{ + "$id": "https://open-metadata.org/schema/entity/services/connections/database/common/azureConfig.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Azure Configuration Source", + "description": "Azure Database Connection Config", + "type": "object", + "javaType": "org.openmetadata.schema.services.connections.database.common.AzureConfig", + "properties": { + "azureConfig": { + "title": "Azure Credentials Configuration", + "$ref": "../../../../../security/credentials/azureCredentials.json" + } + }, + "additionalProperties": false +} \ No newline at end of file diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/mysqlConnection.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/mysqlConnection.json index c11c496c442..2a96d10ed68 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/mysqlConnection.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/mysqlConnection.json @@ -9,13 +9,17 @@ "mySQLType": { "description": "Service type.", "type": "string", - "enum": ["Mysql"], + "enum": [ + "Mysql" + ], "default": "Mysql" }, "mySQLScheme": { "description": "SQLAlchemy driver scheme options.", "type": "string", - "enum": ["mysql+pymysql"], + "enum": [ + "mysql+pymysql" + ], "default": "mysql+pymysql" } }, @@ -46,6 +50,9 @@ }, { "$ref": "./common/iamAuthConfig.json" + }, + { + "$ref": "./common/azureConfig.json" } ] }, @@ -108,5 +115,8 @@ } }, "additionalProperties": false, - "required": ["hostPort", "username"] -} + "required": [ + "hostPort", + "username" + ] +} \ No newline at end of file diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/postgresConnection.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/postgresConnection.json index ac0445d63de..b4e32b29c9e 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/postgresConnection.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/postgresConnection.json @@ -50,6 +50,9 @@ }, { "$ref": "./common/iamAuthConfig.json" + }, + { + "$ref": "./common/azureConfig.json" } ] }, diff --git a/openmetadata-spec/src/main/resources/json/schema/security/credentials/azureCredentials.json b/openmetadata-spec/src/main/resources/json/schema/security/credentials/azureCredentials.json index 330f178d565..7db1e585769 100644 --- a/openmetadata-spec/src/main/resources/json/schema/security/credentials/azureCredentials.json +++ b/openmetadata-spec/src/main/resources/json/schema/security/credentials/azureCredentials.json @@ -31,7 +31,12 @@ "title": "Key Vault Name", "description": "Key Vault Name", "type": "string" + }, + "scopes": { + "title": "Scopes", + "description": "Scopes to get access token, for e.g. api://6dfX33ab-XXXX-49df-XXXX-3459eX817d3e/.default", + "type": "string" } }, "additionalProperties": false -} +} \ No newline at end of file