Fix #961: Profiler: Add support for Hive and Athena (#962)

This commit is contained in:
Sriharsha Chintalapani 2021-10-27 12:10:59 -07:00 committed by GitHub
parent 721372719c
commit e5dfcf9f15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 141 additions and 199 deletions

View File

@ -26,13 +26,18 @@ class Database(Closeable, metaclass=ABCMeta):
def create(cls, config_dict: dict) -> "Database":
pass
@property
@abstractmethod
def columns(self):
pass
@property
@abstractmethod
def sql_exprs(self):
pass
@abstractmethod
def table_metadata_query(self, table_name: str) -> str:
def table_column_metadata(self, table: str, schema: str):
pass
@abstractmethod

View File

@ -14,11 +14,14 @@ import re
from abc import abstractmethod
from datetime import date, datetime
from numbers import Number
from typing import List, Optional
from typing import List, Optional, Type
from urllib.parse import quote_plus
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from sqlalchemy.sql import sqltypes as types
from openmetadata.common.config import ConfigModel, IncludeFilterPattern
from openmetadata.common.database import Database
@ -58,33 +61,27 @@ class SQLConnectionConfig(ConfigModel):
_numeric_types = [
"SMALLINT",
"INTEGER",
"BIGINT",
"DECIMAL",
"NUMERIC",
"REAL",
"DOUBLE PRECISION",
"SMALLSERIAL",
"SERIAL",
"BIGSERIAL",
types.Integer,
types.Numeric,
]
_text_types = ["CHARACTER VARYING", "CHARACTER", "CHAR", "VARCHAR" "TEXT"]
_text_types = [
types.VARCHAR,
types.String,
]
_time_types = [
"TIMESTAMP",
"DATE",
"TIME",
"TIMESTAMP WITH TIME ZONE",
"TIMESTAMP WITHOUT TIME ZONE",
"TIME WITH TIME ZONE",
"TIME WITHOUT TIME ZONE",
types.Date,
types.DATE,
types.Time,
types.DateTime,
types.DATETIME,
types.TIMESTAMP,
]
def register_custom_type(
data_types: List[str], type_category: SupportedDataType
data_types: List[Type[types.TypeEngine]], type_category: SupportedDataType
) -> None:
if type_category == SupportedDataType.TIME:
_time_types.extend(data_types)
@ -242,34 +239,85 @@ class DatabaseCommon(Database):
data_type_date = "DATE"
config: SQLConnectionConfig = None
sql_exprs: SQLExpressions = SQLExpressions()
columns: List[Column] = []
def __init__(self, config: SQLConnectionConfig):
self.config = config
self.connection_string = self.config.get_connection_url()
self.engine = create_engine(self.connection_string, **self.config.options)
self.connection = self.engine.raw_connection()
self.inspector = inspect(self.engine)
@classmethod
def create(cls, config_dict: dict):
pass
def table_metadata_query(self, table_name: str) -> str:
pass
def qualify_table_name(self, table_name: str) -> str:
return table_name
def qualify_column_name(self, column_name: str):
return column_name
def is_text(self, column_type: str):
return column_type.upper() in _text_types
def is_text(self, column_type: Type[types.TypeEngine]):
for sql_type in _text_types:
if isinstance(column_type, sql_type):
return True
return False
def is_number(self, column_type: str):
return column_type.upper() in _numeric_types
def is_number(self, column_type: Type[types.TypeEngine]):
for sql_type in _numeric_types:
if isinstance(column_type, sql_type):
return True
return False
def is_time(self, column_type: str):
return column_type.upper() in _time_types
def is_time(self, column_type: Type[types.TypeEngine]):
for sql_type in _time_types:
if isinstance(column_type, sql_type):
return True
return False
def table_column_metadata(self, table: str, schema: str):
table = self.qualify_table_name(table)
pk_constraints = self.inspector.get_pk_constraint(table, schema)
pk_columns = (
pk_constraints["column_constraints"]
if len(pk_constraints) > 0 and "column_constraints" in pk_constraints.keys()
else {}
)
unique_constraints = []
try:
unique_constraints = self.inspector.get_unique_constraints(table, schema)
except NotImplementedError:
pass
unique_columns = []
for constraint in unique_constraints:
if "column_names" in constraint.keys():
unique_columns = constraint["column_names"]
columns = self.inspector.get_columns(self.qualify_table_name(table))
for column in columns:
name = column["name"]
data_type = column["type"]
nullable = True
if not column["nullable"] or column["name"] in pk_columns:
nullable = False
if self.is_number(data_type):
logical_type = SupportedDataType.NUMERIC
elif self.is_time(data_type):
logical_type = SupportedDataType.TIME
elif self.is_text(data_type):
logical_type = SupportedDataType.TEXT
else:
logger.info(f" {name} ({data_type}) not supported.")
continue
self.columns.append(
Column(
name=name,
data_type=data_type,
nullable=nullable,
logical_type=logical_type,
)
)
def execute_query(self, sql: str) -> tuple:
return self.execute_query_columns(sql)[0]

View File

@ -0,0 +1,48 @@
import json
from typing import Optional
from pyhive import hive # noqa: F401
from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type([HiveDate, HiveTimestamp], SupportedDataType.TIME)
register_custom_type([HiveDecimal], SupportedDataType.NUMERIC)
class HiveConfig(SQLConnectionConfig):
scheme = "hive"
auth_options: Optional[str] = None
service_type = "Hive"
def get_connection_url(self):
url = super().get_connection_url()
if self.auth_options is not None:
return f"{url};{self.auth_options}"
else:
return url
class HiveSQLExpressions(SQLExpressions):
stddev_expr = "STDDEV_POP({})"
regex_like_pattern_expr = "cast({} as string) rlike '{}'"
class Hive(DatabaseCommon):
config: HiveConfig = None
sql_exprs: HiveSQLExpressions = HiveSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = HiveConfig.parse_obj(config_dict)
return cls(config)

