diff --git a/ingestion/src/metadata/data_quality/interface/test_suite_interface.py b/ingestion/src/metadata/data_quality/interface/test_suite_interface.py index f5bfc9f1df6..9e98cf66a1a 100644 --- a/ingestion/src/metadata/data_quality/interface/test_suite_interface.py +++ b/ingestion/src/metadata/data_quality/interface/test_suite_interface.py @@ -46,8 +46,8 @@ class TestSuiteInterface(ABC): @abstractmethod def __init__( self, - ometa_client: OpenMetadata, service_connection_config: DatabaseConnection, + ometa_client: OpenMetadata, table_entity: Table, ): """Required attribute for the interface""" diff --git a/ingestion/src/metadata/data_quality/interface/test_suite_interface_factory.py b/ingestion/src/metadata/data_quality/interface/test_suite_interface_factory.py index da73b780019..15b76cbc479 100644 --- a/ingestion/src/metadata/data_quality/interface/test_suite_interface_factory.py +++ b/ingestion/src/metadata/data_quality/interface/test_suite_interface_factory.py @@ -8,27 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=import-outside-toplevel """ Interface factory """ import traceback from logging import Logger +from typing import Callable, Dict, Type -from metadata.data_quality.interface.pandas.pandas_test_suite_interface import ( - PandasTestSuiteInterface, -) -from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import ( - DatabricksTestSuiteInterface, -) -from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import ( - SnowflakeTestSuiteInterface, -) -from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import ( - SQATestSuiteInterface, -) -from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import ( - UnityCatalogTestSuiteInterface, -) from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( @@ -55,19 +42,18 @@ class TestSuiteInterfaceFactory: def __init__(self): """Initialize the interface factory""" - self._interface_type = { - "base": SQATestSuiteInterface, - DatalakeConnection.__name__: PandasTestSuiteInterface, + self._interface_type: Dict[str, Callable[[], Type[TestSuiteInterface]]] = { + "base": self.sqa, } - def register(self, interface_type: str, interface: TestSuiteInterface): + def register(self, interface_type: str, fn: Callable[[], Type[TestSuiteInterface]]): """Register the interface Args: interface_type (str): type of the interface - interface (TestSuiteInterface): a class that implements the TestSuiteInterface + interface (callable): a class that implements the TestSuiteInterface """ - self._interface_type[interface_type] = interface + self._interface_type[interface_type] = fn def register_many(self, interface_dict): """ @@ -77,8 +63,8 @@ class TestSuiteInterfaceFactory: interface_dict: A dictionary mapping connection class names (strings) to their corresponding profiler interface classes. """ - for interface_type, interface_class in interface_dict.items(): - self.register(interface_type, interface_class) + for interface_type, interface_fn in interface_dict.items(): + self.register(interface_type, interface_fn) def create( self, @@ -104,25 +90,69 @@ class TestSuiteInterfaceFactory: except AttributeError as err: logger.debug(traceback.format_exc()) raise AttributeError(f"Could not instantiate interface class: {err}") - interface = self._interface_type.get(connection_type) + interface_fn = self._interface_type.get(connection_type) - if not interface: - interface = self._interface_type["base"] + if not interface_fn: + interface_fn = self._interface_type["base"] - return interface( + interface_class = interface_fn() + return interface_class( service_connection_config, ometa_client, table_entity, *args, **kwargs ) + @staticmethod + def sqa() -> Type[TestSuiteInterface]: + """Lazy load the SQATestSuiteInterface""" + from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import ( + SQATestSuiteInterface, + ) -test_suite_interface_factory = TestSuiteInterfaceFactory() + return SQATestSuiteInterface + + @staticmethod + def pandas() -> Type[TestSuiteInterface]: + """Lazy load the PandasTestSuiteInterface""" + from metadata.data_quality.interface.pandas.pandas_test_suite_interface import ( + PandasTestSuiteInterface, + ) + + return PandasTestSuiteInterface + + @staticmethod + def snowflake() -> Type[TestSuiteInterface]: + """Lazy load the SnowflakeTestSuiteInterface""" + from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import ( + SnowflakeTestSuiteInterface, + ) + + return SnowflakeTestSuiteInterface + + @staticmethod + def unity_catalog() -> Type[TestSuiteInterface]: + """Lazy load the UnityCatalogTestSuiteInterface""" + from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import ( + UnityCatalogTestSuiteInterface, + ) + + return UnityCatalogTestSuiteInterface + + @staticmethod + def databricks() -> Type[TestSuiteInterface]: + """Lazy load the DatabricksTestSuiteInterface""" + from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import ( + DatabricksTestSuiteInterface, + ) + + return DatabricksTestSuiteInterface test_suite_interface = { - DatabaseConnection.__name__: SQATestSuiteInterface, - DatalakeConnection.__name__: PandasTestSuiteInterface, - SnowflakeConnection.__name__: SnowflakeTestSuiteInterface, - UnityCatalogConnection.__name__: UnityCatalogTestSuiteInterface, - DatabricksConnection.__name__: DatabricksTestSuiteInterface, + DatabaseConnection.__name__: TestSuiteInterfaceFactory.sqa, + DatalakeConnection.__name__: TestSuiteInterfaceFactory.pandas, + SnowflakeConnection.__name__: TestSuiteInterfaceFactory.snowflake, + UnityCatalogConnection.__name__: TestSuiteInterfaceFactory.unity_catalog, + DatabricksConnection.__name__: TestSuiteInterfaceFactory.databricks, } +test_suite_interface_factory = TestSuiteInterfaceFactory() test_suite_interface_factory.register_many(test_suite_interface) diff --git a/ingestion/src/metadata/profiler/source/profiler_source_factory.py b/ingestion/src/metadata/profiler/source/profiler_source_factory.py index 0e616354e8d..e24adff127b 100644 --- a/ingestion/src/metadata/profiler/source/profiler_source_factory.py +++ b/ingestion/src/metadata/profiler/source/profiler_source_factory.py @@ -8,47 +8,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# pylint: disable=import-outside-toplevel """ Factory class for creating profiler source objects """ +from typing import Callable, Dict, Type + from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( BigqueryType, ) from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( DatabricksType, ) -from metadata.profiler.source.base.profiler_source import ProfilerSource -from metadata.profiler.source.bigquery.profiler_source import BigQueryProfilerSource -from metadata.profiler.source.databricks.profiler_source import DataBricksProfilerSource +from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface class ProfilerSourceFactory: """Creational factory for profiler source objects""" def __init__(self): - self._source_type = {"base": ProfilerSource} + self._source_type: Dict[str, Callable[[], Type[ProfilerSourceInterface]]] = { + "base": self.base + } - def register_source(self, source_type: str, source_class): + def register_source(self, type_: str, source_fn): """Register a new source type""" - self._source_type[source_type] = source_class + self._source_type[type_] = source_fn - def create(self, source_type: str, *args, **kwargs) -> ProfilerSource: + def register_many_sources( + self, source_dict: Dict[str, Callable[[], Type[ProfilerSourceInterface]]] + ): + """Register multiple source types at once""" + for type_, source_fn in source_dict.items(): + self.register_source(type_, source_fn) + + def create(self, type_: str, *args, **kwargs) -> ProfilerSourceInterface: """Create source object based on source type""" - source_class = self._source_type.get(source_type) - if not source_class: - source_class = self._source_type["base"] - return source_class(*args, **kwargs) + source_fn = self._source_type.get(type_) + if not source_fn: + source_fn = self._source_type["base"] + + source_class = source_fn() return source_class(*args, **kwargs) + @staticmethod + def base() -> Type[ProfilerSourceInterface]: + """Lazy loading of the base source""" + from metadata.profiler.source.base.profiler_source import ProfilerSource + + return ProfilerSource + + @staticmethod + def bigquery() -> Type[ProfilerSourceInterface]: + """Lazy loading of the BigQuery source""" + from metadata.profiler.source.bigquery.profiler_source import ( + BigQueryProfilerSource, + ) + + return BigQueryProfilerSource + + @staticmethod + def databricks() -> Type[ProfilerSourceInterface]: + """Lazy loading of the Databricks source""" + from metadata.profiler.source.databricks.profiler_source import ( + DataBricksProfilerSource, + ) + + return DataBricksProfilerSource + + +source = { + BigqueryType.BigQuery.value.lower(): ProfilerSourceFactory.bigquery, + DatabricksType.Databricks.value.lower(): ProfilerSourceFactory.databricks, +} profiler_source_factory = ProfilerSourceFactory() -profiler_source_factory.register_source( - BigqueryType.BigQuery.value.lower(), - BigQueryProfilerSource, -) -profiler_source_factory.register_source( - DatabricksType.Databricks.value.lower(), - DataBricksProfilerSource, -) +profiler_source_factory.register_many_sources(source)