Update DataLake and PostgreSQL connection (#22682)

This commit is contained in:
IceS2 2025-08-01 11:08:43 +02:00 committed by GitHub
parent 4364d9cea4
commit 7f8298d49e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 163 additions and 174 deletions

View File

@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect
MysqlConnection as MysqlConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
PostgresConnection as PostgresConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult,
@ -47,9 +47,7 @@ from metadata.ingestion.source.dashboard.superset.queries import (
FETCH_DASHBOARDS_TEST,
)
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
from metadata.ingestion.source.database.postgres.connection import (
get_connection as pg_get_connection,
)
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
from metadata.utils.constants import THREE_MIN
@ -61,10 +59,10 @@ def get_connection(
"""
if isinstance(connection.connection, SupersetApiConnection):
return SupersetAPIClient(connection)
if isinstance(connection.connection, PostgresConnection):
return pg_get_connection(connection=connection.connection)
if isinstance(connection.connection, PostgresConnectionConfig):
return PostgresConnection(connection.connection).client
if isinstance(connection.connection, MysqlConnectionConfig):
return MySQLConnection(connection.connection).get_client()
return MySQLConnection(connection.connection).client
return None

View File

@ -12,8 +12,6 @@
"""
Source connection handler
"""
from dataclasses import dataclass
from functools import singledispatch
from typing import Optional
from metadata.generated.schema.entity.automations.workflow import (
@ -29,71 +27,49 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
S3Config,
)
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
DatalakeConnection as DatalakeConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult,
)
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.datalake.clients.azure_blob import (
DatalakeAzureBlobClient,
)
from metadata.ingestion.source.database.datalake.clients.base import DatalakeBaseClient
from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient
from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client
from metadata.utils.constants import THREE_MIN
# Only import specific datalake dependencies if necessary
# pylint: disable=import-outside-toplevel
@dataclass
class DatalakeClient:
def __init__(self, client, config) -> None:
self.client = client
self.config = config
@singledispatch
def get_datalake_client(config):
class DatalakeConnection(BaseConnection[DatalakeConnectionConfig, DatalakeBaseClient]):
def _get_client(self) -> DatalakeBaseClient:
"""
Method to retrieve datalake client from the config
Return the appropriate Datalake client based on configSource.
"""
if config:
msg = f"Config not implemented for type {type(config)}: {config}"
connection = self.service_connection
if isinstance(connection.configSource, S3Config):
return DatalakeS3Client.from_config(connection.configSource)
elif isinstance(connection.configSource, GCSConfig):
return DatalakeGcsClient.from_config(connection.configSource)
elif isinstance(connection.configSource, AzureConfig):
return DatalakeAzureBlobClient.from_config(connection.configSource)
else:
msg = f"Config not implemented for type {type(connection.configSource)}: {connection.configSource}"
raise NotImplementedError(msg)
@get_datalake_client.register
def _(config: S3Config):
return DatalakeS3Client.from_config(config)
@get_datalake_client.register
def _(config: GCSConfig):
return DatalakeGcsClient.from_config(config)
@get_datalake_client.register
def _(config: AzureConfig):
return DatalakeAzureBlobClient.from_config(config)
def get_connection(connection: DatalakeConnection) -> DatalakeClient:
def get_connection_dict(self) -> dict:
"""
Create connection.
Returns an AWS, Azure or GCS Clients.
Return the connection dictionary for this service.
"""
return DatalakeClient(
client=get_datalake_client(connection.configSource),
config=connection,
)
raise NotImplementedError("get_connection_dict is not implemented for Datalake")
def test_connection(
self,
metadata: OpenMetadata,
connection: DatalakeClient,
service_connection: DatalakeConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
@ -102,15 +78,17 @@ def test_connection(
of a metadata workflow or during an Automation Workflow
"""
test_fn = {
"ListBuckets": connection.client.get_test_list_buckets_fn(
connection.config.bucketName
"ListBuckets": self.client.get_test_list_buckets_fn(
self.service_connection.bucketName
),
}
return test_connection_steps(
metadata=metadata,
test_fn=test_fn,
service_type=service_connection.type.value,
service_type=self.service_connection.type.value
if self.service_connection.type
else "Datalake",
automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds,
)

View File

@ -88,15 +88,14 @@ class DatalakeSource(DatabaseServiceSource):
)
self.metadata = metadata
self.service_connection = self.config.serviceConnection.root.config
self.connection = get_connection(self.service_connection)
self.client = self.connection.client
self.client = get_connection(self.service_connection)
self.table_constraints = None
self.database_source_state = set()
self.config_source = self.service_connection.configSource
self.connection_obj = self.connection
self.connection_obj = self.client
self.test_connection()
self.reader = get_reader(
config_source=self.config_source, client=self.client._client
config_source=self.config_source, client=self.client.client
)
@classmethod

View File

@ -1,6 +1,7 @@
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface,
)
from metadata.ingestion.source.database.datalake.connection import DatalakeConnection
from metadata.ingestion.source.database.datalake.metadata import DatalakeSource
from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface,
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
profiler_class=PandasProfilerInterface,
test_suite_class=PandasTestSuiteInterface,
sampler_class=DatalakeSampler,
connection_class=DatalakeConnection,
)

