Fix #946: Add support for MySQL and Snowflake for profiler (#947)

This commit is contained in:
Sriharsha Chintalapani 2021-10-27 05:37:34 -07:00 committed by GitHub
parent d9a35e4b5c
commit 47d3d5c77a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 203 additions and 34 deletions

View File

@ -58,6 +58,14 @@ plugins: Dict[str, Set[str]] = {
"GeoAlchemy2",
},
"postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"},
"mysql": {"pymysql>=1.0.2"},
"snowflake": {"snowflake-sqlalchemy<=1.2.4"},
"hive": {
"openmetadata-sqlalchemy-hive==0.2.0",
"thrift~=0.13.0",
"sasl==0.3.1",
"thrift-sasl==0.4.3",
},
}
build_options = {"includes": ["_cffi_backend"]}

View File

@ -10,6 +10,7 @@
import io
import json
import pathlib
import re
from abc import ABC, abstractmethod
from typing import IO, Any, List, Optional

View File

@ -56,17 +56,17 @@ class Database(Closeable, metaclass=ABCMeta):
pass
@abstractmethod
def sql_fetchone(self, sql) -> tuple:
def execute_query(self, sql) -> tuple:
pass
@abstractmethod
def sql_fetchone_description(self, sql) -> tuple:
def execute_query_columns(self, sql) -> tuple:
pass
@abstractmethod
def sql_fetchall(self, sql) -> List[tuple]:
def execute_query_all(self, sql) -> List[tuple]:
pass
@abstractmethod
def sql_fetchall_description(self, sql) -> tuple:
def execute_query_all_columns(self, sql) -> tuple:
pass

View File

@ -271,51 +271,35 @@ class DatabaseCommon(Database):
def is_time(self, column_type: str):
return column_type.upper() in _time_types
def sql_fetchone(self, sql: str) -> tuple:
"""
Only returns the tuple obtained by cursor.fetchone()
"""
return self.sql_fetchone_description(sql)[0]
def execute_query(self, sql: str) -> tuple:
return self.execute_query_columns(sql)[0]
def sql_fetchone_description(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuple obtained by cursor.fetchone()
2) the cursor.description
"""
def execute_query_columns(self, sql: str) -> tuple:
cursor = self.connection.cursor()
try:
logger.debug(f"Executing SQL query: \n{sql}")
logger.debug(f"SQL query: \n{sql}")
start = datetime.now()
cursor.execute(sql)
row_tuple = cursor.fetchone()
description = cursor.description
delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}")
logger.debug(f"SQL duration {str(delta)}")
return row_tuple, description
finally:
cursor.close()
def sql_fetchall(self, sql: str) -> List[tuple]:
"""
Only returns the tuples obtained by cursor.fetchall()
"""
return self.sql_fetchall_description(sql)[0]
def execute_query_all(self, sql: str) -> List[tuple]:
return self.execute_query_all_columns(sql)[0]
def sql_fetchall_description(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuples obtained by cursor.fetchall()
2) the cursor.description
"""
def execute_query_all_columns(self, sql: str) -> tuple:
cursor = self.connection.cursor()
try:
logger.debug(f"Executing SQL query: \n{sql}")
logger.debug(f"SQL query: \n{sql}")
start = datetime.now()
cursor.execute(sql)
rows = cursor.fetchall()
delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}")
logger.debug(f"SQL duration {str(delta)}")
return rows, cursor.description
finally:
cursor.close()

View File

@ -0,0 +1,77 @@
from typing import Optional
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"INTEGER",
"INT",
"SMALLINT",
"TINYINT",
"MEDIUMINT",
"BIGINT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DOUBLE",
"REAL",
"DOUBLE PRECISION",
"DEC",
"FIXED",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"],
SupportedDataType.TIME,
)
class MySQLConnectionConfig(SQLConnectionConfig):
host_port = "localhost:3306"
scheme = "mysql+pymysql"
service_type = "MySQL"
def get_connection_url(self):
return super().get_connection_url()
class MySQLExpressions(SQLExpressions):
count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _"
regex_like_pattern_expr = "{} regexp '{}'"
class MySQL(DatabaseCommon):
config: MySQLConnectionConfig = None
sql_exprs: MySQLExpressions = MySQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = MySQLConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_schema = '{self.database}'"
return sql

View File

@ -0,0 +1,99 @@
from typing import Optional
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"NUMBER",
"INT",
"INTEGER",
"BIGINT",
"SMALLINT",
"TINYINT",
"BYTEINT",
"FLOAT",
"FLOAT4",
"FLOAT8",
"DOUBLE",
"DOUBLE PRECISION",
"REAL",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
[
"DATE",
"DATETIME",
"TIME",
"TIMESTAMP",
"TIMESTAMP_LTZ",
"TIMESTAMP_NTZ",
"TIMESTAMP_TZ",
],
SupportedDataType.TIME,
)
class SnowflakeConnectionConfig(SQLConnectionConfig):
scheme = "snowflake"
account: str
database: str # database is required
warehouse: Optional[str]
role: Optional[str]
duration: Optional[int]
service_type = "Snowflake"
def get_connection_url(self):
connect_string = super().get_connection_url()
options = {
"account": self.account,
"warehouse": self.warehouse,
"role": self.role,
}
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
if params:
connect_string = f"{connect_string}?{params}"
return connect_string
class SnowflakeSQLExpressions(SQLExpressions):
count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _"
regex_like_pattern_expr = "{} regexp '{}'"
class Snowflake(DatabaseCommon):
config: SnowflakeConnectionConfig = None
sql_exprs: SnowflakeSQLExpressions = SnowflakeSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = SnowflakeConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name.lower()}'"
)
if self.config.database:
sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'"
if self.config.db_schema:
sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'"
return sql

View File

@ -74,7 +74,7 @@ class Profiler:
def _table_metadata(self):
sql = self.database.table_metadata_query(self.table.name)
columns = self.database.sql_fetchall(sql)
columns = self.database.execute_query_all(sql)
self.queries_executed += 1
self.table_columns = []
for column in columns:
@ -194,7 +194,7 @@ class Profiler:
"SELECT \n " + ",\n ".join(fields) + " \n"
"FROM " + self.qualified_table_name
)
query_result_tuple = self.database.sql_fetchone(sql)
query_result_tuple = self.database.execute_query(sql)
self.queries_executed += 1
for i in range(0, len(measurements)):
@ -305,7 +305,7 @@ class Profiler:
f"FROM group_by_value"
)
query_result_tuple = self.database.sql_fetchone(sql)
query_result_tuple = self.database.execute_query(sql)
self.queries_executed += 1
distinct_count = query_result_tuple[0]
@ -407,7 +407,7 @@ class Profiler:
f"FROM group_by_value"
)
row = self.database.sql_fetchone(sql)
row = self.database.execute_query(sql)
self.queries_executed += 1
# Process the histogram query