diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py index ddbdc7c59dc..8335ceb1c60 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py @@ -43,8 +43,10 @@ class ServiceSpecPatch: def get_data_diff_class(self) -> Type["BaseTableParameter"]: return import_from_module(self.service_spec.data_diff) - def get_connection_class(self) -> Type[BaseConnection]: - return import_from_module(self.service_spec.connection_class) + def get_connection_class(self) -> Optional[Type[BaseConnection]]: + if self.service_spec.connection_class: + return import_from_module(self.service_spec.connection_class) + return None class BaseTableParameter: @@ -99,18 +101,23 @@ class BaseTableParameter: @staticmethod def _get_service_connection_config( - db_service: DatabaseService, + service_connection_config, ) -> Optional[Union[str, dict]]: """ Get the connection dictionary for the service. """ - service_connection_config = db_service.connection.config service_spec_patch = ServiceSpecPatch( ServiceType.Database, service_connection_config.type.value.lower() ) try: connection_class = service_spec_patch.get_connection_class() + if not connection_class: + return ( + str(get_connection(service_connection_config).url) + if service_connection_config + else None + ) connection = connection_class(service_connection_config) return connection.get_connection_dict() except (ValueError, AttributeError, NotImplementedError): @@ -137,7 +144,9 @@ class BaseTableParameter: str: The url for the data diff service """ source_url = ( - BaseTableParameter._get_service_connection_config(db_service) + BaseTableParameter._get_service_connection_config( + db_service.connection.config + ) if not override_url else override_url ) diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index bd4e5fae743..403f7ceee79 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -58,7 +58,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter): DatabaseService, self.table_entity.service.id, nullable=False ) - service1_url = BaseTableParameter._get_service_connection_config(service1) + service1_url = BaseTableParameter._get_service_connection_config( + self.service_connection_config + ) table2_fqn = self.get_parameter(test_case, "table2") if table2_fqn is None: diff --git a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py index 9bf32d37fa2..25db52514d3 100644 --- a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py @@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.common.basic BasicAuth, ) from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( - MysqlConnection, + MysqlConnection as MySQLConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, @@ -49,7 +49,7 @@ from metadata.ingestion.source.database.mysql.queries import ( from metadata.utils.constants import THREE_MIN -class MySQLConnection(BaseConnection[MysqlConnection, Engine]): +class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]): def _get_client(self) -> Engine: """ Return the SQLAlchemy Engine for MySQL.