View File

@ -16,7 +16,6 @@ from typing import Optional
from sqlalchemy.engine import Engine
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
@ -47,6 +46,7 @@ from metadata.ingestion.source.database.mysql.queries import (
MYSQL_TEST_GET_QUERIES_SLOW_LOGS,
)
from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
@ -57,20 +57,8 @@ class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
connection = self.service_connection
if isinstance(connection.authType, AzureConfigurationSource):
if not connection.authType.azureConfig:
raise ValueError("Azure Config is missing")
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-micr"
"osoft-entra-access-token and fetch the resource associated with it,"
" for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.authType = BasicAuth(password=access_token_obj.token) # type: ignore
access_token = get_azure_access_token(connection.authType)
connection.authType = BasicAuth(password=access_token) # type: ignore
return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url_common,

View File

@ -12,20 +12,21 @@
"""
Source connection handler
"""
from typing import Optional
from sqlalchemy.engine import Engine
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
AzureConfigurationSource,
)
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
PostgresConnection as PostgresConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult,
@ -35,6 +36,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common,
get_connection_url_common,
)
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.test_connections import test_connection_db_common
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.postgres.queries import (
@ -46,34 +48,36 @@ from metadata.ingestion.source.database.postgres.utils import (
get_postgres_time_column_name,
)
from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
def get_connection(connection: PostgresConnection) -> Engine:
class PostgresConnection(BaseConnection[PostgresConnectionConfig, Engine]):
def _get_client(self) -> Engine:
"""
Create connection
Return the SQLAlchemy Engine for PostgreSQL.
"""
connection = self.service_connection
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.authType = BasicAuth(password=access_token_obj.token)
if isinstance(connection.authType, AzureConfigurationSource):
access_token = get_azure_access_token(connection.authType)
connection.authType = BasicAuth(password=access_token) # type: ignore
return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url_common,
get_connection_args_fn=get_connection_args_common,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
raise NotImplementedError(
"get_connection_dict is not implemented for PostgreSQL"
)
def test_connection(
self,
metadata: OpenMetadata,
engine: Engine,
service_connection: PostgresConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
@ -81,19 +85,18 @@ def test_connection(
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetQueries": POSTGRES_TEST_GET_QUERIES.format(
time_column_name=get_postgres_time_column_name(engine=engine),
time_column_name=get_postgres_time_column_name(engine=self.client),
),
"GetDatabases": POSTGRES_GET_DATABASE,
"GetTags": POSTGRES_TEST_GET_TAGS,
}
return test_connection_db_common(
metadata=metadata,
engine=engine,
service_connection=service_connection,
engine=self.client,
service_connection=self.service_connection,
automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds,
queries=queries,
)

View File

@ -1,3 +1,4 @@
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource
from metadata.ingestion.source.database.postgres.metadata import PostgresSource
from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource
@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec(
metadata_source_class=PostgresSource,
lineage_source_class=PostgresLineageSource,
usage_source_class=PostgresUsageSource,
connection_class=PostgresConnection,
)

View File

@ -20,7 +20,6 @@ from requests import Session
from sqlalchemy.engine import Engine
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
@ -53,6 +52,7 @@ from metadata.ingestion.connections.test_connections import (
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
# pylint: disable=unused-argument
@ -82,17 +82,11 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
connection_copy = deepcopy(connection)
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
access_token = get_azure_access_token(auth_type)
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token
connection.connectionOptions.root["access_token"] = access_token
# Update the connection with the connection arguments
connection_copy.connectionArguments = self.build_connection_args(
@ -355,12 +349,4 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
Get the azure token for the trino connection
"""
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
if not auth_type.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(auth_type.azureConfig).create_client()
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token
return get_azure_access_token(auth_type)

View File

@ -60,7 +60,7 @@ class DatalakeSampler(SamplerInterface, PandasInterfaceMixin):
return self._table
def get_client(self):
return self.connection.client
return self.connection
def _partitioned_table(self):
"""Get partitioned table"""

View File

@ -21,6 +21,10 @@ from cryptography.hazmat.primitives import serialization
from google import auth
from google.auth import impersonated_credentials
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
AzureConfigurationSource,
)
from metadata.generated.schema.security.credentials.gcpCredentials import (
GcpADC,
GCPCredentials,
@ -237,3 +241,34 @@ def get_gcp_impersonate_credentials(
target_scopes=scopes,
lifetime=lifetime,
)
def get_azure_access_token(azure_config: AzureConfigurationSource) -> str:
"""
Get Azure access token using the provided Azure configuration.
Args:
azure_config: Azure configuration containing the necessary credentials and scopes
Returns:
str: The access token
Raises:
ValueError: If Azure config is missing or scopes are not provided
"""
if not azure_config.azureConfig:
raise ValueError("Azure Config is missing")
if not azure_config.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token "
"and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(azure_config.azureConfig).create_client()
access_token_obj = azure_client.get_token(
*azure_config.azureConfig.scopes.split(",")
)
return access_token_obj.token

View File

@ -16,19 +16,17 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
MysqlConnection as MysqlConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
PostgresConnection as PostgresConnectionConfig,
)
from metadata.generated.schema.security.credentials.azureCredentials import (
AzureCredentials,
)
from metadata.ingestion.source.database.azuresql.connection import get_connection_url
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
from metadata.ingestion.source.database.postgres.connection import (
get_connection as postgres_get_connection,
)
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
class TestGetConnectionURL(unittest.TestCase):
@ -64,7 +62,7 @@ class TestGetConnectionURL(unittest.TestCase):
self.assertEqual(str(get_connection_url(connection)), expected_url)
def test_get_connection_url_mysql(self):
connection = MysqlConnection(
connection = MysqlConnectionConfig(
username="openmetadata_user",
authType=BasicAuth(password="openmetadata_password"),
hostPort="localhost:3306",
@ -75,7 +73,7 @@ class TestGetConnectionURL(unittest.TestCase):
str(engine_connection.url),
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
)
connection = MysqlConnection(
connection = MysqlConnectionConfig(
username="openmetadata_user",
authType=AzureConfigurationSource(
azureConfig=AzureCredentials(
@ -100,18 +98,18 @@ class TestGetConnectionURL(unittest.TestCase):
)
def test_get_connection_url_postgres(self):
connection = PostgresConnection(
connection = PostgresConnectionConfig(
username="openmetadata_user",
authType=BasicAuth(password="openmetadata_password"),
hostPort="localhost:3306",
database="openmetadata_db",
)
engine_connection = postgres_get_connection(connection)
engine_connection = PostgresConnection(connection).client
self.assertEqual(
str(engine_connection.url),
"postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
)
connection = PostgresConnection(
connection = PostgresConnectionConfig(
username="openmetadata_user",
authType=AzureConfigurationSource(
azureConfig=AzureCredentials(
@ -129,7 +127,7 @@ class TestGetConnectionURL(unittest.TestCase):
"get_token",
return_value=AccessToken(token="mocked_token", expires_on=100),
):
engine_connection = postgres_get_connection(connection)
engine_connection = PostgresConnection(connection).client
self.assertEqual(
str(engine_connection.url),
"postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",