feat: use native trino client authentication classes (#16196)

---------

Co-authored-by: ulixius9 <mayursingal9@gmail.com>
This commit is contained in:
mgorsk1 2024-11-15 08:24:42 +01:00 committed by GitHub
parent bde6ee4125
commit 3d2dfeb583
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 145 additions and 36 deletions

View File

@ -12,11 +12,13 @@
""" """
Source connection handler Source connection handler
""" """
from copy import deepcopy
from typing import Optional from typing import Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
from requests import Session from requests import Session
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
from metadata.clients.azure_client import AzureClient from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
@ -25,6 +27,7 @@ from metadata.generated.schema.entity.automations.workflow import (
from metadata.generated.schema.entity.services.connections.database.common import ( from metadata.generated.schema.entity.services.connections.database.common import (
basicAuth, basicAuth,
jwtAuth, jwtAuth,
noConfigAuthenticationTypes,
) )
from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection, TrinoConnection,
@ -36,7 +39,6 @@ from metadata.ingestion.connections.builders import (
create_generic_db_connection, create_generic_db_connection,
get_connection_args_common, get_connection_args_common,
init_empty_connection_arguments, init_empty_connection_arguments,
init_empty_connection_options,
) )
from metadata.ingestion.connections.secrets import connection_with_options_secrets from metadata.ingestion.connections.secrets import connection_with_options_secrets
from metadata.ingestion.connections.test_connections import ( from metadata.ingestion.connections.test_connections import (
@ -52,26 +54,17 @@ def get_connection_url(connection: TrinoConnection) -> str:
Prepare the connection url for trino Prepare the connection url for trino
""" """
url = f"{connection.scheme.value}://" url = f"{connection.scheme.value}://"
# leaving username here as, even though with basic auth is used directly
# in BasicAuthentication class, it's often also required as a part of url.
# For example - it will be used by OAuth2Authentication to persist token in
# cache more efficiently (per user instead of per host)
if connection.username: if connection.username:
# we need to encode twice because trino dialect internally url += f"{quote_plus(connection.username)}@"
# url decodes the username and if there is an special char in username
# it will fail to authenticate
url += f"{quote_plus(quote_plus(connection.username))}"
if (
isinstance(connection.authType, basicAuth.BasicAuth)
and connection.authType.password
):
url += f":{quote_plus(connection.authType.password.get_secret_value())}"
url += "@"
url += f"{connection.hostPort}" url += f"{connection.hostPort}"
if connection.catalog: if connection.catalog:
url += f"/{connection.catalog}" url += f"/{connection.catalog}"
if isinstance(connection.authType, jwtAuth.JwtAuth):
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root[
"access_token"
] = connection.authType.jwt.get_secret_value()
if connection.connectionOptions is not None: if connection.connectionOptions is not None:
params = "&".join( params = "&".join(
f"{key}={quote_plus(value)}" f"{key}={quote_plus(value)}"
@ -84,14 +77,54 @@ def get_connection_url(connection: TrinoConnection) -> str:
@connection_with_options_secrets @connection_with_options_secrets
def get_connection_args(connection: TrinoConnection): def get_connection_args(connection: TrinoConnection):
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
if connection.proxies: if connection.proxies:
session = Session() session = Session()
session.proxies = connection.proxies session.proxies = connection.proxies
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
connection.connectionArguments.root["http_session"] = session connection.connectionArguments.root["http_session"] = session
if isinstance(connection.authType, basicAuth.BasicAuth):
connection.connectionArguments.root["auth"] = BasicAuthentication(
connection.username,
connection.authType.password.get_secret_value()
if connection.authType.password
else None,
)
connection.connectionArguments.root["http_scheme"] = "https"
elif isinstance(connection.authType, jwtAuth.JwtAuth):
connection.connectionArguments.root["auth"] = JWTAuthentication(
connection.authType.jwt.get_secret_value()
)
connection.connectionArguments.root["http_scheme"] = "https"
elif hasattr(connection.authType, "azureConfig"):
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(connection.authType.azureConfig).create_client()
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.connectionArguments.root["auth"] = JWTAuthentication(
access_token_obj.token
)
connection.connectionArguments.root["http_scheme"] = "https"
elif (
connection.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
connection.connectionArguments.root["auth"] = OAuth2Authentication()
connection.connectionArguments.root["http_scheme"] = "https"
return get_connection_args_common(connection) return get_connection_args_common(connection)
@ -99,9 +132,13 @@ def get_connection(connection: TrinoConnection) -> Engine:
""" """
Create connection Create connection
""" """
if connection.verify: # here we are creating a copy of connection, because we need to dynamically
connection.connectionArguments = ( # add auth params to connectionArguments, which we do no intend to store
connection.connectionArguments or init_empty_connection_arguments() # in original connection object and in OpenMetadata database
connection_copy = deepcopy(connection)
if connection_copy.verify:
connection_copy.connectionArguments = (
connection_copy.connectionArguments or init_empty_connection_arguments()
) )
connection.connectionArguments.root["verify"] = {"verify": connection.verify} connection.connectionArguments.root["verify"] = {"verify": connection.verify}
if hasattr(connection.authType, "azureConfig"): if hasattr(connection.authType, "azureConfig"):
@ -117,7 +154,7 @@ def get_connection(connection: TrinoConnection) -> Engine:
connection.connectionOptions = init_empty_connection_options() connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token connection.connectionOptions.root["access_token"] = access_token_obj.token
return create_generic_db_connection( return create_generic_db_connection(
connection=connection, connection=connection_copy,
get_connection_url_fn=get_connection_url, get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args, get_connection_args_fn=get_connection_args,
) )

View File

@ -17,10 +17,10 @@ supporting sqlalchemy abstraction layer
import threading import threading
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from mlflow.protos.databricks_uc_registry_messages_pb2 import Table
from more_itertools import partition from more_itertools import partition
from sqlalchemy import Column from sqlalchemy import Column
from metadata.generated.schema.entity.data.table import Table
from metadata.mixins.sqalchemy.sqa_mixin import Root from metadata.mixins.sqalchemy.sqa_mixin import Root
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface, SQAProfilerInterface,

View File

@ -11,6 +11,8 @@
from unittest import TestCase from unittest import TestCase
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
AthenaConnection, AthenaConnection,
AthenaScheme, AthenaScheme,
@ -19,6 +21,9 @@ from metadata.generated.schema.entity.services.connections.database.clickhouseCo
ClickhouseConnection, ClickhouseConnection,
ClickhouseScheme, ClickhouseScheme,
) )
from metadata.generated.schema.entity.services.connections.database.common import (
noConfigAuthenticationTypes,
)
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth, BasicAuth,
) )
@ -107,6 +112,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common, get_connection_args_common,
get_connection_url_common, get_connection_url_common,
) )
from metadata.ingestion.source.database.trino.connection import get_connection_args
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -401,7 +407,7 @@ class SourceConnectionTest(TestCase):
get_connection_url, get_connection_url,
) )
expected_url = "trino://username:pass@localhost:443/catalog" expected_url = "trino://username@localhost:443/catalog"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
@ -413,7 +419,7 @@ class SourceConnectionTest(TestCase):
assert expected_url == get_connection_url(trino_conn_obj) assert expected_url == get_connection_url(trino_conn_obj)
# Passing @ in username and password # Passing @ in username and password
expected_url = "trino://username%2540444:pass%40111@localhost:443/catalog" expected_url = "trino://username%40444@localhost:443/catalog"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
@ -430,7 +436,10 @@ class SourceConnectionTest(TestCase):
) )
# connection arguments without connectionArguments and without proxies # connection arguments without connectionArguments and without proxies
expected_args = {} expected_args = {
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
@ -442,7 +451,11 @@ class SourceConnectionTest(TestCase):
assert expected_args == get_connection_args(trino_conn_obj) assert expected_args == get_connection_args(trino_conn_obj)
# connection arguments with connectionArguments and without proxies # connection arguments with connectionArguments and without proxies
expected_args = {"user": "user-to-be-impersonated"} expected_args = {
"user": "user-to-be-impersonated",
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
@ -454,7 +467,10 @@ class SourceConnectionTest(TestCase):
assert expected_args == get_connection_args(trino_conn_obj) assert expected_args == get_connection_args(trino_conn_obj)
# connection arguments without connectionArguments and with proxies # connection arguments without connectionArguments and with proxies
expected_args = {} expected_args = {
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
@ -470,7 +486,11 @@ class SourceConnectionTest(TestCase):
assert expected_args == conn_args assert expected_args == conn_args
# connection arguments with connectionArguments and with proxies # connection arguments with connectionArguments and with proxies
expected_args = {"user": "user-to-be-impersonated"} expected_args = {
"user": "user-to-be-impersonated",
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
@ -490,7 +510,7 @@ class SourceConnectionTest(TestCase):
get_connection_url, get_connection_url,
) )
expected_url = "trino://username:pass@localhost:443/catalog?param=value" expected_url = "trino://username@localhost:443/catalog?param=value"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
@ -506,9 +526,11 @@ class SourceConnectionTest(TestCase):
get_connection_url, get_connection_url,
) )
expected_url = ( expected_url = "trino://username@localhost:443/catalog"
"trino://username@localhost:443/catalog?access_token=jwt_token_value" expected_args = {
) "auth": JWTAuthentication("jwt_token_value"),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
@ -517,6 +539,7 @@ class SourceConnectionTest(TestCase):
catalog="catalog", catalog="catalog",
) )
assert expected_url == get_connection_url(trino_conn_obj) assert expected_url == get_connection_url(trino_conn_obj)
assert expected_args == get_connection_args(trino_conn_obj)
def test_trino_with_proxies(self): def test_trino_with_proxies(self):
from metadata.ingestion.source.database.trino.connection import ( from metadata.ingestion.source.database.trino.connection import (
@ -543,7 +566,7 @@ class SourceConnectionTest(TestCase):
) )
# Test trino url without catalog # Test trino url without catalog
expected_url = "trino://username:pass@localhost:443" expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
@ -553,6 +576,40 @@ class SourceConnectionTest(TestCase):
assert expected_url == get_connection_url(trino_conn_obj) assert expected_url == get_connection_url(trino_conn_obj)
def test_trino_without_catalog(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog
expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=BasicAuth(password="pass"),
)
assert expected_url == get_connection_url(trino_conn_obj)
def test_trino_with_oauth2(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog
expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2,
)
assert isinstance(
get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication
)
def test_vertica_url(self): def test_vertica_url(self):
expected_url = ( expected_url = (
"vertica+vertica_python://username:password@localhost:5443/database" "vertica+vertica_python://username:password@localhost:5443/database"

View File

@ -12,4 +12,4 @@
} }
}, },
"additionalProperties": false "additionalProperties": false
} }

View File

@ -0,0 +1,12 @@
{
"$id": "https://open-metadata.org/schema/entity/services/connections/database/noConfigAuthenticationTypes.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "No Config Authentication Types",
"javaType": "org.openmetadata.schema.services.connections.database.common.NoConfigAuthenticationTypes",
"description": "Database Authentication types not requiring config.",
"type": "string",
"enum": [
"OAuth2"
],
"default": "OAuth2"
}

View File

@ -49,6 +49,9 @@
}, },
{ {
"$ref": "./common/azureConfig.json" "$ref": "./common/azureConfig.json"
},
{
"$ref": "./common/noConfigAuthenticationTypes.json"
} }
] ]
}, },