diff --git a/ingestion/src/metadata/data_quality/runner/base_test_suite_source.py b/ingestion/src/metadata/data_quality/runner/base_test_suite_source.py index 8d8583e7dad..485a352e5b9 100644 --- a/ingestion/src/metadata/data_quality/runner/base_test_suite_source.py +++ b/ingestion/src/metadata/data_quality/runner/base_test_suite_source.py @@ -112,10 +112,14 @@ class BaseTestSuiteRunner: entity=self.entity, metadata=self.ometa_client ) test_suite_class = import_test_suite_class( - ServiceType.Database, source_type=self._interface_type + ServiceType.Database, + source_type=self._interface_type, + source_config_type=self.service_conn_config.type.value, ) sampler_class = import_sampler_class( - ServiceType.Database, source_type=self._interface_type + ServiceType.Database, + source_type=self._interface_type, + source_config_type=self.service_conn_config.type.value, ) # This is shared between the sampler and DQ interfaces _orm = self._build_table_orm(self.entity) diff --git a/ingestion/src/metadata/utils/service_spec/service_spec.py b/ingestion/src/metadata/utils/service_spec/service_spec.py index d91c2b77704..3f9ca297683 100644 --- a/ingestion/src/metadata/utils/service_spec/service_spec.py +++ b/ingestion/src/metadata/utils/service_spec/service_spec.py @@ -14,6 +14,7 @@ from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.sampler.sampler_interface import SamplerInterface from metadata.utils.importer import ( TYPE_SEPARATOR, + DynamicImportException, get_class_path, get_module_dir, import_from_module, @@ -112,14 +113,34 @@ def import_profiler_class( def import_test_suite_class( - service_type: ServiceType, source_type: str + service_type: ServiceType, + source_type: str, + source_config_type: Optional[str] = None, ) -> Type[TestSuiteInterface]: - class_path = BaseSpec.get_for_source(service_type, source_type).test_suite_class + try: + class_path = BaseSpec.get_for_source(service_type, source_type).test_suite_class + except DynamicImportException: + if source_config_type: + class_path = BaseSpec.get_for_source( + service_type, source_config_type.lower() + ).test_suite_class + else: + raise return cast(Type[TestSuiteInterface], import_from_module(class_path)) def import_sampler_class( - service_type: ServiceType, source_type: str + service_type: ServiceType, + source_type: str, + source_config_type: Optional[str] = None, ) -> Type[SamplerInterface]: - class_path = BaseSpec.get_for_source(service_type, source_type).sampler_class + try: + class_path = BaseSpec.get_for_source(service_type, source_type).sampler_class + except DynamicImportException: + if source_config_type: + class_path = BaseSpec.get_for_source( + service_type, source_config_type.lower() + ).sampler_class + else: + raise return cast(Type[SamplerInterface], import_from_module(class_path))