Update DataLake and PostgreSQL connection (#22682)

This commit is contained in:
IceS2 2025-08-01 11:08:43 +02:00 committed by GitHub
parent 4364d9cea4
commit 7f8298d49e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 163 additions and 174 deletions

View File

@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect
MysqlConnection as MysqlConnectionConfig, MysqlConnection as MysqlConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection, PostgresConnection as PostgresConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.testConnectionResult import ( from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult, TestConnectionResult,
@ -47,9 +47,7 @@ from metadata.ingestion.source.dashboard.superset.queries import (
FETCH_DASHBOARDS_TEST, FETCH_DASHBOARDS_TEST,
) )
from metadata.ingestion.source.database.mysql.connection import MySQLConnection from metadata.ingestion.source.database.mysql.connection import MySQLConnection
from metadata.ingestion.source.database.postgres.connection import ( from metadata.ingestion.source.database.postgres.connection import PostgresConnection
get_connection as pg_get_connection,
)
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
@ -61,10 +59,10 @@ def get_connection(
""" """
if isinstance(connection.connection, SupersetApiConnection): if isinstance(connection.connection, SupersetApiConnection):
return SupersetAPIClient(connection) return SupersetAPIClient(connection)
if isinstance(connection.connection, PostgresConnection): if isinstance(connection.connection, PostgresConnectionConfig):
return pg_get_connection(connection=connection.connection) return PostgresConnection(connection.connection).client
if isinstance(connection.connection, MysqlConnectionConfig): if isinstance(connection.connection, MysqlConnectionConfig):
return MySQLConnection(connection.connection).get_client() return MySQLConnection(connection.connection).client
return None return None

View File

@ -12,8 +12,6 @@
""" """
Source connection handler Source connection handler
""" """
from dataclasses import dataclass
from functools import singledispatch
from typing import Optional from typing import Optional
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
@ -29,71 +27,49 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
S3Config, S3Config,
) )
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import ( from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection, DatalakeConnection as DatalakeConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.testConnectionResult import ( from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult, TestConnectionResult,
) )
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.test_connections import test_connection_steps from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.datalake.clients.azure_blob import ( from metadata.ingestion.source.database.datalake.clients.azure_blob import (
DatalakeAzureBlobClient, DatalakeAzureBlobClient,
) )
from metadata.ingestion.source.database.datalake.clients.base import DatalakeBaseClient
from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient
from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
# Only import specific datalake dependencies if necessary class DatalakeConnection(BaseConnection[DatalakeConnectionConfig, DatalakeBaseClient]):
# pylint: disable=import-outside-toplevel def _get_client(self) -> DatalakeBaseClient:
@dataclass
class DatalakeClient:
def __init__(self, client, config) -> None:
self.client = client
self.config = config
@singledispatch
def get_datalake_client(config):
""" """
Method to retrieve datalake client from the config Return the appropriate Datalake client based on configSource.
""" """
if config: connection = self.service_connection
msg = f"Config not implemented for type {type(config)}: {config}"
if isinstance(connection.configSource, S3Config):
return DatalakeS3Client.from_config(connection.configSource)
elif isinstance(connection.configSource, GCSConfig):
return DatalakeGcsClient.from_config(connection.configSource)
elif isinstance(connection.configSource, AzureConfig):
return DatalakeAzureBlobClient.from_config(connection.configSource)
else:
msg = f"Config not implemented for type {type(connection.configSource)}: {connection.configSource}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def get_connection_dict(self) -> dict:
@get_datalake_client.register
def _(config: S3Config):
return DatalakeS3Client.from_config(config)
@get_datalake_client.register
def _(config: GCSConfig):
return DatalakeGcsClient.from_config(config)
@get_datalake_client.register
def _(config: AzureConfig):
return DatalakeAzureBlobClient.from_config(config)
def get_connection(connection: DatalakeConnection) -> DatalakeClient:
""" """
Create connection. Return the connection dictionary for this service.
Returns an AWS, Azure or GCS Clients.
""" """
return DatalakeClient( raise NotImplementedError("get_connection_dict is not implemented for Datalake")
client=get_datalake_client(connection.configSource),
config=connection,
)
def test_connection( def test_connection(
self,
metadata: OpenMetadata, metadata: OpenMetadata,
connection: DatalakeClient,
service_connection: DatalakeConnection,
automation_workflow: Optional[AutomationWorkflow] = None, automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN, timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult: ) -> TestConnectionResult:
@ -102,15 +78,17 @@ def test_connection(
of a metadata workflow or during an Automation Workflow of a metadata workflow or during an Automation Workflow
""" """
test_fn = { test_fn = {
"ListBuckets": connection.client.get_test_list_buckets_fn( "ListBuckets": self.client.get_test_list_buckets_fn(
connection.config.bucketName self.service_connection.bucketName
), ),
} }
return test_connection_steps( return test_connection_steps(
metadata=metadata, metadata=metadata,
test_fn=test_fn, test_fn=test_fn,
service_type=service_connection.type.value, service_type=self.service_connection.type.value
if self.service_connection.type
else "Datalake",
automation_workflow=automation_workflow, automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
) )

View File

@ -88,15 +88,14 @@ class DatalakeSource(DatabaseServiceSource):
) )
self.metadata = metadata self.metadata = metadata
self.service_connection = self.config.serviceConnection.root.config self.service_connection = self.config.serviceConnection.root.config
self.connection = get_connection(self.service_connection) self.client = get_connection(self.service_connection)
self.client = self.connection.client
self.table_constraints = None self.table_constraints = None
self.database_source_state = set() self.database_source_state = set()
self.config_source = self.service_connection.configSource self.config_source = self.service_connection.configSource
self.connection_obj = self.connection self.connection_obj = self.client
self.test_connection() self.test_connection()
self.reader = get_reader( self.reader = get_reader(
config_source=self.config_source, client=self.client._client config_source=self.config_source, client=self.client.client
) )
@classmethod @classmethod

