From 8dba4a5ac23aa767924f1809a12177cfa9ed6ea1 Mon Sep 17 00:00:00 2001 From: Akash Verma <138790903+akashverma0786@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:47:28 +0530 Subject: [PATCH] Fixes: snowflake lowercase issue (#19486) --- .../source/database/snowflake/utils.py | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py b/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py index 4f2e7f007f6..a1660f13758 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py @@ -12,10 +12,12 @@ """ Module to define overriden dialect methods """ - +import operator +from functools import reduce from typing import Dict, Optional import sqlalchemy.types as sqltypes +from snowflake.sqlalchemy.snowdialect import SnowflakeDialect from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util from sqlalchemy.engine import reflection @@ -52,6 +54,7 @@ from metadata.utils.sqlalchemy_utils import ( get_table_comment_wrapper, ) +dialect = SnowflakeDialect() Query = str QueryMap = Dict[str, Query] @@ -83,6 +86,20 @@ VIEW_QUERY_MAPS = { } +def _denormalize_quote_join(*idents): + ip = dialect.identifier_preparer + split_idents = reduce( + operator.add, + [ip._split_schema_by_dot(ids) for ids in idents if ids is not None], + ) + quoted_identifiers = ip._quote_free_identifiers(*split_idents) + normalized_identifiers = ( + item if item.startswith('"') and item.endswith('"') else f'"{item}"' + for item in quoted_identifiers + ) + return ".".join(normalized_identifiers) + + def _quoted_name(entity_name: Optional[str]) -> Optional[str]: if entity_name: return fqn.quote_name(entity_name) @@ -256,17 +273,16 @@ def get_schema_columns(self, connection, schema, **kw): None, as it is cacheable and is an unexpected return type for this function""" ans = {} current_database, _ = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join( - current_database, fqn.quote_name(schema) - ) + full_schema_name = _denormalize_quote_join(current_database, fqn.quote_name(schema)) try: schema_primary_keys = self._get_schema_primary_keys( connection, full_schema_name, **kw ) + # removing " " from schema name because schema name is in the WHERE clause of a query + table_schema = self.denormalize_name(fqn.unquote_name(schema)) + table_schema = table_schema.lower() if schema.islower() else table_schema result = connection.execute( - text(SNOWFLAKE_GET_SCHEMA_COLUMNS), - {"table_schema": self.denormalize_name(fqn.unquote_name(schema))} - # removing " " from schema name because schema name is in the WHERE clause of a query + text(SNOWFLAKE_GET_SCHEMA_COLUMNS), {"table_schema": table_schema} ) except sa_exc.ProgrammingError as p_err: @@ -362,9 +378,10 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): schema = schema or self.default_schema_name schema = _quoted_name(entity_name=schema) current_database, current_schema = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join( + full_schema_name = _denormalize_quote_join( current_database, schema if schema else current_schema ) + return self._get_schema_primary_keys( connection, self.denormalize_name(full_schema_name), **kw ).get(table_name, {"constrained_columns": [], "name": None}) @@ -378,7 +395,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): schema = schema or self.default_schema_name schema = _quoted_name(entity_name=schema) current_database, current_schema = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join( + full_schema_name = _denormalize_quote_join( current_database, schema if schema else current_schema ) @@ -452,9 +469,10 @@ def get_unique_constraints(self, connection, table_name, schema, **kw): schema = schema or self.default_schema_name schema = _quoted_name(entity_name=schema) current_database, current_schema = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join( + full_schema_name = _denormalize_quote_join( current_database, schema if schema else current_schema ) + return self._get_schema_unique_constraints( connection, self.denormalize_name(full_schema_name), **kw ).get(table_name, [])