Fix #19633: Fix databricks schema not found (#19646)

This commit is contained in:
Mayur Singal 2025-02-04 11:42:11 +05:30 committed by GitHub
parent 636a83514d
commit 208c40be09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 41 deletions

View File

@ -113,11 +113,6 @@ _type_map.update(
) )
def format_schema_name(schema):
# Adds back quotes(``) if hyphen(-) in schema name
return f"`{schema}`" if "-" in schema else schema
# This method is from hive dialect originally but # This method is from hive dialect originally but
# is overridden to optimize DESCRIBE query execution # is overridden to optimize DESCRIBE query execution
def _get_table_columns(self, connection, table_name, schema, db_name): def _get_table_columns(self, connection, table_name, schema, db_name):
@ -158,7 +153,6 @@ def _get_table_columns(self, connection, table_name, schema, db_name):
def _get_column_rows(self, connection, table_name, schema, db_name): def _get_column_rows(self, connection, table_name, schema, db_name):
# get columns and strip whitespace # get columns and strip whitespace
schema = format_schema_name(schema=schema)
table_columns = _get_table_columns( # pylint: disable=protected-access table_columns = _get_table_columns( # pylint: disable=protected-access
self, connection, table_name, schema, db_name self, connection, table_name, schema, db_name
) )
@ -212,11 +206,10 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"system_data_type": raw_col_type, "system_data_type": raw_col_type,
} }
if col_type in {"array", "struct", "map"}: if col_type in {"array", "struct", "map"}:
col_name = f"`{col_name}`" if "." in col_name else col_name
try: try:
rows = dict( rows = dict(
connection.execute( connection.execute(
f"DESCRIBE TABLE {kw.get('db_name')}.{schema}.{table_name} {col_name}" f"DESCRIBE TABLE `{kw.get('db_name')}`.`{schema}`.`{table_name}` `{col_name}`"
).fetchall() ).fetchall()
) )
col_info["system_data_type"] = rows["data_type"] col_info["system_data_type"] = rows["data_type"]
@ -393,8 +386,7 @@ def get_table_type(self, connection, database, schema, table):
database_name=database, schema_name=schema, table_name=table database_name=database, schema_name=schema, table_name=table
) )
else: else:
schema = format_schema_name(schema=schema) query = f"DESCRIBE TABLE EXTENDED `{schema}`.`{table}`"
query = f"DESCRIBE TABLE EXTENDED {schema}.{table}"
rows = get_table_comment_result( rows = get_table_comment_result(
self, self,
connection=connection, connection=connection,
@ -761,7 +753,6 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB
) -> str: ) -> str:
description = None description = None
try: try:
schema_name = format_schema_name(schema=schema_name)
query = DATABRICKS_GET_TABLE_COMMENTS.format( query = DATABRICKS_GET_TABLE_COMMENTS.format(
database_name=self.context.get().database, database_name=self.context.get().database,
schema_name=schema_name, schema_name=schema_name,
@ -816,9 +807,7 @@ class DatabricksSource(ExternalTableLineageMixin, CommonDbSourceService, MultiDB
try: try:
query = DATABRICKS_GET_TABLE_COMMENTS.format( query = DATABRICKS_GET_TABLE_COMMENTS.format(
database_name=self.context.get().database, database_name=self.context.get().database,
schema_name=format_schema_name( schema_name=self.context.get().database_schema,
schema=self.context.get().database_schema
),
table_name=table_name, table_name=table_name,
) )
result = self.inspector.dialect.get_table_comment_result( result = self.inspector.dialect.get_table_comment_result(

View File

@ -25,27 +25,27 @@ DATABRICKS_VIEW_DEFINITIONS = textwrap.dedent(
) )
DATABRICKS_GET_TABLE_COMMENTS = ( DATABRICKS_GET_TABLE_COMMENTS = (
"DESCRIBE TABLE EXTENDED {database_name}.{schema_name}.{table_name}" "DESCRIBE TABLE EXTENDED `{database_name}`.`{schema_name}`.`{table_name}`"
) )
DATABRICKS_GET_CATALOGS = "SHOW CATALOGS" DATABRICKS_GET_CATALOGS = "SHOW CATALOGS"
DATABRICKS_GET_CATALOGS_TAGS = textwrap.dedent( DATABRICKS_GET_CATALOGS_TAGS = textwrap.dedent(
"""SELECT * FROM {database_name}.information_schema.catalog_tags;""" """SELECT * FROM `{database_name}`.information_schema.catalog_tags;"""
) )
DATABRICKS_GET_SCHEMA_TAGS = textwrap.dedent( DATABRICKS_GET_SCHEMA_TAGS = textwrap.dedent(
""" """
SELECT SELECT
* *
FROM {database_name}.information_schema.schema_tags""" FROM `{database_name}`.information_schema.schema_tags"""
) )
DATABRICKS_GET_TABLE_TAGS = textwrap.dedent( DATABRICKS_GET_TABLE_TAGS = textwrap.dedent(
""" """
SELECT SELECT
* *
FROM {database_name}.information_schema.table_tags FROM `{database_name}`.information_schema.table_tags
""" """
) )
@ -53,8 +53,8 @@ DATABRICKS_GET_COLUMN_TAGS = textwrap.dedent(
""" """
SELECT SELECT
* *
FROM {database_name}.information_schema.column_tags FROM `{database_name}`.information_schema.column_tags
""" """
) )
DATABRICKS_DDL = "SHOW CREATE TABLE {table_name}" DATABRICKS_DDL = "SHOW CREATE TABLE `{table_name}`"

View File

@ -1,21 +0,0 @@
import pytest
from metadata.ingestion.source.database.databricks.metadata import format_schema_name
@pytest.mark.parametrize(
"input_schema, expected_schema",
[
("test_schema-name", "`test_schema-name`"),
("test_schema_name", "test_schema_name"),
("schema-with-hyphen", "`schema-with-hyphen`"),
("schema_with_underscore", "schema_with_underscore"),
("validSchema", "validSchema"),
],
)
def test_schema_name_sanitization(input_schema, expected_schema):
"""
Test sanitization of schema names by adding backticks only around hyphenated names.
"""
sanitized_schema = format_schema_name(input_schema)
assert sanitized_schema == expected_schema