mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-15 12:37:18 +00:00
MINOR: Add injection to profiler (#21738)
* Initial implementation for our Connection Class * Implement the Initial Connection class * Add Unit Tests * Implement Dependency Injection for the Ingestion Framework * Fix Test * Fix Profile Test Connection * Add Injection to Metrics in Profiler * Add Injection to the Profiler * Fix UnitTests * Fix Pytests * Fix Tests * Fix types
This commit is contained in:
parent
bd948de115
commit
e79c54e6a5
@ -11,7 +11,14 @@
|
|||||||
"""
|
"""
|
||||||
OpenMetadata package initialization.
|
OpenMetadata package initialization.
|
||||||
"""
|
"""
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from metadata.profiler.metrics.registry import Metrics
|
||||||
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.profiler.source.database.base.profiler_resolver import (
|
||||||
|
DefaultProfilerResolver,
|
||||||
|
ProfilerResolver,
|
||||||
|
)
|
||||||
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
|
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
|
||||||
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
|
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
|
||||||
|
|
||||||
@ -20,3 +27,5 @@ container = DependencyContainer()
|
|||||||
|
|
||||||
# Register the source loader
|
# Register the source loader
|
||||||
container.register(SourceLoader, DefaultSourceLoader)
|
container.register(SourceLoader, DefaultSourceLoader)
|
||||||
|
container.register(Type[MetricRegistry], lambda: Metrics)
|
||||||
|
container.register(Type[ProfilerResolver], lambda: DefaultProfilerResolver)
|
||||||
|
@ -34,7 +34,7 @@ class TableRowInsertedCountToBeBetweenValidator(
|
|||||||
"""Validator for table row inserted count to be between test case"""
|
"""Validator for table row inserted count to be between test case"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_threshold_date(range_type: str, range_interval: int):
|
def get_threshold_date_dt(range_type: str, range_interval: int) -> datetime:
|
||||||
"""returns the threshold datetime in utc to count the numbers of rows inserted
|
"""returns the threshold datetime in utc to count the numbers of rows inserted
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -55,7 +55,22 @@ class TableRowInsertedCountToBeBetweenValidator(
|
|||||||
threshold_date = threshold_date.replace(
|
threshold_date = threshold_date.replace(
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
)
|
)
|
||||||
return threshold_date.strftime("%Y%m%d%H%M%S")
|
return threshold_date
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_threshold_date(
|
||||||
|
range_type: str, range_interval: int, date_format: str = "%Y%m%d%H%M%S"
|
||||||
|
):
|
||||||
|
"""returns the threshold datetime in utc as string to count the numbers of rows inserted
|
||||||
|
|
||||||
|
Args:
|
||||||
|
range_type (str): type of range (i.e. HOUR, DAY, MONTH, YEAR)
|
||||||
|
range_interval (int): interval of range (i.e. 1, 2, 3, 4)
|
||||||
|
date_format (str): format of the date (i.e. %Y%m%d%H%M%S, %Y-%m-%d %H:%M:%S)
|
||||||
|
"""
|
||||||
|
return TableRowInsertedCountToBeBetweenValidator.get_threshold_date_dt(
|
||||||
|
range_type, range_interval
|
||||||
|
).strftime(date_format)
|
||||||
|
|
||||||
def _get_column_name(self):
|
def _get_column_name(self):
|
||||||
"""returns the column name to be validated"""
|
"""returns the column name to be validated"""
|
||||||
|
@ -113,6 +113,9 @@ def get_test_connection_fn(connection: BaseModel) -> Callable:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def get_connection(connection: BaseModel) -> Any:
|
def get_connection(connection: BaseModel) -> Any:
|
||||||
"""
|
"""
|
||||||
Main method to prepare a connection from
|
Main method to prepare a connection from
|
||||||
|
@ -13,7 +13,7 @@ System table profiler
|
|||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
from more_itertools import partition
|
from more_itertools import partition
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
@ -25,7 +25,11 @@ from metadata.profiler.interface.sqlalchemy.stored_statistics_profiler import (
|
|||||||
StoredStatisticsSource,
|
StoredStatisticsSource,
|
||||||
)
|
)
|
||||||
from metadata.profiler.metrics.core import Metric
|
from metadata.profiler.metrics.core import Metric
|
||||||
from metadata.profiler.metrics.registry import Metrics
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
|
Inject,
|
||||||
|
inject_class_attributes,
|
||||||
|
)
|
||||||
from metadata.utils.logger import profiler_logger
|
from metadata.utils.logger import profiler_logger
|
||||||
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
|
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
|
||||||
from metadata.utils.ssl_manager import get_ssl_connection
|
from metadata.utils.ssl_manager import get_ssl_connection
|
||||||
@ -61,23 +65,28 @@ class TableStats(BaseModel):
|
|||||||
columns: Dict[str, ColumnStats] = {}
|
columns: Dict[str, ColumnStats] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
||||||
"""Trino system profile source"""
|
"""Trino system profile source"""
|
||||||
|
|
||||||
metric_stats_map: Dict[Metrics, str] = {
|
metrics: Inject[Type[MetricRegistry]]
|
||||||
Metrics.NULL_RATIO: "nulls_fractions",
|
|
||||||
Metrics.DISTINCT_COUNT: "distinct_values_count",
|
|
||||||
Metrics.ROW_COUNT: "row_count",
|
|
||||||
Metrics.MAX: "high_value",
|
|
||||||
Metrics.MIN: "low_value",
|
|
||||||
}
|
|
||||||
|
|
||||||
metric_stats_by_name: Dict[str, str] = {
|
@classmethod
|
||||||
k.name: v for k, v in metric_stats_map.items()
|
def get_metric_stats_map(cls) -> Dict[MetricRegistry, str]:
|
||||||
}
|
return {
|
||||||
|
cls.metrics.NULL_RATIO: "nulls_fractions",
|
||||||
|
cls.metrics.DISTINCT_COUNT: "distinct_values_count",
|
||||||
|
cls.metrics.ROW_COUNT: "row_count",
|
||||||
|
cls.metrics.MAX: "high_value",
|
||||||
|
cls.metrics.MIN: "low_value",
|
||||||
|
}
|
||||||
|
|
||||||
def get_statistics_metrics(self) -> Set[Metrics]:
|
@classmethod
|
||||||
return set(self.metric_stats_map.keys())
|
def get_metric_stats_by_name(cls) -> Dict[str, str]:
|
||||||
|
return {k.name: v for k, v in cls.get_metric_stats_map().items()}
|
||||||
|
|
||||||
|
def get_statistics_metrics(self) -> Set[MetricRegistry]:
|
||||||
|
return set(self.get_metric_stats_map().keys())
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -96,7 +105,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
|||||||
f"Column {column} not found in table {table_name}. Statistics might be stale or missing."
|
f"Column {column} not found in table {table_name}. Statistics might be stale or missing."
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
m.name(): getattr(column_stats, self.metric_stats_by_name[m.name()])
|
m.name(): getattr(column_stats, self.get_metric_stats_by_name()[m.name()])
|
||||||
for m in metric
|
for m in metric
|
||||||
}
|
}
|
||||||
result.update(self.get_hybrid_statistics(table_stats, column_stats))
|
result.update(self.get_hybrid_statistics(table_stats, column_stats))
|
||||||
@ -108,7 +117,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
table_stats = self._get_cached_stats(schema, table_name)
|
table_stats = self._get_cached_stats(schema, table_name)
|
||||||
return {
|
return {
|
||||||
m.name(): getattr(table_stats, self.metric_stats_by_name[m.name()])
|
m.name(): getattr(table_stats, self.get_metric_stats_by_name()[m.name()])
|
||||||
for m in metric
|
for m in metric
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +168,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
# trino stats are in fractions, so we need to convert them to counts (unlike our default profiler)
|
# trino stats are in fractions, so we need to convert them to counts (unlike our default profiler)
|
||||||
Metrics.NULL_COUNT.name: (
|
self.metrics.NULL_COUNT.name: (
|
||||||
int(table_stats.row_count * column_stats.nulls_fraction)
|
int(table_stats.row_count * column_stats.nulls_fraction)
|
||||||
if None not in [table_stats.row_count, column_stats.nulls_fraction]
|
if None not in [table_stats.row_count, column_stats.nulls_fraction]
|
||||||
else None
|
else None
|
||||||
|
@ -16,7 +16,7 @@ Run profiler metrics on the table
|
|||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
|
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
|
||||||
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
|
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
|
||||||
@ -27,13 +27,29 @@ from metadata.generated.schema.entity.data.table import TableType
|
|||||||
from metadata.profiler.metrics.registry import Metrics
|
from metadata.profiler.metrics.registry import Metrics
|
||||||
from metadata.profiler.orm.registry import Dialects
|
from metadata.profiler.orm.registry import Dialects
|
||||||
from metadata.profiler.processor.runner import QueryRunner
|
from metadata.profiler.processor.runner import QueryRunner
|
||||||
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
|
DependencyNotFoundError,
|
||||||
|
Inject,
|
||||||
|
inject,
|
||||||
|
)
|
||||||
from metadata.utils.logger import profiler_interface_registry_logger
|
from metadata.utils.logger import profiler_interface_registry_logger
|
||||||
|
|
||||||
logger = profiler_interface_registry_logger()
|
logger = profiler_interface_registry_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None):
|
||||||
|
if metrics is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||||
|
)
|
||||||
|
return metrics.ROW_COUNT().name()
|
||||||
|
|
||||||
|
|
||||||
COLUMN_COUNT = "columnCount"
|
COLUMN_COUNT = "columnCount"
|
||||||
COLUMN_NAMES = "columnNames"
|
COLUMN_NAMES = "columnNames"
|
||||||
ROW_COUNT = Metrics.ROW_COUNT().name()
|
ROW_COUNT = get_row_count_metric()
|
||||||
SIZE_IN_BYTES = "sizeInBytes"
|
SIZE_IN_BYTES = "sizeInBytes"
|
||||||
CREATE_DATETIME = "createDateTime"
|
CREATE_DATETIME = "createDateTime"
|
||||||
|
|
||||||
@ -362,9 +378,15 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer):
|
|||||||
class MySQLTableMetricComputer(BaseTableMetricComputer):
|
class MySQLTableMetricComputer(BaseTableMetricComputer):
|
||||||
"""MySQL Table Metric Computer"""
|
"""MySQL Table Metric Computer"""
|
||||||
|
|
||||||
def compute(self):
|
@inject
|
||||||
|
def compute(self, metrics: Inject[Type[MetricRegistry]] = None):
|
||||||
"""compute table metrics for mysql"""
|
"""compute table metrics for mysql"""
|
||||||
|
|
||||||
|
if metrics is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||||
|
)
|
||||||
|
|
||||||
columns = [
|
columns = [
|
||||||
Column("TABLE_ROWS").label(ROW_COUNT),
|
Column("TABLE_ROWS").label(ROW_COUNT),
|
||||||
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
|
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
|
||||||
@ -390,7 +412,7 @@ class MySQLTableMetricComputer(BaseTableMetricComputer):
|
|||||||
res = res._asdict()
|
res = res._asdict()
|
||||||
# innodb row count is an estimate we need to patch the row count with COUNT(*)
|
# innodb row count is an estimate we need to patch the row count with COUNT(*)
|
||||||
# https://dev.mysql.com/doc/refman/8.3/en/information-schema-innodb-tablestats-table.html
|
# https://dev.mysql.com/doc/refman/8.3/en/information-schema-innodb-tablestats-table.html
|
||||||
row_count = self.runner.select_first_from_table(Metrics.ROW_COUNT().fn())
|
row_count = self.runner.select_first_from_table(metrics.ROW_COUNT().fn())
|
||||||
res.update({ROW_COUNT: row_count.rowCount})
|
res.update({ROW_COUNT: row_count.rowCount})
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type
|
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, cast
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column
|
||||||
@ -94,7 +94,7 @@ class Profiler(Generic[TMetric]):
|
|||||||
:param profile_sample: % of rows to use for sampling column metrics
|
:param profile_sample: % of rows to use for sampling column metrics
|
||||||
"""
|
"""
|
||||||
self.global_profiler_configuration: Optional[ProfilerConfiguration] = (
|
self.global_profiler_configuration: Optional[ProfilerConfiguration] = (
|
||||||
global_profiler_configuration.config_value
|
cast(ProfilerConfiguration, global_profiler_configuration.config_value)
|
||||||
if global_profiler_configuration
|
if global_profiler_configuration
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -12,55 +12,54 @@
|
|||||||
"""
|
"""
|
||||||
Default simple profiler to use
|
Default simple profiler to use
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeMeta
|
from sqlalchemy.orm import DeclarativeMeta
|
||||||
|
|
||||||
from metadata.generated.schema.configuration.profilerConfiguration import (
|
|
||||||
ProfilerConfiguration,
|
|
||||||
)
|
|
||||||
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
|
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
|
||||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||||
|
from metadata.generated.schema.settings.settings import Settings
|
||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||||
from metadata.profiler.metrics.core import Metric, add_props
|
from metadata.profiler.metrics.core import Metric, add_props
|
||||||
from metadata.profiler.metrics.registry import Metrics
|
|
||||||
from metadata.profiler.processor.core import Profiler
|
from metadata.profiler.processor.core import Profiler
|
||||||
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
|
||||||
|
|
||||||
def get_default_metrics(
|
def get_default_metrics(
|
||||||
|
metrics_registry: Type[MetricRegistry],
|
||||||
table: DeclarativeMeta,
|
table: DeclarativeMeta,
|
||||||
ometa_client: Optional[OpenMetadata] = None,
|
ometa_client: Optional[OpenMetadata] = None,
|
||||||
db_service: Optional[DatabaseService] = None,
|
db_service: Optional[DatabaseService] = None,
|
||||||
) -> List[Metric]:
|
) -> List[Metric]:
|
||||||
return [
|
return [
|
||||||
# Table Metrics
|
# Table Metrics
|
||||||
Metrics.ROW_COUNT.value,
|
metrics_registry.ROW_COUNT.value,
|
||||||
add_props(table=table)(Metrics.COLUMN_COUNT.value),
|
add_props(table=table)(metrics_registry.COLUMN_COUNT.value),
|
||||||
add_props(table=table)(Metrics.COLUMN_NAMES.value),
|
add_props(table=table)(metrics_registry.COLUMN_NAMES.value),
|
||||||
# We'll use the ometa_client & db_service in case we need to fetch info to ES
|
# We'll use the ometa_client & db_service in case we need to fetch info to ES
|
||||||
add_props(table=table, ometa_client=ometa_client, db_service=db_service)(
|
add_props(table=table, ometa_client=ometa_client, db_service=db_service)(
|
||||||
Metrics.SYSTEM.value
|
metrics_registry.SYSTEM.value
|
||||||
),
|
),
|
||||||
# Column Metrics
|
# Column Metrics
|
||||||
Metrics.MEDIAN.value,
|
metrics_registry.MEDIAN.value,
|
||||||
Metrics.FIRST_QUARTILE.value,
|
metrics_registry.FIRST_QUARTILE.value,
|
||||||
Metrics.THIRD_QUARTILE.value,
|
metrics_registry.THIRD_QUARTILE.value,
|
||||||
Metrics.MEAN.value,
|
metrics_registry.MEAN.value,
|
||||||
Metrics.COUNT.value,
|
metrics_registry.COUNT.value,
|
||||||
Metrics.DISTINCT_COUNT.value,
|
metrics_registry.DISTINCT_COUNT.value,
|
||||||
Metrics.DISTINCT_RATIO.value,
|
metrics_registry.DISTINCT_RATIO.value,
|
||||||
Metrics.MIN.value,
|
metrics_registry.MIN.value,
|
||||||
Metrics.MAX.value,
|
metrics_registry.MAX.value,
|
||||||
Metrics.NULL_COUNT.value,
|
metrics_registry.NULL_COUNT.value,
|
||||||
Metrics.NULL_RATIO.value,
|
metrics_registry.NULL_RATIO.value,
|
||||||
Metrics.STDDEV.value,
|
metrics_registry.STDDEV.value,
|
||||||
Metrics.SUM.value,
|
metrics_registry.SUM.value,
|
||||||
Metrics.UNIQUE_COUNT.value,
|
metrics_registry.UNIQUE_COUNT.value,
|
||||||
Metrics.UNIQUE_RATIO.value,
|
metrics_registry.UNIQUE_RATIO.value,
|
||||||
Metrics.IQR.value,
|
metrics_registry.IQR.value,
|
||||||
Metrics.HISTOGRAM.value,
|
metrics_registry.HISTOGRAM.value,
|
||||||
Metrics.NON_PARAMETRIC_SKEW.value,
|
metrics_registry.NON_PARAMETRIC_SKEW.value,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -74,12 +73,14 @@ class DefaultProfiler(Profiler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
profiler_interface: ProfilerInterface,
|
profiler_interface: ProfilerInterface,
|
||||||
|
metrics_registry: Type[MetricRegistry],
|
||||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||||
exclude_columns: Optional[List[str]] = None,
|
exclude_columns: Optional[List[str]] = None,
|
||||||
global_profiler_configuration: Optional[ProfilerConfiguration] = None,
|
global_profiler_configuration: Optional[Settings] = None,
|
||||||
db_service=None,
|
db_service=None,
|
||||||
):
|
):
|
||||||
_metrics = get_default_metrics(
|
_metrics = get_default_metrics(
|
||||||
|
metrics_registry=metrics_registry,
|
||||||
table=profiler_interface.table,
|
table=profiler_interface.table,
|
||||||
ometa_client=profiler_interface.ometa_client,
|
ometa_client=profiler_interface.ometa_client,
|
||||||
db_service=db_service,
|
db_service=db_service,
|
||||||
|
@ -33,22 +33,35 @@ from metadata.profiler.metrics.core import (
|
|||||||
SystemMetric,
|
SystemMetric,
|
||||||
TMetric,
|
TMetric,
|
||||||
)
|
)
|
||||||
from metadata.profiler.metrics.registry import Metrics
|
|
||||||
from metadata.profiler.orm.converter.converter_registry import converter_registry
|
from metadata.profiler.orm.converter.converter_registry import converter_registry
|
||||||
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
|
DependencyNotFoundError,
|
||||||
|
Inject,
|
||||||
|
inject,
|
||||||
|
)
|
||||||
from metadata.utils.sqa_like_column import SQALikeColumn
|
from metadata.utils.sqa_like_column import SQALikeColumn
|
||||||
|
|
||||||
|
|
||||||
class MetricFilter:
|
class MetricFilter:
|
||||||
"""Metric filter class for profiler"""
|
"""Metric filter class for profiler"""
|
||||||
|
|
||||||
|
@inject
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
metrics: Tuple[Type[TMetric]],
|
metrics: Tuple[Type[TMetric]],
|
||||||
global_profiler_config: Optional[ProfilerConfiguration] = None,
|
global_profiler_config: Optional[ProfilerConfiguration] = None,
|
||||||
table_profiler_config: Optional[TableProfilerConfig] = None,
|
table_profiler_config: Optional[TableProfilerConfig] = None,
|
||||||
column_profiler_config: Optional[List[ColumnProfilerConfig]] = None,
|
column_profiler_config: Optional[List[ColumnProfilerConfig]] = None,
|
||||||
|
metrics_registry: Inject[Type[MetricRegistry]] = None,
|
||||||
):
|
):
|
||||||
|
if metrics_registry is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||||
|
)
|
||||||
|
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.metrics_registry = metrics_registry
|
||||||
self.global_profiler_config = global_profiler_config
|
self.global_profiler_config = global_profiler_config
|
||||||
self.table_profiler_config = table_profiler_config
|
self.table_profiler_config = table_profiler_config
|
||||||
self.column_profiler_config = column_profiler_config
|
self.column_profiler_config = column_profiler_config
|
||||||
@ -196,7 +209,7 @@ class MetricFilter:
|
|||||||
|
|
||||||
metrics = [
|
metrics = [
|
||||||
Metric.value
|
Metric.value
|
||||||
for Metric in Metrics
|
for Metric in self.metrics_registry
|
||||||
if Metric.value.name() in {mtrc.value for mtrc in col_dtype_config.metrics}
|
if Metric.value.name() in {mtrc.value for mtrc in col_dtype_config.metrics}
|
||||||
and Metric.value in metrics
|
and Metric.value in metrics
|
||||||
]
|
]
|
||||||
@ -240,7 +253,7 @@ class MetricFilter:
|
|||||||
|
|
||||||
metrics = [
|
metrics = [
|
||||||
Metric.value
|
Metric.value
|
||||||
for Metric in Metrics
|
for Metric in self.metrics_registry
|
||||||
if Metric.value.name().lower() in {mtrc.lower() for mtrc in metric_names}
|
if Metric.value.name().lower() in {mtrc.lower() for mtrc in metric_names}
|
||||||
and Metric.value in metrics
|
and Metric.value in metrics
|
||||||
]
|
]
|
||||||
|
@ -13,20 +13,30 @@
|
|||||||
Models to map profiler definitions
|
Models to map profiler definitions
|
||||||
JSON workflows to the profiler
|
JSON workflows to the profiler
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, BeforeValidator
|
from pydantic import BaseModel, BeforeValidator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from metadata.profiler.metrics.registry import Metrics
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
|
DependencyNotFoundError,
|
||||||
|
Inject,
|
||||||
|
inject,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def valid_metric(value: str):
|
@inject
|
||||||
|
def valid_metric(value: str, metrics: Inject[Type[MetricRegistry]] = None):
|
||||||
"""
|
"""
|
||||||
Validate that the input metrics are correctly named
|
Validate that the input metrics are correctly named
|
||||||
and can be found in the Registry
|
and can be found in the Registry
|
||||||
"""
|
"""
|
||||||
if not Metrics.get(value.upper()):
|
if metrics is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||||
|
)
|
||||||
|
if not metrics.get(value.upper()):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Metric name {value} is not a proper metric name from the Registry"
|
f"Metric name {value} is not a proper metric name from the Registry"
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Tuple, Type
|
||||||
|
|
||||||
|
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||||
|
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||||
|
from metadata.sampler.sampler_interface import SamplerInterface
|
||||||
|
from metadata.utils.service_spec.service_spec import (
|
||||||
|
import_profiler_class,
|
||||||
|
import_sampler_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfilerResolver(ABC):
|
||||||
|
"""Abstract class for the profiler resolver"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def resolve(
|
||||||
|
processing_engine: str, service_type: ServiceType, source_type: str
|
||||||
|
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
|
||||||
|
"""Resolve the sampler and profiler based on the processing engine."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultProfilerResolver(ProfilerResolver):
|
||||||
|
"""Default profiler resolver"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve(
|
||||||
|
processing_engine: str, service_type: ServiceType, source_type: str
|
||||||
|
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
|
||||||
|
"""Resolve the sampler and profiler based on the processing engine."""
|
||||||
|
sampler_class = import_sampler_class(service_type, source_type=source_type)
|
||||||
|
profiler_class = import_profiler_class(service_type, source_type=source_type)
|
||||||
|
return sampler_class, profiler_class
|
@ -14,7 +14,7 @@ Base source for the profiler used to instantiate a profiler runner with
|
|||||||
its interface
|
its interface
|
||||||
"""
|
"""
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, cast
|
from typing import Optional, Type, cast
|
||||||
|
|
||||||
from metadata.generated.schema.configuration.profilerConfiguration import (
|
from metadata.generated.schema.configuration.profilerConfiguration import (
|
||||||
ProfilerConfiguration,
|
ProfilerConfiguration,
|
||||||
@ -22,10 +22,7 @@ from metadata.generated.schema.configuration.profilerConfiguration import (
|
|||||||
from metadata.generated.schema.entity.data.database import Database
|
from metadata.generated.schema.entity.data.database import Database
|
||||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||||
from metadata.generated.schema.entity.data.table import Table
|
from metadata.generated.schema.entity.data.table import Table
|
||||||
from metadata.generated.schema.entity.services.databaseService import (
|
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
|
||||||
DatabaseConnection,
|
|
||||||
DatabaseService,
|
|
||||||
)
|
|
||||||
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||||
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
||||||
DatabaseServiceProfilerPipeline,
|
DatabaseServiceProfilerPipeline,
|
||||||
@ -36,9 +33,10 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
|||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.profiler.api.models import ProfilerProcessorConfig, TableConfig
|
from metadata.profiler.api.models import ProfilerProcessorConfig, TableConfig
|
||||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||||
from metadata.profiler.metrics.registry import Metrics
|
|
||||||
from metadata.profiler.processor.core import Profiler
|
from metadata.profiler.processor.core import Profiler
|
||||||
from metadata.profiler.processor.default import DefaultProfiler, get_default_metrics
|
from metadata.profiler.processor.default import DefaultProfiler, get_default_metrics
|
||||||
|
from metadata.profiler.registry import MetricRegistry
|
||||||
|
from metadata.profiler.source.database.base.profiler_resolver import ProfilerResolver
|
||||||
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
|
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
|
||||||
from metadata.sampler.config import (
|
from metadata.sampler.config import (
|
||||||
get_config_for_table,
|
get_config_for_table,
|
||||||
@ -47,12 +45,13 @@ from metadata.sampler.config import (
|
|||||||
)
|
)
|
||||||
from metadata.sampler.models import SampleConfig
|
from metadata.sampler.models import SampleConfig
|
||||||
from metadata.sampler.sampler_interface import SamplerInterface
|
from metadata.sampler.sampler_interface import SamplerInterface
|
||||||
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
|
DependencyNotFoundError,
|
||||||
|
Inject,
|
||||||
|
inject,
|
||||||
|
)
|
||||||
from metadata.utils.logger import profiler_logger
|
from metadata.utils.logger import profiler_logger
|
||||||
from metadata.utils.profiler_utils import get_context_entities
|
from metadata.utils.profiler_utils import get_context_entities
|
||||||
from metadata.utils.service_spec.service_spec import (
|
|
||||||
import_profiler_class,
|
|
||||||
import_sampler_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = profiler_logger()
|
logger = profiler_logger()
|
||||||
|
|
||||||
@ -77,8 +76,7 @@ class ProfilerSource(ProfilerSourceInterface):
|
|||||||
self.ometa_client = ometa_client
|
self.ometa_client = ometa_client
|
||||||
self._interface_type: str = config.source.type.lower()
|
self._interface_type: str = config.source.type.lower()
|
||||||
self._interface = None
|
self._interface = None
|
||||||
# We define this in create_profiler_interface to help us reuse
|
|
||||||
# this method for the sampler, which does not have a DatabaseServiceProfilerPipeline
|
|
||||||
self.source_config = None
|
self.source_config = None
|
||||||
self.global_profiler_configuration = global_profiler_configuration
|
self.global_profiler_configuration = global_profiler_configuration
|
||||||
|
|
||||||
@ -122,25 +120,34 @@ class ProfilerSource(ProfilerSourceInterface):
|
|||||||
|
|
||||||
return config_copy
|
return config_copy
|
||||||
|
|
||||||
|
@inject
|
||||||
def create_profiler_interface(
|
def create_profiler_interface(
|
||||||
self,
|
self,
|
||||||
entity: Table,
|
entity: Table,
|
||||||
config: Optional[TableConfig],
|
config: Optional[TableConfig],
|
||||||
profiler_config: Optional[ProfilerProcessorConfig],
|
schema_entity: DatabaseSchema,
|
||||||
schema_entity: Optional[DatabaseSchema],
|
database_entity: Database,
|
||||||
database_entity: Optional[Database],
|
profiler_resolver: Inject[Type[ProfilerResolver]] = None,
|
||||||
db_service: Optional[DatabaseService],
|
|
||||||
) -> ProfilerInterface:
|
) -> ProfilerInterface:
|
||||||
"""Create sqlalchemy profiler interface"""
|
"""Create the appropriate profiler interface based on processing engine."""
|
||||||
|
if profiler_resolver is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"ProfilerResolver dependency not found. Please ensure the ProfilerResolver is properly registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: For some reason I do not understand, if we instantiate this on the __init__ method, we break the
|
||||||
|
# autoclassification workflow. This should be fixed. There should not be an impact on AutoClassification.
|
||||||
|
# We have an issue to track this here: https://github.com/open-metadata/OpenMetadata/issues/21790
|
||||||
self.source_config = DatabaseServiceProfilerPipeline.model_validate(
|
self.source_config = DatabaseServiceProfilerPipeline.model_validate(
|
||||||
self.config.source.sourceConfig.config
|
self.config.source.sourceConfig.config
|
||||||
)
|
)
|
||||||
profiler_class = import_profiler_class(
|
|
||||||
ServiceType.Database, source_type=self._interface_type
|
sampler_class, profiler_class = profiler_resolver.resolve(
|
||||||
)
|
processing_engine=self.get_processing_engine(self.source_config),
|
||||||
sampler_class = import_sampler_class(
|
service_type=ServiceType.Database,
|
||||||
ServiceType.Database, source_type=self._interface_type
|
source_type=self._interface_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is shared between the sampler and profiler interfaces
|
# This is shared between the sampler and profiler interfaces
|
||||||
sampler_interface: SamplerInterface = sampler_class.create(
|
sampler_interface: SamplerInterface = sampler_class.create(
|
||||||
service_connection_config=self.service_conn_config,
|
service_connection_config=self.service_conn_config,
|
||||||
@ -155,6 +162,8 @@ class ProfilerSource(ProfilerSourceInterface):
|
|||||||
samplingMethodType=self.source_config.samplingMethodType,
|
samplingMethodType=self.source_config.samplingMethodType,
|
||||||
randomizedSample=self.source_config.randomizedSample,
|
randomizedSample=self.source_config.randomizedSample,
|
||||||
),
|
),
|
||||||
|
# TODO: Change this when we have the processing engine configuration implemented. Right now it does nothing.
|
||||||
|
processing_engine=self.get_processing_engine(self.source_config),
|
||||||
)
|
)
|
||||||
|
|
||||||
profiler_interface: ProfilerInterface = profiler_class.create(
|
profiler_interface: ProfilerInterface = profiler_class.create(
|
||||||
@ -168,28 +177,33 @@ class ProfilerSource(ProfilerSourceInterface):
|
|||||||
self.interface = profiler_interface
|
self.interface = profiler_interface
|
||||||
return self.interface
|
return self.interface
|
||||||
|
|
||||||
|
@inject
|
||||||
def get_profiler_runner(
|
def get_profiler_runner(
|
||||||
self, entity: Table, profiler_config: ProfilerProcessorConfig
|
self,
|
||||||
|
entity: Table,
|
||||||
|
profiler_config: ProfilerProcessorConfig,
|
||||||
|
metrics_registry: Inject[Type[MetricRegistry]] = None,
|
||||||
) -> Profiler:
|
) -> Profiler:
|
||||||
"""
|
"""
|
||||||
Returns the runner for the profiler
|
Returns the runner for the profiler
|
||||||
"""
|
"""
|
||||||
|
if metrics_registry is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||||
|
)
|
||||||
|
|
||||||
table_config = get_config_for_table(entity, profiler_config)
|
table_config = get_config_for_table(entity, profiler_config)
|
||||||
schema_entity, database_entity, db_service = get_context_entities(
|
schema_entity, database_entity, db_service = get_context_entities(
|
||||||
entity=entity, metadata=self.ometa_client
|
entity=entity, metadata=self.ometa_client
|
||||||
)
|
)
|
||||||
profiler_interface = self.create_profiler_interface(
|
profiler_interface = self.create_profiler_interface(
|
||||||
entity,
|
entity, table_config, schema_entity, database_entity
|
||||||
table_config,
|
|
||||||
profiler_config,
|
|
||||||
schema_entity,
|
|
||||||
database_entity,
|
|
||||||
db_service,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not profiler_config.profiler:
|
if not profiler_config.profiler:
|
||||||
return DefaultProfiler(
|
return DefaultProfiler(
|
||||||
profiler_interface=profiler_interface,
|
profiler_interface=profiler_interface,
|
||||||
|
metrics_registry=metrics_registry,
|
||||||
include_columns=get_include_columns(entity, table_config),
|
include_columns=get_include_columns(entity, table_config),
|
||||||
exclude_columns=get_exclude_columns(entity, table_config),
|
exclude_columns=get_exclude_columns(entity, table_config),
|
||||||
global_profiler_configuration=self.global_profiler_configuration,
|
global_profiler_configuration=self.global_profiler_configuration,
|
||||||
@ -197,9 +211,10 @@ class ProfilerSource(ProfilerSourceInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = (
|
metrics = (
|
||||||
[Metrics.get(name) for name in profiler_config.profiler.metrics]
|
[metrics_registry.get(name) for name in profiler_config.profiler.metrics]
|
||||||
if profiler_config.profiler.metrics
|
if profiler_config.profiler.metrics
|
||||||
else get_default_metrics(
|
else get_default_metrics(
|
||||||
|
metrics_registry=metrics_registry,
|
||||||
table=profiler_interface.table,
|
table=profiler_interface.table,
|
||||||
ometa_client=self.ometa_client,
|
ometa_client=self.ometa_client,
|
||||||
db_service=db_service,
|
db_service=db_service,
|
||||||
|
@ -16,6 +16,9 @@ Class defining the interface for the profiler source
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
||||||
|
DatabaseServiceProfilerPipeline,
|
||||||
|
)
|
||||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||||
|
|
||||||
|
|
||||||
@ -34,20 +37,12 @@ class ProfilerSourceInterface(ABC):
|
|||||||
"""Set the interface"""
|
"""Set the interface"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_profiler_interface(
|
|
||||||
self,
|
|
||||||
entity,
|
|
||||||
config,
|
|
||||||
profiler_config,
|
|
||||||
schema_entity,
|
|
||||||
database_entity,
|
|
||||||
db_service,
|
|
||||||
) -> ProfilerInterface:
|
|
||||||
"""Create the profiler interface"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_profiler_runner(self, entity, profiler_config):
|
def get_profiler_runner(self, entity, profiler_config):
|
||||||
"""Get the profiler runner"""
|
"""Get the profiler runner"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_processing_engine(config: DatabaseServiceProfilerPipeline):
|
||||||
|
"""Get the processing engine based on the configuration."""
|
||||||
|
return "Native"
|
||||||
|
@ -13,12 +13,13 @@ Interface for sampler
|
|||||||
"""
|
"""
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Set, Union
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
from metadata.generated.schema.entity.data.database import Database
|
from metadata.generated.schema.entity.data.database import Database
|
||||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||||
from metadata.generated.schema.entity.data.table import (
|
from metadata.generated.schema.entity.data.table import (
|
||||||
ColumnProfilerConfig,
|
ColumnProfilerConfig,
|
||||||
|
PartitionProfilerConfig,
|
||||||
Table,
|
Table,
|
||||||
TableData,
|
TableData,
|
||||||
)
|
)
|
||||||
@ -65,20 +66,20 @@ class SamplerInterface(ABC):
|
|||||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||||
exclude_columns: Optional[List[str]] = None,
|
exclude_columns: Optional[List[str]] = None,
|
||||||
sample_config: SampleConfig = SampleConfig(),
|
sample_config: SampleConfig = SampleConfig(),
|
||||||
partition_details: Optional[Dict] = None,
|
partition_details: Optional[PartitionProfilerConfig] = None,
|
||||||
sample_query: Optional[str] = None,
|
sample_query: Optional[str] = None,
|
||||||
storage_config: DataStorageConfig = None,
|
storage_config: Optional[DataStorageConfig] = None,
|
||||||
sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
|
sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
|
||||||
**__,
|
**__,
|
||||||
):
|
):
|
||||||
self.ometa_client = ometa_client
|
self.ometa_client = ometa_client
|
||||||
self._sample = None
|
self._sample = None
|
||||||
self._columns: Optional[List[SQALikeColumn]] = None
|
self._columns: List[SQALikeColumn] = []
|
||||||
self.sample_config = sample_config
|
self.sample_config = sample_config
|
||||||
|
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.include_columns = include_columns
|
self.include_columns = include_columns or []
|
||||||
self.exclude_columns = exclude_columns
|
self.exclude_columns = exclude_columns or []
|
||||||
self.sample_query = sample_query
|
self.sample_query = sample_query
|
||||||
self.sample_limit = sample_data_count
|
self.sample_limit = sample_data_count
|
||||||
self.partition_details = partition_details
|
self.partition_details = partition_details
|
||||||
@ -99,7 +100,7 @@ class SamplerInterface(ABC):
|
|||||||
table_config: Optional[TableConfig] = None,
|
table_config: Optional[TableConfig] = None,
|
||||||
storage_config: Optional[DataStorageConfig] = None,
|
storage_config: Optional[DataStorageConfig] = None,
|
||||||
default_sample_config: Optional[SampleConfig] = None,
|
default_sample_config: Optional[SampleConfig] = None,
|
||||||
default_sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
|
default_sample_data_count: int = SAMPLE_DATA_DEFAULT_COUNT,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "SamplerInterface":
|
) -> "SamplerInterface":
|
||||||
"""Create sampler"""
|
"""Create sampler"""
|
||||||
@ -165,16 +166,20 @@ class SamplerInterface(ABC):
|
|||||||
|
|
||||||
return self._columns
|
return self._columns
|
||||||
|
|
||||||
def _get_excluded_columns(self) -> Optional[Set[str]]:
|
def _get_excluded_columns(self) -> Set[str]:
|
||||||
"""Get excluded columns for table being profiled"""
|
"""Get excluded columns for table being profiled"""
|
||||||
if self.exclude_columns:
|
if self.exclude_columns:
|
||||||
return set(self.exclude_columns)
|
return set(self.exclude_columns)
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _get_included_columns(self) -> Optional[Set[str]]:
|
def _get_included_columns(self) -> Set[str]:
|
||||||
"""Get include columns for table being profiled"""
|
"""Get include columns for table being profiled"""
|
||||||
if self.include_columns:
|
if self.include_columns:
|
||||||
return {include_col.columnName for include_col in self.include_columns}
|
return {
|
||||||
|
include_col.columnName
|
||||||
|
for include_col in self.include_columns
|
||||||
|
if include_col.columnName
|
||||||
|
}
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -152,3 +152,38 @@ The dependency container is thread-safe and uses a reentrant lock (RLock) to sup
|
|||||||
3. Circular dependencies are not supported
|
3. Circular dependencies are not supported
|
||||||
4. Dependencies are always treated as optional and can be overridden
|
4. Dependencies are always treated as optional and can be overridden
|
||||||
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
|
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
|
||||||
|
|
||||||
|
### Class-Level Dependency Injection
|
||||||
|
|
||||||
|
For cases where you want to share dependencies across all instances of a class, you can use the `@inject_class_attributes` decorator:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: ClassVar[Inject[Database]]
|
||||||
|
cache: ClassVar[Inject[Cache]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> dict:
|
||||||
|
cache_key = f"user:{user_id}"
|
||||||
|
cached = cls.cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||||
|
|
||||||
|
# The dependencies are shared across all instances and accessed via class methods
|
||||||
|
user = UserService.get_user(user_id=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `@inject_class_attributes` decorator will:
|
||||||
|
1. Look for class attributes annotated with `ClassVar[Inject[Type]]`
|
||||||
|
2. Automatically inject the dependencies at the class level
|
||||||
|
3. Make the dependencies available to all instances and class methods
|
||||||
|
4. Raise `DependencyNotFoundError` if a required dependency is not registered
|
||||||
|
|
||||||
|
This is particularly useful for:
|
||||||
|
- Utility classes that don't need instance-specific state
|
||||||
|
- Services that should share the same dependencies across all instances
|
||||||
|
- Performance optimization when the same dependencies are used by multiple instances
|
@ -132,6 +132,15 @@ class DependencyContainer:
|
|||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
def get_key(self, dependency_type: Type[Any]) -> str:
|
||||||
|
"""
|
||||||
|
Get the key for a dependency.
|
||||||
|
"""
|
||||||
|
if get_origin(dependency_type) is type:
|
||||||
|
inner_type = get_args(dependency_type)[0]
|
||||||
|
return f"Type[{inner_type.__name__}]"
|
||||||
|
return dependency_type.__name__
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -145,10 +154,11 @@ class DependencyContainer:
|
|||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||||
|
container.register(Type[Metrics], lambda: Metrics) # For registering types themselves
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._dependencies[dependency_type.__name__] = dependency
|
self._dependencies[self.get_key(dependency_type)] = dependency
|
||||||
|
|
||||||
def override(
|
def override(
|
||||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||||
@ -169,7 +179,7 @@ class DependencyContainer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._overrides[dependency_type.__name__] = dependency
|
self._overrides[self.get_key(dependency_type)] = dependency
|
||||||
|
|
||||||
def remove_override(self, dependency_type: Type[T]) -> None:
|
def remove_override(self, dependency_type: Type[T]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -184,7 +194,7 @@ class DependencyContainer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._overrides.pop(dependency_type.__name__, None)
|
self._overrides.pop(self.get_key(dependency_type), None)
|
||||||
|
|
||||||
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
|
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
@ -206,10 +216,9 @@ class DependencyContainer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
type_name = dependency_type.__name__
|
factory = self._overrides.get(
|
||||||
factory = self._overrides.get(type_name) or self._dependencies.get(
|
self.get_key(dependency_type)
|
||||||
type_name
|
) or self._dependencies.get(self.get_key(dependency_type))
|
||||||
)
|
|
||||||
if factory is None:
|
if factory is None:
|
||||||
return None
|
return None
|
||||||
return factory()
|
return factory()
|
||||||
@ -243,8 +252,10 @@ class DependencyContainer:
|
|||||||
print("Database dependency is registered")
|
print("Database dependency is registered")
|
||||||
```"""
|
```"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
type_name = dependency_type.__name__
|
return (
|
||||||
return type_name in self._overrides or type_name in self._dependencies
|
self.get_key(dependency_type) in self._overrides
|
||||||
|
or self.get_key(dependency_type) in self._dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
|
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@ -341,3 +352,52 @@ def extract_inject_arg(tp: Any) -> Any:
|
|||||||
f"Type {tp} is not Inject or Optional[Inject]. "
|
f"Type {tp} is not Inject or Optional[Inject]. "
|
||||||
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
|
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def inject_class_attributes(cls: Type[Any]) -> Type[Any]:
|
||||||
|
"""
|
||||||
|
Decorator to inject dependencies into class-level (static) attributes based on type hints.
|
||||||
|
|
||||||
|
This decorator automatically injects dependencies into class attributes
|
||||||
|
based on their type hints. The dependencies are shared across all instances
|
||||||
|
of the class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to inject dependencies into
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class with dependencies injected into its class-level attributes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: ClassVar[Inject[Database]]
|
||||||
|
cache: ClassVar[Inject[Cache]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> dict:
|
||||||
|
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
container = DependencyContainer()
|
||||||
|
type_hints = get_type_hints(cls, include_extras=True)
|
||||||
|
|
||||||
|
# Inject dependencies into class attributes
|
||||||
|
for attr_name, attr_type in type_hints.items():
|
||||||
|
# Skip if attribute is already set
|
||||||
|
if hasattr(cls, attr_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if it's an Inject type
|
||||||
|
if is_inject_type(attr_type):
|
||||||
|
dependency_type = extract_inject_arg(attr_type)
|
||||||
|
dependency = container.get(dependency_type)
|
||||||
|
if dependency is None:
|
||||||
|
raise DependencyNotFoundError(
|
||||||
|
f"Dependency of type {dependency_type} not found in container. "
|
||||||
|
f"Make sure to register it using container.register({dependency_type.__name__}, ...)"
|
||||||
|
)
|
||||||
|
setattr(cls, attr_name, dependency)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from metadata.utils.dependency_injector.dependency_injector import (
|
from metadata.utils.dependency_injector.dependency_injector import (
|
||||||
@ -5,6 +7,7 @@ from metadata.utils.dependency_injector.dependency_injector import (
|
|||||||
DependencyNotFoundError,
|
DependencyNotFoundError,
|
||||||
Inject,
|
Inject,
|
||||||
inject,
|
inject,
|
||||||
|
inject_class_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -21,8 +24,10 @@ class Cache:
|
|||||||
def __init__(self, host: str):
|
def __init__(self, host: str):
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
||||||
def get(self, key: str) -> str:
|
def get(self, key: str) -> Optional[str]:
|
||||||
return f"Cache hit for {key}"
|
if key == "user:1":
|
||||||
|
return "Cache hit for user:1"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Test functions for injection
|
# Test functions for injection
|
||||||
@ -139,3 +144,99 @@ class TestInjectDecorator:
|
|||||||
custom_db = Database("postgresql://custom:5432")
|
custom_db = Database("postgresql://custom:5432")
|
||||||
result = get_user(user_id=1, db=custom_db)
|
result = get_user(user_id=1, db=custom_db)
|
||||||
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestInjectClassAttributes:
|
||||||
|
def test_inject_class_attributes_single_dependency(self):
|
||||||
|
container = DependencyContainer()
|
||||||
|
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||||
|
container.register(Database, db_factory)
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: Inject[Database]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> str:
|
||||||
|
if cls.db is None:
|
||||||
|
raise DependencyNotFoundError("Database dependency not found")
|
||||||
|
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||||
|
|
||||||
|
result = UserService.get_user(user_id=1)
|
||||||
|
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
||||||
|
|
||||||
|
def test_inject_class_attributes_multiple_dependencies(self):
|
||||||
|
container = DependencyContainer()
|
||||||
|
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||||
|
cache_factory = lambda: Cache("localhost")
|
||||||
|
|
||||||
|
container.register(Database, db_factory)
|
||||||
|
container.register(Cache, cache_factory)
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: Inject[Database]
|
||||||
|
cache: Inject[Cache]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> str:
|
||||||
|
if cls.db is None:
|
||||||
|
raise DependencyNotFoundError("Database dependency not found")
|
||||||
|
if cls.cache is None:
|
||||||
|
raise DependencyNotFoundError("Cache dependency not found")
|
||||||
|
cache_key = f"user:{user_id}"
|
||||||
|
cached = cls.cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||||
|
|
||||||
|
result = UserService.get_user(user_id=1)
|
||||||
|
assert result == "Cache hit for user:1"
|
||||||
|
|
||||||
|
def test_inject_class_attributes_missing_dependency(self):
|
||||||
|
container = DependencyContainer()
|
||||||
|
container.clear() # Ensure no dependencies are registered
|
||||||
|
|
||||||
|
with pytest.raises(DependencyNotFoundError):
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: Inject[Database]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> str:
|
||||||
|
if cls.db is None:
|
||||||
|
raise DependencyNotFoundError("Database dependency not found")
|
||||||
|
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||||
|
|
||||||
|
def test_inject_class_attributes_shared_dependencies(self):
|
||||||
|
container = DependencyContainer()
|
||||||
|
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||||
|
container.register(Database, db_factory)
|
||||||
|
|
||||||
|
@inject_class_attributes
|
||||||
|
class UserService:
|
||||||
|
db: Inject[Database]
|
||||||
|
counter: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def increment_counter(cls) -> int:
|
||||||
|
cls.counter += 1
|
||||||
|
return cls.counter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user(cls, user_id: int) -> str:
|
||||||
|
if cls.db is None:
|
||||||
|
raise DependencyNotFoundError("Database dependency not found")
|
||||||
|
return f"Counter: {cls.counter}, {cls.db.query(f'SELECT * FROM users WHERE id = {user_id}')}"
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = UserService.get_user(user_id=1)
|
||||||
|
assert result1 == "Counter: 0, Executed: SELECT * FROM users WHERE id = 1"
|
||||||
|
|
||||||
|
# Increment counter
|
||||||
|
UserService.increment_counter()
|
||||||
|
|
||||||
|
# Second call - counter should be shared
|
||||||
|
result2 = UserService.get_user(user_id=1)
|
||||||
|
assert result2 == "Counter: 1, Executed: SELECT * FROM users WHERE id = 1"
|
||||||
|
@ -201,6 +201,7 @@ class ProfilerTest(TestCase):
|
|||||||
"""
|
"""
|
||||||
simple = DefaultProfiler(
|
simple = DefaultProfiler(
|
||||||
profiler_interface=self.datalake_profiler_interface,
|
profiler_interface=self.datalake_profiler_interface,
|
||||||
|
metrics_registry=Metrics,
|
||||||
)
|
)
|
||||||
simple.compute_metrics()
|
simple.compute_metrics()
|
||||||
|
|
||||||
@ -337,6 +338,7 @@ class ProfilerTest(TestCase):
|
|||||||
|
|
||||||
default_profiler = DefaultProfiler(
|
default_profiler = DefaultProfiler(
|
||||||
profiler_interface=self.datalake_profiler_interface,
|
profiler_interface=self.datalake_profiler_interface,
|
||||||
|
metrics_registry=Metrics,
|
||||||
)
|
)
|
||||||
column_metrics = default_profiler._prepare_column_metrics()
|
column_metrics = default_profiler._prepare_column_metrics()
|
||||||
for metric in column_metrics:
|
for metric in column_metrics:
|
||||||
|
@ -50,6 +50,7 @@ from metadata.profiler.metrics.core import (
|
|||||||
QueryMetric,
|
QueryMetric,
|
||||||
StaticMetric,
|
StaticMetric,
|
||||||
)
|
)
|
||||||
|
from metadata.profiler.metrics.registry import Metrics
|
||||||
from metadata.profiler.metrics.static.row_count import RowCount
|
from metadata.profiler.metrics.static.row_count import RowCount
|
||||||
from metadata.profiler.processor.default import get_default_metrics
|
from metadata.profiler.processor.default import get_default_metrics
|
||||||
from metadata.sampler.pandas.sampler import DatalakeSampler
|
from metadata.sampler.pandas.sampler import DatalakeSampler
|
||||||
@ -197,7 +198,7 @@ class PandasInterfaceTest(TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
cls.table = User
|
cls.table = User
|
||||||
cls.metrics = get_default_metrics(cls.table)
|
cls.metrics = get_default_metrics(Metrics, cls.table)
|
||||||
cls.static_metrics = [
|
cls.static_metrics = [
|
||||||
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
|
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
|
||||||
]
|
]
|
||||||
|
@ -139,7 +139,7 @@ class ProfilerTest(TestCase):
|
|||||||
Check our pre-cooked profiler
|
Check our pre-cooked profiler
|
||||||
"""
|
"""
|
||||||
simple = DefaultProfiler(
|
simple = DefaultProfiler(
|
||||||
profiler_interface=self.sqa_profiler_interface,
|
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
|
||||||
)
|
)
|
||||||
simple.compute_metrics()
|
simple.compute_metrics()
|
||||||
|
|
||||||
@ -297,7 +297,7 @@ class ProfilerTest(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
simple = DefaultProfiler(
|
simple = DefaultProfiler(
|
||||||
profiler_interface=sqa_profiler_interface,
|
profiler_interface=sqa_profiler_interface, metrics_registry=Metrics
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
@ -316,7 +316,7 @@ class ProfilerTest(TestCase):
|
|||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
default_profiler = DefaultProfiler(
|
default_profiler = DefaultProfiler(
|
||||||
profiler_interface=self.sqa_profiler_interface,
|
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
|
||||||
)
|
)
|
||||||
|
|
||||||
column_metrics = default_profiler._prepare_column_metrics()
|
column_metrics = default_profiler._prepare_column_metrics()
|
||||||
|
@ -15,10 +15,10 @@ Test SQA Interface
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest import TestCase
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
from sqlalchemy import TEXT, Column, Integer, String, inspect
|
from sqlalchemy import TEXT, Column, Integer, String, inspect
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
@ -49,6 +49,7 @@ from metadata.profiler.metrics.core import (
|
|||||||
QueryMetric,
|
QueryMetric,
|
||||||
StaticMetric,
|
StaticMetric,
|
||||||
)
|
)
|
||||||
|
from metadata.profiler.metrics.registry import Metrics
|
||||||
from metadata.profiler.metrics.static.row_count import RowCount
|
from metadata.profiler.metrics.static.row_count import RowCount
|
||||||
from metadata.profiler.processor.default import get_default_metrics
|
from metadata.profiler.processor.default import get_default_metrics
|
||||||
from metadata.sampler.sqlalchemy.sampler import SQASampler
|
from metadata.sampler.sqlalchemy.sampler import SQASampler
|
||||||
@ -64,46 +65,9 @@ class User(declarative_base()):
|
|||||||
age = Column(Integer)
|
age = Column(Integer)
|
||||||
|
|
||||||
|
|
||||||
class SQAInterfaceTest(TestCase):
|
@pytest.fixture
|
||||||
def setUp(self) -> None:
|
def table_entity():
|
||||||
table_entity = Table(
|
return Table(
|
||||||
id=uuid4(),
|
|
||||||
name="user",
|
|
||||||
columns=[
|
|
||||||
EntityColumn(
|
|
||||||
name=ColumnName("id"),
|
|
||||||
dataType=DataType.INT,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
sqlite_conn = SQLiteConnection(
|
|
||||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
|
||||||
sampler = SQASampler(
|
|
||||||
service_connection_config=sqlite_conn,
|
|
||||||
ometa_client=None,
|
|
||||||
entity=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
|
||||||
self.sqa_profiler_interface = SQAProfilerInterface(
|
|
||||||
sqlite_conn, None, table_entity, None, sampler, 5, 43200
|
|
||||||
)
|
|
||||||
self.table = User
|
|
||||||
|
|
||||||
def test_init_interface(self):
|
|
||||||
"""Test we can instantiate our interface object correctly"""
|
|
||||||
|
|
||||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
self.sqa_profiler_interface._sampler = None
|
|
||||||
|
|
||||||
|
|
||||||
class SQAInterfaceTestMultiThread(TestCase):
|
|
||||||
table_entity = Table(
|
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
name="user",
|
name="user",
|
||||||
columns=[
|
columns=[
|
||||||
@ -113,12 +77,17 @@ class SQAInterfaceTestMultiThread(TestCase):
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
db_path = os.path.join(os.path.dirname(__file__), "test.db")
|
|
||||||
sqlite_conn = SQLiteConnection(
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sqlite_conn():
|
||||||
|
return SQLiteConnection(
|
||||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||||
databaseMode=db_path + "?check_same_thread=False",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sqa_profiler_interface(table_entity, sqlite_conn):
|
||||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||||
sampler = SQASampler(
|
sampler = SQASampler(
|
||||||
service_connection_config=sqlite_conn,
|
service_connection_config=sqlite_conn,
|
||||||
@ -126,140 +95,205 @@ class SQAInterfaceTestMultiThread(TestCase):
|
|||||||
entity=None,
|
entity=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
sqa_profiler_interface = SQAProfilerInterface(
|
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||||
sqlite_conn,
|
interface = SQAProfilerInterface(
|
||||||
|
sqlite_conn, None, table_entity, None, sampler, 5, 43200
|
||||||
|
)
|
||||||
|
return interface
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_interface(sqa_profiler_interface):
|
||||||
|
"""Test we can instantiate our interface object correctly"""
|
||||||
|
assert isinstance(sqa_profiler_interface.session, Session)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def db_path():
|
||||||
|
return os.path.join(os.path.dirname(__file__), "test.db")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def class_sqlite_conn(db_path):
|
||||||
|
return SQLiteConnection(
|
||||||
|
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||||
|
databaseMode=db_path + "?check_same_thread=False",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def class_table_entity():
|
||||||
|
return Table(
|
||||||
|
id=uuid4(),
|
||||||
|
name="user",
|
||||||
|
columns=[
|
||||||
|
EntityColumn(
|
||||||
|
name=ColumnName("id"),
|
||||||
|
dataType=DataType.INT,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def class_sqa_profiler_interface(class_sqlite_conn, class_table_entity):
|
||||||
|
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||||
|
sampler = SQASampler(
|
||||||
|
service_connection_config=class_sqlite_conn,
|
||||||
|
ometa_client=None,
|
||||||
|
entity=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
interface = SQAProfilerInterface(
|
||||||
|
class_sqlite_conn,
|
||||||
None,
|
None,
|
||||||
table_entity,
|
class_table_entity,
|
||||||
None,
|
None,
|
||||||
sampler,
|
sampler,
|
||||||
5,
|
5,
|
||||||
43200,
|
43200,
|
||||||
)
|
)
|
||||||
|
return interface
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls) -> None:
|
|
||||||
"""
|
|
||||||
Prepare Ingredients
|
|
||||||
"""
|
|
||||||
User.__table__.create(bind=cls.sqa_profiler_interface.session.get_bind())
|
|
||||||
|
|
||||||
data = [
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
|
def setup_database(class_sqa_profiler_interface):
|
||||||
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
|
"""Setup test database and tables"""
|
||||||
]
|
try:
|
||||||
cls.sqa_profiler_interface.session.add_all(data)
|
# Drop the table if it exists
|
||||||
cls.sqa_profiler_interface.session.commit()
|
User.__table__.drop(
|
||||||
cls.table = User
|
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
|
||||||
cls.metrics = get_default_metrics(cls.table)
|
)
|
||||||
cls.static_metrics = [
|
# Create the table
|
||||||
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
|
User.__table__.create(bind=class_sqa_profiler_interface.session.get_bind())
|
||||||
]
|
except Exception as e:
|
||||||
cls.composed_metrics = [
|
print(f"Error during table setup: {str(e)}")
|
||||||
metric for metric in cls.metrics if issubclass(metric, ComposedMetric)
|
raise e
|
||||||
]
|
|
||||||
cls.window_metrics = [
|
data = [
|
||||||
|
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
|
||||||
|
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
|
||||||
|
]
|
||||||
|
class_sqa_profiler_interface.session.add_all(data)
|
||||||
|
class_sqa_profiler_interface.session.commit()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
try:
|
||||||
|
User.__table__.drop(
|
||||||
|
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
|
||||||
|
)
|
||||||
|
class_sqa_profiler_interface.session.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during cleanup: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def metrics(class_sqa_profiler_interface):
|
||||||
|
metrics = get_default_metrics(Metrics, User)
|
||||||
|
return {
|
||||||
|
"all": metrics,
|
||||||
|
"static": [metric for metric in metrics if issubclass(metric, StaticMetric)],
|
||||||
|
"composed": [
|
||||||
|
metric for metric in metrics if issubclass(metric, ComposedMetric)
|
||||||
|
],
|
||||||
|
"window": [
|
||||||
metric
|
metric
|
||||||
for metric in cls.metrics
|
for metric in metrics
|
||||||
if issubclass(metric, StaticMetric) and metric.is_window_metric()
|
if issubclass(metric, StaticMetric) and metric.is_window_metric()
|
||||||
]
|
],
|
||||||
cls.query_metrics = [
|
"query": [
|
||||||
metric
|
metric
|
||||||
for metric in cls.metrics
|
for metric in metrics
|
||||||
if issubclass(metric, QueryMetric) and metric.is_col_metric()
|
if issubclass(metric, QueryMetric) and metric.is_col_metric()
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def test_init_interface(self):
|
|
||||||
"""Test we can instantiate our interface object correctly"""
|
|
||||||
|
|
||||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
def test_init_interface_multi_thread(class_sqa_profiler_interface):
|
||||||
|
"""Test we can instantiate our interface object correctly"""
|
||||||
|
assert isinstance(class_sqa_profiler_interface.session, Session)
|
||||||
|
|
||||||
def test_get_all_metrics(self):
|
|
||||||
table_metrics = [
|
def test_get_all_metrics(class_sqa_profiler_interface, metrics):
|
||||||
|
table_metrics = [
|
||||||
|
ThreadPoolMetrics(
|
||||||
|
metrics=[
|
||||||
|
metric
|
||||||
|
for metric in metrics["all"]
|
||||||
|
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||||
|
],
|
||||||
|
metric_type=MetricTypes.Table,
|
||||||
|
column=None,
|
||||||
|
table=User,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
column_metrics = []
|
||||||
|
query_metrics = []
|
||||||
|
window_metrics = []
|
||||||
|
for col in inspect(User).c:
|
||||||
|
column_metrics.append(
|
||||||
ThreadPoolMetrics(
|
ThreadPoolMetrics(
|
||||||
metrics=[
|
metrics=[
|
||||||
metric
|
metric
|
||||||
for metric in self.metrics
|
for metric in metrics["static"]
|
||||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
if metric.is_col_metric() and not metric.is_window_metric()
|
||||||
],
|
],
|
||||||
metric_type=MetricTypes.Table,
|
metric_type=MetricTypes.Static,
|
||||||
column=None,
|
column=col,
|
||||||
table=self.table,
|
table=User,
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
column_metrics = []
|
for query_metric in metrics["query"]:
|
||||||
query_metrics = []
|
query_metrics.append(
|
||||||
window_metrics = []
|
|
||||||
for col in inspect(User).c:
|
|
||||||
column_metrics.append(
|
|
||||||
ThreadPoolMetrics(
|
ThreadPoolMetrics(
|
||||||
metrics=[
|
metrics=query_metric,
|
||||||
metric
|
metric_type=MetricTypes.Query,
|
||||||
for metric in self.static_metrics
|
|
||||||
if metric.is_col_metric() and not metric.is_window_metric()
|
|
||||||
],
|
|
||||||
metric_type=MetricTypes.Static,
|
|
||||||
column=col,
|
column=col,
|
||||||
table=self.table,
|
table=User,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for query_metric in self.query_metrics:
|
window_metrics.append(
|
||||||
query_metrics.append(
|
ThreadPoolMetrics(
|
||||||
ThreadPoolMetrics(
|
metrics=[
|
||||||
metrics=query_metric,
|
metric for metric in metrics["window"] if metric.is_window_metric()
|
||||||
metric_type=MetricTypes.Query,
|
],
|
||||||
column=col,
|
metric_type=MetricTypes.Window,
|
||||||
table=self.table,
|
column=col,
|
||||||
)
|
table=User,
|
||||||
)
|
|
||||||
window_metrics.append(
|
|
||||||
ThreadPoolMetrics(
|
|
||||||
metrics=[
|
|
||||||
metric
|
|
||||||
for metric in self.window_metrics
|
|
||||||
if metric.is_window_metric()
|
|
||||||
],
|
|
||||||
metric_type=MetricTypes.Window,
|
|
||||||
column=col,
|
|
||||||
table=self.table,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
|
|
||||||
|
|
||||||
profile_results = self.sqa_profiler_interface.get_all_metrics(
|
|
||||||
all_metrics,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
column_profile = [
|
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
|
||||||
ColumnProfile(**profile_results["columns"].get(col.name))
|
|
||||||
for col in inspect(User).c
|
|
||||||
if profile_results["columns"].get(col.name)
|
|
||||||
]
|
|
||||||
|
|
||||||
table_profile = TableProfile(
|
profile_results = class_sqa_profiler_interface.get_all_metrics(
|
||||||
columnCount=profile_results["table"].get("columnCount"),
|
all_metrics,
|
||||||
rowCount=profile_results["table"].get(RowCount.name()),
|
)
|
||||||
timestamp=Timestamp(int(datetime.now().timestamp())),
|
|
||||||
)
|
|
||||||
|
|
||||||
profile_request = CreateTableProfileRequest(
|
column_profile = [
|
||||||
tableProfile=table_profile, columnProfile=column_profile
|
ColumnProfile(**profile_results["columns"].get(col.name))
|
||||||
)
|
for col in inspect(User).c
|
||||||
|
if profile_results["columns"].get(col.name)
|
||||||
|
]
|
||||||
|
|
||||||
assert profile_request.tableProfile.columnCount == 6
|
table_profile = TableProfile(
|
||||||
assert profile_request.tableProfile.rowCount == 2
|
columnCount=profile_results["table"].get("columnCount"),
|
||||||
name_column_profile = [
|
rowCount=profile_results["table"].get(RowCount.name()),
|
||||||
profile
|
timestamp=Timestamp(int(datetime.now().timestamp())),
|
||||||
for profile in profile_request.columnProfile
|
)
|
||||||
if profile.name == "name"
|
|
||||||
][0]
|
|
||||||
id_column_profile = [
|
|
||||||
profile for profile in profile_request.columnProfile if profile.name == "id"
|
|
||||||
][0]
|
|
||||||
assert name_column_profile.nullCount == 0
|
|
||||||
assert id_column_profile.median == 1.0
|
|
||||||
|
|
||||||
@classmethod
|
profile_request = CreateTableProfileRequest(
|
||||||
def tearDownClass(cls) -> None:
|
tableProfile=table_profile, columnProfile=column_profile
|
||||||
os.remove(cls.db_path)
|
)
|
||||||
return super().tearDownClass()
|
|
||||||
|
assert profile_request.tableProfile.columnCount == 6
|
||||||
|
assert profile_request.tableProfile.rowCount == 2
|
||||||
|
name_column_profile = [
|
||||||
|
profile for profile in profile_request.columnProfile if profile.name == "name"
|
||||||
|
][0]
|
||||||
|
id_column_profile = [
|
||||||
|
profile for profile in profile_request.columnProfile if profile.name == "id"
|
||||||
|
][0]
|
||||||
|
assert name_column_profile.nullCount == 0
|
||||||
|
assert id_column_profile.median == 1.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user