diff --git a/ingestion/src/metadata/ingestion/source/dashboard/superset/connection.py b/ingestion/src/metadata/ingestion/source/dashboard/superset/connection.py index 90d8eb7f78b..995dc4f261d 100644 --- a/ingestion/src/metadata/ingestion/source/dashboard/superset/connection.py +++ b/ingestion/src/metadata/ingestion/source/dashboard/superset/connection.py @@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect MysqlConnection as MysqlConnectionConfig, ) from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( - PostgresConnection, + PostgresConnection as PostgresConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, @@ -47,9 +47,7 @@ from metadata.ingestion.source.dashboard.superset.queries import ( FETCH_DASHBOARDS_TEST, ) from metadata.ingestion.source.database.mysql.connection import MySQLConnection -from metadata.ingestion.source.database.postgres.connection import ( - get_connection as pg_get_connection, -) +from metadata.ingestion.source.database.postgres.connection import PostgresConnection from metadata.utils.constants import THREE_MIN @@ -61,10 +59,10 @@ def get_connection( """ if isinstance(connection.connection, SupersetApiConnection): return SupersetAPIClient(connection) - if isinstance(connection.connection, PostgresConnection): - return pg_get_connection(connection=connection.connection) + if isinstance(connection.connection, PostgresConnectionConfig): + return PostgresConnection(connection.connection).client if isinstance(connection.connection, MysqlConnectionConfig): - return MySQLConnection(connection.connection).get_client() + return MySQLConnection(connection.connection).client return None diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py index b9920f62692..49c150de2f0 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py @@ -12,8 +12,6 @@ """ Source connection handler """ -from dataclasses import dataclass -from functools import singledispatch from typing import Optional from metadata.generated.schema.entity.automations.workflow import ( @@ -29,88 +27,68 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C S3Config, ) from metadata.generated.schema.entity.services.connections.database.datalakeConnection import ( - DatalakeConnection, + DatalakeConnection as DatalakeConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, ) +from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.connections.test_connections import test_connection_steps from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.datalake.clients.azure_blob import ( DatalakeAzureBlobClient, ) +from metadata.ingestion.source.database.datalake.clients.base import DatalakeBaseClient from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client from metadata.utils.constants import THREE_MIN -# Only import specific datalake dependencies if necessary -# pylint: disable=import-outside-toplevel -@dataclass -class DatalakeClient: - def __init__(self, client, config) -> None: - self.client = client - self.config = config +class DatalakeConnection(BaseConnection[DatalakeConnectionConfig, DatalakeBaseClient]): + def _get_client(self) -> DatalakeBaseClient: + """ + Return the appropriate Datalake client based on configSource. + """ + connection = self.service_connection + if isinstance(connection.configSource, S3Config): + return DatalakeS3Client.from_config(connection.configSource) + elif isinstance(connection.configSource, GCSConfig): + return DatalakeGcsClient.from_config(connection.configSource) + elif isinstance(connection.configSource, AzureConfig): + return DatalakeAzureBlobClient.from_config(connection.configSource) + else: + msg = f"Config not implemented for type {type(connection.configSource)}: {connection.configSource}" + raise NotImplementedError(msg) -@singledispatch -def get_datalake_client(config): - """ - Method to retrieve datalake client from the config - """ - if config: - msg = f"Config not implemented for type {type(config)}: {config}" - raise NotImplementedError(msg) + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + raise NotImplementedError("get_connection_dict is not implemented for Datalake") + def test_connection( + self, + metadata: OpenMetadata, + automation_workflow: Optional[AutomationWorkflow] = None, + timeout_seconds: Optional[int] = THREE_MIN, + ) -> TestConnectionResult: + """ + Test connection. This can be executed either as part + of a metadata workflow or during an Automation Workflow + """ + test_fn = { + "ListBuckets": self.client.get_test_list_buckets_fn( + self.service_connection.bucketName + ), + } -@get_datalake_client.register -def _(config: S3Config): - return DatalakeS3Client.from_config(config) - - -@get_datalake_client.register -def _(config: GCSConfig): - return DatalakeGcsClient.from_config(config) - - -@get_datalake_client.register -def _(config: AzureConfig): - return DatalakeAzureBlobClient.from_config(config) - - -def get_connection(connection: DatalakeConnection) -> DatalakeClient: - """ - Create connection. - - Returns an AWS, Azure or GCS Clients. - """ - return DatalakeClient( - client=get_datalake_client(connection.configSource), - config=connection, - ) - - -def test_connection( - metadata: OpenMetadata, - connection: DatalakeClient, - service_connection: DatalakeConnection, - automation_workflow: Optional[AutomationWorkflow] = None, - timeout_seconds: Optional[int] = THREE_MIN, -) -> TestConnectionResult: - """ - Test connection. This can be executed either as part - of a metadata workflow or during an Automation Workflow - """ - test_fn = { - "ListBuckets": connection.client.get_test_list_buckets_fn( - connection.config.bucketName - ), - } - - return test_connection_steps( - metadata=metadata, - test_fn=test_fn, - service_type=service_connection.type.value, - automation_workflow=automation_workflow, - timeout_seconds=timeout_seconds, - ) + return test_connection_steps( + metadata=metadata, + test_fn=test_fn, + service_type=self.service_connection.type.value + if self.service_connection.type + else "Datalake", + automation_workflow=automation_workflow, + timeout_seconds=timeout_seconds, + ) diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py index 00ed5d91ee3..81b88cbc277 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py @@ -88,15 +88,14 @@ class DatalakeSource(DatabaseServiceSource): ) self.metadata = metadata self.service_connection = self.config.serviceConnection.root.config - self.connection = get_connection(self.service_connection) - self.client = self.connection.client + self.client = get_connection(self.service_connection) self.table_constraints = None self.database_source_state = set() self.config_source = self.service_connection.configSource - self.connection_obj = self.connection + self.connection_obj = self.client self.test_connection() self.reader = get_reader( - config_source=self.config_source, client=self.client._client + config_source=self.config_source, client=self.client.client ) @classmethod diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/service_spec.py b/ingestion/src/metadata/ingestion/source/database/datalake/service_spec.py index 98a417a7e41..f4aa063ffe5 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/service_spec.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/service_spec.py @@ -1,6 +1,7 @@ from metadata.data_quality.interface.pandas.pandas_test_suite_interface import ( PandasTestSuiteInterface, ) +from metadata.ingestion.source.database.datalake.connection import DatalakeConnection from metadata.ingestion.source.database.datalake.metadata import DatalakeSource from metadata.profiler.interface.pandas.profiler_interface import ( PandasProfilerInterface, @@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec( profiler_class=PandasProfilerInterface, test_suite_class=PandasTestSuiteInterface, sampler_class=DatalakeSampler, + connection_class=DatalakeConnection, ) diff --git a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py index 25db52514d3..8dedf739892 100644 --- a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py @@ -16,7 +16,6 @@ from typing import Optional from sqlalchemy.engine import Engine -from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) @@ -47,6 +46,7 @@ from metadata.ingestion.source.database.mysql.queries import ( MYSQL_TEST_GET_QUERIES_SLOW_LOGS, ) from metadata.utils.constants import THREE_MIN +from metadata.utils.credentials import get_azure_access_token class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]): @@ -57,20 +57,8 @@ class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]): connection = self.service_connection if isinstance(connection.authType, AzureConfigurationSource): - if not connection.authType.azureConfig: - raise ValueError("Azure Config is missing") - azure_client = AzureClient(connection.authType.azureConfig).create_client() - if not connection.authType.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/" - "en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-micr" - "osoft-entra-access-token and fetch the resource associated with it," - " for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - access_token_obj = azure_client.get_token( - *connection.authType.azureConfig.scopes.split(",") - ) - connection.authType = BasicAuth(password=access_token_obj.token) # type: ignore + access_token = get_azure_access_token(connection.authType) + connection.authType = BasicAuth(password=access_token) # type: ignore return create_generic_db_connection( connection=connection, get_connection_url_fn=get_connection_url_common, diff --git a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py index d8584526630..c03d49d0480 100644 --- a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py @@ -12,20 +12,21 @@ """ Source connection handler """ - from typing import Optional from sqlalchemy.engine import Engine -from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) +from metadata.generated.schema.entity.services.connections.database.common.azureConfig import ( + AzureConfigurationSource, +) from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( BasicAuth, ) from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( - PostgresConnection, + PostgresConnection as PostgresConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, @@ -35,6 +36,7 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_url_common, ) +from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.connections.test_connections import test_connection_db_common from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.postgres.queries import ( @@ -46,54 +48,55 @@ from metadata.ingestion.source.database.postgres.utils import ( get_postgres_time_column_name, ) from metadata.utils.constants import THREE_MIN +from metadata.utils.credentials import get_azure_access_token -def get_connection(connection: PostgresConnection) -> Engine: - """ - Create connection - """ +class PostgresConnection(BaseConnection[PostgresConnectionConfig, Engine]): + def _get_client(self) -> Engine: + """ + Return the SQLAlchemy Engine for PostgreSQL. + """ + connection = self.service_connection - if hasattr(connection.authType, "azureConfig"): - azure_client = AzureClient(connection.authType.azureConfig).create_client() - if not connection.authType.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - access_token_obj = azure_client.get_token( - *connection.authType.azureConfig.scopes.split(",") + if isinstance(connection.authType, AzureConfigurationSource): + access_token = get_azure_access_token(connection.authType) + connection.authType = BasicAuth(password=access_token) # type: ignore + return create_generic_db_connection( + connection=connection, + get_connection_url_fn=get_connection_url_common, + get_connection_args_fn=get_connection_args_common, ) - connection.authType = BasicAuth(password=access_token_obj.token) - return create_generic_db_connection( - connection=connection, - get_connection_url_fn=get_connection_url_common, - get_connection_args_fn=get_connection_args_common, - ) + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + raise NotImplementedError( + "get_connection_dict is not implemented for PostgreSQL" + ) -def test_connection( - metadata: OpenMetadata, - engine: Engine, - service_connection: PostgresConnection, - automation_workflow: Optional[AutomationWorkflow] = None, - timeout_seconds: Optional[int] = THREE_MIN, -) -> TestConnectionResult: - """ - Test connection. This can be executed either as part - of a metadata workflow or during an Automation Workflow - """ - - queries = { - "GetQueries": POSTGRES_TEST_GET_QUERIES.format( - time_column_name=get_postgres_time_column_name(engine=engine), - ), - "GetDatabases": POSTGRES_GET_DATABASE, - "GetTags": POSTGRES_TEST_GET_TAGS, - } - return test_connection_db_common( - metadata=metadata, - engine=engine, - service_connection=service_connection, - automation_workflow=automation_workflow, - queries=queries, - timeout_seconds=timeout_seconds, - ) + def test_connection( + self, + metadata: OpenMetadata, + automation_workflow: Optional[AutomationWorkflow] = None, + timeout_seconds: Optional[int] = THREE_MIN, + ) -> TestConnectionResult: + """ + Test connection. This can be executed either as part + of a metadata workflow or during an Automation Workflow + """ + queries = { + "GetQueries": POSTGRES_TEST_GET_QUERIES.format( + time_column_name=get_postgres_time_column_name(engine=self.client), + ), + "GetDatabases": POSTGRES_GET_DATABASE, + "GetTags": POSTGRES_TEST_GET_TAGS, + } + return test_connection_db_common( + metadata=metadata, + engine=self.client, + service_connection=self.service_connection, + automation_workflow=automation_workflow, + timeout_seconds=timeout_seconds, + queries=queries, + ) diff --git a/ingestion/src/metadata/ingestion/source/database/postgres/service_spec.py b/ingestion/src/metadata/ingestion/source/database/postgres/service_spec.py index 3bea308b164..95662ca3ff7 100644 --- a/ingestion/src/metadata/ingestion/source/database/postgres/service_spec.py +++ b/ingestion/src/metadata/ingestion/source/database/postgres/service_spec.py @@ -1,3 +1,4 @@ +from metadata.ingestion.source.database.postgres.connection import PostgresConnection from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource from metadata.ingestion.source.database.postgres.metadata import PostgresSource from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource @@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec( metadata_source_class=PostgresSource, lineage_source_class=PostgresLineageSource, usage_source_class=PostgresUsageSource, + connection_class=PostgresConnection, ) diff --git a/ingestion/src/metadata/ingestion/source/database/trino/connection.py b/ingestion/src/metadata/ingestion/source/database/trino/connection.py index 0dcdc9f8dc7..22f032b251f 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/trino/connection.py @@ -20,7 +20,6 @@ from requests import Session from sqlalchemy.engine import Engine from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication -from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) @@ -53,6 +52,7 @@ from metadata.ingestion.connections.test_connections import ( from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE from metadata.utils.constants import THREE_MIN +from metadata.utils.credentials import get_azure_access_token # pylint: disable=unused-argument @@ -82,17 +82,11 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]): connection_copy = deepcopy(connection) if hasattr(connection.authType, "azureConfig"): - azure_client = AzureClient(connection.authType.azureConfig).create_client() - if not connection.authType.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - access_token_obj = azure_client.get_token( - *connection.authType.azureConfig.scopes.split(",") - ) + auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType) + access_token = get_azure_access_token(auth_type) if not connection.connectionOptions: connection.connectionOptions = init_empty_connection_options() - connection.connectionOptions.root["access_token"] = access_token_obj.token + connection.connectionOptions.root["access_token"] = access_token # Update the connection with the connection arguments connection_copy.connectionArguments = self.build_connection_args( @@ -355,12 +349,4 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]): Get the azure token for the trino connection """ auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType) - - if not auth_type.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - - azure_client = AzureClient(auth_type.azureConfig).create_client() - - return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token + return get_azure_access_token(auth_type) diff --git a/ingestion/src/metadata/sampler/pandas/sampler.py b/ingestion/src/metadata/sampler/pandas/sampler.py index 5cddf6b4858..f4bc10e9803 100644 --- a/ingestion/src/metadata/sampler/pandas/sampler.py +++ b/ingestion/src/metadata/sampler/pandas/sampler.py @@ -60,7 +60,7 @@ class DatalakeSampler(SamplerInterface, PandasInterfaceMixin): return self._table def get_client(self): - return self.connection.client + return self.connection def _partitioned_table(self): """Get partitioned table""" diff --git a/ingestion/src/metadata/utils/credentials.py b/ingestion/src/metadata/utils/credentials.py index 763923934fa..45a4caf1a97 100644 --- a/ingestion/src/metadata/utils/credentials.py +++ b/ingestion/src/metadata/utils/credentials.py @@ -21,6 +21,10 @@ from cryptography.hazmat.primitives import serialization from google import auth from google.auth import impersonated_credentials +from metadata.clients.azure_client import AzureClient +from metadata.generated.schema.entity.services.connections.database.common.azureConfig import ( + AzureConfigurationSource, +) from metadata.generated.schema.security.credentials.gcpCredentials import ( GcpADC, GCPCredentials, @@ -237,3 +241,34 @@ def get_gcp_impersonate_credentials( target_scopes=scopes, lifetime=lifetime, ) + + +def get_azure_access_token(azure_config: AzureConfigurationSource) -> str: + """ + Get Azure access token using the provided Azure configuration. + + Args: + azure_config: Azure configuration containing the necessary credentials and scopes + + Returns: + str: The access token + + Raises: + ValueError: If Azure config is missing or scopes are not provided + """ + if not azure_config.azureConfig: + raise ValueError("Azure Config is missing") + + if not azure_config.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/" + "en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token " + "and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + + azure_client = AzureClient(azure_config.azureConfig).create_client() + access_token_obj = azure_client.get_token( + *azure_config.azureConfig.scopes.split(",") + ) + + return access_token_obj.token diff --git a/ingestion/tests/unit/test_build_connection_url.py b/ingestion/tests/unit/test_build_connection_url.py index 3c5516c3b2b..f6d6ea1092d 100644 --- a/ingestion/tests/unit/test_build_connection_url.py +++ b/ingestion/tests/unit/test_build_connection_url.py @@ -16,19 +16,17 @@ from metadata.generated.schema.entity.services.connections.database.common.basic BasicAuth, ) from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( - MysqlConnection, + MysqlConnection as MysqlConnectionConfig, ) from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( - PostgresConnection, + PostgresConnection as PostgresConnectionConfig, ) from metadata.generated.schema.security.credentials.azureCredentials import ( AzureCredentials, ) from metadata.ingestion.source.database.azuresql.connection import get_connection_url from metadata.ingestion.source.database.mysql.connection import MySQLConnection -from metadata.ingestion.source.database.postgres.connection import ( - get_connection as postgres_get_connection, -) +from metadata.ingestion.source.database.postgres.connection import PostgresConnection class TestGetConnectionURL(unittest.TestCase): @@ -64,7 +62,7 @@ class TestGetConnectionURL(unittest.TestCase): self.assertEqual(str(get_connection_url(connection)), expected_url) def test_get_connection_url_mysql(self): - connection = MysqlConnection( + connection = MysqlConnectionConfig( username="openmetadata_user", authType=BasicAuth(password="openmetadata_password"), hostPort="localhost:3306", @@ -75,7 +73,7 @@ class TestGetConnectionURL(unittest.TestCase): str(engine_connection.url), "mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", ) - connection = MysqlConnection( + connection = MysqlConnectionConfig( username="openmetadata_user", authType=AzureConfigurationSource( azureConfig=AzureCredentials( @@ -100,18 +98,18 @@ class TestGetConnectionURL(unittest.TestCase): ) def test_get_connection_url_postgres(self): - connection = PostgresConnection( + connection = PostgresConnectionConfig( username="openmetadata_user", authType=BasicAuth(password="openmetadata_password"), hostPort="localhost:3306", database="openmetadata_db", ) - engine_connection = postgres_get_connection(connection) + engine_connection = PostgresConnection(connection).client self.assertEqual( str(engine_connection.url), "postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", ) - connection = PostgresConnection( + connection = PostgresConnectionConfig( username="openmetadata_user", authType=AzureConfigurationSource( azureConfig=AzureCredentials( @@ -129,7 +127,7 @@ class TestGetConnectionURL(unittest.TestCase): "get_token", return_value=AccessToken(token="mocked_token", expires_on=100), ): - engine_connection = postgres_get_connection(connection) + engine_connection = PostgresConnection(connection).client self.assertEqual( str(engine_connection.url), "postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",