diff --git a/profiler/setup.py b/profiler/setup.py index b838272b676..dbe4893fd38 100644 --- a/profiler/setup.py +++ b/profiler/setup.py @@ -58,6 +58,14 @@ plugins: Dict[str, Set[str]] = { "GeoAlchemy2", }, "postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"}, + "mysql": {"pymysql>=1.0.2"}, + "snowflake": {"snowflake-sqlalchemy<=1.2.4"}, + "hive": { + "openmetadata-sqlalchemy-hive==0.2.0", + "thrift~=0.13.0", + "sasl==0.3.1", + "thrift-sasl==0.4.3", + }, } build_options = {"includes": ["_cffi_backend"]} diff --git a/profiler/src/openmetadata/common/config.py b/profiler/src/openmetadata/common/config.py index c3b2d520731..719deb692dd 100644 --- a/profiler/src/openmetadata/common/config.py +++ b/profiler/src/openmetadata/common/config.py @@ -10,6 +10,7 @@ import io import json import pathlib +import re from abc import ABC, abstractmethod from typing import IO, Any, List, Optional diff --git a/profiler/src/openmetadata/common/database.py b/profiler/src/openmetadata/common/database.py index b7ff945cd17..eed87d7dd6e 100644 --- a/profiler/src/openmetadata/common/database.py +++ b/profiler/src/openmetadata/common/database.py @@ -56,17 +56,17 @@ class Database(Closeable, metaclass=ABCMeta): pass @abstractmethod - def sql_fetchone(self, sql) -> tuple: + def execute_query(self, sql) -> tuple: pass @abstractmethod - def sql_fetchone_description(self, sql) -> tuple: + def execute_query_columns(self, sql) -> tuple: pass @abstractmethod - def sql_fetchall(self, sql) -> List[tuple]: + def execute_query_all(self, sql) -> List[tuple]: pass @abstractmethod - def sql_fetchall_description(self, sql) -> tuple: + def execute_query_all_columns(self, sql) -> tuple: pass diff --git a/profiler/src/openmetadata/common/database_common.py b/profiler/src/openmetadata/common/database_common.py index 8c9ea2b4c98..52a65b8932a 100644 --- a/profiler/src/openmetadata/common/database_common.py +++ b/profiler/src/openmetadata/common/database_common.py @@ -271,51 +271,35 @@ class DatabaseCommon(Database): def is_time(self, column_type: str): return column_type.upper() in _time_types - def sql_fetchone(self, sql: str) -> tuple: - """ - Only returns the tuple obtained by cursor.fetchone() - """ - return self.sql_fetchone_description(sql)[0] + def execute_query(self, sql: str) -> tuple: + return self.execute_query_columns(sql)[0] - def sql_fetchone_description(self, sql: str) -> tuple: - """ - Returns a tuple with 2 elements: - 1) the tuple obtained by cursor.fetchone() - 2) the cursor.description - """ + def execute_query_columns(self, sql: str) -> tuple: cursor = self.connection.cursor() try: - logger.debug(f"Executing SQL query: \n{sql}") + logger.debug(f"SQL query: \n{sql}") start = datetime.now() cursor.execute(sql) row_tuple = cursor.fetchone() description = cursor.description delta = datetime.now() - start - logger.debug(f"SQL took {str(delta)}") + logger.debug(f"SQL duration {str(delta)}") return row_tuple, description finally: cursor.close() - def sql_fetchall(self, sql: str) -> List[tuple]: - """ - Only returns the tuples obtained by cursor.fetchall() - """ - return self.sql_fetchall_description(sql)[0] + def execute_query_all(self, sql: str) -> List[tuple]: + return self.execute_query_all_columns(sql)[0] - def sql_fetchall_description(self, sql: str) -> tuple: - """ - Returns a tuple with 2 elements: - 1) the tuples obtained by cursor.fetchall() - 2) the cursor.description - """ + def execute_query_all_columns(self, sql: str) -> tuple: cursor = self.connection.cursor() try: - logger.debug(f"Executing SQL query: \n{sql}") + logger.debug(f"SQL query: \n{sql}") start = datetime.now() cursor.execute(sql) rows = cursor.fetchall() delta = datetime.now() - start - logger.debug(f"SQL took {str(delta)}") + logger.debug(f"SQL duration {str(delta)}") return rows, cursor.description finally: cursor.close() diff --git a/profiler/src/openmetadata/databases/mysql.py b/profiler/src/openmetadata/databases/mysql.py new file mode 100644 index 00000000000..472104b7af3 --- /dev/null +++ b/profiler/src/openmetadata/databases/mysql.py @@ -0,0 +1,77 @@ +from typing import Optional + +from openmetadata.common.database_common import ( + DatabaseCommon, + SQLConnectionConfig, + SQLExpressions, + register_custom_type, +) +from openmetadata.profiler.profiler_metadata import SupportedDataType + +register_custom_type( + ["CHAR", "VARCHAR", "BINARY", "VARBINARY", "BLOB", "TEXT", "ENUM", "SET"], + SupportedDataType.TEXT, +) + +register_custom_type( + [ + "INTEGER", + "INT", + "SMALLINT", + "TINYINT", + "MEDIUMINT", + "BIGINT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DOUBLE", + "REAL", + "DOUBLE PRECISION", + "DEC", + "FIXED", + ], + SupportedDataType.NUMERIC, +) + +register_custom_type( + ["TIMESTAMP", "DATE", "DATETIME", "YEAR", "TIME"], + SupportedDataType.TIME, +) + + +class MySQLConnectionConfig(SQLConnectionConfig): + host_port = "localhost:3306" + scheme = "mysql+pymysql" + service_type = "MySQL" + + def get_connection_url(self): + return super().get_connection_url() + + +class MySQLExpressions(SQLExpressions): + count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _" + regex_like_pattern_expr = "{} regexp '{}'" + + +class MySQL(DatabaseCommon): + config: MySQLConnectionConfig = None + sql_exprs: MySQLExpressions = MySQLExpressions() + + def __init__(self, config): + super().__init__(config) + self.config = config + + @classmethod + def create(cls, config_dict): + config = MySQLConnectionConfig.parse_obj(config_dict) + return cls(config) + + def table_metadata_query(self, table_name: str) -> str: + sql = ( + f"SELECT column_name, data_type, is_nullable \n" + f"FROM information_schema.columns \n" + f"WHERE lower(table_name) = '{table_name}'" + ) + if self.config.database: + sql += f" \n AND table_schema = '{self.database}'" + return sql diff --git a/profiler/src/openmetadata/databases/snowflake.py b/profiler/src/openmetadata/databases/snowflake.py new file mode 100644 index 00000000000..b0e0f0df343 --- /dev/null +++ b/profiler/src/openmetadata/databases/snowflake.py @@ -0,0 +1,99 @@ +from typing import Optional + +from openmetadata.common.database_common import ( + DatabaseCommon, + SQLConnectionConfig, + SQLExpressions, + register_custom_type, +) +from openmetadata.profiler.profiler_metadata import SupportedDataType + +register_custom_type( + ["VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"], + SupportedDataType.TEXT, +) + +register_custom_type( + [ + "NUMBER", + "INT", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "BYTEINT", + "FLOAT", + "FLOAT4", + "FLOAT8", + "DOUBLE", + "DOUBLE PRECISION", + "REAL", + ], + SupportedDataType.NUMERIC, +) + +register_custom_type( + [ + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + ], + SupportedDataType.TIME, +) + + +class SnowflakeConnectionConfig(SQLConnectionConfig): + scheme = "snowflake" + account: str + database: str # database is required + warehouse: Optional[str] + role: Optional[str] + duration: Optional[int] + service_type = "Snowflake" + + def get_connection_url(self): + connect_string = super().get_connection_url() + options = { + "account": self.account, + "warehouse": self.warehouse, + "role": self.role, + } + params = "&".join(f"{key}={value}" for (key, value) in options.items() if value) + if params: + connect_string = f"{connect_string}?{params}" + return connect_string + + +class SnowflakeSQLExpressions(SQLExpressions): + count_conditional_expr = "COUNT(CASE WHEN {} THEN 1 END) AS _" + regex_like_pattern_expr = "{} regexp '{}'" + + +class Snowflake(DatabaseCommon): + config: SnowflakeConnectionConfig = None + sql_exprs: SnowflakeSQLExpressions = SnowflakeSQLExpressions() + + def __init__(self, config): + super().__init__(config) + self.config = config + + @classmethod + def create(cls, config_dict): + config = SnowflakeConnectionConfig.parse_obj(config_dict) + return cls(config) + + def table_metadata_query(self, table_name: str) -> str: + sql = ( + f"SELECT column_name, data_type, is_nullable \n" + f"FROM information_schema.columns \n" + f"WHERE lower(table_name) = '{table_name.lower()}'" + ) + if self.config.database: + sql += f" \n AND lower(table_catalog) = '{self.config.database.lower()}'" + if self.config.db_schema: + sql += f" \n AND lower(table_schema) = '{self.config.db_schema.lower()}'" + return sql diff --git a/profiler/src/openmetadata/profiler/profiler.py b/profiler/src/openmetadata/profiler/profiler.py index ae9e8858405..7c41252fd63 100644 --- a/profiler/src/openmetadata/profiler/profiler.py +++ b/profiler/src/openmetadata/profiler/profiler.py @@ -74,7 +74,7 @@ class Profiler: def _table_metadata(self): sql = self.database.table_metadata_query(self.table.name) - columns = self.database.sql_fetchall(sql) + columns = self.database.execute_query_all(sql) self.queries_executed += 1 self.table_columns = [] for column in columns: @@ -194,7 +194,7 @@ class Profiler: "SELECT \n " + ",\n ".join(fields) + " \n" "FROM " + self.qualified_table_name ) - query_result_tuple = self.database.sql_fetchone(sql) + query_result_tuple = self.database.execute_query(sql) self.queries_executed += 1 for i in range(0, len(measurements)): @@ -305,7 +305,7 @@ class Profiler: f"FROM group_by_value" ) - query_result_tuple = self.database.sql_fetchone(sql) + query_result_tuple = self.database.execute_query(sql) self.queries_executed += 1 distinct_count = query_result_tuple[0] @@ -407,7 +407,7 @@ class Profiler: f"FROM group_by_value" ) - row = self.database.sql_fetchone(sql) + row = self.database.execute_query(sql) self.queries_executed += 1 # Process the histogram query