fix: lazy load classes from factory method (#18321)

This commit is contained in:
Teddy 2024-10-21 11:29:03 +02:00 committed by GitHub
parent 29d6e26dab
commit dcf71aa0ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 54 deletions

View File

@ -46,8 +46,8 @@ class TestSuiteInterface(ABC):
@abstractmethod
def __init__(
self,
ometa_client: OpenMetadata,
service_connection_config: DatabaseConnection,
ometa_client: OpenMetadata,
table_entity: Table,
):
"""Required attribute for the interface"""

View File

@ -8,27 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
"""
Interface factory
"""
import traceback
from logging import Logger
from typing import Callable, Dict, Type
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
DatabricksTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
SQATestSuiteInterface,
)
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
UnityCatalogTestSuiteInterface,
)
from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
@ -55,19 +42,18 @@ class TestSuiteInterfaceFactory:
def __init__(self):
"""Initialize the interface factory"""
self._interface_type = {
"base": SQATestSuiteInterface,
DatalakeConnection.__name__: PandasTestSuiteInterface,
self._interface_type: Dict[str, Callable[[], Type[TestSuiteInterface]]] = {
"base": self.sqa,
}
def register(self, interface_type: str, interface: TestSuiteInterface):
def register(self, interface_type: str, fn: Callable[[], Type[TestSuiteInterface]]):
"""Register the interface
Args:
interface_type (str): type of the interface
interface (TestSuiteInterface): a class that implements the TestSuiteInterface
interface (callable): a class that implements the TestSuiteInterface
"""
self._interface_type[interface_type] = interface
self._interface_type[interface_type] = fn
def register_many(self, interface_dict):
"""
@ -77,8 +63,8 @@ class TestSuiteInterfaceFactory:
interface_dict: A dictionary mapping connection class names (strings) to their
corresponding profiler interface classes.
"""
for interface_type, interface_class in interface_dict.items():
self.register(interface_type, interface_class)
for interface_type, interface_fn in interface_dict.items():
self.register(interface_type, interface_fn)
def create(
self,
@ -104,25 +90,69 @@ class TestSuiteInterfaceFactory:
except AttributeError as err:
logger.debug(traceback.format_exc())
raise AttributeError(f"Could not instantiate interface class: {err}")
interface = self._interface_type.get(connection_type)
interface_fn = self._interface_type.get(connection_type)
if not interface:
interface = self._interface_type["base"]
if not interface_fn:
interface_fn = self._interface_type["base"]
return interface(
interface_class = interface_fn()
return interface_class(
service_connection_config, ometa_client, table_entity, *args, **kwargs
)
@staticmethod
def sqa() -> Type[TestSuiteInterface]:
"""Lazy load the SQATestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
SQATestSuiteInterface,
)
test_suite_interface_factory = TestSuiteInterfaceFactory()
return SQATestSuiteInterface
@staticmethod
def pandas() -> Type[TestSuiteInterface]:
"""Lazy load the PandasTestSuiteInterface"""
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
PandasTestSuiteInterface,
)
return PandasTestSuiteInterface
@staticmethod
def snowflake() -> Type[TestSuiteInterface]:
"""Lazy load the SnowflakeTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface,
)
return SnowflakeTestSuiteInterface
@staticmethod
def unity_catalog() -> Type[TestSuiteInterface]:
"""Lazy load the UnityCatalogTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
UnityCatalogTestSuiteInterface,
)
return UnityCatalogTestSuiteInterface
@staticmethod
def databricks() -> Type[TestSuiteInterface]:
"""Lazy load the DatabricksTestSuiteInterface"""
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
DatabricksTestSuiteInterface,
)
return DatabricksTestSuiteInterface
test_suite_interface = {
DatabaseConnection.__name__: SQATestSuiteInterface,
DatalakeConnection.__name__: PandasTestSuiteInterface,
SnowflakeConnection.__name__: SnowflakeTestSuiteInterface,
UnityCatalogConnection.__name__: UnityCatalogTestSuiteInterface,
DatabricksConnection.__name__: DatabricksTestSuiteInterface,
DatabaseConnection.__name__: TestSuiteInterfaceFactory.sqa,
DatalakeConnection.__name__: TestSuiteInterfaceFactory.pandas,
SnowflakeConnection.__name__: TestSuiteInterfaceFactory.snowflake,
UnityCatalogConnection.__name__: TestSuiteInterfaceFactory.unity_catalog,
DatabricksConnection.__name__: TestSuiteInterfaceFactory.databricks,
}
test_suite_interface_factory = TestSuiteInterfaceFactory()
test_suite_interface_factory.register_many(test_suite_interface)

View File

@ -8,47 +8,80 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
"""
Factory class for creating profiler source objects
"""
from typing import Callable, Dict, Type
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
BigqueryType,
)
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
DatabricksType,
)
from metadata.profiler.source.base.profiler_source import ProfilerSource
from metadata.profiler.source.bigquery.profiler_source import BigQueryProfilerSource
from metadata.profiler.source.databricks.profiler_source import DataBricksProfilerSource
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
class ProfilerSourceFactory:
"""Creational factory for profiler source objects"""
def __init__(self):
self._source_type = {"base": ProfilerSource}
self._source_type: Dict[str, Callable[[], Type[ProfilerSourceInterface]]] = {
"base": self.base
}
def register_source(self, source_type: str, source_class):
def register_source(self, type_: str, source_fn):
"""Register a new source type"""
self._source_type[source_type] = source_class
self._source_type[type_] = source_fn
def create(self, source_type: str, *args, **kwargs) -> ProfilerSource:
def register_many_sources(
self, source_dict: Dict[str, Callable[[], Type[ProfilerSourceInterface]]]
):
"""Register multiple source types at once"""
for type_, source_fn in source_dict.items():
self.register_source(type_, source_fn)
def create(self, type_: str, *args, **kwargs) -> ProfilerSourceInterface:
"""Create source object based on source type"""
source_class = self._source_type.get(source_type)
if not source_class:
source_class = self._source_type["base"]
return source_class(*args, **kwargs)
source_fn = self._source_type.get(type_)
if not source_fn:
source_fn = self._source_type["base"]
source_class = source_fn()
return source_class(*args, **kwargs)
@staticmethod
def base() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the base source"""
from metadata.profiler.source.base.profiler_source import ProfilerSource
return ProfilerSource
@staticmethod
def bigquery() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the BigQuery source"""
from metadata.profiler.source.bigquery.profiler_source import (
BigQueryProfilerSource,
)
return BigQueryProfilerSource
@staticmethod
def databricks() -> Type[ProfilerSourceInterface]:
"""Lazy loading of the Databricks source"""
from metadata.profiler.source.databricks.profiler_source import (
DataBricksProfilerSource,
)
return DataBricksProfilerSource
source = {
BigqueryType.BigQuery.value.lower(): ProfilerSourceFactory.bigquery,
DatabricksType.Databricks.value.lower(): ProfilerSourceFactory.databricks,
}
profiler_source_factory = ProfilerSourceFactory()
profiler_source_factory.register_source(
BigqueryType.BigQuery.value.lower(),
BigQueryProfilerSource,
)
profiler_source_factory.register_source(
DatabricksType.Databricks.value.lower(),
DataBricksProfilerSource,
)
profiler_source_factory.register_many_sources(source)