Fixes #7490 - Split Profiler and TestSuite Interface (#8032)

* Clean up test suite workflow and interface

* Fixed tests

* Split profiler and testSuite interfaces

* Cleaned up workflows and runners

* Fixed code formatting

* - remove old code
- remove `table` attribute used for testing and used mock instead

* Fixed execution bugs from refactor

* Fixed static type checking for profiler/api/workflow.py

* Fixed linting

* Added __init__ files
This commit is contained in:
Teddy 2022-10-11 15:57:25 +02:00 committed by GitHub
parent b914afcb6a
commit f883863b8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1067 additions and 794 deletions

View File

@ -0,0 +1,63 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from abc import ABC, abstractmethod
from typing import Dict, Union
from sqlalchemy import Column
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.orm_profiler.metrics.registry import Metrics
class ProfilerProtocol(ABC):
"""Protocol interface for the profiler processor"""
@abstractmethod
def __init__(
self,
ometa_client: OpenMetadata,
service_connection_config: Union[DatabaseConnection, DatalakeConnection],
):
"""Required attribute for the interface"""
raise NotImplementedError
@property
@abstractmethod
def table(self):
"""OM Table entity"""
raise NotImplementedError
@abstractmethod
def get_all_metrics(self, metric_funcs) -> dict:
"""run profiler metrics"""
raise NotImplementedError
@abstractmethod
def get_composed_metrics(
self, column: Column, metric: Metrics, column_results: Dict
) -> dict:
"""run profiler metrics"""
raise NotImplementedError
@abstractmethod
def fetch_sample_data(self, table) -> dict:
"""run profiler metrics"""
raise NotImplementedError

View File

