diff --git a/ingestion/setup.py b/ingestion/setup.py index f992cb8104d..09ce3fea39d 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -157,7 +157,7 @@ base_requirements = { "packaging", # For version parsing "setuptools~=70.0", "shapely", - "collate-data-diff", + "collate-data-diff>=0.11.6", "jaraco.functools<4.2.0", # above 4.2 breaks the build # TODO: Remove one once we have updated datadiff version "snowflake-connector-python>=3.13.1,<4.0.0", diff --git a/ingestion/src/metadata/data_quality/validations/models.py b/ingestion/src/metadata/data_quality/validations/models.py index 37fa09363a7..7af1430db87 100644 --- a/ingestion/src/metadata/data_quality/validations/models.py +++ b/ingestion/src/metadata/data_quality/validations/models.py @@ -1,6 +1,6 @@ """Models for the TableDiff test case""" -from typing import List, Optional +from typing import List, Optional, Union from pydantic import BaseModel @@ -12,7 +12,7 @@ from metadata.ingestion.models.custom_pydantic import CustomSecretStr class TableParameter(BaseModel): - serviceUrl: str + serviceUrl: Union[str, dict] path: str columns: List[Column] database_service_type: DatabaseServiceType diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py index d99ca44afe1..ddbdc7c59dc 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py @@ -1,15 +1,50 @@ """Base class for param setter logic for table data diff""" -from typing import List, Optional, Set +from typing import List, Optional, Set, Type, Union from urllib.parse import urlparse from metadata.data_quality.validations.models import Column, TableParameter from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.entity.services.serviceType import ServiceType +from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.source.connections import get_connection from metadata.profiler.orm.registry import Dialects from metadata.utils import fqn from metadata.utils.collections import CaseInsensitiveList +from metadata.utils.importer import get_module_dir, import_from_module + + +# TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods. +# Using the specs class method causes circular import as TestSuiteInterface +# imports RuntimeParameterSetter +class ServiceSpecPatch: + def __init__(self, service_type: ServiceType, source_type: str): + self.service_type = service_type + self.source_type = source_type + self._service_spec = None + + @property + def service_spec(self): + if self._service_spec is None: + self._service_spec = self.get_for_source() + return self._service_spec + + def get_for_source(self): + return import_from_module( + "metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209 + "ingestion", + self.service_type.name.lower(), + get_module_dir(self.source_type), + "service_spec", + ) + ) + + def get_data_diff_class(self) -> Type["BaseTableParameter"]: + return import_from_module(self.service_spec.data_diff) + + def get_connection_class(self) -> Type[BaseConnection]: + return import_from_module(self.service_spec.connection_class) class BaseTableParameter: @@ -22,7 +57,7 @@ class BaseTableParameter: key_columns, extra_columns, case_sensitive_columns, - service_url: Optional[str], + service_url: Optional[Union[str, dict]], ) -> TableParameter: """Getter table parameter for the table diff test. @@ -62,10 +97,35 @@ class BaseTableParameter: "___SERVICE___", "__DATABASE__", schema, table ).replace("___SERVICE___.__DATABASE__.", "") + @staticmethod + def _get_service_connection_config( + db_service: DatabaseService, + ) -> Optional[Union[str, dict]]: + """ + Get the connection dictionary for the service. + """ + service_connection_config = db_service.connection.config + service_spec_patch = ServiceSpecPatch( + ServiceType.Database, service_connection_config.type.value.lower() + ) + + try: + connection_class = service_spec_patch.get_connection_class() + connection = connection_class(service_connection_config) + return connection.get_connection_dict() + except (ValueError, AttributeError, NotImplementedError): + return ( + str(get_connection(service_connection_config).url) + if service_connection_config + else None + ) + @staticmethod def get_data_diff_url( - db_service: DatabaseService, table_fqn, override_url: Optional[str] = None - ) -> str: + db_service: DatabaseService, + table_fqn, + override_url: Optional[Union[str, dict]] = None, + ) -> Union[str, dict]: """Get the url for the data diff service. Args: @@ -77,10 +137,14 @@ class BaseTableParameter: str: The url for the data diff service """ source_url = ( - str(get_connection(db_service.connection.config).url) + BaseTableParameter._get_service_connection_config(db_service) if not override_url else override_url ) + if isinstance(source_url, dict): + source_url["driver"] = source_url["driver"].split("+")[0] + return source_url + url = urlparse(source_url) # remove the driver name from the url because table-diff doesn't support it kwargs = {"scheme": url.scheme.split("+")[0]} diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index d28e46aae90..bd4e5fae743 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -11,10 +11,13 @@ """Module that defines the TableDiffParamsSetter class.""" from ast import literal_eval from typing import List, Optional, Set -from urllib.parse import urlparse from metadata.data_quality.validations import utils from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters +from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import ( + BaseTableParameter, + ServiceSpecPatch, +) from metadata.data_quality.validations.runtime_param_setter.param_setter import ( RuntimeParameterSetter, ) @@ -22,24 +25,8 @@ from metadata.generated.schema.entity.data.table import Constraint, Table from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.serviceType import ServiceType from metadata.generated.schema.tests.testCase import TestCase -from metadata.ingestion.source.connections import get_connection -from metadata.profiler.orm.registry import Dialects from metadata.utils import fqn from metadata.utils.collections import CaseInsensitiveList -from metadata.utils.importer import get_module_dir, import_from_module - - -def get_for_source( - service_type: ServiceType, source_type: str, from_: str = "ingestion" -): - return import_from_module( - "metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209 - from_, - service_type.name.lower(), - get_module_dir(source_type), - "service_spec", - ) - ) class TableDiffParamsSetter(RuntimeParameterSetter): @@ -62,23 +49,16 @@ class TableDiffParamsSetter(RuntimeParameterSetter): } def get_parameters(self, test_case) -> TableDiffRuntimeParameters: - # Using the specs class method causes circular import as TestSuiteInterface - # imports RuntimeParameterSetter - cls_path = get_for_source( - ServiceType.Database, - source_type=self.service_connection_config.type.value.lower(), - ).data_diff - cls = import_from_module(cls_path)() + service_spec_patch = ServiceSpecPatch( + ServiceType.Database, self.service_connection_config.type.value.lower() + ) + cls = service_spec_patch.get_data_diff_class()() service1: DatabaseService = self.ometa_client.get_by_id( DatabaseService, self.table_entity.service.id, nullable=False ) - service1_url = ( - str(get_connection(self.service_connection_config).url) - if self.service_connection_config - else None - ) + service1_url = BaseTableParameter._get_service_connection_config(service1) table2_fqn = self.get_parameter(test_case, "table2") if table2_fqn is None: @@ -209,41 +189,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter): (p.value for p in test_case.parameterValues if p.name == key), default ) - @staticmethod - def get_data_diff_url( - db_service: DatabaseService, table_fqn, override_url: Optional[str] = None - ) -> str: - """Get the url for the data diff service. - - Args: - db_service (DatabaseService): The database service entity - table_fqn (str): The fully qualified name of the table - override_url (Optional[str], optional): Override the url. Defaults to None. - - Returns: - str: The url for the data diff service - """ - source_url = ( - str(get_connection(db_service.connection.config).url) - if not override_url - else override_url - ) - url = urlparse(source_url) - # remove the driver name from the url because table-diff doesn't support it - kwargs = {"scheme": url.scheme.split("+")[0]} - service, database, schema, table = fqn.split( # pylint: disable=unused-variable - table_fqn - ) - # path needs to include the database AND schema in some of the connectors - if hasattr(db_service.connection.config, "supportsDatabase"): - kwargs["path"] = f"/{database}" - # this can be found by going to: - # https://github.com/open-metadata/collate-data-diff/blob/main/data_diff/databases/.py - # and looking at the `CONNECT_URI_HELPER` variable - if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake, Dialects.Trino}: - kwargs["path"] = f"/{database}/{schema}" - return url._replace(**kwargs).geturl() - @staticmethod def get_data_diff_table_path(table_fqn: str) -> str: service, database, schema, table = fqn.split( # pylint: disable=unused-variable diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 99e58358dee..157a5b2df57 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -508,7 +508,10 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): ("table1.serviceUrl", self.runtime_params.table1.serviceUrl), ("table2.serviceUrl", self.runtime_params.table2.serviceUrl), ]: - dialect = urlparse(param).scheme + if isinstance(param, dict): + dialect = param.get("driver") + else: + dialect = urlparse(param).scheme if dialect not in SUPPORTED_DIALECTS: raise UnsupportedDialectError(name, dialect) diff --git a/ingestion/src/metadata/ingestion/connections/connection.py b/ingestion/src/metadata/ingestion/connections/connection.py index 081d770b359..b196dd29ece 100644 --- a/ingestion/src/metadata/ingestion/connections/connection.py +++ b/ingestion/src/metadata/ingestion/connections/connection.py @@ -41,12 +41,23 @@ class BaseConnection(ABC, Generic[S, C]): """ service_connection: S + _client: Optional[C] def __init__(self, service_connection: S) -> None: self.service_connection = service_connection + self._client = None + + @property + def client(self) -> C: + """ + Return the main client/engine/connection object for this service. + """ + if self._client is None: + self._client = self._get_client() + return self._client @abstractmethod - def get_client(self) -> C: + def _get_client(self) -> C: """ Return the main client/engine/connection object for this service. """ @@ -61,3 +72,9 @@ class BaseConnection(ABC, Generic[S, C]): """ Test the connection to the service. """ + + @abstractmethod + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ diff --git a/ingestion/src/metadata/ingestion/source/connections.py b/ingestion/src/metadata/ingestion/source/connections.py index e6eea6ceb2c..46c2b331086 100644 --- a/ingestion/src/metadata/ingestion/source/connections.py +++ b/ingestion/src/metadata/ingestion/source/connections.py @@ -73,7 +73,7 @@ def _get_connection_fn_from_service_spec(connection: BaseModel) -> Optional[Call if connection_class: def _get_client(conn): - return connection_class(conn).get_client() + return connection_class(conn).client return _get_client return None diff --git a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py index 0e39be1eef2..9bf32d37fa2 100644 --- a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py @@ -50,7 +50,7 @@ from metadata.utils.constants import THREE_MIN class MySQLConnection(BaseConnection[MysqlConnection, Engine]): - def get_client(self) -> Engine: + def _get_client(self) -> Engine: """ Return the SQLAlchemy Engine for MySQL. """ @@ -77,6 +77,12 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]): get_connection_args_fn=get_connection_args_common, ) + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + raise NotImplementedError("get_connection_dict is not implemented for MySQL") + def test_connection( self, metadata: OpenMetadata, @@ -94,7 +100,7 @@ class MySQLConnection(BaseConnection[MysqlConnection, Engine]): } return test_connection_db_schema_sources( metadata=metadata, - engine=self.get_client(), + engine=self.client, service_connection=self.service_connection, automation_workflow=automation_workflow, timeout_seconds=timeout_seconds, diff --git a/ingestion/src/metadata/ingestion/source/database/trino/connection.py b/ingestion/src/metadata/ingestion/source/database/trino/connection.py index 529436bbb7e..0dcdc9f8dc7 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/trino/connection.py @@ -13,7 +13,7 @@ Source connection handler """ from copy import deepcopy -from typing import Optional +from typing import Optional, cast from urllib.parse import quote_plus from requests import Session @@ -24,13 +24,17 @@ from metadata.clients.azure_client import AzureClient from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) +from metadata.generated.schema.entity.services.connections.connectionBasicType import ( + ConnectionArguments, +) from metadata.generated.schema.entity.services.connections.database.common import ( + azureConfig, basicAuth, jwtAuth, noConfigAuthenticationTypes, ) from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( - TrinoConnection, + TrinoConnection as TrinoConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, @@ -41,6 +45,7 @@ from metadata.ingestion.connections.builders import ( init_empty_connection_arguments, init_empty_connection_options, ) +from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.connections.secrets import connection_with_options_secrets from metadata.ingestion.connections.test_connections import ( test_connection_db_schema_sources, @@ -50,149 +55,312 @@ from metadata.ingestion.source.database.trino.queries import TRINO_GET_DATABASE from metadata.utils.constants import THREE_MIN -def get_connection_url(connection: TrinoConnection) -> str: - """ - Prepare the connection url for trino - """ - url = f"{connection.scheme.value}://" - - # leaving username here as, even though with basic auth is used directly - # in BasicAuthentication class, it's often also required as a part of url. - # For example - it will be used by OAuth2Authentication to persist token in - # cache more efficiently (per user instead of per host) - if connection.username: - url += f"{quote_plus(connection.username)}@" - - url += f"{connection.hostPort}" - if connection.catalog: - url += f"/{connection.catalog}" - if connection.connectionOptions is not None: - params = "&".join( - f"{key}={quote_plus(value)}" - for (key, value) in connection.connectionOptions.root.items() - if value - ) - url = f"{url}?{params}" - return url - - -@connection_with_options_secrets -def get_connection_args(connection: TrinoConnection): - if not connection.connectionArguments: - connection.connectionArguments = init_empty_connection_arguments() - - if connection.proxies: - session = Session() - session.proxies = connection.proxies - - connection.connectionArguments.root["http_session"] = session - - if isinstance(connection.authType, basicAuth.BasicAuth): - connection.connectionArguments.root["auth"] = BasicAuthentication( - connection.username, - connection.authType.password.get_secret_value() - if connection.authType.password - else None, - ) - connection.connectionArguments.root["http_scheme"] = "https" - - elif isinstance(connection.authType, jwtAuth.JwtAuth): - connection.connectionArguments.root["auth"] = JWTAuthentication( - connection.authType.jwt.get_secret_value() - ) - connection.connectionArguments.root["http_scheme"] = "https" - - elif hasattr(connection.authType, "azureConfig"): - if not connection.authType.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - - azure_client = AzureClient(connection.authType.azureConfig).create_client() - - access_token_obj = azure_client.get_token( - *connection.authType.azureConfig.scopes.split(",") - ) - - connection.connectionArguments.root["auth"] = JWTAuthentication( - access_token_obj.token - ) - connection.connectionArguments.root["http_scheme"] = "https" - - elif ( - connection.authType - == noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2 - ): - connection.connectionArguments.root["auth"] = OAuth2Authentication() - connection.connectionArguments.root["http_scheme"] = "https" - - return get_connection_args_common(connection) - - -def get_connection(connection: TrinoConnection) -> Engine: - """ - Create connection - """ - # here we are creating a copy of connection, because we need to dynamically - # add auth params to connectionArguments, which we do no intend to store - # in original connection object and in OpenMetadata database - from trino.sqlalchemy.dialect import TrinoDialect - - TrinoDialect.is_disconnect = _is_disconnect - - connection_copy = deepcopy(connection) - if connection_copy.verify: - connection_copy.connectionArguments = ( - connection_copy.connectionArguments or init_empty_connection_arguments() - ) - connection.connectionArguments.root["verify"] = {"verify": connection.verify} - if hasattr(connection.authType, "azureConfig"): - azure_client = AzureClient(connection.authType.azureConfig).create_client() - if not connection.authType.azureConfig.scopes: - raise ValueError( - "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" - ) - access_token_obj = azure_client.get_token( - *connection.authType.azureConfig.scopes.split(",") - ) - if not connection.connectionOptions: - connection.connectionOptions = init_empty_connection_options() - connection.connectionOptions.root["access_token"] = access_token_obj.token - return create_generic_db_connection( - connection=connection_copy, - get_connection_url_fn=get_connection_url, - get_connection_args_fn=get_connection_args, - ) - - -def test_connection( - metadata: OpenMetadata, - engine: Engine, - service_connection: TrinoConnection, - automation_workflow: Optional[AutomationWorkflow] = None, - timeout_seconds: Optional[int] = THREE_MIN, -) -> TestConnectionResult: - """ - Test connection. This can be executed either as part - of a metadata workflow or during an Automation Workflow - """ - queries = { - "GetDatabases": TRINO_GET_DATABASE, - } - - return test_connection_db_schema_sources( - metadata=metadata, - engine=engine, - service_connection=service_connection, - automation_workflow=automation_workflow, - queries=queries, - timeout_seconds=timeout_seconds, - ) - - # pylint: disable=unused-argument def _is_disconnect(self, e, connection, cursor): """is_disconnect method for the Databricks dialect""" if "JWT expired" in str(e): return True return False + + +class TrinoConnection(BaseConnection[TrinoConnectionConfig, Engine]): + def __init__(self, connection: TrinoConnectionConfig): + super().__init__(connection) + + def _get_client(self) -> Engine: + """ + Create connection + """ + # here we are creating a copy of connection, because we need to dynamically + # add auth params to connectionArguments, which we do no intend to store + # in original connection object and in OpenMetadata database + from trino.sqlalchemy.dialect import TrinoDialect + + TrinoDialect.is_disconnect = _is_disconnect # type: ignore + + connection = self.service_connection + connection_copy = deepcopy(connection) + + if hasattr(connection.authType, "azureConfig"): + azure_client = AzureClient(connection.authType.azureConfig).create_client() + if not connection.authType.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + access_token_obj = azure_client.get_token( + *connection.authType.azureConfig.scopes.split(",") + ) + if not connection.connectionOptions: + connection.connectionOptions = init_empty_connection_options() + connection.connectionOptions.root["access_token"] = access_token_obj.token + + # Update the connection with the connection arguments + connection_copy.connectionArguments = self.build_connection_args( + connection_copy + ) + + return create_generic_db_connection( + connection=connection_copy, + get_connection_url_fn=self.get_connection_url, + get_connection_args_fn=get_connection_args_common, + ) + + def test_connection( + self, + metadata: OpenMetadata, + automation_workflow: Optional[AutomationWorkflow] = None, + timeout_seconds: Optional[int] = THREE_MIN, + ) -> TestConnectionResult: + """ + Test connection. This can be executed either as part + of a metadata workflow or during an Automation Workflow + """ + queries = { + "GetDatabases": TRINO_GET_DATABASE, + } + + return test_connection_db_schema_sources( + metadata=metadata, + engine=self.client, + service_connection=self.service_connection, + automation_workflow=automation_workflow, + queries=queries, + timeout_seconds=timeout_seconds, + ) + + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + url = self.client.url + connection_copy = deepcopy(self.service_connection) + + connection_dict = { + "driver": url.drivername, + "host": url.host, + "port": url.port, + "user": url.username, + "catalog": url.database, + "schema": url.query.get("schema"), + } + + connection_dict.update(url.query) + + if connection_copy.proxies: + connection_dict["http_session"] = connection_copy.proxies + + if ( + connection_copy.connectionArguments + and connection_copy.connectionArguments.root + ): + connection_with_options_secrets(lambda: connection_copy) + connection_dict.update(get_connection_args_common(connection_copy)) + + if isinstance(connection_copy.authType, basicAuth.BasicAuth): + connection_dict["auth"] = TrinoConnection.get_basic_auth_dict( + connection_copy + ) + connection_dict["http_scheme"] = "https" + + elif isinstance(connection_copy.authType, jwtAuth.JwtAuth): + connection_dict["auth"] = TrinoConnection.get_jwt_auth_dict(connection_copy) + connection_dict["http_scheme"] = "https" + + elif hasattr(connection_copy.authType, "azureConfig"): + connection_dict["auth"] = TrinoConnection.get_azure_auth_dict( + connection_copy + ) + connection_dict["http_scheme"] = "https" + + elif ( + connection_copy.authType + == noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2 + ): + connection_dict["auth"] = TrinoConnection.get_oauth2_auth_dict( + connection_copy + ) + connection_dict["http_scheme"] = "https" + + return connection_dict + + @staticmethod + def get_connection_url(connection: TrinoConnectionConfig) -> str: + """ + Prepare the connection url for trino + """ + + url = f"{connection.scheme.value}://" + + # leaving username here as, even though with basic auth is used directly + # in BasicAuthentication class, it's often also required as a part of url. + # For example - it will be used by OAuth2Authentication to persist token in + # cache more efficiently (per user instead of per host) + if connection.username: + url += f"{quote_plus(connection.username)}@" + + url += f"{connection.hostPort}" + if connection.catalog: + url += f"/{connection.catalog}" + if connection.connectionOptions is not None: + params = "&".join( + f"{key}={quote_plus(value)}" + for (key, value) in connection.connectionOptions.root.items() + if value + ) + url = f"{url}?{params}" + return url + + @staticmethod + @connection_with_options_secrets + def build_connection_args(connection: TrinoConnectionConfig) -> ConnectionArguments: + """ + Get the connection args for the trino connection + """ + connection_args: ConnectionArguments = ( + connection.connectionArguments or init_empty_connection_arguments() + ) + assert connection_args.root is not None + + if connection.verify: + connection_args.root["verify"] = {"verify": connection.verify} + + if connection.proxies: + session = Session() + session.proxies = connection.proxies + + connection_args.root["http_session"] = session + + if isinstance(connection.authType, basicAuth.BasicAuth): + TrinoConnection.set_basic_auth(connection, connection_args) + + elif isinstance(connection.authType, jwtAuth.JwtAuth): + TrinoConnection.set_jwt_auth(connection, connection_args) + + elif hasattr(connection.authType, "azureConfig"): + TrinoConnection.set_azure_auth(connection, connection_args) + + elif ( + connection.authType + == noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2 + ): + TrinoConnection.set_oauth2_auth(connection, connection_args) + + return connection_args + + @staticmethod + def get_basic_auth_dict(connection: TrinoConnectionConfig) -> dict: + """ + Get the basic auth dictionary for the trino connection + """ + auth_type = cast(basicAuth.BasicAuth, connection.authType) + return { + "authType": "basic", + "username": connection.username, + "password": auth_type.password.get_secret_value() + if auth_type.password + else None, + } + + @staticmethod + def set_basic_auth( + connection: TrinoConnectionConfig, connection_args: ConnectionArguments + ) -> None: + """ + Get the basic auth dictionary for the trino connection + """ + assert connection_args.root is not None + auth_type = cast(basicAuth.BasicAuth, connection.authType) + + connection_args.root["auth"] = BasicAuthentication( + connection.username, + auth_type.password.get_secret_value() if auth_type.password else None, + ) + connection_args.root["http_scheme"] = "https" + + @staticmethod + def get_jwt_auth_dict(connection: TrinoConnectionConfig) -> dict: + """ + Get the jwt auth dictionary for the trino connection + """ + auth_type = cast(jwtAuth.JwtAuth, connection.authType) + + return { + "authType": "jwt", + "jwt": auth_type.jwt.get_secret_value(), + } + + @staticmethod + def set_jwt_auth( + connection: TrinoConnectionConfig, connection_args: ConnectionArguments + ) -> None: + """ + Set the jwt auth for the trino connection + """ + assert connection_args.root is not None + auth_type = cast(jwtAuth.JwtAuth, connection.authType) + + connection_args.root["auth"] = JWTAuthentication( + auth_type.jwt.get_secret_value() + ) + connection_args.root["http_scheme"] = "https" + + @staticmethod + def get_azure_auth_dict(connection: TrinoConnectionConfig) -> dict: + """ + Get the azure auth dictionary for the trino connection + """ + return { + "authType": "jwt", + "jwt": TrinoConnection.get_azure_token(connection), + } + + @staticmethod + def set_azure_auth( + connection: TrinoConnectionConfig, connection_args: ConnectionArguments + ) -> None: + """ + Set the azure auth for the trino connection + """ + assert connection_args.root is not None + + connection_args.root["auth"] = JWTAuthentication( + TrinoConnection.get_azure_token(connection) + ) + connection_args.root["http_scheme"] = "https" + + @staticmethod + def get_oauth2_auth_dict(connection: TrinoConnectionConfig) -> dict: + """ + Get the oauth2 auth dictionary for the trino connection + """ + return { + "authType": "oauth2", + } + + @staticmethod + def set_oauth2_auth( + connection: TrinoConnectionConfig, connection_args: ConnectionArguments + ) -> None: + """ + Set the oauth2 auth for the trino connection + """ + assert connection_args.root is not None + + connection_args.root["auth"] = OAuth2Authentication() + connection_args.root["http_scheme"] = "https" + + @staticmethod + def get_azure_token(connection: TrinoConnectionConfig) -> str: + """ + Get the azure token for the trino connection + """ + auth_type = cast(azureConfig.AzureConfigurationSource, connection.authType) + + if not auth_type.azureConfig.scopes: + raise ValueError( + "Azure Scopes are missing, please refer https://learn.microsoft.com/en-gb/azure/mysql/flexible-server/how-to-azure-ad#2---retrieve-microsoft-entra-access-token and fetch the resource associated with it, for e.g. https://ossrdbms-aad.database.windows.net/.default" + ) + + azure_client = AzureClient(auth_type.azureConfig).create_client() + + return azure_client.get_token(*auth_type.azureConfig.scopes.split(",")).token diff --git a/ingestion/src/metadata/ingestion/source/database/trino/service_spec.py b/ingestion/src/metadata/ingestion/source/database/trino/service_spec.py index f4c8b506154..b81cb821cd9 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino/service_spec.py +++ b/ingestion/src/metadata/ingestion/source/database/trino/service_spec.py @@ -1,3 +1,4 @@ +from metadata.ingestion.source.database.trino.connection import TrinoConnection from metadata.ingestion.source.database.trino.lineage import TrinoLineageSource from metadata.ingestion.source.database.trino.metadata import TrinoSource from metadata.ingestion.source.database.trino.usage import TrinoUsageSource @@ -13,4 +14,5 @@ ServiceSpec = DefaultDatabaseSpec( usage_source_class=TrinoUsageSource, profiler_class=TrinoProfilerInterface, sampler_class=TrinoSampler, + connection_class=TrinoConnection, ) diff --git a/ingestion/tests/unit/metadata/data_quality/test_table_diff_param_setter.py b/ingestion/tests/unit/metadata/data_quality/test_table_diff_param_setter.py index ceb81836060..284dfd0396f 100644 --- a/ingestion/tests/unit/metadata/data_quality/test_table_diff_param_setter.py +++ b/ingestion/tests/unit/metadata/data_quality/test_table_diff_param_setter.py @@ -6,6 +6,9 @@ from sqlalchemy import Column as SAColumn from sqlalchemy import MetaData, String, create_engine from sqlalchemy.orm import declarative_base +from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import ( + BaseTableParameter, +) from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import ( TableDiffParamsSetter, ) @@ -111,9 +114,9 @@ SERVICE_CONNECTION_CONFIG = MysqlConnection( ], ) def test_get_data_diff_url(input, expected): - assert expected == TableDiffParamsSetter( - None, None, MOCK_TABLE, None - ).get_data_diff_url(input, "service.database.schema.table") + assert expected == BaseTableParameter.get_data_diff_url( + input, "service.database.schema.table" + ) @pytest.mark.parametrize( diff --git a/ingestion/tests/unit/metadata/ingestion/connections/test_connection.py b/ingestion/tests/unit/metadata/ingestion/connections/test_connection.py index 4a5ec8c6387..cbc3f6f3511 100644 --- a/ingestion/tests/unit/metadata/ingestion/connections/test_connection.py +++ b/ingestion/tests/unit/metadata/ingestion/connections/test_connection.py @@ -44,7 +44,7 @@ def test_connection(mock_service_connection): class TestConnection(BaseConnection): """Concrete implementation of BaseConnection for testing""" - def get_client(self): + def _get_client(self): return MagicMock() def test_connection( @@ -64,6 +64,9 @@ def test_connection(mock_service_connection): lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)), ) + def get_connection_dict(self): + return {} + return TestConnection(mock_service_connection) @@ -100,7 +103,7 @@ class TestBaseConnection: mock_client = MagicMock() class TestConnectionWithMockClient(BaseConnection): - def get_client(self): + def _get_client(self): return mock_client def test_connection( @@ -120,6 +123,9 @@ class TestBaseConnection: lastUpdatedAt=Timestamp(int(datetime.now().timestamp() * 1000)), ) + def get_connection_dict(self): + return {} + connection = TestConnectionWithMockClient(test_connection.service_connection) - client = connection.get_client() + client = connection.client assert client == mock_client diff --git a/ingestion/tests/unit/test_build_connection_url.py b/ingestion/tests/unit/test_build_connection_url.py index ae35d745554..3c5516c3b2b 100644 --- a/ingestion/tests/unit/test_build_connection_url.py +++ b/ingestion/tests/unit/test_build_connection_url.py @@ -70,7 +70,7 @@ class TestGetConnectionURL(unittest.TestCase): hostPort="localhost:3306", databaseSchema="openmetadata_db", ) - engine_connection = MySQLConnection(connection).get_client() + engine_connection = MySQLConnection(connection).client self.assertEqual( str(engine_connection.url), "mysql+pymysql://openmetadata_user:openmetadata_password@localhost:3306/openmetadata_db", @@ -93,7 +93,7 @@ class TestGetConnectionURL(unittest.TestCase): "get_token", return_value=AccessToken(token="mocked_token", expires_on=100), ): - engine_connection = MySQLConnection(connection).get_client() + engine_connection = MySQLConnection(connection).client self.assertEqual( str(engine_connection.url), "mysql+pymysql://openmetadata_user:mocked_token@localhost:3306/openmetadata_db", diff --git a/ingestion/tests/unit/test_source_connection.py b/ingestion/tests/unit/test_source_connection.py index ff6b6a9414f..a58652772a8 100644 --- a/ingestion/tests/unit/test_source_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -100,7 +100,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon SnowflakeScheme, ) from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( - TrinoConnection, + TrinoConnection as TrinoConnectionConfig, +) +from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( TrinoScheme, ) from metadata.generated.schema.entity.services.connections.database.verticaConnection import ( @@ -112,7 +114,7 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_url_common, ) -from metadata.ingestion.source.database.trino.connection import get_connection_args +from metadata.ingestion.source.database.trino.connection import TrinoConnection # pylint: disable=import-outside-toplevel @@ -405,44 +407,38 @@ class SourceConnectionTest(TestCase): assert expected_result == get_connection_url(impala_conn_obj) def test_trino_url_without_params(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - expected_url = "trino://username@localhost:443/catalog" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", authType=BasicAuth(password="pass"), catalog="catalog", ) + trino_connection = TrinoConnection(trino_conn_obj) - assert expected_url == get_connection_url(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) # Passing @ in username and password expected_url = "trino://username%40444@localhost:443/catalog" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username@444", authType=BasicAuth(password="pass@111"), catalog="catalog", ) + trino_connection = TrinoConnection(trino_conn_obj) - assert expected_url == get_connection_url(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) def test_trino_conn_arguments(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_args, - ) - # connection arguments without connectionArguments and without proxies expected_args = { "auth": BasicAuthentication("user", None), "http_scheme": "https", } - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( username="user", authType=BasicAuth(password=None), hostPort="localhost:443", @@ -450,7 +446,10 @@ class SourceConnectionTest(TestCase): connectionArguments=None, scheme=TrinoScheme.trino, ) - assert expected_args == get_connection_args(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + assert ( + expected_args == trino_connection.build_connection_args(trino_conn_obj).root + ) # connection arguments with connectionArguments and without proxies expected_args = { @@ -458,7 +457,7 @@ class SourceConnectionTest(TestCase): "auth": BasicAuthentication("user", None), "http_scheme": "https", } - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( username="user", authType=BasicAuth(password=None), hostPort="localhost:443", @@ -466,14 +465,17 @@ class SourceConnectionTest(TestCase): connectionArguments={"user": "user-to-be-impersonated"}, scheme=TrinoScheme.trino, ) - assert expected_args == get_connection_args(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + assert ( + expected_args == trino_connection.build_connection_args(trino_conn_obj).root + ) # connection arguments without connectionArguments and with proxies expected_args = { "auth": BasicAuthentication("user", None), "http_scheme": "https", } - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( username="user", authType=BasicAuth(password=None), hostPort="localhost:443", @@ -482,7 +484,9 @@ class SourceConnectionTest(TestCase): proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"}, scheme=TrinoScheme.trino, ) - conn_args = get_connection_args(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + conn_args = trino_connection.build_connection_args(trino_conn_obj).root + assert "http_session" in conn_args conn_args.pop("http_session") assert expected_args == conn_args @@ -493,7 +497,7 @@ class SourceConnectionTest(TestCase): "auth": BasicAuthentication("user", None), "http_scheme": "https", } - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( username="user", authType=BasicAuth(password=None), hostPort="localhost:443", @@ -502,18 +506,15 @@ class SourceConnectionTest(TestCase): proxies={"http": "foo.bar:3128", "http://host.name": "foo.bar:4012"}, scheme=TrinoScheme.trino, ) - conn_args = get_connection_args(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + conn_args = trino_connection.build_connection_args(trino_conn_obj).root assert "http_session" in conn_args conn_args.pop("http_session") assert expected_args == conn_args def test_trino_url_with_params(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - expected_url = "trino://username@localhost:443/catalog?param=value" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", @@ -521,35 +522,31 @@ class SourceConnectionTest(TestCase): catalog="catalog", connectionOptions={"param": "value"}, ) - assert expected_url == get_connection_url(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) def test_trino_url_with_jwt_auth(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - expected_url = "trino://username@localhost:443/catalog" expected_args = { "auth": JWTAuthentication("jwt_token_value"), "http_scheme": "https", } - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", authType=JwtAuth(jwt="jwt_token_value"), catalog="catalog", ) - assert expected_url == get_connection_url(trino_conn_obj) - assert expected_args == get_connection_args(trino_conn_obj) - - def test_trino_with_proxies(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_args, + trino_connection = TrinoConnection(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) + assert ( + expected_args == trino_connection.build_connection_args(trino_conn_obj).root ) + def test_trino_with_proxies(self): test_proxies = {"http": "http_proxy", "https": "https_proxy"} - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", @@ -557,59 +554,54 @@ class SourceConnectionTest(TestCase): catalog="catalog", proxies=test_proxies, ) + trino_connection = TrinoConnection(trino_conn_obj) assert ( test_proxies - == get_connection_args(trino_conn_obj).get("http_session").proxies + == trino_connection.build_connection_args(trino_conn_obj) + .root.get("http_session") + .proxies ) def test_trino_without_catalog(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - # Test trino url without catalog expected_url = "trino://username@localhost:443" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", authType=BasicAuth(password="pass"), ) - assert expected_url == get_connection_url(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) def test_trino_without_catalog(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - # Test trino url without catalog expected_url = "trino://username@localhost:443" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", authType=BasicAuth(password="pass"), ) - assert expected_url == get_connection_url(trino_conn_obj) + trino_connection = TrinoConnection(trino_conn_obj) + assert expected_url == str(trino_connection.client.url) def test_trino_with_oauth2(self): - from metadata.ingestion.source.database.trino.connection import ( - get_connection_url, - ) - # Test trino url without catalog expected_url = "trino://username@localhost:443" - trino_conn_obj = TrinoConnection( + trino_conn_obj = TrinoConnectionConfig( scheme=TrinoScheme.trino, hostPort="localhost:443", username="username", authType=noConfigAuthenticationTypes.NoConfigAuthenticationTypes.OAuth2, ) - assert isinstance( - get_connection_args(trino_conn_obj).get("auth"), OAuth2Authentication + trino_connection = TrinoConnection(trino_conn_obj) + assert ( + trino_connection.build_connection_args(trino_conn_obj).root.get("auth") + == OAuth2Authentication() ) def test_vertica_url(self):