MINOR: Update Trino Connection to fix data diff (#21983)

This commit is contained in:
IceS2 2025-06-27 07:58:48 +02:00 committed by GitHub
parent 7ff36b2478
commit c899d45e8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 493 additions and 287 deletions

View File

@ -157,7 +157,7 @@ base_requirements = {
"packaging", # For version parsing "packaging", # For version parsing
"setuptools~=70.0", "setuptools~=70.0",
"shapely", "shapely",
"collate-data-diff", "collate-data-diff>=0.11.6",
"jaraco.functools<4.2.0", # above 4.2 breaks the build "jaraco.functools<4.2.0", # above 4.2 breaks the build
# TODO: Remove one once we have updated datadiff version # TODO: Remove one once we have updated datadiff version
"snowflake-connector-python>=3.13.1,<4.0.0", "snowflake-connector-python>=3.13.1,<4.0.0",

View File

@ -1,6 +1,6 @@
"""Models for the TableDiff test case""" """Models for the TableDiff test case"""
from typing import List, Optional from typing import List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -12,7 +12,7 @@ from metadata.ingestion.models.custom_pydantic import CustomSecretStr
class TableParameter(BaseModel): class TableParameter(BaseModel):
serviceUrl: str serviceUrl: Union[str, dict]
path: str path: str
columns: List[Column] columns: List[Column]
database_service_type: DatabaseServiceType database_service_type: DatabaseServiceType

View File

@ -1,15 +1,50 @@
"""Base class for param setter logic for table data diff""" """Base class for param setter logic for table data diff"""
from typing import List, Optional, Set from typing import List, Optional, Set, Type, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from metadata.data_quality.validations.models import Column, TableParameter from metadata.data_quality.validations.models import Column, TableParameter
from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.source.connections import get_connection from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module
# TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods.
# Using the specs class method causes circular import as TestSuiteInterface
# imports RuntimeParameterSetter
class ServiceSpecPatch:
def __init__(self, service_type: ServiceType, source_type: str):
self.service_type = service_type
self.source_type = source_type
self._service_spec = None
@property
def service_spec(self):
if self._service_spec is None:
self._service_spec = self.get_for_source()
return self._service_spec
def get_for_source(self):
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
"ingestion",
self.service_type.name.lower(),
get_module_dir(self.source_type),
"service_spec",
)
)
def get_data_diff_class(self) -> Type["BaseTableParameter"]:
return import_from_module(self.service_spec.data_diff)
def get_connection_class(self) -> Type[BaseConnection]:
return import_from_module(self.service_spec.connection_class)
class BaseTableParameter: class BaseTableParameter:
@ -22,7 +57,7 @@ class BaseTableParameter:
key_columns, key_columns,
extra_columns, extra_columns,
case_sensitive_columns, case_sensitive_columns,
service_url: Optional[str], service_url: Optional[Union[str, dict]],
) -> TableParameter: ) -> TableParameter:
"""Getter table parameter for the table diff test. """Getter table parameter for the table diff test.
@ -62,10 +97,35 @@ class BaseTableParameter:
"___SERVICE___", "__DATABASE__", schema, table "___SERVICE___", "__DATABASE__", schema, table
).replace("___SERVICE___.__DATABASE__.", "") ).replace("___SERVICE___.__DATABASE__.", "")
@staticmethod
def _get_service_connection_config(
db_service: DatabaseService,
) -> Optional[Union[str, dict]]:
"""
Get the connection dictionary for the service.
"""
service_connection_config = db_service.connection.config
service_spec_patch = ServiceSpecPatch(
ServiceType.Database, service_connection_config.type.value.lower()
)
try:
connection_class = service_spec_patch.get_connection_class()
connection = connection_class(service_connection_config)
return connection.get_connection_dict()
except (ValueError, AttributeError, NotImplementedError):
return (
str(get_connection(service_connection_config).url)
if service_connection_config
else None
)
@staticmethod @staticmethod
def get_data_diff_url( def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None db_service: DatabaseService,
) -> str: table_fqn,
override_url: Optional[Union[str, dict]] = None,
) -> Union[str, dict]:
"""Get the url for the data diff service. """Get the url for the data diff service.
Args: Args:
@ -77,10 +137,14 @@ class BaseTableParameter:
str: The url for the data diff service str: The url for the data diff service
""" """
source_url = ( source_url = (
str(get_connection(db_service.connection.config).url) BaseTableParameter._get_service_connection_config(db_service)
if not override_url if not override_url
else override_url else override_url
) )
if isinstance(source_url, dict):
source_url["driver"] = source_url["driver"].split("+")[0]
return source_url
url = urlparse(source_url) url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it # remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]} kwargs = {"scheme": url.scheme.split("+")[0]}

