mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-01 13:13:10 +00:00
MINOR: Update Trino Connection to fix data diff (#21983)
This commit is contained in:
parent
7ff36b2478
commit
c899d45e8e
@ -157,7 +157,7 @@ base_requirements = {
|
||||
"packaging", # For version parsing
|
||||
"setuptools~=70.0",
|
||||
"shapely",
|
||||
"collate-data-diff",
|
||||
"collate-data-diff>=0.11.6",
|
||||
"jaraco.functools<4.2.0", # above 4.2 breaks the build
|
||||
# TODO: Remove one once we have updated datadiff version
|
||||
"snowflake-connector-python>=3.13.1,<4.0.0",
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Models for the TableDiff test case"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -12,7 +12,7 @@ from metadata.ingestion.models.custom_pydantic import CustomSecretStr
|
||||
|
||||
|
||||
class TableParameter(BaseModel):
|
||||
serviceUrl: str
|
||||
serviceUrl: Union[str, dict]
|
||||
path: str
|
||||
columns: List[Column]
|
||||
database_service_type: DatabaseServiceType
|
||||
|
@ -1,15 +1,50 @@
|
||||
"""Base class for param setter logic for table data diff"""
|
||||
|
||||
from typing import List, Optional, Set
|
||||
from typing import List, Optional, Set, Type, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from metadata.data_quality.validations.models import Column, TableParameter
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||
from metadata.ingestion.connections.connection import BaseConnection
|
||||
from metadata.ingestion.source.connections import get_connection
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.collections import CaseInsensitiveList
|
||||
from metadata.utils.importer import get_module_dir, import_from_module
|
||||
|
||||
|
||||
# TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods.
|
||||
# Using the specs class method causes circular import as TestSuiteInterface
|
||||
# imports RuntimeParameterSetter
|
||||
class ServiceSpecPatch:
|
||||
def __init__(self, service_type: ServiceType, source_type: str):
|
||||
self.service_type = service_type
|
||||
self.source_type = source_type
|
||||
self._service_spec = None
|
||||
|
||||
@property
|
||||
def service_spec(self):
|
||||
if self._service_spec is None:
|
||||
self._service_spec = self.get_for_source()
|
||||
return self._service_spec
|
||||
|
||||
def get_for_source(self):
|
||||
return import_from_module(
|
||||
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
|
||||
"ingestion",
|
||||
self.service_type.name.lower(),
|
||||
get_module_dir(self.source_type),
|
||||
"service_spec",
|
||||
)
|
||||
)
|
||||
|
||||
def get_data_diff_class(self) -> Type["BaseTableParameter"]:
|
||||
return import_from_module(self.service_spec.data_diff)
|
||||
|
||||
def get_connection_class(self) -> Type[BaseConnection]:
|
||||
return import_from_module(self.service_spec.connection_class)
|
||||
|
||||
|
||||
class BaseTableParameter:
|
||||
@ -22,7 +57,7 @@ class BaseTableParameter:
|
||||
key_columns,
|
||||
extra_columns,
|
||||
case_sensitive_columns,
|
||||
service_url: Optional[str],
|
||||
service_url: Optional[Union[str, dict]],
|
||||
) -> TableParameter:
|
||||
"""Getter table parameter for the table diff test.
|
||||
|
||||
@ -62,10 +97,35 @@ class BaseTableParameter:
|
||||
"___SERVICE___", "__DATABASE__", schema, table
|
||||
).replace("___SERVICE___.__DATABASE__.", "")
|
||||
|
||||
@staticmethod
|
||||
def _get_service_connection_config(
|
||||
db_service: DatabaseService,
|
||||
) -> Optional[Union[str, dict]]:
|
||||
"""
|
||||
Get the connection dictionary for the service.
|
||||
"""
|
||||
service_connection_config = db_service.connection.config
|
||||
service_spec_patch = ServiceSpecPatch(
|
||||
ServiceType.Database, service_connection_config.type.value.lower()
|
||||
)
|
||||
|
||||
try:
|
||||
connection_class = service_spec_patch.get_connection_class()
|
||||
connection = connection_class(service_connection_config)
|
||||
return connection.get_connection_dict()
|
||||
except (ValueError, AttributeError, NotImplementedError):
|
||||
return (
|
||||
str(get_connection(service_connection_config).url)
|
||||
if service_connection_config
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_data_diff_url(
|
||||
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
|
||||
) -> str:
|
||||
db_service: DatabaseService,
|
||||
table_fqn,
|
||||
override_url: Optional[Union[str, dict]] = None,
|
||||
) -> Union[str, dict]:
|
||||
"""Get the url for the data diff service.
|
||||
|
||||
Args:
|
||||
@ -77,10 +137,14 @@ class BaseTableParameter:
|
||||
str: The url for the data diff service
|
||||
"""
|
||||
source_url = (
|
||||
str(get_connection(db_service.connection.config).url)
|
||||
BaseTableParameter._get_service_connection_config(db_service)
|
||||
if not override_url
|
||||
else override_url
|
||||
)
|
||||
if isinstance(source_url, dict):
|
||||
source_url["driver"] = source_url["driver"].split("+")[0]
|
||||
return source_url
|
||||
|
||||
url = urlparse(source_url)
|
||||
# remove the driver name from the url because table-diff doesn't support it
|
||||
kwargs = {"scheme": url.scheme.split("+")[0]}
|
||||
|
@ -11,10 +11,13 @@
|
||||
"""Module that defines the TableDiffParamsSetter class."""
|
||||
from ast import literal_eval
|
||||
from typing import List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from metadata.data_quality.validations import utils
|
||||
from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters
|
||||
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
|
||||
BaseTableParameter,
|
||||
ServiceSpecPatch,
|
||||
)
|
||||
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
|
||||
RuntimeParameterSetter,
|
||||
)
|
||||
@ -22,24 +25,8 @@ from metadata.generated.schema.entity.data.table import Constraint, Table
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.ingestion.source.connections import get_connection
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.collections import CaseInsensitiveList
|
||||
from metadata.utils.importer import get_module_dir, import_from_module
|
||||
|
||||
|
||||
def get_for_source(
|
||||
service_type: ServiceType, source_type: str, from_: str = "ingestion"
|
||||
):
|
||||
return import_from_module(
|
||||
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
|
||||
from_,
|
||||
service_type.name.lower(),
|
||||
get_module_dir(source_type),
|
||||
"service_spec",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TableDiffParamsSetter(RuntimeParameterSetter):
|
||||
@ -62,23 +49,16 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
|
||||
}
|
||||
|
||||
def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
|
||||
# Using the specs class method causes circular import as TestSuiteInterface
|
||||
# imports RuntimeParameterSetter
|
||||
cls_path = get_for_source(
|
||||
ServiceType.Database,
|
||||
source_type=self.service_connection_config.type.value.lower(),
|
||||
).data_diff
|
||||
cls = import_from_module(cls_path)()
|
||||
service_spec_patch = ServiceSpecPatch(
|
||||
ServiceType.Database, self.service_connection_config.type.value.lower()
|
||||
)
|
||||
cls = service_spec_patch.get_data_diff_class()()
|
||||
|
||||
service1: DatabaseService = self.ometa_client.get_by_id(
|
||||
DatabaseService, self.table_entity.service.id, nullable=False
|
||||
)
|
||||
|
||||
service1_url = (
|
||||
str(get_connection(self.service_connection_config).url)
|
||||
if self.service_connection_config
|
||||
else None
|
||||
)
|
||||
service1_url = BaseTableParameter._get_service_connection_config(service1)
|
||||
|
||||
table2_fqn = self.get_parameter(test_case, "table2")
|
||||
if table2_fqn is None:
|
||||
@ -209,41 +189,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
|
||||
(p.value for p in test_case.parameterValues if p.name == key), default
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_data_diff_url(
|
||||
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the url for the data diff service.
|
||||
|
||||
Args:
|
||||
db_service (DatabaseService): The database service entity
|
||||
table_fqn (str): The fully qualified name of the table
|
||||
override_url (Optional[str], optional): Override the url. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The url for the data diff service
|
||||
"""
|
||||
source_url = (
|
||||
str(get_connection(db_service.connection.config).url)
|
||||
if not override_url
|
||||
else override_url
|
||||
)
|
||||
url = urlparse(source_url)
|
||||
# remove the driver name from the url because table-diff doesn't support it
|
||||
kwargs = {"scheme": url.scheme.split("+")[0]}
|
||||
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
|
||||
table_fqn
|
||||
)
|
||||
# path needs to include the database AND schema in some of the connectors
|
||||
if hasattr(db_service.connection.config, "supportsDatabase"):
|
||||
kwargs["path"] = f"/{database}"
|
||||
# this can be found by going to:
|
||||
# https://github.com/open-metadata/collate-data-diff/blob/main/data_diff/databases/<connector>.py
|
||||
# and looking at the `CONNECT_URI_HELPER` variable
|
||||
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake, Dialects.Trino}:
|
||||
kwargs["path"] = f"/{database}/{schema}"
|
||||
return url._replace(**kwargs).geturl()
|
||||
|
||||
@staticmethod
|
||||
def get_data_diff_table_path(table_fqn: str) -> str:
|
||||
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
|
||||
|
@ -508,6 +508,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||||
("table1.serviceUrl", self.runtime_params.table1.serviceUrl),
|
||||
("table2.serviceUrl", self.runtime_params.table2.serviceUrl),
|
||||
]:
|
||||
if isinstance(param, dict):
|
||||
dialect = param.get("driver")
|
||||
else:
|
||||
dialect = urlparse(param).scheme
|
||||
if dialect not in SUPPORTED_DIALECTS:
|
||||
raise UnsupportedDialectError(name, dialect)
|
||||
|
@ -41,12 +41,23 @@ class BaseConnection(ABC, Generic[S, C]):
|
||||
"""
|
||||
|
||||
service_connection: S
|
||||
_client: Optional[C]
|
||||
|
||||
def __init__(self, service_connection: S) -> None:
|
||||
self.service_connection = service_connection
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> C:
|
||||
"""
|
||||
Return the main client/engine/connection object for this service.
|
||||
"""
|
||||
if self._client is None:
|
||||
self._client = self._get_client()
|
||||
return self._client
|
||||
|
||||
@abstractmethod
|
||||
def get_client(self) -> C:
|
||||
def _get_client(self) -> C:
|
||||
"""
|
||||
Return the main client/engine/connection object for this service.
|
||||
"""
|
||||
@ -61,3 +72,9 @@ class BaseConnection(ABC, Generic[S, C]):
|
||||
"""
|
||||
Test the connection to the service.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connection_dict(self) -> dict:
|
||||
"""
|
||||
Return the connection dictionary for this service.
|
||||
"""
|
||||
|
@ -73,7 +73,7 @@ def _get_connection_fn_from_service_spec(connection: BaseModel) -> Optional[Call
|
||||
if connection_class:
|
||||
|
||||
def _get_client(conn):
|
||||
return connection_class(conn).get_client()
|
||||
return connection_class(conn).client
|
||||
|
||||
return _get_client
|
||||
return None
|
||||
|
@ -50,7 +50,7 @@ from metadata.utils.constants import THREE_MIN
|
||||
|
||||
|
||||
class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
|
||||
def get_client(self) -> Engine:
|
||||
def _get_client(self) -> Engine:
|
||||
"""
|
||||
Return the SQLAlchemy Engine for MySQL.
|
||||
"""
|
||||
@ -77,6 +77,12 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
|
||||
get_connection_args_fn=get_connection_args_common,
|
||||
)
|
||||
|
||||
def get_connection_dict(self) -> dict:
|
||||
"""
|
||||
Return the connection dictionary for this service.
|
||||
"""
|
||||
raise NotImplementedError("get_connection_dict is not implemented for MySQL")
|
||||
|
||||
def test_connection(
|
||||
self,
|
||||
metadata: OpenMetadata,
|
||||
@ -94,7 +100,7 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
|
||||
}
|
||||
return test_connection_db_schema_sources(
|
||||
metadata=metadata,
|
||||
engine=self.get_client(),
|
||||
engine=self.client,
|
||||
service_connection=self.service_connection,
|
||||
automation_workflow=automation_workflow,
|
||||
timeout_seconds=timeout_seconds,
|
||||
|
@ -13,7 +13,7 @@
|
||||
Source connection handler
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from requests import Session
|
||||
@ -24,13 +24,17 @@ from metadata.clients.azure_client import AzureClient
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
|
||||
ConnectionArguments,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.common import (
|
||||
azureConfig,
|
||||
basicAuth,
|
||||
jwtAuth,
|
||||
noConfigAuthenticationTypes,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
TrinoConnection,
|
||||
TrinoConnection as TrinoConnectionConfig,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
||||
TestConnectionResult,
|
||||
@ -41,6 +45,7 @@ from metadata.ingestion.connections.builders import (
|
||||
init_empty_connection_arguments,
|
||||
init_empty_connection_options,
|
||||
)
|
||||
from metadata.ingestion.connections.connection import BaseConnection
|
||||
from metadata.ingestion.connections.secrets import connection_with_options_secrets
|
||||
from metadata.ingestion.connections.test_connections import (
|
||||
test_connection_db_schema_sources,
|
||||
@ -50,10 +55,140 @@ from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
|
||||
from metadata.utils.constants import THREE_MIN
|
||||
|
||||
|
||||
def get_connection_url(connection: TrinoConnection) -> str:
|
||||
# pylint: disable=unused-argument
|
||||
def _is_disconnect(self, e, connection, cursor):
|
||||
"""is_disconnect method for the Databricks dialect"""
|
||||
if "JWT expired" in str(e):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
|
||||
def __init__(self, connection: TrinoConnectionConfig):
|
||||
super().__init__(connection)
|
||||
|
||||
def _get_client(self) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
# here we are creating a copy of connection, because we need to dynamically
|
||||
# add auth params to connectionArguments, which we do no intend to store
|
||||
# in original connection object and in OpenMetadata database
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
TrinoDialect.is_disconnect = _is_disconnect # type: ignore
|
||||
|
||||
connection = self.service_connection
|
||||
connection_copy = deepcopy(connection)
|
||||
|
||||
if hasattr(connection.authType, "azureConfig"):
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
if not connection.connectionOptions:
|
||||
connection.connectionOptions = init_empty_connection_options()
|
||||
connection.connectionOptions.root["access_token"] = access_token_obj.token
|
||||
|
||||
# Update the connection with the connection arguments
|
||||
connection_copy.connectionArguments = self.build_connection_args(
|
||||
connection_copy
|
||||
)
|
||||
|
||||
return create_generic_db_connection(
|
||||
connection=connection_copy,
|
||||
get_connection_url_fn=self.get_connection_url,
|
||||
get_connection_args_fn=get_connection_args_common,
|
||||
)
|
||||
|
||||
def test_connection(
|
||||
self,
|
||||
metadata: OpenMetadata,
|
||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||
timeout_seconds: Optional[int] = THREE_MIN,
|
||||
) -> TestConnectionResult:
|
||||
"""
|
||||
Test connection. This can be executed either as part
|
||||
of a metadata workflow or during an Automation Workflow
|
||||
"""
|
||||
queries = {
|
||||
"GetDatabases": TRINO_GET_DATABASE,
|
||||
}
|
||||
|
||||
return test_connection_db_schema_sources(
|
||||
metadata=metadata,
|
||||
engine=self.client,
|
||||
service_connection=self.service_connection,
|
||||
automation_workflow=automation_workflow,
|
||||
queries=queries,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def get_connection_dict(self) -> dict:
|
||||
"""
|
||||
Return the connection dictionary for this service.
|
||||
"""
|
||||
url = self.client.url
|
||||
connection_copy = deepcopy(self.service_connection)
|
||||
|
||||
connection_dict = {
|
||||
"driver": url.drivername,
|
||||
"host": url.host,
|
||||
"port": url.port,
|
||||
"user": url.username,
|
||||
"catalog": url.database,
|
||||
"schema": url.query.get("schema"),
|
||||
}
|
||||
|
||||
connection_dict.update(url.query)
|
||||
|
||||
if connection_copy.proxies:
|
||||
connection_dict["http_session"] = connection_copy.proxies
|
||||
|
||||
if (
|
||||
connection_copy.connectionArguments
|
||||
and connection_copy.connectionArguments.root
|
||||
):
|
||||
connection_with_options_secrets(lambda: connection_copy)
|
||||
connection_dict.update(get_connection_args_common(connection_copy))
|
||||
|
||||
if isinstance(connection_copy.authType, basicAuth.BasicAuth):
|
||||
connection_dict["auth"] = TrinoConnection.get_basic_auth_dict(
|
||||
connection_copy
|
||||
)
|
||||
connection_dict["http_scheme"] = "https"
|
||||
|
||||
elif isinstance(connection_copy.authType, jwtAuth.JwtAuth):
|
||||
connection_dict["auth"] = TrinoConnection.get_jwt_auth_dict(connection_copy)
|
||||
connection_dict["http_scheme"] = "https"
|
||||
|
||||
elif hasattr(connection_copy.authType, "azureConfig"):
|
||||
connection_dict["auth"] = TrinoConnection.get_azure_auth_dict(
|
||||
connection_copy
|
||||
)
|
||||
connection_dict["http_scheme"] = "https"
|
||||
|
||||
elif (
|
||||
connection_copy.authType
|
||||
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
|
||||
):
|
||||
connection_dict["auth"] = TrinoConnection.get_oauth2_auth_dict(
|
||||
connection_copy
|
||||
)
|
||||
connection_dict["http_scheme"] = "https"
|
||||
|
||||
return connection_dict
|
||||
|
||||
@staticmethod
|
||||
def get_connection_url(connection: TrinoConnectionConfig) -> str:
|
||||
"""
|
||||
Prepare the connection url for trino
|
||||
"""
|
||||
|
||||
url = f"{connection.scheme.value}://"
|
||||
|
||||
# leaving username here as, even though with basic auth is used directly
|
||||
@ -75,124 +210,157 @@ def get_connection_url(connection: TrinoConnection) -> str:
|
||||
url = f"{url}?{params}"
|
||||
return url
|
||||
|
||||
|
||||
@staticmethod
|
||||
@connection_with_options_secrets
|
||||
def get_connection_args(connection: TrinoConnection):
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
def build_connection_args(connection: TrinoConnectionConfig) -> ConnectionArguments:
|
||||
"""
|
||||
Get the connection args for the trino connection
|
||||
"""
|
||||
connection_args: ConnectionArguments = (
|
||||
connection.connectionArguments or init_empty_connection_arguments()
|
||||
)
|
||||
assert connection_args.root is not None
|
||||
|
||||
if connection.verify:
|
||||
connection_args.root["verify"] = {"verify": connection.verify}
|
||||
|
||||
if connection.proxies:
|
||||
session = Session()
|
||||
session.proxies = connection.proxies
|
||||
|
||||
connection.connectionArguments.root["http_session"] = session
|
||||
connection_args.root["http_session"] = session
|
||||
|
||||
if isinstance(connection.authType, basicAuth.BasicAuth):
|
||||
connection.connectionArguments.root["auth"] = BasicAuthentication(
|
||||
connection.username,
|
||||
connection.authType.password.get_secret_value()
|
||||
if connection.authType.password
|
||||
else None,
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
TrinoConnection.set_basic_auth(connection, connection_args)
|
||||
|
||||
elif isinstance(connection.authType, jwtAuth.JwtAuth):
|
||||
connection.connectionArguments.root["auth"] = JWTAuthentication(
|
||||
connection.authType.jwt.get_secret_value()
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
TrinoConnection.set_jwt_auth(connection, connection_args)
|
||||
|
||||
elif hasattr(connection.authType, "azureConfig"):
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
|
||||
connection.connectionArguments.root["auth"] = JWTAuthentication(
|
||||
access_token_obj.token
|
||||
)
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
TrinoConnection.set_azure_auth(connection, connection_args)
|
||||
|
||||
elif (
|
||||
connection.authType
|
||||
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
|
||||
):
|
||||
connection.connectionArguments.root["auth"] = OAuth2Authentication()
|
||||
connection.connectionArguments.root["http_scheme"] = "https"
|
||||
TrinoConnection.set_oauth2_auth(connection, connection_args)
|
||||
|
||||
return get_connection_args_common(connection)
|
||||
return connection_args
|
||||
|
||||
|
||||
def get_connection(connection: TrinoConnection) -> Engine:
|
||||
@staticmethod
|
||||
def get_basic_auth_dict(connection: TrinoConnectionConfig) -> dict:
|
||||
"""
|
||||
Create connection
|
||||
Get the basic auth dictionary for the trino connection
|
||||
"""
|
||||
# here we are creating a copy of connection, because we need to dynamically
|
||||
# add auth params to connectionArguments, which we do no intend to store
|
||||
# in original connection object and in OpenMetadata database
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
auth_type = cast(basicAuth.BasicAuth, connection.authType)
|
||||
return {
|
||||
"authType": "basic",
|
||||
"username": connection.username,
|
||||
"password": auth_type.password.get_secret_value()
|
||||
if auth_type.password
|
||||
else None,
|
||||
}
|
||||
|
||||
TrinoDialect.is_disconnect = _is_disconnect
|
||||
@staticmethod
|
||||
def set_basic_auth(
|
||||
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
|
||||
) -> None:
|
||||
"""
|
||||
Get the basic auth dictionary for the trino connection
|
||||
"""
|
||||
assert connection_args.root is not None
|
||||
auth_type = cast(basicAuth.BasicAuth, connection.authType)
|
||||
|
||||
connection_copy = deepcopy(connection)
|
||||
if connection_copy.verify:
|
||||
connection_copy.connectionArguments = (
|
||||
connection_copy.connectionArguments or init_empty_connection_arguments()
|
||||
connection_args.root["auth"] = BasicAuthentication(
|
||||
connection.username,
|
||||
auth_type.password.get_secret_value() if auth_type.password else None,
|
||||
)
|
||||
connection.connectionArguments.root["verify"] = {"verify": connection.verify}
|
||||
if hasattr(connection.authType, "azureConfig"):
|
||||
azure_client = AzureClient(connection.authType.azureConfig).create_client()
|
||||
if not connection.authType.azureConfig.scopes:
|
||||
connection_args.root["http_scheme"] = "https"
|
||||
|
||||
@staticmethod
|
||||
def get_jwt_auth_dict(connection: TrinoConnectionConfig) -> dict:
|
||||
"""
|
||||
Get the jwt auth dictionary for the trino connection
|
||||
"""
|
||||
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
|
||||
|
||||
return {
|
||||
"authType": "jwt",
|
||||
"jwt": auth_type.jwt.get_secret_value(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def set_jwt_auth(
|
||||
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
|
||||
) -> None:
|
||||
"""
|
||||
Set the jwt auth for the trino connection
|
||||
"""
|
||||
assert connection_args.root is not None
|
||||
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
|
||||
|
||||
connection_args.root["auth"] = JWTAuthentication(
|
||||
auth_type.jwt.get_secret_value()
|
||||
)
|
||||
connection_args.root["http_scheme"] = "https"
|
||||
|
||||
@staticmethod
|
||||
def get_azure_auth_dict(connection: TrinoConnectionConfig) -> dict:
|
||||
"""
|
||||
Get the azure auth dictionary for the trino connection
|
||||
"""
|
||||
return {
|
||||
"authType": "jwt",
|
||||
"jwt": TrinoConnection.get_azure_token(connection),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def set_azure_auth(
|
||||
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
|
||||
) -> None:
|
||||
"""
|
||||
Set the azure auth for the trino connection
|
||||
"""
|
||||
assert connection_args.root is not None
|
||||
|
||||
connection_args.root["auth"] = JWTAuthentication(
|
||||
TrinoConnection.get_azure_token(connection)
|
||||
)
|
||||
connection_args.root["http_scheme"] = "https"
|
||||
|
||||
@staticmethod
|
||||
def get_oauth2_auth_dict(connection: TrinoConnectionConfig) -> dict:
|
||||
"""
|
||||
Get the oauth2 auth dictionary for the trino connection
|
||||
"""
|
||||
return {
|
||||
"authType": "oauth2",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def set_oauth2_auth(
|
||||
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
|
||||
) -> None:
|
||||
"""
|
||||
Set the oauth2 auth for the trino connection
|
||||
"""
|
||||
assert connection_args.root is not None
|
||||
|
||||
connection_args.root["auth"] = OAuth2Authentication()
|
||||
connection_args.root["http_scheme"] = "https"
|
||||
|
||||
@staticmethod
|
||||
def get_azure_token(connection: TrinoConnectionConfig) -> str:
|
||||
"""
|
||||
Get the azure token for the trino connection
|
||||
"""
|
||||
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
|
||||
|
||||
if not auth_type.azureConfig.scopes:
|
||||
raise ValueError(
|
||||
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
|
||||
)
|
||||
access_token_obj = azure_client.get_token(
|
||||
*connection.authType.azureConfig.scopes.split(",")
|
||||
)
|
||||
if not connection.connectionOptions:
|
||||
connection.connectionOptions = init_empty_connection_options()
|
||||
connection.connectionOptions.root["access_token"] = access_token_obj.token
|
||||
return create_generic_db_connection(
|
||||
connection=connection_copy,
|
||||
get_connection_url_fn=get_connection_url,
|
||||
get_connection_args_fn=get_connection_args,
|
||||
)
|
||||
|
||||
azure_client = AzureClient(auth_type.azureConfig).create_client()
|
||||
|
||||
def test_connection(
|
||||
metadata: OpenMetadata,
|
||||
engine: Engine,
|
||||
service_connection: TrinoConnection,
|
||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||
timeout_seconds: Optional[int] = THREE_MIN,
|
||||
) -> TestConnectionResult:
|
||||
"""
|
||||
Test connection. This can be executed either as part
|
||||
of a metadata workflow or during an Automation Workflow
|
||||
"""
|
||||
queries = {
|
||||
"GetDatabases": TRINO_GET_DATABASE,
|
||||
}
|
||||
|
||||
return test_connection_db_schema_sources(
|
||||
metadata=metadata,
|
||||
engine=engine,
|
||||
service_connection=service_connection,
|
||||
automation_workflow=automation_workflow,
|
||||
queries=queries,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def _is_disconnect(self, e, connection, cursor):
|
||||
"""is_disconnect method for the Databricks dialect"""
|
||||
if "JWT expired" in str(e):
|
||||
return True
|
||||
return False
|
||||
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token
|
||||
|
@ -1,3 +1,4 @@
|
||||
from metadata.ingestion.source.database.trino.connection import TrinoConnection
|
||||
from metadata.ingestion.source.database.trino.lineage import TrinoLineageSource
|
||||
from metadata.ingestion.source.database.trino.metadata import TrinoSource
|
||||
from metadata.ingestion.source.database.trino.usage import TrinoUsageSource
|
||||
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
|
||||
usage_source_class=TrinoUsageSource,
|
||||
profiler_class=TrinoProfilerInterface,
|
||||
sampler_class=TrinoSampler,
|
||||
connection_class=TrinoConnection,
|
||||
)
|
||||
|
@ -6,6 +6,9 @@ from sqlalchemy import Column as SAColumn
|
||||
from sqlalchemy import MetaData, String, create_engine
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
|
||||
BaseTableParameter,
|
||||
)
|
||||
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
|
||||
TableDiffParamsSetter,
|
||||
)
|
||||
@ -111,9 +114,9 @@ SERVICE_CONNECTION_CONFIG = MysqlConnection(
|
||||
],
|
||||
)
|
||||
def test_get_data_diff_url(input, expected):
|
||||
assert expected == TableDiffParamsSetter(
|
||||
None, None, MOCK_TABLE, None
|
||||
).get_data_diff_url(input, "service.database.schema.table")
|
||||
assert expected == BaseTableParameter.get_data_diff_url(
|
||||
input, "service.database.schema.table"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -44,7 +44,7 @@ def test_connection(mock_service_connection):
|
||||
class TestConnection(BaseConnection):
|
||||
"""Concrete implementation of BaseConnection for testing"""
|
||||
|
||||
def get_client(self):
|
||||
def _get_client(self):
|
||||
return MagicMock()
|
||||
|
||||
def test_connection(
|
||||
@ -64,6 +64,9 @@ def test_connection(mock_service_connection):
|
||||
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
|
||||
)
|
||||
|
||||
def get_connection_dict(self):
|
||||
return {}
|
||||
|
||||
return TestConnection(mock_service_connection)
|
||||
|
||||
|
||||
@ -100,7 +103,7 @@ class TestBaseConnection:
|
||||
mock_client = MagicMock()
|
||||
|
||||
class TestConnectionWithMockClient(BaseConnection):
|
||||
def get_client(self):
|
||||
def _get_client(self):
|
||||
return mock_client
|
||||
|
||||
def test_connection(
|
||||
@ -120,6 +123,9 @@ class TestBaseConnection:
|
||||
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
|
||||
)
|
||||
|
||||
def get_connection_dict(self):
|
||||
return {}
|
||||
|
||||
connection = TestConnectionWithMockClient(test_connection.service_connection)
|
||||
client = connection.get_client()
|
||||
client = connection.client
|
||||
assert client == mock_client
|
||||
|
@ -70,7 +70,7 @@ class TestGetConnectionURL(unittest.TestCase):
|
||||
hostPort="localhost:3306",
|
||||
databaseSchema="openmetadata_db",
|
||||
)
|
||||
engine_connection = MySQLConnection(connection).get_client()
|
||||
engine_connection = MySQLConnection(connection).client
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
|
||||
@ -93,7 +93,7 @@ class TestGetConnectionURL(unittest.TestCase):
|
||||
"get_token",
|
||||
return_value=AccessToken(token="mocked_token", expires_on=100),
|
||||
):
|
||||
engine_connection = MySQLConnection(connection).get_client()
|
||||
engine_connection = MySQLConnection(connection).client
|
||||
self.assertEqual(
|
||||
str(engine_connection.url),
|
||||
"mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",
|
||||
|
@ -100,7 +100,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon
|
||||
SnowflakeScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
TrinoConnection,
|
||||
TrinoConnection as TrinoConnectionConfig,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
TrinoScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.verticaConnection import (
|
||||
@ -112,7 +114,7 @@ from metadata.ingestion.connections.builders import (
|
||||
get_connection_args_common,
|
||||
get_connection_url_common,
|
||||
)
|
||||
from metadata.ingestion.source.database.trino.connection import get_connection_args
|
||||
from metadata.ingestion.source.database.trino.connection import TrinoConnection
|
||||
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
@ -405,44 +407,38 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_result == get_connection_url(impala_conn_obj)
|
||||
|
||||
def test_trino_url_without_params(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "trino://username@localhost:443/catalog"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=BasicAuth(password="pass"),
|
||||
catalog="catalog",
|
||||
)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
|
||||
# Passing @ in username and password
|
||||
expected_url = "trino://username%40444@localhost:443/catalog"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username@444",
|
||||
authType=BasicAuth(password="pass@111"),
|
||||
catalog="catalog",
|
||||
)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
|
||||
def test_trino_conn_arguments(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_args,
|
||||
)
|
||||
|
||||
# connection arguments without connectionArguments and without proxies
|
||||
expected_args = {
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
hostPort="localhost:443",
|
||||
@ -450,7 +446,10 @@ class SourceConnectionTest(TestCase):
|
||||
connectionArguments=None,
|
||||
scheme=TrinoScheme.trino,
|
||||
)
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert (
|
||||
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
|
||||
)
|
||||
|
||||
# connection arguments with connectionArguments and without proxies
|
||||
expected_args = {
|
||||
@ -458,7 +457,7 @@ class SourceConnectionTest(TestCase):
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
hostPort="localhost:443",
|
||||
@ -466,14 +465,17 @@ class SourceConnectionTest(TestCase):
|
||||
connectionArguments={"user": "user-to-be-impersonated"},
|
||||
scheme=TrinoScheme.trino,
|
||||
)
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert (
|
||||
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
|
||||
)
|
||||
|
||||
# connection arguments without connectionArguments and with proxies
|
||||
expected_args = {
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
hostPort="localhost:443",
|
||||
@ -482,7 +484,9 @@ class SourceConnectionTest(TestCase):
|
||||
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
|
||||
scheme=TrinoScheme.trino,
|
||||
)
|
||||
conn_args = get_connection_args(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
|
||||
|
||||
assert "http_session" in conn_args
|
||||
conn_args.pop("http_session")
|
||||
assert expected_args == conn_args
|
||||
@ -493,7 +497,7 @@ class SourceConnectionTest(TestCase):
|
||||
"auth": BasicAuthentication("user", None),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
username="user",
|
||||
authType=BasicAuth(password=None),
|
||||
hostPort="localhost:443",
|
||||
@ -502,18 +506,15 @@ class SourceConnectionTest(TestCase):
|
||||
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
|
||||
scheme=TrinoScheme.trino,
|
||||
)
|
||||
conn_args = get_connection_args(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
|
||||
assert "http_session" in conn_args
|
||||
conn_args.pop("http_session")
|
||||
assert expected_args == conn_args
|
||||
|
||||
def test_trino_url_with_params(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "trino://username@localhost:443/catalog?param=value"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
@ -521,35 +522,31 @@ class SourceConnectionTest(TestCase):
|
||||
catalog="catalog",
|
||||
connectionOptions={"param": "value"},
|
||||
)
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
|
||||
def test_trino_url_with_jwt_auth(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "trino://username@localhost:443/catalog"
|
||||
expected_args = {
|
||||
"auth": JWTAuthentication("jwt_token_value"),
|
||||
"http_scheme": "https",
|
||||
}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=JwtAuth(jwt="jwt_token_value"),
|
||||
catalog="catalog",
|
||||
)
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
assert expected_args == get_connection_args(trino_conn_obj)
|
||||
|
||||
def test_trino_with_proxies(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_args,
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
assert (
|
||||
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
|
||||
)
|
||||
|
||||
def test_trino_with_proxies(self):
|
||||
test_proxies = {"http": "http_proxy", "https": "https_proxy"}
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
@ -557,59 +554,54 @@ class SourceConnectionTest(TestCase):
|
||||
catalog="catalog",
|
||||
proxies=test_proxies,
|
||||
)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert (
|
||||
test_proxies
|
||||
== get_connection_args(trino_conn_obj).get("http_session").proxies
|
||||
== trino_connection.build_connection_args(trino_conn_obj)
|
||||
.root.get("http_session")
|
||||
.proxies
|
||||
)
|
||||
|
||||
def test_trino_without_catalog(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=BasicAuth(password="pass"),
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
|
||||
def test_trino_without_catalog(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=BasicAuth(password="pass"),
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert expected_url == str(trino_connection.client.url)
|
||||
|
||||
def test_trino_with_oauth2(self):
|
||||
from metadata.ingestion.source.database.trino.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
trino_conn_obj = TrinoConnectionConfig(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2,
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication
|
||||
trino_connection = TrinoConnection(trino_conn_obj)
|
||||
assert (
|
||||
trino_connection.build_connection_args(trino_conn_obj).root.get("auth")
|
||||
== OAuth2Authentication()
|
||||
)
|
||||
|
||||
def test_vertica_url(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user