Fix #961: Profiler: Add support for Hive and Athena (#962)

This commit is contained in:
Sriharsha Chintalapani 2021-10-27 12:10:59 -07:00 committed by GitHub
parent 721372719c
commit e5dfcf9f15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 141 additions and 199 deletions

View File

@ -26,13 +26,18 @@ class Database(Closeable, metaclass=ABCMeta):
def create(cls, config_dict: dict) -> "Database": def create(cls, config_dict: dict) -> "Database":
pass pass
@property
@abstractmethod
def columns(self):
pass
@property @property
@abstractmethod @abstractmethod
def sql_exprs(self): def sql_exprs(self):
pass pass
@abstractmethod @abstractmethod
def table_metadata_query(self, table_name: str) -> str: def table_column_metadata(self, table: str, schema: str):
pass pass
@abstractmethod @abstractmethod

View File

@ -14,11 +14,14 @@ import re
from abc import abstractmethod from abc import abstractmethod
from datetime import date, datetime from datetime import date, datetime
from numbers import Number from numbers import Number
from typing import List, Optional from typing import List, Optional, Type
from urllib.parse import quote_plus from urllib.parse import quote_plus
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from sqlalchemy.sql import sqltypes as types
from openmetadata.common.config import ConfigModel, IncludeFilterPattern from openmetadata.common.config import ConfigModel, IncludeFilterPattern
from openmetadata.common.database import Database from openmetadata.common.database import Database
@ -58,33 +61,27 @@ class SQLConnectionConfig(ConfigModel):
_numeric_types = [ _numeric_types = [
"SMALLINT", types.Integer,
"INTEGER", types.Numeric,
"BIGINT",
"DECIMAL",
"NUMERIC",
"REAL",
"DOUBLE PRECISION",
"SMALLSERIAL",
"SERIAL",
"BIGSERIAL",
] ]
_text_types = ["CHARACTER VARYING", "CHARACTER", "CHAR", "VARCHAR" "TEXT"] _text_types = [
types.VARCHAR,
types.String,
]
_time_types = [ _time_types = [
"TIMESTAMP", types.Date,
"DATE", types.DATE,
"TIME", types.Time,
"TIMESTAMP WITH TIME ZONE", types.DateTime,
"TIMESTAMP WITHOUT TIME ZONE", types.DATETIME,
"TIME WITH TIME ZONE", types.TIMESTAMP,
"TIME WITHOUT TIME ZONE",
] ]
def register_custom_type( def register_custom_type(
data_types: List[str], type_category: SupportedDataType data_types: List[Type[types.TypeEngine]], type_category: SupportedDataType
) -> None: ) -> None:
if type_category == SupportedDataType.TIME: if type_category == SupportedDataType.TIME:
_time_types.extend(data_types) _time_types.extend(data_types)
@ -242,34 +239,85 @@ class DatabaseCommon(Database):
data_type_date = "DATE" data_type_date = "DATE"
config: SQLConnectionConfig = None config: SQLConnectionConfig = None
sql_exprs: SQLExpressions = SQLExpressions() sql_exprs: SQLExpressions = SQLExpressions()
columns: List[Column] = []
def __init__(self, config: SQLConnectionConfig): def __init__(self, config: SQLConnectionConfig):
self.config = config self.config = config
self.connection_string = self.config.get_connection_url() self.connection_string = self.config.get_connection_url()
self.engine = create_engine(self.connection_string, **self.config.options) self.engine = create_engine(self.connection_string, **self.config.options)
self.connection = self.engine.raw_connection() self.connection = self.engine.raw_connection()
self.inspector = inspect(self.engine)
@classmethod @classmethod
def create(cls, config_dict: dict): def create(cls, config_dict: dict):
pass pass
def table_metadata_query(self, table_name: str) -> str:
pass
def qualify_table_name(self, table_name: str) -> str: def qualify_table_name(self, table_name: str) -> str:
return table_name return table_name
def qualify_column_name(self, column_name: str): def qualify_column_name(self, column_name: str):
return column_name return column_name
def is_text(self, column_type: str): def is_text(self, column_type: Type[types.TypeEngine]):
return column_type.upper() in _text_types for sql_type in _text_types:
if isinstance(column_type, sql_type):
return True
return False
def is_number(self, column_type: str): def is_number(self, column_type: Type[types.TypeEngine]):
return column_type.upper() in _numeric_types for sql_type in _numeric_types:
if isinstance(column_type, sql_type):
return True
return False
def is_time(self, column_type: str): def is_time(self, column_type: Type[types.TypeEngine]):
return column_type.upper() in _time_types for sql_type in _time_types:
if isinstance(column_type, sql_type):
return True
return False
def table_column_metadata(self, table: str, schema: str):
table = self.qualify_table_name(table)
pk_constraints = self.inspector.get_pk_constraint(table, schema)
pk_columns = (
pk_constraints["column_constraints"]
if len(pk_constraints) > 0 and "column_constraints" in pk_constraints.keys()
else {}
)
unique_constraints = []
try:
unique_constraints = self.inspector.get_unique_constraints(table, schema)
except NotImplementedError:
pass
unique_columns = []
for constraint in unique_constraints:
if "column_names" in constraint.keys():
unique_columns = constraint["column_names"]
columns = self.inspector.get_columns(self.qualify_table_name(table))
for column in columns:
name = column["name"]
data_type = column["type"]
nullable = True
if not column["nullable"] or column["name"] in pk_columns:
nullable = False
if self.is_number(data_type):
logical_type = SupportedDataType.NUMERIC
elif self.is_time(data_type):
logical_type = SupportedDataType.TIME
elif self.is_text(data_type):
logical_type = SupportedDataType.TEXT
else:
logger.info(f" {name} ({data_type}) not supported.")
continue
self.columns.append(
Column(
name=name,
data_type=data_type,
nullable=nullable,
logical_type=logical_type,
)
)
def execute_query(self, sql: str) -> tuple: def execute_query(self, sql: str) -> tuple:
return self.execute_query_columns(sql)[0] return self.execute_query_columns(sql)[0]