View File

@ -11,10 +11,13 @@
"""Module that defines the TableDiffParamsSetter class.""" """Module that defines the TableDiffParamsSetter class."""
from ast import literal_eval from ast import literal_eval
from typing import List, Optional, Set from typing import List, Optional, Set
from urllib.parse import urlparse
from metadata.data_quality.validations import utils from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
ServiceSpecPatch,
)
from metadata.data_quality.validations.runtime_param_setter.param_setter import ( from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter, RuntimeParameterSetter,
) )
@ -22,24 +25,8 @@ from metadata.generated.schema.entity.data.table import Constraint, Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.serviceType import ServiceType from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.generated.schema.tests.testCase import TestCase from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module
def get_for_source(
service_type: ServiceType, source_type: str, from_: str = "ingestion"
):
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
from_,
service_type.name.lower(),
get_module_dir(source_type),
"service_spec",
)
)
class TableDiffParamsSetter(RuntimeParameterSetter): class TableDiffParamsSetter(RuntimeParameterSetter):
@ -62,23 +49,16 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
} }
def get_parameters(self, test_case) -> TableDiffRuntimeParameters: def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
# Using the specs class method causes circular import as TestSuiteInterface service_spec_patch = ServiceSpecPatch(
# imports RuntimeParameterSetter ServiceType.Database, self.service_connection_config.type.value.lower()
cls_path = get_for_source( )
ServiceType.Database, cls = service_spec_patch.get_data_diff_class()()
source_type=self.service_connection_config.type.value.lower(),
).data_diff
cls = import_from_module(cls_path)()
service1: DatabaseService = self.ometa_client.get_by_id( service1: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, self.table_entity.service.id, nullable=False DatabaseService, self.table_entity.service.id, nullable=False
) )
service1_url = ( service1_url = BaseTableParameter._get_service_connection_config(service1)
str(get_connection(self.service_connection_config).url)
if self.service_connection_config
else None
)
table2_fqn = self.get_parameter(test_case, "table2") table2_fqn = self.get_parameter(test_case, "table2")
if table2_fqn is None: if table2_fqn is None:
@ -209,41 +189,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
(p.value for p in test_case.parameterValues if p.name == key), default (p.value for p in test_case.parameterValues if p.name == key), default
) )
@staticmethod
def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
) -> str:
"""Get the url for the data diff service.
Args:
db_service (DatabaseService): The database service entity
table_fqn (str): The fully qualified name of the table
override_url (Optional[str], optional): Override the url. Defaults to None.
Returns:
str: The url for the data diff service
"""
source_url = (
str(get_connection(db_service.connection.config).url)
if not override_url
else override_url
)
url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]}
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
table_fqn
)
# path needs to include the database AND schema in some of the connectors
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
# this can be found by going to:
# https://github.com/open-metadata/collate-data-diff/blob/main/data_diff/databases/<connector>.py
# and looking at the `CONNECT_URI_HELPER` variable
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake, Dialects.Trino}:
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()
@staticmethod @staticmethod
def get_data_diff_table_path(table_fqn: str) -> str: def get_data_diff_table_path(table_fqn: str) -> str:
service, database, schema, table = fqn.split( # pylint: disable=unused-variable service, database, schema, table = fqn.split( # pylint: disable=unused-variable

View File

@ -508,7 +508,10 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
("table1.serviceUrl", self.runtime_params.table1.serviceUrl), ("table1.serviceUrl", self.runtime_params.table1.serviceUrl),
("table2.serviceUrl", self.runtime_params.table2.serviceUrl), ("table2.serviceUrl", self.runtime_params.table2.serviceUrl),
]: ]:
dialect = urlparse(param).scheme if isinstance(param, dict):
dialect = param.get("driver")
else:
dialect = urlparse(param).scheme
if dialect not in SUPPORTED_DIALECTS: if dialect not in SUPPORTED_DIALECTS:
raise UnsupportedDialectError(name, dialect) raise UnsupportedDialectError(name, dialect)

