mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-15 04:26:59 +00:00
MINOR: Add injection to profiler (#21738)
* Initial implementation for our Connection Class * Implement the Initial Connection class * Add Unit Tests * Implement Dependency Injection for the Ingestion Framework * Fix Test * Fix Profile Test Connection * Add Injection to Metrics in Profiler * Add Injection to the Profiler * Fix UnitTests * Fix Pytests * Fix Tests * Fix types
This commit is contained in:
parent
bd948de115
commit
e79c54e6a5
@ -11,7 +11,14 @@
|
||||
"""
|
||||
OpenMetadata package initialization.
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.profiler.source.database.base.profiler_resolver import (
|
||||
DefaultProfilerResolver,
|
||||
ProfilerResolver,
|
||||
)
|
||||
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
|
||||
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
|
||||
|
||||
@ -20,3 +27,5 @@ container = DependencyContainer()
|
||||
|
||||
# Register the source loader
|
||||
container.register(SourceLoader, DefaultSourceLoader)
|
||||
container.register(Type[MetricRegistry], lambda: Metrics)
|
||||
container.register(Type[ProfilerResolver], lambda: DefaultProfilerResolver)
|
||||
|
@ -34,7 +34,7 @@ class TableRowInsertedCountToBeBetweenValidator(
|
||||
"""Validator for table row inserted count to be between test case"""
|
||||
|
||||
@staticmethod
|
||||
def _get_threshold_date(range_type: str, range_interval: int):
|
||||
def get_threshold_date_dt(range_type: str, range_interval: int) -> datetime:
|
||||
"""returns the threshold datetime in utc to count the numbers of rows inserted
|
||||
|
||||
Args:
|
||||
@ -55,7 +55,22 @@ class TableRowInsertedCountToBeBetweenValidator(
|
||||
threshold_date = threshold_date.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
return threshold_date.strftime("%Y%m%d%H%M%S")
|
||||
return threshold_date
|
||||
|
||||
@staticmethod
|
||||
def _get_threshold_date(
|
||||
range_type: str, range_interval: int, date_format: str = "%Y%m%d%H%M%S"
|
||||
):
|
||||
"""returns the threshold datetime in utc as string to count the numbers of rows inserted
|
||||
|
||||
Args:
|
||||
range_type (str): type of range (i.e. HOUR, DAY, MONTH, YEAR)
|
||||
range_interval (int): interval of range (i.e. 1, 2, 3, 4)
|
||||
date_format (str): format of the date (i.e. %Y%m%d%H%M%S, %Y-%m-%d %H:%M:%S)
|
||||
"""
|
||||
return TableRowInsertedCountToBeBetweenValidator.get_threshold_date_dt(
|
||||
range_type, range_interval
|
||||
).strftime(date_format)
|
||||
|
||||
def _get_column_name(self):
|
||||
"""returns the column name to be validated"""
|
||||
|
@ -113,6 +113,9 @@ def get_test_connection_fn(connection: BaseModel) -> Callable:
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
def get_connection(connection: BaseModel) -> Any:
|
||||
"""
|
||||
Main method to prepare a connection from
|
||||
|
@ -13,7 +13,7 @@ System table profiler
|
||||
"""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
from more_itertools import partition
|
||||
from pydantic import field_validator
|
||||
@ -25,7 +25,11 @@ from metadata.profiler.interface.sqlalchemy.stored_statistics_profiler import (
|
||||
StoredStatisticsSource,
|
||||
)
|
||||
from metadata.profiler.metrics.core import Metric
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
Inject,
|
||||
inject_class_attributes,
|
||||
)
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
|
||||
from metadata.utils.ssl_manager import get_ssl_connection
|
||||
@ -61,23 +65,28 @@ class TableStats(BaseModel):
|
||||
columns: Dict[str, ColumnStats] = {}
|
||||
|
||||
|
||||
@inject_class_attributes
|
||||
class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
||||
"""Trino system profile source"""
|
||||
|
||||
metric_stats_map: Dict[Metrics, str] = {
|
||||
Metrics.NULL_RATIO: "nulls_fractions",
|
||||
Metrics.DISTINCT_COUNT: "distinct_values_count",
|
||||
Metrics.ROW_COUNT: "row_count",
|
||||
Metrics.MAX: "high_value",
|
||||
Metrics.MIN: "low_value",
|
||||
}
|
||||
metrics: Inject[Type[MetricRegistry]]
|
||||
|
||||
metric_stats_by_name: Dict[str, str] = {
|
||||
k.name: v for k, v in metric_stats_map.items()
|
||||
}
|
||||
@classmethod
|
||||
def get_metric_stats_map(cls) -> Dict[MetricRegistry, str]:
|
||||
return {
|
||||
cls.metrics.NULL_RATIO: "nulls_fractions",
|
||||
cls.metrics.DISTINCT_COUNT: "distinct_values_count",
|
||||
cls.metrics.ROW_COUNT: "row_count",
|
||||
cls.metrics.MAX: "high_value",
|
||||
cls.metrics.MIN: "low_value",
|
||||
}
|
||||
|
||||
def get_statistics_metrics(self) -> Set[Metrics]:
|
||||
return set(self.metric_stats_map.keys())
|
||||
@classmethod
|
||||
def get_metric_stats_by_name(cls) -> Dict[str, str]:
|
||||
return {k.name: v for k, v in cls.get_metric_stats_map().items()}
|
||||
|
||||
def get_statistics_metrics(self) -> Set[MetricRegistry]:
|
||||
return set(self.get_metric_stats_map().keys())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@ -96,7 +105,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
||||
f"Column {column} not found in table {table_name}. Statistics might be stale or missing."
|
||||
)
|
||||
result = {
|
||||
m.name(): getattr(column_stats, self.metric_stats_by_name[m.name()])
|
||||
m.name(): getattr(column_stats, self.get_metric_stats_by_name()[m.name()])
|
||||
for m in metric
|
||||
}
|
||||
result.update(self.get_hybrid_statistics(table_stats, column_stats))
|
||||
@ -108,7 +117,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
||||
) -> dict:
|
||||
table_stats = self._get_cached_stats(schema, table_name)
|
||||
return {
|
||||
m.name(): getattr(table_stats, self.metric_stats_by_name[m.name()])
|
||||
m.name(): getattr(table_stats, self.get_metric_stats_by_name()[m.name()])
|
||||
for m in metric
|
||||
}
|
||||
|
||||
@ -159,7 +168,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
# trino stats are in fractions, so we need to convert them to counts (unlike our default profiler)
|
||||
Metrics.NULL_COUNT.name: (
|
||||
self.metrics.NULL_COUNT.name: (
|
||||
int(table_stats.row_count * column_stats.nulls_fraction)
|
||||
if None not in [table_stats.row_count, column_stats.nulls_fraction]
|
||||
else None
|
||||
|
@ -16,7 +16,7 @@ Run profiler metrics on the table
|
||||
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Type
|
||||
|
||||
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
|
||||
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
|
||||
@ -27,13 +27,29 @@ from metadata.generated.schema.entity.data.table import TableType
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.profiler.processor.runner import QueryRunner
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
)
|
||||
from metadata.utils.logger import profiler_interface_registry_logger
|
||||
|
||||
logger = profiler_interface_registry_logger()
|
||||
|
||||
|
||||
@inject
|
||||
def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None):
|
||||
if metrics is None:
|
||||
raise DependencyNotFoundError(
|
||||
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||
)
|
||||
return metrics.ROW_COUNT().name()
|
||||
|
||||
|
||||
COLUMN_COUNT = "columnCount"
|
||||
COLUMN_NAMES = "columnNames"
|
||||
ROW_COUNT = Metrics.ROW_COUNT().name()
|
||||
ROW_COUNT = get_row_count_metric()
|
||||
SIZE_IN_BYTES = "sizeInBytes"
|
||||
CREATE_DATETIME = "createDateTime"
|
||||
|
||||
@ -362,9 +378,15 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer):
|
||||
class MySQLTableMetricComputer(BaseTableMetricComputer):
|
||||
"""MySQL Table Metric Computer"""
|
||||
|
||||
def compute(self):
|
||||
@inject
|
||||
def compute(self, metrics: Inject[Type[MetricRegistry]] = None):
|
||||
"""compute table metrics for mysql"""
|
||||
|
||||
if metrics is None:
|
||||
raise DependencyNotFoundError(
|
||||
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||
)
|
||||
|
||||
columns = [
|
||||
Column("TABLE_ROWS").label(ROW_COUNT),
|
||||
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
|
||||
@ -390,7 +412,7 @@ class MySQLTableMetricComputer(BaseTableMetricComputer):
|
||||
res = res._asdict()
|
||||
# innodb row count is an estimate we need to patch the row count with COUNT(*)
|
||||
# https://dev.mysql.com/doc/refman/8.3/en/information-schema-innodb-tablestats-table.html
|
||||
row_count = self.runner.select_first_from_table(Metrics.ROW_COUNT().fn())
|
||||
row_count = self.runner.select_first_from_table(metrics.ROW_COUNT().fn())
|
||||
res.update({ROW_COUNT: row_count.rowCount})
|
||||
return res
|
||||
|
||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type
|
||||
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import Column
|
||||
@ -94,7 +94,7 @@ class Profiler(Generic[TMetric]):
|
||||
:param profile_sample: % of rows to use for sampling column metrics
|
||||
"""
|
||||
self.global_profiler_configuration: Optional[ProfilerConfiguration] = (
|
||||
global_profiler_configuration.config_value
|
||||
cast(ProfilerConfiguration, global_profiler_configuration.config_value)
|
||||
if global_profiler_configuration
|
||||
else None
|
||||
)
|
||||
|
@ -12,55 +12,54 @@
|
||||
"""
|
||||
Default simple profiler to use
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from metadata.generated.schema.configuration.profilerConfiguration import (
|
||||
ProfilerConfiguration,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.settings.settings import Settings
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.profiler.metrics.core import Metric, add_props
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.processor.core import Profiler
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
|
||||
|
||||
def get_default_metrics(
|
||||
metrics_registry: Type[MetricRegistry],
|
||||
table: DeclarativeMeta,
|
||||
ometa_client: Optional[OpenMetadata] = None,
|
||||
db_service: Optional[DatabaseService] = None,
|
||||
) -> List[Metric]:
|
||||
return [
|
||||
# Table Metrics
|
||||
Metrics.ROW_COUNT.value,
|
||||
add_props(table=table)(Metrics.COLUMN_COUNT.value),
|
||||
add_props(table=table)(Metrics.COLUMN_NAMES.value),
|
||||
metrics_registry.ROW_COUNT.value,
|
||||
add_props(table=table)(metrics_registry.COLUMN_COUNT.value),
|
||||
add_props(table=table)(metrics_registry.COLUMN_NAMES.value),
|
||||
# We'll use the ometa_client & db_service in case we need to fetch info to ES
|
||||
add_props(table=table, ometa_client=ometa_client, db_service=db_service)(
|
||||
Metrics.SYSTEM.value
|
||||
metrics_registry.SYSTEM.value
|
||||
),
|
||||
# Column Metrics
|
||||
Metrics.MEDIAN.value,
|
||||
Metrics.FIRST_QUARTILE.value,
|
||||
Metrics.THIRD_QUARTILE.value,
|
||||
Metrics.MEAN.value,
|
||||
Metrics.COUNT.value,
|
||||
Metrics.DISTINCT_COUNT.value,
|
||||
Metrics.DISTINCT_RATIO.value,
|
||||
Metrics.MIN.value,
|
||||
Metrics.MAX.value,
|
||||
Metrics.NULL_COUNT.value,
|
||||
Metrics.NULL_RATIO.value,
|
||||
Metrics.STDDEV.value,
|
||||
Metrics.SUM.value,
|
||||
Metrics.UNIQUE_COUNT.value,
|
||||
Metrics.UNIQUE_RATIO.value,
|
||||
Metrics.IQR.value,
|
||||
Metrics.HISTOGRAM.value,
|
||||
Metrics.NON_PARAMETRIC_SKEW.value,
|
||||
metrics_registry.MEDIAN.value,
|
||||
metrics_registry.FIRST_QUARTILE.value,
|
||||
metrics_registry.THIRD_QUARTILE.value,
|
||||
metrics_registry.MEAN.value,
|
||||
metrics_registry.COUNT.value,
|
||||
metrics_registry.DISTINCT_COUNT.value,
|
||||
metrics_registry.DISTINCT_RATIO.value,
|
||||
metrics_registry.MIN.value,
|
||||
metrics_registry.MAX.value,
|
||||
metrics_registry.NULL_COUNT.value,
|
||||
metrics_registry.NULL_RATIO.value,
|
||||
metrics_registry.STDDEV.value,
|
||||
metrics_registry.SUM.value,
|
||||
metrics_registry.UNIQUE_COUNT.value,
|
||||
metrics_registry.UNIQUE_RATIO.value,
|
||||
metrics_registry.IQR.value,
|
||||
metrics_registry.HISTOGRAM.value,
|
||||
metrics_registry.NON_PARAMETRIC_SKEW.value,
|
||||
]
|
||||
|
||||
|
||||
@ -74,12 +73,14 @@ class DefaultProfiler(Profiler):
|
||||
def __init__(
|
||||
self,
|
||||
profiler_interface: ProfilerInterface,
|
||||
metrics_registry: Type[MetricRegistry],
|
||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: Optional[List[str]] = None,
|
||||
global_profiler_configuration: Optional[ProfilerConfiguration] = None,
|
||||
global_profiler_configuration: Optional[Settings] = None,
|
||||
db_service=None,
|
||||
):
|
||||
_metrics = get_default_metrics(
|
||||
metrics_registry=metrics_registry,
|
||||
table=profiler_interface.table,
|
||||
ometa_client=profiler_interface.ometa_client,
|
||||
db_service=db_service,
|
||||
|
@ -33,22 +33,35 @@ from metadata.profiler.metrics.core import (
|
||||
SystemMetric,
|
||||
TMetric,
|
||||
)
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.orm.converter.converter_registry import converter_registry
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
)
|
||||
from metadata.utils.sqa_like_column import SQALikeColumn
|
||||
|
||||
|
||||
class MetricFilter:
|
||||
"""Metric filter class for profiler"""
|
||||
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Tuple[Type[TMetric]],
|
||||
global_profiler_config: Optional[ProfilerConfiguration] = None,
|
||||
table_profiler_config: Optional[TableProfilerConfig] = None,
|
||||
column_profiler_config: Optional[List[ColumnProfilerConfig]] = None,
|
||||
metrics_registry: Inject[Type[MetricRegistry]] = None,
|
||||
):
|
||||
if metrics_registry is None:
|
||||
raise DependencyNotFoundError(
|
||||
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||
)
|
||||
|
||||
self.metrics = metrics
|
||||
self.metrics_registry = metrics_registry
|
||||
self.global_profiler_config = global_profiler_config
|
||||
self.table_profiler_config = table_profiler_config
|
||||
self.column_profiler_config = column_profiler_config
|
||||
@ -196,7 +209,7 @@ class MetricFilter:
|
||||
|
||||
metrics = [
|
||||
Metric.value
|
||||
for Metric in Metrics
|
||||
for Metric in self.metrics_registry
|
||||
if Metric.value.name() in {mtrc.value for mtrc in col_dtype_config.metrics}
|
||||
and Metric.value in metrics
|
||||
]
|
||||
@ -240,7 +253,7 @@ class MetricFilter:
|
||||
|
||||
metrics = [
|
||||
Metric.value
|
||||
for Metric in Metrics
|
||||
for Metric in self.metrics_registry
|
||||
if Metric.value.name().lower() in {mtrc.lower() for mtrc in metric_names}
|
||||
and Metric.value in metrics
|
||||
]
|
||||
|
@ -13,20 +13,30 @@
|
||||
Models to map profiler definitions
|
||||
JSON workflows to the profiler
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
)
|
||||
|
||||
|
||||
def valid_metric(value: str):
|
||||
@inject
|
||||
def valid_metric(value: str, metrics: Inject[Type[MetricRegistry]] = None):
|
||||
"""
|
||||
Validate that the input metrics are correctly named
|
||||
and can be found in the Registry
|
||||
"""
|
||||
if not Metrics.get(value.upper()):
|
||||
if metrics is None:
|
||||
raise DependencyNotFoundError(
|
||||
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||
)
|
||||
if not metrics.get(value.upper()):
|
||||
raise ValueError(
|
||||
f"Metric name {value} is not a proper metric name from the Registry"
|
||||
)
|
||||
|
@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Type
|
||||
|
||||
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.sampler.sampler_interface import SamplerInterface
|
||||
from metadata.utils.service_spec.service_spec import (
|
||||
import_profiler_class,
|
||||
import_sampler_class,
|
||||
)
|
||||
|
||||
|
||||
class ProfilerResolver(ABC):
|
||||
"""Abstract class for the profiler resolver"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def resolve(
|
||||
processing_engine: str, service_type: ServiceType, source_type: str
|
||||
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
|
||||
"""Resolve the sampler and profiler based on the processing engine."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DefaultProfilerResolver(ProfilerResolver):
|
||||
"""Default profiler resolver"""
|
||||
|
||||
@staticmethod
|
||||
def resolve(
|
||||
processing_engine: str, service_type: ServiceType, source_type: str
|
||||
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
|
||||
"""Resolve the sampler and profiler based on the processing engine."""
|
||||
sampler_class = import_sampler_class(service_type, source_type=source_type)
|
||||
profiler_class = import_profiler_class(service_type, source_type=source_type)
|
||||
return sampler_class, profiler_class
|
@ -14,7 +14,7 @@ Base source for the profiler used to instantiate a profiler runner with
|
||||
its interface
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
from typing import Optional, Type, cast
|
||||
|
||||
from metadata.generated.schema.configuration.profilerConfiguration import (
|
||||
ProfilerConfiguration,
|
||||
@ -22,10 +22,7 @@ from metadata.generated.schema.configuration.profilerConfiguration import (
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseConnection,
|
||||
DatabaseService,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
|
||||
from metadata.generated.schema.entity.services.serviceType import ServiceType
|
||||
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
||||
DatabaseServiceProfilerPipeline,
|
||||
@ -36,9 +33,10 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.profiler.api.models import ProfilerProcessorConfig, TableConfig
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.processor.core import Profiler
|
||||
from metadata.profiler.processor.default import DefaultProfiler, get_default_metrics
|
||||
from metadata.profiler.registry import MetricRegistry
|
||||
from metadata.profiler.source.database.base.profiler_resolver import ProfilerResolver
|
||||
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
|
||||
from metadata.sampler.config import (
|
||||
get_config_for_table,
|
||||
@ -47,12 +45,13 @@ from metadata.sampler.config import (
|
||||
)
|
||||
from metadata.sampler.models import SampleConfig
|
||||
from metadata.sampler.sampler_interface import SamplerInterface
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
)
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.profiler_utils import get_context_entities
|
||||
from metadata.utils.service_spec.service_spec import (
|
||||
import_profiler_class,
|
||||
import_sampler_class,
|
||||
)
|
||||
|
||||
logger = profiler_logger()
|
||||
|
||||
@ -77,8 +76,7 @@ class ProfilerSource(ProfilerSourceInterface):
|
||||
self.ometa_client = ometa_client
|
||||
self._interface_type: str = config.source.type.lower()
|
||||
self._interface = None
|
||||
# We define this in create_profiler_interface to help us reuse
|
||||
# this method for the sampler, which does not have a DatabaseServiceProfilerPipeline
|
||||
|
||||
self.source_config = None
|
||||
self.global_profiler_configuration = global_profiler_configuration
|
||||
|
||||
@ -122,25 +120,34 @@ class ProfilerSource(ProfilerSourceInterface):
|
||||
|
||||
return config_copy
|
||||
|
||||
@inject
|
||||
def create_profiler_interface(
|
||||
self,
|
||||
entity: Table,
|
||||
config: Optional[TableConfig],
|
||||
profiler_config: Optional[ProfilerProcessorConfig],
|
||||
schema_entity: Optional[DatabaseSchema],
|
||||
database_entity: Optional[Database],
|
||||
db_service: Optional[DatabaseService],
|
||||
schema_entity: DatabaseSchema,
|
||||
database_entity: Database,
|
||||
profiler_resolver: Inject[Type[ProfilerResolver]] = None,
|
||||
) -> ProfilerInterface:
|
||||
"""Create sqlalchemy profiler interface"""
|
||||
"""Create the appropriate profiler interface based on processing engine."""
|
||||
if profiler_resolver is None:
|
||||
raise DependencyNotFoundError(
|
||||
"ProfilerResolver dependency not found. Please ensure the ProfilerResolver is properly registered."
|
||||
)
|
||||
|
||||
# NOTE: For some reason I do not understand, if we instantiate this on the __init__ method, we break the
|
||||
# autoclassification workflow. This should be fixed. There should not be an impact on AutoClassification.
|
||||
# We have an issue to track this here: https://github.com/open-metadata/OpenMetadata/issues/21790
|
||||
self.source_config = DatabaseServiceProfilerPipeline.model_validate(
|
||||
self.config.source.sourceConfig.config
|
||||
)
|
||||
profiler_class = import_profiler_class(
|
||||
ServiceType.Database, source_type=self._interface_type
|
||||
)
|
||||
sampler_class = import_sampler_class(
|
||||
ServiceType.Database, source_type=self._interface_type
|
||||
|
||||
sampler_class, profiler_class = profiler_resolver.resolve(
|
||||
processing_engine=self.get_processing_engine(self.source_config),
|
||||
service_type=ServiceType.Database,
|
||||
source_type=self._interface_type,
|
||||
)
|
||||
|
||||
# This is shared between the sampler and profiler interfaces
|
||||
sampler_interface: SamplerInterface = sampler_class.create(
|
||||
service_connection_config=self.service_conn_config,
|
||||
@ -155,6 +162,8 @@ class ProfilerSource(ProfilerSourceInterface):
|
||||
samplingMethodType=self.source_config.samplingMethodType,
|
||||
randomizedSample=self.source_config.randomizedSample,
|
||||
),
|
||||
# TODO: Change this when we have the processing engine configuration implemented. Right now it does nothing.
|
||||
processing_engine=self.get_processing_engine(self.source_config),
|
||||
)
|
||||
|
||||
profiler_interface: ProfilerInterface = profiler_class.create(
|
||||
@ -168,28 +177,33 @@ class ProfilerSource(ProfilerSourceInterface):
|
||||
self.interface = profiler_interface
|
||||
return self.interface
|
||||
|
||||
@inject
|
||||
def get_profiler_runner(
|
||||
self, entity: Table, profiler_config: ProfilerProcessorConfig
|
||||
self,
|
||||
entity: Table,
|
||||
profiler_config: ProfilerProcessorConfig,
|
||||
metrics_registry: Inject[Type[MetricRegistry]] = None,
|
||||
) -> Profiler:
|
||||
"""
|
||||
Returns the runner for the profiler
|
||||
"""
|
||||
if metrics_registry is None:
|
||||
raise DependencyNotFoundError(
|
||||
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
|
||||
)
|
||||
|
||||
table_config = get_config_for_table(entity, profiler_config)
|
||||
schema_entity, database_entity, db_service = get_context_entities(
|
||||
entity=entity, metadata=self.ometa_client
|
||||
)
|
||||
profiler_interface = self.create_profiler_interface(
|
||||
entity,
|
||||
table_config,
|
||||
profiler_config,
|
||||
schema_entity,
|
||||
database_entity,
|
||||
db_service,
|
||||
entity, table_config, schema_entity, database_entity
|
||||
)
|
||||
|
||||
if not profiler_config.profiler:
|
||||
return DefaultProfiler(
|
||||
profiler_interface=profiler_interface,
|
||||
metrics_registry=metrics_registry,
|
||||
include_columns=get_include_columns(entity, table_config),
|
||||
exclude_columns=get_exclude_columns(entity, table_config),
|
||||
global_profiler_configuration=self.global_profiler_configuration,
|
||||
@ -197,9 +211,10 @@ class ProfilerSource(ProfilerSourceInterface):
|
||||
)
|
||||
|
||||
metrics = (
|
||||
[Metrics.get(name) for name in profiler_config.profiler.metrics]
|
||||
[metrics_registry.get(name) for name in profiler_config.profiler.metrics]
|
||||
if profiler_config.profiler.metrics
|
||||
else get_default_metrics(
|
||||
metrics_registry=metrics_registry,
|
||||
table=profiler_interface.table,
|
||||
ometa_client=self.ometa_client,
|
||||
db_service=db_service,
|
||||
|
@ -16,6 +16,9 @@ Class defining the interface for the profiler source
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
||||
DatabaseServiceProfilerPipeline,
|
||||
)
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
|
||||
|
||||
@ -34,20 +37,12 @@ class ProfilerSourceInterface(ABC):
|
||||
"""Set the interface"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_profiler_interface(
|
||||
self,
|
||||
entity,
|
||||
config,
|
||||
profiler_config,
|
||||
schema_entity,
|
||||
database_entity,
|
||||
db_service,
|
||||
) -> ProfilerInterface:
|
||||
"""Create the profiler interface"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_profiler_runner(self, entity, profiler_config):
|
||||
"""Get the profiler runner"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_processing_engine(config: DatabaseServiceProfilerPipeline):
|
||||
"""Get the processing engine based on the configuration."""
|
||||
return "Native"
|
||||
|
@ -13,12 +13,13 @@ Interface for sampler
|
||||
"""
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
ColumnProfilerConfig,
|
||||
PartitionProfilerConfig,
|
||||
Table,
|
||||
TableData,
|
||||
)
|
||||
@ -65,20 +66,20 @@ class SamplerInterface(ABC):
|
||||
include_columns: Optional[List[ColumnProfilerConfig]] = None,
|
||||
exclude_columns: Optional[List[str]] = None,
|
||||
sample_config: SampleConfig = SampleConfig(),
|
||||
partition_details: Optional[Dict] = None,
|
||||
partition_details: Optional[PartitionProfilerConfig] = None,
|
||||
sample_query: Optional[str] = None,
|
||||
storage_config: DataStorageConfig = None,
|
||||
storage_config: Optional[DataStorageConfig] = None,
|
||||
sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
|
||||
**__,
|
||||
):
|
||||
self.ometa_client = ometa_client
|
||||
self._sample = None
|
||||
self._columns: Optional[List[SQALikeColumn]] = None
|
||||
self._columns: List[SQALikeColumn] = []
|
||||
self.sample_config = sample_config
|
||||
|
||||
self.entity = entity
|
||||
self.include_columns = include_columns
|
||||
self.exclude_columns = exclude_columns
|
||||
self.include_columns = include_columns or []
|
||||
self.exclude_columns = exclude_columns or []
|
||||
self.sample_query = sample_query
|
||||
self.sample_limit = sample_data_count
|
||||
self.partition_details = partition_details
|
||||
@ -99,7 +100,7 @@ class SamplerInterface(ABC):
|
||||
table_config: Optional[TableConfig] = None,
|
||||
storage_config: Optional[DataStorageConfig] = None,
|
||||
default_sample_config: Optional[SampleConfig] = None,
|
||||
default_sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
|
||||
default_sample_data_count: int = SAMPLE_DATA_DEFAULT_COUNT,
|
||||
**kwargs,
|
||||
) -> "SamplerInterface":
|
||||
"""Create sampler"""
|
||||
@ -165,16 +166,20 @@ class SamplerInterface(ABC):
|
||||
|
||||
return self._columns
|
||||
|
||||
def _get_excluded_columns(self) -> Optional[Set[str]]:
|
||||
def _get_excluded_columns(self) -> Set[str]:
|
||||
"""Get excluded columns for table being profiled"""
|
||||
if self.exclude_columns:
|
||||
return set(self.exclude_columns)
|
||||
return set()
|
||||
|
||||
def _get_included_columns(self) -> Optional[Set[str]]:
|
||||
def _get_included_columns(self) -> Set[str]:
|
||||
"""Get include columns for table being profiled"""
|
||||
if self.include_columns:
|
||||
return {include_col.columnName for include_col in self.include_columns}
|
||||
return {
|
||||
include_col.columnName
|
||||
for include_col in self.include_columns
|
||||
if include_col.columnName
|
||||
}
|
||||
return set()
|
||||
|
||||
@property
|
||||
|
@ -151,4 +151,39 @@ The dependency container is thread-safe and uses a reentrant lock (RLock) to sup
|
||||
2. The system uses type names as keys, so different types with the same name will conflict
|
||||
3. Circular dependencies are not supported
|
||||
4. Dependencies are always treated as optional and can be overridden
|
||||
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
|
||||
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
|
||||
|
||||
### Class-Level Dependency Injection
|
||||
|
||||
For cases where you want to share dependencies across all instances of a class, you can use the `@inject_class_attributes` decorator:
|
||||
|
||||
```python
|
||||
from typing import ClassVar
|
||||
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: ClassVar[Inject[Database]]
|
||||
cache: ClassVar[Inject[Cache]]
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> dict:
|
||||
cache_key = f"user:{user_id}"
|
||||
cached = cls.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# The dependencies are shared across all instances and accessed via class methods
|
||||
user = UserService.get_user(user_id=1)
|
||||
```
|
||||
|
||||
The `@inject_class_attributes` decorator will:
|
||||
1. Look for class attributes annotated with `ClassVar[Inject[Type]]`
|
||||
2. Automatically inject the dependencies at the class level
|
||||
3. Make the dependencies available to all instances and class methods
|
||||
4. Raise `DependencyNotFoundError` if a required dependency is not registered
|
||||
|
||||
This is particularly useful for:
|
||||
- Utility classes that don't need instance-specific state
|
||||
- Services that should share the same dependencies across all instances
|
||||
- Performance optimization when the same dependencies are used by multiple instances
|
@ -132,6 +132,15 @@ class DependencyContainer:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_key(self, dependency_type: Type[Any]) -> str:
|
||||
"""
|
||||
Get the key for a dependency.
|
||||
"""
|
||||
if get_origin(dependency_type) is type:
|
||||
inner_type = get_args(dependency_type)[0]
|
||||
return f"Type[{inner_type.__name__}]"
|
||||
return dependency_type.__name__
|
||||
|
||||
def register(
|
||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||
) -> None:
|
||||
@ -145,10 +154,11 @@ class DependencyContainer:
|
||||
Example:
|
||||
```python
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
container.register(Type[Metrics], lambda: Metrics) # For registering types themselves
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._dependencies[dependency_type.__name__] = dependency
|
||||
self._dependencies[self.get_key(dependency_type)] = dependency
|
||||
|
||||
def override(
|
||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||
@ -169,7 +179,7 @@ class DependencyContainer:
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._overrides[dependency_type.__name__] = dependency
|
||||
self._overrides[self.get_key(dependency_type)] = dependency
|
||||
|
||||
def remove_override(self, dependency_type: Type[T]) -> None:
|
||||
"""
|
||||
@ -184,7 +194,7 @@ class DependencyContainer:
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._overrides.pop(dependency_type.__name__, None)
|
||||
self._overrides.pop(self.get_key(dependency_type), None)
|
||||
|
||||
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
|
||||
"""
|
||||
@ -206,10 +216,9 @@ class DependencyContainer:
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
type_name = dependency_type.__name__
|
||||
factory = self._overrides.get(type_name) or self._dependencies.get(
|
||||
type_name
|
||||
)
|
||||
factory = self._overrides.get(
|
||||
self.get_key(dependency_type)
|
||||
) or self._dependencies.get(self.get_key(dependency_type))
|
||||
if factory is None:
|
||||
return None
|
||||
return factory()
|
||||
@ -243,8 +252,10 @@ class DependencyContainer:
|
||||
print("Database dependency is registered")
|
||||
```"""
|
||||
with self._lock:
|
||||
type_name = dependency_type.__name__
|
||||
return type_name in self._overrides or type_name in self._dependencies
|
||||
return (
|
||||
self.get_key(dependency_type) in self._overrides
|
||||
or self.get_key(dependency_type) in self._dependencies
|
||||
)
|
||||
|
||||
|
||||
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@ -341,3 +352,52 @@ def extract_inject_arg(tp: Any) -> Any:
|
||||
f"Type {tp} is not Inject or Optional[Inject]. "
|
||||
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
|
||||
)
|
||||
|
||||
|
||||
def inject_class_attributes(cls: Type[Any]) -> Type[Any]:
|
||||
"""
|
||||
Decorator to inject dependencies into class-level (static) attributes based on type hints.
|
||||
|
||||
This decorator automatically injects dependencies into class attributes
|
||||
based on their type hints. The dependencies are shared across all instances
|
||||
of the class.
|
||||
|
||||
Args:
|
||||
cls: The class to inject dependencies into
|
||||
|
||||
Returns:
|
||||
A class with dependencies injected into its class-level attributes
|
||||
|
||||
Example:
|
||||
```python
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: ClassVar[Inject[Database]]
|
||||
cache: ClassVar[Inject[Cache]]
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> dict:
|
||||
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
```
|
||||
"""
|
||||
container = DependencyContainer()
|
||||
type_hints = get_type_hints(cls, include_extras=True)
|
||||
|
||||
# Inject dependencies into class attributes
|
||||
for attr_name, attr_type in type_hints.items():
|
||||
# Skip if attribute is already set
|
||||
if hasattr(cls, attr_name):
|
||||
continue
|
||||
|
||||
# Check if it's an Inject type
|
||||
if is_inject_type(attr_type):
|
||||
dependency_type = extract_inject_arg(attr_type)
|
||||
dependency = container.get(dependency_type)
|
||||
if dependency is None:
|
||||
raise DependencyNotFoundError(
|
||||
f"Dependency of type {dependency_type} not found in container. "
|
||||
f"Make sure to register it using container.register({dependency_type.__name__}, ...)"
|
||||
)
|
||||
setattr(cls, attr_name, dependency)
|
||||
|
||||
return cls
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
@ -5,6 +7,7 @@ from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
inject_class_attributes,
|
||||
)
|
||||
|
||||
|
||||
@ -21,8 +24,10 @@ class Cache:
|
||||
def __init__(self, host: str):
|
||||
self.host = host
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
return f"Cache hit for {key}"
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
if key == "user:1":
|
||||
return "Cache hit for user:1"
|
||||
return None
|
||||
|
||||
|
||||
# Test functions for injection
|
||||
@ -139,3 +144,99 @@ class TestInjectDecorator:
|
||||
custom_db = Database("postgresql://custom:5432")
|
||||
result = get_user(user_id=1, db=custom_db)
|
||||
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
||||
|
||||
|
||||
class TestInjectClassAttributes:
|
||||
def test_inject_class_attributes_single_dependency(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
container.register(Database, db_factory)
|
||||
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: Inject[Database]
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> str:
|
||||
if cls.db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
result = UserService.get_user(user_id=1)
|
||||
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
||||
|
||||
def test_inject_class_attributes_multiple_dependencies(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
cache_factory = lambda: Cache("localhost")
|
||||
|
||||
container.register(Database, db_factory)
|
||||
container.register(Cache, cache_factory)
|
||||
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: Inject[Database]
|
||||
cache: Inject[Cache]
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> str:
|
||||
if cls.db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
if cls.cache is None:
|
||||
raise DependencyNotFoundError("Cache dependency not found")
|
||||
cache_key = f"user:{user_id}"
|
||||
cached = cls.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
result = UserService.get_user(user_id=1)
|
||||
assert result == "Cache hit for user:1"
|
||||
|
||||
def test_inject_class_attributes_missing_dependency(self):
|
||||
container = DependencyContainer()
|
||||
container.clear() # Ensure no dependencies are registered
|
||||
|
||||
with pytest.raises(DependencyNotFoundError):
|
||||
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: Inject[Database]
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> str:
|
||||
if cls.db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
def test_inject_class_attributes_shared_dependencies(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
container.register(Database, db_factory)
|
||||
|
||||
@inject_class_attributes
|
||||
class UserService:
|
||||
db: Inject[Database]
|
||||
counter: int = 0
|
||||
|
||||
@classmethod
|
||||
def increment_counter(cls) -> int:
|
||||
cls.counter += 1
|
||||
return cls.counter
|
||||
|
||||
@classmethod
|
||||
def get_user(cls, user_id: int) -> str:
|
||||
if cls.db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
return f"Counter: {cls.counter}, {cls.db.query(f'SELECT * FROM users WHERE id = {user_id}')}"
|
||||
|
||||
# First call
|
||||
result1 = UserService.get_user(user_id=1)
|
||||
assert result1 == "Counter: 0, Executed: SELECT * FROM users WHERE id = 1"
|
||||
|
||||
# Increment counter
|
||||
UserService.increment_counter()
|
||||
|
||||
# Second call - counter should be shared
|
||||
result2 = UserService.get_user(user_id=1)
|
||||
assert result2 == "Counter: 1, Executed: SELECT * FROM users WHERE id = 1"
|
||||
|
@ -201,6 +201,7 @@ class ProfilerTest(TestCase):
|
||||
"""
|
||||
simple = DefaultProfiler(
|
||||
profiler_interface=self.datalake_profiler_interface,
|
||||
metrics_registry=Metrics,
|
||||
)
|
||||
simple.compute_metrics()
|
||||
|
||||
@ -337,6 +338,7 @@ class ProfilerTest(TestCase):
|
||||
|
||||
default_profiler = DefaultProfiler(
|
||||
profiler_interface=self.datalake_profiler_interface,
|
||||
metrics_registry=Metrics,
|
||||
)
|
||||
column_metrics = default_profiler._prepare_column_metrics()
|
||||
for metric in column_metrics:
|
||||
|
@ -50,6 +50,7 @@ from metadata.profiler.metrics.core import (
|
||||
QueryMetric,
|
||||
StaticMetric,
|
||||
)
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.metrics.static.row_count import RowCount
|
||||
from metadata.profiler.processor.default import get_default_metrics
|
||||
from metadata.sampler.pandas.sampler import DatalakeSampler
|
||||
@ -197,7 +198,7 @@ class PandasInterfaceTest(TestCase):
|
||||
"""
|
||||
|
||||
cls.table = User
|
||||
cls.metrics = get_default_metrics(cls.table)
|
||||
cls.metrics = get_default_metrics(Metrics, cls.table)
|
||||
cls.static_metrics = [
|
||||
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
|
||||
]
|
||||
|
@ -139,7 +139,7 @@ class ProfilerTest(TestCase):
|
||||
Check our pre-cooked profiler
|
||||
"""
|
||||
simple = DefaultProfiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
|
||||
)
|
||||
simple.compute_metrics()
|
||||
|
||||
@ -297,7 +297,7 @@ class ProfilerTest(TestCase):
|
||||
)
|
||||
|
||||
simple = DefaultProfiler(
|
||||
profiler_interface=sqa_profiler_interface,
|
||||
profiler_interface=sqa_profiler_interface, metrics_registry=Metrics
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
@ -316,7 +316,7 @@ class ProfilerTest(TestCase):
|
||||
) # type: ignore
|
||||
|
||||
default_profiler = DefaultProfiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
|
||||
)
|
||||
|
||||
column_metrics = default_profiler._prepare_column_metrics()
|
||||
|
@ -15,10 +15,10 @@ Test SQA Interface
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import TEXT, Column, Integer, String, inspect
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.orm.session import Session
|
||||
@ -49,6 +49,7 @@ from metadata.profiler.metrics.core import (
|
||||
QueryMetric,
|
||||
StaticMetric,
|
||||
)
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.metrics.static.row_count import RowCount
|
||||
from metadata.profiler.processor.default import get_default_metrics
|
||||
from metadata.sampler.sqlalchemy.sampler import SQASampler
|
||||
@ -64,46 +65,9 @@ class User(declarative_base()):
|
||||
age = Column(Integer)
|
||||
|
||||
|
||||
class SQAInterfaceTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName("id"),
|
||||
dataType=DataType.INT,
|
||||
)
|
||||
],
|
||||
)
|
||||
sqlite_conn = SQLiteConnection(
|
||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||
)
|
||||
|
||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||
sampler = SQASampler(
|
||||
service_connection_config=sqlite_conn,
|
||||
ometa_client=None,
|
||||
entity=None,
|
||||
)
|
||||
|
||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||
self.sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn, None, table_entity, None, sampler, 5, 43200
|
||||
)
|
||||
self.table = User
|
||||
|
||||
def test_init_interface(self):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
|
||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.sqa_profiler_interface._sampler = None
|
||||
|
||||
|
||||
class SQAInterfaceTestMultiThread(TestCase):
|
||||
table_entity = Table(
|
||||
@pytest.fixture
|
||||
def table_entity():
|
||||
return Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
@ -113,12 +77,17 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
)
|
||||
],
|
||||
)
|
||||
db_path = os.path.join(os.path.dirname(__file__), "test.db")
|
||||
sqlite_conn = SQLiteConnection(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_conn():
|
||||
return SQLiteConnection(
|
||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||
databaseMode=db_path + "?check_same_thread=False",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqa_profiler_interface(table_entity, sqlite_conn):
|
||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||
sampler = SQASampler(
|
||||
service_connection_config=sqlite_conn,
|
||||
@ -126,140 +95,205 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
entity=None,
|
||||
)
|
||||
|
||||
sqa_profiler_interface = SQAProfilerInterface(
|
||||
sqlite_conn,
|
||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||
interface = SQAProfilerInterface(
|
||||
sqlite_conn, None, table_entity, None, sampler, 5, 43200
|
||||
)
|
||||
return interface
|
||||
|
||||
|
||||
def test_init_interface(sqa_profiler_interface):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
assert isinstance(sqa_profiler_interface.session, Session)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def db_path():
|
||||
return os.path.join(os.path.dirname(__file__), "test.db")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def class_sqlite_conn(db_path):
|
||||
return SQLiteConnection(
|
||||
scheme=SQLiteScheme.sqlite_pysqlite,
|
||||
databaseMode=db_path + "?check_same_thread=False",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def class_table_entity():
|
||||
return Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName("id"),
|
||||
dataType=DataType.INT,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def class_sqa_profiler_interface(class_sqlite_conn, class_table_entity):
|
||||
with patch.object(SQASampler, "build_table_orm", return_value=User):
|
||||
sampler = SQASampler(
|
||||
service_connection_config=class_sqlite_conn,
|
||||
ometa_client=None,
|
||||
entity=None,
|
||||
)
|
||||
|
||||
interface = SQAProfilerInterface(
|
||||
class_sqlite_conn,
|
||||
None,
|
||||
table_entity,
|
||||
class_table_entity,
|
||||
None,
|
||||
sampler,
|
||||
5,
|
||||
43200,
|
||||
)
|
||||
return interface
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
"""
|
||||
Prepare Ingredients
|
||||
"""
|
||||
User.__table__.create(bind=cls.sqa_profiler_interface.session.get_bind())
|
||||
|
||||
data = [
|
||||
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
|
||||
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
|
||||
]
|
||||
cls.sqa_profiler_interface.session.add_all(data)
|
||||
cls.sqa_profiler_interface.session.commit()
|
||||
cls.table = User
|
||||
cls.metrics = get_default_metrics(cls.table)
|
||||
cls.static_metrics = [
|
||||
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
|
||||
]
|
||||
cls.composed_metrics = [
|
||||
metric for metric in cls.metrics if issubclass(metric, ComposedMetric)
|
||||
]
|
||||
cls.window_metrics = [
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def setup_database(class_sqa_profiler_interface):
|
||||
"""Setup test database and tables"""
|
||||
try:
|
||||
# Drop the table if it exists
|
||||
User.__table__.drop(
|
||||
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
|
||||
)
|
||||
# Create the table
|
||||
User.__table__.create(bind=class_sqa_profiler_interface.session.get_bind())
|
||||
except Exception as e:
|
||||
print(f"Error during table setup: {str(e)}")
|
||||
raise e
|
||||
|
||||
data = [
|
||||
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
|
||||
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
|
||||
]
|
||||
class_sqa_profiler_interface.session.add_all(data)
|
||||
class_sqa_profiler_interface.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
User.__table__.drop(
|
||||
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
|
||||
)
|
||||
class_sqa_profiler_interface.session.close()
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def metrics(class_sqa_profiler_interface):
|
||||
metrics = get_default_metrics(Metrics, User)
|
||||
return {
|
||||
"all": metrics,
|
||||
"static": [metric for metric in metrics if issubclass(metric, StaticMetric)],
|
||||
"composed": [
|
||||
metric for metric in metrics if issubclass(metric, ComposedMetric)
|
||||
],
|
||||
"window": [
|
||||
metric
|
||||
for metric in cls.metrics
|
||||
for metric in metrics
|
||||
if issubclass(metric, StaticMetric) and metric.is_window_metric()
|
||||
]
|
||||
cls.query_metrics = [
|
||||
],
|
||||
"query": [
|
||||
metric
|
||||
for metric in cls.metrics
|
||||
for metric in metrics
|
||||
if issubclass(metric, QueryMetric) and metric.is_col_metric()
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
def test_init_interface(self):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
|
||||
assert isinstance(self.sqa_profiler_interface.session, Session)
|
||||
def test_init_interface_multi_thread(class_sqa_profiler_interface):
|
||||
"""Test we can instantiate our interface object correctly"""
|
||||
assert isinstance(class_sqa_profiler_interface.session, Session)
|
||||
|
||||
def test_get_all_metrics(self):
|
||||
table_metrics = [
|
||||
|
||||
def test_get_all_metrics(class_sqa_profiler_interface, metrics):
|
||||
table_metrics = [
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in metrics["all"]
|
||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||
],
|
||||
metric_type=MetricTypes.Table,
|
||||
column=None,
|
||||
table=User,
|
||||
)
|
||||
]
|
||||
column_metrics = []
|
||||
query_metrics = []
|
||||
window_metrics = []
|
||||
for col in inspect(User).c:
|
||||
column_metrics.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.metrics
|
||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||
for metric in metrics["static"]
|
||||
if metric.is_col_metric() and not metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Table,
|
||||
column=None,
|
||||
table=self.table,
|
||||
metric_type=MetricTypes.Static,
|
||||
column=col,
|
||||
table=User,
|
||||
)
|
||||
]
|
||||
column_metrics = []
|
||||
query_metrics = []
|
||||
window_metrics = []
|
||||
for col in inspect(User).c:
|
||||
column_metrics.append(
|
||||
)
|
||||
for query_metric in metrics["query"]:
|
||||
query_metrics.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.static_metrics
|
||||
if metric.is_col_metric() and not metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Static,
|
||||
metrics=query_metric,
|
||||
metric_type=MetricTypes.Query,
|
||||
column=col,
|
||||
table=self.table,
|
||||
table=User,
|
||||
)
|
||||
)
|
||||
for query_metric in self.query_metrics:
|
||||
query_metrics.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=query_metric,
|
||||
metric_type=MetricTypes.Query,
|
||||
column=col,
|
||||
table=self.table,
|
||||
)
|
||||
)
|
||||
window_metrics.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.window_metrics
|
||||
if metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Window,
|
||||
column=col,
|
||||
table=self.table,
|
||||
)
|
||||
window_metrics.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric for metric in metrics["window"] if metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Window,
|
||||
column=col,
|
||||
table=User,
|
||||
)
|
||||
|
||||
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
|
||||
|
||||
profile_results = self.sqa_profiler_interface.get_all_metrics(
|
||||
all_metrics,
|
||||
)
|
||||
|
||||
column_profile = [
|
||||
ColumnProfile(**profile_results["columns"].get(col.name))
|
||||
for col in inspect(User).c
|
||||
if profile_results["columns"].get(col.name)
|
||||
]
|
||||
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
|
||||
|
||||
table_profile = TableProfile(
|
||||
columnCount=profile_results["table"].get("columnCount"),
|
||||
rowCount=profile_results["table"].get(RowCount.name()),
|
||||
timestamp=Timestamp(int(datetime.now().timestamp())),
|
||||
)
|
||||
profile_results = class_sqa_profiler_interface.get_all_metrics(
|
||||
all_metrics,
|
||||
)
|
||||
|
||||
profile_request = CreateTableProfileRequest(
|
||||
tableProfile=table_profile, columnProfile=column_profile
|
||||
)
|
||||
column_profile = [
|
||||
ColumnProfile(**profile_results["columns"].get(col.name))
|
||||
for col in inspect(User).c
|
||||
if profile_results["columns"].get(col.name)
|
||||
]
|
||||
|
||||
assert profile_request.tableProfile.columnCount == 6
|
||||
assert profile_request.tableProfile.rowCount == 2
|
||||
name_column_profile = [
|
||||
profile
|
||||
for profile in profile_request.columnProfile
|
||||
if profile.name == "name"
|
||||
][0]
|
||||
id_column_profile = [
|
||||
profile for profile in profile_request.columnProfile if profile.name == "id"
|
||||
][0]
|
||||
assert name_column_profile.nullCount == 0
|
||||
assert id_column_profile.median == 1.0
|
||||
table_profile = TableProfile(
|
||||
columnCount=profile_results["table"].get("columnCount"),
|
||||
rowCount=profile_results["table"].get(RowCount.name()),
|
||||
timestamp=Timestamp(int(datetime.now().timestamp())),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
os.remove(cls.db_path)
|
||||
return super().tearDownClass()
|
||||
profile_request = CreateTableProfileRequest(
|
||||
tableProfile=table_profile, columnProfile=column_profile
|
||||
)
|
||||
|
||||
assert profile_request.tableProfile.columnCount == 6
|
||||
assert profile_request.tableProfile.rowCount == 2
|
||||
name_column_profile = [
|
||||
profile for profile in profile_request.columnProfile if profile.name == "name"
|
||||
][0]
|
||||
id_column_profile = [
|
||||
profile for profile in profile_request.columnProfile if profile.name == "id"
|
||||
][0]
|
||||
assert name_column_profile.nullCount == 0
|
||||
assert id_column_profile.median == 1.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user