mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-18 05:57:17 +00:00
feat: use native trino client authentication classes (#16196)
--------- Co-authored-by: ulixius9 <mayursingal9@gmail.com>
This commit is contained in:
parent
bde6ee4125
commit
3d2dfeb583
@ -12,11 +12,13 @@
|
||||
"""
|
||||
Source connection handler
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from requests import Session
|
||||
from sqlalchemy.engine import Engine
|
||||
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
|
||||
|
||||
from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
@ -25,6 +27,7 @@ from metadata.generated.schema.entity.automations.workflow import (
|
||||
from metadata.generated.schema.entity.services.connections.database.common import (
|
||||
basicAuth,
|
||||
jwtAuth,
|
||||
noConfigAuthenticationTypes,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
TrinoConnection,
|
||||
@ -36,7 +39,6 @@ from metadata.ingestion.connections.builders import (
|
||||
create_generic_db_connection,
|
||||
get_connection_args_common,
|
||||
init_empty_connection_arguments,
|
||||
init_empty_connection_options,
|
||||
)
|
||||
from metadata.ingestion.connections.secrets import connection_with_options_secrets
|
||||
from metadata.ingestion.connections.test_connections import (
|
||||
@ -52,26 +54,17 @@ def get_connection_url(connection: TrinoConnection) -> str:
|
||||
Prepare the connection url for trino
|
||||
"""
|
||||
url = f"{connection.scheme.value}://"
|
||||
|
||||
# leaving username here as, even though with basic auth is used directly
|
||||
# in BasicAuthentication class, it's often also required as a part of url.
|
||||
# For example - it will be used by OAuth2Authentication to persist token in
|
||||
# cache more efficiently (per user instead of per host)
|
||||
if connection.username:
|
||||
# we need to encode twice because trino dialect internally
|
||||
# url decodes the username and if there is an special char in username
|
||||
# it will fail to authenticate
|
||||
url += f"{quote_plus(quote_plus(connection.username))}"
|
||||
if (
|
||||
isinstance(connection.authType, basicAuth.BasicAuth)
|
||||
and connection.authType.password
|
||||
):
|
||||
url += f":{quote_plus(connection.authType.password.get_secret_value())}"
|
||||
url += "@"
|
||||
url += f"{quote_plus(connection.username)}@"
|
||||
|
||||
url += f"{connection.hostPort}"
|
||||
if connection.catalog:
|
||||
url += f"/{connection.catalog}"
|
||||
if isinstance(connection.authType, jwtAuth.JwtAuth):
|
||||
if not connection.connectionOptions:
|
||||
connection.connectionOptions = init_empty_connection_options()
|
||||
connection.connectionOptions.root[
|
||||
"access_token"
|
||||
] = connection.authType.jwt.get_secret_value()
|
||||
if connection.connectionOptions is not None:
|
||||
params = "&".join(
|
||||
f"{key}={quote_plus(value)}"
|
||||
@ -84,14 +77,54 @@ def get_connection_url(connection: TrinoConnection) -> str:
|
||||
|
||||
@connection_with_options_secrets
|
||||
def get_connection_args(connection: TrinoConnection):
|
||||
if connection.proxies:
|
||||
session = Session()
|
||||
session.proxies = connection.proxies
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
|
||||
if connection.proxies:
|
||||
session = Session()
|
||||
session.proxies = connection.proxies
|
||||
|
||||
connection.connectionArguments.root["http_session"] = session
|
||||
|
||||
if isinstance(connection.authType, basicAuth.BasicAuth):
|
||||
connection.connectionArguments.root["auth"] = BasicAuthentication(
|
||||
connection.username,
|
||||
connection.authType.password.get_secret_value()
|
||||
if connection.authType.password
|
||||
else None,
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
|
||||
elif isinstance(connection.authType, jwtAuth.JwtAuth):
|
||||
connection.connectionArguments.root["auth"] = JWTAuthentication(
|
||||
connection.authType.jwt.get_secret_value()
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
|
||||
elif hasattr(connection.authType, "azureConfig"):
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
|
||||
connection.connectionArguments.root["auth"] = JWTAuthentication(
|
||||
access_token_obj.token
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
|
||||
elif (
|
||||
connection.authType
|
||||
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
|
||||
):
|
||||
connection.connectionArguments.root["auth"] = OAuth2Authentication()
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
|
||||
return get_connection_args_common(connection)
|
||||
|
||||
|
||||
@ -99,9 +132,13 @@ def get_connection(connection: TrinoConnection) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
if connection.verify:
|
||||
connection.connectionArguments = (
|
||||
connection.connectionArguments or init_empty_connection_arguments()
|
||||
# here we are creating a copy of connection, because we need to dynamically
|
||||
# add auth params to connectionArguments, which we do no intend to store
|
||||
# in original connection object and in OpenMetadata database
|
||||
connection_copy = deepcopy(connection)
|
||||
if connection_copy.verify:
|
||||
connection_copy.connectionArguments = (
|
||||
connection_copy.connectionArguments or init_empty_connection_arguments()
|
||||
)
|
||||
connection.connectionArguments.root["verify"] = {"verify": connection.verify}
|
||||
if hasattr(connection.authType, "azureConfig"):
|
||||
@ -117,7 +154,7 @@ def get_connection(connection: TrinoConnection) -> Engine:
|
||||
connection.connectionOptions = init_empty_connection_options()
|
||||
connection.connectionOptions.root["access_token"] = access_token_obj.token
|
||||
return create_generic_db_connection(
|
||||
connection=connection,
|
||||
connection=connection_copy,
|
||||
get_connection_url_fn=get_connection_url,
|
||||
get_connection_args_fn=get_connection_args,
|
||||
)
|
||||
|
@ -17,10 +17,10 @@ supporting sqlalchemy abstraction layer
|
||||
import threading
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from mlflow.protos.databricks_uc_registry_messages_pb2 import Table
|
||||
from more_itertools import partition
|
||||
from sqlalchemy import Column
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.mixins.sqalchemy.sqa_mixin import Root
|
||||
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
|
||||
SQAProfilerInterface,
|
||||
|
@ -11,6 +11,8 @@
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
|
||||
AthenaConnection,
|
||||
AthenaScheme,
|
||||
@ -19,6 +21,9 @@ from metadata.generated.schema.entity.services.connections.database.clickhouseCo
|
||||
ClickhouseConnection,
|
||||
ClickhouseScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common import (
|
||||
noConfigAuthenticationTypes,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
|
||||
BasicAuth,
|
||||
)
|
||||
@ -107,6 +112,7 @@ from metadata.ingestion.connections.builders import (
|
||||
get_connection_args_common,
|
||||
get_connection_url_common,
|
||||
)
|
||||
from metadata.ingestion.source.database.trino.connection import get_connection_args
|
||||
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
@ -401,7 +407,7 @@ class SourceConnectionTest(TestCase):
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "trino://username:pass@localhost:443/catalog"
|
||||
expected_url = "trino://username@localhost:443/catalog"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
@ -413,7 +419,7 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
|
||||
# Passing @ in username and password
|
||||
expected_url = "trino://username%2540444:pass%40111@localhost:443/catalog"
|
||||
expected_url = "trino://username%40444@localhost:443/catalog"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
@ -430,7 +436,10 @@ class SourceConnectionTest(TestCase):
|
||||
)
|
||||
|
||||
# connection arguments without connectionArguments and without proxies
|
||||
expected_args = {}
|
||||
expected_args = {
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
@ -442,7 +451,11 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
|
||||
# connection arguments with connectionArguments and without proxies
|
||||
expected_args = {"user": "user-to-be-impersonated"}
|
||||
expected_args = {
|
||||
"user": "user-to-be-impersonated",
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
@ -454,7 +467,10 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
|
||||
# connection arguments without connectionArguments and with proxies
|
||||
expected_args = {}
|
||||
expected_args = {
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
@ -470,7 +486,11 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_args == conn_args
|
||||
|
||||
# connection arguments with connectionArguments and with proxies
|
||||
expected_args = {"user": "user-to-be-impersonated"}
|
||||
expected_args = {
|
||||
"user": "user-to-be-impersonated",
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
@ -490,7 +510,7 @@ class SourceConnectionTest(TestCase):
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "trino://username:pass@localhost:443/catalog?param=value"
|
||||
expected_url = "trino://username@localhost:443/catalog?param=value"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
@ -506,9 +526,11 @@ class SourceConnectionTest(TestCase):
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = (
|
||||
"trino://username@localhost:443/catalog?access_token=jwt_token_value"
|
||||
)
|
||||
expected_url = "trino://username@localhost:443/catalog"
|
||||
expected_args = {
|
||||
"auth": JWTAuthentication("jwt_token_value"),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
@ -517,6 +539,7 @@ class SourceConnectionTest(TestCase):
|
||||
catalog="catalog",
|
||||
)
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
|
||||
def test_trino_with_proxies(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
@ -543,7 +566,7 @@ class SourceConnectionTest(TestCase):
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username:pass@localhost:443"
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
@ -553,6 +576,40 @@ class SourceConnectionTest(TestCase):
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
|
||||
def test_trino_without_catalog(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=BasicAuth(password="pass"),
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
|
||||
def test_trino_with_oauth2(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2,
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication
|
||||
)
|
||||
|
||||
def test_vertica_url(self):
|
||||
expected_url = (
|
||||
"vertica+vertica_python://username:password@localhost:5443/database"
|
||||
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"$id": "https://open-metadata.org/schema/entity/services/connections/database/noConfigAuthenticationTypes.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "No Config Authentication Types",
|
||||
"javaType": "org.openmetadata.schema.services.connections.database.common.NoConfigAuthenticationTypes",
|
||||
"description": "Database Authentication types not requiring config.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"OAuth2"
|
||||
],
|
||||
"default": "OAuth2"
|
||||
}
|
@ -49,6 +49,9 @@
|
||||
},
|
||||
{
|
||||
"$ref": "./common/azureConfig.json"
|
||||
},
|
||||
{
|
||||
"$ref": "./common/noConfigAuthenticationTypes.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user