@ -1,555 +0,0 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
import concurrent.futures
import threading
import traceback
from collections import defaultdict
from datetime import datetime, timezone
from typing import Dict, List, Optional, Union
from sqlalchemy import Column, MetaData, inspect
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import DeclarativeMeta, Session
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeType,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.api.processor import ProfilerProcessorStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.interface_protocol import InterfaceProtocol
from metadata.orm_profiler.api.models import TablePartitionConfig
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.orm.converter import ometa_to_orm
from metadata.orm_profiler.profiler.handle_partition import (
get_partition_cols,
is_partitioned,
)
from metadata.orm_profiler.profiler.runner import QueryRunner
from metadata.orm_profiler.profiler.sampler import Sampler
from metadata.test_suite.validations.core import validation_enum_registry
from metadata.utils.connections import (
create_and_bind_thread_safe_session,
get_connection,
test_connection,
)
from metadata.utils.constants import TEN_MIN
from metadata.utils.dispatch import enum_register
from metadata.utils.helpers import get_start_and_end
from metadata.utils.logger import sqa_interface_registry_logger
from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY
from metadata.utils.timeout import cls_timeout
logger = sqa_interface_registry_logger()
thread_local = threading.local()
class SQAInterface(InterfaceProtocol):
"""
Interface to interact with registry supporting
sqlalchemy.
"""
def __init__(
self,
service_connection_config,
sqa_metadata_obj: Optional[MetaData] = None,
metadata_config: Optional[OpenMetadataConnection] = None,
thread_count: Optional[int] = 5,
table_entity: Optional[Table] = None,
table: Optional[DeclarativeMeta] = None,
profile_sample: Optional[float] = None,
profile_query: Optional[str] = None,
partition_config: Optional[TablePartitionConfig] = None,
):
"""Instantiate SQA Interface object"""
self._thread_count = thread_count
self.table_entity = table_entity
self._create_ometa_obj(metadata_config)
self.processor_status = ProfilerProcessorStatus()
self.processor_status.entity = (
self.table_entity.fullyQualifiedName.__root__
if self.table_entity.fullyQualifiedName
else None
)
# Allows SQA Interface to be used without OM server config
self.table = table or self._convert_table_to_orm_object(sqa_metadata_obj)
self.service_connection_config = service_connection_config
self.session_factory = self._session_factory(service_connection_config)
self.session: Session = self.session_factory()
self.set_session_tag(self.session)
self.profile_sample = profile_sample
self.profile_query = profile_query
self.partition_details = (
self._get_partition_details(partition_config)
if not self.profile_query
else None
)
self._sampler = self.create_sampler()
self._runner = self.create_runner()
@property
def sample(self):
"""Getter method for sample attribute"""
if not self.sampler:
raise RuntimeError(
"You must create a sampler first `<instance>.create_sampler(...)`."
)
return self.sampler.random_sample()
@property
def runner(self):
"""Getter method for runner attribute"""
return self._runner
@property
def sampler(self):
"""Getter methid for sampler attribute"""
return self._sampler
@staticmethod
def _session_factory(service_connection_config):
"""Create thread safe session that will be automatically
garbage collected once the application thread ends
"""
engine = get_connection(service_connection_config)
return create_and_bind_thread_safe_session(engine)
def set_session_tag(self, session):
"""
Set session query tag
Args:
service_connection_config: connection details for the specific service
"""
if (
self.service_connection_config.type.value == SnowflakeType.Snowflake.value
and hasattr(self.service_connection_config, "queryTag")
and self.service_connection_config.queryTag
):
session.execute(
SNOWFLAKE_SESSION_TAG_QUERY.format(
query_tag=self.service_connection_config.queryTag
)
)
def _get_engine(self, service_connection_config):
"""Get engine for database
Args:
service_connection_config: connection details for the specific service
Returns:
sqlalchemy engine
"""
engine = get_connection(service_connection_config)
test_connection(engine)
return engine
def _get_partition_details(
self, partition_config: TablePartitionConfig
) -> Optional[Dict]:
"""From partition config, get the partition table for a table entity"""
if self.table_entity.serviceType == DatabaseServiceType.BigQuery:
if is_partitioned(self.session, self.table):
start, end = get_start_and_end(partition_config.partitionQueryDuration)
partition_details = {
"partition_field": partition_config.partitionField
or get_partition_cols(self.session, self.table),
"partition_start": start,
"partition_end": end,
"partition_values": partition_config.partitionValues,
}
return partition_details
return None
def _create_ometa_obj(self, metadata_config):
try:
self._metadata = OpenMetadata(metadata_config)
self._metadata.health_check()
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"No OpenMetadata server configuration found. Running profiler interface without OM server connection: {exc}"
)
self._metadata = None
def _create_thread_safe_sampler(
self,
session,
table,
):
"""Create thread safe runner"""
if not hasattr(thread_local, "sampler"):
thread_local.sampler = Sampler(
session=session,
table=table,
profile_sample=self.profile_sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.sampler
def _create_thread_safe_runner(
self,
session,
table,
sample,
):
"""Create thread safe runner"""
if not hasattr(thread_local, "runner"):
thread_local.runner = QueryRunner(
session=session,
table=table,
sample=sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.runner
def _convert_table_to_orm_object(
self, sqa_metadata_obj: Optional[MetaData]
) -> DeclarativeMeta:
"""Given a table entity return a SQA ORM object
Args:
sqa_metadata_obj: sqa metadata registry
Returns:
DeclarativeMeta
"""
return ometa_to_orm(self.table_entity, self._metadata, sqa_metadata_obj)
def get_columns(self) -> Column:
"""get columns from an orm object"""
return inspect(self.table).c
def compute_metrics_in_thread(
self,
metric_funcs,
):
"""Run metrics in processor worker"""
(
metrics,
metric_type,
column,
table,
) = metric_funcs
logger.debug(
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
)
Session = self.session_factory
with Session() as session:
self.set_session_tag(session)
sampler = self._create_thread_safe_sampler(
session,
table,
)
sample = sampler.random_sample()
runner = self._create_thread_safe_runner(
session,
table,
sample,
)
row = compute_metrics_registry.registry[metric_type.value](
metrics,
runner=runner,
session=session,
column=column,
sample=sample,
processor_status=self.processor_status,
)
if column is not None:
column = column.name
return row, column
def get_all_metrics(
self,
metric_funcs: list,
):
"""get all profiler metrics"""
logger.info(f"Computing metrics with {self._thread_count} threads.")
profile_results = {"table": dict(), "columns": defaultdict(dict)}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._thread_count
) as executor:
futures = [
executor.submit(
self.compute_metrics_in_thread,
metric_func,
)
for metric_func in metric_funcs
]
for future in concurrent.futures.as_completed(futures):
profile, column = future.result()
if not isinstance(profile, dict):
profile = dict()
if not column:
profile_results["table"].update(profile)
else:
profile_results["columns"][column].update(
{
"name": column,
"timestamp": datetime.now(tz=timezone.utc).timestamp(),
**profile,
}
)
return profile_results
def fetch_sample_data(self):
if not self.sampler:
raise RuntimeError(
"You must create a sampler first `<instance>.create_sampler(...)`."
)
return self.sampler.fetch_sample_data()
def create_sampler(self) -> Sampler:
"""Create sampler instance"""
return Sampler(
session=self.session,
table=self.table,
profile_sample=self.profile_sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
def create_runner(self) -> None:
"""Create a QueryRunner Instance"""
return cls_timeout(TEN_MIN)(
QueryRunner(
session=self.session,
table=self.table,
sample=self.sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_sample,
)
)
def get_composed_metrics(
self, column: Column, metric: Metrics, column_results: Dict
):
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
return metric(column).fn(column_results)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Unexpected exception computing metrics: {exc}")
self.session.rollback()
def run_test_case(
self,
test_case: TestCase,
) -> Optional[TestCaseResult]:
"""Run table tests where platformsTest=OpenMetadata
Args:
table_test_type: test type to be ran
table_profile: table profile
table: SQA table,
profile_sample: sample for the profile
"""
try:
return validation_enum_registry.registry[
test_case.testDefinition.fullyQualifiedName
](
test_case,
execution_date=datetime.now(tz=timezone.utc).timestamp(),
runner=self.runner,
)
except KeyError as err:
logger.warning(
f"Test definition {test_case.testDefinition.fullyQualifiedName} not registered in OpenMetadata TestDefintion registry. Skipping test case {test_case.name.__root__}"
)
return None
def close(self):
"""close session"""
self.session.close()
def get_table_metrics(
metrics: List[Metrics],
runner: QueryRunner,
session: Session,
*args,
**kwargs,
):
"""Given a list of metrics, compute the given results
and returns the values
Args:
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(*[metric().fn() for metric in metrics])
if row:
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}: {exc}"
)
session.rollback()
def get_static_metrics(
metrics: List[Metrics],
runner: QueryRunner,
session: Session,
column: Column,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Dict[str, Union[str, int]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(
*[
metric(column).fn()
for metric in metrics
if not metric.is_window_metric()
]
)
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Static Metrics", f"{exc}")
def get_query_metrics(
metric: Metrics,
runner: QueryRunner,
session: Session,
column: Column,
sample,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Optional[Dict[str, Union[str, int]]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
col_metric = metric(column)
metric_query = col_metric.query(sample=sample, session=session)
if not metric_query:
return None
if col_metric.metric_type == dict:
results = runner.select_all_from_query(metric_query)
data = {k: [result[k] for result in results] for k in dict(results[0])}
return {metric.name(): data}
else:
row = runner.select_first_from_query(metric_query)
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Query Metrics", f"{exc}")
def get_window_metrics(
metric: Metrics,
runner: QueryRunner,
session: Session,
column: Column,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Dict[str, Union[str, int]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(metric(column).fn())
if not isinstance(row, Row):
return {metric.name(): row}
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Window Metrics", f"{exc}")
compute_metrics_registry = enum_register()
compute_metrics_registry.add("Static")(get_static_metrics)
compute_metrics_registry.add("Table")(get_table_metrics)
compute_metrics_registry.add("Query")(get_query_metrics)
compute_metrics_registry.add("Window")(get_window_metrics)

View File

@ -0,0 +1,123 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from typing import Dict, Optional
from sqlalchemy import Column, MetaData, inspect
from sqlalchemy.orm import DeclarativeMeta
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeType,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.orm_profiler.api.models import TablePartitionConfig
from metadata.orm_profiler.orm.converter import ometa_to_orm
from metadata.orm_profiler.profiler.handle_partition import (
get_partition_cols,
is_partitioned,
)
from metadata.utils.connections import get_connection
from metadata.utils.helpers import get_start_and_end
from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY
class SQAInterfaceMixin:
"""SQLAlchemy inteface mixin grouping shared methods between sequential and threaded executor"""
@property
def table(self):
"""OM Table entity"""
return self._table
def _get_engine(self):
"""Get engine for database
Args:
service_connection_config: connection details for the specific service
Returns:
sqlalchemy engine
"""
engine = get_connection(self.service_connection_config)
return engine
def _convert_table_to_orm_object(
self,
sqa_metadata_obj: Optional[MetaData] = None,
) -> DeclarativeMeta:
"""Given a table entity return a SQA ORM object
Args:
sqa_metadata_obj: sqa metadata registry
Returns:
DeclarativeMeta
"""
return ometa_to_orm(self.table_entity, self.ometa_client, sqa_metadata_obj)
def get_columns(self) -> Column:
"""get columns from an orm object"""
return inspect(self.table).c
def set_session_tag(self, session) -> None:
"""
Set session query tag for snowflake
Args:
service_connection_config: connection details for the specific service
"""
if (
self.service_connection_config.type.value == SnowflakeType.Snowflake.value
and hasattr(self.service_connection_config, "queryTag")
and self.service_connection_config.queryTag
):
session.execute(
SNOWFLAKE_SESSION_TAG_QUERY.format(
query_tag=self.service_connection_config.queryTag
)
)
def get_partition_details(
self, partition_config: TablePartitionConfig
) -> Optional[Dict]:
"""From partition config, get the partition table for a table entity
Args:
partition_config: TablePartitionConfig object with some partition details
Returns:
dict or None: dictionnary with all the elements constituing the a partition
"""
if (
self.table_entity.serviceType == DatabaseServiceType.BigQuery
and is_partitioned(self.session, self.table)
):
start, end = get_start_and_end(partition_config.partitionQueryDuration)
return {
"partition_field": partition_config.partitionField
or get_partition_cols(self.session, self.table),
"partition_start": start,
"partition_end": end,
"partition_values": partition_config.partitionValues,
}
return None
def close(self):
"""close session"""
self.session.close()

View File

@ -0,0 +1,247 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
import concurrent.futures
import threading
import traceback
from collections import defaultdict
from datetime import datetime, timezone
from typing import Dict, Optional
from sqlalchemy import Column, MetaData
from metadata.generated.schema.entity.data.table import Table, TableData
from metadata.ingestion.api.processor import ProfilerProcessorStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.profiler_protocol import ProfilerProtocol
from metadata.interfaces.sqalchemy.mixins.sqa_mixin import SQAInterfaceMixin
from metadata.orm_profiler.api.models import TablePartitionConfig
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.metrics.sqa_metrics_computation_registry import (
compute_metrics_registry,
)
from metadata.orm_profiler.profiler.runner import QueryRunner
from metadata.orm_profiler.profiler.sampler import Sampler
from metadata.utils.connections import (
create_and_bind_thread_safe_session,
get_connection,
)
from metadata.utils.logger import sqa_interface_registry_logger
logger = sqa_interface_registry_logger()
thread_local = threading.local()
class SQAProfilerInterface(SQAInterfaceMixin, ProfilerProtocol):
"""
Interface to interact with registry supporting
sqlalchemy.
"""
# pylint: disable=too-many-arguments
def __init__(
self,
service_connection_config,
ometa_client: OpenMetadata,
sqa_metadata_obj: Optional[MetaData] = None,
thread_count: Optional[float] = 5,
table_entity: Optional[Table] = None,
table_sample_precentage: Optional[float] = None,
table_sample_query: Optional[str] = None,
table_partition_config: Optional[TablePartitionConfig] = None,
):
"""Instantiate SQA Interface object"""
self._thread_count = thread_count
self.table_entity = table_entity
self.ometa_client = ometa_client
self.service_connection_config = service_connection_config
self.processor_status = ProfilerProcessorStatus()
self.processor_status.entity = (
self.table_entity.fullyQualifiedName.__root__
if self.table_entity.fullyQualifiedName
else None
)
self._table = self._convert_table_to_orm_object(sqa_metadata_obj)
self.session_factory = self._session_factory(service_connection_config)
self.session = self.session_factory()
self.set_session_tag(self.session)
self.profile_sample = table_sample_precentage
self.profile_query = table_sample_query
self.partition_details = (
self.get_partition_details(table_partition_config)
if not self.profile_query
else None
)
@staticmethod
def _session_factory(service_connection_config):
"""Create thread safe session that will be automatically
garbage collected once the application thread ends
"""
engine = get_connection(service_connection_config)
return create_and_bind_thread_safe_session(engine)
def _create_thread_safe_sampler(
self,
session,
table,
):
"""Create thread safe runner"""
if not hasattr(thread_local, "sampler"):
thread_local.sampler = Sampler(
session=session,
table=table,
profile_sample=self.profile_sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.sampler
def _create_thread_safe_runner(
self,
session,
table,
sample,
):
"""Create thread safe runner"""
if not hasattr(thread_local, "runner"):
thread_local.runner = QueryRunner(
session=session,
table=table,
sample=sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.runner
def compute_metrics_in_thread(
self,
metric_funcs,
):
"""Run metrics in processor worker"""
(
metrics,
metric_type,
column,
table,
) = metric_funcs
logger.debug(
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
)
Session = self.session_factory # pylint: disable=invalid-name
with Session() as session:
self.set_session_tag(session)
sampler = self._create_thread_safe_sampler(
session,
table,
)
sample = sampler.random_sample()
runner = self._create_thread_safe_runner(
session,
table,
sample,
)
row = compute_metrics_registry.registry[metric_type.value](
metrics,
runner=runner,
session=session,
column=column,
sample=sample,
processor_status=self.processor_status,
)
if column is not None:
column = column.name
return row, column
# pylint: disable=use-dict-literal
def get_all_metrics(
self,
metric_funcs: list,
):
"""get all profiler metrics"""
logger.info(f"Computing metrics with {self._thread_count} threads.")
profile_results = {"table": dict(), "columns": defaultdict(dict)}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._thread_count
) as executor:
futures = [
executor.submit(
self.compute_metrics_in_thread,
metric_func,
)
for metric_func in metric_funcs
]
for future in concurrent.futures.as_completed(futures):
profile, column = future.result()
if not isinstance(profile, dict):
profile = dict()
if not column:
profile_results["table"].update(profile)
else:
profile_results["columns"][column].update(
{
"name": column,
"timestamp": datetime.now(tz=timezone.utc).timestamp(),
**profile,
}
)
return profile_results
def fetch_sample_data(self, table) -> TableData:
"""Fetch sample data from database
Args:
table: ORM declarative table
Returns:
TableData: sample table data
"""
sampler = Sampler(
session=self.session,
table=table,
profile_sample=self.profile_sample,
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return sampler.fetch_sample_data()
def get_composed_metrics(
self, column: Column, metric: Metrics, column_results: Dict
):
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
return metric(column).fn(column_results)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Unexpected exception computing metrics: {exc}")
self.session.rollback()
return None

View File

@ -0,0 +1,159 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from datetime import datetime, timezone
from typing import Optional, Union
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.orm.util import AliasedClass
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.sqalchemy.mixins.sqa_mixin import SQAInterfaceMixin
from metadata.interfaces.test_suite_protocol import TestSuiteProtocol
from metadata.orm_profiler.profiler.runner import QueryRunner
from metadata.orm_profiler.profiler.sampler import Sampler
from metadata.test_suite.validations.core import validation_enum_registry
from metadata.utils.connections import create_and_bind_session, get_connection
from metadata.utils.constants import TEN_MIN
from metadata.utils.logger import sqa_interface_registry_logger
from metadata.utils.timeout import cls_timeout
logger = sqa_interface_registry_logger()
class SQATestSuiteInterface(SQAInterfaceMixin, TestSuiteProtocol):
"""
Sequential interface protocol for testSuite and Profiler. This class
implements specific operations needed to run profiler and test suite workflow
against a SQAlchemy source.
"""
def __init__(
self,
service_connection_config: DatabaseConnection,
ometa_client: OpenMetadata,
table_sample_precentage: float = None,
table_sample_query: str = None,
table_partition_config: dict = None,
table_entity: Table = None,
):
self.ometa_client = ometa_client
self.table_entity = table_entity
self.service_connection_config = service_connection_config
self.session = create_and_bind_session(
get_connection(self.service_connection_config)
)
self.set_session_tag(self.session)
self._table = self._convert_table_to_orm_object()
self.table_sample_precentage = table_sample_precentage
self.table_sample_query = table_sample_query
self.table_partition_config = (
self.get_partition_details(table_partition_config)
if not self.table_sample_query
else None
)
self._sampler = self._create_sampler()
self._runner = self._create_runner()
@property
def sample(self) -> Union[DeclarativeMeta, AliasedClass]:
"""_summary_
Returns:
Union[DeclarativeMeta, AliasedClass]: _description_
"""
if not self.sampler:
raise RuntimeError(
"You must create a sampler first `<instance>.create_sampler(...)`."
)
return self.sampler.random_sample()
@property
def runner(self) -> QueryRunner:
"""getter method for the QueryRunner object
Returns:
QueryRunner: runner object
"""
return self._runner
@property
def sampler(self) -> Sampler:
"""getter method for the Runner object
Returns:
Sampler: sampler object
"""
return self._sampler
def _create_sampler(self) -> Sampler:
"""Create sampler instance"""
return Sampler(
session=self.session,
table=self.table,
profile_sample=self.table_sample_precentage,
partition_details=self.table_partition_config,
profile_sample_query=self.table_sample_query,
)
def _create_runner(self) -> None:
"""Create a QueryRunner Instance"""
return cls_timeout(TEN_MIN)(
QueryRunner(
session=self.session,
table=self.table,
sample=self.sample,
partition_details=self.table_partition_config,
profile_sample_query=self.table_sample_query,
)
)
def run_test_case(
self,
test_case: TestCase,
) -> Optional[TestCaseResult]:
"""Run table tests where platformsTest=OpenMetadata
Args:
test_case: test case object to execute
Returns:
TestCaseResult object
"""
try:
return validation_enum_registry.registry[
test_case.testDefinition.fullyQualifiedName
](
test_case,
execution_date=datetime.now(tz=timezone.utc).timestamp(),
runner=self.runner,
)
except KeyError as err:
logger.warning(
f"Test definition {test_case.testDefinition.fullyQualifiedName} not registered in OpenMetadata "
f"TestDefintion registry. Skipping test case {test_case.name.__root__} - {err}"
)
return None

View File

@ -15,42 +15,30 @@ supporting sqlalchemy abstraction layer
"""
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.orm_profiler.api.models import ProfilerProcessorConfig
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.ometa.ometa_api import OpenMetadata
class InterfaceProtocol(ABC):
class TestSuiteProtocol(ABC):
"""Protocol interface for the processor"""
@abstractmethod
def __init__(
self,
metadata_config: OpenMetadataConnection = None,
profiler_config: ProfilerProcessorConfig = None,
workflow_profile_sample: float = None,
thread_count: int = 5,
table: Table = None,
ometa_client: OpenMetadata = None,
service_connection_config: Union[DatabaseConnection, DatalakeConnection] = None,
):
"""Required attribute for the interface"""
raise NotImplementedError
@abstractmethod
def create_sampler(*args, **kwargs) -> None:
"""Method to instantiate a Sampler object"""
raise NotImplementedError
@abstractmethod
def create_runner(*args, **kwargs) -> None:
"""Method to instantiate a Runner object"""
raise NotImplementedError
@abstractmethod
def run_test_case(*args, **kwargs) -> Optional[TestCaseResult]:
def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]:
"""run column data quality tests"""
raise NotImplementedError

View File

@ -18,7 +18,7 @@ Workflow definition for the ORM Profiler.
"""
import traceback
from copy import deepcopy
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, cast
from pydantic import ValidationError
from sqlalchemy import MetaData
@ -30,7 +30,6 @@ from metadata.generated.schema.entity.data.table import (
ColumnProfilerConfig,
IntervalType,
Table,
TableProfile,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
@ -48,9 +47,11 @@ from metadata.ingestion.api.processor import ProcessorStatus
from metadata.ingestion.api.sink import Sink
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.profiler_protocol import ProfilerProtocol
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.api.models import (
ProfilerProcessorConfig,
ProfilerResponse,
TableConfig,
TablePartitionConfig,
)
@ -63,12 +64,17 @@ from metadata.utils.class_helper import (
get_service_type_from_source_type,
)
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.helpers import create_ometa_client
from metadata.utils.logger import profiler_logger
from metadata.utils.workflow_output_handler import print_profiler_status
logger = profiler_logger()
class ProfilerInterfaceInstantiationError(Exception):
"""Raise when interface cannot be instantiated"""
class ProfilerWorkflow:
"""
Configure and run the ORM profiler
@ -93,9 +99,9 @@ class ProfilerWorkflow:
self._retrieve_service_connection_if_needed()
# Init and type the source config
self.source_config: DatabaseServiceProfilerPipeline = (
self.config.source.sourceConfig.config
)
self.source_config: DatabaseServiceProfilerPipeline = cast(
DatabaseServiceProfilerPipeline, self.config.source.sourceConfig.config
) # Used to satisfy type checked
self.source_status = SQLSourceStatus()
self.status = ProcessorStatus()
@ -143,14 +149,14 @@ class ProfilerWorkflow:
table_config
for table_config in self.profiler_config.tableConfig
if table_config.fullyQualifiedName.__root__
== entity.fullyQualifiedName.__root__
== entity.fullyQualifiedName.__root__ # type: ignore
),
None,
)
def get_include_columns(self, entity) -> Optional[List[ColumnProfilerConfig]]:
"""get included columns"""
entity_config: TableConfig = self.get_config_for_entity(entity)
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
if entity_config and entity_config.columnConfig:
return entity_config.columnConfig.includeColumns
@ -161,7 +167,7 @@ class ProfilerWorkflow:
def get_exclude_columns(self, entity) -> Optional[List[str]]:
"""get included columns"""
entity_config: TableConfig = self.get_config_for_entity(entity)
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
if entity_config and entity_config.columnConfig:
return entity_config.columnConfig.excludeColumns
@ -176,7 +182,7 @@ class ProfilerWorkflow:
Args:
entity: table entity
"""
entity_config: TableConfig = self.get_config_for_entity(entity)
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
if entity_config:
return entity_config.profileSample
@ -188,13 +194,13 @@ class ProfilerWorkflow:
return None
def get_profile_query(self, entity: Table) -> Optional[float]:
def get_profile_query(self, entity: Table) -> Optional[str]:
"""Get profile sample
Args:
entity: table entity
"""
entity_config: TableConfig = self.get_config_for_entity(entity)
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
if entity_config:
return entity_config.profileQuery
@ -215,7 +221,7 @@ class ProfilerWorkflow:
):
return None
entity_config: TableConfig = self.get_config_for_entity(entity)
entity_config: Optional[TableConfig] = self.get_config_for_entity(entity)
if entity_config:
return entity_config.partitionConfig
@ -240,25 +246,30 @@ class ProfilerWorkflow:
sqa_metadata_obj: Optional[MetaData] = None,
):
"""Creates a profiler interface object"""
return SQAInterface(
service_connection_config,
sqa_metadata_obj=sqa_metadata_obj,
metadata_config=self.metadata_config,
thread_count=self.source_config.threadCount,
table_entity=table_entity,
profile_sample=self.get_profile_sample(table_entity)
if not self.get_profile_query(table_entity)
else None,
profile_query=self.get_profile_query(table_entity)
if not self.get_profile_sample(table_entity)
else None,
partition_config=self.get_partition_details(table_entity)
if not self.get_profile_query(table_entity)
else None,
)
try:
return SQAProfilerInterface(
service_connection_config,
sqa_metadata_obj=sqa_metadata_obj,
ometa_client=create_ometa_client(self.metadata_config),
thread_count=self.source_config.threadCount,
table_entity=table_entity,
table_sample_precentage=self.get_profile_sample(table_entity)
if not self.get_profile_query(table_entity)
else None,
table_sample_query=self.get_profile_query(table_entity)
if not self.get_profile_sample(table_entity)
else None,
table_partition_config=self.get_partition_details(table_entity)
if not self.get_profile_query(table_entity)
else None,
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error("We could not create a profiler interface")
raise ProfilerInterfaceInstantiationError(exc)
def create_profiler_obj(
self, table_entity: Table, profiler_interface: SQAInterface
self, table_entity: Table, profiler_interface: ProfilerProtocol
):
"""Profile a single entity"""
if not self.profiler_config.profiler:
@ -275,7 +286,7 @@ class ProfilerWorkflow:
)
self.profiler_obj = Profiler(
*metrics,
*metrics, # type: ignore
profiler_interface=profiler_interface,
include_columns=self.get_include_columns(table_entity),
exclude_columns=self.get_exclude_columns(table_entity),
@ -293,7 +304,7 @@ class ProfilerWorkflow:
return None
return database
def filter_entities(self, tables: List[Table]) -> Iterable[Table]:
def filter_entities(self, tables: Iterable[Table]) -> Iterable[Table]:
"""
From a list of tables, apply the SQLSourceConfig
filter patterns.
@ -304,10 +315,10 @@ class ProfilerWorkflow:
try:
if filter_by_schema(
self.source_config.schemaFilterPattern,
table.databaseSchema.name,
table.databaseSchema.name, # type: ignore
):
self.source_status.filter(
f"Schema pattern not allowed: {table.fullyQualifiedName.__root__}",
f"Schema pattern not allowed: {table.fullyQualifiedName.__root__}", # type: ignore
"Schema pattern not allowed",
)
continue
@ -316,7 +327,7 @@ class ProfilerWorkflow:
table.name.__root__,
):
self.source_status.filter(
f"Table pattern not allowed: {table.fullyQualifiedName.__root__}",
f"Table pattern not allowed: {table.fullyQualifiedName.__root__}", # type: ignore
"Table pattern not allowed",
)
continue
@ -325,9 +336,10 @@ class ProfilerWorkflow:
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.warning(
f"Unexpected error filtering entities for table [{table.fullyQualifiedName.__root__}]: {exc}"
"Unexpected error filtering entities for table "
f"[{table.fullyQualifiedName.__root__}]: {exc}" # type: ignore
)
self.source_status.failure(table.fullyQualifiedName.__root__, f"{exc}")
self.source_status.failure(table.fullyQualifiedName.__root__, f"{exc}") # type: ignore
def get_database_entities(self):
"""List all databases in service"""
@ -370,25 +382,28 @@ class ProfilerWorkflow:
service_name=self.config.source.serviceName,
database_name=database.name.__root__,
),
},
}, # type: ignore
)
yield from self.filter_entities(all_tables)
def copy_service_config(self, database) -> None:
def copy_service_config(self, database) -> DatabaseService.__config__:
copy_service_connection_config = deepcopy(
self.config.source.serviceConnection.__root__.config
self.config.source.serviceConnection.__root__.config # type: ignore
)
if hasattr(
self.config.source.serviceConnection.__root__.config,
self.config.source.serviceConnection.__root__.config, # type: ignore
"supportsDatabase",
):
if hasattr(
self.config.source.serviceConnection.__root__.config, "database"
):
copy_service_connection_config.database = database.name.__root__
if hasattr(self.config.source.serviceConnection.__root__.config, "catalog"):
copy_service_connection_config.catalog = database.name.__root__
if hasattr(copy_service_connection_config, "database"):
copy_service_connection_config.database = database.name.__root__ # type: ignore
if hasattr(copy_service_connection_config, "catalog"):
copy_service_connection_config.catalog = database.name.__root__ # type: ignore
# we know we'll only be working with databaseServices, we cast the type to satisfy type checker
copy_service_connection_config = cast(
DatabaseService.__config__, copy_service_connection_config
)
return copy_service_connection_config
@ -402,8 +417,8 @@ class ProfilerWorkflow:
if not databases:
raise ValueError(
"databaseFilterPattern returned 0 result. At least 1 database must be returned by the filter pattern."
f"\n\t- includes: {self.source_config.databaseFilterPattern.includes}\n\t"
f"- excludes: {self.source_config.databaseFilterPattern.excludes}"
f"\n\t- includes: {self.source_config.databaseFilterPattern.includes if self.source_config.databaseFilterPattern else None}" # pylint: disable=line-too-long
f"\n\t- excludes: {self.source_config.databaseFilterPattern.excludes if self.source_config.databaseFilterPattern else None}" # pylint: disable=line-too-long
)
for database in databases:
@ -416,7 +431,7 @@ class ProfilerWorkflow:
copied_service_config, entity, sqa_metadata_obj
)
self.create_profiler_obj(entity, profiler_interface)
profile: TableProfile = self.profiler_obj.process(
profile: ProfilerResponse = self.profiler_obj.process(
self.source_config.generateSampleData
)
profiler_interface.close()
@ -425,18 +440,19 @@ class ProfilerWorkflow:
self.status.failures.extend(
profiler_interface.processor_status.failures
) # we can have column level failures we need to report
self.status.processed(entity.fullyQualifiedName.__root__)
self.source_status.scanned(entity.fullyQualifiedName.__root__)
self.status.processed(entity.fullyQualifiedName.__root__) # type: ignore
self.source_status.scanned(entity.fullyQualifiedName.__root__) # type: ignore
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.warning(
f"Unexpected exception processing entity [{entity.fullyQualifiedName.__root__}]: {exc}"
"Unexpected exception processing entity "
f"[{entity.fullyQualifiedName.__root__}]: {exc}" # type: ignore
)
self.status.failures.extend(
profiler_interface.processor_status.failures
profiler_interface.processor_status.failures # type: ignore
)
self.source_status.failure(
entity.fullyQualifiedName.__root__, f"{exc}"
entity.fullyQualifiedName.__root__, f"{exc}" # type: ignore
)
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())

