MINOR: Update Trino Connection to fix data diff (#21983)

This commit is contained in:
IceS2 2025-06-27 07:58:48 +02:00 committed by GitHub
parent 7ff36b2478
commit c899d45e8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 493 additions and 287 deletions

View File

@ -157,7 +157,7 @@ base_requirements = {
"packaging", # For version parsing
"setuptools~=70.0",
"shapely",
"collate-data-diff",
"collate-data-diff>=0.11.6",
"jaraco.functools<4.2.0", # above 4.2 breaks the build
# TODO: Remove one once we have updated datadiff version
"snowflake-connector-python>=3.13.1,<4.0.0",

View File

@ -1,6 +1,6 @@
"""Models for the TableDiff test case"""
from typing import List, Optional
from typing import List, Optional, Union
from pydantic import BaseModel
@ -12,7 +12,7 @@ from metadata.ingestion.models.custom_pydantic import CustomSecretStr
class TableParameter(BaseModel):
serviceUrl: str
serviceUrl: Union[str, dict]
path: str
columns: List[Column]
database_service_type: DatabaseServiceType

View File

@ -1,15 +1,50 @@
"""Base class for param setter logic for table data diff"""
from typing import List, Optional, Set
from typing import List, Optional, Set, Type, Union
from urllib.parse import urlparse
from metadata.data_quality.validations.models import Column, TableParameter
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module
# TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods.
# Using the specs class method causes circular import as TestSuiteInterface
# imports RuntimeParameterSetter
class ServiceSpecPatch:
def __init__(self, service_type: ServiceType, source_type: str):
self.service_type = service_type
self.source_type = source_type
self._service_spec = None
@property
def service_spec(self):
if self._service_spec is None:
self._service_spec = self.get_for_source()
return self._service_spec
def get_for_source(self):
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
"ingestion",
self.service_type.name.lower(),
get_module_dir(self.source_type),
"service_spec",
)
)
def get_data_diff_class(self) -> Type["BaseTableParameter"]:
return import_from_module(self.service_spec.data_diff)
def get_connection_class(self) -> Type[BaseConnection]:
return import_from_module(self.service_spec.connection_class)
class BaseTableParameter:
@ -22,7 +57,7 @@ class BaseTableParameter:
key_columns,
extra_columns,
case_sensitive_columns,
service_url: Optional[str],
service_url: Optional[Union[str, dict]],
) -> TableParameter:
"""Getter table parameter for the table diff test.
@ -62,10 +97,35 @@ class BaseTableParameter:
"___SERVICE___", "__DATABASE__", schema, table
).replace("___SERVICE___.__DATABASE__.", "")
@staticmethod
def _get_service_connection_config(
db_service: DatabaseService,
) -> Optional[Union[str, dict]]:
"""
Get the connection dictionary for the service.
"""
service_connection_config = db_service.connection.config
service_spec_patch = ServiceSpecPatch(
ServiceType.Database, service_connection_config.type.value.lower()
)
try:
connection_class = service_spec_patch.get_connection_class()
connection = connection_class(service_connection_config)
return connection.get_connection_dict()
except (ValueError, AttributeError, NotImplementedError):
return (
str(get_connection(service_connection_config).url)
if service_connection_config
else None
)
@staticmethod
def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
) -> str:
db_service: DatabaseService,
table_fqn,
override_url: Optional[Union[str, dict]] = None,
) -> Union[str, dict]:
"""Get the url for the data diff service.
Args:
@ -77,10 +137,14 @@ class BaseTableParameter:
str: The url for the data diff service
"""
source_url = (
str(get_connection(db_service.connection.config).url)
BaseTableParameter._get_service_connection_config(db_service)
if not override_url
else override_url
)
if isinstance(source_url, dict):
source_url["driver"] = source_url["driver"].split("+")[0]
return source_url
url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]}

View File

@ -11,10 +11,13 @@
"""Module that defines the TableDiffParamsSetter class."""
from ast import literal_eval
from typing import List, Optional, Set
from urllib.parse import urlparse
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
ServiceSpecPatch,
)
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
@ -22,24 +25,8 @@ from metadata.generated.schema.entity.data.table import Constraint, Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module
def get_for_source(
service_type: ServiceType, source_type: str, from_: str = "ingestion"
):
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
from_,
service_type.name.lower(),
get_module_dir(source_type),
"service_spec",
)
)
class TableDiffParamsSetter(RuntimeParameterSetter):
@ -62,23 +49,16 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
}
def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
# Using the specs class method causes circular import as TestSuiteInterface
# imports RuntimeParameterSetter
cls_path = get_for_source(
ServiceType.Database,
source_type=self.service_connection_config.type.value.lower(),
).data_diff
cls = import_from_module(cls_path)()
service_spec_patch = ServiceSpecPatch(
ServiceType.Database, self.service_connection_config.type.value.lower()
)
cls = service_spec_patch.get_data_diff_class()()
service1: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, self.table_entity.service.id, nullable=False
)
service1_url = (
str(get_connection(self.service_connection_config).url)
if self.service_connection_config
else None
)
service1_url = BaseTableParameter._get_service_connection_config(service1)
table2_fqn = self.get_parameter(test_case, "table2")
if table2_fqn is None:
@ -209,41 +189,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
(p.value for p in test_case.parameterValues if p.name == key), default
)
@staticmethod
def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
) -> str:
"""Get the url for the data diff service.
Args:
db_service (DatabaseService): The database service entity
table_fqn (str): The fully qualified name of the table
override_url (Optional[str], optional): Override the url. Defaults to None.
Returns:
str: The url for the data diff service
"""
source_url = (
str(get_connection(db_service.connection.config).url)
if not override_url
else override_url
)
url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]}
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
table_fqn
)
# path needs to include the database AND schema in some of the connectors
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
# this can be found by going to:
# https://github.com/open-metadata/collate-data-diff/blob/main/data_diff/databases/<connector>.py
# and looking at the `CONNECT_URI_HELPER` variable
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake, Dialects.Trino}:
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()
@staticmethod
def get_data_diff_table_path(table_fqn: str) -> str:
service, database, schema, table = fqn.split( # pylint: disable=unused-variable

View File

@ -508,6 +508,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
("table1.serviceUrl", self.runtime_params.table1.serviceUrl),
("table2.serviceUrl", self.runtime_params.table2.serviceUrl),
]:
if isinstance(param, dict):
dialect = param.get("driver")
else:
dialect = urlparse(param).scheme
if dialect not in SUPPORTED_DIALECTS:
raise UnsupportedDialectError(name, dialect)

View File

@ -41,12 +41,23 @@ class BaseConnection(ABC, Generic[S, C]):
"""
service_connection: S
_client: Optional[C]
def __init__(self, service_connection: S) -> None:
self.service_connection = service_connection
self._client = None
@property
def client(self) -> C:
"""
Return the main client/engine/connection object for this service.
"""
if self._client is None:
self._client = self._get_client()
return self._client
@abstractmethod
def get_client(self) -> C:
def _get_client(self) -> C:
"""
Return the main client/engine/connection object for this service.
"""
@ -61,3 +72,9 @@ class BaseConnection(ABC, Generic[S, C]):
"""
Test the connection to the service.
"""
@abstractmethod
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""

