2022-08-16 09:24:02 +05:30
|
|
|
import logging
|
2022-09-12 23:12:52 +05:30
|
|
|
from typing import Any, Optional
|
2022-08-16 09:24:02 +05:30
|
|
|
|
|
|
|
|
from snowflake.connector import SnowflakeConnection
|
|
|
|
|
from snowflake.connector.cursor import DictCursor
|
2022-09-12 23:12:52 +05:30
|
|
|
from typing_extensions import Protocol
|
2022-08-16 09:24:02 +05:30
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
from datahub.configuration.common import MetaError
|
2022-12-06 00:57:25 +05:30
|
|
|
from datahub.configuration.pattern_utils import is_schema_allowed
|
2024-06-14 13:23:07 -07:00
|
|
|
from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance
|
2022-12-28 21:50:37 +05:30
|
|
|
from datahub.ingestion.source.snowflake.constants import (
|
|
|
|
|
GENERIC_PERMISSION_ERROR_KEY,
|
|
|
|
|
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING,
|
2023-12-18 19:54:31 +01:00
|
|
|
SnowflakeCloudProvider,
|
2022-12-28 21:50:37 +05:30
|
|
|
SnowflakeObjectDomain,
|
|
|
|
|
)
|
2022-08-16 09:24:02 +05:30
|
|
|
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
|
|
|
|
|
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
|
|
|
|
|
|
2022-12-03 00:08:46 +05:30
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
class SnowflakePermissionError(MetaError):
|
|
|
|
|
"""A permission error has happened"""
|
2022-12-03 00:08:46 +05:30
|
|
|
|
2022-08-16 09:24:02 +05:30
|
|
|
|
2022-11-11 21:17:09 +05:30
|
|
|
# Required only for mypy, since we are using mixin classes, and not inheritance.
|
|
|
|
|
# Reference - https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
|
2022-08-16 09:24:02 +05:30
|
|
|
class SnowflakeLoggingProtocol(Protocol):
|
2022-12-07 10:33:10 +01:00
|
|
|
logger: logging.Logger
|
2022-08-16 09:24:02 +05:30
|
|
|
|
|
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
class SnowflakeQueryProtocol(SnowflakeLoggingProtocol, Protocol):
|
|
|
|
|
def get_connection(self) -> SnowflakeConnection:
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SnowflakeQueryMixin:
|
|
|
|
|
def query(self: SnowflakeQueryProtocol, query: str) -> Any:
|
|
|
|
|
try:
|
2024-06-18 15:16:20 -07:00
|
|
|
self.logger.info(f"Query : {query}", stacklevel=2)
|
2022-12-28 21:50:37 +05:30
|
|
|
resp = self.get_connection().cursor(DictCursor).execute(query)
|
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if is_permission_error(e):
|
|
|
|
|
raise SnowflakePermissionError(e) from e
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
2022-12-07 10:33:10 +01:00
|
|
|
class SnowflakeCommonProtocol(SnowflakeLoggingProtocol, Protocol):
|
2024-06-14 13:23:07 -07:00
|
|
|
platform: str = "snowflake"
|
|
|
|
|
|
2022-12-07 10:33:10 +01:00
|
|
|
config: SnowflakeV2Config
|
|
|
|
|
report: SnowflakeV2Report
|
2022-08-16 09:24:02 +05:30
|
|
|
|
|
|
|
|
def get_dataset_identifier(
|
|
|
|
|
self, table_name: str, schema_name: str, db_name: str
|
|
|
|
|
) -> str:
|
|
|
|
|
...
|
|
|
|
|
|
2022-11-11 21:17:09 +05:30
|
|
|
def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str:
|
|
|
|
|
...
|
|
|
|
|
|
2022-08-16 09:24:02 +05:30
|
|
|
def snowflake_identifier(self, identifier: str) -> str:
|
|
|
|
|
...
|
|
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
def report_warning(self, key: str, reason: str) -> None:
|
|
|
|
|
...
|
2022-08-16 09:24:02 +05:30
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
def report_error(self, key: str, reason: str) -> None:
|
|
|
|
|
...
|
2022-08-16 09:24:02 +05:30
|
|
|
|
|
|
|
|
|
|
|
|
|
class SnowflakeCommonMixin:
|
2022-08-28 04:55:50 +05:30
|
|
|
platform = "snowflake"
|
|
|
|
|
|
2023-12-18 19:54:31 +01:00
|
|
|
CLOUD_REGION_IDS_WITHOUT_CLOUD_SUFFIX = [
|
|
|
|
|
"us-west-2",
|
|
|
|
|
"us-east-1",
|
|
|
|
|
"eu-west-1",
|
|
|
|
|
"eu-central-1",
|
|
|
|
|
"ap-southeast-1",
|
|
|
|
|
"ap-southeast-2",
|
|
|
|
|
]
|
|
|
|
|
|
2022-12-03 00:08:46 +05:30
|
|
|
@staticmethod
|
2022-12-17 00:30:42 +05:30
|
|
|
def create_snowsight_base_url(
|
|
|
|
|
account_locator: str,
|
|
|
|
|
cloud_region_id: str,
|
|
|
|
|
cloud: str,
|
|
|
|
|
privatelink: bool = False,
|
|
|
|
|
) -> Optional[str]:
|
2023-12-18 19:54:31 +01:00
|
|
|
if cloud:
|
|
|
|
|
url_cloud_provider_suffix = f".{cloud}"
|
|
|
|
|
|
|
|
|
|
if cloud == SnowflakeCloudProvider.AWS:
|
|
|
|
|
# Some AWS regions do not have cloud suffix. See below the list:
|
|
|
|
|
# https://docs.snowflake.com/en/user-guide/admin-account-identifier#non-vps-account-locator-formats-by-cloud-platform-and-region
|
|
|
|
|
if (
|
|
|
|
|
cloud_region_id
|
|
|
|
|
in SnowflakeCommonMixin.CLOUD_REGION_IDS_WITHOUT_CLOUD_SUFFIX
|
|
|
|
|
):
|
|
|
|
|
url_cloud_provider_suffix = ""
|
|
|
|
|
else:
|
|
|
|
|
url_cloud_provider_suffix = f".{cloud}"
|
2022-12-17 00:30:42 +05:30
|
|
|
if privatelink:
|
|
|
|
|
url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.snowflakecomputing.com/"
|
|
|
|
|
else:
|
2023-12-18 19:54:31 +01:00
|
|
|
url = f"https://app.snowflake.com/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
|
2022-12-17 00:30:42 +05:30
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_cloud_region_from_snowflake_region_id(region):
|
|
|
|
|
if region in SNOWFLAKE_REGION_CLOUD_REGION_MAPPING.keys():
|
|
|
|
|
cloud, cloud_region_id = SNOWFLAKE_REGION_CLOUD_REGION_MAPPING[region]
|
|
|
|
|
elif region.startswith(("aws_", "gcp_", "azure_")):
|
|
|
|
|
# e.g. aws_us_west_2, gcp_us_central1, azure_northeurope
|
|
|
|
|
cloud, cloud_region_id = region.split("_", 1)
|
|
|
|
|
cloud_region_id = cloud_region_id.replace("_", "-")
|
2022-12-03 00:08:46 +05:30
|
|
|
else:
|
2022-12-17 00:30:42 +05:30
|
|
|
raise Exception(f"Unknown snowflake region {region}")
|
|
|
|
|
return cloud, cloud_region_id
|
2022-12-03 00:08:46 +05:30
|
|
|
|
2022-08-16 09:24:02 +05:30
|
|
|
def _is_dataset_pattern_allowed(
|
|
|
|
|
self: SnowflakeCommonProtocol,
|
|
|
|
|
dataset_name: Optional[str],
|
|
|
|
|
dataset_type: Optional[str],
|
2023-04-24 23:31:15 +05:30
|
|
|
is_upstream: bool = False,
|
2022-08-16 09:24:02 +05:30
|
|
|
) -> bool:
|
2023-04-24 23:31:15 +05:30
|
|
|
if is_upstream and not self.config.validate_upstreams_against_patterns:
|
|
|
|
|
return True
|
2022-08-16 09:24:02 +05:30
|
|
|
if not dataset_type or not dataset_name:
|
|
|
|
|
return True
|
|
|
|
|
dataset_params = dataset_name.split(".")
|
2022-12-28 21:50:37 +05:30
|
|
|
if dataset_type.lower() not in (
|
|
|
|
|
SnowflakeObjectDomain.TABLE,
|
|
|
|
|
SnowflakeObjectDomain.EXTERNAL_TABLE,
|
|
|
|
|
SnowflakeObjectDomain.VIEW,
|
|
|
|
|
SnowflakeObjectDomain.MATERIALIZED_VIEW,
|
|
|
|
|
):
|
|
|
|
|
return False
|
2022-08-16 09:24:02 +05:30
|
|
|
if len(dataset_params) != 3:
|
2022-12-28 21:50:37 +05:30
|
|
|
self.report_warning(
|
2022-08-16 09:24:02 +05:30
|
|
|
"invalid-dataset-pattern",
|
|
|
|
|
f"Found {dataset_params} of type {dataset_type}",
|
|
|
|
|
)
|
|
|
|
|
# NOTE: this case returned `True` earlier when extracting lineage
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if not self.config.database_pattern.allowed(
|
|
|
|
|
dataset_params[0].strip('"')
|
2022-12-06 00:57:25 +05:30
|
|
|
) or not is_schema_allowed(
|
|
|
|
|
self.config.schema_pattern,
|
|
|
|
|
dataset_params[1].strip('"'),
|
|
|
|
|
dataset_params[0].strip('"'),
|
|
|
|
|
self.config.match_fully_qualified_names,
|
|
|
|
|
):
|
2022-08-16 09:24:02 +05:30
|
|
|
return False
|
|
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
if dataset_type.lower() in {
|
|
|
|
|
SnowflakeObjectDomain.TABLE
|
|
|
|
|
} and not self.config.table_pattern.allowed(
|
2022-11-11 21:17:09 +05:30
|
|
|
self.get_dataset_identifier_from_qualified_name(dataset_name)
|
2022-08-16 09:24:02 +05:30
|
|
|
):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if dataset_type.lower() in {
|
|
|
|
|
"view",
|
|
|
|
|
"materialized_view",
|
2022-11-11 21:17:09 +05:30
|
|
|
} and not self.config.view_pattern.allowed(
|
|
|
|
|
self.get_dataset_identifier_from_qualified_name(dataset_name)
|
|
|
|
|
):
|
2022-08-16 09:24:02 +05:30
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str:
|
|
|
|
|
# to be in in sync with older connector, convert name to lowercase
|
|
|
|
|
if self.config.convert_urns_to_lowercase:
|
|
|
|
|
return identifier.lower()
|
|
|
|
|
return identifier
|
|
|
|
|
|
2024-06-14 13:23:07 -07:00
|
|
|
def gen_dataset_urn(self: SnowflakeCommonProtocol, dataset_identifier: str) -> str:
|
|
|
|
|
return make_dataset_urn_with_platform_instance(
|
|
|
|
|
platform=self.platform,
|
|
|
|
|
name=dataset_identifier,
|
|
|
|
|
platform_instance=self.config.platform_instance,
|
|
|
|
|
env=self.config.env,
|
|
|
|
|
)
|
|
|
|
|
|
2023-01-04 23:05:23 +02:00
|
|
|
@staticmethod
|
|
|
|
|
def get_quoted_identifier_for_database(db_name):
|
|
|
|
|
return f'"{db_name}"'
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_quoted_identifier_for_schema(db_name, schema_name):
|
|
|
|
|
return f'"{db_name}"."{schema_name}"'
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
|
|
|
|
|
return f'"{db_name}"."{schema_name}"."{table_name}"'
|
|
|
|
|
|
2022-08-16 09:24:02 +05:30
|
|
|
def get_dataset_identifier(
|
|
|
|
|
self: SnowflakeCommonProtocol, table_name: str, schema_name: str, db_name: str
|
|
|
|
|
) -> str:
|
|
|
|
|
return self.snowflake_identifier(f"{db_name}.{schema_name}.{table_name}")
|
|
|
|
|
|
|
|
|
|
# Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers,
|
|
|
|
|
# For example "test-database"."test-schema".test_table
|
|
|
|
|
# whereas we generate urns without quotes even for quoted identifiers for backward compatibility
|
|
|
|
|
# and also unavailability of utility function to identify whether current table/schema/database
|
|
|
|
|
# name should be quoted in above method get_dataset_identifier
|
|
|
|
|
def get_dataset_identifier_from_qualified_name(
|
|
|
|
|
self: SnowflakeCommonProtocol, qualified_name: str
|
|
|
|
|
) -> str:
|
|
|
|
|
name_parts = qualified_name.split(".")
|
|
|
|
|
if len(name_parts) != 3:
|
|
|
|
|
self.report.report_warning(
|
|
|
|
|
"invalid-dataset-pattern",
|
|
|
|
|
f"Found non-parseable {name_parts} for {qualified_name}",
|
|
|
|
|
)
|
|
|
|
|
return self.snowflake_identifier(qualified_name.replace('"', ""))
|
|
|
|
|
return self.get_dataset_identifier(
|
|
|
|
|
name_parts[2].strip('"'), name_parts[1].strip('"'), name_parts[0].strip('"')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Note - decide how to construct user urns.
|
|
|
|
|
# Historically urns were created using part before @ from user's email.
|
|
|
|
|
# Users without email were skipped from both user entries as well as aggregates.
|
|
|
|
|
# However email is not mandatory field in snowflake user, user_name is always present.
|
|
|
|
|
def get_user_identifier(
|
2023-08-03 08:30:50 +05:30
|
|
|
self: SnowflakeCommonProtocol,
|
|
|
|
|
user_name: str,
|
|
|
|
|
user_email: Optional[str],
|
|
|
|
|
email_as_user_identifier: bool,
|
2022-08-16 09:24:02 +05:30
|
|
|
) -> str:
|
2022-09-12 23:12:52 +05:30
|
|
|
if user_email:
|
2023-08-03 08:30:50 +05:30
|
|
|
return self.snowflake_identifier(
|
|
|
|
|
user_email
|
|
|
|
|
if email_as_user_identifier is True
|
|
|
|
|
else user_email.split("@")[0]
|
|
|
|
|
)
|
2022-08-16 09:24:02 +05:30
|
|
|
return self.snowflake_identifier(user_name)
|
|
|
|
|
|
2022-12-28 21:50:37 +05:30
|
|
|
# TODO: Revisit this after stateful ingestion can commit checkpoint
|
|
|
|
|
# for failures that do not affect the checkpoint
|
|
|
|
|
def warn_if_stateful_else_error(
|
|
|
|
|
self: SnowflakeCommonProtocol, key: str, reason: str
|
|
|
|
|
) -> None:
|
|
|
|
|
if (
|
|
|
|
|
self.config.stateful_ingestion is not None
|
|
|
|
|
and self.config.stateful_ingestion.enabled
|
|
|
|
|
):
|
|
|
|
|
self.report_warning(key, reason)
|
|
|
|
|
else:
|
|
|
|
|
self.report_error(key, reason)
|
|
|
|
|
|
|
|
|
|
def report_warning(self: SnowflakeCommonProtocol, key: str, reason: str) -> None:
|
|
|
|
|
self.report.report_warning(key, reason)
|
|
|
|
|
self.logger.warning(f"{key} => {reason}")
|
|
|
|
|
|
|
|
|
|
def report_error(self: SnowflakeCommonProtocol, key: str, reason: str) -> None:
|
|
|
|
|
self.report.report_failure(key, reason)
|
|
|
|
|
self.logger.error(f"{key} => {reason}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SnowflakeConnectionProtocol(SnowflakeLoggingProtocol, Protocol):
|
|
|
|
|
connection: Optional[SnowflakeConnection]
|
|
|
|
|
config: SnowflakeV2Config
|
|
|
|
|
report: SnowflakeV2Report
|
|
|
|
|
|
|
|
|
|
def create_connection(self) -> Optional[SnowflakeConnection]:
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
def report_error(self, key: str, reason: str) -> None:
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SnowflakeConnectionMixin:
|
|
|
|
|
def get_connection(self: SnowflakeConnectionProtocol) -> SnowflakeConnection:
|
|
|
|
|
if self.connection is None:
|
|
|
|
|
# Ideally this is never called here
|
|
|
|
|
self.logger.info("Did you forget to initialize connection for module?")
|
|
|
|
|
self.connection = self.create_connection()
|
|
|
|
|
|
|
|
|
|
# Connection is already present by the time its used for query
|
|
|
|
|
# Every module initializes the connection or fails and returns
|
|
|
|
|
assert self.connection is not None
|
|
|
|
|
return self.connection
|
|
|
|
|
|
|
|
|
|
# If connection succeeds, return connection, else return None and report failure
|
|
|
|
|
def create_connection(
|
|
|
|
|
self: SnowflakeConnectionProtocol,
|
|
|
|
|
) -> Optional[SnowflakeConnection]:
|
|
|
|
|
try:
|
|
|
|
|
conn = self.config.get_connection()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(e, exc_info=e)
|
|
|
|
|
if "not granted to this user" in str(e):
|
|
|
|
|
self.report_error(
|
|
|
|
|
GENERIC_PERMISSION_ERROR_KEY,
|
|
|
|
|
f"Failed to connect with snowflake due to error {e}",
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(e, exc_info=e)
|
|
|
|
|
self.report_error(
|
|
|
|
|
"snowflake-connection",
|
|
|
|
|
f"Failed to connect to snowflake instance due to error {e}.",
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return conn
|
|
|
|
|
|
|
|
|
|
def close(self: SnowflakeConnectionProtocol) -> None:
|
|
|
|
|
if self.connection is not None and not self.connection.is_closed():
|
|
|
|
|
self.connection.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_permission_error(e: Exception) -> bool:
|
|
|
|
|
msg = str(e)
|
|
|
|
|
# 002003 (02000): SQL compilation error: Database/SCHEMA 'XXXX' does not exist or not authorized.
|
|
|
|
|
# Insufficient privileges to operate on database 'XXXX'
|
|
|
|
|
return "Insufficient privileges" in msg or "not authorized" in msg
|