View File

@ -0,0 +1,185 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenMetadata Profiler supported metrics
Use these registries to avoid messy imports.
Note that we are using our own Registry definition
that allows us to directly call our metrics without
having the verbosely pass .value all the time...
"""
# pylint: disable=unused-argument
import traceback
from typing import Dict, List, Optional, Union
from sqlalchemy import Column
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
from metadata.ingestion.api.processor import ProfilerProcessorStatus
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.profiler.runner import QueryRunner
from metadata.utils.dispatch import enum_register
from metadata.utils.logger import sqa_interface_registry_logger
logger = sqa_interface_registry_logger()
def get_table_metrics(
metrics: List[Metrics],
runner: QueryRunner,
session: Session,
*args,
**kwargs,
):
"""Given a list of metrics, compute the given results
and returns the values
Args:
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(*[metric().fn() for metric in metrics])
if row:
return dict(row)
return None
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}: {exc}"
)
session.rollback()
return None
def get_static_metrics(
metrics: List[Metrics],
runner: QueryRunner,
session: Session,
column: Column,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Optional[Dict[str, Union[str, int]]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(
*[
metric(column).fn()
for metric in metrics
if not metric.is_window_metric()
]
)
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Static Metrics", f"{exc}")
return None
def get_query_metrics(
metric: Metrics,
runner: QueryRunner,
session: Session,
column: Column,
sample,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Optional[Dict[str, Union[str, int]]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
col_metric = metric(column)
metric_query = col_metric.query(sample=sample, session=session)
if not metric_query:
return None
if col_metric.metric_type == dict:
results = runner.select_all_from_query(metric_query)
data = {k: [result[k] for result in results] for k in dict(results[0])}
return {metric.name(): data}
row = runner.select_first_from_query(metric_query)
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Query Metrics", f"{exc}")
return None
def get_window_metrics(
metric: Metrics,
runner: QueryRunner,
session: Session,
column: Column,
processor_status: ProfilerProcessorStatus,
*args,
**kwargs,
) -> Dict[str, Union[str, int]]:
"""Given a list of metrics, compute the given results
and returns the values
Args:
column: the column to compute the metrics against
metrics: list of metrics to compute
Returns:
dictionnary of results
"""
try:
row = runner.select_first_from_sample(metric(column).fn())
if not isinstance(row, Row):
return {metric.name(): row}
return dict(row)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
)
session.rollback()
processor_status.failure(f"{column.name}", "Window Metrics", f"{exc}")
return None
compute_metrics_registry = enum_register()
compute_metrics_registry.add("Static")(get_static_metrics)
compute_metrics_registry.add("Table")(get_table_metrics)
compute_metrics_registry.add("Query")(get_query_metrics)
compute_metrics_registry.add("Window")(get_window_metrics)

View File

@ -47,6 +47,7 @@ def _(elements, compiler, **kwargs):
return "median(%s)" % col
# pylint: disable=unused-argument
@compiles(MedianFn, Dialects.Trino)
@compiles(MedianFn, Dialects.Presto)
def _(elements, compiler, **kwargs):

View File

@ -15,14 +15,12 @@ Main Profile definition and queries to execute
from __future__ import annotations
import traceback
import warnings
from datetime import datetime, timezone
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type
from pydantic import ValidationError
from sqlalchemy import Column
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.orm.session import Session
from metadata.generated.schema.api.data.createTableProfile import (
CreateTableProfileRequest,
@ -32,8 +30,7 @@ from metadata.generated.schema.entity.data.table import (
ColumnProfilerConfig,
TableProfile,
)
from metadata.interfaces.interface_protocol import InterfaceProtocol
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.profiler_protocol import ProfilerProtocol
from metadata.orm_profiler.api.models import ProfilerResponse
from metadata.orm_profiler.metrics.core import (
ComposedMetric,
@ -71,10 +68,10 @@ class Profiler(Generic[TMetric]):
def __init__(
self,
*metrics: Type[TMetric],
profiler_interface: InterfaceProtocol,
profiler_interface: ProfilerProtocol,
profile_date: datetime = datetime.now(tz=timezone.utc).timestamp(),
include_columns: List[Optional[ColumnProfilerConfig]] = None,
exclude_columns: List[Optional[str]] = None,
include_columns: Optional[List[ColumnProfilerConfig]] = None,
exclude_columns: Optional[List[str]] = None,
):
"""
:param metrics: Metrics to run. We are receiving the uninitialized classes
@ -99,21 +96,6 @@ class Profiler(Generic[TMetric]):
# We will get columns from the property
self._columns: Optional[List[Column]] = None
@property
def session(self) -> Session:
"""Kept for backward compatibility"""
warnings.warn(
"`<instance>`.session will be retired as platform specific. Instead use"
"`<instance>.profiler_interface` to see if session is supported by the profiler interface",
DeprecationWarning,
)
if isinstance(self.profiler_interface, SQAInterface):
return self.profiler_interface.session
raise ValueError(
f"session is not supported for profiler interface {self.profiler_interface.__class__.__name__}"
)
@property
def table(self) -> DeclarativeMeta:
return self.profiler_interface.table
@ -392,7 +374,7 @@ class Profiler(Generic[TMetric]):
return self
def process(self, generate_sample_data: bool) -> ProfilerResponse:
def process(self, generate_sample_data: Optional[bool]) -> ProfilerResponse:
"""
Given a table, we will prepare the profiler for
all its columns and return all the run profilers
@ -408,7 +390,7 @@ class Profiler(Generic[TMetric]):
logger.info(
f"Fetching sample data for {self.profiler_interface.table_entity.fullyQualifiedName.__root__}..."
)
sample_data = self.profiler_interface.fetch_sample_data()
sample_data = self.profiler_interface.fetch_sample_data(self.table)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warning(f"Error fetching sample data: {err}")

View File

@ -17,7 +17,7 @@ from typing import List, Optional
from sqlalchemy.orm import DeclarativeMeta
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.profiler_protocol import ProfilerProtocol
from metadata.orm_profiler.metrics.core import Metric, add_props
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.profiler.core import Profiler
@ -58,9 +58,9 @@ class DefaultProfiler(Profiler):
def __init__(
self,
profiler_interface: SQAInterface,
include_columns: List[Optional[ColumnProfilerConfig]] = None,
exclude_columns: List[Optional[str]] = None,
profiler_interface: ProfilerProtocol,
include_columns: Optional[List[ColumnProfilerConfig]] = None,
exclude_columns: Optional[List[str]] = None,
):
_metrics = get_default_metrics(profiler_interface.table)

View File

@ -56,7 +56,8 @@ class Sampler:
self.table, (ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL)
)
.suffix_with(
f"SAMPLE SYSTEM ({self.profile_sample})", dialect=Dialects.Snowflake
f"SAMPLE SYSTEM ({self.profile_sample or 100})",
dialect=Dialects.Snowflake,
)
.cte(f"{self.table.__tablename__}_rnd")
)

View File

@ -47,11 +47,12 @@ from metadata.generated.schema.tests.testSuite import TestSuite
from metadata.ingestion.api.parser import parse_workflow_config_gracefully
from metadata.ingestion.api.processor import ProcessorStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_test_suite_interface import SQATestSuiteInterface
from metadata.orm_profiler.api.models import TablePartitionConfig
from metadata.test_suite.api.models import TestCaseDefinition, TestSuiteProcessorConfig
from metadata.test_suite.runner.core import DataTestsRunner
from metadata.utils import entity_link
from metadata.utils.helpers import create_ometa_client
from metadata.utils.logger import test_suite_logger
from metadata.utils.workflow_output_handler import print_test_suite_status
@ -115,27 +116,27 @@ class TestSuiteWorkflow:
)
raise err
def _filter_test_cases_for_table_entity(
self, table_fqn: str, test_cases: List[TestCase]
def _filter_test_cases_for_entity(
self, entity_fqn: str, test_cases: List[TestCase]
) -> list[TestCase]:
"""Filter test cases for specific entity"""
return [
test_case
for test_case in test_cases
if test_case.entityLink.__root__.split("::")[2].replace(">", "")
== table_fqn
== entity_fqn
]
def _get_unique_table_entities(self, test_cases: List[TestCase]) -> Set:
def _get_unique_entities_from_test_cases(self, test_cases: List[TestCase]) -> Set:
"""from a list of test cases extract unique table entities"""
table_fqns = [
entity_fqns = [
test_case.entityLink.__root__.split("::")[2].replace(">", "")
for test_case in test_cases
]
return set(table_fqns)
return set(entity_fqns)
def _get_service_connection_from_test_case(self, table_fqn: str):
def _get_service_connection_from_test_case(self, entity_fqn: str):
"""given an entityLink return the service connection
Args:
@ -143,7 +144,7 @@ class TestSuiteWorkflow:
"""
service = self.metadata.get_by_name(
entity=DatabaseService,
fqn=table_fqn.split(".")[0],
fqn=entity_fqn.split(".")[0],
)
if service:
@ -162,7 +163,7 @@ class TestSuiteWorkflow:
)
and not service_connection_config.database
):
service_connection_config.database = table_fqn.split(".")[1]
service_connection_config.database = entity_fqn.split(".")[1]
if (
hasattr(
service_connection_config,
@ -170,13 +171,13 @@ class TestSuiteWorkflow:
)
and not service_connection_config.catalog
):
service_connection_config.catalog = table_fqn.split(".")[1]
service_connection_config.catalog = entity_fqn.split(".")[1]
return service_connection_config
logger.error(f"Could not retrive connection details for entity {entity_link}")
raise ValueError()
def _get_table_entity_from_test_case(self, table_fqn: str):
def _get_table_entity_from_test_case(self, entity_fqn: str):
"""given an entityLink return the table entity
Args:
@ -184,7 +185,7 @@ class TestSuiteWorkflow:
"""
return self.metadata.get_by_name(
entity=Table,
fqn=table_fqn,
fqn=entity_fqn,
fields=["profile"],
)
@ -235,22 +236,22 @@ class TestSuiteWorkflow:
)
return None
def _create_sqa_tests_runner_interface(self, table_fqn: str):
def _create_runner_interface(self, entity_fqn: str):
"""create the interface to execute test against SQA sources"""
table_entity = self._get_table_entity_from_test_case(table_fqn)
return SQAInterface(
table_entity = self._get_table_entity_from_test_case(entity_fqn)
return SQATestSuiteInterface(
service_connection_config=self._get_service_connection_from_test_case(
table_fqn
entity_fqn
),
metadata_config=self.metadata_config,
ometa_client=create_ometa_client(self.metadata_config),
table_entity=table_entity,
profile_sample=self._get_profile_sample(table_entity)
table_sample_precentage=self._get_profile_sample(table_entity)
if not self._get_profile_query(table_entity)
else None,
profile_query=self._get_profile_query(table_entity)
table_sample_query=self._get_profile_query(table_entity)
if not self._get_profile_sample(table_entity)
else None,
partition_config=self._get_partition_details(table_entity)
table_partition_config=self._get_partition_details(table_entity)
if not self._get_profile_query(table_entity)
else None,
)
@ -385,6 +386,15 @@ class TestSuiteWorkflow:
return created_test_case
def add_test_cases_from_cli_config(self, test_cases: list) -> list:
cli_config_test_cases_def = self.get_test_case_from_cli_config()
runtime_created_test_cases = self.compare_and_create_test_cases(
cli_config_test_cases_def, test_cases
)
if runtime_created_test_cases:
return runtime_created_test_cases
return []
def execute(self):
"""Execute test suite workflow"""
test_suites = (
@ -397,23 +407,19 @@ class TestSuiteWorkflow:
test_cases = self.get_test_cases_from_test_suite(test_suites)
if self.processor_config.testSuites:
cli_config_test_cases_def = self.get_test_case_from_cli_config()
runtime_created_test_cases = self.compare_and_create_test_cases(
cli_config_test_cases_def, test_cases
)
if runtime_created_test_cases:
test_cases.extend(runtime_created_test_cases)
test_cases.extend(self.add_test_cases_from_cli_config(test_cases))
unique_table_fqns = self._get_unique_table_entities(test_cases)
unique_entity_fqns = self._get_unique_entities_from_test_cases(test_cases)
for table_fqn in unique_table_fqns:
for entity_fqn in unique_entity_fqns:
try:
sqa_interface = self._create_sqa_tests_runner_interface(table_fqn)
for test_case in self._filter_test_cases_for_table_entity(
table_fqn, test_cases
runner_interface = self._create_runner_interface(entity_fqn)
data_test_runner = self._create_data_tests_runner(runner_interface)
for test_case in self._filter_test_cases_for_entity(
entity_fqn, test_cases
):
try:
data_test_runner = self._create_data_tests_runner(sqa_interface)
test_result = data_test_runner.run_and_handle(test_case)
if not test_result:
continue
@ -430,8 +436,8 @@ class TestSuiteWorkflow:
)
except TypeError as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Could not run test case for table {table_fqn}: {exc}")
self.status.failure(table_fqn)
logger.warning(f"Could not run test case for table {entity_fqn}: {exc}")
self.status.failure(entity_fqn)
def print_status(self) -> None:
"""

View File

@ -15,7 +15,7 @@ Main class to run data tests
from metadata.generated.schema.tests.testCase import TestCase
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.test_suite_protocol import TestSuiteProtocol
from metadata.test_suite.runner.models import TestCaseResultResponse
from metadata.utils.logger import test_suite_logger
@ -25,7 +25,7 @@ logger = test_suite_logger()
class DataTestsRunner:
"""class to execute the test validation"""
def __init__(self, test_runner_interface: SQAInterface):
def __init__(self, test_runner_interface: TestSuiteProtocol):
self.test_runner_interace = test_runner_interface
def run_and_handle(self, test_case: TestCase):

View File

@ -34,6 +34,7 @@ from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
# pylint: disable=abstract-class-instantiated
def column_values_in_set(
test_case: TestCase,
execution_date: datetime,
@ -70,9 +71,7 @@ def column_values_in_set(
f"Cannot find the configured column {column_name} for test case {test_case.name.__root__}"
)
set_count_dict = dict(
runner.dispatch_query_select_first(set_count(col).fn())
) # pylint: disable=abstract-class-instantiated
set_count_dict = dict(runner.dispatch_query_select_first(set_count(col).fn()))
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)
except Exception as exc: # pylint: disable=broad-except

View File

@ -99,8 +99,8 @@ def column_values_missing_count_to_be_equal(
try:
set_count_dict = dict(
runner.dispatch_query_select_first(
set_count(col).fn()
) # pylint: disable=abstract-class-instantiated
set_count(col).fn() # pylint: disable=abstract-class-instantiated
)
)
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)

View File

@ -33,7 +33,7 @@ from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
# pylint: disable=abstract-class-instantiated
def column_values_not_in_set(
test_case: TestCase,
execution_date: datetime,
@ -69,9 +69,7 @@ def column_values_not_in_set(
raise ValueError(
f"Cannot find the configured column {column_name} for test case {test_case.name}"
)
set_count_dict = dict(
runner.dispatch_query_select_first(set_count(col).fn())
) # pylint: disable=abstract-class-instantiated
set_count_dict = dict(runner.dispatch_query_select_first(set_count(col).fn()))
set_count_res = set_count_dict.get(Metrics.COUNT_IN_SET.name)
except Exception as exc:

View File

@ -71,8 +71,8 @@ def column_values_to_match_regex(
value_count_value_res = value_count_value_dict.get(Metrics.COUNT.name)
like_count_value_dict = dict(
runner.dispatch_query_select_first(
like_count(col).fn()
) # pylint: disable=abstract-class-instantiated
like_count(col).fn() # pylint: disable=abstract-class-instantiated
)
)
like_count_value_res = like_count_value_dict.get(Metrics.LIKE_COUNT.name)

View File

@ -71,8 +71,8 @@ def column_values_to_not_match_regex(
not_like_count_dict = dict(
runner.dispatch_query_select_first(
not_like_count(col).fn()
) # pylint: disable=abstract-class-instantiated
not_like_count(col).fn() # pylint: disable=abstract-class-instantiated
)
)
not_like_count_res = not_like_count_dict.get(Metrics.NOT_LIKE_COUNT.name)

View File

@ -92,9 +92,8 @@ def table_column_to_match_set(
None,
)
expected_column_names = [item.strip() for item in column_name.split(",")]
compare = lambda x, y: collections.Counter(x) == collections.Counter(
y
) # pylint: disable=unnecessary-lambda-assignment
# pylint: disable=unnecessary-lambda-assignment
compare = lambda x, y: collections.Counter(x) == collections.Counter(y)
if ordered:
_status = expected_column_names == [col.name for col in column_names]

View File

@ -14,6 +14,7 @@ Helpers module for ingestion related methods
"""
import re
import traceback
from datetime import datetime, timedelta
from functools import wraps
from time import perf_counter
@ -33,6 +34,9 @@ from metadata.generated.schema.api.services.createStorageService import (
)
from metadata.generated.schema.entity.data.chart import Chart, ChartType
from metadata.generated.schema.entity.data.table import Column, Table
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.dashboardService import DashboardService
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.messagingService import MessagingService
@ -382,6 +386,30 @@ def list_to_dict(original: Optional[List[str]], sep: str = "=") -> Dict[str, str
return dict(split_original)
def create_ometa_client(
metadata_config: OpenMetadataConnection,
) -> OpenMetadata:
"""Create an OpenMetadata client
Args:
metadata_config (OpenMetadataConnection): OM connection config
Returns:
OpenMetadata: an OM client
"""
try:
metadata = OpenMetadata(metadata_config)
metadata.health_check()
return metadata
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"No OpenMetadata server configuration found. "
f"Setting client to `None`. You won't be able to access the server from the client: {exc}"
)
raise ValueError(exc)
def clean_up_starting_ending_double_quotes_in_string(string: str) -> str:
"""Remove start and ending double quotes in a string

View File

@ -16,6 +16,7 @@ Validate workflow e2e
import os
import unittest
from datetime import datetime, timedelta
from unittest.mock import patch
import sqlalchemy as sqa
from sqlalchemy.orm import declarative_base
@ -45,7 +46,7 @@ from metadata.generated.schema.entity.services.databaseService import (
)
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.test_suite.api.workflow import TestSuiteWorkflow
test_suite_config = {
@ -168,12 +169,14 @@ class TestE2EWorkflow(unittest.TestCase):
),
)
)
sqa_profiler_interface = SQAInterface(
cls.sqlite_conn.config,
table=User,
table_entity=table,
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQAProfilerInterface(
cls.sqlite_conn.config,
table_entity=table,
ometa_client=None,
)
engine = sqa_profiler_interface.session.get_bind()
session = sqa_profiler_interface.session

View File

@ -15,6 +15,7 @@ Test Metrics behavior
import datetime
import os
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import TEXT, Column, Date, DateTime, Integer, String, Time
@ -26,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.metrics.core import add_props
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.orm.functions.sum import SumFn
@ -71,16 +72,23 @@ class MetricsTest(TestCase):
)
],
)
sqa_profiler_interface = SQAInterface(
sqlite_conn, table=User, table_entity=table_entity
)
engine = sqa_profiler_interface.session.get_bind()
@classmethod
def setUpClass(cls) -> None:
"""
Prepare Ingredients
"""
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
cls.sqa_profiler_interface = SQAProfilerInterface(
cls.sqlite_conn,
table_entity=cls.table_entity,
ometa_client=None,
)
cls.engine = cls.sqa_profiler_interface.session.get_bind()
User.__table__.create(bind=cls.engine)
data = [
@ -688,9 +696,14 @@ class MetricsTest(TestCase):
EmptyUser.__table__.create(bind=self.engine)
sqa_profiler_interface = SQAInterface(
self.sqlite_conn, table=EmptyUser, table_entity=self.table_entity
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=EmptyUser
):
sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
table_entity=self.table_entity,
ometa_client=None,
)
hist = add_props(bins=5)(Metrics.HISTOGRAM.value)
res = (

View File

@ -15,6 +15,7 @@ Test Profiler behavior
import os
from datetime import datetime, timezone
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
import pytest
@ -38,7 +39,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteScheme,
)
from metadata.ingestion.source import sqa_types
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.metrics.core import add_props
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.profiler.core import MissingMetricException, Profiler
@ -79,9 +80,12 @@ class ProfilerTest(TestCase):
)
],
)
sqa_profiler_interface = SQAInterface(
sqlite_conn, table=User, table_entity=table_entity
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn, table_entity=table_entity, ometa_client=None
)
@classmethod
def setUpClass(cls) -> None:

View File

@ -14,6 +14,7 @@ Test Sample behavior
"""
import os
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import TEXT, Column, Integer, String, func
@ -25,7 +26,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.orm.registry import CustomTypes
from metadata.orm_profiler.profiler.core import Profiler
@ -68,9 +69,12 @@ class SampleTest(TestCase):
],
)
sqa_profiler_interface = SQAInterface(
sqlite_conn, table=User, table_entity=table_entity
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn, table_entity=table_entity, ometa_client=None
)
engine = sqa_profiler_interface.session.get_bind()
session = sqa_profiler_interface.session
@ -125,19 +129,21 @@ class SampleTest(TestCase):
"""
# Randomly pick table_count to init the Profiler, we don't care for this test
table_count = Metrics.ROW_COUNT.value
sqa_profiler_interface = SQAInterface(
self.sqlite_conn,
table=User,
table_entity=self.table_entity,
profile_sample=50,
)
profiler = Profiler(
table_count,
profiler_interface=sqa_profiler_interface,
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
table_entity=self.table_entity,
table_sample_precentage=50,
ometa_client=None,
)
res = self.session.query(func.count()).select_from(profiler.sample).first()
sample = sqa_profiler_interface._create_thread_safe_sampler(
self.session, User
).random_sample()
res = self.session.query(func.count()).select_from(sample).first()
assert res[0] < 30
def test_table_row_count(self):
@ -162,15 +168,18 @@ class SampleTest(TestCase):
"""
count = Metrics.COUNT.value
profiler = Profiler(
count,
profiler_interface=SQAInterface(
self.sqlite_conn,
table=User,
table_entity=self.table_entity,
profile_sample=50,
),
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
profiler = Profiler(
count,
profiler_interface=SQAProfilerInterface(
self.sqlite_conn,
table_entity=self.table_entity,
table_sample_precentage=50,
ometa_client=None,
),
)
res = profiler.compute_metrics()._column_results
assert res.get(User.name.name)[Metrics.COUNT.name] < 30
@ -179,15 +188,18 @@ class SampleTest(TestCase):
Histogram should run correctly
"""
hist = Metrics.HISTOGRAM.value
profiler = Profiler(
hist,
profiler_interface=SQAInterface(
self.sqlite_conn,
table=User,
table_entity=self.table_entity,
profile_sample=50,
),
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
profiler = Profiler(
hist,
profiler_interface=SQAProfilerInterface(
self.sqlite_conn,
table_entity=self.table_entity,
table_sample_precentage=50,
ometa_client=None,
),
)
res = profiler.compute_metrics()._column_results
# The sum of all frequencies should be sampled

View File

@ -16,6 +16,7 @@ Test SQA Interface
import os
from datetime import datetime, timezone
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from pytest import raises
@ -38,7 +39,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.metrics.core import (
ComposedMetric,
MetricTypes,
@ -74,24 +75,19 @@ class SQAInterfaceTest(TestCase):
sqlite_conn = SQLiteConnection(
scheme=SQLiteScheme.sqlite_pysqlite,
)
self.sqa_profiler_interface = SQAInterface(
sqlite_conn, table=User, table_entity=table_entity
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
self.sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn, table_entity=table_entity, ometa_client=None
)
self.table = User
def test_init_interface(self):
"""Test we can instantiate our interface object correctly"""
assert self.sqa_profiler_interface._sampler != None
assert self.sqa_profiler_interface._runner != None
assert isinstance(self.sqa_profiler_interface.session, Session)
def test_private_attributes(self):
with raises(AttributeError):
self.sqa_profiler_interface.runner = None
self.sqa_profiler_interface.sampler = None
self.sqa_profiler_interface.sample = None
def tearDown(self) -> None:
self.sqa_profiler_interface._sampler = None
@ -113,9 +109,12 @@ class SQAInterfaceTestMultiThread(TestCase):
scheme=SQLiteScheme.sqlite_pysqlite,
databaseMode=db_path + "?check_same_thread=False",
)
sqa_profiler_interface = SQAInterface(
sqlite_conn, table=User, table_entity=table_entity
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn, table_entity=table_entity, ometa_client=None
)
@classmethod
def setUpClass(cls) -> None:
@ -152,8 +151,6 @@ class SQAInterfaceTestMultiThread(TestCase):
def test_init_interface(self):
"""Test we can instantiate our interface object correctly"""
assert self.sqa_profiler_interface._sampler != None
assert self.sqa_profiler_interface._runner != None
assert isinstance(self.sqa_profiler_interface.session, Session)
def test_get_all_metrics(self):

View File

@ -33,7 +33,7 @@ from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline
DatabaseServiceProfilerPipeline,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.api.models import ProfilerProcessorConfig
from metadata.orm_profiler.api.workflow import ProfilerWorkflow
from metadata.orm_profiler.profiler.default import DefaultProfiler
@ -88,7 +88,7 @@ class User(Base):
@patch.object(
SQAInterface,
SQAProfilerInterface,
"_convert_table_to_orm_object",
return_value=User,
)
@ -190,7 +190,7 @@ def test_filter_entities(mocked_method):
@patch.object(
SQAInterface,
SQAProfilerInterface,
"_convert_table_to_orm_object",
return_value=User,
)
@ -227,7 +227,7 @@ def test_profile_def(mocked_method, mocked_orm):
@patch.object(
SQAInterface,
SQAProfilerInterface,
"_convert_table_to_orm_object",
return_value=User,
)

View File

@ -18,6 +18,7 @@ Each test should validate the Success, Failure and Aborted statuses
import os
import unittest
from datetime import datetime
from unittest.mock import patch
from uuid import uuid4
import sqlalchemy as sqa
@ -32,7 +33,7 @@ from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
from metadata.generated.schema.tests.testSuite import TestSuite
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.interfaces.sqa_interface import SQAInterface
from metadata.interfaces.sqalchemy.sqa_test_suite_interface import SQATestSuiteInterface
from metadata.test_suite.validations.core import validation_enum_registry
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
@ -78,11 +79,14 @@ class testSuiteValidation(unittest.TestCase):
databaseMode=db_path + "?check_same_thread=False",
)
sqa_profiler_interface = SQAInterface(
sqlite_conn,
table=User,
table_entity=TABLE,
)
with patch.object(
SQATestSuiteInterface, "_convert_table_to_orm_object", return_value=User
):
sqa_profiler_interface = SQATestSuiteInterface(
sqlite_conn,
table_entity=TABLE,
ometa_client=None,
)
runner = sqa_profiler_interface.runner
engine = sqa_profiler_interface.session.get_bind()
session = sqa_profiler_interface.session