View File

@ -4,38 +4,6 @@ from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"INTEGER",
"INT",
"SMALLINT",
"TINYINT",
"MEDIUMINT",
"BIGINT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DOUBLE",
"REAL",
"DOUBLE PRECISION",
"DEC",
"FIXED",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"],
SupportedDataType.TIME,
)
@ -65,13 +33,3 @@ class MySQL(DatabaseCommon):
def create(cls, config_dict):
config = MySQLConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_schema = '{self.database}'"
return sql

View File

@ -43,18 +43,6 @@ class Postgres(DatabaseCommon):
config = PostgresConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql
def qualify_table_name(self, table_name: str) -> str:
if self.config.db_schema:
return f'"{self.config.db_schema}"."{table_name}"'

View File

@ -13,56 +13,6 @@ from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
[
"CHAR",
"CHARACTER",
"BPCHAR",
"VARCHAR",
"CHARACTER VARYING",
"NVARCHAR",
"TEXT",
],
SupportedDataType.TEXT,
)
register_custom_type(
[
"SMALLINT",
"INT2",
"INTEGER",
"INT",
"INT4",
"BIGINT",
"INT8",
"DECIMAL",
"NUMERIC",
"REAL",
"FLOAT4",
"DOUBLE PRECISION",
"FLOAT8",
"FLOAT",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
[
"DATE",
"TIMESTAMP",
"TIMESTAMP WITHOUT TIME ZONE",
"TIMESTAMPTZ",
"TIMESTAMP WITH TIME ZONE",
"TIME",
"TIME WITHOUT TIME ZONE",
"TIMETZ",
"TIME WITH TIME ZONE",
],
SupportedDataType.TIME,
)
@ -94,15 +44,3 @@ class Redshift(DatabaseCommon):
def create(cls, config_dict):
config = RedshiftConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql

View File

@ -85,15 +85,3 @@ class Snowflake(DatabaseCommon):
def create(cls, config_dict):
config = SnowflakeConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name.lower()}'"
)
if self.config.database:
sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'"
if self.config.db_schema:
sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'"
return sql

View File

@ -16,7 +16,6 @@ from typing import List
from openmetadata.common.database import Database
from openmetadata.common.metric import Metric
from openmetadata.profiler.profiler_metadata import (
Column,
ColumnProfileResult,
MetricMeasurement,
ProfileResult,
@ -44,7 +43,6 @@ class Profiler:
self.time = profile_time
self.qualified_table_name = self.database.qualify_table_name(table_name)
self.scan_reference = None
self.columns: List[Column] = []
self.start_time = None
self.queries_executed = 0
@ -57,7 +55,9 @@ class Profiler:
self.start_time = datetime.now()
try:
self._table_metadata()
self.database.table_column_metadata(self.table.name, None)
logger.debug(str(len(self.database.columns)) + " columns:")
self.profiler_result.table_result.col_count = len(self.database.columns)
self._profile_aggregations()
self._query_group_by_value()
self._query_histograms()
@ -72,37 +72,6 @@ class Profiler:
return self.profiler_result
def _table_metadata(self):
sql = self.database.table_metadata_query(self.table.name)
columns = self.database.execute_query_all(sql)
self.queries_executed += 1
self.table_columns = []
for column in columns:
name = column[0]
data_type = column[1]
nullable = "YES" == column[2].upper()
if self.database.is_number(data_type):
logical_type = SupportedDataType.NUMERIC
elif self.database.is_time(data_type):
logical_type = SupportedDataType.TIME
elif self.database.is_text(data_type):
logical_type = SupportedDataType.TEXT
else:
logger.info(f" {name} ({data_type}) not supported.")
continue
self.columns.append(
Column(
name=name,
data_type=data_type,
nullable=nullable,
logical_type=logical_type,
)
)
self.column_names: List[str] = [column.name for column in self.columns]
logger.debug(str(len(self.columns)) + " columns:")
self.profiler_result.table_result.col_count = len(self.columns)
def _profile_aggregations(self):
measurements: List[MetricMeasurement] = []
fields: List[str] = []
@ -113,7 +82,7 @@ class Profiler:
column_metric_indices = {}
try:
for column in self.columns:
for column in self.database.columns:
metric_indices = {}
column_metric_indices[column.name.lower()] = metric_indices
column_name = column.name
@ -210,7 +179,7 @@ class Profiler:
if row_count_measurement:
row_count = row_count_measurement.value
self.profiler_result.table_result.row_count = row_count
for column in self.columns:
for column in self.database.columns:
column_name = column.name
metric_indices = column_metric_indices[column_name.lower()]
non_missing_index = metric_indices.get("non_missing")
@ -288,7 +257,7 @@ class Profiler:
logger.error(f"Exception during aggregation query", exc_info=e)
def _query_group_by_value(self):
for column in self.columns:
for column in self.database.columns:
try:
measurements = []
column_name = column.name
@ -346,7 +315,7 @@ class Profiler:
)
def _query_histograms(self):
for column in self.columns:
for column in self.database.columns:
column_name = column.name
try:
if column.is_number():

View File

@ -8,7 +8,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
@ -58,7 +58,7 @@ class Column(BaseModel):
name: str
nullable: bool = None
data_type: str
data_type: Any
logical_type: SupportedDataType
def is_text(self) -> bool:
@ -79,7 +79,7 @@ class Table(BaseModel):
class GroupValue(BaseModel):
"""Metrinc Group Values"""
"""Metric Group Values"""
group: Dict = {}
value: object