View File

@ -41,12 +41,23 @@ class BaseConnection(ABC, Generic[S, C]):
""" """
service_connection: S service_connection: S
_client: Optional[C]
def __init__(self, service_connection: S) -> None: def __init__(self, service_connection: S) -> None:
self.service_connection = service_connection self.service_connection = service_connection
self._client = None
@property
def client(self) -> C:
"""
Return the main client/engine/connection object for this service.
"""
if self._client is None:
self._client = self._get_client()
return self._client
@abstractmethod @abstractmethod
def get_client(self) -> C: def _get_client(self) -> C:
""" """
Return the main client/engine/connection object for this service. Return the main client/engine/connection object for this service.
""" """
@ -61,3 +72,9 @@ class BaseConnection(ABC, Generic[S, C]):
""" """
Test the connection to the service. Test the connection to the service.
""" """
@abstractmethod
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""

View File

@ -73,7 +73,7 @@ def _get_connection_fn_from_service_spec(connection: BaseModel) -> Optional[Call
if connection_class: if connection_class:
def _get_client(conn): def _get_client(conn):
return connection_class(conn).get_client() return connection_class(conn).client
return _get_client return _get_client
return None return None

View File

@ -50,7 +50,7 @@ from metadata.utils.constants import THREE_MIN
class MySQLConnection(BaseConnection[MysqlConnection, Engine]): class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
def get_client(self) -> Engine: def _get_client(self) -> Engine:
""" """
Return the SQLAlchemy Engine for MySQL. Return the SQLAlchemy Engine for MySQL.
""" """
@ -77,6 +77,12 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
get_connection_args_fn=get_connection_args_common, get_connection_args_fn=get_connection_args_common,
) )
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
raise NotImplementedError("get_connection_dict is not implemented for MySQL")
def test_connection( def test_connection(
self, self,
metadata: OpenMetadata, metadata: OpenMetadata,
@ -94,7 +100,7 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
} }
return test_connection_db_schema_sources( return test_connection_db_schema_sources(
metadata=metadata, metadata=metadata,
engine=self.get_client(), engine=self.client,
service_connection=self.service_connection, service_connection=self.service_connection,
automation_workflow=automation_workflow, automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,

View File

@ -13,7 +13,7 @@
Source connection handler Source connection handler
""" """
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional, cast
from urllib.parse import quote_plus from urllib.parse import quote_plus
from requests import Session from requests import Session
@ -24,13 +24,17 @@ from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
ConnectionArguments,
)
from metadata.generated.schema.entity.services.connections.database.common import ( from metadata.generated.schema.entity.services.connections.database.common import (
azureConfig,
basicAuth, basicAuth,
jwtAuth, jwtAuth,
noConfigAuthenticationTypes, noConfigAuthenticationTypes,
) )
from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection, TrinoConnection as TrinoConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.testConnectionResult import ( from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult, TestConnectionResult,
@ -41,6 +45,7 @@ from metadata.ingestion.connections.builders import (
init_empty_connection_arguments, init_empty_connection_arguments,
init_empty_connection_options, init_empty_connection_options,
) )
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.secrets import connection_with_options_secrets from metadata.ingestion.connections.secrets import connection_with_options_secrets
from metadata.ingestion.connections.test_connections import ( from metadata.ingestion.connections.test_connections import (
test_connection_db_schema_sources, test_connection_db_schema_sources,
@ -50,149 +55,312 @@ from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
from metadata.utils.constants import THREE_MIN from metadata.utils.constants import THREE_MIN
def get_connection_url(connection: TrinoConnection) -> str:
"""
Prepare the connection url for trino
"""
url = f"{connection.scheme.value}://"
# leaving username here as, even though with basic auth is used directly
# in BasicAuthentication class, it's often also required as a part of url.
# For example - it will be used by OAuth2Authentication to persist token in
# cache more efficiently (per user instead of per host)
if connection.username:
url += f"{quote_plus(connection.username)}@"
url += f"{connection.hostPort}"
if connection.catalog:
url += f"/{connection.catalog}"
if connection.connectionOptions is not None:
params = "&".join(
f"{key}={quote_plus(value)}"
for (key, value) in connection.connectionOptions.root.items()
if value
)
url = f"{url}?{params}"
return url
@connection_with_options_secrets
def get_connection_args(connection: TrinoConnection):
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
if connection.proxies:
session = Session()
session.proxies = connection.proxies
connection.connectionArguments.root["http_session"] = session
if isinstance(connection.authType, basicAuth.BasicAuth):
connection.connectionArguments.root["auth"] = BasicAuthentication(
connection.username,
connection.authType.password.get_secret_value()
if connection.authType.password
else None,
)
connection.connectionArguments.root["http_scheme"] = "https"
elif isinstance(connection.authType, jwtAuth.JwtAuth):
connection.connectionArguments.root["auth"] = JWTAuthentication(
connection.authType.jwt.get_secret_value()
)
connection.connectionArguments.root["http_scheme"] = "https"
elif hasattr(connection.authType, "azureConfig"):
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(connection.authType.azureConfig).create_client()
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.connectionArguments.root["auth"] = JWTAuthentication(
access_token_obj.token
)
connection.connectionArguments.root["http_scheme"] = "https"
elif (
connection.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
connection.connectionArguments.root["auth"] = OAuth2Authentication()
connection.connectionArguments.root["http_scheme"] = "https"
return get_connection_args_common(connection)
def get_connection(connection: TrinoConnection) -> Engine:
"""
Create connection
"""
# here we are creating a copy of connection, because we need to dynamically
# add auth params to connectionArguments, which we do no intend to store
# in original connection object and in OpenMetadata database
from trino.sqlalchemy.dialect import TrinoDialect
TrinoDialect.is_disconnect = _is_disconnect
connection_copy = deepcopy(connection)
if connection_copy.verify:
connection_copy.connectionArguments = (
connection_copy.connectionArguments or init_empty_connection_arguments()
)
connection.connectionArguments.root["verify"] = {"verify": connection.verify}
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token
return create_generic_db_connection(
connection=connection_copy,
get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args,
)
def test_connection(
metadata: OpenMetadata,
engine: Engine,
service_connection: TrinoConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetDatabases": TRINO_GET_DATABASE,
}
return test_connection_db_schema_sources(
metadata=metadata,
engine=engine,
service_connection=service_connection,
automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds,
)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _is_disconnect(self, e, connection, cursor): def _is_disconnect(self, e, connection, cursor):
"""is_disconnect method for the Databricks dialect""" """is_disconnect method for the Databricks dialect"""
if "JWT expired" in str(e): if "JWT expired" in str(e):
return True return True
return False return False
class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
def __init__(self, connection: TrinoConnectionConfig):
super().__init__(connection)
def _get_client(self) -> Engine:
"""
Create connection
"""
# here we are creating a copy of connection, because we need to dynamically
# add auth params to connectionArguments, which we do no intend to store
# in original connection object and in OpenMetadata database
from trino.sqlalchemy.dialect import TrinoDialect
TrinoDialect.is_disconnect = _is_disconnect # type: ignore
connection = self.service_connection
connection_copy = deepcopy(connection)
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token
# Update the connection with the connection arguments
connection_copy.connectionArguments = self.build_connection_args(
connection_copy
)
return create_generic_db_connection(
connection=connection_copy,
get_connection_url_fn=self.get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
def test_connection(
self,
metadata: OpenMetadata,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetDatabases": TRINO_GET_DATABASE,
}
return test_connection_db_schema_sources(
metadata=metadata,
engine=self.client,
service_connection=self.service_connection,
automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
url = self.client.url
connection_copy = deepcopy(self.service_connection)
connection_dict = {
"driver": url.drivername,
"host": url.host,
"port": url.port,
"user": url.username,
"catalog": url.database,
"schema": url.query.get("schema"),
}
connection_dict.update(url.query)
if connection_copy.proxies:
connection_dict["http_session"] = connection_copy.proxies
if (
connection_copy.connectionArguments
and connection_copy.connectionArguments.root
):
connection_with_options_secrets(lambda: connection_copy)
connection_dict.update(get_connection_args_common(connection_copy))
if isinstance(connection_copy.authType, basicAuth.BasicAuth):
connection_dict["auth"] = TrinoConnection.get_basic_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
elif isinstance(connection_copy.authType, jwtAuth.JwtAuth):
connection_dict["auth"] = TrinoConnection.get_jwt_auth_dict(connection_copy)
connection_dict["http_scheme"] = "https"
elif hasattr(connection_copy.authType, "azureConfig"):
connection_dict["auth"] = TrinoConnection.get_azure_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
elif (
connection_copy.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
connection_dict["auth"] = TrinoConnection.get_oauth2_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
return connection_dict
@staticmethod
def get_connection_url(connection: TrinoConnectionConfig) -> str:
"""
Prepare the connection url for trino
"""
url = f"{connection.scheme.value}://"
# leaving username here as, even though with basic auth is used directly
# in BasicAuthentication class, it's often also required as a part of url.
# For example - it will be used by OAuth2Authentication to persist token in
# cache more efficiently (per user instead of per host)
if connection.username:
url += f"{quote_plus(connection.username)}@"
url += f"{connection.hostPort}"
if connection.catalog:
url += f"/{connection.catalog}"
if connection.connectionOptions is not None:
params = "&".join(
f"{key}={quote_plus(value)}"
for (key, value) in connection.connectionOptions.root.items()
if value
)
url = f"{url}?{params}"
return url
@staticmethod
@connection_with_options_secrets
def build_connection_args(connection: TrinoConnectionConfig) -> ConnectionArguments:
"""
Get the connection args for the trino connection
"""
connection_args: ConnectionArguments = (
connection.connectionArguments or init_empty_connection_arguments()
)
assert connection_args.root is not None
if connection.verify:
connection_args.root["verify"] = {"verify": connection.verify}
if connection.proxies:
session = Session()
session.proxies = connection.proxies
connection_args.root["http_session"] = session
if isinstance(connection.authType, basicAuth.BasicAuth):
TrinoConnection.set_basic_auth(connection, connection_args)
elif isinstance(connection.authType, jwtAuth.JwtAuth):
TrinoConnection.set_jwt_auth(connection, connection_args)
elif hasattr(connection.authType, "azureConfig"):
TrinoConnection.set_azure_auth(connection, connection_args)
elif (
connection.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
TrinoConnection.set_oauth2_auth(connection, connection_args)
return connection_args
@staticmethod
def get_basic_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the basic auth dictionary for the trino connection
"""
auth_type = cast(basicAuth.BasicAuth, connection.authType)
return {
"authType": "basic",
"username": connection.username,
"password": auth_type.password.get_secret_value()
if auth_type.password
else None,
}
@staticmethod
def set_basic_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Get the basic auth dictionary for the trino connection
"""
assert connection_args.root is not None
auth_type = cast(basicAuth.BasicAuth, connection.authType)
connection_args.root["auth"] = BasicAuthentication(
connection.username,
auth_type.password.get_secret_value() if auth_type.password else None,
)
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_jwt_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the jwt auth dictionary for the trino connection
"""
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
return {
"authType": "jwt",
"jwt": auth_type.jwt.get_secret_value(),
}
@staticmethod
def set_jwt_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the jwt auth for the trino connection
"""
assert connection_args.root is not None
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
connection_args.root["auth"] = JWTAuthentication(
auth_type.jwt.get_secret_value()
)
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_azure_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the azure auth dictionary for the trino connection
"""
return {
"authType": "jwt",
"jwt": TrinoConnection.get_azure_token(connection),
}
@staticmethod
def set_azure_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the azure auth for the trino connection
"""
assert connection_args.root is not None
connection_args.root["auth"] = JWTAuthentication(
TrinoConnection.get_azure_token(connection)
)
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_oauth2_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the oauth2 auth dictionary for the trino connection
"""
return {
"authType": "oauth2",
}
@staticmethod
def set_oauth2_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the oauth2 auth for the trino connection
"""
assert connection_args.root is not None
connection_args.root["auth"] = OAuth2Authentication()
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_azure_token(connection: TrinoConnectionConfig) -> str:
"""
Get the azure token for the trino connection
"""
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
if not auth_type.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(auth_type.azureConfig).create_client()
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token

View File

@ -1,3 +1,4 @@
from metadata.ingestion.source.database.trino.connection import TrinoConnection
from metadata.ingestion.source.database.trino.lineage import TrinoLineageSource from metadata.ingestion.source.database.trino.lineage import TrinoLineageSource
from metadata.ingestion.source.database.trino.metadata import TrinoSource from metadata.ingestion.source.database.trino.metadata import TrinoSource
from metadata.ingestion.source.database.trino.usage import TrinoUsageSource from metadata.ingestion.source.database.trino.usage import TrinoUsageSource
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
usage_source_class=TrinoUsageSource, usage_source_class=TrinoUsageSource,
profiler_class=TrinoProfilerInterface, profiler_class=TrinoProfilerInterface,
sampler_class=TrinoSampler, sampler_class=TrinoSampler,
connection_class=TrinoConnection,
) )

View File

@ -6,6 +6,9 @@ from sqlalchemy import Column as SAColumn
from sqlalchemy import MetaData, String, create_engine from sqlalchemy import MetaData, String, create_engine
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import ( from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter, TableDiffParamsSetter,
) )
@ -111,9 +114,9 @@ SERVICE_CONNECTION_CONFIG = MysqlConnection(
], ],
) )
def test_get_data_diff_url(input, expected): def test_get_data_diff_url(input, expected):
assert expected == TableDiffParamsSetter( assert expected == BaseTableParameter.get_data_diff_url(
None, None, MOCK_TABLE, None input, "service.database.schema.table"
).get_data_diff_url(input, "service.database.schema.table") )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -44,7 +44,7 @@ def test_connection(mock_service_connection):
class TestConnection(BaseConnection): class TestConnection(BaseConnection):
"""Concrete implementation of BaseConnection for testing""" """Concrete implementation of BaseConnection for testing"""
def get_client(self): def _get_client(self):
return MagicMock() return MagicMock()
def test_connection( def test_connection(
@ -64,6 +64,9 @@ def test_connection(mock_service_connection):
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)), lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
) )
def get_connection_dict(self):
return {}
return TestConnection(mock_service_connection) return TestConnection(mock_service_connection)
@ -100,7 +103,7 @@ class TestBaseConnection:
mock_client = MagicMock() mock_client = MagicMock()
class TestConnectionWithMockClient(BaseConnection): class TestConnectionWithMockClient(BaseConnection):
def get_client(self): def _get_client(self):
return mock_client return mock_client
def test_connection( def test_connection(
@ -120,6 +123,9 @@ class TestBaseConnection:
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)), lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
) )
def get_connection_dict(self):
return {}
connection = TestConnectionWithMockClient(test_connection.service_connection) connection = TestConnectionWithMockClient(test_connection.service_connection)
client = connection.get_client() client = connection.client
assert client == mock_client assert client == mock_client

