Fix #946: Add support for MySQL and Snowflake for profiler (#947)

This commit is contained in:
Sriharsha Chintalapani 2021-10-27 05:37:34 -07:00 committed by GitHub
parent d9a35e4b5c
commit 47d3d5c77a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 203 additions and 34 deletions

View File

@ -58,6 +58,14 @@ plugins: Dict[str, Set[str]] = {
"GeoAlchemy2", "GeoAlchemy2",
}, },
"postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"}, "postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"},
"mysql": {"pymysql>=1.0.2"},
"snowflake": {"snowflake-sqlalchemy<=1.2.4"},
"hive": {
"openmetadata-sqlalchemy-hive==0.2.0",
"thrift~=0.13.0",
"sasl==0.3.1",
"thrift-sasl==0.4.3",
},
} }
build_options = {"includes": ["_cffi_backend"]} build_options = {"includes": ["_cffi_backend"]}

View File

@ -10,6 +10,7 @@
import io import io
import json import json
import pathlib import pathlib
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import IO, Any, List, Optional from typing import IO, Any, List, Optional

View File

@ -56,17 +56,17 @@ class Database(Closeable, metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def sql_fetchone(self, sql) -> tuple: def execute_query(self, sql) -> tuple:
pass pass
@abstractmethod @abstractmethod
def sql_fetchone_description(self, sql) -> tuple: def execute_query_columns(self, sql) -> tuple:
pass pass
@abstractmethod @abstractmethod
def sql_fetchall(self, sql) -> List[tuple]: def execute_query_all(self, sql) -> List[tuple]:
pass pass
@abstractmethod @abstractmethod
def sql_fetchall_description(self, sql) -> tuple: def execute_query_all_columns(self, sql) -> tuple:
pass pass

View File

@ -271,51 +271,35 @@ class DatabaseCommon(Database):
def is_time(self, column_type: str): def is_time(self, column_type: str):
return column_type.upper() in _time_types return column_type.upper() in _time_types
def sql_fetchone(self, sql: str) -> tuple: def execute_query(self, sql: str) -> tuple:
""" return self.execute_query_columns(sql)[0]
Only returns the tuple obtained by cursor.fetchone()
"""
return self.sql_fetchone_description(sql)[0]
def sql_fetchone_description(self, sql: str) -> tuple: def execute_query_columns(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuple obtained by cursor.fetchone()
2) the cursor.description
"""
cursor = self.connection.cursor() cursor = self.connection.cursor()
try: try:
logger.debug(f"Executing SQL query: \n{sql}") logger.debug(f"SQL query: \n{sql}")
start = datetime.now() start = datetime.now()
cursor.execute(sql) cursor.execute(sql)
row_tuple = cursor.fetchone() row_tuple = cursor.fetchone()
description = cursor.description description = cursor.description
delta = datetime.now() - start delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}") logger.debug(f"SQL duration {str(delta)}")
return row_tuple, description return row_tuple, description
finally: finally:
cursor.close() cursor.close()
def sql_fetchall(self, sql: str) -> List[tuple]: def execute_query_all(self, sql: str) -> List[tuple]:
""" return self.execute_query_all_columns(sql)[0]
Only returns the tuples obtained by cursor.fetchall()
"""
return self.sql_fetchall_description(sql)[0]
def sql_fetchall_description(self, sql: str) -> tuple: def execute_query_all_columns(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuples obtained by cursor.fetchall()
2) the cursor.description
"""
cursor = self.connection.cursor() cursor = self.connection.cursor()
try: try:
logger.debug(f"Executing SQL query: \n{sql}") logger.debug(f"SQL query: \n{sql}")
start = datetime.now() start = datetime.now()
cursor.execute(sql) cursor.execute(sql)
rows = cursor.fetchall() rows = cursor.fetchall()
delta = datetime.now() - start delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}") logger.debug(f"SQL duration {str(delta)}")
return rows, cursor.description return rows, cursor.description
finally: finally:
cursor.close() cursor.close()

View File