View File

@ -73,7 +73,7 @@ def _get_connection_fn_from_service_spec(connection: BaseModel) -> Optional[Call
if connection_class:
def _get_client(conn):
return connection_class(conn).get_client()
return connection_class(conn).client
return _get_client
return None

View File

@ -50,7 +50,7 @@ from metadata.utils.constants import THREE_MIN
class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
def get_client(self) -> Engine:
def _get_client(self) -> Engine:
"""
Return the SQLAlchemy Engine for MySQL.
"""
@ -77,6 +77,12 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
get_connection_args_fn=get_connection_args_common,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
raise NotImplementedError("get_connection_dict is not implemented for MySQL")
def test_connection(
self,
metadata: OpenMetadata,
@ -94,7 +100,7 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
}
return test_connection_db_schema_sources(
metadata=metadata,
engine=self.get_client(),
engine=self.client,
service_connection=self.service_connection,
automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds,

View File

@ -13,7 +13,7 @@
Source connection handler
"""
from copy import deepcopy
from typing import Optional
from typing import Optional, cast
from urllib.parse import quote_plus
from requests import Session
@ -24,13 +24,17 @@ from metadata.clients.azure_client import AzureClient
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
ConnectionArguments,
)
from metadata.generated.schema.entity.services.connections.database.common import (
azureConfig,
basicAuth,
jwtAuth,
noConfigAuthenticationTypes,
)
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection,
TrinoConnection as TrinoConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult,
@ -41,6 +45,7 @@ from metadata.ingestion.connections.builders import (
init_empty_connection_arguments,
init_empty_connection_options,
)
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.secrets import connection_with_options_secrets
from metadata.ingestion.connections.test_connections import (
test_connection_db_schema_sources,
@ -50,10 +55,140 @@ from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE
from metadata.utils.constants import THREE_MIN
def get_connection_url(connection: TrinoConnection) -> str:
# pylint: disable=unused-argument
def _is_disconnect(self, e, connection, cursor):
"""is_disconnect method for the Databricks dialect"""
if "JWT expired" in str(e):
return True
return False
class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]):
def __init__(self, connection: TrinoConnectionConfig):
super().__init__(connection)
def _get_client(self) -> Engine:
"""
Create connection
"""
# here we are creating a copy of connection, because we need to dynamically
# add auth params to connectionArguments, which we do no intend to store
# in original connection object and in OpenMetadata database
from trino.sqlalchemy.dialect import TrinoDialect
TrinoDialect.is_disconnect = _is_disconnect # type: ignore
connection = self.service_connection
connection_copy = deepcopy(connection)
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token
# Update the connection with the connection arguments
connection_copy.connectionArguments = self.build_connection_args(
connection_copy
)
return create_generic_db_connection(
connection=connection_copy,
get_connection_url_fn=self.get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
def test_connection(
self,
metadata: OpenMetadata,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetDatabases": TRINO_GET_DATABASE,
}
return test_connection_db_schema_sources(
metadata=metadata,
engine=self.client,
service_connection=self.service_connection,
automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
url = self.client.url
connection_copy = deepcopy(self.service_connection)
connection_dict = {
"driver": url.drivername,
"host": url.host,
"port": url.port,
"user": url.username,
"catalog": url.database,
"schema": url.query.get("schema"),
}
connection_dict.update(url.query)
if connection_copy.proxies:
connection_dict["http_session"] = connection_copy.proxies
if (
connection_copy.connectionArguments
and connection_copy.connectionArguments.root
):
connection_with_options_secrets(lambda: connection_copy)
connection_dict.update(get_connection_args_common(connection_copy))
if isinstance(connection_copy.authType, basicAuth.BasicAuth):
connection_dict["auth"] = TrinoConnection.get_basic_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
elif isinstance(connection_copy.authType, jwtAuth.JwtAuth):
connection_dict["auth"] = TrinoConnection.get_jwt_auth_dict(connection_copy)
connection_dict["http_scheme"] = "https"
elif hasattr(connection_copy.authType, "azureConfig"):
connection_dict["auth"] = TrinoConnection.get_azure_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
elif (
connection_copy.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
connection_dict["auth"] = TrinoConnection.get_oauth2_auth_dict(
connection_copy
)
connection_dict["http_scheme"] = "https"
return connection_dict
@staticmethod
def get_connection_url(connection: TrinoConnectionConfig) -> str:
"""
Prepare the connection url for trino
"""
url = f"{connection.scheme.value}://"
# leaving username here as, even though with basic auth is used directly
@ -75,124 +210,157 @@ def get_connection_url(connection: TrinoConnection) -> str:
url = f"{url}?{params}"
return url
@staticmethod
@connection_with_options_secrets
def get_connection_args(connection: TrinoConnection):
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
def build_connection_args(connection: TrinoConnectionConfig) -> ConnectionArguments:
"""
Get the connection args for the trino connection
"""
connection_args: ConnectionArguments = (
connection.connectionArguments or init_empty_connection_arguments()
)
assert connection_args.root is not None
if connection.verify:
connection_args.root["verify"] = {"verify": connection.verify}
if connection.proxies:
session = Session()
session.proxies = connection.proxies
connection.connectionArguments.root["http_session"] = session
connection_args.root["http_session"] = session
if isinstance(connection.authType, basicAuth.BasicAuth):
connection.connectionArguments.root["auth"] = BasicAuthentication(
connection.username,
connection.authType.password.get_secret_value()
if connection.authType.password
else None,
)
connection.connectionArguments.root["http_scheme"] = "https"
TrinoConnection.set_basic_auth(connection, connection_args)
elif isinstance(connection.authType, jwtAuth.JwtAuth):
connection.connectionArguments.root["auth"] = JWTAuthentication(
connection.authType.jwt.get_secret_value()
)
connection.connectionArguments.root["http_scheme"] = "https"
TrinoConnection.set_jwt_auth(connection, connection_args)
elif hasattr(connection.authType, "azureConfig"):
if not connection.authType.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
azure_client = AzureClient(connection.authType.azureConfig).create_client()
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
connection.connectionArguments.root["auth"] = JWTAuthentication(
access_token_obj.token
)
connection.connectionArguments.root["http_scheme"] = "https"
TrinoConnection.set_azure_auth(connection, connection_args)
elif (
connection.authType
== noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2
):
connection.connectionArguments.root["auth"] = OAuth2Authentication()
connection.connectionArguments.root["http_scheme"] = "https"
TrinoConnection.set_oauth2_auth(connection, connection_args)
return get_connection_args_common(connection)
return connection_args
def get_connection(connection: TrinoConnection) -> Engine:
@staticmethod
def get_basic_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Create connection
Get the basic auth dictionary for the trino connection
"""
# here we are creating a copy of connection, because we need to dynamically
# add auth params to connectionArguments, which we do no intend to store
# in original connection object and in OpenMetadata database
from trino.sqlalchemy.dialect import TrinoDialect
auth_type = cast(basicAuth.BasicAuth, connection.authType)
return {
"authType": "basic",
"username": connection.username,
"password": auth_type.password.get_secret_value()
if auth_type.password
else None,
}
TrinoDialect.is_disconnect = _is_disconnect
@staticmethod
def set_basic_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Get the basic auth dictionary for the trino connection
"""
assert connection_args.root is not None
auth_type = cast(basicAuth.BasicAuth, connection.authType)
connection_copy = deepcopy(connection)
if connection_copy.verify:
connection_copy.connectionArguments = (
connection_copy.connectionArguments or init_empty_connection_arguments()
connection_args.root["auth"] = BasicAuthentication(
connection.username,
auth_type.password.get_secret_value() if auth_type.password else None,
)
connection.connectionArguments.root["verify"] = {"verify": connection.verify}
if hasattr(connection.authType, "azureConfig"):
azure_client = AzureClient(connection.authType.azureConfig).create_client()
if not connection.authType.azureConfig.scopes:
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_jwt_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the jwt auth dictionary for the trino connection
"""
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
return {
"authType": "jwt",
"jwt": auth_type.jwt.get_secret_value(),
}
@staticmethod
def set_jwt_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the jwt auth for the trino connection
"""
assert connection_args.root is not None
auth_type = cast(jwtAuth.JwtAuth, connection.authType)
connection_args.root["auth"] = JWTAuthentication(
auth_type.jwt.get_secret_value()
)
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_azure_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the azure auth dictionary for the trino connection
"""
return {
"authType": "jwt",
"jwt": TrinoConnection.get_azure_token(connection),
}
@staticmethod
def set_azure_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the azure auth for the trino connection
"""
assert connection_args.root is not None
connection_args.root["auth"] = JWTAuthentication(
TrinoConnection.get_azure_token(connection)
)
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_oauth2_auth_dict(connection: TrinoConnectionConfig) -> dict:
"""
Get the oauth2 auth dictionary for the trino connection
"""
return {
"authType": "oauth2",
}
@staticmethod
def set_oauth2_auth(
connection: TrinoConnectionConfig, connection_args: ConnectionArguments
) -> None:
"""
Set the oauth2 auth for the trino connection
"""
assert connection_args.root is not None
connection_args.root["auth"] = OAuth2Authentication()
connection_args.root["http_scheme"] = "https"
@staticmethod
def get_azure_token(connection: TrinoConnectionConfig) -> str:
"""
Get the azure token for the trino connection
"""
auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType)
if not auth_type.azureConfig.scopes:
raise ValueError(
"Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default"
)
access_token_obj = azure_client.get_token(
*connection.authType.azureConfig.scopes.split(",")
)
if not connection.connectionOptions:
connection.connectionOptions = init_empty_connection_options()
connection.connectionOptions.root["access_token"] = access_token_obj.token
return create_generic_db_connection(
connection=connection_copy,
get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args,
)
azure_client = AzureClient(auth_type.azureConfig).create_client()
def test_connection(
metadata: OpenMetadata,
engine: Engine,
service_connection: TrinoConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetDatabases": TRINO_GET_DATABASE,
}
return test_connection_db_schema_sources(
metadata=metadata,
engine=engine,
service_connection=service_connection,
automation_workflow=automation_workflow,
queries=queries,
timeout_seconds=timeout_seconds,
)
# pylint: disable=unused-argument
def _is_disconnect(self, e, connection, cursor):
"""is_disconnect method for the Databricks dialect"""
if "JWT expired" in str(e):
return True
return False
return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token

View File

@ -1,3 +1,4 @@
from metadata.ingestion.source.database.trino.connection import TrinoConnection
from metadata.ingestion.source.database.trino.lineage import TrinoLineageSource
from metadata.ingestion.source.database.trino.metadata import TrinoSource
from metadata.ingestion.source.database.trino.usage import TrinoUsageSource
@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec(
usage_source_class=TrinoUsageSource,
profiler_class=TrinoProfilerInterface,
sampler_class=TrinoSampler,
connection_class=TrinoConnection,
)

View File

@ -6,6 +6,9 @@ from sqlalchemy import Column as SAColumn
from sqlalchemy import MetaData, String, create_engine
from sqlalchemy.orm import declarative_base
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
)
@ -111,9 +114,9 @@ SERVICE_CONNECTION_CONFIG = MysqlConnection(
],
)
def test_get_data_diff_url(input, expected):
assert expected == TableDiffParamsSetter(
None, None, MOCK_TABLE, None
).get_data_diff_url(input, "service.database.schema.table")
assert expected == BaseTableParameter.get_data_diff_url(
input, "service.database.schema.table"
)
@pytest.mark.parametrize(

View File

@ -44,7 +44,7 @@ def test_connection(mock_service_connection):
class TestConnection(BaseConnection):
"""Concrete implementation of BaseConnection for testing"""
def get_client(self):
def _get_client(self):
return MagicMock()
def test_connection(
@ -64,6 +64,9 @@ def test_connection(mock_service_connection):
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
)
def get_connection_dict(self):
return {}
return TestConnection(mock_service_connection)
@ -100,7 +103,7 @@ class TestBaseConnection:
mock_client = MagicMock()
class TestConnectionWithMockClient(BaseConnection):
def get_client(self):
def _get_client(self):
return mock_client
def test_connection(
@ -120,6 +123,9 @@ class TestBaseConnection:
lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)),
)
def get_connection_dict(self):
return {}
connection = TestConnectionWithMockClient(test_connection.service_connection)
client = connection.get_client()
client = connection.client
assert client == mock_client

View File

@ -70,7 +70,7 @@ class TestGetConnectionURL(unittest.TestCase):
hostPort="localhost:3306",
databaseSchema="openmetadata_db",
)
engine_connection = MySQLConnection(connection).get_client()
engine_connection = MySQLConnection(connection).client
self.assertEqual(
str(engine_connection.url),
"mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db",
@ -93,7 +93,7 @@ class TestGetConnectionURL(unittest.TestCase):
"get_token",
return_value=AccessToken(token="mocked_token", expires_on=100),
):
engine_connection = MySQLConnection(connection).get_client()
engine_connection = MySQLConnection(connection).client
self.assertEqual(
str(engine_connection.url),
"mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db",

View File

@ -100,7 +100,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon
SnowflakeScheme,
)
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection,
TrinoConnection as TrinoConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoScheme,
)
from metadata.generated.schema.entity.services.connections.database.verticaConnection import (
@ -112,7 +114,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common,
get_connection_url_common,
)
from metadata.ingestion.source.database.trino.connection import get_connection_args
from metadata.ingestion.source.database.trino.connection import TrinoConnection
# pylint: disable=import-outside-toplevel
@ -405,44 +407,38 @@ class SourceConnectionTest(TestCase):
assert expected_result == get_connection_url(impala_conn_obj)
def test_trino_url_without_params(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=BasicAuth(password="pass"),
catalog="catalog",
)
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == get_connection_url(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
# Passing @ in username and password
expected_url = "trino://username%40444@localhost:443/catalog"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username@444",
authType=BasicAuth(password="pass@111"),
catalog="catalog",
)
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == get_connection_url(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_conn_arguments(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_args,
)
# connection arguments without connectionArguments and without proxies
expected_args = {
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
username="user",
authType=BasicAuth(password=None),
hostPort="localhost:443",
@ -450,7 +446,10 @@ class SourceConnectionTest(TestCase):
connectionArguments=None,
scheme=TrinoScheme.trino,
)
assert expected_args == get_connection_args(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
assert (
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
)
# connection arguments with connectionArguments and without proxies
expected_args = {
@ -458,7 +457,7 @@ class SourceConnectionTest(TestCase):
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
username="user",
authType=BasicAuth(password=None),
hostPort="localhost:443",
@ -466,14 +465,17 @@ class SourceConnectionTest(TestCase):
connectionArguments={"user": "user-to-be-impersonated"},
scheme=TrinoScheme.trino,
)
assert expected_args == get_connection_args(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
assert (
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
)
# connection arguments without connectionArguments and with proxies
expected_args = {
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
username="user",
authType=BasicAuth(password=None),
hostPort="localhost:443",
@ -482,7 +484,9 @@ class SourceConnectionTest(TestCase):
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
scheme=TrinoScheme.trino,
)
conn_args = get_connection_args(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
assert "http_session" in conn_args
conn_args.pop("http_session")
assert expected_args == conn_args
@ -493,7 +497,7 @@ class SourceConnectionTest(TestCase):
"auth": BasicAuthentication("user", None),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
username="user",
authType=BasicAuth(password=None),
hostPort="localhost:443",
@ -502,18 +506,15 @@ class SourceConnectionTest(TestCase):
proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"},
scheme=TrinoScheme.trino,
)
conn_args = get_connection_args(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
conn_args = trino_connection.build_connection_args(trino_conn_obj).root
assert "http_session" in conn_args
conn_args.pop("http_session")
assert expected_args == conn_args
def test_trino_url_with_params(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog?param=value"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
@ -521,35 +522,31 @@ class SourceConnectionTest(TestCase):
catalog="catalog",
connectionOptions={"param": "value"},
)
assert expected_url == get_connection_url(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_url_with_jwt_auth(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
expected_url = "trino://username@localhost:443/catalog"
expected_args = {
"auth": JWTAuthentication("jwt_token_value"),
"http_scheme": "https",
}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=JwtAuth(jwt="jwt_token_value"),
catalog="catalog",
)
assert expected_url == get_connection_url(trino_conn_obj)
assert expected_args == get_connection_args(trino_conn_obj)
def test_trino_with_proxies(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_args,
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
assert (
expected_args == trino_connection.build_connection_args(trino_conn_obj).root
)
def test_trino_with_proxies(self):
test_proxies = {"http": "http_proxy", "https": "https_proxy"}
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
@ -557,59 +554,54 @@ class SourceConnectionTest(TestCase):
catalog="catalog",
proxies=test_proxies,
)
trino_connection = TrinoConnection(trino_conn_obj)
assert (
test_proxies
== get_connection_args(trino_conn_obj).get("http_session").proxies
== trino_connection.build_connection_args(trino_conn_obj)
.root.get("http_session")
.proxies
)
def test_trino_without_catalog(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog
expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=BasicAuth(password="pass"),
)
assert expected_url == get_connection_url(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_without_catalog(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog
expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=BasicAuth(password="pass"),
)
assert expected_url == get_connection_url(trino_conn_obj)
trino_connection = TrinoConnection(trino_conn_obj)
assert expected_url == str(trino_connection.client.url)
def test_trino_with_oauth2(self):
from metadata.ingestion.source.database.trino.connection import (
get_connection_url,
)
# Test trino url without catalog
expected_url = "trino://username@localhost:443"
trino_conn_obj = TrinoConnection(
trino_conn_obj = TrinoConnectionConfig(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2,
)
assert isinstance(
get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication
trino_connection = TrinoConnection(trino_conn_obj)
assert (
trino_connection.build_connection_args(trino_conn_obj).root.get("auth")
== OAuth2Authentication()
)
def test_vertica_url(self):