OpenMetadata/ingestion/src/metadata/sampler/sampler_interface.py
Teddy 28bd01c471
MINOR: Remove default 100 when profileSample is None (#19672)
* fix: remove default 100% percent

* fix: use get_dataset

* fix: orm_profiler tests
2025-02-05 19:14:31 +01:00

249 lines
8.8 KiB
Python

# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interface for sampler
"""
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Union
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import (
ColumnProfilerConfig,
Table,
TableData,
)
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
DataStorageConfig,
)
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.api.models import TableConfig
from metadata.profiler.processor.sample_data_handler import upload_sample_data
from metadata.sampler.config import (
get_exclude_columns,
get_include_columns,
get_profile_sample_config,
get_sample_data_count_config,
get_sample_query,
)
from metadata.sampler.models import SampleConfig
from metadata.sampler.partition import get_partition_details
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
from metadata.utils.execution_time_tracker import calculate_execution_time
from metadata.utils.logger import sampler_logger
from metadata.utils.sqa_like_column import SQALikeColumn
from metadata.utils.ssl_manager import get_ssl_connection
logger = sampler_logger()
class SamplerInterface(ABC):
"""Sampler interface
This should be the entrypoint for computing any metrics that are required downstream for
data quality, profiling, etc.
"""
# pylint: disable=too-many-instance-attributes, too-many-arguments
def __init__(
self,
service_connection_config: Union[DatabaseConnection, DatalakeConnection],
ometa_client: OpenMetadata,
entity: Table,
include_columns: Optional[List[ColumnProfilerConfig]] = None,
exclude_columns: Optional[List[str]] = None,
sample_config: SampleConfig = SampleConfig(),
partition_details: Optional[Dict] = None,
sample_query: Optional[str] = None,
storage_config: DataStorageConfig = None,
sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
**__,
):
self.ometa_client = ometa_client
self._sample = None
self._columns: Optional[List[SQALikeColumn]] = None
self.sample_config = sample_config
self.entity = entity
self.include_columns = include_columns
self.exclude_columns = exclude_columns
self.sample_query = sample_query
self.sample_limit = sample_data_count
self.partition_details = partition_details
self.storage_config = storage_config
self.service_connection_config = service_connection_config
self.connection = get_ssl_connection(self.service_connection_config)
self.client = self.get_client()
# pylint: disable=too-many-arguments, too-many-locals
@classmethod
def create(
cls,
service_connection_config: Union[DatabaseConnection, DatalakeConnection],
ometa_client: OpenMetadata,
entity: Table,
schema_entity: DatabaseSchema,
database_entity: Database,
table_config: Optional[TableConfig] = None,
storage_config: Optional[DataStorageConfig] = None,
default_sample_config: Optional[SampleConfig] = None,
default_sample_data_count: Optional[int] = SAMPLE_DATA_DEFAULT_COUNT,
**kwargs,
) -> "SamplerInterface":
"""Create sampler"""
sample_data_count = get_sample_data_count_config(
entity=entity,
schema_entity=schema_entity,
database_entity=database_entity,
entity_config=table_config,
default_sample_data_count=default_sample_data_count,
)
sample_config = get_profile_sample_config(
entity=entity,
schema_entity=schema_entity,
database_entity=database_entity,
entity_config=table_config,
default_sample_config=default_sample_config,
)
sample_query = get_sample_query(entity=entity, entity_config=table_config)
partition_details = get_partition_details(
entity=entity, entity_config=table_config
)
include_columns = get_include_columns(entity, entity_config=table_config)
exclude_columns = get_exclude_columns(entity, entity_config=table_config)
return cls(
service_connection_config=service_connection_config,
ometa_client=ometa_client,
entity=entity,
include_columns=include_columns,
exclude_columns=exclude_columns,
sample_config=sample_config,
partition_details=partition_details,
sample_query=sample_query,
storage_config=storage_config,
sample_data_count=sample_data_count,
**kwargs,
)
@property
def columns(self) -> List[SQALikeColumn]:
"""
Return the list of columns to profile
by skipping the columns to ignore.
"""
if self._columns:
return self._columns
if self._get_included_columns():
self._columns = [
column
for column in self.get_columns()
if column.name in self._get_included_columns()
]
if not self._get_included_columns():
self._columns = [
column
for column in self._columns or self.get_columns()
if column.name not in self._get_excluded_columns()
]
return self._columns
def _get_excluded_columns(self) -> Optional[Set[str]]:
"""Get excluded columns for table being profiled"""
if self.exclude_columns:
return set(self.exclude_columns)
return set()
def _get_included_columns(self) -> Optional[Set[str]]:
"""Get include columns for table being profiled"""
if self.include_columns:
return {include_col.columnName for include_col in self.include_columns}
return set()
@property
@abstractmethod
def raw_dataset(self):
"""Table object to run the sampling"""
raise NotImplementedError
@abstractmethod
def get_client(self):
"""Get client"""
raise NotImplementedError
@abstractmethod
def _rdn_sample_from_user_query(self):
"""Get random sample from user query"""
raise NotImplementedError
@abstractmethod
def _fetch_sample_data_from_user_query(self) -> TableData:
"""Fetch sample data from user query"""
raise NotImplementedError
@abstractmethod
def get_dataset(self, **kwargs):
"""Get random sample"""
raise NotImplementedError
@abstractmethod
def fetch_sample_data(self, columns: Optional[List[SQALikeColumn]]) -> TableData:
"""Fetch sample data
Args:
columns (Optional[List]): List of columns to fetch
"""
raise NotImplementedError
@abstractmethod
def get_columns(self) -> List[SQALikeColumn]:
"""get columns"""
raise NotImplementedError
@calculate_execution_time(store=False)
def generate_sample_data(self) -> Optional[TableData]:
"""Fetch and ingest sample data
Returns:
TableData: sample data
"""
try:
logger.debug(
f"Fetching sample data for {self.entity.fullyQualifiedName.root}..."
)
table_data = self.fetch_sample_data(self.columns)
# Only store the data if configured to do so
if self.storage_config:
upload_sample_data(
data=table_data,
entity=self.entity,
sample_storage_config=self.storage_config,
)
table_data.rows = table_data.rows[
: min(SAMPLE_DATA_DEFAULT_COUNT, self.sample_limit)
]
return table_data
except Exception as err:
logger.debug(traceback.format_exc())
logger.warning(f"Error fetching sample data: {err}")
raise err