mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-06 12:36:56 +00:00
* Clean up test suite workflow and interface * Fixed tests * Split profiler and testSuite interfaces * Cleaned up workflows and runners * Fixed code formatting * - remove old code - remove `table` attribute used for testing and used mock instead * Fixed execution bugs from refactor * Fixed static type checking for profiler/api/workflow.py * Fixed linting * Added __init__ files
This commit is contained in:
parent
b914afcb6a
commit
f883863b8a
0
ingestion/src/metadata/interfaces/__init__.py
Normal file
0
ingestion/src/metadata/interfaces/__init__.py
Normal file
63
ingestion/src/metadata/interfaces/profiler_protocol.py
Normal file
63
ingestion/src/metadata/interfaces/profiler_protocol.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||
DatalakeConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
|
||||
|
||||
class ProfilerProtocol(ABC):
|
||||
"""Protocol interface for the profiler processor"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
ometa_client: OpenMetadata,
|
||||
service_connection_config: Union[DatabaseConnection, DatalakeConnection],
|
||||
):
|
||||
"""Required attribute for the interface"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def table(self):
|
||||
"""OM Table entity"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_metrics(self, metric_funcs) -> dict:
|
||||
"""run profiler metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_composed_metrics(
|
||||
self, column: Column, metric: Metrics, column_results: Dict
|
||||
) -> dict:
|
||||
"""run profiler metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_sample_data(self, table) -> dict:
|
||||
"""run profiler metrics"""
|
||||
raise NotImplementedError
|
||||
@ -1,555 +0,0 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sqlalchemy import Column, MetaData, inspect
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import DeclarativeMeta, Session
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
|
||||
SnowflakeType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.tests.basic import TestCaseResult
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.ingestion.api.processor import ProfilerProcessorStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.interfaces.interface_protocol import InterfaceProtocol
|
||||
from metadata.orm_profiler.api.models import TablePartitionConfig
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.orm.converter import ometa_to_orm
|
||||
from metadata.orm_profiler.profiler.handle_partition import (
|
||||
get_partition_cols,
|
||||
is_partitioned,
|
||||
)
|
||||
from metadata.orm_profiler.profiler.runner import QueryRunner
|
||||
from metadata.orm_profiler.profiler.sampler import Sampler
|
||||
from metadata.test_suite.validations.core import validation_enum_registry
|
||||
from metadata.utils.connections import (
|
||||
create_and_bind_thread_safe_session,
|
||||
get_connection,
|
||||
test_connection,
|
||||
)
|
||||
from metadata.utils.constants import TEN_MIN
|
||||
from metadata.utils.dispatch import enum_register
|
||||
from metadata.utils.helpers import get_start_and_end
|
||||
from metadata.utils.logger import sqa_interface_registry_logger
|
||||
from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY
|
||||
from metadata.utils.timeout import cls_timeout
|
||||
|
||||
logger = sqa_interface_registry_logger()
|
||||
|
||||
thread_local = threading.local()
|
||||
|
||||
|
||||
class SQAInterface(InterfaceProtocol):
|
||||
"""
|
||||
Interface to interact with registry supporting
|
||||
sqlalchemy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_connection_config,
|
||||
sqa_metadata_obj: Optional[MetaData] = None,
|
||||
metadata_config: Optional[OpenMetadataConnection] = None,
|
||||
thread_count: Optional[int] = 5,
|
||||
table_entity: Optional[Table] = None,
|
||||
table: Optional[DeclarativeMeta] = None,
|
||||
profile_sample: Optional[float] = None,
|
||||
profile_query: Optional[str] = None,
|
||||
partition_config: Optional[TablePartitionConfig] = None,
|
||||
):
|
||||
"""Instantiate SQA Interface object"""
|
||||
self._thread_count = thread_count
|
||||
self.table_entity = table_entity
|
||||
self._create_ometa_obj(metadata_config)
|
||||
|
||||
self.processor_status = ProfilerProcessorStatus()
|
||||
self.processor_status.entity = (
|
||||
self.table_entity.fullyQualifiedName.__root__
|
||||
if self.table_entity.fullyQualifiedName
|
||||
else None
|
||||
)
|
||||
|
||||
# Allows SQA Interface to be used without OM server config
|
||||
self.table = table or self._convert_table_to_orm_object(sqa_metadata_obj)
|
||||
self.service_connection_config = service_connection_config
|
||||
|
||||
self.session_factory = self._session_factory(service_connection_config)
|
||||
self.session: Session = self.session_factory()
|
||||
self.set_session_tag(self.session)
|
||||
|
||||
self.profile_sample = profile_sample
|
||||
self.profile_query = profile_query
|
||||
self.partition_details = (
|
||||
self._get_partition_details(partition_config)
|
||||
if not self.profile_query
|
||||
else None
|
||||
)
|
||||
|
||||
self._sampler = self.create_sampler()
|
||||
self._runner = self.create_runner()
|
||||
|
||||
@property
|
||||
def sample(self):
|
||||
"""Getter method for sample attribute"""
|
||||
if not self.sampler:
|
||||
raise RuntimeError(
|
||||
"You must create a sampler first `<instance>.create_sampler(...)`."
|
||||
)
|
||||
|
||||
return self.sampler.random_sample()
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
"""Getter method for runner attribute"""
|
||||
return self._runner
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
"""Getter methid for sampler attribute"""
|
||||
return self._sampler
|
||||
|
||||
@staticmethod
|
||||
def _session_factory(service_connection_config):
|
||||
"""Create thread safe session that will be automatically
|
||||
garbage collected once the application thread ends
|
||||
"""
|
||||
engine = get_connection(service_connection_config)
|
||||
return create_and_bind_thread_safe_session(engine)
|
||||
|
||||
def set_session_tag(self, session):
|
||||
"""
|
||||
Set session query tag
|
||||
Args:
|
||||
service_connection_config: connection details for the specific service
|
||||
"""
|
||||
if (
|
||||
self.service_connection_config.type.value == SnowflakeType.Snowflake.value
|
||||
and hasattr(self.service_connection_config, "queryTag")
|
||||
and self.service_connection_config.queryTag
|
||||
):
|
||||
session.execute(
|
||||
SNOWFLAKE_SESSION_TAG_QUERY.format(
|
||||
query_tag=self.service_connection_config.queryTag
|
||||
)
|
||||
)
|
||||
|
||||
def _get_engine(self, service_connection_config):
|
||||
"""Get engine for database
|
||||
|
||||
Args:
|
||||
service_connection_config: connection details for the specific service
|
||||
Returns:
|
||||
sqlalchemy engine
|
||||
"""
|
||||
engine = get_connection(service_connection_config)
|
||||
test_connection(engine)
|
||||
|
||||
return engine
|
||||
|
||||
def _get_partition_details(
|
||||
self, partition_config: TablePartitionConfig
|
||||
) -> Optional[Dict]:
|
||||
"""From partition config, get the partition table for a table entity"""
|
||||
if self.table_entity.serviceType == DatabaseServiceType.BigQuery:
|
||||
if is_partitioned(self.session, self.table):
|
||||
start, end = get_start_and_end(partition_config.partitionQueryDuration)
|
||||
partition_details = {
|
||||
"partition_field": partition_config.partitionField
|
||||
or get_partition_cols(self.session, self.table),
|
||||
"partition_start": start,
|
||||
"partition_end": end,
|
||||
"partition_values": partition_config.partitionValues,
|
||||
}
|
||||
|
||||
return partition_details
|
||||
|
||||
return None
|
||||
|
||||
def _create_ometa_obj(self, metadata_config):
|
||||
try:
|
||||
self._metadata = OpenMetadata(metadata_config)
|
||||
self._metadata.health_check()
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"No OpenMetadata server configuration found. Running profiler interface without OM server connection: {exc}"
|
||||
)
|
||||
self._metadata = None
|
||||
|
||||
def _create_thread_safe_sampler(
|
||||
self,
|
||||
session,
|
||||
table,
|
||||
):
|
||||
"""Create thread safe runner"""
|
||||
if not hasattr(thread_local, "sampler"):
|
||||
thread_local.sampler = Sampler(
|
||||
session=session,
|
||||
table=table,
|
||||
profile_sample=self.profile_sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return thread_local.sampler
|
||||
|
||||
def _create_thread_safe_runner(
|
||||
self,
|
||||
session,
|
||||
table,
|
||||
sample,
|
||||
):
|
||||
"""Create thread safe runner"""
|
||||
if not hasattr(thread_local, "runner"):
|
||||
thread_local.runner = QueryRunner(
|
||||
session=session,
|
||||
table=table,
|
||||
sample=sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return thread_local.runner
|
||||
|
||||
def _convert_table_to_orm_object(
|
||||
self, sqa_metadata_obj: Optional[MetaData]
|
||||
) -> DeclarativeMeta:
|
||||
"""Given a table entity return a SQA ORM object
|
||||
|
||||
Args:
|
||||
sqa_metadata_obj: sqa metadata registry
|
||||
Returns:
|
||||
DeclarativeMeta
|
||||
"""
|
||||
return ometa_to_orm(self.table_entity, self._metadata, sqa_metadata_obj)
|
||||
|
||||
def get_columns(self) -> Column:
|
||||
"""get columns from an orm object"""
|
||||
return inspect(self.table).c
|
||||
|
||||
def compute_metrics_in_thread(
|
||||
self,
|
||||
metric_funcs,
|
||||
):
|
||||
"""Run metrics in processor worker"""
|
||||
(
|
||||
metrics,
|
||||
metric_type,
|
||||
column,
|
||||
table,
|
||||
) = metric_funcs
|
||||
logger.debug(
|
||||
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
|
||||
)
|
||||
Session = self.session_factory
|
||||
with Session() as session:
|
||||
self.set_session_tag(session)
|
||||
sampler = self._create_thread_safe_sampler(
|
||||
session,
|
||||
table,
|
||||
)
|
||||
sample = sampler.random_sample()
|
||||
runner = self._create_thread_safe_runner(
|
||||
session,
|
||||
table,
|
||||
sample,
|
||||
)
|
||||
|
||||
row = compute_metrics_registry.registry[metric_type.value](
|
||||
metrics,
|
||||
runner=runner,
|
||||
session=session,
|
||||
column=column,
|
||||
sample=sample,
|
||||
processor_status=self.processor_status,
|
||||
)
|
||||
|
||||
if column is not None:
|
||||
column = column.name
|
||||
|
||||
return row, column
|
||||
|
||||
def get_all_metrics(
|
||||
self,
|
||||
metric_funcs: list,
|
||||
):
|
||||
"""get all profiler metrics"""
|
||||
logger.info(f"Computing metrics with {self._thread_count} threads.")
|
||||
profile_results = {"table": dict(), "columns": defaultdict(dict)}
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self._thread_count
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self.compute_metrics_in_thread,
|
||||
metric_func,
|
||||
)
|
||||
for metric_func in metric_funcs
|
||||
]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
profile, column = future.result()
|
||||
if not isinstance(profile, dict):
|
||||
profile = dict()
|
||||
if not column:
|
||||
profile_results["table"].update(profile)
|
||||
else:
|
||||
profile_results["columns"][column].update(
|
||||
{
|
||||
"name": column,
|
||||
"timestamp": datetime.now(tz=timezone.utc).timestamp(),
|
||||
**profile,
|
||||
}
|
||||
)
|
||||
return profile_results
|
||||
|
||||
def fetch_sample_data(self):
|
||||
if not self.sampler:
|
||||
raise RuntimeError(
|
||||
"You must create a sampler first `<instance>.create_sampler(...)`."
|
||||
)
|
||||
return self.sampler.fetch_sample_data()
|
||||
|
||||
def create_sampler(self) -> Sampler:
|
||||
"""Create sampler instance"""
|
||||
return Sampler(
|
||||
session=self.session,
|
||||
table=self.table,
|
||||
profile_sample=self.profile_sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
|
||||
def create_runner(self) -> None:
|
||||
"""Create a QueryRunner Instance"""
|
||||
|
||||
return cls_timeout(TEN_MIN)(
|
||||
QueryRunner(
|
||||
session=self.session,
|
||||
table=self.table,
|
||||
sample=self.sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_sample,
|
||||
)
|
||||
)
|
||||
|
||||
def get_composed_metrics(
|
||||
self, column: Column, metric: Metrics, column_results: Dict
|
||||
):
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
return metric(column).fn(column_results)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Unexpected exception computing metrics: {exc}")
|
||||
self.session.rollback()
|
||||
|
||||
def run_test_case(
|
||||
self,
|
||||
test_case: TestCase,
|
||||
) -> Optional[TestCaseResult]:
|
||||
"""Run table tests where platformsTest=OpenMetadata
|
||||
|
||||
Args:
|
||||
table_test_type: test type to be ran
|
||||
table_profile: table profile
|
||||
table: SQA table,
|
||||
profile_sample: sample for the profile
|
||||
"""
|
||||
|
||||
try:
|
||||
return validation_enum_registry.registry[
|
||||
test_case.testDefinition.fullyQualifiedName
|
||||
](
|
||||
test_case,
|
||||
execution_date=datetime.now(tz=timezone.utc).timestamp(),
|
||||
runner=self.runner,
|
||||
)
|
||||
except KeyError as err:
|
||||
logger.warning(
|
||||
f"Test definition {test_case.testDefinition.fullyQualifiedName} not registered in OpenMetadata TestDefintion registry. Skipping test case {test_case.name.__root__}"
|
||||
)
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""close session"""
|
||||
self.session.close()
|
||||
|
||||
|
||||
def get_table_metrics(
|
||||
metrics: List[Metrics],
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(*[metric().fn() for metric in metrics])
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
|
||||
|
||||
def get_static_metrics(
|
||||
metrics: List[Metrics],
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict[str, Union[str, int]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(
|
||||
*[
|
||||
metric(column).fn()
|
||||
for metric in metrics
|
||||
if not metric.is_window_metric()
|
||||
]
|
||||
)
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Static Metrics", f"{exc}")
|
||||
|
||||
|
||||
def get_query_metrics(
|
||||
metric: Metrics,
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
sample,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Union[str, int]]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
col_metric = metric(column)
|
||||
metric_query = col_metric.query(sample=sample, session=session)
|
||||
if not metric_query:
|
||||
return None
|
||||
if col_metric.metric_type == dict:
|
||||
results = runner.select_all_from_query(metric_query)
|
||||
data = {k: [result[k] for result in results] for k in dict(results[0])}
|
||||
return {metric.name(): data}
|
||||
|
||||
else:
|
||||
row = runner.select_first_from_query(metric_query)
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Query Metrics", f"{exc}")
|
||||
|
||||
|
||||
def get_window_metrics(
|
||||
metric: Metrics,
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict[str, Union[str, int]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(metric(column).fn())
|
||||
if not isinstance(row, Row):
|
||||
return {metric.name(): row}
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Window Metrics", f"{exc}")
|
||||
|
||||
|
||||
compute_metrics_registry = enum_register()
|
||||
compute_metrics_registry.add("Static")(get_static_metrics)
|
||||
compute_metrics_registry.add("Table")(get_table_metrics)
|
||||
compute_metrics_registry.add("Query")(get_query_metrics)
|
||||
compute_metrics_registry.add("Window")(get_window_metrics)
|
||||
123
ingestion/src/metadata/interfaces/sqalchemy/mixins/sqa_mixin.py
Normal file
123
ingestion/src/metadata/interfaces/sqalchemy/mixins/sqa_mixin.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy import Column, MetaData, inspect
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
|
||||
SnowflakeType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.orm_profiler.api.models import TablePartitionConfig
|
||||
from metadata.orm_profiler.orm.converter import ometa_to_orm
|
||||
from metadata.orm_profiler.profiler.handle_partition import (
|
||||
get_partition_cols,
|
||||
is_partitioned,
|
||||
)
|
||||
from metadata.utils.connections import get_connection
|
||||
from metadata.utils.helpers import get_start_and_end
|
||||
from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY
|
||||
|
||||
|
||||
class SQAInterfaceMixin:
|
||||
"""SQLAlchemy inteface mixin grouping shared methods between sequential and threaded executor"""
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
"""OM Table entity"""
|
||||
return self._table
|
||||
|
||||
def _get_engine(self):
|
||||
"""Get engine for database
|
||||
|
||||
Args:
|
||||
service_connection_config: connection details for the specific service
|
||||
Returns:
|
||||
sqlalchemy engine
|
||||
"""
|
||||
engine = get_connection(self.service_connection_config)
|
||||
|
||||
return engine
|
||||
|
||||
def _convert_table_to_orm_object(
|
||||
self,
|
||||
sqa_metadata_obj: Optional[MetaData] = None,
|
||||
) -> DeclarativeMeta:
|
||||
"""Given a table entity return a SQA ORM object
|
||||
|
||||
Args:
|
||||
sqa_metadata_obj: sqa metadata registry
|
||||
Returns:
|
||||
DeclarativeMeta
|
||||
"""
|
||||
return ometa_to_orm(self.table_entity, self.ometa_client, sqa_metadata_obj)
|
||||
|
||||
def get_columns(self) -> Column:
|
||||
"""get columns from an orm object"""
|
||||
return inspect(self.table).c
|
||||
|
||||
def set_session_tag(self, session) -> None:
|
||||
"""
|
||||
Set session query tag for snowflake
|
||||
|
||||
Args:
|
||||
service_connection_config: connection details for the specific service
|
||||
"""
|
||||
if (
|
||||
self.service_connection_config.type.value == SnowflakeType.Snowflake.value
|
||||
and hasattr(self.service_connection_config, "queryTag")
|
||||
and self.service_connection_config.queryTag
|
||||
):
|
||||
session.execute(
|
||||
SNOWFLAKE_SESSION_TAG_QUERY.format(
|
||||
query_tag=self.service_connection_config.queryTag
|
||||
)
|
||||
)
|
||||
|
||||
def get_partition_details(
|
||||
self, partition_config: TablePartitionConfig
|
||||
) -> Optional[Dict]:
|
||||
"""From partition config, get the partition table for a table entity
|
||||
|
||||
Args:
|
||||
partition_config: TablePartitionConfig object with some partition details
|
||||
|
||||
Returns:
|
||||
dict or None: dictionnary with all the elements constituing the a partition
|
||||
"""
|
||||
if (
|
||||
self.table_entity.serviceType == DatabaseServiceType.BigQuery
|
||||
and is_partitioned(self.session, self.table)
|
||||
):
|
||||
start, end = get_start_and_end(partition_config.partitionQueryDuration)
|
||||
return {
|
||||
"partition_field": partition_config.partitionField
|
||||
or get_partition_cols(self.session, self.table),
|
||||
"partition_start": start,
|
||||
"partition_end": end,
|
||||
"partition_values": partition_config.partitionValues,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""close session"""
|
||||
self.session.close()
|
||||
@ -0,0 +1,247 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy import Column, MetaData
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Table, TableData
|
||||
from metadata.ingestion.api.processor import ProfilerProcessorStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.interfaces.profiler_protocol import ProfilerProtocol
|
||||
from metadata.interfaces.sqalchemy.mixins.sqa_mixin import SQAInterfaceMixin
|
||||
from metadata.orm_profiler.api.models import TablePartitionConfig
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.metrics.sqa_metrics_computation_registry import (
|
||||
compute_metrics_registry,
|
||||
)
|
||||
from metadata.orm_profiler.profiler.runner import QueryRunner
|
||||
from metadata.orm_profiler.profiler.sampler import Sampler
|
||||
from metadata.utils.connections import (
|
||||
create_and_bind_thread_safe_session,
|
||||
get_connection,
|
||||
)
|
||||
from metadata.utils.logger import sqa_interface_registry_logger
|
||||
|
||||
logger = sqa_interface_registry_logger()
|
||||
thread_local = threading.local()
|
||||
|
||||
|
||||
class SQAProfilerInterface(SQAInterfaceMixin, ProfilerProtocol):
|
||||
"""
|
||||
Interface to interact with registry supporting
|
||||
sqlalchemy.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
service_connection_config,
|
||||
ometa_client: OpenMetadata,
|
||||
sqa_metadata_obj: Optional[MetaData] = None,
|
||||
thread_count: Optional[float] = 5,
|
||||
table_entity: Optional[Table] = None,
|
||||
table_sample_precentage: Optional[float] = None,
|
||||
table_sample_query: Optional[str] = None,
|
||||
table_partition_config: Optional[TablePartitionConfig] = None,
|
||||
):
|
||||
"""Instantiate SQA Interface object"""
|
||||
self._thread_count = thread_count
|
||||
self.table_entity = table_entity
|
||||
self.ometa_client = ometa_client
|
||||
self.service_connection_config = service_connection_config
|
||||
|
||||
self.processor_status = ProfilerProcessorStatus()
|
||||
self.processor_status.entity = (
|
||||
self.table_entity.fullyQualifiedName.__root__
|
||||
if self.table_entity.fullyQualifiedName
|
||||
else None
|
||||
)
|
||||
|
||||
self._table = self._convert_table_to_orm_object(sqa_metadata_obj)
|
||||
|
||||
self.session_factory = self._session_factory(service_connection_config)
|
||||
self.session = self.session_factory()
|
||||
self.set_session_tag(self.session)
|
||||
|
||||
self.profile_sample = table_sample_precentage
|
||||
self.profile_query = table_sample_query
|
||||
self.partition_details = (
|
||||
self.get_partition_details(table_partition_config)
|
||||
if not self.profile_query
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _session_factory(service_connection_config):
|
||||
"""Create thread safe session that will be automatically
|
||||
garbage collected once the application thread ends
|
||||
"""
|
||||
engine = get_connection(service_connection_config)
|
||||
return create_and_bind_thread_safe_session(engine)
|
||||
|
||||
def _create_thread_safe_sampler(
|
||||
self,
|
||||
session,
|
||||
table,
|
||||
):
|
||||
"""Create thread safe runner"""
|
||||
if not hasattr(thread_local, "sampler"):
|
||||
thread_local.sampler = Sampler(
|
||||
session=session,
|
||||
table=table,
|
||||
profile_sample=self.profile_sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return thread_local.sampler
|
||||
|
||||
def _create_thread_safe_runner(
|
||||
self,
|
||||
session,
|
||||
table,
|
||||
sample,
|
||||
):
|
||||
"""Create thread safe runner"""
|
||||
if not hasattr(thread_local, "runner"):
|
||||
thread_local.runner = QueryRunner(
|
||||
session=session,
|
||||
table=table,
|
||||
sample=sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return thread_local.runner
|
||||
|
||||
def compute_metrics_in_thread(
|
||||
self,
|
||||
metric_funcs,
|
||||
):
|
||||
"""Run metrics in processor worker"""
|
||||
(
|
||||
metrics,
|
||||
metric_type,
|
||||
column,
|
||||
table,
|
||||
) = metric_funcs
|
||||
logger.debug(
|
||||
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
|
||||
)
|
||||
Session = self.session_factory # pylint: disable=invalid-name
|
||||
with Session() as session:
|
||||
self.set_session_tag(session)
|
||||
sampler = self._create_thread_safe_sampler(
|
||||
session,
|
||||
table,
|
||||
)
|
||||
sample = sampler.random_sample()
|
||||
runner = self._create_thread_safe_runner(
|
||||
session,
|
||||
table,
|
||||
sample,
|
||||
)
|
||||
|
||||
row = compute_metrics_registry.registry[metric_type.value](
|
||||
metrics,
|
||||
runner=runner,
|
||||
session=session,
|
||||
column=column,
|
||||
sample=sample,
|
||||
processor_status=self.processor_status,
|
||||
)
|
||||
|
||||
if column is not None:
|
||||
column = column.name
|
||||
|
||||
return row, column
|
||||
|
||||
# pylint: disable=use-dict-literal
|
||||
def get_all_metrics(
|
||||
self,
|
||||
metric_funcs: list,
|
||||
):
|
||||
"""get all profiler metrics"""
|
||||
logger.info(f"Computing metrics with {self._thread_count} threads.")
|
||||
profile_results = {"table": dict(), "columns": defaultdict(dict)}
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self._thread_count
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self.compute_metrics_in_thread,
|
||||
metric_func,
|
||||
)
|
||||
for metric_func in metric_funcs
|
||||
]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
profile, column = future.result()
|
||||
if not isinstance(profile, dict):
|
||||
profile = dict()
|
||||
if not column:
|
||||
profile_results["table"].update(profile)
|
||||
else:
|
||||
profile_results["columns"][column].update(
|
||||
{
|
||||
"name": column,
|
||||
"timestamp": datetime.now(tz=timezone.utc).timestamp(),
|
||||
**profile,
|
||||
}
|
||||
)
|
||||
return profile_results
|
||||
|
||||
def fetch_sample_data(self, table) -> TableData:
|
||||
"""Fetch sample data from database
|
||||
|
||||
Args:
|
||||
table: ORM declarative table
|
||||
|
||||
Returns:
|
||||
TableData: sample table data
|
||||
"""
|
||||
sampler = Sampler(
|
||||
session=self.session,
|
||||
table=table,
|
||||
profile_sample=self.profile_sample,
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return sampler.fetch_sample_data()
|
||||
|
||||
def get_composed_metrics(
|
||||
self, column: Column, metric: Metrics, column_results: Dict
|
||||
):
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
return metric(column).fn(column_results)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Unexpected exception computing metrics: {exc}")
|
||||
self.session.rollback()
|
||||
return None
|
||||
@ -0,0 +1,159 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
from sqlalchemy.orm.util import AliasedClass
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
|
||||
from metadata.generated.schema.tests.basic import TestCaseResult
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.interfaces.sqalchemy.mixins.sqa_mixin import SQAInterfaceMixin
|
||||
from metadata.interfaces.test_suite_protocol import TestSuiteProtocol
|
||||
from metadata.orm_profiler.profiler.runner import QueryRunner
|
||||
from metadata.orm_profiler.profiler.sampler import Sampler
|
||||
from metadata.test_suite.validations.core import validation_enum_registry
|
||||
from metadata.utils.connections import create_and_bind_session, get_connection
|
||||
from metadata.utils.constants import TEN_MIN
|
||||
from metadata.utils.logger import sqa_interface_registry_logger
|
||||
from metadata.utils.timeout import cls_timeout
|
||||
|
||||
logger = sqa_interface_registry_logger()
|
||||
|
||||
|
||||
class SQATestSuiteInterface(SQAInterfaceMixin, TestSuiteProtocol):
|
||||
"""
|
||||
Sequential interface protocol for testSuite and Profiler. This class
|
||||
implements specific operations needed to run profiler and test suite workflow
|
||||
against a SQAlchemy source.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_connection_config: DatabaseConnection,
|
||||
ometa_client: OpenMetadata,
|
||||
table_sample_precentage: float = None,
|
||||
table_sample_query: str = None,
|
||||
table_partition_config: dict = None,
|
||||
table_entity: Table = None,
|
||||
):
|
||||
self.ometa_client = ometa_client
|
||||
self.table_entity = table_entity
|
||||
self.service_connection_config = service_connection_config
|
||||
self.session = create_and_bind_session(
|
||||
get_connection(self.service_connection_config)
|
||||
)
|
||||
self.set_session_tag(self.session)
|
||||
|
||||
self._table = self._convert_table_to_orm_object()
|
||||
|
||||
self.table_sample_precentage = table_sample_precentage
|
||||
self.table_sample_query = table_sample_query
|
||||
self.table_partition_config = (
|
||||
self.get_partition_details(table_partition_config)
|
||||
if not self.table_sample_query
|
||||
else None
|
||||
)
|
||||
|
||||
self._sampler = self._create_sampler()
|
||||
self._runner = self._create_runner()
|
||||
|
||||
@property
|
||||
def sample(self) -> Union[DeclarativeMeta, AliasedClass]:
|
||||
"""_summary_
|
||||
|
||||
Returns:
|
||||
Union[DeclarativeMeta, AliasedClass]: _description_
|
||||
"""
|
||||
if not self.sampler:
|
||||
raise RuntimeError(
|
||||
"You must create a sampler first `<instance>.create_sampler(...)`."
|
||||
)
|
||||
|
||||
return self.sampler.random_sample()
|
||||
|
||||
@property
|
||||
def runner(self) -> QueryRunner:
|
||||
"""getter method for the QueryRunner object
|
||||
|
||||
Returns:
|
||||
QueryRunner: runner object
|
||||
"""
|
||||
return self._runner
|
||||
|
||||
@property
|
||||
def sampler(self) -> Sampler:
|
||||
"""getter method for the Runner object
|
||||
|
||||
Returns:
|
||||
Sampler: sampler object
|
||||
"""
|
||||
return self._sampler
|
||||
|
||||
def _create_sampler(self) -> Sampler:
|
||||
"""Create sampler instance"""
|
||||
return Sampler(
|
||||
session=self.session,
|
||||
table=self.table,
|
||||
profile_sample=self.table_sample_precentage,
|
||||
partition_details=self.table_partition_config,
|
||||
profile_sample_query=self.table_sample_query,
|
||||
)
|
||||
|
||||
def _create_runner(self) -> None:
|
||||
"""Create a QueryRunner Instance"""
|
||||
|
||||
return cls_timeout(TEN_MIN)(
|
||||
QueryRunner(
|
||||
session=self.session,
|
||||
table=self.table,
|
||||
sample=self.sample,
|
||||
partition_details=self.table_partition_config,
|
||||
profile_sample_query=self.table_sample_query,
|
||||
)
|
||||
)
|
||||
|
||||
def run_test_case(
|
||||
self,
|
||||
test_case: TestCase,
|
||||
) -> Optional[TestCaseResult]:
|
||||
"""Run table tests where platformsTest=OpenMetadata
|
||||
|
||||
Args:
|
||||
test_case: test case object to execute
|
||||
|
||||
Returns:
|
||||
TestCaseResult object
|
||||
"""
|
||||
|
||||
try:
|
||||
return validation_enum_registry.registry[
|
||||
test_case.testDefinition.fullyQualifiedName
|
||||
](
|
||||
test_case,
|
||||
execution_date=datetime.now(tz=timezone.utc).timestamp(),
|
||||
runner=self.runner,
|
||||
)
|
||||
except KeyError as err:
|
||||
logger.warning(
|
||||
f"Test definition {test_case.testDefinition.fullyQualifiedName} not registered in OpenMetadata "
|
||||
f"TestDefintion registry. Skipping test case {test_case.name.__root__} - {err}"
|
||||
)
|
||||
return None
|
||||
@ -15,42 +15,30 @@ supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||
DatalakeConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
|
||||
from metadata.generated.schema.tests.basic import TestCaseResult
|
||||
from metadata.orm_profiler.api.models import ProfilerProcessorConfig
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
|
||||
|
||||
class InterfaceProtocol(ABC):
|
||||
class TestSuiteProtocol(ABC):
|
||||
"""Protocol interface for the processor"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
metadata_config: OpenMetadataConnection = None,
|
||||
profiler_config: ProfilerProcessorConfig = None,
|
||||
workflow_profile_sample: float = None,
|
||||
thread_count: int = 5,
|
||||
table: Table = None,
|
||||
ometa_client: OpenMetadata = None,
|
||||
service_connection_config: Union[DatabaseConnection, DatalakeConnection] = None,
|
||||
):
|
||||
"""Required attribute for the interface"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_sampler(*args, **kwargs) -> None:
|
||||
"""Method to instantiate a Sampler object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_runner(*args, **kwargs) -> None:
|
||||
"""Method to instantiate a Runner object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_test_case(*args, **kwargs) -> Optional[TestCaseResult]:
|
||||
def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]:
|
||||
"""run column data quality tests"""
|
||||
raise NotImplementedError
|
||||
@ -18,7 +18,7 @@ Workflow definition for the ORM Profiler.
|
||||
"""
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List, Optional, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import MetaData
|
||||
@ -30,7 +30,6 @@ from metadata.generated.schema.entity.data.table import (
|
||||
ColumnProfilerConfig,
|
||||
IntervalType,
|
||||
Table,
|
||||
TableProfile,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
@ -48,9 +47,11 @@ from metadata.ingestion.api.processor import ProcessorStatus
|
||||
from metadata.ingestion.api.sink import Sink
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.profiler_protocol import ProfilerProtocol
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.api.models import (
|
||||
ProfilerProcessorConfig,
|
||||
ProfilerResponse,
|
||||
TableConfig,
|
||||
TablePartitionConfig,
|
||||
)
|
||||
@ -63,12 +64,17 @@ from metadata.utils.class_helper import (
|
||||
get_service_type_from_source_type,
|
||||
)
|
||||
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
|
||||
from metadata.utils.helpers import create_ometa_client
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.workflow_output_handler import print_profiler_status
|
||||
|
||||
logger = profiler_logger()
|
||||
|
||||
|
||||
class ProfilerInterfaceInstantiationError(Exception):
|
||||
"""Raise when interface cannot be instantiated"""
|
||||
|
||||
|
||||
class ProfilerWorkflow:
|
||||
"""
|
||||
Configure and run the ORM profiler
|
||||
@ -93,9 +99,9 @@ class ProfilerWorkflow:
|
||||
self._retrieve_service_connection_if_needed()
|
||||
|
||||
# Init and type the source config
|
||||
self.source_config: DatabaseServiceProfilerPipeline = (
|
||||
self.config.source.sourceConfig.config
|
||||
)
|
||||
self.source_config: DatabaseServiceProfilerPipeline = cast(
|
||||
DatabaseServiceProfilerPipeline, self.config.source.sourceConfig.config
|
||||
) # Used to satisfy type checked
|
||||
self.source_status = SQLSourceStatus()
|
||||
self.status = ProcessorStatus()
|
||||
|
||||
@ -143,14 +149,14 @@ class ProfilerWorkflow:
|
||||
table_config
|
||||
for table_config in self.profiler_config.tableConfig
|
||||
if table_config.fullyQualifiedName.__root__
|
||||
== entity.fullyQualifiedName.__root__
|
||||
== entity.fullyQualifiedName.__root__ # type: ignore
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def get_include_columns(self, entity) -> Optional[List[ColumnProfilerConfig]]:
|
||||
"""get included columns"""
|
||||
entity_config: TableConfig = self.get_config_for_entity(entity)
|
||||
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
|
||||
if entity_config and entity_config.columnConfig:
|
||||
return entity_config.columnConfig.includeColumns
|
||||
|
||||
@ -161,7 +167,7 @@ class ProfilerWorkflow:
|
||||
|
||||
def get_exclude_columns(self, entity) -> Optional[List[str]]:
|
||||
"""get included columns"""
|
||||
entity_config: TableConfig = self.get_config_for_entity(entity)
|
||||
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
|
||||
if entity_config and entity_config.columnConfig:
|
||||
return entity_config.columnConfig.excludeColumns
|
||||
|
||||
@ -176,7 +182,7 @@ class ProfilerWorkflow:
|
||||
Args:
|
||||
entity: table entity
|
||||
"""
|
||||
entity_config: TableConfig = self.get_config_for_entity(entity)
|
||||
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
|
||||
if entity_config:
|
||||
return entity_config.profileSample
|
||||
|
||||
@ -188,13 +194,13 @@ class ProfilerWorkflow:
|
||||
|
||||
return None
|
||||
|
||||
def get_profile_query(self, entity: Table) -> Optional[float]:
|
||||
def get_profile_query(self, entity: Table) -> Optional[str]:
|
||||
"""Get profile sample
|
||||
|
||||
Args:
|
||||
entity: table entity
|
||||
"""
|
||||
entity_config: TableConfig = self.get_config_for_entity(entity)
|
||||
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
|
||||
if entity_config:
|
||||
return entity_config.profileQuery
|
||||
|
||||
@ -215,7 +221,7 @@ class ProfilerWorkflow:
|
||||
):
|
||||
return None
|
||||
|
||||
entity_config: TableConfig = self.get_config_for_entity(entity)
|
||||
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
|
||||
if entity_config:
|
||||
return entity_config.partitionConfig
|
||||
|
||||
@ -240,25 +246,30 @@ class ProfilerWorkflow:
|
||||
sqa_metadata_obj: Optional[MetaData] = None,
|
||||
):
|
||||
"""Creates a profiler interface object"""
|
||||
return SQAInterface(
|
||||
service_connection_config,
|
||||
sqa_metadata_obj=sqa_metadata_obj,
|
||||
metadata_config=self.metadata_config,
|
||||
thread_count=self.source_config.threadCount,
|
||||
table_entity=table_entity,
|
||||
profile_sample=self.get_profile_sample(table_entity)
|
||||
if not self.get_profile_query(table_entity)
|
||||
else None,
|
||||
profile_query=self.get_profile_query(table_entity)
|
||||
if not self.get_profile_sample(table_entity)
|
||||
else None,
|
||||
partition_config=self.get_partition_details(table_entity)
|
||||
if not self.get_profile_query(table_entity)
|
||||
else None,
|
||||
)
|
||||
try:
|
||||
return SQAProfilerInterface(
|
||||
service_connection_config,
|
||||
sqa_metadata_obj=sqa_metadata_obj,
|
||||
ometa_client=create_ometa_client(self.metadata_config),
|
||||
thread_count=self.source_config.threadCount,
|
||||
table_entity=table_entity,
|
||||
table_sample_precentage=self.get_profile_sample(table_entity)
|
||||
if not self.get_profile_query(table_entity)
|
||||
else None,
|
||||
table_sample_query=self.get_profile_query(table_entity)
|
||||
if not self.get_profile_sample(table_entity)
|
||||
else None,
|
||||
table_partition_config=self.get_partition_details(table_entity)
|
||||
if not self.get_profile_query(table_entity)
|
||||
else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error("We could not create a profiler interface")
|
||||
raise ProfilerInterfaceInstantiationError(exc)
|
||||
|
||||
def create_profiler_obj(
|
||||
self, table_entity: Table, profiler_interface: SQAInterface
|
||||
self, table_entity: Table, profiler_interface: ProfilerProtocol
|
||||
):
|
||||
"""Profile a single entity"""
|
||||
if not self.profiler_config.profiler:
|
||||
@ -275,7 +286,7 @@ class ProfilerWorkflow:
|
||||
)
|
||||
|
||||
self.profiler_obj = Profiler(
|
||||
*metrics,
|
||||
*metrics, # type: ignore
|
||||
profiler_interface=profiler_interface,
|
||||
include_columns=self.get_include_columns(table_entity),
|
||||
exclude_columns=self.get_exclude_columns(table_entity),
|
||||
@ -293,7 +304,7 @@ class ProfilerWorkflow:
|
||||
return None
|
||||
return database
|
||||
|
||||
def filter_entities(self, tables: List[Table]) -> Iterable[Table]:
|
||||
def filter_entities(self, tables: Iterable[Table]) -> Iterable[Table]:
|
||||
"""
|
||||
From a list of tables, apply the SQLSourceConfig
|
||||
filter patterns.
|
||||
@ -304,10 +315,10 @@ class ProfilerWorkflow:
|
||||
try:
|
||||
if filter_by_schema(
|
||||
self.source_config.schemaFilterPattern,
|
||||
table.databaseSchema.name,
|
||||
table.databaseSchema.name, # type: ignore
|
||||
):
|
||||
self.source_status.filter(
|
||||
f"Schema pattern not allowed: {table.fullyQualifiedName.__root__}",
|
||||
f"Schema pattern not allowed: {table.fullyQualifiedName.__root__}", # type: ignore
|
||||
"Schema pattern not allowed",
|
||||
)
|
||||
continue
|
||||
@ -316,7 +327,7 @@ class ProfilerWorkflow:
|
||||
table.name.__root__,
|
||||
):
|
||||
self.source_status.filter(
|
||||
f"Table pattern not allowed: {table.fullyQualifiedName.__root__}",
|
||||
f"Table pattern not allowed: {table.fullyQualifiedName.__root__}", # type: ignore
|
||||
"Table pattern not allowed",
|
||||
)
|
||||
continue
|
||||
@ -325,9 +336,10 @@ class ProfilerWorkflow:
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Unexpected error filtering entities for table [{table.fullyQualifiedName.__root__}]: {exc}"
|
||||
"Unexpected error filtering entities for table "
|
||||
f"[{table.fullyQualifiedName.__root__}]: {exc}" # type: ignore
|
||||
)
|
||||
self.source_status.failure(table.fullyQualifiedName.__root__, f"{exc}")
|
||||
self.source_status.failure(table.fullyQualifiedName.__root__, f"{exc}") # type: ignore
|
||||
|
||||
def get_database_entities(self):
|
||||
"""List all databases in service"""
|
||||
@ -370,25 +382,28 @@ class ProfilerWorkflow:
|
||||
service_name=self.config.source.serviceName,
|
||||
database_name=database.name.__root__,
|
||||
),
|
||||
},
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
yield from self.filter_entities(all_tables)
|
||||
|
||||
def copy_service_config(self, database) -> None:
|
||||
def copy_service_config(self, database) -> DatabaseService.__config__:
|
||||
copy_service_connection_config = deepcopy(
|
||||
self.config.source.serviceConnection.__root__.config
|
||||
self.config.source.serviceConnection.__root__.config # type: ignore
|
||||
)
|
||||
if hasattr(
|
||||
self.config.source.serviceConnection.__root__.config,
|
||||
self.config.source.serviceConnection.__root__.config, # type: ignore
|
||||
"supportsDatabase",
|
||||
):
|
||||
if hasattr(
|
||||
self.config.source.serviceConnection.__root__.config, "database"
|
||||
):
|
||||
copy_service_connection_config.database = database.name.__root__
|
||||
if hasattr(self.config.source.serviceConnection.__root__.config, "catalog"):
|
||||
copy_service_connection_config.catalog = database.name.__root__
|
||||
if hasattr(copy_service_connection_config, "database"):
|
||||
copy_service_connection_config.database = database.name.__root__ # type: ignore
|
||||
if hasattr(copy_service_connection_config, "catalog"):
|
||||
copy_service_connection_config.catalog = database.name.__root__ # type: ignore
|
||||
|
||||
# we know we'll only be working with databaseServices, we cast the type to satisfy type checker
|
||||
copy_service_connection_config = cast(
|
||||
DatabaseService.__config__, copy_service_connection_config
|
||||
)
|
||||
|
||||
return copy_service_connection_config
|
||||
|
||||
@ -402,8 +417,8 @@ class ProfilerWorkflow:
|
||||
if not databases:
|
||||
raise ValueError(
|
||||
"databaseFilterPattern returned 0 result. At least 1 database must be returned by the filter pattern."
|
||||
f"\n\t- includes: {self.source_config.databaseFilterPattern.includes}\n\t"
|
||||
f"- excludes: {self.source_config.databaseFilterPattern.excludes}"
|
||||
f"\n\t- includes: {self.source_config.databaseFilterPattern.includes if self.source_config.databaseFilterPattern else None}" # pylint: disable=line-too-long
|
||||
f"\n\t- excludes: {self.source_config.databaseFilterPattern.excludes if self.source_config.databaseFilterPattern else None}" # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
for database in databases:
|
||||
@ -416,7 +431,7 @@ class ProfilerWorkflow:
|
||||
copied_service_config, entity, sqa_metadata_obj
|
||||
)
|
||||
self.create_profiler_obj(entity, profiler_interface)
|
||||
profile: TableProfile = self.profiler_obj.process(
|
||||
profile: ProfilerResponse = self.profiler_obj.process(
|
||||
self.source_config.generateSampleData
|
||||
)
|
||||
profiler_interface.close()
|
||||
@ -425,18 +440,19 @@ class ProfilerWorkflow:
|
||||
self.status.failures.extend(
|
||||
profiler_interface.processor_status.failures
|
||||
) # we can have column level failures we need to report
|
||||
self.status.processed(entity.fullyQualifiedName.__root__)
|
||||
self.source_status.scanned(entity.fullyQualifiedName.__root__)
|
||||
self.status.processed(entity.fullyQualifiedName.__root__) # type: ignore
|
||||
self.source_status.scanned(entity.fullyQualifiedName.__root__) # type: ignore
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Unexpected exception processing entity [{entity.fullyQualifiedName.__root__}]: {exc}"
|
||||
"Unexpected exception processing entity "
|
||||
f"[{entity.fullyQualifiedName.__root__}]: {exc}" # type: ignore
|
||||
)
|
||||
self.status.failures.extend(
|
||||
profiler_interface.processor_status.failures
|
||||
profiler_interface.processor_status.failures # type: ignore
|
||||
)
|
||||
self.source_status.failure(
|
||||
entity.fullyQualifiedName.__root__, f"{exc}"
|
||||
entity.fullyQualifiedName.__root__, f"{exc}" # type: ignore
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
@ -0,0 +1,185 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenMetadata Profiler supported metrics
|
||||
|
||||
Use these registries to avoid messy imports.
|
||||
|
||||
Note that we are using our own Registry definition
|
||||
that allows us to directly call our metrics without
|
||||
having the verbosely pass .value all the time...
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from metadata.ingestion.api.processor import ProfilerProcessorStatus
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.profiler.runner import QueryRunner
|
||||
from metadata.utils.dispatch import enum_register
|
||||
from metadata.utils.logger import sqa_interface_registry_logger
|
||||
|
||||
logger = sqa_interface_registry_logger()
|
||||
|
||||
|
||||
def get_table_metrics(
|
||||
metrics: List[Metrics],
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(*[metric().fn() for metric in metrics])
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
def get_static_metrics(
|
||||
metrics: List[Metrics],
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Union[str, int]]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(
|
||||
*[
|
||||
metric(column).fn()
|
||||
for metric in metrics
|
||||
if not metric.is_window_metric()
|
||||
]
|
||||
)
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Static Metrics", f"{exc}")
|
||||
return None
|
||||
|
||||
|
||||
def get_query_metrics(
|
||||
metric: Metrics,
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
sample,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Union[str, int]]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
col_metric = metric(column)
|
||||
metric_query = col_metric.query(sample=sample, session=session)
|
||||
if not metric_query:
|
||||
return None
|
||||
if col_metric.metric_type == dict:
|
||||
results = runner.select_all_from_query(metric_query)
|
||||
data = {k: [result[k] for result in results] for k in dict(results[0])}
|
||||
return {metric.name(): data}
|
||||
|
||||
row = runner.select_first_from_query(metric_query)
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Query Metrics", f"{exc}")
|
||||
return None
|
||||
|
||||
|
||||
def get_window_metrics(
|
||||
metric: Metrics,
|
||||
runner: QueryRunner,
|
||||
session: Session,
|
||||
column: Column,
|
||||
processor_status: ProfilerProcessorStatus,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict[str, Union[str, int]]:
|
||||
"""Given a list of metrics, compute the given results
|
||||
and returns the values
|
||||
|
||||
Args:
|
||||
column: the column to compute the metrics against
|
||||
metrics: list of metrics to compute
|
||||
Returns:
|
||||
dictionnary of results
|
||||
"""
|
||||
try:
|
||||
row = runner.select_first_from_sample(metric(column).fn())
|
||||
if not isinstance(row, Row):
|
||||
return {metric.name(): row}
|
||||
return dict(row)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
)
|
||||
session.rollback()
|
||||
processor_status.failure(f"{column.name}", "Window Metrics", f"{exc}")
|
||||
return None
|
||||
|
||||
|
||||
compute_metrics_registry = enum_register()
|
||||
compute_metrics_registry.add("Static")(get_static_metrics)
|
||||
compute_metrics_registry.add("Table")(get_table_metrics)
|
||||
compute_metrics_registry.add("Query")(get_query_metrics)
|
||||
compute_metrics_registry.add("Window")(get_window_metrics)
|
||||
@ -47,6 +47,7 @@ def _(elements, compiler, **kwargs):
|
||||
return "median(%s)" % col
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@compiles(MedianFn, Dialects.Trino)
|
||||
@compiles(MedianFn, Dialects.Presto)
|
||||
def _(elements, compiler, **kwargs):
|
||||
|
||||
@ -15,14 +15,12 @@ Main Profile definition and queries to execute
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from metadata.generated.schema.api.data.createTableProfile import (
|
||||
CreateTableProfileRequest,
|
||||
@ -32,8 +30,7 @@ from metadata.generated.schema.entity.data.table import (
|
||||
ColumnProfilerConfig,
|
||||
TableProfile,
|
||||
)
|
||||
from metadata.interfaces.interface_protocol import InterfaceProtocol
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.profiler_protocol import ProfilerProtocol
|
||||
from metadata.orm_profiler.api.models import ProfilerResponse
|
||||
from metadata.orm_profiler.metrics.core import (
|
||||
ComposedMetric,
|
||||
@ -71,10 +68,10 @@ class Profiler(Generic[TMetric]):
|
||||
def __init__(
|
||||
self,
|
||||
*metrics: Type[TMetric],
|
||||
profiler_interface: InterfaceProtocol,
|
||||
profiler_interface: ProfilerProtocol,
|
||||
profile_date: datetime = datetime.now(tz=timezone.utc).timestamp(),
|
||||
include_columns: List[Optional[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: List[Optional[str]] = None,
|
||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
:param metrics: Metrics to run. We are receiving the uninitialized classes
|
||||
@ -99,21 +96,6 @@ class Profiler(Generic[TMetric]):
|
||||
# We will get columns from the property
|
||||
self._columns: Optional[List[Column]] = None
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""Kept for backward compatibility"""
|
||||
warnings.warn(
|
||||
"`<instance>`.session will be retired as platform specific. Instead use"
|
||||
"`<instance>.profiler_interface` to see if session is supported by the profiler interface",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if isinstance(self.profiler_interface, SQAInterface):
|
||||
return self.profiler_interface.session
|
||||
|
||||
raise ValueError(
|
||||
f"session is not supported for profiler interface {self.profiler_interface.__class__.__name__}"
|
||||
)
|
||||
|
||||
@property
|
||||
def table(self) -> DeclarativeMeta:
|
||||
return self.profiler_interface.table
|
||||
@ -392,7 +374,7 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
return self
|
||||
|
||||
def process(self, generate_sample_data: bool) -> ProfilerResponse:
|
||||
def process(self, generate_sample_data: Optional[bool]) -> ProfilerResponse:
|
||||
"""
|
||||
Given a table, we will prepare the profiler for
|
||||
all its columns and return all the run profilers
|
||||
@ -408,7 +390,7 @@ class Profiler(Generic[TMetric]):
|
||||
logger.info(
|
||||
f"Fetching sample data for {self.profiler_interface.table_entity.fullyQualifiedName.__root__}..."
|
||||
)
|
||||
sample_data = self.profiler_interface.fetch_sample_data()
|
||||
sample_data = self.profiler_interface.fetch_sample_data(self.table)
|
||||
except Exception as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error fetching sample data: {err}")
|
||||
|
||||
@ -17,7 +17,7 @@ from typing import List, Optional
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.profiler_protocol import ProfilerProtocol
|
||||
from metadata.orm_profiler.metrics.core import Metric, add_props
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.profiler.core import Profiler
|
||||
@ -58,9 +58,9 @@ class DefaultProfiler(Profiler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
profiler_interface: SQAInterface,
|
||||
include_columns: List[Optional[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: List[Optional[str]] = None,
|
||||
profiler_interface: ProfilerProtocol,
|
||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: Optional[List[str]] = None,
|
||||
):
|
||||
|
||||
_metrics = get_default_metrics(profiler_interface.table)
|
||||
|
||||
@ -56,7 +56,8 @@ class Sampler:
|
||||
self.table, (ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL)
|
||||
)
|
||||
.suffix_with(
|
||||
f"SAMPLE SYSTEM ({self.profile_sample})", dialect=Dialects.Snowflake
|
||||
f"SAMPLE SYSTEM ({self.profile_sample or 100})",
|
||||
dialect=Dialects.Snowflake,
|
||||
)
|
||||
.cte(f"{self.table.__tablename__}_rnd")
|
||||
)
|
||||
|
||||
@ -47,11 +47,12 @@ from metadata.generated.schema.tests.testSuite import TestSuite
|
||||
from metadata.ingestion.api.parser import parse_workflow_config_gracefully
|
||||
from metadata.ingestion.api.processor import ProcessorStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_test_suite_interface import SQATestSuiteInterface
|
||||
from metadata.orm_profiler.api.models import TablePartitionConfig
|
||||
from metadata.test_suite.api.models import TestCaseDefinition, TestSuiteProcessorConfig
|
||||
from metadata.test_suite.runner.core import DataTestsRunner
|
||||
from metadata.utils import entity_link
|
||||
from metadata.utils.helpers import create_ometa_client
|
||||
from metadata.utils.logger import test_suite_logger
|
||||
from metadata.utils.workflow_output_handler import print_test_suite_status
|
||||
|
||||
@ -115,27 +116,27 @@ class TestSuiteWorkflow:
|
||||
)
|
||||
raise err
|
||||
|
||||
def _filter_test_cases_for_table_entity(
|
||||
self, table_fqn: str, test_cases: List[TestCase]
|
||||
def _filter_test_cases_for_entity(
|
||||
self, entity_fqn: str, test_cases: List[TestCase]
|
||||
) -> list[TestCase]:
|
||||
"""Filter test cases for specific entity"""
|
||||
return [
|
||||
test_case
|
||||
for test_case in test_cases
|
||||
if test_case.entityLink.__root__.split("::")[2].replace(">", "")
|
||||
== table_fqn
|
||||
== entity_fqn
|
||||
]
|
||||
|
||||
def _get_unique_table_entities(self, test_cases: List[TestCase]) -> Set:
|
||||
def _get_unique_entities_from_test_cases(self, test_cases: List[TestCase]) -> Set:
|
||||
"""from a list of test cases extract unique table entities"""
|
||||
table_fqns = [
|
||||
entity_fqns = [
|
||||
test_case.entityLink.__root__.split("::")[2].replace(">", "")
|
||||
for test_case in test_cases
|
||||
]
|
||||
|
||||
return set(table_fqns)
|
||||
return set(entity_fqns)
|
||||
|
||||
def _get_service_connection_from_test_case(self, table_fqn: str):
|
||||
def _get_service_connection_from_test_case(self, entity_fqn: str):
|
||||
"""given an entityLink return the service connection
|
||||
|
||||
Args:
|
||||
@ -143,7 +144,7 @@ class TestSuiteWorkflow:
|
||||
"""
|
||||
service = self.metadata.get_by_name(
|
||||
entity=DatabaseService,
|
||||
fqn=table_fqn.split(".")[0],
|
||||
fqn=entity_fqn.split(".")[0],
|
||||
)
|
||||
|
||||
if service:
|
||||
@ -162,7 +163,7 @@ class TestSuiteWorkflow:
|
||||
)
|
||||
and not service_connection_config.database
|
||||
):
|
||||
service_connection_config.database = table_fqn.split(".")[1]
|
||||
service_connection_config.database = entity_fqn.split(".")[1]
|
||||
if (
|
||||
hasattr(
|
||||
service_connection_config,
|
||||
@ -170,13 +171,13 @@ class TestSuiteWorkflow:
|
||||
)
|
||||
and not service_connection_config.catalog
|
||||
):
|
||||
service_connection_config.catalog = table_fqn.split(".")[1]
|
||||
service_connection_config.catalog = entity_fqn.split(".")[1]
|
||||
return service_connection_config
|
||||
|
||||
logger.error(f"Could not retrive connection details for entity {entity_link}")
|
||||
raise ValueError()
|
||||
|
||||
def _get_table_entity_from_test_case(self, table_fqn: str):
|
||||
def _get_table_entity_from_test_case(self, entity_fqn: str):
|
||||
"""given an entityLink return the table entity
|
||||
|
||||
Args:
|
||||
@ -184,7 +185,7 @@ class TestSuiteWorkflow:
|
||||
"""
|
||||
return self.metadata.get_by_name(
|
||||
entity=Table,
|
||||
fqn=table_fqn,
|
||||
fqn=entity_fqn,
|
||||
fields=["profile"],
|
||||
)
|
||||
|
||||
@ -235,22 +236,22 @@ class TestSuiteWorkflow:
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_sqa_tests_runner_interface(self, table_fqn: str):
|
||||
def _create_runner_interface(self, entity_fqn: str):
|
||||
"""create the interface to execute test against SQA sources"""
|
||||
table_entity = self._get_table_entity_from_test_case(table_fqn)
|
||||
return SQAInterface(
|
||||
table_entity = self._get_table_entity_from_test_case(entity_fqn)
|
||||
return SQATestSuiteInterface(
|
||||
service_connection_config=self._get_service_connection_from_test_case(
|
||||
table_fqn
|
||||
entity_fqn
|
||||
),
|
||||
metadata_config=self.metadata_config,
|
||||
ometa_client=create_ometa_client(self.metadata_config),
|
||||
table_entity=table_entity,
|
||||
profile_sample=self._get_profile_sample(table_entity)
|
||||
table_sample_precentage=self._get_profile_sample(table_entity)
|
||||
if not self._get_profile_query(table_entity)
|
||||
else None,
|
||||
profile_query=self._get_profile_query(table_entity)
|
||||
table_sample_query=self._get_profile_query(table_entity)
|
||||
if not self._get_profile_sample(table_entity)
|
||||
else None,
|
||||
partition_config=self._get_partition_details(table_entity)
|
||||
table_partition_config=self._get_partition_details(table_entity)
|
||||
if not self._get_profile_query(table_entity)
|
||||
else None,
|
||||
)
|
||||
@ -385,6 +386,15 @@ class TestSuiteWorkflow:
|
||||
|
||||
return created_test_case
|
||||
|
||||
def add_test_cases_from_cli_config(self, test_cases: list) -> list:
|
||||
cli_config_test_cases_def = self.get_test_case_from_cli_config()
|
||||
runtime_created_test_cases = self.compare_and_create_test_cases(
|
||||
cli_config_test_cases_def, test_cases
|
||||
)
|
||||
if runtime_created_test_cases:
|
||||
return runtime_created_test_cases
|
||||
return []
|
||||
|
||||
def execute(self):
|
||||
"""Execute test suite workflow"""
|
||||
test_suites = (
|
||||
@ -397,23 +407,19 @@ class TestSuiteWorkflow:
|
||||
|
||||
test_cases = self.get_test_cases_from_test_suite(test_suites)
|
||||
if self.processor_config.testSuites:
|
||||
cli_config_test_cases_def = self.get_test_case_from_cli_config()
|
||||
runtime_created_test_cases = self.compare_and_create_test_cases(
|
||||
cli_config_test_cases_def, test_cases
|
||||
)
|
||||
if runtime_created_test_cases:
|
||||
test_cases.extend(runtime_created_test_cases)
|
||||
test_cases.extend(self.add_test_cases_from_cli_config(test_cases))
|
||||
|
||||
unique_table_fqns = self._get_unique_table_entities(test_cases)
|
||||
unique_entity_fqns = self._get_unique_entities_from_test_cases(test_cases)
|
||||
|
||||
for table_fqn in unique_table_fqns:
|
||||
for entity_fqn in unique_entity_fqns:
|
||||
try:
|
||||
sqa_interface = self._create_sqa_tests_runner_interface(table_fqn)
|
||||
for test_case in self._filter_test_cases_for_table_entity(
|
||||
table_fqn, test_cases
|
||||
runner_interface = self._create_runner_interface(entity_fqn)
|
||||
data_test_runner = self._create_data_tests_runner(runner_interface)
|
||||
|
||||
for test_case in self._filter_test_cases_for_entity(
|
||||
entity_fqn, test_cases
|
||||
):
|
||||
try:
|
||||
data_test_runner = self._create_data_tests_runner(sqa_interface)
|
||||
test_result = data_test_runner.run_and_handle(test_case)
|
||||
if not test_result:
|
||||
continue
|
||||
@ -430,8 +436,8 @@ class TestSuiteWorkflow:
|
||||
)
|
||||
except TypeError as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Could not run test case for table {table_fqn}: {exc}")
|
||||
self.status.failure(table_fqn)
|
||||
logger.warning(f"Could not run test case for table {entity_fqn}: {exc}")
|
||||
self.status.failure(entity_fqn)
|
||||
|
||||
def print_status(self) -> None:
|
||||
"""
|
||||
|
||||
@ -15,7 +15,7 @@ Main class to run data tests
|
||||
|
||||
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.test_suite_protocol import TestSuiteProtocol
|
||||
from metadata.test_suite.runner.models import TestCaseResultResponse
|
||||
from metadata.utils.logger import test_suite_logger
|
||||
|
||||
@ -25,7 +25,7 @@ logger = test_suite_logger()
|
||||
class DataTestsRunner:
|
||||
"""class to execute the test validation"""
|
||||
|
||||
def __init__(self, test_runner_interface: SQAInterface):
|
||||
def __init__(self, test_runner_interface: TestSuiteProtocol):
|
||||
self.test_runner_interace = test_runner_interface
|
||||
|
||||
def run_and_handle(self, test_case: TestCase):
|
||||
|
||||
@ -34,6 +34,7 @@ from metadata.utils.logger import test_suite_logger
|
||||
logger = test_suite_logger()
|
||||
|
||||
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
def column_values_in_set(
|
||||
test_case: TestCase,
|
||||
execution_date: datetime,
|
||||
@ -70,9 +71,7 @@ def column_values_in_set(
|
||||
f"Cannot find the configured column {column_name} for test case {test_case.name.__root__}"
|
||||
)
|
||||
|
||||
set_count_dict = dict(
|
||||
runner.dispatch_query_select_first(set_count(col).fn())
|
||||
) # pylint: disable=abstract-class-instantiated
|
||||
set_count_dict = dict(runner.dispatch_query_select_first(set_count(col).fn()))
|
||||
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)
|
||||
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
|
||||
@ -99,8 +99,8 @@ def column_values_missing_count_to_be_equal(
|
||||
try:
|
||||
set_count_dict = dict(
|
||||
runner.dispatch_query_select_first(
|
||||
set_count(col).fn()
|
||||
) # pylint: disable=abstract-class-instantiated
|
||||
set_count(col).fn() # pylint: disable=abstract-class-instantiated
|
||||
)
|
||||
)
|
||||
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from metadata.utils.logger import test_suite_logger
|
||||
|
||||
logger = test_suite_logger()
|
||||
|
||||
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
def column_values_not_in_set(
|
||||
test_case: TestCase,
|
||||
execution_date: datetime,
|
||||
@ -69,9 +69,7 @@ def column_values_not_in_set(
|
||||
raise ValueError(
|
||||
f"Cannot find the configured column {column_name} for test case {test_case.name}"
|
||||
)
|
||||
set_count_dict = dict(
|
||||
runner.dispatch_query_select_first(set_count(col).fn())
|
||||
) # pylint: disable=abstract-class-instantiated
|
||||
set_count_dict = dict(runner.dispatch_query_select_first(set_count(col).fn()))
|
||||
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)
|
||||
|
||||
except Exception as exc:
|
||||
|
||||
@ -71,8 +71,8 @@ def column_values_to_match_regex(
|
||||
value_count_value_res = value_count_value_dict.get(Metrics.COUNT.name)
|
||||
like_count_value_dict = dict(
|
||||
runner.dispatch_query_select_first(
|
||||
like_count(col).fn()
|
||||
) # pylint: disable=abstract-class-instantiated
|
||||
like_count(col).fn() # pylint: disable=abstract-class-instantiated
|
||||
)
|
||||
)
|
||||
like_count_value_res = like_count_value_dict.get(Metrics.LIKE_COUNT.name)
|
||||
|
||||
|
||||
@ -71,8 +71,8 @@ def column_values_to_not_match_regex(
|
||||
|
||||
not_like_count_dict = dict(
|
||||
runner.dispatch_query_select_first(
|
||||
not_like_count(col).fn()
|
||||
) # pylint: disable=abstract-class-instantiated
|
||||
not_like_count(col).fn() # pylint: disable=abstract-class-instantiated
|
||||
)
|
||||
)
|
||||
not_like_count_res = not_like_count_dict.get(Metrics.NOT_LIKE_COUNT.name)
|
||||
|
||||
|
||||
@ -92,9 +92,8 @@ def table_column_to_match_set(
|
||||
None,
|
||||
)
|
||||
expected_column_names = [item.strip() for item in column_name.split(",")]
|
||||
compare = lambda x, y: collections.Counter(x) == collections.Counter(
|
||||
y
|
||||
) # pylint: disable=unnecessary-lambda-assignment
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
compare = lambda x, y: collections.Counter(x) == collections.Counter(y)
|
||||
|
||||
if ordered:
|
||||
_status = expected_column_names == [col.name for col in column_names]
|
||||
|
||||
@ -14,6 +14,7 @@ Helpers module for ingestion related methods
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
@ -33,6 +34,9 @@ from metadata.generated.schema.api.services.createStorageService import (
|
||||
)
|
||||
from metadata.generated.schema.entity.data.chart import Chart, ChartType
|
||||
from metadata.generated.schema.entity.data.table import Column, Table
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.dashboardService import DashboardService
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.services.messagingService import MessagingService
|
||||
@ -382,6 +386,30 @@ def list_to_dict(original: Optional[List[str]], sep: str = "=") -> Dict[str, str
|
||||
return dict(split_original)
|
||||
|
||||
|
||||
def create_ometa_client(
|
||||
metadata_config: OpenMetadataConnection,
|
||||
) -> OpenMetadata:
|
||||
"""Create an OpenMetadata client
|
||||
|
||||
Args:
|
||||
metadata_config (OpenMetadataConnection): OM connection config
|
||||
|
||||
Returns:
|
||||
OpenMetadata: an OM client
|
||||
"""
|
||||
try:
|
||||
metadata = OpenMetadata(metadata_config)
|
||||
metadata.health_check()
|
||||
return metadata
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"No OpenMetadata server configuration found. "
|
||||
f"Setting client to `None`. You won't be able to access the server from the client: {exc}"
|
||||
)
|
||||
raise ValueError(exc)
|
||||
|
||||
|
||||
def clean_up_starting_ending_double_quotes_in_string(string: str) -> str:
|
||||
"""Remove start and ending double quotes in a string
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Validate workflow e2e
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import sqlalchemy as sqa
|
||||
from sqlalchemy.orm import declarative_base
|
||||
@ -45,7 +46,7 @@ from metadata.generated.schema.entity.services.databaseService import (
|
||||
)
|
||||
from metadata.generated.schema.tests.testCase import TestCase
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.test_suite.api.workflow import TestSuiteWorkflow
|
||||
|
||||
test_suite_config = {
|
||||
@ -168,12 +169,14 @@ class TestE2EWorkflow(unittest.TestCase):
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
cls.sqlite_conn.config,
|
||||
table=User,
|
||||
table_entity=table,
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
cls.sqlite_conn.config,
|
||||
table_entity=table,
|
||||
ometa_client=None,
|
||||
)
|
||||
engine = sqa_profiler_interface.session.get_bind()
|
||||
session = sqa_profiler_interface.session
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ Test Metrics behavior
|
||||
import datetime
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import TEXT, Column, Date, DateTime, Integer, String, Time
|
||||
@ -26,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.metrics.core import add_props
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.orm.functions.sum import SumFn
|
||||
@ -71,16 +72,23 @@ class MetricsTest(TestCase):
|
||||
)
|
||||
],
|
||||
)
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn, table=User, table_entity=table_entity
|
||||
)
|
||||
engine = sqa_profiler_interface.session.get_bind()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
"""
|
||||
Prepare Ingredients
|
||||
"""
|
||||
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
cls.sqa_profiler_interface = SQAProfilerInterface(
|
||||
cls.sqlite_conn,
|
||||
table_entity=cls.table_entity,
|
||||
ometa_client=None,
|
||||
)
|
||||
cls.engine = cls.sqa_profiler_interface.session.get_bind()
|
||||
|
||||
User.__table__.create(bind=cls.engine)
|
||||
|
||||
data = [
|
||||
@ -688,9 +696,14 @@ class MetricsTest(TestCase):
|
||||
|
||||
EmptyUser.__table__.create(bind=self.engine)
|
||||
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
self.sqlite_conn, table=EmptyUser, table_entity=self.table_entity
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=EmptyUser
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
table_entity=self.table_entity,
|
||||
ometa_client=None,
|
||||
)
|
||||
|
||||
hist = add_props(bins=5)(Metrics.HISTOGRAM.value)
|
||||
res = (
|
||||
|
||||
@ -15,6 +15,7 @@ Test Profiler behavior
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -38,7 +39,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.metrics.core import add_props
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.profiler.core import MissingMetricException, Profiler
|
||||
@ -79,9 +80,12 @@ class ProfilerTest(TestCase):
|
||||
)
|
||||
],
|
||||
)
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn, table=User, table_entity=table_entity
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn, table_entity=table_entity, ometa_client=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
|
||||
@ -14,6 +14,7 @@ Test Sample behavior
|
||||
"""
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import TEXT, Column, Integer, String, func
|
||||
@ -25,7 +26,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.orm.registry import CustomTypes
|
||||
from metadata.orm_profiler.profiler.core import Profiler
|
||||
@ -68,9 +69,12 @@ class SampleTest(TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn, table=User, table_entity=table_entity
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn, table_entity=table_entity, ometa_client=None
|
||||
)
|
||||
engine = sqa_profiler_interface.session.get_bind()
|
||||
session = sqa_profiler_interface.session
|
||||
|
||||
@ -125,19 +129,21 @@ class SampleTest(TestCase):
|
||||
"""
|
||||
|
||||
# Randomly pick table_count to init the Profiler, we don't care for this test
|
||||
table_count = Metrics.ROW_COUNT.value
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
self.sqlite_conn,
|
||||
table=User,
|
||||
table_entity=self.table_entity,
|
||||
profile_sample=50,
|
||||
)
|
||||
profiler = Profiler(
|
||||
table_count,
|
||||
profiler_interface=sqa_profiler_interface,
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
table_entity=self.table_entity,
|
||||
table_sample_precentage=50,
|
||||
ometa_client=None,
|
||||
)
|
||||
|
||||
res = self.session.query(func.count()).select_from(profiler.sample).first()
|
||||
sample = sqa_profiler_interface._create_thread_safe_sampler(
|
||||
self.session, User
|
||||
).random_sample()
|
||||
|
||||
res = self.session.query(func.count()).select_from(sample).first()
|
||||
assert res[0] < 30
|
||||
|
||||
def test_table_row_count(self):
|
||||
@ -162,15 +168,18 @@ class SampleTest(TestCase):
|
||||
"""
|
||||
|
||||
count = Metrics.COUNT.value
|
||||
profiler = Profiler(
|
||||
count,
|
||||
profiler_interface=SQAInterface(
|
||||
self.sqlite_conn,
|
||||
table=User,
|
||||
table_entity=self.table_entity,
|
||||
profile_sample=50,
|
||||
),
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
profiler = Profiler(
|
||||
count,
|
||||
profiler_interface=SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
table_entity=self.table_entity,
|
||||
table_sample_precentage=50,
|
||||
ometa_client=None,
|
||||
),
|
||||
)
|
||||
res = profiler.compute_metrics()._column_results
|
||||
assert res.get(User.name.name)[Metrics.COUNT.name] < 30
|
||||
|
||||
@ -179,15 +188,18 @@ class SampleTest(TestCase):
|
||||
Histogram should run correctly
|
||||
"""
|
||||
hist = Metrics.HISTOGRAM.value
|
||||
profiler = Profiler(
|
||||
hist,
|
||||
profiler_interface=SQAInterface(
|
||||
self.sqlite_conn,
|
||||
table=User,
|
||||
table_entity=self.table_entity,
|
||||
profile_sample=50,
|
||||
),
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
profiler = Profiler(
|
||||
hist,
|
||||
profiler_interface=SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
table_entity=self.table_entity,
|
||||
table_sample_precentage=50,
|
||||
ometa_client=None,
|
||||
),
|
||||
)
|
||||
res = profiler.compute_metrics()._column_results
|
||||
|
||||
# The sum of all frequencies should be sampled
|
||||
|
||||
@ -16,6 +16,7 @@ Test SQA Interface
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from pytest import raises
|
||||
@ -38,7 +39,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.metrics.core import (
|
||||
ComposedMetric,
|
||||
MetricTypes,
|
||||
@ -74,24 +75,19 @@ class SQAInterfaceTest(TestCase):
|
||||
sqlite_conn = SQLiteConnection(
|
||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||
)
|
||||
self.sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn, table=User, table_entity=table_entity
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
self.sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn, table_entity=table_entity, ometa_client=None
|
||||
)
|
||||
self.table = User
|
||||
|
||||
def test_init_interface(self):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
|
||||
assert self.sqa_profiler_interface._sampler != None
|
||||
assert self.sqa_profiler_interface._runner != None
|
||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
||||
|
||||
def test_private_attributes(self):
|
||||
with raises(AttributeError):
|
||||
self.sqa_profiler_interface.runner = None
|
||||
self.sqa_profiler_interface.sampler = None
|
||||
self.sqa_profiler_interface.sample = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.sqa_profiler_interface._sampler = None
|
||||
|
||||
@ -113,9 +109,12 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||
databaseMode=db_path + "?check_same_thread=False",
|
||||
)
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn, table=User, table_entity=table_entity
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn, table_entity=table_entity, ometa_client=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
@ -152,8 +151,6 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
def test_init_interface(self):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
|
||||
assert self.sqa_profiler_interface._sampler != None
|
||||
assert self.sqa_profiler_interface._runner != None
|
||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
||||
|
||||
def test_get_all_metrics(self):
|
||||
|
||||
@ -33,7 +33,7 @@ from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline
|
||||
DatabaseServiceProfilerPipeline,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
|
||||
from metadata.orm_profiler.api.models import ProfilerProcessorConfig
|
||||
from metadata.orm_profiler.api.workflow import ProfilerWorkflow
|
||||
from metadata.orm_profiler.profiler.default import DefaultProfiler
|
||||
@ -88,7 +88,7 @@ class User(Base):
|
||||
|
||||
|
||||
@patch.object(
|
||||
SQAInterface,
|
||||
SQAProfilerInterface,
|
||||
"_convert_table_to_orm_object",
|
||||
return_value=User,
|
||||
)
|
||||
@ -190,7 +190,7 @@ def test_filter_entities(mocked_method):
|
||||
|
||||
|
||||
@patch.object(
|
||||
SQAInterface,
|
||||
SQAProfilerInterface,
|
||||
"_convert_table_to_orm_object",
|
||||
return_value=User,
|
||||
)
|
||||
@ -227,7 +227,7 @@ def test_profile_def(mocked_method, mocked_orm):
|
||||
|
||||
|
||||
@patch.object(
|
||||
SQAInterface,
|
||||
SQAProfilerInterface,
|
||||
"_convert_table_to_orm_object",
|
||||
return_value=User,
|
||||
)
|
||||
|
||||
@ -18,6 +18,7 @@ Each test should validate the Success, Failure and Aborted statuses
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sqa
|
||||
@ -32,7 +33,7 @@ from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
|
||||
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
|
||||
from metadata.generated.schema.tests.testSuite import TestSuite
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.interfaces.sqa_interface import SQAInterface
|
||||
from metadata.interfaces.sqalchemy.sqa_test_suite_interface import SQATestSuiteInterface
|
||||
from metadata.test_suite.validations.core import validation_enum_registry
|
||||
|
||||
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
|
||||
@ -78,11 +79,14 @@ class testSuiteValidation(unittest.TestCase):
|
||||
databaseMode=db_path + "?check_same_thread=False",
|
||||
)
|
||||
|
||||
sqa_profiler_interface = SQAInterface(
|
||||
sqlite_conn,
|
||||
table=User,
|
||||
table_entity=TABLE,
|
||||
)
|
||||
with patch.object(
|
||||
SQATestSuiteInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
sqa_profiler_interface = SQATestSuiteInterface(
|
||||
sqlite_conn,
|
||||
table_entity=TABLE,
|
||||
ometa_client=None,
|
||||
)
|
||||
runner = sqa_profiler_interface.runner
|
||||
engine = sqa_profiler_interface.session.get_bind()
|
||||
session = sqa_profiler_interface.session
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user