View File

@ -0,0 +1,48 @@
import json
from typing import Optional
from pyhive import hive # noqa: F401
from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type([HiveDate, HiveTimestamp], SupportedDataType.TIME)
register_custom_type([HiveDecimal], SupportedDataType.NUMERIC)
class HiveConfig(SQLConnectionConfig):
scheme = "hive"
auth_options: Optional[str] = None
service_type = "Hive"
def get_connection_url(self):
url = super().get_connection_url()
if self.auth_options is not None:
return f"{url};{self.auth_options}"
else:
return url
class HiveSQLExpressions(SQLExpressions):
stddev_expr = "STDDEV_POP({})"
regex_like_pattern_expr = "cast({} as string) rlike '{}'"
class Hive(DatabaseCommon):
config: HiveConfig = None
sql_exprs: HiveSQLExpressions = HiveSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = HiveConfig.parse_obj(config_dict)
return cls(config)

View File

@ -4,38 +4,6 @@ from openmetadata.common.database_common import (
DatabaseCommon, DatabaseCommon,
SQLConnectionConfig, SQLConnectionConfig,
SQLExpressions, SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"INTEGER",
"INT",
"SMALLINT",
"TINYINT",
"MEDIUMINT",
"BIGINT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DOUBLE",
"REAL",
"DOUBLE PRECISION",
"DEC",
"FIXED",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"],
SupportedDataType.TIME,
) )
@ -65,13 +33,3 @@ class MySQL(DatabaseCommon):
def create(cls, config_dict): def create(cls, config_dict):
config = MySQLConnectionConfig.parse_obj(config_dict) config = MySQLConnectionConfig.parse_obj(config_dict)
return cls(config) return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_schema = '{self.database}'"
return sql

View File

@ -43,18 +43,6 @@ class Postgres(DatabaseCommon):
config = PostgresConnectionConfig.parse_obj(config_dict) config = PostgresConnectionConfig.parse_obj(config_dict)
return cls(config) return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql
def qualify_table_name(self, table_name: str) -> str: def qualify_table_name(self, table_name: str) -> str:
if self.config.db_schema: if self.config.db_schema:
return f'"{self.config.db_schema}"."{table_name}"' return f'"{self.config.db_schema}"."{table_name}"'

View File

@ -13,56 +13,6 @@ from openmetadata.common.database_common import (
DatabaseCommon, DatabaseCommon,
SQLConnectionConfig, SQLConnectionConfig,
SQLExpressions, SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
[
"CHAR",
"CHARACTER",
"BPCHAR",
"VARCHAR",
"CHARACTER VARYING",
"NVARCHAR",
"TEXT",
],
SupportedDataType.TEXT,
)
register_custom_type(
[
"SMALLINT",
"INT2",
"INTEGER",
"INT",
"INT4",
"BIGINT",
"INT8",
"DECIMAL",
"NUMERIC",
"REAL",
"FLOAT4",
"DOUBLE PRECISION",
"FLOAT8",
"FLOAT",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
[
"DATE",
"TIMESTAMP",
"TIMESTAMP WITHOUT TIME ZONE",
"TIMESTAMPTZ",
"TIMESTAMP WITH TIME ZONE",
"TIME",
"TIME WITHOUT TIME ZONE",
"TIMETZ",
"TIME WITH TIME ZONE",
],
SupportedDataType.TIME,
) )
@ -94,15 +44,3 @@ class Redshift(DatabaseCommon):
def create(cls, config_dict): def create(cls, config_dict):
config = RedshiftConnectionConfig.parse_obj(config_dict) config = RedshiftConnectionConfig.parse_obj(config_dict)
return cls(config) return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql

View File

