MINOR: Add injection to profiler (#21738)

* Initial implementation for our Connection Class

* Implement the Initial Connection class

* Add Unit Tests

* Implement Dependency Injection for the Ingestion Framework

* Fix Test

* Fix Profile Test Connection

* Add Injection to Metrics in Profiler

* Add Injection to the Profiler

* Fix UnitTests

* Fix Pytests

* Fix Tests

* Fix types
This commit is contained in:
IceS2 2025-06-17 19:01:00 +02:00 committed by GitHub
parent bd948de115
commit e79c54e6a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 644 additions and 279 deletions

View File

@ -11,7 +11,14 @@
"""
OpenMetadata package initialization.
"""
from typing import Type
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.registry import MetricRegistry
from metadata.profiler.source.database.base.profiler_resolver import (
DefaultProfilerResolver,
ProfilerResolver,
)
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
@ -20,3 +27,5 @@ container = DependencyContainer()
# Register the source loader
container.register(SourceLoader, DefaultSourceLoader)
container.register(Type[MetricRegistry], lambda: Metrics)
container.register(Type[ProfilerResolver], lambda: DefaultProfilerResolver)

View File

@ -34,7 +34,7 @@ class TableRowInsertedCountToBeBetweenValidator(
"""Validator for table row inserted count to be between test case"""
@staticmethod
def _get_threshold_date(range_type: str, range_interval: int):
def get_threshold_date_dt(range_type: str, range_interval: int) -> datetime:
"""returns the threshold datetime in utc to count the numbers of rows inserted
Args:
@ -55,7 +55,22 @@ class TableRowInsertedCountToBeBetweenValidator(
threshold_date = threshold_date.replace(
hour=0, minute=0, second=0, microsecond=0
)
return threshold_date.strftime("%Y%m%d%H%M%S")
return threshold_date
@staticmethod
def _get_threshold_date(
range_type: str, range_interval: int, date_format: str = "%Y%m%d%H%M%S"
):
"""returns the threshold datetime in utc as string to count the numbers of rows inserted
Args:
range_type (str): type of range (i.e. HOUR, DAY, MONTH, YEAR)
range_interval (int): interval of range (i.e. 1, 2, 3, 4)
date_format (str): format of the date (i.e. %Y%m%d%H%M%S, %Y-%m-%d %H:%M:%S)
"""
return TableRowInsertedCountToBeBetweenValidator.get_threshold_date_dt(
range_type, range_interval
).strftime(date_format)
def _get_column_name(self):
"""returns the column name to be validated"""

View File

@ -113,6 +113,9 @@ def get_test_connection_fn(connection: BaseModel) -> Callable:
)
# ------------------------------------------------------------
def get_connection(connection: BaseModel) -> Any:
"""
Main method to prepare a connection from

View File

@ -13,7 +13,7 @@ System table profiler
"""
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Type, Union
from more_itertools import partition
from pydantic import field_validator
@ -25,7 +25,11 @@ from metadata.profiler.interface.sqlalchemy.stored_statistics_profiler import (
StoredStatisticsSource,
)
from metadata.profiler.metrics.core import Metric
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.registry import MetricRegistry
from metadata.utils.dependency_injector.dependency_injector import (
Inject,
inject_class_attributes,
)
from metadata.utils.logger import profiler_logger
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
from metadata.utils.ssl_manager import get_ssl_connection
@ -61,23 +65,28 @@ class TableStats(BaseModel):
columns: Dict[str, ColumnStats] = {}
@inject_class_attributes
class TrinoStoredStatisticsSource(StoredStatisticsSource):
"""Trino system profile source"""
metric_stats_map: Dict[Metrics, str] = {
Metrics.NULL_RATIO: "nulls_fractions",
Metrics.DISTINCT_COUNT: "distinct_values_count",
Metrics.ROW_COUNT: "row_count",
Metrics.MAX: "high_value",
Metrics.MIN: "low_value",
}
metrics: Inject[Type[MetricRegistry]]
metric_stats_by_name: Dict[str, str] = {
k.name: v for k, v in metric_stats_map.items()
}
@classmethod
def get_metric_stats_map(cls) -> Dict[MetricRegistry, str]:
return {
cls.metrics.NULL_RATIO: "nulls_fractions",
cls.metrics.DISTINCT_COUNT: "distinct_values_count",
cls.metrics.ROW_COUNT: "row_count",
cls.metrics.MAX: "high_value",
cls.metrics.MIN: "low_value",
}
def get_statistics_metrics(self) -> Set[Metrics]:
return set(self.metric_stats_map.keys())
@classmethod
def get_metric_stats_by_name(cls) -> Dict[str, str]:
return {k.name: v for k, v in cls.get_metric_stats_map().items()}
def get_statistics_metrics(self) -> Set[MetricRegistry]:
return set(self.get_metric_stats_map().keys())
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -96,7 +105,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
f"Column {column} not found in table {table_name}. Statistics might be stale or missing."
)
result = {
m.name(): getattr(column_stats, self.metric_stats_by_name[m.name()])
m.name(): getattr(column_stats, self.get_metric_stats_by_name()[m.name()])
for m in metric
}
result.update(self.get_hybrid_statistics(table_stats, column_stats))
@ -108,7 +117,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
) -> dict:
table_stats = self._get_cached_stats(schema, table_name)
return {
m.name(): getattr(table_stats, self.metric_stats_by_name[m.name()])
m.name(): getattr(table_stats, self.get_metric_stats_by_name()[m.name()])
for m in metric
}
@ -159,7 +168,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource):
) -> Dict[str, Any]:
return {
# trino stats are in fractions, so we need to convert them to counts (unlike our default profiler)
Metrics.NULL_COUNT.name: (
self.metrics.NULL_COUNT.name: (
int(table_stats.row_count * column_stats.nulls_fraction)
if None not in [table_stats.row_count, column_stats.nulls_fraction]
else None

View File

@ -16,7 +16,7 @@ Run profiler metrics on the table
import traceback
from abc import ABC, abstractmethod
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Type
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
@ -27,13 +27,29 @@ from metadata.generated.schema.entity.data.table import TableType
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner
from metadata.profiler.registry import MetricRegistry
from metadata.utils.dependency_injector.dependency_injector import (
DependencyNotFoundError,
Inject,
inject,
)
from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger()
@inject
def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None):
if metrics is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
return metrics.ROW_COUNT().name()
COLUMN_COUNT = "columnCount"
COLUMN_NAMES = "columnNames"
ROW_COUNT = Metrics.ROW_COUNT().name()
ROW_COUNT = get_row_count_metric()
SIZE_IN_BYTES = "sizeInBytes"
CREATE_DATETIME = "createDateTime"
@ -362,9 +378,15 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer):
class MySQLTableMetricComputer(BaseTableMetricComputer):
"""MySQL Table Metric Computer"""
def compute(self):
@inject
def compute(self, metrics: Inject[Type[MetricRegistry]] = None):
"""compute table metrics for mysql"""
if metrics is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
columns = [
Column("TABLE_ROWS").label(ROW_COUNT),
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
@ -390,7 +412,7 @@ class MySQLTableMetricComputer(BaseTableMetricComputer):
res = res._asdict()
# innodb row count is an estimate we need to patch the row count with COUNT(*)
# https://dev.mysql.com/doc/refman/8.3/en/information-schema-innodb-tablestats-table.html
row_count = self.runner.select_first_from_table(Metrics.ROW_COUNT().fn())
row_count = self.runner.select_first_from_table(metrics.ROW_COUNT().fn())
res.update({ROW_COUNT: row_count.rowCount})
return res

View File

@ -16,7 +16,7 @@ from __future__ import annotations
import traceback
from datetime import datetime, timezone
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, cast
from pydantic import ValidationError
from sqlalchemy import Column
@ -94,7 +94,7 @@ class Profiler(Generic[TMetric]):
:param profile_sample: % of rows to use for sampling column metrics
"""
self.global_profiler_configuration: Optional[ProfilerConfiguration] = (
global_profiler_configuration.config_value
cast(ProfilerConfiguration, global_profiler_configuration.config_value)
if global_profiler_configuration
else None
)

View File

@ -12,55 +12,54 @@
"""
Default simple profiler to use
"""
from typing import List, Optional
from typing import List, Optional, Type
from sqlalchemy.orm import DeclarativeMeta
from metadata.generated.schema.configuration.profilerConfiguration import (
ProfilerConfiguration,
)
from metadata.generated.schema.entity.data.table import ColumnProfilerConfig
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.settings.settings import Settings
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import Metric, add_props
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.processor.core import Profiler
from metadata.profiler.registry import MetricRegistry
def get_default_metrics(
metrics_registry: Type[MetricRegistry],
table: DeclarativeMeta,
ometa_client: Optional[OpenMetadata] = None,
db_service: Optional[DatabaseService] = None,
) -> List[Metric]:
return [
# Table Metrics
Metrics.ROW_COUNT.value,
add_props(table=table)(Metrics.COLUMN_COUNT.value),
add_props(table=table)(Metrics.COLUMN_NAMES.value),
metrics_registry.ROW_COUNT.value,
add_props(table=table)(metrics_registry.COLUMN_COUNT.value),
add_props(table=table)(metrics_registry.COLUMN_NAMES.value),
# We'll use the ometa_client & db_service in case we need to fetch info to ES
add_props(table=table, ometa_client=ometa_client, db_service=db_service)(
Metrics.SYSTEM.value
metrics_registry.SYSTEM.value
),
# Column Metrics
Metrics.MEDIAN.value,
Metrics.FIRST_QUARTILE.value,
Metrics.THIRD_QUARTILE.value,
Metrics.MEAN.value,
Metrics.COUNT.value,
Metrics.DISTINCT_COUNT.value,
Metrics.DISTINCT_RATIO.value,
Metrics.MIN.value,
Metrics.MAX.value,
Metrics.NULL_COUNT.value,
Metrics.NULL_RATIO.value,
Metrics.STDDEV.value,
Metrics.SUM.value,
Metrics.UNIQUE_COUNT.value,
Metrics.UNIQUE_RATIO.value,
Metrics.IQR.value,
Metrics.HISTOGRAM.value,
Metrics.NON_PARAMETRIC_SKEW.value,
metrics_registry.MEDIAN.value,
metrics_registry.FIRST_QUARTILE.value,
metrics_registry.THIRD_QUARTILE.value,
metrics_registry.MEAN.value,
metrics_registry.COUNT.value,
metrics_registry.DISTINCT_COUNT.value,
metrics_registry.DISTINCT_RATIO.value,
metrics_registry.MIN.value,
metrics_registry.MAX.value,
metrics_registry.NULL_COUNT.value,
metrics_registry.NULL_RATIO.value,
metrics_registry.STDDEV.value,
metrics_registry.SUM.value,
metrics_registry.UNIQUE_COUNT.value,
metrics_registry.UNIQUE_RATIO.value,
metrics_registry.IQR.value,
metrics_registry.HISTOGRAM.value,
metrics_registry.NON_PARAMETRIC_SKEW.value,
]
@ -74,12 +73,14 @@ class DefaultProfiler(Profiler):
def __init__(
self,
profiler_interface: ProfilerInterface,
metrics_registry: Type[MetricRegistry],
include_columns: Optional[List[ColumnProfilerConfig]] = None,
exclude_columns: Optional[List[str]] = None,
global_profiler_configuration: Optional[ProfilerConfiguration] = None,
global_profiler_configuration: Optional[Settings] = None,
db_service=None,
):
_metrics = get_default_metrics(
metrics_registry=metrics_registry,
table=profiler_interface.table,
ometa_client=profiler_interface.ometa_client,
db_service=db_service,

View File

@ -33,22 +33,35 @@ from metadata.profiler.metrics.core import (
SystemMetric,
TMetric,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.converter.converter_registry import converter_registry
from metadata.profiler.registry import MetricRegistry
from metadata.utils.dependency_injector.dependency_injector import (
DependencyNotFoundError,
Inject,
inject,
)
from metadata.utils.sqa_like_column import SQALikeColumn
class MetricFilter:
"""Metric filter class for profiler"""
@inject
def __init__(
self,
metrics: Tuple[Type[TMetric]],
global_profiler_config: Optional[ProfilerConfiguration] = None,
table_profiler_config: Optional[TableProfilerConfig] = None,
column_profiler_config: Optional[List[ColumnProfilerConfig]] = None,
metrics_registry: Inject[Type[MetricRegistry]] = None,
):
if metrics_registry is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
self.metrics = metrics
self.metrics_registry = metrics_registry
self.global_profiler_config = global_profiler_config
self.table_profiler_config = table_profiler_config
self.column_profiler_config = column_profiler_config
@ -196,7 +209,7 @@ class MetricFilter:
metrics = [
Metric.value
for Metric in Metrics
for Metric in self.metrics_registry
if Metric.value.name() in {mtrc.value for mtrc in col_dtype_config.metrics}
and Metric.value in metrics
]
@ -240,7 +253,7 @@ class MetricFilter:
metrics = [
Metric.value
for Metric in Metrics
for Metric in self.metrics_registry
if Metric.value.name().lower() in {mtrc.lower() for mtrc in metric_names}
and Metric.value in metrics
]

View File

@ -13,20 +13,30 @@
Models to map profiler definitions
JSON workflows to the profiler
"""
from typing import List, Optional
from typing import List, Optional, Type
from pydantic import BaseModel, BeforeValidator
from typing_extensions import Annotated
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.registry import MetricRegistry
from metadata.utils.dependency_injector.dependency_injector import (
DependencyNotFoundError,
Inject,
inject,
)
def valid_metric(value: str):
@inject
def valid_metric(value: str, metrics: Inject[Type[MetricRegistry]] = None):
"""
Validate that the input metrics are correctly named
and can be found in the Registry
"""
if not Metrics.get(value.upper()):
if metrics is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
if not metrics.get(value.upper()):
raise ValueError(
f"Metric name {value} is not a proper metric name from the Registry"
)

View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from typing import Tuple, Type
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.service_spec.service_spec import (
import_profiler_class,
import_sampler_class,
)
class ProfilerResolver(ABC):
"""Abstract class for the profiler resolver"""
@staticmethod
@abstractmethod
def resolve(
processing_engine: str, service_type: ServiceType, source_type: str
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
"""Resolve the sampler and profiler based on the processing engine."""
raise NotImplementedError
class DefaultProfilerResolver(ProfilerResolver):
"""Default profiler resolver"""
@staticmethod
def resolve(
processing_engine: str, service_type: ServiceType, source_type: str
) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]:
"""Resolve the sampler and profiler based on the processing engine."""
sampler_class = import_sampler_class(service_type, source_type=source_type)
profiler_class = import_profiler_class(service_type, source_type=source_type)
return sampler_class, profiler_class

View File

@ -14,7 +14,7 @@ Base source for the profiler used to instantiate a profiler runner with
its interface
"""
from copy import deepcopy
from typing import Optional, cast
from typing import Optional, Type, cast
from metadata.generated.schema.configuration.profilerConfiguration import (
ProfilerConfiguration,
@ -22,10 +22,7 @@ from metadata.generated.schema.configuration.profilerConfiguration import (
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
DatabaseServiceProfilerPipeline,
@ -36,9 +33,10 @@ from metadata.generated.schema.metadataIngestion.workflow import (
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.api.models import ProfilerProcessorConfig, TableConfig
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.processor.core import Profiler
from metadata.profiler.processor.default import DefaultProfiler, get_default_metrics
from metadata.profiler.registry import MetricRegistry
from metadata.profiler.source.database.base.profiler_resolver import ProfilerResolver
from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface
from metadata.sampler.config import (
get_config_for_table,
@ -47,12 +45,13 @@ from metadata.sampler.config import (
)
from metadata.sampler.models import SampleConfig
from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.dependency_injector.dependency_injector import (
DependencyNotFoundError,
Inject,
inject,
)
from metadata.utils.logger import profiler_logger
from metadata.utils.profiler_utils import get_context_entities
from metadata.utils.service_spec.service_spec import (
import_profiler_class,
import_sampler_class,
)
logger = profiler_logger()
@ -77,8 +76,7 @@ class ProfilerSource(ProfilerSourceInterface):
self.ometa_client = ometa_client
self._interface_type: str = config.source.type.lower()
self._interface = None
# We define this in create_profiler_interface to help us reuse
# this method for the sampler, which does not have a DatabaseServiceProfilerPipeline
self.source_config = None
self.global_profiler_configuration = global_profiler_configuration
@ -122,25 +120,34 @@ class ProfilerSource(ProfilerSourceInterface):
return config_copy
@inject
def create_profiler_interface(
self,
entity: Table,
config: Optional[TableConfig],
profiler_config: Optional[ProfilerProcessorConfig],
schema_entity: Optional[DatabaseSchema],
database_entity: Optional[Database],
db_service: Optional[DatabaseService],
schema_entity: DatabaseSchema,
database_entity: Database,
profiler_resolver: Inject[Type[ProfilerResolver]] = None,
) -> ProfilerInterface:
"""Create sqlalchemy profiler interface"""
"""Create the appropriate profiler interface based on processing engine."""
if profiler_resolver is None:
raise DependencyNotFoundError(
"ProfilerResolver dependency not found. Please ensure the ProfilerResolver is properly registered."
)
# NOTE: For some reason I do not understand, if we instantiate this on the __init__ method, we break the
# autoclassification workflow. This should be fixed. There should not be an impact on AutoClassification.
# We have an issue to track this here: https://github.com/open-metadata/OpenMetadata/issues/21790
self.source_config = DatabaseServiceProfilerPipeline.model_validate(
self.config.source.sourceConfig.config
)
profiler_class = import_profiler_class(
ServiceType.Database, source_type=self._interface_type
)
sampler_class = import_sampler_class(
ServiceType.Database, source_type=self._interface_type
sampler_class, profiler_class = profiler_resolver.resolve(
processing_engine=self.get_processing_engine(self.source_config),
service_type=ServiceType.Database,
source_type=self._interface_type,
)
# This is shared between the sampler and profiler interfaces
sampler_interface: SamplerInterface = sampler_class.create(
service_connection_config=self.service_conn_config,
@ -155,6 +162,8 @@ class ProfilerSource(ProfilerSourceInterface):
samplingMethodType=self.source_config.samplingMethodType,
randomizedSample=self.source_config.randomizedSample,
),
# TODO: Change this when we have the processing engine configuration implemented. Right now it does nothing.
processing_engine=self.get_processing_engine(self.source_config),
)
profiler_interface: ProfilerInterface = profiler_class.create(
@ -168,28 +177,33 @@ class ProfilerSource(ProfilerSourceInterface):
self.interface = profiler_interface
return self.interface
@inject
def get_profiler_runner(
self, entity: Table, profiler_config: ProfilerProcessorConfig
self,
entity: Table,
profiler_config: ProfilerProcessorConfig,
metrics_registry: Inject[Type[MetricRegistry]] = None,
) -> Profiler:
"""
Returns the runner for the profiler
"""
if metrics_registry is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
table_config = get_config_for_table(entity, profiler_config)
schema_entity, database_entity, db_service = get_context_entities(
entity=entity, metadata=self.ometa_client
)
profiler_interface = self.create_profiler_interface(
entity,
table_config,
profiler_config,
schema_entity,
database_entity,
db_service,
entity, table_config, schema_entity, database_entity
)
if not profiler_config.profiler:
return DefaultProfiler(
profiler_interface=profiler_interface,
metrics_registry=metrics_registry,
include_columns=get_include_columns(entity, table_config),
exclude_columns=get_exclude_columns(entity, table_config),
global_profiler_configuration=self.global_profiler_configuration,
@ -197,9 +211,10 @@ class ProfilerSource(ProfilerSourceInterface):
)
metrics = (
[Metrics.get(name) for name in profiler_config.profiler.metrics]
[metrics_registry.get(name) for name in profiler_config.profiler.metrics]
if profiler_config.profiler.metrics
else get_default_metrics(
metrics_registry=metrics_registry,
table=profiler_interface.table,
ometa_client=self.ometa_client,
db_service=db_service,

View File

@ -16,6 +16,9 @@ Class defining the interface for the profiler source
from abc import ABC, abstractmethod
from typing import Optional
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
DatabaseServiceProfilerPipeline,
)
from metadata.profiler.interface.profiler_interface import ProfilerInterface
@ -34,20 +37,12 @@ class ProfilerSourceInterface(ABC):
"""Set the interface"""
raise NotImplementedError
@abstractmethod
def create_profiler_interface(
self,
entity,
config,
profiler_config,
schema_entity,
database_entity,
db_service,
) -> ProfilerInterface:
"""Create the profiler interface"""
raise NotImplementedError
@abstractmethod
def get_profiler_runner(self, entity, profiler_config):
"""Get the profiler runner"""
raise NotImplementedError
@staticmethod
def get_processing_engine(config: DatabaseServiceProfilerPipeline):
"""Get the processing engine based on the configuration."""
return "Native"

View File

@ -13,12 +13,13 @@ Interface for sampler
"""
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Union
from typing import List, Optional, Set, Union
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import (
ColumnProfilerConfig,
PartitionProfilerConfig,
Table,
TableData,
)
@ -65,20 +66,20 @@ class SamplerInterface(ABC):
include_columns: Optional[List[ColumnProfilerConfig]] = None,
exclude_columns: Optional[List[str]] = None,
sample_config: SampleConfig = SampleConfig(),
partition_details: Optional[Dict] = None,
partition_details: Optional[PartitionProfilerConfig] = None,
sample_query: Optional[str] = None,
storage_config: DataStorageConfig = None,
storage_config: Optional[DataStorageConfig] = None,
sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
**__,
):
self.ometa_client = ometa_client
self._sample = None
self._columns: Optional[List[SQALikeColumn]] = None
self._columns: List[SQALikeColumn] = []
self.sample_config = sample_config
self.entity = entity
self.include_columns = include_columns
self.exclude_columns = exclude_columns
self.include_columns = include_columns or []
self.exclude_columns = exclude_columns or []
self.sample_query = sample_query
self.sample_limit = sample_data_count
self.partition_details = partition_details
@ -99,7 +100,7 @@ class SamplerInterface(ABC):
table_config: Optional[TableConfig] = None,
storage_config: Optional[DataStorageConfig] = None,
default_sample_config: Optional[SampleConfig] = None,
default_sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
default_sample_data_count: int = SAMPLE_DATA_DEFAULT_COUNT,
**kwargs,
) -> "SamplerInterface":
"""Create sampler"""
@ -165,16 +166,20 @@ class SamplerInterface(ABC):
return self._columns
def _get_excluded_columns(self) -> Optional[Set[str]]:
def _get_excluded_columns(self) -> Set[str]:
"""Get excluded columns for table being profiled"""
if self.exclude_columns:
return set(self.exclude_columns)
return set()
def _get_included_columns(self) -> Optional[Set[str]]:
def _get_included_columns(self) -> Set[str]:
"""Get include columns for table being profiled"""
if self.include_columns:
return {include_col.columnName for include_col in self.include_columns}
return {
include_col.columnName
for include_col in self.include_columns
if include_col.columnName
}
return set()
@property

View File

@ -152,3 +152,38 @@ The dependency container is thread-safe and uses a reentrant lock (RLock) to sup
3. Circular dependencies are not supported
4. Dependencies are always treated as optional and can be overridden
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
### Class-Level Dependency Injection
For cases where you want to share dependencies across all instances of a class, you can use the `@inject_class_attributes` decorator:
```python
from typing import ClassVar
@inject_class_attributes
class UserService:
db: ClassVar[Inject[Database]]
cache: ClassVar[Inject[Cache]]
@classmethod
def get_user(cls, user_id: int) -> dict:
cache_key = f"user:{user_id}"
cached = cls.cache.get(cache_key)
if cached:
return cached
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
# The dependencies are shared across all instances and accessed via class methods
user = UserService.get_user(user_id=1)
```
The `@inject_class_attributes` decorator will:
1. Look for class attributes annotated with `ClassVar[Inject[Type]]`
2. Automatically inject the dependencies at the class level
3. Make the dependencies available to all instances and class methods
4. Raise `DependencyNotFoundError` if a required dependency is not registered
This is particularly useful for:
- Utility classes that don't need instance-specific state
- Services that should share the same dependencies across all instances
- Performance optimization when the same dependencies are used by multiple instances

View File

@ -132,6 +132,15 @@ class DependencyContainer:
cls._instance = super().__new__(cls)
return cls._instance
def get_key(self, dependency_type: Type[Any]) -> str:
"""
Get the key for a dependency.
"""
if get_origin(dependency_type) is type:
inner_type = get_args(dependency_type)[0]
return f"Type[{inner_type.__name__}]"
return dependency_type.__name__
def register(
self, dependency_type: Type[Any], dependency: Callable[[], Any]
) -> None:
@ -145,10 +154,11 @@ class DependencyContainer:
Example:
```python
container.register(Database, lambda: Database("postgresql://localhost:5432"))
container.register(Type[Metrics], lambda: Metrics) # For registering types themselves
```
"""
with self._lock:
self._dependencies[dependency_type.__name__] = dependency
self._dependencies[self.get_key(dependency_type)] = dependency
def override(
self, dependency_type: Type[Any], dependency: Callable[[], Any]
@ -169,7 +179,7 @@ class DependencyContainer:
```
"""
with self._lock:
self._overrides[dependency_type.__name__] = dependency
self._overrides[self.get_key(dependency_type)] = dependency
def remove_override(self, dependency_type: Type[T]) -> None:
"""
@ -184,7 +194,7 @@ class DependencyContainer:
```
"""
with self._lock:
self._overrides.pop(dependency_type.__name__, None)
self._overrides.pop(self.get_key(dependency_type), None)
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
"""
@ -206,10 +216,9 @@ class DependencyContainer:
```
"""
with self._lock:
type_name = dependency_type.__name__
factory = self._overrides.get(type_name) or self._dependencies.get(
type_name
)
factory = self._overrides.get(
self.get_key(dependency_type)
) or self._dependencies.get(self.get_key(dependency_type))
if factory is None:
return None
return factory()
@ -243,8 +252,10 @@ class DependencyContainer:
print("Database dependency is registered")
```"""
with self._lock:
type_name = dependency_type.__name__
return type_name in self._overrides or type_name in self._dependencies
return (
self.get_key(dependency_type) in self._overrides
or self.get_key(dependency_type) in self._dependencies
)
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
@ -341,3 +352,52 @@ def extract_inject_arg(tp: Any) -> Any:
f"Type {tp} is not Inject or Optional[Inject]. "
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
)
def inject_class_attributes(cls: Type[Any]) -> Type[Any]:
"""
Decorator to inject dependencies into class-level (static) attributes based on type hints.
This decorator automatically injects dependencies into class attributes
based on their type hints. The dependencies are shared across all instances
of the class.
Args:
cls: The class to inject dependencies into
Returns:
A class with dependencies injected into its class-level attributes
Example:
```python
@inject_class_attributes
class UserService:
db: ClassVar[Inject[Database]]
cache: ClassVar[Inject[Cache]]
@classmethod
def get_user(cls, user_id: int) -> dict:
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
```
"""
container = DependencyContainer()
type_hints = get_type_hints(cls, include_extras=True)
# Inject dependencies into class attributes
for attr_name, attr_type in type_hints.items():
# Skip if attribute is already set
if hasattr(cls, attr_name):
continue
# Check if it's an Inject type
if is_inject_type(attr_type):
dependency_type = extract_inject_arg(attr_type)
dependency = container.get(dependency_type)
if dependency is None:
raise DependencyNotFoundError(
f"Dependency of type {dependency_type} not found in container. "
f"Make sure to register it using container.register({dependency_type.__name__}, ...)"
)
setattr(cls, attr_name, dependency)
return cls

View File

@ -1,3 +1,5 @@
from typing import Optional
import pytest
from metadata.utils.dependency_injector.dependency_injector import (
@ -5,6 +7,7 @@ from metadata.utils.dependency_injector.dependency_injector import (
DependencyNotFoundError,
Inject,
inject,
inject_class_attributes,
)
@ -21,8 +24,10 @@ class Cache:
def __init__(self, host: str):
self.host = host
def get(self, key: str) -> str:
return f"Cache hit for {key}"
def get(self, key: str) -> Optional[str]:
if key == "user:1":
return "Cache hit for user:1"
return None
# Test functions for injection
@ -139,3 +144,99 @@ class TestInjectDecorator:
custom_db = Database("postgresql://custom:5432")
result = get_user(user_id=1, db=custom_db)
assert result == "Executed: SELECT * FROM users WHERE id = 1"
class TestInjectClassAttributes:
def test_inject_class_attributes_single_dependency(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
container.register(Database, db_factory)
@inject_class_attributes
class UserService:
db: Inject[Database]
@classmethod
def get_user(cls, user_id: int) -> str:
if cls.db is None:
raise DependencyNotFoundError("Database dependency not found")
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
result = UserService.get_user(user_id=1)
assert result == "Executed: SELECT * FROM users WHERE id = 1"
def test_inject_class_attributes_multiple_dependencies(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
cache_factory = lambda: Cache("localhost")
container.register(Database, db_factory)
container.register(Cache, cache_factory)
@inject_class_attributes
class UserService:
db: Inject[Database]
cache: Inject[Cache]
@classmethod
def get_user(cls, user_id: int) -> str:
if cls.db is None:
raise DependencyNotFoundError("Database dependency not found")
if cls.cache is None:
raise DependencyNotFoundError("Cache dependency not found")
cache_key = f"user:{user_id}"
cached = cls.cache.get(cache_key)
if cached:
return cached
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
result = UserService.get_user(user_id=1)
assert result == "Cache hit for user:1"
def test_inject_class_attributes_missing_dependency(self):
container = DependencyContainer()
container.clear() # Ensure no dependencies are registered
with pytest.raises(DependencyNotFoundError):
@inject_class_attributes
class UserService:
db: Inject[Database]
@classmethod
def get_user(cls, user_id: int) -> str:
if cls.db is None:
raise DependencyNotFoundError("Database dependency not found")
return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}")
def test_inject_class_attributes_shared_dependencies(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
container.register(Database, db_factory)
@inject_class_attributes
class UserService:
db: Inject[Database]
counter: int = 0
@classmethod
def increment_counter(cls) -> int:
cls.counter += 1
return cls.counter
@classmethod
def get_user(cls, user_id: int) -> str:
if cls.db is None:
raise DependencyNotFoundError("Database dependency not found")
return f"Counter: {cls.counter}, {cls.db.query(f'SELECT * FROM users WHERE id = {user_id}')}"
# First call
result1 = UserService.get_user(user_id=1)
assert result1 == "Counter: 0, Executed: SELECT * FROM users WHERE id = 1"
# Increment counter
UserService.increment_counter()
# Second call - counter should be shared
result2 = UserService.get_user(user_id=1)
assert result2 == "Counter: 1, Executed: SELECT * FROM users WHERE id = 1"

View File

@ -201,6 +201,7 @@ class ProfilerTest(TestCase):
"""
simple = DefaultProfiler(
profiler_interface=self.datalake_profiler_interface,
metrics_registry=Metrics,
)
simple.compute_metrics()
@ -337,6 +338,7 @@ class ProfilerTest(TestCase):
default_profiler = DefaultProfiler(
profiler_interface=self.datalake_profiler_interface,
metrics_registry=Metrics,
)
column_metrics = default_profiler._prepare_column_metrics()
for metric in column_metrics:

View File

@ -50,6 +50,7 @@ from metadata.profiler.metrics.core import (
QueryMetric,
StaticMetric,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.metrics.static.row_count import RowCount
from metadata.profiler.processor.default import get_default_metrics
from metadata.sampler.pandas.sampler import DatalakeSampler
@ -197,7 +198,7 @@ class PandasInterfaceTest(TestCase):
"""
cls.table = User
cls.metrics = get_default_metrics(cls.table)
cls.metrics = get_default_metrics(Metrics, cls.table)
cls.static_metrics = [
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
]

View File

@ -139,7 +139,7 @@ class ProfilerTest(TestCase):
Check our pre-cooked profiler
"""
simple = DefaultProfiler(
profiler_interface=self.sqa_profiler_interface,
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
)
simple.compute_metrics()
@ -297,7 +297,7 @@ class ProfilerTest(TestCase):
)
simple = DefaultProfiler(
profiler_interface=sqa_profiler_interface,
profiler_interface=sqa_profiler_interface, metrics_registry=Metrics
)
with pytest.raises(TimeoutError):
@ -316,7 +316,7 @@ class ProfilerTest(TestCase):
) # type: ignore
default_profiler = DefaultProfiler(
profiler_interface=self.sqa_profiler_interface,
profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics
)
column_metrics = default_profiler._prepare_column_metrics()

View File

@ -15,10 +15,10 @@ Test SQA Interface
import os
from datetime import datetime
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import TEXT, Column, Integer, String, inspect
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm.session import Session
@ -49,6 +49,7 @@ from metadata.profiler.metrics.core import (
QueryMetric,
StaticMetric,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.metrics.static.row_count import RowCount
from metadata.profiler.processor.default import get_default_metrics
from metadata.sampler.sqlalchemy.sampler import SQASampler
@ -64,46 +65,9 @@ class User(declarative_base()):
age = Column(Integer)
class SQAInterfaceTest(TestCase):
def setUp(self) -> None:
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName("id"),
dataType=DataType.INT,
)
],
)
sqlite_conn = SQLiteConnection(
scheme=SQLiteScheme.sqlite_pysqlite,
)
with patch.object(SQASampler, "build_table_orm", return_value=User):
sampler = SQASampler(
service_connection_config=sqlite_conn,
ometa_client=None,
entity=None,
)
with patch.object(SQASampler, "build_table_orm", return_value=User):
self.sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn, None, table_entity, None, sampler, 5, 43200
)
self.table = User
def test_init_interface(self):
"""Test we can instantiate our interface object correctly"""
assert isinstance(self.sqa_profiler_interface.session, Session)
def tearDown(self) -> None:
self.sqa_profiler_interface._sampler = None
class SQAInterfaceTestMultiThread(TestCase):
table_entity = Table(
@pytest.fixture
def table_entity():
return Table(
id=uuid4(),
name="user",
columns=[
@ -113,12 +77,17 @@ class SQAInterfaceTestMultiThread(TestCase):
)
],
)
db_path = os.path.join(os.path.dirname(__file__), "test.db")
sqlite_conn = SQLiteConnection(
@pytest.fixture
def sqlite_conn():
return SQLiteConnection(
scheme=SQLiteScheme.sqlite_pysqlite,
databaseMode=db_path + "?check_same_thread=False",
)
@pytest.fixture
def sqa_profiler_interface(table_entity, sqlite_conn):
with patch.object(SQASampler, "build_table_orm", return_value=User):
sampler = SQASampler(
service_connection_config=sqlite_conn,
@ -126,140 +95,205 @@ class SQAInterfaceTestMultiThread(TestCase):
entity=None,
)
sqa_profiler_interface = SQAProfilerInterface(
sqlite_conn,
with patch.object(SQASampler, "build_table_orm", return_value=User):
interface = SQAProfilerInterface(
sqlite_conn, None, table_entity, None, sampler, 5, 43200
)
return interface
def test_init_interface(sqa_profiler_interface):
"""Test we can instantiate our interface object correctly"""
assert isinstance(sqa_profiler_interface.session, Session)
@pytest.fixture(scope="class")
def db_path():
return os.path.join(os.path.dirname(__file__), "test.db")
@pytest.fixture(scope="class")
def class_sqlite_conn(db_path):
return SQLiteConnection(
scheme=SQLiteScheme.sqlite_pysqlite,
databaseMode=db_path + "?check_same_thread=False",
)
@pytest.fixture(scope="class")
def class_table_entity():
return Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName("id"),
dataType=DataType.INT,
)
],
)
@pytest.fixture(scope="class")
def class_sqa_profiler_interface(class_sqlite_conn, class_table_entity):
with patch.object(SQASampler, "build_table_orm", return_value=User):
sampler = SQASampler(
service_connection_config=class_sqlite_conn,
ometa_client=None,
entity=None,
)
interface = SQAProfilerInterface(
class_sqlite_conn,
None,
table_entity,
class_table_entity,
None,
sampler,
5,
43200,
)
return interface
@classmethod
def setUpClass(cls) -> None:
"""
Prepare Ingredients
"""
User.__table__.create(bind=cls.sqa_profiler_interface.session.get_bind())
data = [
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
]
cls.sqa_profiler_interface.session.add_all(data)
cls.sqa_profiler_interface.session.commit()
cls.table = User
cls.metrics = get_default_metrics(cls.table)
cls.static_metrics = [
metric for metric in cls.metrics if issubclass(metric, StaticMetric)
]
cls.composed_metrics = [
metric for metric in cls.metrics if issubclass(metric, ComposedMetric)
]
cls.window_metrics = [
@pytest.fixture(scope="class", autouse=True)
def setup_database(class_sqa_profiler_interface):
"""Setup test database and tables"""
try:
# Drop the table if it exists
User.__table__.drop(
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
)
# Create the table
User.__table__.create(bind=class_sqa_profiler_interface.session.get_bind())
except Exception as e:
print(f"Error during table setup: {str(e)}")
raise e
data = [
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30),
User(name="Jane", fullname="Jone Doe", nickname=None, age=31),
]
class_sqa_profiler_interface.session.add_all(data)
class_sqa_profiler_interface.session.commit()
yield
# Cleanup
try:
User.__table__.drop(
bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True
)
class_sqa_profiler_interface.session.close()
except Exception as e:
print(f"Error during cleanup: {str(e)}")
raise e
@pytest.fixture(scope="class")
def metrics(class_sqa_profiler_interface):
metrics = get_default_metrics(Metrics, User)
return {
"all": metrics,
"static": [metric for metric in metrics if issubclass(metric, StaticMetric)],
"composed": [
metric for metric in metrics if issubclass(metric, ComposedMetric)
],
"window": [
metric
for metric in cls.metrics
for metric in metrics
if issubclass(metric, StaticMetric) and metric.is_window_metric()
]
cls.query_metrics = [
],
"query": [
metric
for metric in cls.metrics
for metric in metrics
if issubclass(metric, QueryMetric) and metric.is_col_metric()
]
],
}
def test_init_interface(self):
"""Test we can instantiate our interface object correctly"""
assert isinstance(self.sqa_profiler_interface.session, Session)
def test_init_interface_multi_thread(class_sqa_profiler_interface):
"""Test we can instantiate our interface object correctly"""
assert isinstance(class_sqa_profiler_interface.session, Session)
def test_get_all_metrics(self):
table_metrics = [
def test_get_all_metrics(class_sqa_profiler_interface, metrics):
table_metrics = [
ThreadPoolMetrics(
metrics=[
metric
for metric in metrics["all"]
if (not metric.is_col_metric() and not metric.is_system_metrics())
],
metric_type=MetricTypes.Table,
column=None,
table=User,
)
]
column_metrics = []
query_metrics = []
window_metrics = []
for col in inspect(User).c:
column_metrics.append(
ThreadPoolMetrics(
metrics=[
metric
for metric in self.metrics
if (not metric.is_col_metric() and not metric.is_system_metrics())
for metric in metrics["static"]
if metric.is_col_metric() and not metric.is_window_metric()
],
metric_type=MetricTypes.Table,
column=None,
table=self.table,
metric_type=MetricTypes.Static,
column=col,
table=User,
)
]
column_metrics = []
query_metrics = []
window_metrics = []
for col in inspect(User).c:
column_metrics.append(
)
for query_metric in metrics["query"]:
query_metrics.append(
ThreadPoolMetrics(
metrics=[
metric
for metric in self.static_metrics
if metric.is_col_metric() and not metric.is_window_metric()
],
metric_type=MetricTypes.Static,
metrics=query_metric,
metric_type=MetricTypes.Query,
column=col,
table=self.table,
table=User,
)
)
for query_metric in self.query_metrics:
query_metrics.append(
ThreadPoolMetrics(
metrics=query_metric,
metric_type=MetricTypes.Query,
column=col,
table=self.table,
)
)
window_metrics.append(
ThreadPoolMetrics(
metrics=[
metric
for metric in self.window_metrics
if metric.is_window_metric()
],
metric_type=MetricTypes.Window,
column=col,
table=self.table,
)
window_metrics.append(
ThreadPoolMetrics(
metrics=[
metric for metric in metrics["window"] if metric.is_window_metric()
],
metric_type=MetricTypes.Window,
column=col,
table=User,
)
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
profile_results = self.sqa_profiler_interface.get_all_metrics(
all_metrics,
)
column_profile = [
ColumnProfile(**profile_results["columns"].get(col.name))
for col in inspect(User).c
if profile_results["columns"].get(col.name)
]
all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics]
table_profile = TableProfile(
columnCount=profile_results["table"].get("columnCount"),
rowCount=profile_results["table"].get(RowCount.name()),
timestamp=Timestamp(int(datetime.now().timestamp())),
)
profile_results = class_sqa_profiler_interface.get_all_metrics(
all_metrics,
)
profile_request = CreateTableProfileRequest(
tableProfile=table_profile, columnProfile=column_profile
)
column_profile = [
ColumnProfile(**profile_results["columns"].get(col.name))
for col in inspect(User).c
if profile_results["columns"].get(col.name)
]
assert profile_request.tableProfile.columnCount == 6
assert profile_request.tableProfile.rowCount == 2
name_column_profile = [
profile
for profile in profile_request.columnProfile
if profile.name == "name"
][0]
id_column_profile = [
profile for profile in profile_request.columnProfile if profile.name == "id"
][0]
assert name_column_profile.nullCount == 0
assert id_column_profile.median == 1.0
table_profile = TableProfile(
columnCount=profile_results["table"].get("columnCount"),
rowCount=profile_results["table"].get(RowCount.name()),
timestamp=Timestamp(int(datetime.now().timestamp())),
)
@classmethod
def tearDownClass(cls) -> None:
os.remove(cls.db_path)
return super().tearDownClass()
profile_request = CreateTableProfileRequest(
tableProfile=table_profile, columnProfile=column_profile
)
assert profile_request.tableProfile.columnCount == 6
assert profile_request.tableProfile.rowCount == 2
name_column_profile = [
profile for profile in profile_request.columnProfile if profile.name == "name"
][0]
id_column_profile = [
profile for profile in profile_request.columnProfile if profile.name == "id"
][0]
assert name_column_profile.nullCount == 0
assert id_column_profile.median == 1.0