mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-09 07:23:39 +00:00
Update DataLake and PostgreSQL connection (#22682)
This commit is contained in:
parent
4364d9cea4
commit
7f8298d49e
@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect
|
|||||||
MysqlConnection as MysqlConnectionConfig,
|
MysqlConnection as MysqlConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
||||||
PostgresConnection,
|
PostgresConnection as PostgresConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
||||||
TestConnectionResult,
|
TestConnectionResult,
|
||||||
@ -47,9 +47,7 @@ from metadata.ingestion.source.dashboard.superset.queries import (
|
|||||||
FETCH_DASHBOARDS_TEST,
|
FETCH_DASHBOARDS_TEST,
|
||||||
)
|
)
|
||||||
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
|
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
|
||||||
from metadata.ingestion.source.database.postgres.connection import (
|
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
|
||||||
get_connection as pg_get_connection,
|
|
||||||
)
|
|
||||||
from metadata.utils.constants import THREE_MIN
|
from metadata.utils.constants import THREE_MIN
|
||||||
|
|
||||||
|
|
||||||
@ -61,10 +59,10 @@ def get_connection(
|
|||||||
"""
|
"""
|
||||||
if isinstance(connection.connection, SupersetApiConnection):
|
if isinstance(connection.connection, SupersetApiConnection):
|
||||||
return SupersetAPIClient(connection)
|
return SupersetAPIClient(connection)
|
||||||
if isinstance(connection.connection, PostgresConnection):
|
if isinstance(connection.connection, PostgresConnectionConfig):
|
||||||
return pg_get_connection(connection=connection.connection)
|
return PostgresConnection(connection.connection).client
|
||||||
if isinstance(connection.connection, MysqlConnectionConfig):
|
if isinstance(connection.connection, MysqlConnectionConfig):
|
||||||
return MySQLConnection(connection.connection).get_client()
|
return MySQLConnection(connection.connection).client
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,6 @@
|
|||||||
"""
|
"""
|
||||||
Source connection handler
|
Source connection handler
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import singledispatch
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from metadata.generated.schema.entity.automations.workflow import (
|
from metadata.generated.schema.entity.automations.workflow import (
|
||||||
@ -29,71 +27,49 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
|
|||||||
S3Config,
|
S3Config,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||||
DatalakeConnection,
|
DatalakeConnection as DatalakeConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
||||||
TestConnectionResult,
|
TestConnectionResult,
|
||||||
)
|
)
|
||||||
|
from metadata.ingestion.connections.connection import BaseConnection
|
||||||
from metadata.ingestion.connections.test_connections import test_connection_steps
|
from metadata.ingestion.connections.test_connections import test_connection_steps
|
||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.ingestion.source.database.datalake.clients.azure_blob import (
|
from metadata.ingestion.source.database.datalake.clients.azure_blob import (
|
||||||
DatalakeAzureBlobClient,
|
DatalakeAzureBlobClient,
|
||||||
)
|
)
|
||||||
|
from metadata.ingestion.source.database.datalake.clients.base import DatalakeBaseClient
|
||||||
from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient
|
from metadata.ingestion.source.database.datalake.clients.gcs import DatalakeGcsClient
|
||||||
from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client
|
from metadata.ingestion.source.database.datalake.clients.s3 import DatalakeS3Client
|
||||||
from metadata.utils.constants import THREE_MIN
|
from metadata.utils.constants import THREE_MIN
|
||||||
|
|
||||||
|
|
||||||
# Only import specific datalake dependencies if necessary
|
class DatalakeConnection(BaseConnection[DatalakeConnectionConfig, DatalakeBaseClient]):
|
||||||
# pylint: disable=import-outside-toplevel
|
def _get_client(self) -> DatalakeBaseClient:
|
||||||
@dataclass
|
|
||||||
class DatalakeClient:
|
|
||||||
def __init__(self, client, config) -> None:
|
|
||||||
self.client = client
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
|
||||||
def get_datalake_client(config):
|
|
||||||
"""
|
"""
|
||||||
Method to retrieve datalake client from the config
|
Return the appropriate Datalake client based on configSource.
|
||||||
"""
|
"""
|
||||||
if config:
|
connection = self.service_connection
|
||||||
msg = f"Config not implemented for type {type(config)}: {config}"
|
|
||||||
|
if isinstance(connection.configSource, S3Config):
|
||||||
|
return DatalakeS3Client.from_config(connection.configSource)
|
||||||
|
elif isinstance(connection.configSource, GCSConfig):
|
||||||
|
return DatalakeGcsClient.from_config(connection.configSource)
|
||||||
|
elif isinstance(connection.configSource, AzureConfig):
|
||||||
|
return DatalakeAzureBlobClient.from_config(connection.configSource)
|
||||||
|
else:
|
||||||
|
msg = f"Config not implemented for type {type(connection.configSource)}: {connection.configSource}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
def get_connection_dict(self) -> dict:
|
||||||
@get_datalake_client.register
|
|
||||||
def _(config: S3Config):
|
|
||||||
return DatalakeS3Client.from_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
@get_datalake_client.register
|
|
||||||
def _(config: GCSConfig):
|
|
||||||
return DatalakeGcsClient.from_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
@get_datalake_client.register
|
|
||||||
def _(config: AzureConfig):
|
|
||||||
return DatalakeAzureBlobClient.from_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_connection(connection: DatalakeConnection) -> DatalakeClient:
|
|
||||||
"""
|
"""
|
||||||
Create connection.
|
Return the connection dictionary for this service.
|
||||||
|
|
||||||
Returns an AWS, Azure or GCS Clients.
|
|
||||||
"""
|
"""
|
||||||
return DatalakeClient(
|
raise NotImplementedError("get_connection_dict is not implemented for Datalake")
|
||||||
client=get_datalake_client(connection.configSource),
|
|
||||||
config=connection,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connection(
|
def test_connection(
|
||||||
|
self,
|
||||||
metadata: OpenMetadata,
|
metadata: OpenMetadata,
|
||||||
connection: DatalakeClient,
|
|
||||||
service_connection: DatalakeConnection,
|
|
||||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||||
timeout_seconds: Optional[int] = THREE_MIN,
|
timeout_seconds: Optional[int] = THREE_MIN,
|
||||||
) -> TestConnectionResult:
|
) -> TestConnectionResult:
|
||||||
@ -102,15 +78,17 @@ def test_connection(
|
|||||||
of a metadata workflow or during an Automation Workflow
|
of a metadata workflow or during an Automation Workflow
|
||||||
"""
|
"""
|
||||||
test_fn = {
|
test_fn = {
|
||||||
"ListBuckets": connection.client.get_test_list_buckets_fn(
|
"ListBuckets": self.client.get_test_list_buckets_fn(
|
||||||
connection.config.bucketName
|
self.service_connection.bucketName
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
return test_connection_steps(
|
return test_connection_steps(
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
service_type=service_connection.type.value,
|
service_type=self.service_connection.type.value
|
||||||
|
if self.service_connection.type
|
||||||
|
else "Datalake",
|
||||||
automation_workflow=automation_workflow,
|
automation_workflow=automation_workflow,
|
||||||
timeout_seconds=timeout_seconds,
|
timeout_seconds=timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -88,15 +88,14 @@ class DatalakeSource(DatabaseServiceSource):
|
|||||||
)
|
)
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.service_connection = self.config.serviceConnection.root.config
|
self.service_connection = self.config.serviceConnection.root.config
|
||||||
self.connection = get_connection(self.service_connection)
|
self.client = get_connection(self.service_connection)
|
||||||
self.client = self.connection.client
|
|
||||||
self.table_constraints = None
|
self.table_constraints = None
|
||||||
self.database_source_state = set()
|
self.database_source_state = set()
|
||||||
self.config_source = self.service_connection.configSource
|
self.config_source = self.service_connection.configSource
|
||||||
self.connection_obj = self.connection
|
self.connection_obj = self.client
|
||||||
self.test_connection()
|
self.test_connection()
|
||||||
self.reader = get_reader(
|
self.reader = get_reader(
|
||||||
config_source=self.config_source, client=self.client._client
|
config_source=self.config_source, client=self.client.client
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
|
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
|
||||||
PandasTestSuiteInterface,
|
PandasTestSuiteInterface,
|
||||||
)
|
)
|
||||||
|
from metadata.ingestion.source.database.datalake.connection import DatalakeConnection
|
||||||
from metadata.ingestion.source.database.datalake.metadata import DatalakeSource
|
from metadata.ingestion.source.database.datalake.metadata import DatalakeSource
|
||||||
from metadata.profiler.interface.pandas.profiler_interface import (
|
from metadata.profiler.interface.pandas.profiler_interface import (
|
||||||
PandasProfilerInterface,
|
PandasProfilerInterface,
|
||||||
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
|
|||||||
profiler_class=PandasProfilerInterface,
|
profiler_class=PandasProfilerInterface,
|
||||||
test_suite_class=PandasTestSuiteInterface,
|
test_suite_class=PandasTestSuiteInterface,
|
||||||
sampler_class=DatalakeSampler,
|
sampler_class=DatalakeSampler,
|
||||||
|
connection_class=DatalakeConnection,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
from metadata.clients.azure_client import AzureClient
|
|
||||||
from metadata.generated.schema.entity.automations.workflow import (
|
from metadata.generated.schema.entity.automations.workflow import (
|
||||||
Workflow as AutomationWorkflow,
|
Workflow as AutomationWorkflow,
|
||||||
)
|
)
|
||||||
@ -47,6 +46,7 @@ from metadata.ingestion.source.database.mysql.queries import (
|
|||||||
MYSQL_TEST_GET_QUERIES_SLOW_LOGS,
|
MYSQL_TEST_GET_QUERIES_SLOW_LOGS,
|
||||||
)
|
)
|
||||||
from metadata.utils.constants import THREE_MIN
|
from metadata.utils.constants import THREE_MIN
|
||||||
|
from metadata.utils.credentials import get_azure_access_token
|
||||||
|
|
||||||
|
|
||||||
class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
|
class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
|
||||||
@ -57,20 +57,8 @@ class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
|
|||||||
connection = self.service_connection
|
connection = self.service_connection
|
||||||
|
|
||||||
if isinstance(connection.authType, AzureConfigurationSource):
|
if isinstance(connection.authType, AzureConfigurationSource):
|
||||||
if not connection.authType.azureConfig:
|
access_token = get_azure_access_token(connection.authType)
|
||||||
raise ValueError("Azure Config is missing")
|
connection.authType = BasicAuth(password=access_token) # type: ignore
|
||||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
|
||||||
if not connection.authType.azureConfig.scopes:
|
|
||||||
raise ValueError(
|
|
||||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
|
|
||||||
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-micr"
|
|
||||||
"osoft-entra-access-token and fetch the resource associated with it,"
|
|
||||||
" for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
|
||||||
)
|
|
||||||
access_token_obj = azure_client.get_token(
|
|
||||||
*connection.authType.azureConfig.scopes.split(",")
|
|
||||||
)
|
|
||||||
connection.authType = BasicAuth(password=access_token_obj.token) # type: ignore
|
|
||||||
return create_generic_db_connection(
|
return create_generic_db_connection(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
get_connection_url_fn=get_connection_url_common,
|
get_connection_url_fn=get_connection_url_common,
|
||||||
|
|||||||
@ -12,20 +12,21 @@
|
|||||||
"""
|
"""
|
||||||
Source connection handler
|
Source connection handler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
from metadata.clients.azure_client import AzureClient
|
|
||||||
from metadata.generated.schema.entity.automations.workflow import (
|
from metadata.generated.schema.entity.automations.workflow import (
|
||||||
Workflow as AutomationWorkflow,
|
Workflow as AutomationWorkflow,
|
||||||
)
|
)
|
||||||
|
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
|
||||||
|
AzureConfigurationSource,
|
||||||
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
||||||
BasicAuth,
|
BasicAuth,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
||||||
PostgresConnection,
|
PostgresConnection as PostgresConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
||||||
TestConnectionResult,
|
TestConnectionResult,
|
||||||
@ -35,6 +36,7 @@ from metadata.ingestion.connections.builders import (
|
|||||||
get_connection_args_common,
|
get_connection_args_common,
|
||||||
get_connection_url_common,
|
get_connection_url_common,
|
||||||
)
|
)
|
||||||
|
from metadata.ingestion.connections.connection import BaseConnection
|
||||||
from metadata.ingestion.connections.test_connections import test_connection_db_common
|
from metadata.ingestion.connections.test_connections import test_connection_db_common
|
||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.ingestion.source.database.postgres.queries import (
|
from metadata.ingestion.source.database.postgres.queries import (
|
||||||
@ -46,34 +48,36 @@ from metadata.ingestion.source.database.postgres.utils import (
|
|||||||
get_postgres_time_column_name,
|
get_postgres_time_column_name,
|
||||||
)
|
)
|
||||||
from metadata.utils.constants import THREE_MIN
|
from metadata.utils.constants import THREE_MIN
|
||||||
|
from metadata.utils.credentials import get_azure_access_token
|
||||||
|
|
||||||
|
|
||||||
def get_connection(connection: PostgresConnection) -> Engine:
|
class PostgresConnection(BaseConnection[PostgresConnectionConfig, Engine]):
|
||||||
|
def _get_client(self) -> Engine:
|
||||||
"""
|
"""
|
||||||
Create connection
|
Return the SQLAlchemy Engine for PostgreSQL.
|
||||||
"""
|
"""
|
||||||
|
connection = self.service_connection
|
||||||
|
|
||||||
if hasattr(connection.authType, "azureConfig"):
|
if isinstance(connection.authType, AzureConfigurationSource):
|
||||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
access_token = get_azure_access_token(connection.authType)
|
||||||
if not connection.authType.azureConfig.scopes:
|
connection.authType = BasicAuth(password=access_token) # type: ignore
|
||||||
raise ValueError(
|
|
||||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication#retrieve-the-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
|
||||||
)
|
|
||||||
access_token_obj = azure_client.get_token(
|
|
||||||
*connection.authType.azureConfig.scopes.split(",")
|
|
||||||
)
|
|
||||||
connection.authType = BasicAuth(password=access_token_obj.token)
|
|
||||||
return create_generic_db_connection(
|
return create_generic_db_connection(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
get_connection_url_fn=get_connection_url_common,
|
get_connection_url_fn=get_connection_url_common,
|
||||||
get_connection_args_fn=get_connection_args_common,
|
get_connection_args_fn=get_connection_args_common,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_connection_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Return the connection dictionary for this service.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_connection_dict is not implemented for PostgreSQL"
|
||||||
|
)
|
||||||
|
|
||||||
def test_connection(
|
def test_connection(
|
||||||
|
self,
|
||||||
metadata: OpenMetadata,
|
metadata: OpenMetadata,
|
||||||
engine: Engine,
|
|
||||||
service_connection: PostgresConnection,
|
|
||||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||||
timeout_seconds: Optional[int] = THREE_MIN,
|
timeout_seconds: Optional[int] = THREE_MIN,
|
||||||
) -> TestConnectionResult:
|
) -> TestConnectionResult:
|
||||||
@ -81,19 +85,18 @@ def test_connection(
|
|||||||
Test connection. This can be executed either as part
|
Test connection. This can be executed either as part
|
||||||
of a metadata workflow or during an Automation Workflow
|
of a metadata workflow or during an Automation Workflow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
queries = {
|
queries = {
|
||||||
"GetQueries": POSTGRES_TEST_GET_QUERIES.format(
|
"GetQueries": POSTGRES_TEST_GET_QUERIES.format(
|
||||||
time_column_name=get_postgres_time_column_name(engine=engine),
|
time_column_name=get_postgres_time_column_name(engine=self.client),
|
||||||
),
|
),
|
||||||
"GetDatabases": POSTGRES_GET_DATABASE,
|
"GetDatabases": POSTGRES_GET_DATABASE,
|
||||||
"GetTags": POSTGRES_TEST_GET_TAGS,
|
"GetTags": POSTGRES_TEST_GET_TAGS,
|
||||||
}
|
}
|
||||||
return test_connection_db_common(
|
return test_connection_db_common(
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
engine=engine,
|
engine=self.client,
|
||||||
service_connection=service_connection,
|
service_connection=self.service_connection,
|
||||||
automation_workflow=automation_workflow,
|
automation_workflow=automation_workflow,
|
||||||
queries=queries,
|
|
||||||
timeout_seconds=timeout_seconds,
|
timeout_seconds=timeout_seconds,
|
||||||
|
queries=queries,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
|
||||||
from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource
|
from metadata.ingestion.source.database.postgres.lineage import PostgresLineageSource
|
||||||
from metadata.ingestion.source.database.postgres.metadata import PostgresSource
|
from metadata.ingestion.source.database.postgres.metadata import PostgresSource
|
||||||
from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource
|
from metadata.ingestion.source.database.postgres.usage import PostgresUsageSource
|
||||||
@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec(
|
|||||||
metadata_source_class=PostgresSource,
|
metadata_source_class=PostgresSource,
|
||||||
lineage_source_class=PostgresLineageSource,
|
lineage_source_class=PostgresLineageSource,
|
||||||
usage_source_class=PostgresUsageSource,
|
usage_source_class=PostgresUsageSource,
|
||||||
|
connection_class=PostgresConnection,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -20,7 +20,6 @@ from requests import Session
|
|||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
|
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
|
||||||
|
|
||||||
from metadata.clients.azure_client import AzureClient
|
|
||||||
from metadata.generated.schema.entity.automations.workflow import (
|
from metadata.generated.schema.entity.automations.workflow import (
|
||||||
Workflow as AutomationWorkflow,
|
Workflow as AutomationWorkflow,
|
||||||
)
|
)
|
||||||
@ -53,6 +52,7 @@ from metadata.ingestion.connections.test_connections import (
|
|||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
|
from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
|
||||||
from metadata.utils.constants import THREE_MIN
|
from metadata.utils.constants import THREE_MIN
|
||||||
|
from metadata.utils.credentials import get_azure_access_token
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@ -82,17 +82,11 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
|
|||||||
connection_copy = deepcopy(connection)
|
connection_copy = deepcopy(connection)
|
||||||
|
|
||||||
if hasattr(connection.authType, "azureConfig"):
|
if hasattr(connection.authType, "azureConfig"):
|
||||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
|
||||||
if not connection.authType.azureConfig.scopes:
|
access_token = get_azure_access_token(auth_type)
|
||||||
raise ValueError(
|
|
||||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
|
||||||
)
|
|
||||||
access_token_obj = azure_client.get_token(
|
|
||||||
*connection.authType.azureConfig.scopes.split(",")
|
|
||||||
)
|
|
||||||
if not connection.connectionOptions:
|
if not connection.connectionOptions:
|
||||||
connection.connectionOptions = init_empty_connection_options()
|
connection.connectionOptions = init_empty_connection_options()
|
||||||
connection.connectionOptions.root["access_token"] = access_token_obj.token
|
connection.connectionOptions.root["access_token"] = access_token
|
||||||
|
|
||||||
# Update the connection with the connection arguments
|
# Update the connection with the connection arguments
|
||||||
connection_copy.connectionArguments = self.build_connection_args(
|
connection_copy.connectionArguments = self.build_connection_args(
|
||||||
@ -355,12 +349,4 @@ class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
|
|||||||
Get the azure token for the trino connection
|
Get the azure token for the trino connection
|
||||||
"""
|
"""
|
||||||
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
|
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
|
||||||
|
return get_azure_access_token(auth_type)
|
||||||
if not auth_type.azureConfig.scopes:
|
|
||||||
raise ValueError(
|
|
||||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
|
||||||
)
|
|
||||||
|
|
||||||
azure_client = AzureClient(auth_type.azureConfig).create_client()
|
|
||||||
|
|
||||||
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class DatalakeSampler(SamplerInterface, PandasInterfaceMixin):
|
|||||||
return self._table
|
return self._table
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self):
|
||||||
return self.connection.client
|
return self.connection
|
||||||
|
|
||||||
def _partitioned_table(self):
|
def _partitioned_table(self):
|
||||||
"""Get partitioned table"""
|
"""Get partitioned table"""
|
||||||
|
|||||||
@ -21,6 +21,10 @@ from cryptography.hazmat.primitives import serialization
|
|||||||
from google import auth
|
from google import auth
|
||||||
from google.auth import impersonated_credentials
|
from google.auth import impersonated_credentials
|
||||||
|
|
||||||
|
from metadata.clients.azure_client import AzureClient
|
||||||
|
from metadata.generated.schema.entity.services.connections.database.common.azureConfig import (
|
||||||
|
AzureConfigurationSource,
|
||||||
|
)
|
||||||
from metadata.generated.schema.security.credentials.gcpCredentials import (
|
from metadata.generated.schema.security.credentials.gcpCredentials import (
|
||||||
GcpADC,
|
GcpADC,
|
||||||
GCPCredentials,
|
GCPCredentials,
|
||||||
@ -237,3 +241,34 @@ def get_gcp_impersonate_credentials(
|
|||||||
target_scopes=scopes,
|
target_scopes=scopes,
|
||||||
lifetime=lifetime,
|
lifetime=lifetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_access_token(azure_config: AzureConfigurationSource) -> str:
|
||||||
|
"""
|
||||||
|
Get Azure access token using the provided Azure configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
azure_config: Azure configuration containing the necessary credentials and scopes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The access token
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If Azure config is missing or scopes are not provided
|
||||||
|
"""
|
||||||
|
if not azure_config.azureConfig:
|
||||||
|
raise ValueError("Azure Config is missing")
|
||||||
|
|
||||||
|
if not azure_config.azureConfig.scopes:
|
||||||
|
raise ValueError(
|
||||||
|
"Azure Scopes are missing, please refer https://learn.microsoft.com/"
|
||||||
|
"en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token "
|
||||||
|
"and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_client = AzureClient(azure_config.azureConfig).create_client()
|
||||||
|
access_token_obj = azure_client.get_token(
|
||||||
|
*azure_config.azureConfig.scopes.split(",")
|
||||||
|
)
|
||||||
|
|
||||||
|
return access_token_obj.token
|
||||||
|
|||||||
@ -16,19 +16,17 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
|
|||||||
BasicAuth,
|
BasicAuth,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
|
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
|
||||||
MysqlConnection,
|
MysqlConnection as MysqlConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
||||||
PostgresConnection,
|
PostgresConnection as PostgresConnectionConfig,
|
||||||
)
|
)
|
||||||
from metadata.generated.schema.security.credentials.azureCredentials import (
|
from metadata.generated.schema.security.credentials.azureCredentials import (
|
||||||
AzureCredentials,
|
AzureCredentials,
|
||||||
)
|
)
|
||||||
from metadata.ingestion.source.database.azuresql.connection import get_connection_url
|
from metadata.ingestion.source.database.azuresql.connection import get_connection_url
|
||||||
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
|
from metadata.ingestion.source.database.mysql.connection import MySQLConnection
|
||||||
from metadata.ingestion.source.database.postgres.connection import (
|
from metadata.ingestion.source.database.postgres.connection import PostgresConnection
|
||||||
get_connection as postgres_get_connection,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetConnectionURL(unittest.TestCase):
|
class TestGetConnectionURL(unittest.TestCase):
|
||||||
@ -64,7 +62,7 @@ class TestGetConnectionURL(unittest.TestCase):
|
|||||||
self.assertEqual(str(get_connection_url(connection)), expected_url)
|
self.assertEqual(str(get_connection_url(connection)), expected_url)
|
||||||
|
|
||||||
def test_get_connection_url_mysql(self):
|
def test_get_connection_url_mysql(self):
|
||||||
connection = MysqlConnection(
|
connection = MysqlConnectionConfig(
|
||||||
username="openmetadata_user",
|
username="openmetadata_user",
|
||||||
authType=BasicAuth(password="openmetadata_password"),
|
authType=BasicAuth(password="openmetadata_password"),
|
||||||
hostPort="localhost:3306",
|
hostPort="localhost:3306",
|
||||||
@ -75,7 +73,7 @@ class TestGetConnectionURL(unittest.TestCase):
|
|||||||
str(engine_connection.url),
|
str(engine_connection.url),
|
||||||
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
||||||
)
|
)
|
||||||
connection = MysqlConnection(
|
connection = MysqlConnectionConfig(
|
||||||
username="openmetadata_user",
|
username="openmetadata_user",
|
||||||
authType=AzureConfigurationSource(
|
authType=AzureConfigurationSource(
|
||||||
azureConfig=AzureCredentials(
|
azureConfig=AzureCredentials(
|
||||||
@ -100,18 +98,18 @@ class TestGetConnectionURL(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_get_connection_url_postgres(self):
|
def test_get_connection_url_postgres(self):
|
||||||
connection = PostgresConnection(
|
connection = PostgresConnectionConfig(
|
||||||
username="openmetadata_user",
|
username="openmetadata_user",
|
||||||
authType=BasicAuth(password="openmetadata_password"),
|
authType=BasicAuth(password="openmetadata_password"),
|
||||||
hostPort="localhost:3306",
|
hostPort="localhost:3306",
|
||||||
database="openmetadata_db",
|
database="openmetadata_db",
|
||||||
)
|
)
|
||||||
engine_connection = postgres_get_connection(connection)
|
engine_connection = PostgresConnection(connection).client
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(engine_connection.url),
|
str(engine_connection.url),
|
||||||
"postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
"postgresql+psycopg2://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
||||||
)
|
)
|
||||||
connection = PostgresConnection(
|
connection = PostgresConnectionConfig(
|
||||||
username="openmetadata_user",
|
username="openmetadata_user",
|
||||||
authType=AzureConfigurationSource(
|
authType=AzureConfigurationSource(
|
||||||
azureConfig=AzureCredentials(
|
azureConfig=AzureCredentials(
|
||||||
@ -129,7 +127,7 @@ class TestGetConnectionURL(unittest.TestCase):
|
|||||||
"get_token",
|
"get_token",
|
||||||
return_value=AccessToken(token="mocked_token", expires_on=100),
|
return_value=AccessToken(token="mocked_token", expires_on=100),
|
||||||
):
|
):
|
||||||
engine_connection = postgres_get_connection(connection)
|
engine_connection = PostgresConnection(connection).client
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(engine_connection.url),
|
str(engine_connection.url),
|
||||||
"postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",
|
"postgresql+psycopg2://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user