diff --git a/profiler/src/openmetadata/common/database.py b/profiler/src/openmetadata/common/database.py index eed87d7dd6e..126e1b162ae 100644 --- a/profiler/src/openmetadata/common/database.py +++ b/profiler/src/openmetadata/common/database.py @@ -26,13 +26,18 @@ class Database(Closeable, metaclass=ABCMeta): def create(cls, config_dict: dict) -> "Database": pass + @property + @abstractmethod + def columns(self): + pass + @property @abstractmethod def sql_exprs(self): pass @abstractmethod - def table_metadata_query(self, table_name: str) -> str: + def table_column_metadata(self, table: str, schema: str): pass @abstractmethod diff --git a/profiler/src/openmetadata/common/database_common.py b/profiler/src/openmetadata/common/database_common.py index 52a65b8932a..65f7c89ae6d 100644 --- a/profiler/src/openmetadata/common/database_common.py +++ b/profiler/src/openmetadata/common/database_common.py @@ -14,11 +14,14 @@ import re from abc import abstractmethod from datetime import date, datetime from numbers import Number -from typing import List, Optional +from typing import List, Optional, Type from urllib.parse import quote_plus from pydantic import BaseModel from sqlalchemy import create_engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.inspection import inspect +from sqlalchemy.sql import sqltypes as types from openmetadata.common.config import ConfigModel, IncludeFilterPattern from openmetadata.common.database import Database @@ -58,33 +61,27 @@ class SQLConnectionConfig(ConfigModel): _numeric_types = [ - "SMALLINT", - "INTEGER", - "BIGINT", - "DECIMAL", - "NUMERIC", - "REAL", - "DOUBLE PRECISION", - "SMALLSERIAL", - "SERIAL", - "BIGSERIAL", + types.Integer, + types.Numeric, ] -_text_types = ["CHARACTER VARYING", "CHARACTER", "CHAR", "VARCHAR" "TEXT"] +_text_types = [ + types.VARCHAR, + types.String, +] _time_types = [ - "TIMESTAMP", - "DATE", - "TIME", - "TIMESTAMP WITH TIME ZONE", - "TIMESTAMP WITHOUT TIME ZONE", - "TIME WITH TIME ZONE", - "TIME WITHOUT TIME ZONE", + types.Date, + types.DATE, + types.Time, + types.DateTime, + types.DATETIME, + types.TIMESTAMP, ] def register_custom_type( - data_types: List[str], type_category: SupportedDataType + data_types: List[Type[types.TypeEngine]], type_category: SupportedDataType ) -> None: if type_category == SupportedDataType.TIME: _time_types.extend(data_types) @@ -242,34 +239,85 @@ class DatabaseCommon(Database): data_type_date = "DATE" config: SQLConnectionConfig = None sql_exprs: SQLExpressions = SQLExpressions() + columns: List[Column] = [] def __init__(self, config: SQLConnectionConfig): self.config = config self.connection_string = self.config.get_connection_url() self.engine = create_engine(self.connection_string, **self.config.options) self.connection = self.engine.raw_connection() + self.inspector = inspect(self.engine) @classmethod def create(cls, config_dict: dict): pass - def table_metadata_query(self, table_name: str) -> str: - pass - def qualify_table_name(self, table_name: str) -> str: return table_name def qualify_column_name(self, column_name: str): return column_name - def is_text(self, column_type: str): - return column_type.upper() in _text_types + def is_text(self, column_type: Type[types.TypeEngine]): + for sql_type in _text_types: + if isinstance(column_type, sql_type): + return True + return False - def is_number(self, column_type: str): - return column_type.upper() in _numeric_types + def is_number(self, column_type: Type[types.TypeEngine]): + for sql_type in _numeric_types: + if isinstance(column_type, sql_type): + return True + return False - def is_time(self, column_type: str): - return column_type.upper() in _time_types + def is_time(self, column_type: Type[types.TypeEngine]): + for sql_type in _time_types: + if isinstance(column_type, sql_type): + return True + return False + + def table_column_metadata(self, table: str, schema: str): + table = self.qualify_table_name(table) + pk_constraints = self.inspector.get_pk_constraint(table, schema) + pk_columns = ( + pk_constraints["column_constraints"] + if len(pk_constraints) > 0 and "column_constraints" in pk_constraints.keys() + else {} + ) + unique_constraints = [] + try: + unique_constraints = self.inspector.get_unique_constraints(table, schema) + except NotImplementedError: + pass + unique_columns = [] + for constraint in unique_constraints: + if "column_names" in constraint.keys(): + unique_columns = constraint["column_names"] + columns = self.inspector.get_columns(self.qualify_table_name(table)) + for column in columns: + name = column["name"] + data_type = column["type"] + nullable = True + if not column["nullable"] or column["name"] in pk_columns: + nullable = False + + if self.is_number(data_type): + logical_type = SupportedDataType.NUMERIC + elif self.is_time(data_type): + logical_type = SupportedDataType.TIME + elif self.is_text(data_type): + logical_type = SupportedDataType.TEXT + else: + logger.info(f" {name} ({data_type}) not supported.") + continue + self.columns.append( + Column( + name=name, + data_type=data_type, + nullable=nullable, + logical_type=logical_type, + ) + ) def execute_query(self, sql: str) -> tuple: return self.execute_query_columns(sql)[0] diff --git a/profiler/src/openmetadata/databases/hive.py b/profiler/src/openmetadata/databases/hive.py new file mode 100644 index 00000000000..3a9c58d65cc --- /dev/null +++ b/profiler/src/openmetadata/databases/hive.py @@ -0,0 +1,48 @@ +import json +from typing import Optional + +from pyhive import hive # noqa: F401 +from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp + +from openmetadata.common.database_common import ( + DatabaseCommon, + SQLConnectionConfig, + SQLExpressions, + register_custom_type, +) +from openmetadata.profiler.profiler_metadata import SupportedDataType + +register_custom_type([HiveDate, HiveTimestamp], SupportedDataType.TIME) +register_custom_type([HiveDecimal], SupportedDataType.NUMERIC) + + +class HiveConfig(SQLConnectionConfig): + scheme = "hive" + auth_options: Optional[str] = None + service_type = "Hive" + + def get_connection_url(self): + url = super().get_connection_url() + if self.auth_options is not None: + return f"{url};{self.auth_options}" + else: + return url + + +class HiveSQLExpressions(SQLExpressions): + stddev_expr = "STDDEV_POP({})" + regex_like_pattern_expr = "cast({} as string) rlike '{}'" + + +class Hive(DatabaseCommon): + config: HiveConfig = None + sql_exprs: HiveSQLExpressions = HiveSQLExpressions() + + def __init__(self, config): + super().__init__(config) + self.config = config + + @classmethod + def create(cls, config_dict): + config = HiveConfig.parse_obj(config_dict) + return cls(config) diff --git a/profiler/src/openmetadata/databases/mysql.py b/profiler/src/openmetadata/databases/mysql.py index 472104b7af3..e35e0551e9c 100644 --- a/profiler/src/openmetadata/databases/mysql.py +++ b/profiler/src/openmetadata/databases/mysql.py @@ -4,38 +4,6 @@ from openmetadata.common.database_common import ( DatabaseCommon, SQLConnectionConfig, SQLExpressions, - register_custom_type, -) -from openmetadata.profiler.profiler_metadata import SupportedDataType - -register_custom_type( - ["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"], - SupportedDataType.TEXT, -) - -register_custom_type( - [ - "INTEGER", - "INT", - "SMALLINT", - "TINYINT", - "MEDIUMINT", - "BIGINT", - "DECIMAL", - "NUMERIC", - "FLOAT", - "DOUBLE", - "REAL", - "DOUBLE PRECISION", - "DEC", - "FIXED", - ], - SupportedDataType.NUMERIC, -) - -register_custom_type( - ["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"], - SupportedDataType.TIME, ) @@ -65,13 +33,3 @@ class MySQL(DatabaseCommon): def create(cls, config_dict): config = MySQLConnectionConfig.parse_obj(config_dict) return cls(config) - - def table_metadata_query(self, table_name: str) -> str: - sql = ( - f"SELECT column_name, data_type, is_nullable \n" - f"FROM information_schema.columns \n" - f"WHERE lower(table_name) = '{table_name}'" - ) - if self.config.database: - sql += f" \n AND table_schema = '{self.database}'" - return sql diff --git a/profiler/src/openmetadata/databases/postgres.py b/profiler/src/openmetadata/databases/postgres.py index f4f4e63960f..f085264d116 100644 --- a/profiler/src/openmetadata/databases/postgres.py +++ b/profiler/src/openmetadata/databases/postgres.py @@ -43,18 +43,6 @@ class Postgres(DatabaseCommon): config = PostgresConnectionConfig.parse_obj(config_dict) return cls(config) - def table_metadata_query(self, table_name: str) -> str: - sql = ( - f"SELECT column_name, data_type, is_nullable \n" - f"FROM information_schema.columns \n" - f"WHERE lower(table_name) = '{table_name}'" - ) - if self.config.database: - sql += f" \n AND table_catalog = '{self.config.database}'" - if self.config.db_schema: - sql += f" \n AND table_schema = '{self.config.db_schema}'" - return sql - def qualify_table_name(self, table_name: str) -> str: if self.config.db_schema: return f'"{self.config.db_schema}"."{table_name}"' diff --git a/profiler/src/openmetadata/databases/redshift.py b/profiler/src/openmetadata/databases/redshift.py index 6e598ebfee9..b45e64c0f89 100644 --- a/profiler/src/openmetadata/databases/redshift.py +++ b/profiler/src/openmetadata/databases/redshift.py @@ -13,56 +13,6 @@ from openmetadata.common.database_common import ( DatabaseCommon, SQLConnectionConfig, SQLExpressions, - register_custom_type, -) -from openmetadata.profiler.profiler_metadata import SupportedDataType - -register_custom_type( - [ - "CHAR", - "CHARACTER", - "BPCHAR", - "VARCHAR", - "CHARACTER VARYING", - "NVARCHAR", - "TEXT", - ], - SupportedDataType.TEXT, -) - -register_custom_type( - [ - "SMALLINT", - "INT2", - "INTEGER", - "INT", - "INT4", - "BIGINT", - "INT8", - "DECIMAL", - "NUMERIC", - "REAL", - "FLOAT4", - "DOUBLE PRECISION", - "FLOAT8", - "FLOAT", - ], - SupportedDataType.NUMERIC, -) - -register_custom_type( - [ - "DATE", - "TIMESTAMP", - "TIMESTAMP WITHOUT TIME ZONE", - "TIMESTAMPTZ", - "TIMESTAMP WITH TIME ZONE", - "TIME", - "TIME WITHOUT TIME ZONE", - "TIMETZ", - "TIME WITH TIME ZONE", - ], - SupportedDataType.TIME, ) @@ -94,15 +44,3 @@ class Redshift(DatabaseCommon): def create(cls, config_dict): config = RedshiftConnectionConfig.parse_obj(config_dict) return cls(config) - - def table_metadata_query(self, table_name: str) -> str: - sql = ( - f"SELECT column_name, data_type, is_nullable \n" - f"FROM information_schema.columns \n" - f"WHERE lower(table_name) = '{table_name}'" - ) - if self.config.database: - sql += f" \n AND table_catalog = '{self.config.database}'" - if self.config.db_schema: - sql += f" \n AND table_schema = '{self.config.db_schema}'" - return sql diff --git a/profiler/src/openmetadata/databases/snowflake.py b/profiler/src/openmetadata/databases/snowflake.py index b0e0f0df343..eb563160138 100644 --- a/profiler/src/openmetadata/databases/snowflake.py +++ b/profiler/src/openmetadata/databases/snowflake.py @@ -85,15 +85,3 @@ class Snowflake(DatabaseCommon): def create(cls, config_dict): config = SnowflakeConnectionConfig.parse_obj(config_dict) return cls(config) - - def table_metadata_query(self, table_name: str) -> str: - sql = ( - f"SELECT column_name, data_type, is_nullable \n" - f"FROM information_schema.columns \n" - f"WHERE lower(table_name) = '{table_name.lower()}'" - ) - if self.config.database: - sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'" - if self.config.db_schema: - sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'" - return sql diff --git a/profiler/src/openmetadata/profiler/profiler.py b/profiler/src/openmetadata/profiler/profiler.py index 7c41252fd63..ed8ede4e207 100644 --- a/profiler/src/openmetadata/profiler/profiler.py +++ b/profiler/src/openmetadata/profiler/profiler.py @@ -16,7 +16,6 @@ from typing import List from openmetadata.common.database import Database from openmetadata.common.metric import Metric from openmetadata.profiler.profiler_metadata import ( - Column, ColumnProfileResult, MetricMeasurement, ProfileResult, @@ -44,7 +43,6 @@ class Profiler: self.time = profile_time self.qualified_table_name = self.database.qualify_table_name(table_name) self.scan_reference = None - self.columns: List[Column] = [] self.start_time = None self.queries_executed = 0 @@ -57,7 +55,9 @@ class Profiler: self.start_time = datetime.now() try: - self._table_metadata() + self.database.table_column_metadata(self.table.name, None) + logger.debug(str(len(self.database.columns)) + " columns:") + self.profiler_result.table_result.col_count = len(self.database.columns) self._profile_aggregations() self._query_group_by_value() self._query_histograms() @@ -72,37 +72,6 @@ class Profiler: return self.profiler_result - def _table_metadata(self): - sql = self.database.table_metadata_query(self.table.name) - columns = self.database.execute_query_all(sql) - self.queries_executed += 1 - self.table_columns = [] - for column in columns: - name = column[0] - data_type = column[1] - nullable = "YES" == column[2].upper() - if self.database.is_number(data_type): - logical_type = SupportedDataType.NUMERIC - elif self.database.is_time(data_type): - logical_type = SupportedDataType.TIME - elif self.database.is_text(data_type): - logical_type = SupportedDataType.TEXT - else: - logger.info(f" {name} ({data_type}) not supported.") - continue - self.columns.append( - Column( - name=name, - data_type=data_type, - nullable=nullable, - logical_type=logical_type, - ) - ) - - self.column_names: List[str] = [column.name for column in self.columns] - logger.debug(str(len(self.columns)) + " columns:") - self.profiler_result.table_result.col_count = len(self.columns) - def _profile_aggregations(self): measurements: List[MetricMeasurement] = [] fields: List[str] = [] @@ -113,7 +82,7 @@ class Profiler: column_metric_indices = {} try: - for column in self.columns: + for column in self.database.columns: metric_indices = {} column_metric_indices[column.name.lower()] = metric_indices column_name = column.name @@ -210,7 +179,7 @@ class Profiler: if row_count_measurement: row_count = row_count_measurement.value self.profiler_result.table_result.row_count = row_count - for column in self.columns: + for column in self.database.columns: column_name = column.name metric_indices = column_metric_indices[column_name.lower()] non_missing_index = metric_indices.get("non_missing") @@ -288,7 +257,7 @@ class Profiler: logger.error(f"Exception during aggregation query", exc_info=e) def _query_group_by_value(self): - for column in self.columns: + for column in self.database.columns: try: measurements = [] column_name = column.name @@ -346,7 +315,7 @@ class Profiler: ) def _query_histograms(self): - for column in self.columns: + for column in self.database.columns: column_name = column.name try: if column.is_number(): diff --git a/profiler/src/openmetadata/profiler/profiler_metadata.py b/profiler/src/openmetadata/profiler/profiler_metadata.py index a295868f545..9773d814838 100644 --- a/profiler/src/openmetadata/profiler/profiler_metadata.py +++ b/profiler/src/openmetadata/profiler/profiler_metadata.py @@ -8,7 +8,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -58,7 +58,7 @@ class Column(BaseModel): name: str nullable: bool = None - data_type: str + data_type: Any logical_type: SupportedDataType def is_text(self) -> bool: @@ -79,7 +79,7 @@ class Table(BaseModel): class GroupValue(BaseModel): - """Metrinc Group Values""" + """Metric Group Values""" group: Dict = {} value: object