Fix #17778 : Databricks query run optimisation (#18467)

* Fix : Databricks query run  optimization

* Fixed dialect error

* fix get columns

* py format

---------

Co-authored-by: ulixius9 <mayursingal9@gmail.com>
This commit is contained in:
Suman Maharana 2024-11-06 10:10:01 +05:30 committed by GitHub
parent fffd5e593e
commit 426ad2000b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 134 additions and 20 deletions

View File

@ -65,6 +65,8 @@ from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
get_all_view_definitions,
get_table_comment_result_wrapper,
get_table_comment_results,
get_view_definition_wrapper,
)
from metadata.utils.tag_utils import get_ometa_tag_and_classification
@ -110,11 +112,50 @@ _type_map.update(
}
)
# This method is from hive dialect originally but
# is overridden to optimize DESCRIBE query execution
def _get_table_columns(self, connection, table_name, schema, db_name):
full_table = table_name
if schema:
full_table = schema + "." + table_name
# TODO using TGetColumnsReq hangs after sending TFetchResultsReq.
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
query = DATABRICKS_GET_TABLE_COMMENTS.format(
database_name=db_name, schema_name=schema, table_name=table_name
)
cursor = get_table_comment_result(
self,
connection=connection,
query=query,
database=db_name,
table_name=table_name,
schema=schema,
)
def _get_column_rows(self, connection, table_name, schema):
rows = cursor.fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r"TExecuteStatementResp.*SemanticException.*Table not found {}"
regex = regex_fmt.format(re.escape(full_table))
if re.search(regex, e.args[0]):
raise exc.NoSuchTableError(full_table)
else:
raise
else:
# Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist
regex = r"Table .* does not exist"
if len(rows) == 1 and re.match(regex, rows[0].col_name):
raise exc.NoSuchTableError(full_table)
return rows
def _get_column_rows(self, connection, table_name, schema, db_name):
# get columns and strip whitespace
table_columns = self._get_table_columns( # pylint: disable=protected-access
connection, table_name, schema
table_columns = _get_table_columns( # pylint: disable=protected-access
self, connection, table_name, schema, db_name
)
column_rows = [
[col.strip() if col else None for col in row] for row in table_columns
@ -134,7 +175,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
Databricks ingest config file.
"""
rows = _get_column_rows(self, connection, table_name, schema)
rows = _get_column_rows(self, connection, table_name, schema, kw.get("db_name"))
result = []
for col_name, col_type, _comment in rows:
# Handle both oss hive and Databricks' hive partition header, respectively
@ -142,6 +183,8 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"# Partition Information",
"# Partitioning",
"# Clustering Information",
"# Delta Statistics Columns",
"# Detailed Table Information",
):
break
# Take out the more detailed type information
@ -225,12 +268,18 @@ def get_table_comment( # pylint: disable=unused-argument
"""
Returns comment of table
"""
cursor = connection.execute(
DATABRICKS_GET_TABLE_COMMENTS.format(
query = DATABRICKS_GET_TABLE_COMMENTS.format(
database_name=self.context.get().database,
schema_name=schema_name,
table_name=table_name,
)
cursor = self.get_table_comment_result(
self,
connection=connection,
query=query,
database=self.context.get().database,
table_name=table_name,
schema=schema_name,
)
try:
for result in list(cursor):
@ -258,6 +307,26 @@ def get_view_definition(
return None
@reflection.cache
def get_table_comment_result(
self,
connection,
query,
database,
table_name,
schema=None,
**kw, # pylint: disable=unused-argument
):
return get_table_comment_result_wrapper(
self,
connection,
query=query,
database=database,
table_name=table_name,
schema=schema,
)
@reflection.cache
def get_table_ddl(
self, connection, table_name, schema=None, **kw
@ -296,7 +365,7 @@ def get_table_names(
table_name = row[0]
if schema:
database = kw.get("db_name")
table_type = get_table_type(connection, database, schema, table_name)
table_type = get_table_type(self, connection, database, schema, table_name)
if not table_type or table_type == "FOREIGN":
# skip the table if it's foreign table / error in fetching table_type
logger.debug(
@ -311,7 +380,7 @@ def get_table_names(
return [table for table in tables if table not in views]
def get_table_type(connection, database, schema, table):
def get_table_type(self, connection, database, schema, table):
"""get table type (regular/foreign)"""
try:
if database:
@ -320,7 +389,14 @@ def get_table_type(connection, database, schema, table):
)
else:
query = f"DESCRIBE TABLE EXTENDED {schema}.{table}"
rows = connection.execute(query)
rows = get_table_comment_result(
self,
connection=connection,
query=query,
database=database,
table_name=table,
schema=schema,
)
for row in rows:
row_dict = dict(row)
if row_dict.get("col_name") == "Type":
@ -338,6 +414,8 @@ DatabricksDialect.get_schema_names = get_schema_names
DatabricksDialect.get_view_definition = get_view_definition
DatabricksDialect.get_table_names = get_table_names
DatabricksDialect.get_all_view_definitions = get_all_view_definitions
DatabricksDialect.get_table_comment_results = get_table_comment_results
DatabricksDialect.get_table_comment_result = get_table_comment_result
reflection.Inspector.get_schema_names = get_schema_names_reflection
reflection.Inspector.get_table_ddl = get_table_ddl
@ -677,12 +755,17 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB
) -> str:
description = None
try:
cursor = self.connection.execute(
DATABRICKS_GET_TABLE_COMMENTS.format(
query = DATABRICKS_GET_TABLE_COMMENTS.format(
database_name=self.context.get().database,
schema_name=schema_name,
table_name=table_name,
)
cursor = inspector.dialect.get_table_comment_result(
connection=self.connection,
query=query,
database=self.context.get().database,
table_name=table_name,
schema=schema_name,
)
for result in list(cursor):
data = result.values()
@ -729,7 +812,13 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB
schema_name=self.context.get().database_schema,
table_name=table_name,
)
result = self.connection.engine.execute(query)
result = self.inspector.dialect.get_table_comment_result(
connection=self.connection,
query=query,
database=self.context.get().database,
table_name=table_name,
schema=self.context.get().database_schema,
)
owner = None
for row in result:
row_dict = dict(row)

View File

@ -169,3 +169,28 @@ def get_table_ddl(
table_name=table_name,
schema=schema,
)
@reflection.cache
def get_table_comment_results(
self, connection, query, database, table_name, schema=None
):
"""
Method to fetch comment of all available tables
"""
self.table_comment_result: Dict[Tuple[str, str], str] = {}
self.current_db: str = database
result = connection.execute(query)
self.table_comment_result[(table_name, schema)] = result
def get_table_comment_result_wrapper(
self, connection, query, database, table_name, schema=None
):
if (
not hasattr(self, "table_comment_result")
or self.table_comment_result.get((table_name, schema)) is None
or self.current_db != database
):
self.get_table_comment_results(connection, query, database, table_name, schema)
return self.table_comment_result.get((table_name, schema))