@ -85,15 +85,3 @@ class Snowflake(DatabaseCommon):
def create(cls, config_dict): def create(cls, config_dict):
config = SnowflakeConnectionConfig.parse_obj(config_dict) config = SnowflakeConnectionConfig.parse_obj(config_dict)
return cls(config) return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name.lower()}'"
)
if self.config.database:
sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'"
if self.config.db_schema:
sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'"
return sql

View File

@ -16,7 +16,6 @@ from typing import List
from openmetadata.common.database import Database from openmetadata.common.database import Database
from openmetadata.common.metric import Metric from openmetadata.common.metric import Metric
from openmetadata.profiler.profiler_metadata import ( from openmetadata.profiler.profiler_metadata import (
Column,
ColumnProfileResult, ColumnProfileResult,
MetricMeasurement, MetricMeasurement,
ProfileResult, ProfileResult,
@ -44,7 +43,6 @@ class Profiler:
self.time = profile_time self.time = profile_time
self.qualified_table_name = self.database.qualify_table_name(table_name) self.qualified_table_name = self.database.qualify_table_name(table_name)
self.scan_reference = None self.scan_reference = None
self.columns: List[Column] = []
self.start_time = None self.start_time = None
self.queries_executed = 0 self.queries_executed = 0
@ -57,7 +55,9 @@ class Profiler:
self.start_time = datetime.now() self.start_time = datetime.now()
try: try:
self._table_metadata() self.database.table_column_metadata(self.table.name, None)
logger.debug(str(len(self.database.columns)) + " columns:")
self.profiler_result.table_result.col_count = len(self.database.columns)
self._profile_aggregations() self._profile_aggregations()
self._query_group_by_value() self._query_group_by_value()
self._query_histograms() self._query_histograms()
@ -72,37 +72,6 @@ class Profiler:
return self.profiler_result return self.profiler_result
def _table_metadata(self):
sql = self.database.table_metadata_query(self.table.name)
columns = self.database.execute_query_all(sql)
self.queries_executed += 1
self.table_columns = []
for column in columns:
name = column[0]
data_type = column[1]
nullable = "YES" == column[2].upper()
if self.database.is_number(data_type):
logical_type = SupportedDataType.NUMERIC
elif self.database.is_time(data_type):
logical_type = SupportedDataType.TIME
elif self.database.is_text(data_type):
logical_type = SupportedDataType.TEXT
else:
logger.info(f" {name} ({data_type}) not supported.")
continue
self.columns.append(
Column(
name=name,
data_type=data_type,
nullable=nullable,
logical_type=logical_type,
)
)
self.column_names: List[str] = [column.name for column in self.columns]
logger.debug(str(len(self.columns)) + " columns:")
self.profiler_result.table_result.col_count = len(self.columns)
def _profile_aggregations(self): def _profile_aggregations(self):
measurements: List[MetricMeasurement] = [] measurements: List[MetricMeasurement] = []
fields: List[str] = [] fields: List[str] = []
@ -113,7 +82,7 @@ class Profiler:
column_metric_indices = {} column_metric_indices = {}
try: try:
for column in self.columns: for column in self.database.columns:
metric_indices = {} metric_indices = {}
column_metric_indices[column.name.lower()] = metric_indices column_metric_indices[column.name.lower()] = metric_indices
column_name = column.name column_name = column.name
@ -210,7 +179,7 @@ class Profiler:
if row_count_measurement: if row_count_measurement:
row_count = row_count_measurement.value row_count = row_count_measurement.value
self.profiler_result.table_result.row_count = row_count self.profiler_result.table_result.row_count = row_count
for column in self.columns: for column in self.database.columns:
column_name = column.name column_name = column.name
metric_indices = column_metric_indices[column_name.lower()] metric_indices = column_metric_indices[column_name.lower()]
non_missing_index = metric_indices.get("non_missing") non_missing_index = metric_indices.get("non_missing")
@ -288,7 +257,7 @@ class Profiler:
logger.error(f"Exception during aggregation query", exc_info=e) logger.error(f"Exception during aggregation query", exc_info=e)
def _query_group_by_value(self): def _query_group_by_value(self):
for column in self.columns: for column in self.database.columns:
try: try:
measurements = [] measurements = []
column_name = column.name column_name = column.name
@ -346,7 +315,7 @@ class Profiler:
) )
def _query_histograms(self): def _query_histograms(self):
for column in self.columns: for column in self.database.columns:
column_name = column.name column_name = column.name
try: try:
if column.is_number(): if column.is_number():

View File

@ -8,7 +8,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -58,7 +58,7 @@ class Column(BaseModel):
name: str name: str
nullable: bool = None nullable: bool = None
data_type: str data_type: Any
logical_type: SupportedDataType logical_type: SupportedDataType
def is_text(self) -> bool: def is_text(self) -> bool:
@ -79,7 +79,7 @@ class Table(BaseModel):
class GroupValue(BaseModel): class GroupValue(BaseModel):
"""Metrinc Group Values""" """Metric Group Values"""
group: Dict = {} group: Dict = {}
value: object value: object