View File

@ -70,7 +70,7 @@ class TestGetConnectionURL(unittest.TestCase):
hostPort="localhost:3306", hostPort="localhost:3306",
databaseSchema="openmetadata_db", databaseSchema="openmetadata_db",
) )
engine_connection = MySQLConnection(connection).get_client() engine_connection = MySQLConnection(connection).client
self.assertEqual( self.assertEqual(
str(engine_connection.url), str(engine_connection.url),
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", "mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
@ -93,7 +93,7 @@ class TestGetConnectionURL(unittest.TestCase):
"get_token", "get_token",
return_value=AccessToken(token="mocked_token", expires_on=100), return_value=AccessToken(token="mocked_token", expires_on=100),
): ):
engine_connection = MySQLConnection(connection).get_client() engine_connection = MySQLConnection(connection).client
self.assertEqual( self.assertEqual(
str(engine_connection.url), str(engine_connection.url),
"mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db", "mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",

View File

@ -100,7 +100,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon
SnowflakeScheme, SnowflakeScheme,
) )
from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection, TrinoConnection as TrinoConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoScheme, TrinoScheme,
) )
from metadata.generated.schema.entity.services.connections.database.verticaConnection import ( from metadata.generated.schema.entity.services.connections.database.verticaConnection import (
@ -112,7 +114,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common, get_connection_args_common,
get_connection_url_common, get_connection_url_common,
) )
from metadata.ingestion.source.database.trino.connection import get_connection_args from metadata.ingestion.source.database.trino.connection import TrinoConnection
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -405,44 +407,38 @@ class SourceConnectionTest(TestCase):
assert expected_result == get_connection_url(impala_conn_obj) assert expected_result == get_connection_url(impala_conn_obj)
def test_trino_url_without_params(self): def test_trino_url_without_params(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog" expected_url = "trino://username@localhost:443/catalog"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
authType=BasicAuth(password="pass"), authType=BasicAuth(password="pass"),
catalog="catalog", catalog="catalog",
) )
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == get_connection_url(trino_conn_obj) assert expected_url == str(trino_connection.client.url)
# Passing @ in username and password # Passing @ in username and password
expected_url = "trino://username%40444@localhost:443/catalog" expected_url = "trino://username%40444@localhost:443/catalog"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username@444", username="username@444",
authType=BasicAuth(password="pass@111"), authType=BasicAuth(password="pass@111"),
catalog="catalog", catalog="catalog",
) )
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == get_connection_url(trino_conn_obj) assert expected_url == str(trino_connection.client.url)
def test_trino_conn_arguments(self): def test_trino_conn_arguments(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_args,
)
# connection arguments without connectionArguments and without proxies # connection arguments without connectionArguments and without proxies
expected_args = { expected_args = {
"auth": BasicAuthentication("user", None), "auth": BasicAuthentication("user", None),
"http_scheme": "https", "http_scheme": "https",
} }
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
hostPort="localhost:443", hostPort="localhost:443",
@ -450,7 +446,10 @@ class SourceConnectionTest(TestCase):
connectionArguments=None, connectionArguments=None,
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
) )
assert expected_args == get_connection_args(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert (
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
)
# connection arguments with connectionArguments and without proxies # connection arguments with connectionArguments and without proxies
expected_args = { expected_args = {
@ -458,7 +457,7 @@ class SourceConnectionTest(TestCase):
"auth": BasicAuthentication("user", None), "auth": BasicAuthentication("user", None),
"http_scheme": "https", "http_scheme": "https",
} }
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
hostPort="localhost:443", hostPort="localhost:443",
@ -466,14 +465,17 @@ class SourceConnectionTest(TestCase):
connectionArguments={"user": "user-to-be-impersonated"}, connectionArguments={"user": "user-to-be-impersonated"},
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
) )
assert expected_args == get_connection_args(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert (
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
)
# connection arguments without connectionArguments and with proxies # connection arguments without connectionArguments and with proxies
expected_args = { expected_args = {
"auth": BasicAuthentication("user", None), "auth": BasicAuthentication("user", None),
"http_scheme": "https", "http_scheme": "https",
} }
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
hostPort="localhost:443", hostPort="localhost:443",
@ -482,7 +484,9 @@ class SourceConnectionTest(TestCase):
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"}, proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
) )
conn_args = get_connection_args(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
assert "http_session" in conn_args assert "http_session" in conn_args
conn_args.pop("http_session") conn_args.pop("http_session")
assert expected_args == conn_args assert expected_args == conn_args
@ -493,7 +497,7 @@ class SourceConnectionTest(TestCase):
"auth": BasicAuthentication("user", None), "auth": BasicAuthentication("user", None),
"http_scheme": "https", "http_scheme": "https",
} }
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
username="user", username="user",
authType=BasicAuth(password=None), authType=BasicAuth(password=None),
hostPort="localhost:443", hostPort="localhost:443",
@ -502,18 +506,15 @@ class SourceConnectionTest(TestCase):
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"}, proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
) )
conn_args = get_connection_args(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
assert "http_session" in conn_args assert "http_session" in conn_args
conn_args.pop("http_session") conn_args.pop("http_session")
assert expected_args == conn_args assert expected_args == conn_args
def test_trino_url_with_params(self): def test_trino_url_with_params(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog?param=value" expected_url = "trino://username@localhost:443/catalog?param=value"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
@ -521,35 +522,31 @@ class SourceConnectionTest(TestCase):
catalog="catalog", catalog="catalog",
connectionOptions={"param": "value"}, connectionOptions={"param": "value"},
) )
assert expected_url == get_connection_url(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_url_with_jwt_auth(self): def test_trino_url_with_jwt_auth(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog" expected_url = "trino://username@localhost:443/catalog"
expected_args = { expected_args = {
"auth": JWTAuthentication("jwt_token_value"), "auth": JWTAuthentication("jwt_token_value"),
"http_scheme": "https", "http_scheme": "https",
} }
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
authType=JwtAuth(jwt="jwt_token_value"), authType=JwtAuth(jwt="jwt_token_value"),
catalog="catalog", catalog="catalog",
) )
assert expected_url == get_connection_url(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert expected_args == get_connection_args(trino_conn_obj) assert expected_url == str(trino_connection.client.url)
assert (
def test_trino_with_proxies(self): expected_args == trino_connection.build_connection_args(trino_conn_obj).root
from metadata.ingestion.source.database.trino.connection import (
get_connection_args,
) )
def test_trino_with_proxies(self):
test_proxies = {"http": "http_proxy", "https": "https_proxy"} test_proxies = {"http": "http_proxy", "https": "https_proxy"}
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
@ -557,59 +554,54 @@ class SourceConnectionTest(TestCase):
catalog="catalog", catalog="catalog",
proxies=test_proxies, proxies=test_proxies,
) )
trino_connection = TrinoConnection(trino_conn_obj)
assert ( assert (
test_proxies test_proxies
== get_connection_args(trino_conn_obj).get("http_session").proxies == trino_connection.build_connection_args(trino_conn_obj)
.root.get("http_session")
.proxies
) )
def test_trino_without_catalog(self): def test_trino_without_catalog(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog # Test trino url without catalog
expected_url = "trino://username@localhost:443" expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
authType=BasicAuth(password="pass"), authType=BasicAuth(password="pass"),
) )
assert expected_url == get_connection_url(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_without_catalog(self): def test_trino_without_catalog(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog # Test trino url without catalog
expected_url = "trino://username@localhost:443" expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
authType=BasicAuth(password="pass"), authType=BasicAuth(password="pass"),
) )
assert expected_url == get_connection_url(trino_conn_obj) trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_with_oauth2(self): def test_trino_with_oauth2(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog # Test trino url without catalog
expected_url = "trino://username@localhost:443" expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection( trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino, scheme=TrinoScheme.trino,
hostPort="localhost:443", hostPort="localhost:443",
username="username", username="username",
authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2, authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2,
) )
assert isinstance( trino_connection = TrinoConnection(trino_conn_obj)
get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication assert (
trino_connection.build_connection_args(trino_conn_obj).root.get("auth")
== OAuth2Authentication()
) )
def test_vertica_url(self): def test_vertica_url(self):