From e79c54e6a59a9075c09bae037e3d87dd82339d2f Mon Sep 17 00:00:00 2001 From: IceS2 Date: Tue, 17 Jun 2025 19:01:00 +0200 Subject: [PATCH] MINOR: Add injection to profiler (#21738) * Initial implementation for our Connection Class * Implement the Initial Connection class * Add Unit Tests * Implement Dependency Injection for the Ingestion Framework * Fix Test * Fix Profile Test Connection * Add Injection to Metrics in Profiler * Add Injection to the Profiler * Fix UnitTests * Fix Pytests * Fix Tests * Fix types --- ingestion/src/metadata/__init__.py | 9 + .../tableRowInsertedCountToBeBetween.py | 19 +- .../metadata/ingestion/source/connections.py | 3 + .../trino/profiler/system_tables_profiler.py | 43 ++- .../orm/functions/table_metric_computer.py | 30 +- .../src/metadata/profiler/processor/core.py | 4 +- .../metadata/profiler/processor/default.py | 57 +-- .../profiler/processor/metric_filter.py | 19 +- .../src/metadata/profiler/processor/models.py | 18 +- .../source/database/base/profiler_resolver.py | 35 ++ .../source/database/base/profiler_source.py | 75 ++-- .../source/profiler_source_interface.py | 21 +- .../src/metadata/sampler/sampler_interface.py | 25 +- .../utils/dependency_injector/README.md | 37 +- .../dependency_injector.py | 78 +++- .../test_dependency_injector.py | 105 +++++- .../unit/profiler/pandas/test_profiler.py | 2 + .../pandas/test_profiler_interface.py | 3 +- .../unit/profiler/sqlalchemy/test_profiler.py | 6 +- .../sqlalchemy/test_sqa_profiler_interface.py | 334 ++++++++++-------- 20 files changed, 644 insertions(+), 279 deletions(-) create mode 100644 ingestion/src/metadata/profiler/source/database/base/profiler_resolver.py diff --git a/ingestion/src/metadata/__init__.py b/ingestion/src/metadata/__init__.py index 4a794f00c76..5b6ca801771 100644 --- a/ingestion/src/metadata/__init__.py +++ b/ingestion/src/metadata/__init__.py @@ -11,7 +11,14 @@ """ OpenMetadata package initialization. """ +from typing import Type +from metadata.profiler.metrics.registry import Metrics +from metadata.profiler.registry import MetricRegistry +from metadata.profiler.source.database.base.profiler_resolver import ( + DefaultProfilerResolver, + ProfilerResolver, +) from metadata.utils.dependency_injector.dependency_injector import DependencyContainer from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader @@ -20,3 +27,5 @@ container = DependencyContainer() # Register the source loader container.register(SourceLoader, DefaultSourceLoader) +container.register(Type[MetricRegistry], lambda: Metrics) +container.register(Type[ProfilerResolver], lambda: DefaultProfilerResolver) diff --git a/ingestion/src/metadata/data_quality/validations/table/pandas/tableRowInsertedCountToBeBetween.py b/ingestion/src/metadata/data_quality/validations/table/pandas/tableRowInsertedCountToBeBetween.py index 9f0a849c5df..d6f5d79bc8f 100644 --- a/ingestion/src/metadata/data_quality/validations/table/pandas/tableRowInsertedCountToBeBetween.py +++ b/ingestion/src/metadata/data_quality/validations/table/pandas/tableRowInsertedCountToBeBetween.py @@ -34,7 +34,7 @@ class TableRowInsertedCountToBeBetweenValidator( """Validator for table row inserted count to be between test case""" @staticmethod - def _get_threshold_date(range_type: str, range_interval: int): + def get_threshold_date_dt(range_type: str, range_interval: int) -> datetime: """returns the threshold datetime in utc to count the numbers of rows inserted Args: @@ -55,7 +55,22 @@ class TableRowInsertedCountToBeBetweenValidator( threshold_date = threshold_date.replace( hour=0, minute=0, second=0, microsecond=0 ) - return threshold_date.strftime("%Y%m%d%H%M%S") + return threshold_date + + @staticmethod + def _get_threshold_date( + range_type: str, range_interval: int, date_format: str = "%Y%m%d%H%M%S" + ): + """returns the threshold datetime in utc as string to count the numbers of rows inserted + + Args: + range_type (str): type of range (i.e. HOUR, DAY, MONTH, YEAR) + range_interval (int): interval of range (i.e. 1, 2, 3, 4) + date_format (str): format of the date (i.e. %Y%m%d%H%M%S, %Y-%m-%d %H:%M:%S) + """ + return TableRowInsertedCountToBeBetweenValidator.get_threshold_date_dt( + range_type, range_interval + ).strftime(date_format) def _get_column_name(self): """returns the column name to be validated""" diff --git a/ingestion/src/metadata/ingestion/source/connections.py b/ingestion/src/metadata/ingestion/source/connections.py index 2f794955a2a..e203563fe45 100644 --- a/ingestion/src/metadata/ingestion/source/connections.py +++ b/ingestion/src/metadata/ingestion/source/connections.py @@ -113,6 +113,9 @@ def get_test_connection_fn(connection: BaseModel) -> Callable: ) +# ------------------------------------------------------------ + + def get_connection(connection: BaseModel) -> Any: """ Main method to prepare a connection from diff --git a/ingestion/src/metadata/ingestion/source/database/trino/profiler/system_tables_profiler.py b/ingestion/src/metadata/ingestion/source/database/trino/profiler/system_tables_profiler.py index 82f8ff98a8e..69275080950 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino/profiler/system_tables_profiler.py +++ b/ingestion/src/metadata/ingestion/source/database/trino/profiler/system_tables_profiler.py @@ -13,7 +13,7 @@ System table profiler """ from datetime import datetime from decimal import Decimal -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Type, Union from more_itertools import partition from pydantic import field_validator @@ -25,7 +25,11 @@ from metadata.profiler.interface.sqlalchemy.stored_statistics_profiler import ( StoredStatisticsSource, ) from metadata.profiler.metrics.core import Metric -from metadata.profiler.metrics.registry import Metrics +from metadata.profiler.registry import MetricRegistry +from metadata.utils.dependency_injector.dependency_injector import ( + Inject, + inject_class_attributes, +) from metadata.utils.logger import profiler_logger from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache from metadata.utils.ssl_manager import get_ssl_connection @@ -61,23 +65,28 @@ class TableStats(BaseModel): columns: Dict[str, ColumnStats] = {} +@inject_class_attributes class TrinoStoredStatisticsSource(StoredStatisticsSource): """Trino system profile source""" - metric_stats_map: Dict[Metrics, str] = { - Metrics.NULL_RATIO: "nulls_fractions", - Metrics.DISTINCT_COUNT: "distinct_values_count", - Metrics.ROW_COUNT: "row_count", - Metrics.MAX: "high_value", - Metrics.MIN: "low_value", - } + metrics: Inject[Type[MetricRegistry]] - metric_stats_by_name: Dict[str, str] = { - k.name: v for k, v in metric_stats_map.items() - } + @classmethod + def get_metric_stats_map(cls) -> Dict[MetricRegistry, str]: + return { + cls.metrics.NULL_RATIO: "nulls_fractions", + cls.metrics.DISTINCT_COUNT: "distinct_values_count", + cls.metrics.ROW_COUNT: "row_count", + cls.metrics.MAX: "high_value", + cls.metrics.MIN: "low_value", + } - def get_statistics_metrics(self) -> Set[Metrics]: - return set(self.metric_stats_map.keys()) + @classmethod + def get_metric_stats_by_name(cls) -> Dict[str, str]: + return {k.name: v for k, v in cls.get_metric_stats_map().items()} + + def get_statistics_metrics(self) -> Set[MetricRegistry]: + return set(self.get_metric_stats_map().keys()) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -96,7 +105,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource): f"Column {column} not found in table {table_name}. Statistics might be stale or missing." ) result = { - m.name(): getattr(column_stats, self.metric_stats_by_name[m.name()]) + m.name(): getattr(column_stats, self.get_metric_stats_by_name()[m.name()]) for m in metric } result.update(self.get_hybrid_statistics(table_stats, column_stats)) @@ -108,7 +117,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource): ) -> dict: table_stats = self._get_cached_stats(schema, table_name) return { - m.name(): getattr(table_stats, self.metric_stats_by_name[m.name()]) + m.name(): getattr(table_stats, self.get_metric_stats_by_name()[m.name()]) for m in metric } @@ -159,7 +168,7 @@ class TrinoStoredStatisticsSource(StoredStatisticsSource): ) -> Dict[str, Any]: return { # trino stats are in fractions, so we need to convert them to counts (unlike our default profiler) - Metrics.NULL_COUNT.name: ( + self.metrics.NULL_COUNT.name: ( int(table_stats.row_count * column_stats.nulls_fraction) if None not in [table_stats.row_count, column_stats.nulls_fraction] else None diff --git a/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py index ff4ddfcfc74..514e4d7d8d5 100644 --- a/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py +++ b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py @@ -16,7 +16,7 @@ Run profiler metrics on the table import traceback from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Type from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select from sqlalchemy.sql.expression import ColumnOperators, and_, cte @@ -27,13 +27,29 @@ from metadata.generated.schema.entity.data.table import TableType from metadata.profiler.metrics.registry import Metrics from metadata.profiler.orm.registry import Dialects from metadata.profiler.processor.runner import QueryRunner +from metadata.profiler.registry import MetricRegistry +from metadata.utils.dependency_injector.dependency_injector import ( + DependencyNotFoundError, + Inject, + inject, +) from metadata.utils.logger import profiler_interface_registry_logger logger = profiler_interface_registry_logger() + +@inject +def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None): + if metrics is None: + raise DependencyNotFoundError( + "MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered." + ) + return metrics.ROW_COUNT().name() + + COLUMN_COUNT = "columnCount" COLUMN_NAMES = "columnNames" -ROW_COUNT = Metrics.ROW_COUNT().name() +ROW_COUNT = get_row_count_metric() SIZE_IN_BYTES = "sizeInBytes" CREATE_DATETIME = "createDateTime" @@ -362,9 +378,15 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer): class MySQLTableMetricComputer(BaseTableMetricComputer): """MySQL Table Metric Computer""" - def compute(self): + @inject + def compute(self, metrics: Inject[Type[MetricRegistry]] = None): """compute table metrics for mysql""" + if metrics is None: + raise DependencyNotFoundError( + "MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered." + ) + columns = [ Column("TABLE_ROWS").label(ROW_COUNT), (Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES), @@ -390,7 +412,7 @@ class MySQLTableMetricComputer(BaseTableMetricComputer): res = res._asdict() # innodb row count is an estimate we need to patch the row count with COUNT(*) # https://dev.mysql.com/doc/refman/8.3/en/information-schema-innodb-tablestats-table.html - row_count = self.runner.select_first_from_table(Metrics.ROW_COUNT().fn()) + row_count = self.runner.select_first_from_table(metrics.ROW_COUNT().fn()) res.update({ROW_COUNT: row_count.rowCount}) return res diff --git a/ingestion/src/metadata/profiler/processor/core.py b/ingestion/src/metadata/profiler/processor/core.py index 67958e07cbd..1ef91c0e038 100644 --- a/ingestion/src/metadata/profiler/processor/core.py +++ b/ingestion/src/metadata/profiler/processor/core.py @@ -16,7 +16,7 @@ from __future__ import annotations import traceback from datetime import datetime, timezone -from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type +from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, cast from pydantic import ValidationError from sqlalchemy import Column @@ -94,7 +94,7 @@ class Profiler(Generic[TMetric]): :param profile_sample: % of rows to use for sampling column metrics """ self.global_profiler_configuration: Optional[ProfilerConfiguration] = ( - global_profiler_configuration.config_value + cast(ProfilerConfiguration, global_profiler_configuration.config_value) if global_profiler_configuration else None ) diff --git a/ingestion/src/metadata/profiler/processor/default.py b/ingestion/src/metadata/profiler/processor/default.py index 9fa2af8078b..eb4ed7db877 100644 --- a/ingestion/src/metadata/profiler/processor/default.py +++ b/ingestion/src/metadata/profiler/processor/default.py @@ -12,55 +12,54 @@ """ Default simple profiler to use """ -from typing import List, Optional +from typing import List, Optional, Type from sqlalchemy.orm import DeclarativeMeta -from metadata.generated.schema.configuration.profilerConfiguration import ( - ProfilerConfiguration, -) from metadata.generated.schema.entity.data.table import ColumnProfilerConfig from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.settings.settings import Settings from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.profiler.metrics.core import Metric, add_props -from metadata.profiler.metrics.registry import Metrics from metadata.profiler.processor.core import Profiler +from metadata.profiler.registry import MetricRegistry def get_default_metrics( + metrics_registry: Type[MetricRegistry], table: DeclarativeMeta, ometa_client: Optional[OpenMetadata] = None, db_service: Optional[DatabaseService] = None, ) -> List[Metric]: return [ # Table Metrics - Metrics.ROW_COUNT.value, - add_props(table=table)(Metrics.COLUMN_COUNT.value), - add_props(table=table)(Metrics.COLUMN_NAMES.value), + metrics_registry.ROW_COUNT.value, + add_props(table=table)(metrics_registry.COLUMN_COUNT.value), + add_props(table=table)(metrics_registry.COLUMN_NAMES.value), # We'll use the ometa_client & db_service in case we need to fetch info to ES add_props(table=table, ometa_client=ometa_client, db_service=db_service)( - Metrics.SYSTEM.value + metrics_registry.SYSTEM.value ), # Column Metrics - Metrics.MEDIAN.value, - Metrics.FIRST_QUARTILE.value, - Metrics.THIRD_QUARTILE.value, - Metrics.MEAN.value, - Metrics.COUNT.value, - Metrics.DISTINCT_COUNT.value, - Metrics.DISTINCT_RATIO.value, - Metrics.MIN.value, - Metrics.MAX.value, - Metrics.NULL_COUNT.value, - Metrics.NULL_RATIO.value, - Metrics.STDDEV.value, - Metrics.SUM.value, - Metrics.UNIQUE_COUNT.value, - Metrics.UNIQUE_RATIO.value, - Metrics.IQR.value, - Metrics.HISTOGRAM.value, - Metrics.NON_PARAMETRIC_SKEW.value, + metrics_registry.MEDIAN.value, + metrics_registry.FIRST_QUARTILE.value, + metrics_registry.THIRD_QUARTILE.value, + metrics_registry.MEAN.value, + metrics_registry.COUNT.value, + metrics_registry.DISTINCT_COUNT.value, + metrics_registry.DISTINCT_RATIO.value, + metrics_registry.MIN.value, + metrics_registry.MAX.value, + metrics_registry.NULL_COUNT.value, + metrics_registry.NULL_RATIO.value, + metrics_registry.STDDEV.value, + metrics_registry.SUM.value, + metrics_registry.UNIQUE_COUNT.value, + metrics_registry.UNIQUE_RATIO.value, + metrics_registry.IQR.value, + metrics_registry.HISTOGRAM.value, + metrics_registry.NON_PARAMETRIC_SKEW.value, ] @@ -74,12 +73,14 @@ class DefaultProfiler(Profiler): def __init__( self, profiler_interface: ProfilerInterface, + metrics_registry: Type[MetricRegistry], include_columns: Optional[List[ColumnProfilerConfig]] = None, exclude_columns: Optional[List[str]] = None, - global_profiler_configuration: Optional[ProfilerConfiguration] = None, + global_profiler_configuration: Optional[Settings] = None, db_service=None, ): _metrics = get_default_metrics( + metrics_registry=metrics_registry, table=profiler_interface.table, ometa_client=profiler_interface.ometa_client, db_service=db_service, diff --git a/ingestion/src/metadata/profiler/processor/metric_filter.py b/ingestion/src/metadata/profiler/processor/metric_filter.py index f0d5d012f94..3915df6d87d 100644 --- a/ingestion/src/metadata/profiler/processor/metric_filter.py +++ b/ingestion/src/metadata/profiler/processor/metric_filter.py @@ -33,22 +33,35 @@ from metadata.profiler.metrics.core import ( SystemMetric, TMetric, ) -from metadata.profiler.metrics.registry import Metrics from metadata.profiler.orm.converter.converter_registry import converter_registry +from metadata.profiler.registry import MetricRegistry +from metadata.utils.dependency_injector.dependency_injector import ( + DependencyNotFoundError, + Inject, + inject, +) from metadata.utils.sqa_like_column import SQALikeColumn class MetricFilter: """Metric filter class for profiler""" + @inject def __init__( self, metrics: Tuple[Type[TMetric]], global_profiler_config: Optional[ProfilerConfiguration] = None, table_profiler_config: Optional[TableProfilerConfig] = None, column_profiler_config: Optional[List[ColumnProfilerConfig]] = None, + metrics_registry: Inject[Type[MetricRegistry]] = None, ): + if metrics_registry is None: + raise DependencyNotFoundError( + "MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered." + ) + self.metrics = metrics + self.metrics_registry = metrics_registry self.global_profiler_config = global_profiler_config self.table_profiler_config = table_profiler_config self.column_profiler_config = column_profiler_config @@ -196,7 +209,7 @@ class MetricFilter: metrics = [ Metric.value - for Metric in Metrics + for Metric in self.metrics_registry if Metric.value.name() in {mtrc.value for mtrc in col_dtype_config.metrics} and Metric.value in metrics ] @@ -240,7 +253,7 @@ class MetricFilter: metrics = [ Metric.value - for Metric in Metrics + for Metric in self.metrics_registry if Metric.value.name().lower() in {mtrc.lower() for mtrc in metric_names} and Metric.value in metrics ] diff --git a/ingestion/src/metadata/profiler/processor/models.py b/ingestion/src/metadata/profiler/processor/models.py index e12f743598a..1faee249d76 100644 --- a/ingestion/src/metadata/profiler/processor/models.py +++ b/ingestion/src/metadata/profiler/processor/models.py @@ -13,20 +13,30 @@ Models to map profiler definitions JSON workflows to the profiler """ -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel, BeforeValidator from typing_extensions import Annotated -from metadata.profiler.metrics.registry import Metrics +from metadata.profiler.registry import MetricRegistry +from metadata.utils.dependency_injector.dependency_injector import ( + DependencyNotFoundError, + Inject, + inject, +) -def valid_metric(value: str): +@inject +def valid_metric(value: str, metrics: Inject[Type[MetricRegistry]] = None): """ Validate that the input metrics are correctly named and can be found in the Registry """ - if not Metrics.get(value.upper()): + if metrics is None: + raise DependencyNotFoundError( + "MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered." + ) + if not metrics.get(value.upper()): raise ValueError( f"Metric name {value} is not a proper metric name from the Registry" ) diff --git a/ingestion/src/metadata/profiler/source/database/base/profiler_resolver.py b/ingestion/src/metadata/profiler/source/database/base/profiler_resolver.py new file mode 100644 index 00000000000..d308ea8a0b5 --- /dev/null +++ b/ingestion/src/metadata/profiler/source/database/base/profiler_resolver.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Tuple, Type + +from metadata.generated.schema.entity.services.serviceType import ServiceType +from metadata.profiler.interface.profiler_interface import ProfilerInterface +from metadata.sampler.sampler_interface import SamplerInterface +from metadata.utils.service_spec.service_spec import ( + import_profiler_class, + import_sampler_class, +) + + +class ProfilerResolver(ABC): + """Abstract class for the profiler resolver""" + + @staticmethod + @abstractmethod + def resolve( + processing_engine: str, service_type: ServiceType, source_type: str + ) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]: + """Resolve the sampler and profiler based on the processing engine.""" + raise NotImplementedError + + +class DefaultProfilerResolver(ProfilerResolver): + """Default profiler resolver""" + + @staticmethod + def resolve( + processing_engine: str, service_type: ServiceType, source_type: str + ) -> Tuple[Type[SamplerInterface], Type[ProfilerInterface]]: + """Resolve the sampler and profiler based on the processing engine.""" + sampler_class = import_sampler_class(service_type, source_type=source_type) + profiler_class = import_profiler_class(service_type, source_type=source_type) + return sampler_class, profiler_class diff --git a/ingestion/src/metadata/profiler/source/database/base/profiler_source.py b/ingestion/src/metadata/profiler/source/database/base/profiler_source.py index a50619b0ee4..e878da248eb 100644 --- a/ingestion/src/metadata/profiler/source/database/base/profiler_source.py +++ b/ingestion/src/metadata/profiler/source/database/base/profiler_source.py @@ -14,7 +14,7 @@ Base source for the profiler used to instantiate a profiler runner with its interface """ from copy import deepcopy -from typing import Optional, cast +from typing import Optional, Type, cast from metadata.generated.schema.configuration.profilerConfiguration import ( ProfilerConfiguration, @@ -22,10 +22,7 @@ from metadata.generated.schema.configuration.profilerConfiguration import ( from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.entity.services.databaseService import ( - DatabaseConnection, - DatabaseService, -) +from metadata.generated.schema.entity.services.databaseService import DatabaseConnection from metadata.generated.schema.entity.services.serviceType import ServiceType from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import ( DatabaseServiceProfilerPipeline, @@ -36,9 +33,10 @@ from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.profiler.api.models import ProfilerProcessorConfig, TableConfig from metadata.profiler.interface.profiler_interface import ProfilerInterface -from metadata.profiler.metrics.registry import Metrics from metadata.profiler.processor.core import Profiler from metadata.profiler.processor.default import DefaultProfiler, get_default_metrics +from metadata.profiler.registry import MetricRegistry +from metadata.profiler.source.database.base.profiler_resolver import ProfilerResolver from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface from metadata.sampler.config import ( get_config_for_table, @@ -47,12 +45,13 @@ from metadata.sampler.config import ( ) from metadata.sampler.models import SampleConfig from metadata.sampler.sampler_interface import SamplerInterface +from metadata.utils.dependency_injector.dependency_injector import ( + DependencyNotFoundError, + Inject, + inject, +) from metadata.utils.logger import profiler_logger from metadata.utils.profiler_utils import get_context_entities -from metadata.utils.service_spec.service_spec import ( - import_profiler_class, - import_sampler_class, -) logger = profiler_logger() @@ -77,8 +76,7 @@ class ProfilerSource(ProfilerSourceInterface): self.ometa_client = ometa_client self._interface_type: str = config.source.type.lower() self._interface = None - # We define this in create_profiler_interface to help us reuse - # this method for the sampler, which does not have a DatabaseServiceProfilerPipeline + self.source_config = None self.global_profiler_configuration = global_profiler_configuration @@ -122,25 +120,34 @@ class ProfilerSource(ProfilerSourceInterface): return config_copy + @inject def create_profiler_interface( self, entity: Table, config: Optional[TableConfig], - profiler_config: Optional[ProfilerProcessorConfig], - schema_entity: Optional[DatabaseSchema], - database_entity: Optional[Database], - db_service: Optional[DatabaseService], + schema_entity: DatabaseSchema, + database_entity: Database, + profiler_resolver: Inject[Type[ProfilerResolver]] = None, ) -> ProfilerInterface: - """Create sqlalchemy profiler interface""" + """Create the appropriate profiler interface based on processing engine.""" + if profiler_resolver is None: + raise DependencyNotFoundError( + "ProfilerResolver dependency not found. Please ensure the ProfilerResolver is properly registered." + ) + + # NOTE: For some reason I do not understand, if we instantiate this on the __init__ method, we break the + # autoclassification workflow. This should be fixed. There should not be an impact on AutoClassification. + # We have an issue to track this here: https://github.com/open-metadata/OpenMetadata/issues/21790 self.source_config = DatabaseServiceProfilerPipeline.model_validate( self.config.source.sourceConfig.config ) - profiler_class = import_profiler_class( - ServiceType.Database, source_type=self._interface_type - ) - sampler_class = import_sampler_class( - ServiceType.Database, source_type=self._interface_type + + sampler_class, profiler_class = profiler_resolver.resolve( + processing_engine=self.get_processing_engine(self.source_config), + service_type=ServiceType.Database, + source_type=self._interface_type, ) + # This is shared between the sampler and profiler interfaces sampler_interface: SamplerInterface = sampler_class.create( service_connection_config=self.service_conn_config, @@ -155,6 +162,8 @@ class ProfilerSource(ProfilerSourceInterface): samplingMethodType=self.source_config.samplingMethodType, randomizedSample=self.source_config.randomizedSample, ), + # TODO: Change this when we have the processing engine configuration implemented. Right now it does nothing. + processing_engine=self.get_processing_engine(self.source_config), ) profiler_interface: ProfilerInterface = profiler_class.create( @@ -168,28 +177,33 @@ class ProfilerSource(ProfilerSourceInterface): self.interface = profiler_interface return self.interface + @inject def get_profiler_runner( - self, entity: Table, profiler_config: ProfilerProcessorConfig + self, + entity: Table, + profiler_config: ProfilerProcessorConfig, + metrics_registry: Inject[Type[MetricRegistry]] = None, ) -> Profiler: """ Returns the runner for the profiler """ + if metrics_registry is None: + raise DependencyNotFoundError( + "MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered." + ) + table_config = get_config_for_table(entity, profiler_config) schema_entity, database_entity, db_service = get_context_entities( entity=entity, metadata=self.ometa_client ) profiler_interface = self.create_profiler_interface( - entity, - table_config, - profiler_config, - schema_entity, - database_entity, - db_service, + entity, table_config, schema_entity, database_entity ) if not profiler_config.profiler: return DefaultProfiler( profiler_interface=profiler_interface, + metrics_registry=metrics_registry, include_columns=get_include_columns(entity, table_config), exclude_columns=get_exclude_columns(entity, table_config), global_profiler_configuration=self.global_profiler_configuration, @@ -197,9 +211,10 @@ class ProfilerSource(ProfilerSourceInterface): ) metrics = ( - [Metrics.get(name) for name in profiler_config.profiler.metrics] + [metrics_registry.get(name) for name in profiler_config.profiler.metrics] if profiler_config.profiler.metrics else get_default_metrics( + metrics_registry=metrics_registry, table=profiler_interface.table, ometa_client=self.ometa_client, db_service=db_service, diff --git a/ingestion/src/metadata/profiler/source/profiler_source_interface.py b/ingestion/src/metadata/profiler/source/profiler_source_interface.py index b088e1b2a41..515f05f463e 100644 --- a/ingestion/src/metadata/profiler/source/profiler_source_interface.py +++ b/ingestion/src/metadata/profiler/source/profiler_source_interface.py @@ -16,6 +16,9 @@ Class defining the interface for the profiler source from abc import ABC, abstractmethod from typing import Optional +from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import ( + DatabaseServiceProfilerPipeline, +) from metadata.profiler.interface.profiler_interface import ProfilerInterface @@ -34,20 +37,12 @@ class ProfilerSourceInterface(ABC): """Set the interface""" raise NotImplementedError - @abstractmethod - def create_profiler_interface( - self, - entity, - config, - profiler_config, - schema_entity, - database_entity, - db_service, - ) -> ProfilerInterface: - """Create the profiler interface""" - raise NotImplementedError - @abstractmethod def get_profiler_runner(self, entity, profiler_config): """Get the profiler runner""" raise NotImplementedError + + @staticmethod + def get_processing_engine(config: DatabaseServiceProfilerPipeline): + """Get the processing engine based on the configuration.""" + return "Native" diff --git a/ingestion/src/metadata/sampler/sampler_interface.py b/ingestion/src/metadata/sampler/sampler_interface.py index 6bee6ef5de4..874ebdcef89 100644 --- a/ingestion/src/metadata/sampler/sampler_interface.py +++ b/ingestion/src/metadata/sampler/sampler_interface.py @@ -13,12 +13,13 @@ Interface for sampler """ import traceback from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Union +from typing import List, Optional, Set, Union from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.data.table import ( ColumnProfilerConfig, + PartitionProfilerConfig, Table, TableData, ) @@ -65,20 +66,20 @@ class SamplerInterface(ABC): include_columns: Optional[List[ColumnProfilerConfig]] = None, exclude_columns: Optional[List[str]] = None, sample_config: SampleConfig = SampleConfig(), - partition_details: Optional[Dict] = None, + partition_details: Optional[PartitionProfilerConfig] = None, sample_query: Optional[str] = None, - storage_config: DataStorageConfig = None, + storage_config: Optional[DataStorageConfig] = None, sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT, **__, ): self.ometa_client = ometa_client self._sample = None - self._columns: Optional[List[SQALikeColumn]] = None + self._columns: List[SQALikeColumn] = [] self.sample_config = sample_config self.entity = entity - self.include_columns = include_columns - self.exclude_columns = exclude_columns + self.include_columns = include_columns or [] + self.exclude_columns = exclude_columns or [] self.sample_query = sample_query self.sample_limit = sample_data_count self.partition_details = partition_details @@ -99,7 +100,7 @@ class SamplerInterface(ABC): table_config: Optional[TableConfig] = None, storage_config: Optional[DataStorageConfig] = None, default_sample_config: Optional[SampleConfig] = None, - default_sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT, + default_sample_data_count: int = SAMPLE_DATA_DEFAULT_COUNT, **kwargs, ) -> "SamplerInterface": """Create sampler""" @@ -165,16 +166,20 @@ class SamplerInterface(ABC): return self._columns - def _get_excluded_columns(self) -> Optional[Set[str]]: + def _get_excluded_columns(self) -> Set[str]: """Get excluded columns for table being profiled""" if self.exclude_columns: return set(self.exclude_columns) return set() - def _get_included_columns(self) -> Optional[Set[str]]: + def _get_included_columns(self) -> Set[str]: """Get include columns for table being profiled""" if self.include_columns: - return {include_col.columnName for include_col in self.include_columns} + return { + include_col.columnName + for include_col in self.include_columns + if include_col.columnName + } return set() @property diff --git a/ingestion/src/metadata/utils/dependency_injector/README.md b/ingestion/src/metadata/utils/dependency_injector/README.md index f709e484a54..c0c5a6910bf 100644 --- a/ingestion/src/metadata/utils/dependency_injector/README.md +++ b/ingestion/src/metadata/utils/dependency_injector/README.md @@ -151,4 +151,39 @@ The dependency container is thread-safe and uses a reentrant lock (RLock) to sup 2. The system uses type names as keys, so different types with the same name will conflict 3. Circular dependencies are not supported 4. Dependencies are always treated as optional and can be overridden -5. Dependencies can't be passed as *arg. Must be passed as *kwargs \ No newline at end of file +5. Dependencies can't be passed as *arg. Must be passed as *kwargs + +### Class-Level Dependency Injection + +For cases where you want to share dependencies across all instances of a class, you can use the `@inject_class_attributes` decorator: + +```python +from typing import ClassVar + +@inject_class_attributes +class UserService: + db: ClassVar[Inject[Database]] + cache: ClassVar[Inject[Cache]] + + @classmethod + def get_user(cls, user_id: int) -> dict: + cache_key = f"user:{user_id}" + cached = cls.cache.get(cache_key) + if cached: + return cached + return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}") + +# The dependencies are shared across all instances and accessed via class methods +user = UserService.get_user(user_id=1) +``` + +The `@inject_class_attributes` decorator will: +1. Look for class attributes annotated with `ClassVar[Inject[Type]]` +2. Automatically inject the dependencies at the class level +3. Make the dependencies available to all instances and class methods +4. Raise `DependencyNotFoundError` if a required dependency is not registered + +This is particularly useful for: +- Utility classes that don't need instance-specific state +- Services that should share the same dependencies across all instances +- Performance optimization when the same dependencies are used by multiple instances \ No newline at end of file diff --git a/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py b/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py index 7065dd9eb21..ddc4ca21ad8 100644 --- a/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py +++ b/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py @@ -132,6 +132,15 @@ class DependencyContainer: cls._instance = super().__new__(cls) return cls._instance + def get_key(self, dependency_type: Type[Any]) -> str: + """ + Get the key for a dependency. + """ + if get_origin(dependency_type) is type: + inner_type = get_args(dependency_type)[0] + return f"Type[{inner_type.__name__}]" + return dependency_type.__name__ + def register( self, dependency_type: Type[Any], dependency: Callable[[], Any] ) -> None: @@ -145,10 +154,11 @@ class DependencyContainer: Example: ```python container.register(Database, lambda: Database("postgresql://localhost:5432")) + container.register(Type[Metrics], lambda: Metrics) # For registering types themselves ``` """ with self._lock: - self._dependencies[dependency_type.__name__] = dependency + self._dependencies[self.get_key(dependency_type)] = dependency def override( self, dependency_type: Type[Any], dependency: Callable[[], Any] @@ -169,7 +179,7 @@ class DependencyContainer: ``` """ with self._lock: - self._overrides[dependency_type.__name__] = dependency + self._overrides[self.get_key(dependency_type)] = dependency def remove_override(self, dependency_type: Type[T]) -> None: """ @@ -184,7 +194,7 @@ class DependencyContainer: ``` """ with self._lock: - self._overrides.pop(dependency_type.__name__, None) + self._overrides.pop(self.get_key(dependency_type), None) def get(self, dependency_type: Type[Any]) -> Optional[Any]: """ @@ -206,10 +216,9 @@ class DependencyContainer: ``` """ with self._lock: - type_name = dependency_type.__name__ - factory = self._overrides.get(type_name) or self._dependencies.get( - type_name - ) + factory = self._overrides.get( + self.get_key(dependency_type) + ) or self._dependencies.get(self.get_key(dependency_type)) if factory is None: return None return factory() @@ -243,8 +252,10 @@ class DependencyContainer: print("Database dependency is registered") ```""" with self._lock: - type_name = dependency_type.__name__ - return type_name in self._overrides or type_name in self._dependencies + return ( + self.get_key(dependency_type) in self._overrides + or self.get_key(dependency_type) in self._dependencies + ) def inject(func: Callable[..., Any]) -> Callable[..., Any]: @@ -341,3 +352,52 @@ def extract_inject_arg(tp: Any) -> Any: f"Type {tp} is not Inject or Optional[Inject]. " f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection." ) + + +def inject_class_attributes(cls: Type[Any]) -> Type[Any]: + """ + Decorator to inject dependencies into class-level (static) attributes based on type hints. + + This decorator automatically injects dependencies into class attributes + based on their type hints. The dependencies are shared across all instances + of the class. + + Args: + cls: The class to inject dependencies into + + Returns: + A class with dependencies injected into its class-level attributes + + Example: + ```python + @inject_class_attributes + class UserService: + db: ClassVar[Inject[Database]] + cache: ClassVar[Inject[Cache]] + + @classmethod + def get_user(cls, user_id: int) -> dict: + return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}") + ``` + """ + container = DependencyContainer() + type_hints = get_type_hints(cls, include_extras=True) + + # Inject dependencies into class attributes + for attr_name, attr_type in type_hints.items(): + # Skip if attribute is already set + if hasattr(cls, attr_name): + continue + + # Check if it's an Inject type + if is_inject_type(attr_type): + dependency_type = extract_inject_arg(attr_type) + dependency = container.get(dependency_type) + if dependency is None: + raise DependencyNotFoundError( + f"Dependency of type {dependency_type} not found in container. " + f"Make sure to register it using container.register({dependency_type.__name__}, ...)" + ) + setattr(cls, attr_name, dependency) + + return cls diff --git a/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py b/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py index 8a8aa27e777..834d3001c4b 100644 --- a/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py +++ b/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from metadata.utils.dependency_injector.dependency_injector import ( @@ -5,6 +7,7 @@ from metadata.utils.dependency_injector.dependency_injector import ( DependencyNotFoundError, Inject, inject, + inject_class_attributes, ) @@ -21,8 +24,10 @@ class Cache: def __init__(self, host: str): self.host = host - def get(self, key: str) -> str: - return f"Cache hit for {key}" + def get(self, key: str) -> Optional[str]: + if key == "user:1": + return "Cache hit for user:1" + return None # Test functions for injection @@ -139,3 +144,99 @@ class TestInjectDecorator: custom_db = Database("postgresql://custom:5432") result = get_user(user_id=1, db=custom_db) assert result == "Executed: SELECT * FROM users WHERE id = 1" + + +class TestInjectClassAttributes: + def test_inject_class_attributes_single_dependency(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + container.register(Database, db_factory) + + @inject_class_attributes + class UserService: + db: Inject[Database] + + @classmethod + def get_user(cls, user_id: int) -> str: + if cls.db is None: + raise DependencyNotFoundError("Database dependency not found") + return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}") + + result = UserService.get_user(user_id=1) + assert result == "Executed: SELECT * FROM users WHERE id = 1" + + def test_inject_class_attributes_multiple_dependencies(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + cache_factory = lambda: Cache("localhost") + + container.register(Database, db_factory) + container.register(Cache, cache_factory) + + @inject_class_attributes + class UserService: + db: Inject[Database] + cache: Inject[Cache] + + @classmethod + def get_user(cls, user_id: int) -> str: + if cls.db is None: + raise DependencyNotFoundError("Database dependency not found") + if cls.cache is None: + raise DependencyNotFoundError("Cache dependency not found") + cache_key = f"user:{user_id}" + cached = cls.cache.get(cache_key) + if cached: + return cached + return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}") + + result = UserService.get_user(user_id=1) + assert result == "Cache hit for user:1" + + def test_inject_class_attributes_missing_dependency(self): + container = DependencyContainer() + container.clear() # Ensure no dependencies are registered + + with pytest.raises(DependencyNotFoundError): + + @inject_class_attributes + class UserService: + db: Inject[Database] + + @classmethod + def get_user(cls, user_id: int) -> str: + if cls.db is None: + raise DependencyNotFoundError("Database dependency not found") + return cls.db.query(f"SELECT * FROM users WHERE id = {user_id}") + + def test_inject_class_attributes_shared_dependencies(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + container.register(Database, db_factory) + + @inject_class_attributes + class UserService: + db: Inject[Database] + counter: int = 0 + + @classmethod + def increment_counter(cls) -> int: + cls.counter += 1 + return cls.counter + + @classmethod + def get_user(cls, user_id: int) -> str: + if cls.db is None: + raise DependencyNotFoundError("Database dependency not found") + return f"Counter: {cls.counter}, {cls.db.query(f'SELECT * FROM users WHERE id = {user_id}')}" + + # First call + result1 = UserService.get_user(user_id=1) + assert result1 == "Counter: 0, Executed: SELECT * FROM users WHERE id = 1" + + # Increment counter + UserService.increment_counter() + + # Second call - counter should be shared + result2 = UserService.get_user(user_id=1) + assert result2 == "Counter: 1, Executed: SELECT * FROM users WHERE id = 1" diff --git a/ingestion/tests/unit/profiler/pandas/test_profiler.py b/ingestion/tests/unit/profiler/pandas/test_profiler.py index a2ac0e97f9f..4be76570ebc 100644 --- a/ingestion/tests/unit/profiler/pandas/test_profiler.py +++ b/ingestion/tests/unit/profiler/pandas/test_profiler.py @@ -201,6 +201,7 @@ class ProfilerTest(TestCase): """ simple = DefaultProfiler( profiler_interface=self.datalake_profiler_interface, + metrics_registry=Metrics, ) simple.compute_metrics() @@ -337,6 +338,7 @@ class ProfilerTest(TestCase): default_profiler = DefaultProfiler( profiler_interface=self.datalake_profiler_interface, + metrics_registry=Metrics, ) column_metrics = default_profiler._prepare_column_metrics() for metric in column_metrics: diff --git a/ingestion/tests/unit/profiler/pandas/test_profiler_interface.py b/ingestion/tests/unit/profiler/pandas/test_profiler_interface.py index 4f1c82deb13..3bfe092a962 100644 --- a/ingestion/tests/unit/profiler/pandas/test_profiler_interface.py +++ b/ingestion/tests/unit/profiler/pandas/test_profiler_interface.py @@ -50,6 +50,7 @@ from metadata.profiler.metrics.core import ( QueryMetric, StaticMetric, ) +from metadata.profiler.metrics.registry import Metrics from metadata.profiler.metrics.static.row_count import RowCount from metadata.profiler.processor.default import get_default_metrics from metadata.sampler.pandas.sampler import DatalakeSampler @@ -197,7 +198,7 @@ class PandasInterfaceTest(TestCase): """ cls.table = User - cls.metrics = get_default_metrics(cls.table) + cls.metrics = get_default_metrics(Metrics, cls.table) cls.static_metrics = [ metric for metric in cls.metrics if issubclass(metric, StaticMetric) ] diff --git a/ingestion/tests/unit/profiler/sqlalchemy/test_profiler.py b/ingestion/tests/unit/profiler/sqlalchemy/test_profiler.py index bc6cee296d7..d0c239ac950 100644 --- a/ingestion/tests/unit/profiler/sqlalchemy/test_profiler.py +++ b/ingestion/tests/unit/profiler/sqlalchemy/test_profiler.py @@ -139,7 +139,7 @@ class ProfilerTest(TestCase): Check our pre-cooked profiler """ simple = DefaultProfiler( - profiler_interface=self.sqa_profiler_interface, + profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics ) simple.compute_metrics() @@ -297,7 +297,7 @@ class ProfilerTest(TestCase): ) simple = DefaultProfiler( - profiler_interface=sqa_profiler_interface, + profiler_interface=sqa_profiler_interface, metrics_registry=Metrics ) with pytest.raises(TimeoutError): @@ -316,7 +316,7 @@ class ProfilerTest(TestCase): ) # type: ignore default_profiler = DefaultProfiler( - profiler_interface=self.sqa_profiler_interface, + profiler_interface=self.sqa_profiler_interface, metrics_registry=Metrics ) column_metrics = default_profiler._prepare_column_metrics() diff --git a/ingestion/tests/unit/profiler/sqlalchemy/test_sqa_profiler_interface.py b/ingestion/tests/unit/profiler/sqlalchemy/test_sqa_profiler_interface.py index 0c0be95c387..fc2dd10a344 100644 --- a/ingestion/tests/unit/profiler/sqlalchemy/test_sqa_profiler_interface.py +++ b/ingestion/tests/unit/profiler/sqlalchemy/test_sqa_profiler_interface.py @@ -15,10 +15,10 @@ Test SQA Interface import os from datetime import datetime -from unittest import TestCase from unittest.mock import patch from uuid import uuid4 +import pytest from sqlalchemy import TEXT, Column, Integer, String, inspect from sqlalchemy.orm import declarative_base from sqlalchemy.orm.session import Session @@ -49,6 +49,7 @@ from metadata.profiler.metrics.core import ( QueryMetric, StaticMetric, ) +from metadata.profiler.metrics.registry import Metrics from metadata.profiler.metrics.static.row_count import RowCount from metadata.profiler.processor.default import get_default_metrics from metadata.sampler.sqlalchemy.sampler import SQASampler @@ -64,46 +65,9 @@ class User(declarative_base()): age = Column(Integer) -class SQAInterfaceTest(TestCase): - def setUp(self) -> None: - table_entity = Table( - id=uuid4(), - name="user", - columns=[ - EntityColumn( - name=ColumnName("id"), - dataType=DataType.INT, - ) - ], - ) - sqlite_conn = SQLiteConnection( - scheme=SQLiteScheme.sqlite_pysqlite, - ) - - with patch.object(SQASampler, "build_table_orm", return_value=User): - sampler = SQASampler( - service_connection_config=sqlite_conn, - ometa_client=None, - entity=None, - ) - - with patch.object(SQASampler, "build_table_orm", return_value=User): - self.sqa_profiler_interface = SQAProfilerInterface( - sqlite_conn, None, table_entity, None, sampler, 5, 43200 - ) - self.table = User - - def test_init_interface(self): - """Test we can instantiate our interface object correctly""" - - assert isinstance(self.sqa_profiler_interface.session, Session) - - def tearDown(self) -> None: - self.sqa_profiler_interface._sampler = None - - -class SQAInterfaceTestMultiThread(TestCase): - table_entity = Table( +@pytest.fixture +def table_entity(): + return Table( id=uuid4(), name="user", columns=[ @@ -113,12 +77,17 @@ class SQAInterfaceTestMultiThread(TestCase): ) ], ) - db_path = os.path.join(os.path.dirname(__file__), "test.db") - sqlite_conn = SQLiteConnection( + + +@pytest.fixture +def sqlite_conn(): + return SQLiteConnection( scheme=SQLiteScheme.sqlite_pysqlite, - databaseMode=db_path + "?check_same_thread=False", ) + +@pytest.fixture +def sqa_profiler_interface(table_entity, sqlite_conn): with patch.object(SQASampler, "build_table_orm", return_value=User): sampler = SQASampler( service_connection_config=sqlite_conn, @@ -126,140 +95,205 @@ class SQAInterfaceTestMultiThread(TestCase): entity=None, ) - sqa_profiler_interface = SQAProfilerInterface( - sqlite_conn, + with patch.object(SQASampler, "build_table_orm", return_value=User): + interface = SQAProfilerInterface( + sqlite_conn, None, table_entity, None, sampler, 5, 43200 + ) + return interface + + +def test_init_interface(sqa_profiler_interface): + """Test we can instantiate our interface object correctly""" + assert isinstance(sqa_profiler_interface.session, Session) + + +@pytest.fixture(scope="class") +def db_path(): + return os.path.join(os.path.dirname(__file__), "test.db") + + +@pytest.fixture(scope="class") +def class_sqlite_conn(db_path): + return SQLiteConnection( + scheme=SQLiteScheme.sqlite_pysqlite, + databaseMode=db_path + "?check_same_thread=False", + ) + + +@pytest.fixture(scope="class") +def class_table_entity(): + return Table( + id=uuid4(), + name="user", + columns=[ + EntityColumn( + name=ColumnName("id"), + dataType=DataType.INT, + ) + ], + ) + + +@pytest.fixture(scope="class") +def class_sqa_profiler_interface(class_sqlite_conn, class_table_entity): + with patch.object(SQASampler, "build_table_orm", return_value=User): + sampler = SQASampler( + service_connection_config=class_sqlite_conn, + ometa_client=None, + entity=None, + ) + + interface = SQAProfilerInterface( + class_sqlite_conn, None, - table_entity, + class_table_entity, None, sampler, 5, 43200, ) + return interface - @classmethod - def setUpClass(cls) -> None: - """ - Prepare Ingredients - """ - User.__table__.create(bind=cls.sqa_profiler_interface.session.get_bind()) - data = [ - User(name="John", fullname="John Doe", nickname="johnny b goode", age=30), - User(name="Jane", fullname="Jone Doe", nickname=None, age=31), - ] - cls.sqa_profiler_interface.session.add_all(data) - cls.sqa_profiler_interface.session.commit() - cls.table = User - cls.metrics = get_default_metrics(cls.table) - cls.static_metrics = [ - metric for metric in cls.metrics if issubclass(metric, StaticMetric) - ] - cls.composed_metrics = [ - metric for metric in cls.metrics if issubclass(metric, ComposedMetric) - ] - cls.window_metrics = [ +@pytest.fixture(scope="class", autouse=True) +def setup_database(class_sqa_profiler_interface): + """Setup test database and tables""" + try: + # Drop the table if it exists + User.__table__.drop( + bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True + ) + # Create the table + User.__table__.create(bind=class_sqa_profiler_interface.session.get_bind()) + except Exception as e: + print(f"Error during table setup: {str(e)}") + raise e + + data = [ + User(name="John", fullname="John Doe", nickname="johnny b goode", age=30), + User(name="Jane", fullname="Jone Doe", nickname=None, age=31), + ] + class_sqa_profiler_interface.session.add_all(data) + class_sqa_profiler_interface.session.commit() + + yield + + # Cleanup + try: + User.__table__.drop( + bind=class_sqa_profiler_interface.session.get_bind(), checkfirst=True + ) + class_sqa_profiler_interface.session.close() + except Exception as e: + print(f"Error during cleanup: {str(e)}") + raise e + + +@pytest.fixture(scope="class") +def metrics(class_sqa_profiler_interface): + metrics = get_default_metrics(Metrics, User) + return { + "all": metrics, + "static": [metric for metric in metrics if issubclass(metric, StaticMetric)], + "composed": [ + metric for metric in metrics if issubclass(metric, ComposedMetric) + ], + "window": [ metric - for metric in cls.metrics + for metric in metrics if issubclass(metric, StaticMetric) and metric.is_window_metric() - ] - cls.query_metrics = [ + ], + "query": [ metric - for metric in cls.metrics + for metric in metrics if issubclass(metric, QueryMetric) and metric.is_col_metric() - ] + ], + } - def test_init_interface(self): - """Test we can instantiate our interface object correctly""" - assert isinstance(self.sqa_profiler_interface.session, Session) +def test_init_interface_multi_thread(class_sqa_profiler_interface): + """Test we can instantiate our interface object correctly""" + assert isinstance(class_sqa_profiler_interface.session, Session) - def test_get_all_metrics(self): - table_metrics = [ + +def test_get_all_metrics(class_sqa_profiler_interface, metrics): + table_metrics = [ + ThreadPoolMetrics( + metrics=[ + metric + for metric in metrics["all"] + if (not metric.is_col_metric() and not metric.is_system_metrics()) + ], + metric_type=MetricTypes.Table, + column=None, + table=User, + ) + ] + column_metrics = [] + query_metrics = [] + window_metrics = [] + for col in inspect(User).c: + column_metrics.append( ThreadPoolMetrics( metrics=[ metric - for metric in self.metrics - if (not metric.is_col_metric() and not metric.is_system_metrics()) + for metric in metrics["static"] + if metric.is_col_metric() and not metric.is_window_metric() ], - metric_type=MetricTypes.Table, - column=None, - table=self.table, + metric_type=MetricTypes.Static, + column=col, + table=User, ) - ] - column_metrics = [] - query_metrics = [] - window_metrics = [] - for col in inspect(User).c: - column_metrics.append( + ) + for query_metric in metrics["query"]: + query_metrics.append( ThreadPoolMetrics( - metrics=[ - metric - for metric in self.static_metrics - if metric.is_col_metric() and not metric.is_window_metric() - ], - metric_type=MetricTypes.Static, + metrics=query_metric, + metric_type=MetricTypes.Query, column=col, - table=self.table, + table=User, ) ) - for query_metric in self.query_metrics: - query_metrics.append( - ThreadPoolMetrics( - metrics=query_metric, - metric_type=MetricTypes.Query, - column=col, - table=self.table, - ) - ) - window_metrics.append( - ThreadPoolMetrics( - metrics=[ - metric - for metric in self.window_metrics - if metric.is_window_metric() - ], - metric_type=MetricTypes.Window, - column=col, - table=self.table, - ) + window_metrics.append( + ThreadPoolMetrics( + metrics=[ + metric for metric in metrics["window"] if metric.is_window_metric() + ], + metric_type=MetricTypes.Window, + column=col, + table=User, ) - - all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics] - - profile_results = self.sqa_profiler_interface.get_all_metrics( - all_metrics, ) - column_profile = [ - ColumnProfile(**profile_results["columns"].get(col.name)) - for col in inspect(User).c - if profile_results["columns"].get(col.name) - ] + all_metrics = [*table_metrics, *column_metrics, *query_metrics, *window_metrics] - table_profile = TableProfile( - columnCount=profile_results["table"].get("columnCount"), - rowCount=profile_results["table"].get(RowCount.name()), - timestamp=Timestamp(int(datetime.now().timestamp())), - ) + profile_results = class_sqa_profiler_interface.get_all_metrics( + all_metrics, + ) - profile_request = CreateTableProfileRequest( - tableProfile=table_profile, columnProfile=column_profile - ) + column_profile = [ + ColumnProfile(**profile_results["columns"].get(col.name)) + for col in inspect(User).c + if profile_results["columns"].get(col.name) + ] - assert profile_request.tableProfile.columnCount == 6 - assert profile_request.tableProfile.rowCount == 2 - name_column_profile = [ - profile - for profile in profile_request.columnProfile - if profile.name == "name" - ][0] - id_column_profile = [ - profile for profile in profile_request.columnProfile if profile.name == "id" - ][0] - assert name_column_profile.nullCount == 0 - assert id_column_profile.median == 1.0 + table_profile = TableProfile( + columnCount=profile_results["table"].get("columnCount"), + rowCount=profile_results["table"].get(RowCount.name()), + timestamp=Timestamp(int(datetime.now().timestamp())), + ) - @classmethod - def tearDownClass(cls) -> None: - os.remove(cls.db_path) - return super().tearDownClass() + profile_request = CreateTableProfileRequest( + tableProfile=table_profile, columnProfile=column_profile + ) + + assert profile_request.tableProfile.columnCount == 6 + assert profile_request.tableProfile.rowCount == 2 + name_column_profile = [ + profile for profile in profile_request.columnProfile if profile.name == "name" + ][0] + id_column_profile = [ + profile for profile in profile_request.columnProfile if profile.name == "id" + ][0] + assert name_column_profile.nullCount == 0 + assert id_column_profile.median == 1.0