From 3d2dfeb583307e0eac19adb78f03e9897ca2afb8 Mon Sep 17 00:00:00 2001 From: mgorsk1 Date: Fri, 15 Nov 2024 08:24:42 +0100 Subject: [PATCH] feat: use native trino client authentication classes (#16196) --------- Co-authored-by: ulixius9 --- .../source/database/trino/connection.py | 83 ++++++++++++++----- .../sqlalchemy/stored_statistics_profiler.py | 2 +- .../tests/unit/test_source_connection.py | 79 +++++++++++++++--- .../database/common/azureConfig.json | 2 +- .../common/noConfigAuthenticationTypes.json | 12 +++ .../connections/database/trinoConnection.json | 3 + 6 files changed, 145 insertions(+), 36 deletions(-) create mode 100644 openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/noConfigAuthenticationTypes.json diff --git a/ingestion/src/metadata/ingestion/source/database/trino/connection.py b/ingestion/src/metadata/ingestion/source/database/trino/connection.py index 6ee18950850..6dbae4ac9c3 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/trino/connection.py @@ -12,11 +12,13 @@ """ Source connection handler """ +from copy import deepcopy from typing import Optional from urllib.parse import quote_plus from requests import Session from sqlalchemy.engine import Engine +from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( @@ -25,6 +27,7 @@ from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.services.connections.database.common import ( basicAuth, jwtAuth, + noConfigAuthenticationTypes, ) from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( TrinoConnection, @@ -36,7 +39,6 @@ from metadata.ingestion.connections.builders import ( create_generic_db_connection, get_connection_args_common, init_empty_connection_arguments, - init_empty_connection_options, ) from metadata.ingestion.connections.secrets import connection_with_options_secrets from metadata.ingestion.connections.test_connections import ( @@ -52,26 +54,17 @@ def get_connection_url(connection: TrinoConnection) -> str: Prepare the connection url for trino """ url = f"{connection.scheme.value}://" + + # leaving username here as, even though with basic auth is used directly + # in BasicAuthentication class, it's often also required as a part of url. + # For example - it will be used by OAuth2Authentication to persist token in + # cache more efficiently (per user instead of per host) if connection.username: - # we need to encode twice because trino dialect internally - # url decodes the username and if there is an special char in username - # it will fail to authenticate - url += f"{quote_plus(quote_plus(connection.username))}" - if ( - isinstance(connection.authType, basicAuth.BasicAuth) - and connection.authType.password - ): - url += f":{quote_plus(connection.authType.password.get_secret_value())}" - url += "@" + url += f"{quote_plus(connection.username)}@" + url += f"{connection.hostPort}" if connection.catalog: url += f"/{connection.catalog}" - if isinstance(connection.authType, jwtAuth.JwtAuth): - if not connection.connectionOptions: - connection.connectionOptions = init_empty_connection_options() - connection.connectionOptions.root[ - "access_token" - ] = connection.authType.jwt.get_secret_value() if connection.connectionOptions is not None: params = "&".join( f"{key}={quote_plus(value)}" @@ -84,14 +77,54 @@ def get_connection_url(connection: TrinoConnection) -> str: @connection_with_options_secrets def get_connection_args(connection: TrinoConnection): + if not connection.connectionArguments: + connection.connectionArguments = init_empty_connection_arguments() + if connection.proxies: session = Session() session.proxies = connection.proxies - if not connection.connectionArguments: - connection.connectionArguments = init_empty_connection_arguments() connection.connectionArguments.root["http_session"] = session + if isinstance(connection.authType, basicAuth.BasicAuth): + connection.connectionArguments.root["auth"] = BasicAuthentication( + connection.username, + connection.authType.password.get_secret_value() + if connection.authType.password + else None, + ) + connection.connectionArguments.root["http_scheme"] = "https" + + elif isinstance(connection.authType, jwtAuth.JwtAuth): + connection.connectionArguments.root["auth"] = JWTAuthentication( + connection.authType.jwt.get_secret_value() + ) + connection.connectionArguments.root["http_scheme"] = "https" + + elif hasattr(connection.authType, "azureConfig"): + if not connection.authType.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + + azure_client = AzureClient(connection.authType.azureConfig).create_client() + + access_token_obj = azure_client.get_token( + *connection.authType.azureConfig.scopes.split(",") + ) + + connection.connectionArguments.root["auth"] = JWTAuthentication( + access_token_obj.token + ) + connection.connectionArguments.root["http_scheme"] = "https" + + elif ( + connection.authType + == noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2 + ): + connection.connectionArguments.root["auth"] = OAuth2Authentication() + connection.connectionArguments.root["http_scheme"] = "https" + return get_connection_args_common(connection) @@ -99,9 +132,13 @@ def get_connection(connection: TrinoConnection) -> Engine: """ Create connection """ - if connection.verify: - connection.connectionArguments = ( - connection.connectionArguments or init_empty_connection_arguments() + # here we are creating a copy of connection, because we need to dynamically + # add auth params to connectionArguments, which we do no intend to store + # in original connection object and in OpenMetadata database + connection_copy = deepcopy(connection) + if connection_copy.verify: + connection_copy.connectionArguments = ( + connection_copy.connectionArguments or init_empty_connection_arguments() ) connection.connectionArguments.root["verify"] = {"verify": connection.verify} if hasattr(connection.authType, "azureConfig"): @@ -117,7 +154,7 @@ def get_connection(connection: TrinoConnection) -> Engine: connection.connectionOptions = init_empty_connection_options() connection.connectionOptions.root["access_token"] = access_token_obj.token return create_generic_db_connection( - connection=connection, + connection=connection_copy, get_connection_url_fn=get_connection_url, get_connection_args_fn=get_connection_args, ) diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/stored_statistics_profiler.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/stored_statistics_profiler.py index cf7fb2123fd..63612cd9535 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/stored_statistics_profiler.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/stored_statistics_profiler.py @@ -17,10 +17,10 @@ supporting sqlalchemy abstraction layer import threading from typing import Any, Dict, List, Set -from mlflow.protos.databricks_uc_registry_messages_pb2 import Table from more_itertools import partition from sqlalchemy import Column +from metadata.generated.schema.entity.data.table import Table from metadata.mixins.sqalchemy.sqa_mixin import Root from metadata.profiler.interface.sqlalchemy.profiler_interface import ( SQAProfilerInterface, diff --git a/ingestion/tests/unit/test_source_connection.py b/ingestion/tests/unit/test_source_connection.py index d1984048980..4912097f336 100644 --- a/ingestion/tests/unit/test_source_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -11,6 +11,8 @@ from unittest import TestCase +from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication + from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( AthenaConnection, AthenaScheme, @@ -19,6 +21,9 @@ from metadata.generated.schema.entity.services.connections.database.clickhouseCo ClickhouseConnection, ClickhouseScheme, ) +from metadata.generated.schema.entity.services.connections.database.common import ( + noConfigAuthenticationTypes, +) from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( BasicAuth, ) @@ -107,6 +112,7 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_url_common, ) +from metadata.ingestion.source.database.trino.connection import get_connection_args # pylint: disable=import-outside-toplevel @@ -401,7 +407,7 @@ class SourceConnectionTest(TestCase): get_connection_url, ) - expected_url = "trino://username:pass@localhost:443/catalog" + expected_url = "trino://username@localhost:443/catalog" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, hostPort="localhost:443", @@ -413,7 +419,7 @@ class SourceConnectionTest(TestCase): assert expected_url == get_connection_url(trino_conn_obj) # Passing @ in username and password - expected_url = "trino://username%2540444:pass%40111@localhost:443/catalog" + expected_url = "trino://username%40444@localhost:443/catalog" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, hostPort="localhost:443", @@ -430,7 +436,10 @@ class SourceConnectionTest(TestCase): ) # connection arguments without connectionArguments and without proxies - expected_args = {} + expected_args = { + "auth": BasicAuthentication("user", None), + "http_scheme": "https", + } trino_conn_obj = TrinoConnection( username="user", authType=BasicAuth(password=None), @@ -442,7 +451,11 @@ class SourceConnectionTest(TestCase): assert expected_args == get_connection_args(trino_conn_obj) # connection arguments with connectionArguments and without proxies - expected_args = {"user": "user-to-be-impersonated"} + expected_args = { + "user": "user-to-be-impersonated", + "auth": BasicAuthentication("user", None), + "http_scheme": "https", + } trino_conn_obj = TrinoConnection( username="user", authType=BasicAuth(password=None), @@ -454,7 +467,10 @@ class SourceConnectionTest(TestCase): assert expected_args == get_connection_args(trino_conn_obj) # connection arguments without connectionArguments and with proxies - expected_args = {} + expected_args = { + "auth": BasicAuthentication("user", None), + "http_scheme": "https", + } trino_conn_obj = TrinoConnection( username="user", authType=BasicAuth(password=None), @@ -470,7 +486,11 @@ class SourceConnectionTest(TestCase): assert expected_args == conn_args # connection arguments with connectionArguments and with proxies - expected_args = {"user": "user-to-be-impersonated"} + expected_args = { + "user": "user-to-be-impersonated", + "auth": BasicAuthentication("user", None), + "http_scheme": "https", + } trino_conn_obj = TrinoConnection( username="user", authType=BasicAuth(password=None), @@ -490,7 +510,7 @@ class SourceConnectionTest(TestCase): get_connection_url, ) - expected_url = "trino://username:pass@localhost:443/catalog?param=value" + expected_url = "trino://username@localhost:443/catalog?param=value" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, hostPort="localhost:443", @@ -506,9 +526,11 @@ class SourceConnectionTest(TestCase): get_connection_url, ) - expected_url = ( - "trino://username@localhost:443/catalog?access_token=jwt_token_value" - ) + expected_url = "trino://username@localhost:443/catalog" + expected_args = { + "auth": JWTAuthentication("jwt_token_value"), + "http_scheme": "https", + } trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, hostPort="localhost:443", @@ -517,6 +539,7 @@ class SourceConnectionTest(TestCase): catalog="catalog", ) assert expected_url == get_connection_url(trino_conn_obj) + assert expected_args == get_connection_args(trino_conn_obj) def test_trino_with_proxies(self): from metadata.ingestion.source.database.trino.connection import ( @@ -543,7 +566,7 @@ class SourceConnectionTest(TestCase): ) # Test trino url without catalog - expected_url = "trino://username:pass@localhost:443" + expected_url = "trino://username@localhost:443" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, hostPort="localhost:443", @@ -553,6 +576,40 @@ class SourceConnectionTest(TestCase): assert expected_url == get_connection_url(trino_conn_obj) + def test_trino_without_catalog(self): + from metadata.ingestion.source.database.trino.connection import ( + get_connection_url, + ) + + # Test trino url without catalog + expected_url = "trino://username@localhost:443" + trino_conn_obj = TrinoConnection( + scheme=TrinoScheme.trino, + hostPort="localhost:443", + username="username", + authType=BasicAuth(password="pass"), + ) + + assert expected_url == get_connection_url(trino_conn_obj) + + def test_trino_with_oauth2(self): + from metadata.ingestion.source.database.trino.connection import ( + get_connection_url, + ) + + # Test trino url without catalog + expected_url = "trino://username@localhost:443" + trino_conn_obj = TrinoConnection( + scheme=TrinoScheme.trino, + hostPort="localhost:443", + username="username", + authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2, + ) + + assert isinstance( + get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication + ) + def test_vertica_url(self): expected_url = ( "vertica+vertica_python://username:password@localhost:5443/database" diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json index 364f69347c8..65f16199af9 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/azureConfig.json @@ -12,4 +12,4 @@ } }, "additionalProperties": false -} \ No newline at end of file +} diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/noConfigAuthenticationTypes.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/noConfigAuthenticationTypes.json new file mode 100644 index 00000000000..0b20ce14c18 --- /dev/null +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/common/noConfigAuthenticationTypes.json @@ -0,0 +1,12 @@ +{ + "$id": "https://open-metadata.org/schema/entity/services/connections/database/noConfigAuthenticationTypes.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "No Config Authentication Types", + "javaType": "org.openmetadata.schema.services.connections.database.common.NoConfigAuthenticationTypes", + "description": "Database Authentication types not requiring config.", + "type": "string", + "enum": [ + "OAuth2" + ], + "default": "OAuth2" +} diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/trinoConnection.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/trinoConnection.json index df1b761be8f..8bc8f28d124 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/trinoConnection.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/database/trinoConnection.json @@ -49,6 +49,9 @@ }, { "$ref": "./common/azureConfig.json" + }, + { + "$ref": "./common/noConfigAuthenticationTypes.json" } ] },