View File

@ -1,6 +1,7 @@
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import ( from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface, PandasTestSuiteInterface,
) )
from metadata.ingestion.source.database.datalake.connection import DatalakeConnection
from metadata.ingestion.source.database.datalake.metadata import DatalakeSource from metadata.ingestion.source.database.datalake.metadata import DatalakeSource
from metadata.profiler.interface.pandas.profiler_interface import ( from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface, PandasProfilerInterface,
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
profiler_class=PandasProfilerInterface, profiler_class=PandasProfilerInterface,
test_suite_class=PandasTestSuiteInterface, test_suite_class=PandasTestSuiteInterface,
sampler_class=DatalakeSampler, sampler_class=DatalakeSampler,
connection_class=DatalakeConnection,
) )

View File

@ -16,7 +16,6 @@ from typing import Optional
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
@ -47,6 +46,7 @@ from metadata.ingestion.source.database.mysql.queries import (
MYSQL_TEST_GET_QUERIES_SLOW_LOGS, MYSQL_TEST_GET_QUERIES_SLOW_LOGS,
) )
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]): class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
@ -57,20 +57,8 @@ class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
connection = self.service_connection connection = self.service_connection
if isinstance(connection.authType, AzureConfigurationSource): if isinstance(connection.authType, AzureConfigurationSource):
if not connection.authType.azureConfig: access_token = get_azure_access_token(connection.authType)
raise ValueError("Azure Config is missing") connection.authType = BasicAuth(password=access_token) # type: ignore
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-micr"
"osoft-entra-access-token and fetch the resource associated with it,"
" for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.authType = BasicAuth(password=access_token_obj.token) # type: ignore
return create_generic_db_connection( return create_generic_db_connection(
connection=connection, connection=connection,
get_connection_url_fn=get_connection_url_common, get_connection_url_fn=get_connection_url_common,

View File

@ -12,20 +12,21 @@
""" """
Source connection handler Source connection handler
""" """
from typing import Optional from typing import Optional
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
AzureConfigurationSource,
)
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth, BasicAuth,
) )
from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection, PostgresConnection as PostgresConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.testConnectionResult import ( from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult, TestConnectionResult,
@ -35,6 +36,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common, get_connection_args_common,
get_connection_url_common, get_connection_url_common,
) )
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.test_connections import test_connection_db_common from metadata.ingestion.connections.test_connections import test_connection_db_common
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.postgres.queries import ( from metadata.ingestion.source.database.postgres.queries import (
@ -46,34 +48,36 @@ from metadata.ingestion.source.database.postgres.utils import (
get_postgres_time_column_name, get_postgres_time_column_name,
) )
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
def get_connection(connection: PostgresConnection) -> Engine: class PostgresConnection(BaseConnection[PostgresConnectionConfig, Engine]):
def _get_client(self) -> Engine:
""" """
Create connection Return the SQLAlchemy Engine for PostgreSQL.
""" """
connection = self.service_connection
if hasattr(connection.authType, "azureConfig"): if isinstance(connection.authType, AzureConfigurationSource):
azure_client = AzureClient(connection.authType.azureConfig).create_client() access_token = get_azure_access_token(connection.authType)
if not connection.authType.azureConfig.scopes: connection.authType = BasicAuth(password=access_token) # type: ignore
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.authType = BasicAuth(password=access_token_obj.token)
return create_generic_db_connection( return create_generic_db_connection(
connection=connection, connection=connection,
get_connection_url_fn=get_connection_url_common, get_connection_url_fn=get_connection_url_common,
get_connection_args_fn=get_connection_args_common, get_connection_args_fn=get_connection_args_common,
) )
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
raise NotImplementedError(
"get_connection_dict is not implemented for PostgreSQL"
)
def test_connection( def test_connection(
self,
metadata: OpenMetadata, metadata: OpenMetadata,
engine: Engine,
service_connection: PostgresConnection,
automation_workflow: Optional[AutomationWorkflow] = None, automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN, timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult: ) -> TestConnectionResult:
@ -81,19 +85,18 @@ def test_connection(
Test connection. This can be executed either as part Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow of a metadata workflow or during an Automation Workflow
""" """
queries = { queries = {
"GetQueries": POSTGRES_TEST_GET_QUERIES.format( "GetQueries": POSTGRES_TEST_GET_QUERIES.format(
time_column_name=get_postgres_time_column_name(engine=engine), time_column_name=get_postgres_time_column_name(engine=self.client),
), ),
"GetDatabases": POSTGRES_GET_DATABASE, "GetDatabases": POSTGRES_GET_DATABASE,
"GetTags": POSTGRES_TEST_GET_TAGS, "GetTags": POSTGRES_TEST_GET_TAGS,
} }
return test_connection_db_common( return test_connection_db_common(
metadata=metadata, metadata=metadata,
engine=engine, engine=self.client,
service_connection=service_connection, service_connection=self.service_connection,
automation_workflow=automation_workflow, automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
queries=queries,
) )

View File

@ -1,3 +1,4 @@
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource
from metadata.ingestion.source.database.postgres.metadata import PostgresSource from metadata.ingestion.source.database.postgres.metadata import PostgresSource
from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource
@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec(
metadata_source_class=PostgresSource, metadata_source_class=PostgresSource,
lineage_source_class=PostgresLineageSource, lineage_source_class=PostgresLineageSource,
usage_source_class=PostgresUsageSource, usage_source_class=PostgresUsageSource,
connection_class=PostgresConnection,
) )

View File

@ -20,7 +20,6 @@ from requests import Session
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
@ -53,6 +52,7 @@ from metadata.ingestion.connections.test_connections import (
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
from metadata.utils.credentials import get_azure_access_token
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -82,17 +82,11 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
connection_copy = deepcopy(connection) connection_copy = deepcopy(connection)
if hasattr(connection.authType, "azureConfig"): if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client() auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
if not connection.authType.azureConfig.scopes: access_token = get_azure_access_token(auth_type)
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
if not connection.connectionOptions: if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options() connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token connection.connectionOptions.root["access_token"] = access_token
# Update the connection with the connection arguments # Update the connection with the connection arguments
connection_copy.connectionArguments = self.build_connection_args( connection_copy.connectionArguments = self.build_connection_args(
@ -355,12 +349,4 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
Get the azure token for the trino connection Get the azure token for the trino connection
""" """
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType) auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
return get_azure_access_token(auth_type)
if not auth_type.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(auth_type.azureConfig).create_client()
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token

View File

@ -60,7 +60,7 @@ class DatalakeSampler(SamplerInterface, PandasInterfaceMixin):
return self._table return self._table
def get_client(self): def get_client(self):
return self.connection.client return self.connection
def _partitioned_table(self): def _partitioned_table(self):
"""Get partitioned table""" """Get partitioned table"""

View File

@ -21,6 +21,10 @@ from cryptography.hazmat.primitives import serialization
from google import auth from google import auth
from google.auth import impersonated_credentials from google.auth import impersonated_credentials
from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
AzureConfigurationSource,
)
from metadata.generated.schema.security.credentials.gcpCredentials import ( from metadata.generated.schema.security.credentials.gcpCredentials import (
GcpADC, GcpADC,
GCPCredentials, GCPCredentials,
@ -237,3 +241,34 @@ def get_gcp_impersonate_credentials(
target_scopes=scopes, target_scopes=scopes,
lifetime=lifetime, lifetime=lifetime,
) )
def get_azure_access_token(azure_config: AzureConfigurationSource) -> str:
"""
Get Azure access token using the provided Azure configuration.
Args:
azure_config: Azure configuration containing the necessary credentials and scopes
Returns:
str: The access token
Raises:
ValueError: If Azure config is missing or scopes are not provided
"""
if not azure_config.azureConfig:
raise ValueError("Azure Config is missing")
if not azure_config.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token "
"and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(azure_config.azureConfig).create_client()
access_token_obj = azure_client.get_token(
*azure_config.azureConfig.scopes.split(",")
)
return access_token_obj.token

View File

@ -16,19 +16,17 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
BasicAuth, BasicAuth,
) )
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection, MysqlConnection as MysqlConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection, PostgresConnection as PostgresConnectionConfig,
) )
from metadata.generated.schema.security.credentials.azureCredentials import ( from metadata.generated.schema.security.credentials.azureCredentials import (
AzureCredentials, AzureCredentials,
) )
from metadata.ingestion.source.database.azuresql.connection import get_connection_url from metadata.ingestion.source.database.azuresql.connection import get_connection_url
from metadata.ingestion.source.database.mysql.connection import MySQLConnection from metadata.ingestion.source.database.mysql.connection import MySQLConnection
from metadata.ingestion.source.database.postgres.connection import ( from metadata.ingestion.source.database.postgres.connection import PostgresConnection
get_connection as postgres_get_connection,
)
class TestGetConnectionURL(unittest.TestCase): class TestGetConnectionURL(unittest.TestCase):
@ -64,7 +62,7 @@ class TestGetConnectionURL(unittest.TestCase):
self.assertEqual(str(get_connection_url(connection)), expected_url) self.assertEqual(str(get_connection_url(connection)), expected_url)
def test_get_connection_url_mysql(self): def test_get_connection_url_mysql(self):
connection = MysqlConnection( connection = MysqlConnectionConfig(
username="openmetadata_user", username="openmetadata_user",
authType=BasicAuth(password="openmetadata_password"), authType=BasicAuth(password="openmetadata_password"),
hostPort="localhost:3306", hostPort="localhost:3306",
@ -75,7 +73,7 @@ class TestGetConnectionURL(unittest.TestCase):
str(engine_connection.url), str(engine_connection.url),
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", "mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
) )
connection = MysqlConnection( connection = MysqlConnectionConfig(
username="openmetadata_user", username="openmetadata_user",
authType=AzureConfigurationSource( authType=AzureConfigurationSource(
azureConfig=AzureCredentials( azureConfig=AzureCredentials(
@ -100,18 +98,18 @@ class TestGetConnectionURL(unittest.TestCase):
) )
def test_get_connection_url_postgres(self): def test_get_connection_url_postgres(self):
connection = PostgresConnection( connection = PostgresConnectionConfig(
username="openmetadata_user", username="openmetadata_user",
authType=BasicAuth(password="openmetadata_password"), authType=BasicAuth(password="openmetadata_password"),
hostPort="localhost:3306", hostPort="localhost:3306",
database="openmetadata_db", database="openmetadata_db",
) )
engine_connection = postgres_get_connection(connection) engine_connection = PostgresConnection(connection).client
self.assertEqual( self.assertEqual(
str(engine_connection.url), str(engine_connection.url),
"postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", "postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
) )
connection = PostgresConnection( connection = PostgresConnectionConfig(
username="openmetadata_user", username="openmetadata_user",
authType=AzureConfigurationSource( authType=AzureConfigurationSource(
azureConfig=AzureCredentials( azureConfig=AzureCredentials(
@ -129,7 +127,7 @@ class TestGetConnectionURL(unittest.TestCase):
"get_token", "get_token",
return_value=AccessToken(token="mocked_token", expires_on=100), return_value=AccessToken(token="mocked_token", expires_on=100),
): ):
engine_connection = postgres_get_connection(connection) engine_connection = PostgresConnection(connection).client
self.assertEqual( self.assertEqual(
str(engine_connection.url), str(engine_connection.url),
"postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db", "postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",