Fixes: snowflake lowercase issue (#19486)

This commit is contained in:
Akash Verma 2025-02-04 17:47:28 +05:30 committed by GitHub
parent cd03a60b74
commit 8dba4a5ac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,10 +12,12 @@
"""
Module to define overriden dialect methods
"""
import operator
from functools import reduce
from typing import Dict, Optional
import sqlalchemy.types as sqltypes
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
from sqlalchemy.engine import reflection
@ -52,6 +54,7 @@ from metadata.utils.sqlalchemy_utils import (
get_table_comment_wrapper,
)
dialect = SnowflakeDialect()
Query = str
QueryMap = Dict[str, Query]
@ -83,6 +86,20 @@ VIEW_QUERY_MAPS = {
}
def _denormalize_quote_join(*idents):
ip = dialect.identifier_preparer
split_idents = reduce(
operator.add,
[ip._split_schema_by_dot(ids) for ids in idents if ids is not None],
)
quoted_identifiers = ip._quote_free_identifiers(*split_idents)
normalized_identifiers = (
item if item.startswith('"') and item.endswith('"') else f'"{item}"'
for item in quoted_identifiers
)
return ".".join(normalized_identifiers)
def _quoted_name(entity_name: Optional[str]) -> Optional[str]:
if entity_name:
return fqn.quote_name(entity_name)
@ -256,17 +273,16 @@ def get_schema_columns(self, connection, schema, **kw):
None, as it is cacheable and is an unexpected return type for this function"""
ans = {}
current_database, _ = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
current_database, fqn.quote_name(schema)
)
full_schema_name = _denormalize_quote_join(current_database, fqn.quote_name(schema))
try:
schema_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, **kw
)
# removing " " from schema name because schema name is in the WHERE clause of a query
table_schema = self.denormalize_name(fqn.unquote_name(schema))
table_schema = table_schema.lower() if schema.islower() else table_schema
result = connection.execute(
text(SNOWFLAKE_GET_SCHEMA_COLUMNS),
{"table_schema": self.denormalize_name(fqn.unquote_name(schema))}
# removing " " from schema name because schema name is in the WHERE clause of a query
text(SNOWFLAKE_GET_SCHEMA_COLUMNS), {"table_schema": table_schema}
)
except sa_exc.ProgrammingError as p_err:
@ -362,9 +378,10 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)
return self._get_schema_primary_keys(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, {"constrained_columns": [], "name": None})
@ -378,7 +395,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)
@ -452,9 +469,10 @@ def get_unique_constraints(self, connection, table_name, schema, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)
return self._get_schema_unique_constraints(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, [])