fix: lazy load classes from factory method (#18321)

This commit is contained in:
Teddy 2024-10-21 11:29:03 +02:00 committed by GitHub
parent 29d6e26dab
commit dcf71aa0ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 54 deletions

View File

@ -46,8 +46,8 @@ class TestSuiteInterface(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
ometa_client: OpenMetadata,
service_connection_config: DatabaseConnection, service_connection_config: DatabaseConnection,
ometa_client: OpenMetadata,
table_entity: Table, table_entity: Table,
): ):
"""Required attribute for the interface""" """Required attribute for the interface"""

View File

@ -8,27 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=import-outside-toplevel
""" """
Interface factory Interface factory
""" """
import traceback import traceback
from logging import Logger from logging import Logger
from typing import Callable, Dict, Type
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
DatabricksTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
SQATestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
UnityCatalogTestSuiteInterface,
)
from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface
from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
@ -55,19 +42,18 @@ class TestSuiteInterfaceFactory:
def __init__(self): def __init__(self):
"""Initialize the interface factory""" """Initialize the interface factory"""
self._interface_type = { self._interface_type: Dict[str, Callable[[], Type[TestSuiteInterface]]] = {
"base": SQATestSuiteInterface, "base": self.sqa,
DatalakeConnection.__name__: PandasTestSuiteInterface,
} }
def register(self, interface_type: str, interface: TestSuiteInterface): def register(self, interface_type: str, fn: Callable[[], Type[TestSuiteInterface]]):
"""Register the interface """Register the interface
Args: Args:
interface_type (str): type of the interface interface_type (str): type of the interface
interface (TestSuiteInterface): a class that implements the TestSuiteInterface interface (callable): a class that implements the TestSuiteInterface
""" """
self._interface_type[interface_type] = interface self._interface_type[interface_type] = fn
def register_many(self, interface_dict): def register_many(self, interface_dict):
""" """
@ -77,8 +63,8 @@ class TestSuiteInterfaceFactory:
interface_dict: A dictionary mapping connection class names (strings) to their interface_dict: A dictionary mapping connection class names (strings) to their
corresponding profiler interface classes. corresponding profiler interface classes.
""" """
for interface_type, interface_class in interface_dict.items(): for interface_type, interface_fn in interface_dict.items():
self.register(interface_type, interface_class) self.register(interface_type, interface_fn)
def create( def create(
self, self,
@ -104,25 +90,69 @@ class TestSuiteInterfaceFactory:
except AttributeError as err: except AttributeError as err:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
raise AttributeError(f"Could not instantiate interface class: {err}") raise AttributeError(f"Could not instantiate interface class: {err}")
interface = self._interface_type.get(connection_type) interface_fn = self._interface_type.get(connection_type)
if not interface: if not interface_fn:
interface = self._interface_type["base"] interface_fn = self._interface_type["base"]
return interface( interface_class = interface_fn()
return interface_class(
service_connection_config, ometa_client, table_entity, *args, **kwargs service_connection_config, ometa_client, table_entity, *args, **kwargs
) )
@staticmethod
def sqa() -> Type[TestSuiteInterface]:
"""Lazy load the SQATestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
SQATestSuiteInterface,
)
test_suite_interface_factory = TestSuiteInterfaceFactory() return SQATestSuiteInterface
@staticmethod
def pandas() -> Type[TestSuiteInterface]:
"""Lazy load the PandasTestSuiteInterface"""
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface,
)
return PandasTestSuiteInterface
@staticmethod
def snowflake() -> Type[TestSuiteInterface]:
"""Lazy load the SnowflakeTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface,
)
return SnowflakeTestSuiteInterface
@staticmethod
def unity_catalog() -> Type[TestSuiteInterface]:
"""Lazy load the UnityCatalogTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
UnityCatalogTestSuiteInterface,
)
return UnityCatalogTestSuiteInterface
@staticmethod
def databricks() -> Type[TestSuiteInterface]:
"""Lazy load the DatabricksTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
DatabricksTestSuiteInterface,
)
return DatabricksTestSuiteInterface
test_suite_interface = { test_suite_interface = {
DatabaseConnection.__name__: SQATestSuiteInterface, DatabaseConnection.__name__: TestSuiteInterfaceFactory.sqa,
DatalakeConnection.__name__: PandasTestSuiteInterface, DatalakeConnection.__name__: TestSuiteInterfaceFactory.pandas,
SnowflakeConnection.__name__: SnowflakeTestSuiteInterface, SnowflakeConnection.__name__: TestSuiteInterfaceFactory.snowflake,
UnityCatalogConnection.__name__: UnityCatalogTestSuiteInterface, UnityCatalogConnection.__name__: TestSuiteInterfaceFactory.unity_catalog,
DatabricksConnection.__name__: DatabricksTestSuiteInterface, DatabricksConnection.__name__: TestSuiteInterfaceFactory.databricks,
} }
test_suite_interface_factory = TestSuiteInterfaceFactory()
test_suite_interface_factory.register_many(test_suite_interface) test_suite_interface_factory.register_many(test_suite_interface)

