mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-03 12:08:31 +00:00
fix: lazy load classes from factory method (#18321)
This commit is contained in:
parent
29d6e26dab
commit
dcf71aa0ea
@ -46,8 +46,8 @@ class TestSuiteInterface(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
ometa_client: OpenMetadata,
|
||||
service_connection_config: DatabaseConnection,
|
||||
ometa_client: OpenMetadata,
|
||||
table_entity: Table,
|
||||
):
|
||||
"""Required attribute for the interface"""
|
||||
|
||||
@ -8,27 +8,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=import-outside-toplevel
|
||||
"""
|
||||
Interface factory
|
||||
"""
|
||||
import traceback
|
||||
from logging import Logger
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
|
||||
PandasTestSuiteInterface,
|
||||
)
|
||||
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
|
||||
DatabricksTestSuiteInterface,
|
||||
)
|
||||
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
|
||||
SnowflakeTestSuiteInterface,
|
||||
)
|
||||
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
|
||||
SQATestSuiteInterface,
|
||||
)
|
||||
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
|
||||
UnityCatalogTestSuiteInterface,
|
||||
)
|
||||
from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
|
||||
@ -55,19 +42,18 @@ class TestSuiteInterfaceFactory:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the interface factory"""
|
||||
self._interface_type = {
|
||||
"base": SQATestSuiteInterface,
|
||||
DatalakeConnection.__name__: PandasTestSuiteInterface,
|
||||
self._interface_type: Dict[str, Callable[[], Type[TestSuiteInterface]]] = {
|
||||
"base": self.sqa,
|
||||
}
|
||||
|
||||
def register(self, interface_type: str, interface: TestSuiteInterface):
|
||||
def register(self, interface_type: str, fn: Callable[[], Type[TestSuiteInterface]]):
|
||||
"""Register the interface
|
||||
|
||||
Args:
|
||||
interface_type (str): type of the interface
|
||||
interface (TestSuiteInterface): a class that implements the TestSuiteInterface
|
||||
interface (callable): a class that implements the TestSuiteInterface
|
||||
"""
|
||||
self._interface_type[interface_type] = interface
|
||||
self._interface_type[interface_type] = fn
|
||||
|
||||
def register_many(self, interface_dict):
|
||||
"""
|
||||
@ -77,8 +63,8 @@ class TestSuiteInterfaceFactory:
|
||||
interface_dict: A dictionary mapping connection class names (strings) to their
|
||||
corresponding profiler interface classes.
|
||||
"""
|
||||
for interface_type, interface_class in interface_dict.items():
|
||||
self.register(interface_type, interface_class)
|
||||
for interface_type, interface_fn in interface_dict.items():
|
||||
self.register(interface_type, interface_fn)
|
||||
|
||||
def create(
|
||||
self,
|
||||
@ -104,25 +90,69 @@ class TestSuiteInterfaceFactory:
|
||||
except AttributeError as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
raise AttributeError(f"Could not instantiate interface class: {err}")
|
||||
interface = self._interface_type.get(connection_type)
|
||||
interface_fn = self._interface_type.get(connection_type)
|
||||
|
||||
if not interface:
|
||||
interface = self._interface_type["base"]
|
||||
if not interface_fn:
|
||||
interface_fn = self._interface_type["base"]
|
||||
|
||||
return interface(
|
||||
interface_class = interface_fn()
|
||||
return interface_class(
|
||||
service_connection_config, ometa_client, table_entity, *args, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sqa() -> Type[TestSuiteInterface]:
|
||||
"""Lazy load the SQATestSuiteInterface"""
|
||||
from metadata.data_quality.interface.sqlalchemy.sqa_test_suite_interface import (
|
||||
SQATestSuiteInterface,
|
||||
)
|
||||
|
||||
test_suite_interface_factory = TestSuiteInterfaceFactory()
|
||||
return SQATestSuiteInterface
|
||||
|
||||
@staticmethod
|
||||
def pandas() -> Type[TestSuiteInterface]:
|
||||
"""Lazy load the PandasTestSuiteInterface"""
|
||||
from metadata.data_quality.interface.pandas.pandas_test_suite_interface import (
|
||||
PandasTestSuiteInterface,
|
||||
)
|
||||
|
||||
return PandasTestSuiteInterface
|
||||
|
||||
@staticmethod
|
||||
def snowflake() -> Type[TestSuiteInterface]:
|
||||
"""Lazy load the SnowflakeTestSuiteInterface"""
|
||||
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
|
||||
SnowflakeTestSuiteInterface,
|
||||
)
|
||||
|
||||
return SnowflakeTestSuiteInterface
|
||||
|
||||
@staticmethod
|
||||
def unity_catalog() -> Type[TestSuiteInterface]:
|
||||
"""Lazy load the UnityCatalogTestSuiteInterface"""
|
||||
from metadata.data_quality.interface.sqlalchemy.unity_catalog.test_suite_interface import (
|
||||
UnityCatalogTestSuiteInterface,
|
||||
)
|
||||
|
||||
return UnityCatalogTestSuiteInterface
|
||||
|
||||
@staticmethod
|
||||
def databricks() -> Type[TestSuiteInterface]:
|
||||
"""Lazy load the DatabricksTestSuiteInterface"""
|
||||
from metadata.data_quality.interface.sqlalchemy.databricks.test_suite_interface import (
|
||||
DatabricksTestSuiteInterface,
|
||||
)
|
||||
|
||||
return DatabricksTestSuiteInterface
|
||||
|
||||
|
||||
test_suite_interface = {
|
||||
DatabaseConnection.__name__: SQATestSuiteInterface,
|
||||
DatalakeConnection.__name__: PandasTestSuiteInterface,
|
||||
SnowflakeConnection.__name__: SnowflakeTestSuiteInterface,
|
||||
UnityCatalogConnection.__name__: UnityCatalogTestSuiteInterface,
|
||||
DatabricksConnection.__name__: DatabricksTestSuiteInterface,
|
||||
DatabaseConnection.__name__: TestSuiteInterfaceFactory.sqa,
|
||||
DatalakeConnection.__name__: TestSuiteInterfaceFactory.pandas,
|
||||
SnowflakeConnection.__name__: TestSuiteInterfaceFactory.snowflake,
|
||||
UnityCatalogConnection.__name__: TestSuiteInterfaceFactory.unity_catalog,
|
||||
DatabricksConnection.__name__: TestSuiteInterfaceFactory.databricks,
|
||||
}
|
||||
|
||||
test_suite_interface_factory = TestSuiteInterfaceFactory()
|
||||
test_suite_interface_factory.register_many(test_suite_interface)
|
||||
|
||||
@ -8,47 +8,80 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
"""
|
||||
Factory class for creating profiler source objects
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
|
||||
BigqueryType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
|
||||
DatabricksType,
|
||||
)
|
||||
from metadata.profiler.source.base.profiler_source import ProfilerSource
|
||||
from metadata.profiler.source.bigquery.profiler_source import BigQueryProfilerSource
|
||||
from metadata.profiler.source.databricks.profiler_source import DataBricksProfilerSource
|
||||
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
|
||||
|
||||
|
||||
class ProfilerSourceFactory:
|
||||
"""Creational factory for profiler source objects"""
|
||||
|
||||
def __init__(self):
|
||||
self._source_type = {"base": ProfilerSource}
|
||||
self._source_type: Dict[str, Callable[[], Type[ProfilerSourceInterface]]] = {
|
||||
"base": self.base
|
||||
}
|
||||
|
||||
def register_source(self, source_type: str, source_class):
|
||||
def register_source(self, type_: str, source_fn):
|
||||
"""Register a new source type"""
|
||||
self._source_type[source_type] = source_class
|
||||
self._source_type[type_] = source_fn
|
||||
|
||||
def create(self, source_type: str, *args, **kwargs) -> ProfilerSource:
|
||||
def register_many_sources(
|
||||
self, source_dict: Dict[str, Callable[[], Type[ProfilerSourceInterface]]]
|
||||
):
|
||||
"""Register multiple source types at once"""
|
||||
for type_, source_fn in source_dict.items():
|
||||
self.register_source(type_, source_fn)
|
||||
|
||||
def create(self, type_: str, *args, **kwargs) -> ProfilerSourceInterface:
|
||||
"""Create source object based on source type"""
|
||||
source_class = self._source_type.get(source_type)
|
||||
if not source_class:
|
||||
source_class = self._source_type["base"]
|
||||
return source_class(*args, **kwargs)
|
||||
source_fn = self._source_type.get(type_)
|
||||
if not source_fn:
|
||||
source_fn = self._source_type["base"]
|
||||
|
||||
source_class = source_fn()
|
||||
return source_class(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def base() -> Type[ProfilerSourceInterface]:
|
||||
"""Lazy loading of the base source"""
|
||||
from metadata.profiler.source.base.profiler_source import ProfilerSource
|
||||
|
||||
return ProfilerSource
|
||||
|
||||
@staticmethod
|
||||
def bigquery() -> Type[ProfilerSourceInterface]:
|
||||
"""Lazy loading of the BigQuery source"""
|
||||
from metadata.profiler.source.bigquery.profiler_source import (
|
||||
BigQueryProfilerSource,
|
||||
)
|
||||
|
||||
return BigQueryProfilerSource
|
||||
|
||||
@staticmethod
|
||||
def databricks() -> Type[ProfilerSourceInterface]:
|
||||
"""Lazy loading of the Databricks source"""
|
||||
from metadata.profiler.source.databricks.profiler_source import (
|
||||
DataBricksProfilerSource,
|
||||
)
|
||||
|
||||
return DataBricksProfilerSource
|
||||
|
||||
|
||||
source = {
|
||||
BigqueryType.BigQuery.value.lower(): ProfilerSourceFactory.bigquery,
|
||||
DatabricksType.Databricks.value.lower(): ProfilerSourceFactory.databricks,
|
||||
}
|
||||
|
||||
profiler_source_factory = ProfilerSourceFactory()
|
||||
profiler_source_factory.register_source(
|
||||
BigqueryType.BigQuery.value.lower(),
|
||||
BigQueryProfilerSource,
|
||||
)
|
||||
profiler_source_factory.register_source(
|
||||
DatabricksType.Databricks.value.lower(),
|
||||
DataBricksProfilerSource,
|
||||
)
|
||||
profiler_source_factory.register_many_sources(source)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user