@ -0,0 +1,77 @@
from typing import Optional
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"INTEGER",
"INT",
"SMALLINT",
"TINYINT",
"MEDIUMINT",
"BIGINT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DOUBLE",
"REAL",
"DOUBLE PRECISION",
"DEC",
"FIXED",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"],
SupportedDataType.TIME,
)
class MySQLConnectionConfig(SQLConnectionConfig):
host_port = "localhost:3306"
scheme = "mysql+pymysql"
service_type = "MySQL"
def get_connection_url(self):
return super().get_connection_url()
class MySQLExpressions(SQLExpressions):
count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _"
regex_like_pattern_expr = "{} regexp '{}'"
class MySQL(DatabaseCommon):
config: MySQLConnectionConfig = None
sql_exprs: MySQLExpressions = MySQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = MySQLConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_schema = '{self.database}'"
return sql

View File

@ -0,0 +1,99 @@
from typing import Optional
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
["VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"],
SupportedDataType.TEXT,
)
register_custom_type(
[
"NUMBER",
"INT",
"INTEGER",
"BIGINT",
"SMALLINT",
"TINYINT",
"BYTEINT",
"FLOAT",
"FLOAT4",
"FLOAT8",
"DOUBLE",
"DOUBLE PRECISION",
"REAL",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
[
"DATE",
"DATETIME",
"TIME",
"TIMESTAMP",
"TIMESTAMP_LTZ",
"TIMESTAMP_NTZ",
"TIMESTAMP_TZ",
],
SupportedDataType.TIME,
)
class SnowflakeConnectionConfig(SQLConnectionConfig):
scheme = "snowflake"
account: str
database: str # database is required
warehouse: Optional[str]
role: Optional[str]
duration: Optional[int]
service_type = "Snowflake"
def get_connection_url(self):
connect_string = super().get_connection_url()
options = {
"account": self.account,
"warehouse": self.warehouse,
"role": self.role,
}
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
if params:
connect_string = f"{connect_string}?{params}"
return connect_string
class SnowflakeSQLExpressions(SQLExpressions):
count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _"
regex_like_pattern_expr = "{} regexp '{}'"
class Snowflake(DatabaseCommon):
config: SnowflakeConnectionConfig = None
sql_exprs: SnowflakeSQLExpressions = SnowflakeSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = SnowflakeConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name.lower()}'"
)
if self.config.database:
sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'"
if self.config.db_schema:
sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'"
return sql

View File

@ -74,7 +74,7 @@ class Profiler:
def _table_metadata(self): def _table_metadata(self):
sql = self.database.table_metadata_query(self.table.name) sql = self.database.table_metadata_query(self.table.name)
columns = self.database.sql_fetchall(sql) columns = self.database.execute_query_all(sql)
self.queries_executed += 1 self.queries_executed += 1
self.table_columns = [] self.table_columns = []
for column in columns: for column in columns:
@ -194,7 +194,7 @@ class Profiler:
"SELECT \n " + ",\n ".join(fields) + " \n" "SELECT \n " + ",\n ".join(fields) + " \n"
"FROM " + self.qualified_table_name "FROM " + self.qualified_table_name
) )
query_result_tuple = self.database.sql_fetchone(sql) query_result_tuple = self.database.execute_query(sql)
self.queries_executed += 1 self.queries_executed += 1
for i in range(0, len(measurements)): for i in range(0, len(measurements)):
@ -305,7 +305,7 @@ class Profiler:
f"FROM group_by_value" f"FROM group_by_value"
) )
query_result_tuple = self.database.sql_fetchone(sql) query_result_tuple = self.database.execute_query(sql)
self.queries_executed += 1 self.queries_executed += 1
distinct_count = query_result_tuple[0] distinct_count = query_result_tuple[0]
@ -407,7 +407,7 @@ class Profiler:
f"FROM group_by_value" f"FROM group_by_value"
) )
row = self.database.sql_fetchone(sql) row = self.database.execute_query(sql)
self.queries_executed += 1 self.queries_executed += 1
# Process the histogram query # Process the histogram query