View File

@ -8,47 +8,80 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=import-outside-toplevel
""" """
Factory class for creating profiler source objects Factory class for creating profiler source objects
""" """
from typing import Callable, Dict, Type
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
BigqueryType, BigqueryType,
) )
from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
DatabricksType, DatabricksType,
) )
from metadata.profiler.source.base.profiler_source import ProfilerSource from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
from metadata.profiler.source.bigquery.profiler_source import BigQueryProfilerSource
from metadata.profiler.source.databricks.profiler_source import DataBricksProfilerSource
class ProfilerSourceFactory: class ProfilerSourceFactory:
"""Creational factory for profiler source objects""" """Creational factory for profiler source objects"""
def __init__(self): def __init__(self):
self._source_type = {"base": ProfilerSource} self._source_type: Dict[str, Callable[[], Type[ProfilerSourceInterface]]] = {
"base": self.base
}
def register_source(self, source_type: str, source_class): def register_source(self, type_: str, source_fn):
"""Register a new source type""" """Register a new source type"""
self._source_type[source_type] = source_class self._source_type[type_] = source_fn
def create(self, source_type: str, *args, **kwargs) -> ProfilerSource: def register_many_sources(
self, source_dict: Dict[str, Callable[[], Type[ProfilerSourceInterface]]]
):
"""Register multiple source types at once"""
for type_, source_fn in source_dict.items():
self.register_source(type_, source_fn)
def create(self, type_: str, *args, **kwargs) -> ProfilerSourceInterface:
"""Create source object based on source type""" """Create source object based on source type"""
source_class = self._source_type.get(source_type) source_fn = self._source_type.get(type_)
if not source_class: if not source_fn:
source_class = self._source_type["base"] source_fn = self._source_type["base"]
return source_class(*args, **kwargs)
source_class = source_fn()
return source_class(*args, **kwargs) return source_class(*args, **kwargs)
@staticmethod
def base() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the base source"""
from metadata.profiler.source.base.profiler_source import ProfilerSource
return ProfilerSource
@staticmethod
def bigquery() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the BigQuery source"""
from metadata.profiler.source.bigquery.profiler_source import (
BigQueryProfilerSource,
)
return BigQueryProfilerSource
@staticmethod
def databricks() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the Databricks source"""
from metadata.profiler.source.databricks.profiler_source import (
DataBricksProfilerSource,
)
return DataBricksProfilerSource
source = {
BigqueryType.BigQuery.value.lower(): ProfilerSourceFactory.bigquery,
DatabricksType.Databricks.value.lower(): ProfilerSourceFactory.databricks,
}
profiler_source_factory = ProfilerSourceFactory() profiler_source_factory = ProfilerSourceFactory()
profiler_source_factory.register_source( profiler_source_factory.register_many_sources(source)
BigqueryType.BigQuery.value.lower(),
BigQueryProfilerSource,
)
profiler_source_factory.register_source(
DatabricksType.Databricks.value.lower(),
DataBricksProfilerSource,
)