diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py b/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py index a0d748f1e08..b288cb00d61 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py @@ -65,6 +65,8 @@ from metadata.utils.filters import filter_by_database from metadata.utils.logger import ingestion_logger from metadata.utils.sqlalchemy_utils import ( get_all_view_definitions, + get_table_comment_result_wrapper, + get_table_comment_results, get_view_definition_wrapper, ) from metadata.utils.tag_utils import get_ometa_tag_and_classification @@ -110,11 +112,50 @@ _type_map.update( } ) +# This method is from hive dialect originally but +# is overridden to optimize DESCRIBE query execution +def _get_table_columns(self, connection, table_name, schema, db_name): + full_table = table_name + if schema: + full_table = schema + "." + table_name + # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. + # Using DESCRIBE works but is uglier. + try: + # This needs the table name to be unescaped (no backticks). + query = DATABRICKS_GET_TABLE_COMMENTS.format( + database_name=db_name, schema_name=schema, table_name=table_name + ) + cursor = get_table_comment_result( + self, + connection=connection, + query=query, + database=db_name, + table_name=table_name, + schema=schema, + ) -def _get_column_rows(self, connection, table_name, schema): + rows = cursor.fetchall() + + except exc.OperationalError as e: + # Does the table exist? + regex_fmt = r"TExecuteStatementResp.*SemanticException.*Table not found {}" + regex = regex_fmt.format(re.escape(full_table)) + if re.search(regex, e.args[0]): + raise exc.NoSuchTableError(full_table) + else: + raise + else: + # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist + regex = r"Table .* does not exist" + if len(rows) == 1 and re.match(regex, rows[0].col_name): + raise exc.NoSuchTableError(full_table) + return rows + + +def _get_column_rows(self, connection, table_name, schema, db_name): # get columns and strip whitespace - table_columns = self._get_table_columns( # pylint: disable=protected-access - connection, table_name, schema + table_columns = _get_table_columns( # pylint: disable=protected-access + self, connection, table_name, schema, db_name ) column_rows = [ [col.strip() if col else None for col in row] for row in table_columns @@ -134,7 +175,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): Databricks ingest config file. """ - rows = _get_column_rows(self, connection, table_name, schema) + rows = _get_column_rows(self, connection, table_name, schema, kw.get("db_name")) result = [] for col_name, col_type, _comment in rows: # Handle both oss hive and Databricks' hive partition header, respectively @@ -142,6 +183,8 @@ def get_columns(self, connection, table_name, schema=None, **kw): "# Partition Information", "# Partitioning", "# Clustering Information", + "# Delta Statistics Columns", + "# Detailed Table Information", ): break # Take out the more detailed type information @@ -225,12 +268,18 @@ def get_table_comment( # pylint: disable=unused-argument """ Returns comment of table """ - cursor = connection.execute( - DATABRICKS_GET_TABLE_COMMENTS.format( - database_name=self.context.get().database, - schema_name=schema_name, - table_name=table_name, - ) + query = DATABRICKS_GET_TABLE_COMMENTS.format( + database_name=self.context.get().database, + schema_name=schema_name, + table_name=table_name, + ) + cursor = self.get_table_comment_result( + self, + connection=connection, + query=query, + database=self.context.get().database, + table_name=table_name, + schema=schema_name, ) try: for result in list(cursor): @@ -258,6 +307,26 @@ def get_view_definition( return None +@reflection.cache +def get_table_comment_result( + self, + connection, + query, + database, + table_name, + schema=None, + **kw, # pylint: disable=unused-argument +): + return get_table_comment_result_wrapper( + self, + connection, + query=query, + database=database, + table_name=table_name, + schema=schema, + ) + + @reflection.cache def get_table_ddl( self, connection, table_name, schema=None, **kw @@ -296,7 +365,7 @@ def get_table_names( table_name = row[0] if schema: database = kw.get("db_name") - table_type = get_table_type(connection, database, schema, table_name) + table_type = get_table_type(self, connection, database, schema, table_name) if not table_type or table_type == "FOREIGN": # skip the table if it's foreign table / error in fetching table_type logger.debug( @@ -311,7 +380,7 @@ def get_table_names( return [table for table in tables if table not in views] -def get_table_type(connection, database, schema, table): +def get_table_type(self, connection, database, schema, table): """get table type (regular/foreign)""" try: if database: @@ -320,7 +389,14 @@ def get_table_type(connection, database, schema, table): ) else: query = f"DESCRIBE TABLE EXTENDED {schema}.{table}" - rows = connection.execute(query) + rows = get_table_comment_result( + self, + connection=connection, + query=query, + database=database, + table_name=table, + schema=schema, + ) for row in rows: row_dict = dict(row) if row_dict.get("col_name") == "Type": @@ -338,6 +414,8 @@ DatabricksDialect.get_schema_names = get_schema_names DatabricksDialect.get_view_definition = get_view_definition DatabricksDialect.get_table_names = get_table_names DatabricksDialect.get_all_view_definitions = get_all_view_definitions +DatabricksDialect.get_table_comment_results = get_table_comment_results +DatabricksDialect.get_table_comment_result = get_table_comment_result reflection.Inspector.get_schema_names = get_schema_names_reflection reflection.Inspector.get_table_ddl = get_table_ddl @@ -677,12 +755,17 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB ) -> str: description = None try: - cursor = self.connection.execute( - DATABRICKS_GET_TABLE_COMMENTS.format( - database_name=self.context.get().database, - schema_name=schema_name, - table_name=table_name, - ) + query = DATABRICKS_GET_TABLE_COMMENTS.format( + database_name=self.context.get().database, + schema_name=schema_name, + table_name=table_name, + ) + cursor = inspector.dialect.get_table_comment_result( + connection=self.connection, + query=query, + database=self.context.get().database, + table_name=table_name, + schema=schema_name, ) for result in list(cursor): data = result.values() @@ -729,7 +812,13 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB schema_name=self.context.get().database_schema, table_name=table_name, ) - result = self.connection.engine.execute(query) + result = self.inspector.dialect.get_table_comment_result( + connection=self.connection, + query=query, + database=self.context.get().database, + table_name=table_name, + schema=self.context.get().database_schema, + ) owner = None for row in result: row_dict = dict(row) diff --git a/ingestion/src/metadata/utils/sqlalchemy_utils.py b/ingestion/src/metadata/utils/sqlalchemy_utils.py index 642248ede87..1c1e33e7687 100644 --- a/ingestion/src/metadata/utils/sqlalchemy_utils.py +++ b/ingestion/src/metadata/utils/sqlalchemy_utils.py @@ -169,3 +169,28 @@ def get_table_ddl( table_name=table_name, schema=schema, ) + + +@reflection.cache +def get_table_comment_results( + self, connection, query, database, table_name, schema=None +): + """ + Method to fetch comment of all available tables + """ + self.table_comment_result: Dict[Tuple[str, str], str] = {} + self.current_db: str = database + result = connection.execute(query) + self.table_comment_result[(table_name, schema)] = result + + +def get_table_comment_result_wrapper( + self, connection, query, database, table_name, schema=None +): + if ( + not hasattr(self, "table_comment_result") + or self.table_comment_result.get((table_name, schema)) is None + or self.current_db != database + ): + self.get_table_comment_results(connection, query, database, table_name, schema) + return self.table_comment_result.get((table_name, schema))