feat(ingest/snowflake): integrate snowflake-queries into main source (#10905)

This commit is contained in:
Harshal Sheth 2024-07-17 10:22:14 -07:00 committed by GitHub
parent 79e1e2eb58
commit bccfd8f0a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 614 additions and 480 deletions

View File

@ -1,3 +1,4 @@
import contextlib
import datetime
import logging
from abc import ABCMeta, abstractmethod
@ -10,6 +11,7 @@ from typing import (
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
@ -97,6 +99,7 @@ class StructuredLogs(Report):
context: Optional[str] = None,
exc: Optional[BaseException] = None,
log: bool = False,
stacklevel: int = 1,
) -> None:
"""
Report a user-facing warning for the ingestion run.
@ -109,7 +112,8 @@ class StructuredLogs(Report):
exc: The exception associated with the event. We'll show the stack trace when in debug mode.
"""
stacklevel = 2
# One for this method, and one for the containing report_* call.
stacklevel = stacklevel + 2
log_key = f"{title}-{message}"
entries = self._entries[level]
@ -118,6 +122,8 @@ class StructuredLogs(Report):
context = f"{context[:_MAX_CONTEXT_STRING_LENGTH]} ..."
log_content = f"{message} => {context}" if context else message
if title:
log_content = f"{title}: {log_content}"
if exc:
log_content += f"{log_content}: {exc}"
@ -255,9 +261,10 @@ class SourceReport(Report):
context: Optional[str] = None,
title: Optional[LiteralString] = None,
exc: Optional[BaseException] = None,
log: bool = True,
) -> None:
self._structured_logs.report_log(
StructuredLogLevel.ERROR, message, title, context, exc, log=False
StructuredLogLevel.ERROR, message, title, context, exc, log=log
)
def failure(
@ -266,9 +273,10 @@ class SourceReport(Report):
context: Optional[str] = None,
title: Optional[LiteralString] = None,
exc: Optional[BaseException] = None,
log: bool = True,
) -> None:
self._structured_logs.report_log(
StructuredLogLevel.ERROR, message, title, context, exc, log=True
StructuredLogLevel.ERROR, message, title, context, exc, log=log
)
def info(
@ -277,11 +285,30 @@ class SourceReport(Report):
context: Optional[str] = None,
title: Optional[LiteralString] = None,
exc: Optional[BaseException] = None,
log: bool = True,
) -> None:
self._structured_logs.report_log(
StructuredLogLevel.INFO, message, title, context, exc, log=True
StructuredLogLevel.INFO, message, title, context, exc, log=log
)
@contextlib.contextmanager
def report_exc(
self,
message: LiteralString,
title: Optional[LiteralString] = None,
context: Optional[str] = None,
level: StructuredLogLevel = StructuredLogLevel.ERROR,
) -> Iterator[None]:
# Convenience method that helps avoid boilerplate try/except blocks.
# TODO: I'm not super happy with the naming here - it's not obvious that this
# suppresses the exception in addition to reporting it.
try:
yield
except Exception as exc:
self._structured_logs.report_log(
level, message=message, title=title, context=context, exc=exc
)
def __post_init__(self) -> None:
self.start_time = datetime.datetime.now()
self.running_time: datetime.timedelta = datetime.timedelta(seconds=0)

View File

@ -11,14 +11,13 @@ from datahub.emitter.mce_builder import (
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeIdentifierConfig,
SnowflakeV2Config,
)
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_connection import SnowflakeConnection
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeIdentifierMixin
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeIdentifierBuilder,
)
from datahub.metadata.com.linkedin.pegasus2avro.assertion import (
AssertionResult,
AssertionResultType,
@ -40,23 +39,20 @@ class DataQualityMonitoringResult(BaseModel):
VALUE: int
class SnowflakeAssertionsHandler(SnowflakeIdentifierMixin):
class SnowflakeAssertionsHandler:
def __init__(
self,
config: SnowflakeV2Config,
report: SnowflakeV2Report,
connection: SnowflakeConnection,
identifiers: SnowflakeIdentifierBuilder,
) -> None:
self.config = config
self.report = report
self.logger = logger
self.connection = connection
self.identifiers = identifiers
self._urns_processed: List[str] = []
@property
def identifier_config(self) -> SnowflakeIdentifierConfig:
return self.config
def get_assertion_workunits(
self, discovered_datasets: List[str]
) -> Iterable[MetadataWorkUnit]:
@ -80,10 +76,10 @@ class SnowflakeAssertionsHandler(SnowflakeIdentifierMixin):
return MetadataChangeProposalWrapper(
entityUrn=urn,
aspect=DataPlatformInstance(
platform=make_data_platform_urn(self.platform),
platform=make_data_platform_urn(self.identifiers.platform),
instance=(
make_dataplatform_instance_urn(
self.platform, self.config.platform_instance
self.identifiers.platform, self.config.platform_instance
)
if self.config.platform_instance
else None
@ -98,7 +94,7 @@ class SnowflakeAssertionsHandler(SnowflakeIdentifierMixin):
result = DataQualityMonitoringResult.parse_obj(result_row)
assertion_guid = result.METRIC_NAME.split("__")[-1].lower()
status = bool(result.VALUE) # 1 if PASS, 0 if FAIL
assertee = self.get_dataset_identifier(
assertee = self.identifiers.get_dataset_identifier(
result.TABLE_NAME, result.TABLE_SCHEMA, result.TABLE_DATABASE
)
if assertee in discovered_datasets:
@ -107,7 +103,7 @@ class SnowflakeAssertionsHandler(SnowflakeIdentifierMixin):
aspect=AssertionRunEvent(
timestampMillis=datetime_to_ts_millis(result.MEASUREMENT_TIME),
runId=result.MEASUREMENT_TIME.strftime("%Y-%m-%dT%H:%M:%SZ"),
asserteeUrn=self.gen_dataset_urn(assertee),
asserteeUrn=self.identifiers.gen_dataset_urn(assertee),
status=AssertionRunStatus.COMPLETE,
assertionUrn=make_assertion_urn(assertion_guid),
result=AssertionResult(

View File

@ -131,6 +131,7 @@ class SnowflakeIdentifierConfig(
# Changing default value here.
convert_urns_to_lowercase: bool = Field(
default=True,
description="Whether to convert dataset urns to lowercase.",
)
@ -210,8 +211,13 @@ class SnowflakeV2Config(
description="Populates view->view and table->view column lineage using DataHub's sql parser.",
)
lazy_schema_resolver: bool = Field(
use_queries_v2: bool = Field(
default=False,
description="If enabled, uses the new queries extractor to extract queries from snowflake.",
)
lazy_schema_resolver: bool = Field(
default=True,
description="If enabled, uses lazy schema resolver to resolve schemas for tables and views. "
"This is useful if you have a large number of schemas and want to avoid bulk fetching the schema for each table/view.",
)

View File

@ -2,7 +2,7 @@ import json
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Collection, Iterable, List, Optional, Set, Tuple, Type
from typing import Any, Collection, Iterable, List, Optional, Set, Tuple, Type
from pydantic import BaseModel, validator
@ -21,7 +21,11 @@ from datahub.ingestion.source.snowflake.snowflake_connection import (
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeFilter,
SnowflakeIdentifierBuilder,
)
from datahub.ingestion.source.state.redundant_run_skip_handler import (
RedundantLineageRunSkipHandler,
)
@ -119,18 +123,19 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
config: SnowflakeV2Config,
report: SnowflakeV2Report,
connection: SnowflakeConnection,
dataset_urn_builder: Callable[[str], str],
filters: SnowflakeFilter,
identifiers: SnowflakeIdentifierBuilder,
redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler],
sql_aggregator: SqlParsingAggregator,
) -> None:
self.config = config
self.report = report
self.logger = logger
self.dataset_urn_builder = dataset_urn_builder
self.connection = connection
self.filters = filters
self.identifiers = identifiers
self.redundant_run_skip_handler = redundant_run_skip_handler
self.sql_aggregator = sql_aggregator
self.redundant_run_skip_handler = redundant_run_skip_handler
self.start_time, self.end_time = (
self.report.lineage_start_time,
self.report.lineage_end_time,
@ -210,7 +215,7 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
results: Iterable[UpstreamLineageEdge],
) -> None:
for db_row in results:
dataset_name = self.get_dataset_identifier_from_qualified_name(
dataset_name = self.identifiers.get_dataset_identifier_from_qualified_name(
db_row.DOWNSTREAM_TABLE_NAME
)
if dataset_name not in discovered_assets or not db_row.QUERIES:
@ -233,7 +238,7 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
if not db_row.UPSTREAM_TABLES:
return None
downstream_table_urn = self.dataset_urn_builder(dataset_name)
downstream_table_urn = self.identifiers.gen_dataset_urn(dataset_name)
known_lineage = KnownQueryLineageInfo(
query_text=query.query_text,
@ -288,7 +293,7 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
external_tables_query: str = SnowflakeQuery.show_external_tables()
try:
for db_row in self.connection.query(external_tables_query):
key = self.get_dataset_identifier(
key = self.identifiers.get_dataset_identifier(
db_row["name"], db_row["schema_name"], db_row["database_name"]
)
@ -299,16 +304,16 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
upstream_urn=make_s3_urn_for_lineage(
db_row["location"], self.config.env
),
downstream_urn=self.dataset_urn_builder(key),
downstream_urn=self.identifiers.gen_dataset_urn(key),
)
self.report.num_external_table_edges_scanned += 1
self.report.num_external_table_edges_scanned += 1
except Exception as e:
logger.debug(e, exc_info=e)
self.report_warning(
"external_lineage",
f"Populating external table lineage from Snowflake failed due to error {e}.",
self.structured_reporter.warning(
"Error populating external table lineage from Snowflake",
exc=e,
)
self.report_status(EXTERNAL_LINEAGE, False)
@ -328,41 +333,47 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
try:
for db_row in self.connection.query(query):
known_lineage_mapping = self._process_external_lineage_result_row(
db_row, discovered_tables
db_row, discovered_tables, identifiers=self.identifiers
)
if known_lineage_mapping:
self.report.num_external_table_edges_scanned += 1
yield known_lineage_mapping
except Exception as e:
if isinstance(e, SnowflakePermissionError):
error_msg = "Failed to get external lineage. Please grant imported privileges on SNOWFLAKE database. "
self.warn_if_stateful_else_error(LINEAGE_PERMISSION_ERROR, error_msg)
else:
logger.debug(e, exc_info=e)
self.report_warning(
"external_lineage",
f"Populating table external lineage from Snowflake failed due to error {e}.",
self.structured_reporter.warning(
"Error fetching external lineage from Snowflake",
exc=e,
)
self.report_status(EXTERNAL_LINEAGE, False)
@classmethod
def _process_external_lineage_result_row(
self, db_row: dict, discovered_tables: List[str]
cls,
db_row: dict,
discovered_tables: Optional[List[str]],
identifiers: SnowflakeIdentifierBuilder,
) -> Optional[KnownLineageMapping]:
# key is the down-stream table name
key: str = self.get_dataset_identifier_from_qualified_name(
key: str = identifiers.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)
if key not in discovered_tables:
if discovered_tables is not None and key not in discovered_tables:
return None
if db_row["UPSTREAM_LOCATIONS"] is not None:
external_locations = json.loads(db_row["UPSTREAM_LOCATIONS"])
loc: str
for loc in external_locations:
if loc.startswith("s3://"):
self.report.num_external_table_edges_scanned += 1
return KnownLineageMapping(
upstream_urn=make_s3_urn_for_lineage(loc, self.config.env),
downstream_urn=self.dataset_urn_builder(key),
upstream_urn=make_s3_urn_for_lineage(
loc, identifiers.identifier_config.env
),
downstream_urn=identifiers.gen_dataset_urn(key),
)
return None
@ -388,10 +399,9 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
error_msg = "Failed to get table/view to table lineage. Please grant imported privileges on SNOWFLAKE database. "
self.warn_if_stateful_else_error(LINEAGE_PERMISSION_ERROR, error_msg)
else:
logger.debug(e, exc_info=e)
self.report_warning(
"table-upstream-lineage",
f"Extracting lineage from Snowflake failed due to error {e}.",
self.structured_reporter.warning(
"Failed to extract table/view -> table lineage from Snowflake",
exc=e,
)
self.report_status(TABLE_LINEAGE, False)
@ -402,9 +412,10 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
return UpstreamLineageEdge.parse_obj(db_row)
except Exception as e:
self.report.num_upstream_lineage_edge_parsing_failed += 1
self.report_warning(
f"Parsing lineage edge failed due to error {e}",
db_row.get("DOWNSTREAM_TABLE_NAME") or "",
self.structured_reporter.warning(
"Failed to parse lineage edge",
context=db_row.get("DOWNSTREAM_TABLE_NAME") or None,
exc=e,
)
return None
@ -417,17 +428,21 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
for upstream_table in upstream_tables:
if upstream_table and upstream_table.query_id == query_id:
try:
upstream_name = self.get_dataset_identifier_from_qualified_name(
upstream_table.upstream_object_name
upstream_name = (
self.identifiers.get_dataset_identifier_from_qualified_name(
upstream_table.upstream_object_name
)
)
if upstream_name and (
not self.config.validate_upstreams_against_patterns
or self.is_dataset_pattern_allowed(
or self.filters.is_dataset_pattern_allowed(
upstream_name,
upstream_table.upstream_object_domain,
)
):
upstreams.append(self.dataset_urn_builder(upstream_name))
upstreams.append(
self.identifiers.gen_dataset_urn(upstream_name)
)
except Exception as e:
logger.debug(e, exc_info=e)
return upstreams
@ -491,7 +506,7 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
return None
column_lineage = ColumnLineageInfo(
downstream=DownstreamColumnRef(
table=dataset_urn, column=self.snowflake_identifier(col)
table=dataset_urn, column=self.identifiers.snowflake_identifier(col)
),
upstreams=sorted(column_upstreams),
)
@ -508,19 +523,23 @@ class SnowflakeLineageExtractor(SnowflakeCommonMixin, Closeable):
and upstream_col.column_name
and (
not self.config.validate_upstreams_against_patterns
or self.is_dataset_pattern_allowed(
or self.filters.is_dataset_pattern_allowed(
upstream_col.object_name,
upstream_col.object_domain,
)
)
):
upstream_dataset_name = self.get_dataset_identifier_from_qualified_name(
upstream_col.object_name
upstream_dataset_name = (
self.identifiers.get_dataset_identifier_from_qualified_name(
upstream_col.object_name
)
)
column_upstreams.append(
ColumnRef(
table=self.dataset_urn_builder(upstream_dataset_name),
column=self.snowflake_identifier(upstream_col.column_name),
table=self.identifiers.gen_dataset_urn(upstream_dataset_name),
column=self.identifiers.snowflake_identifier(
upstream_col.column_name
),
)
)
return column_upstreams

View File

@ -37,7 +37,6 @@ class SnowflakeProfiler(GenericProfiler, SnowflakeCommonMixin):
super().__init__(config, report, self.platform, state_handler)
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = report
self.logger = logger
self.database_default_schema: Dict[str, str] = dict()
def get_workunits(
@ -86,7 +85,7 @@ class SnowflakeProfiler(GenericProfiler, SnowflakeCommonMixin):
)
def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> str:
return self.get_dataset_identifier(table_name, schema_name, db_name)
return self.identifiers.get_dataset_identifier(table_name, schema_name, db_name)
def get_batch_kwargs(
self, table: BaseTable, schema_name: str, db_name: str

View File

@ -1,3 +1,4 @@
import dataclasses
import functools
import json
import logging
@ -11,6 +12,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import pydantic
from typing_extensions import Self
from datahub.configuration.common import ConfigModel
from datahub.configuration.time_window_config import (
BaseTimeWindowConfig,
BucketDuration,
@ -20,6 +22,7 @@ from datahub.ingestion.api.report import Report
from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
from datahub.ingestion.source.snowflake.snowflake_config import (
DEFAULT_TEMP_TABLES_PATTERNS,
@ -30,13 +33,18 @@ from datahub.ingestion.source.snowflake.snowflake_connection import (
SnowflakeConnection,
SnowflakeConnectionConfig,
)
from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import (
SnowflakeLineageExtractor,
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeFilterMixin,
SnowflakeIdentifierMixin,
SnowflakeFilter,
SnowflakeIdentifierBuilder,
SnowflakeStructuredReportMixin,
)
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
from datahub.metadata.urns import CorpUserUrn
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sql_parsing_aggregator import (
KnownLineageMapping,
PreparsedQuery,
@ -50,11 +58,12 @@ from datahub.sql_parsing.sqlglot_lineage import (
DownstreamColumnRef,
)
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
from datahub.utilities.perf_timer import PerfTimer
logger = logging.getLogger(__name__)
class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilterConfig):
class SnowflakeQueriesExtractorConfig(ConfigModel):
# TODO: Support stateful ingestion for the time windows.
window: BaseTimeWindowConfig = BaseTimeWindowConfig()
@ -76,12 +85,6 @@ class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilter
hidden_from_docs=True,
)
convert_urns_to_lowercase: bool = pydantic.Field(
# Override the default.
default=True,
description="Whether to convert dataset urns to lowercase.",
)
include_lineage: bool = True
include_queries: bool = True
include_usage_statistics: bool = True
@ -89,40 +92,56 @@ class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilter
include_operations: bool = True
class SnowflakeQueriesSourceConfig(SnowflakeQueriesExtractorConfig):
class SnowflakeQueriesSourceConfig(
SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig
):
connection: SnowflakeConnectionConfig
@dataclass
class SnowflakeQueriesExtractorReport(Report):
window: Optional[BaseTimeWindowConfig] = None
copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None
@dataclass
class SnowflakeQueriesSourceReport(SourceReport):
window: Optional[BaseTimeWindowConfig] = None
queries_extractor: Optional[SnowflakeQueriesExtractorReport] = None
class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin):
def __init__(
self,
connection: SnowflakeConnection,
config: SnowflakeQueriesExtractorConfig,
structured_report: SourceReport,
filters: SnowflakeFilter,
identifiers: SnowflakeIdentifierBuilder,
graph: Optional[DataHubGraph] = None,
schema_resolver: Optional[SchemaResolver] = None,
discovered_tables: Optional[List[str]] = None,
):
self.connection = connection
self.config = config
self.report = SnowflakeQueriesExtractorReport()
self.filters = filters
self.identifiers = identifiers
self.discovered_tables = discovered_tables
self._structured_report = structured_report
self.aggregator = SqlParsingAggregator(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
# graph=self.ctx.graph,
platform=self.identifiers.platform,
platform_instance=self.identifiers.identifier_config.platform_instance,
env=self.identifiers.identifier_config.env,
schema_resolver=schema_resolver,
graph=graph,
eager_graph_load=False,
generate_lineage=self.config.include_lineage,
generate_queries=self.config.include_queries,
generate_usage_statistics=self.config.include_usage_statistics,
@ -144,14 +163,6 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
def structured_reporter(self) -> SourceReport:
return self._structured_report
@property
def filter_config(self) -> SnowflakeFilterConfig:
return self.config
@property
def identifier_config(self) -> SnowflakeIdentifierConfig:
return self.config
@functools.cached_property
def local_temp_path(self) -> pathlib.Path:
if self.config.local_temp_path:
@ -170,13 +181,16 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
)
def is_allowed_table(self, name: str) -> bool:
return self.is_dataset_pattern_allowed(name, SnowflakeObjectDomain.TABLE)
if self.discovered_tables and name not in self.discovered_tables:
return False
return self.filters.is_dataset_pattern_allowed(
name, SnowflakeObjectDomain.TABLE
)
def get_workunits_internal(
self,
) -> Iterable[MetadataWorkUnit]:
self.report.window = self.config.window
# TODO: Add some logic to check if the cached audit log is stale or not.
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()
@ -191,74 +205,90 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
shared_connection = ConnectionWrapper(audit_log_file)
queries = FileBackedList(shared_connection)
entry: Union[KnownLineageMapping, PreparsedQuery]
logger.info("Fetching audit log")
for entry in self.fetch_audit_log():
queries.append(entry)
with self.report.copy_history_fetch_timer:
for entry in self.fetch_copy_history():
queries.append(entry)
for query in queries:
self.aggregator.add(query)
# TODO: Add "show external tables" lineage to the main schema extractor.
# Because it's not a time-based thing, it doesn't really make sense in the snowflake-queries extractor.
with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log():
queries.append(entry)
with self.report.audit_log_load_timer:
for query in queries:
self.aggregator.add(query)
yield from auto_workunit(self.aggregator.gen_metadata())
def fetch_audit_log(
self,
) -> Iterable[Union[KnownLineageMapping, PreparsedQuery]]:
"""
# TODO: we need to fetch this info from somewhere
discovered_tables = []
def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
# Derived from _populate_external_lineage_from_copy_history.
snowflake_lineage_v2 = SnowflakeLineageExtractor(
config=self.config, # type: ignore
report=self.report, # type: ignore
dataset_urn_builder=self.gen_dataset_urn,
redundant_run_skip_handler=None,
sql_aggregator=self.aggregator, # TODO this should be unused
query: str = SnowflakeQuery.copy_lineage_history(
start_time_millis=int(self.config.window.start_time.timestamp() * 1000),
end_time_millis=int(self.config.window.end_time.timestamp() * 1000),
downstreams_deny_pattern=self.config.temporary_tables_pattern,
)
for (
known_lineage_mapping
) in snowflake_lineage_v2._populate_external_lineage_from_copy_history(
discovered_tables=discovered_tables
with self.structured_reporter.report_exc(
"Error fetching copy history from Snowflake"
):
interim_results.append(known_lineage_mapping)
logger.info("Fetching copy history from Snowflake")
resp = self.connection.query(query)
for (
known_lineage_mapping
) in snowflake_lineage_v2._populate_external_lineage_from_show_query(
discovered_tables=discovered_tables
):
interim_results.append(known_lineage_mapping)
"""
for row in resp:
try:
result = (
SnowflakeLineageExtractor._process_external_lineage_result_row(
row,
discovered_tables=self.discovered_tables,
identifiers=self.identifiers,
)
)
except Exception as e:
self.structured_reporter.warning(
"Error parsing copy history row",
context=f"{row}",
exc=e,
)
else:
if result:
yield result
audit_log_query = _build_enriched_audit_log_query(
def fetch_query_log(
self,
) -> Iterable[PreparsedQuery]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
end_time=self.config.window.end_time,
bucket_duration=self.config.window.bucket_duration,
deny_usernames=self.config.deny_usernames,
)
resp = self.connection.query(audit_log_query)
with self.structured_reporter.report_exc(
"Error fetching query log from Snowflake"
):
logger.info("Fetching query log from Snowflake")
resp = self.connection.query(query_log_query)
for i, row in enumerate(resp):
if i % 1000 == 0:
logger.info(f"Processed {i} audit log rows")
for i, row in enumerate(resp):
if i % 1000 == 0:
logger.info(f"Processed {i} query log rows")
assert isinstance(row, dict)
try:
entry = self._parse_audit_log_row(row)
except Exception as e:
self.structured_reporter.warning(
"Error parsing audit log row",
context=f"{row}",
exc=e,
)
else:
yield entry
def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str:
# Copied from SnowflakeCommonMixin.
return self.snowflake_identifier(self.cleanup_qualified_name(qualified_name))
assert isinstance(row, dict)
try:
entry = self._parse_audit_log_row(row)
except Exception as e:
self.structured_reporter.warning(
"Error parsing query log row",
context=f"{row}",
exc=e,
)
else:
yield entry
def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery:
json_fields = {
@ -280,13 +310,17 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
column_usage = {}
for obj in direct_objects_accessed:
dataset = self.gen_dataset_urn(
self.get_dataset_identifier_from_qualified_name(obj["objectName"])
dataset = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
obj["objectName"]
)
)
columns = set()
for modified_column in obj["columns"]:
columns.add(self.snowflake_identifier(modified_column["columnName"]))
columns.add(
self.identifiers.snowflake_identifier(modified_column["columnName"])
)
upstreams.append(dataset)
column_usage[dataset] = columns
@ -301,8 +335,10 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
context=f"{row}",
)
downstream = self.gen_dataset_urn(
self.get_dataset_identifier_from_qualified_name(obj["objectName"])
downstream = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
obj["objectName"]
)
)
column_lineage = []
for modified_column in obj["columns"]:
@ -310,18 +346,18 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
ColumnLineageInfo(
downstream=DownstreamColumnRef(
dataset=downstream,
column=self.snowflake_identifier(
column=self.identifiers.snowflake_identifier(
modified_column["columnName"]
),
),
upstreams=[
ColumnRef(
table=self.gen_dataset_urn(
self.get_dataset_identifier_from_qualified_name(
table=self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
upstream["objectName"]
)
),
column=self.snowflake_identifier(
column=self.identifiers.snowflake_identifier(
upstream["columnName"]
),
)
@ -332,12 +368,9 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
)
)
# TODO: Support filtering the table names.
# if objects_modified:
# breakpoint()
# TODO implement email address mapping
user = CorpUserUrn(res["user_name"])
# TODO: Fetch email addresses from Snowflake to map user -> email
# TODO: Support email_domain fallback for generating user urns.
user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"]))
timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
@ -348,14 +381,18 @@ class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
)
entry = PreparsedQuery(
query_id=res["query_fingerprint"],
# Despite having Snowflake's fingerprints available, our own fingerprinting logic does a better
# job at eliminating redundant / repetitive queries. As such, we don't include the fingerprint
# here so that the aggregator auto-generates one.
# query_id=res["query_fingerprint"],
query_id=None,
query_text=res["query_text"],
upstreams=upstreams,
downstream=downstream,
column_lineage=column_lineage,
column_usage=column_usage,
inferred_schema=None,
confidence_score=1,
confidence_score=1.0,
query_count=res["query_count"],
user=user,
timestamp=timestamp,
@ -371,7 +408,14 @@ class SnowflakeQueriesSource(Source):
self.config = config
self.report = SnowflakeQueriesSourceReport()
self.platform = "snowflake"
self.filters = SnowflakeFilter(
filter_config=self.config,
structured_reporter=self.report,
)
self.identifiers = SnowflakeIdentifierBuilder(
identifier_config=self.config,
structured_reporter=self.report,
)
self.connection = self.config.connection.get_connection()
@ -379,6 +423,9 @@ class SnowflakeQueriesSource(Source):
connection=self.connection,
config=self.config,
structured_report=self.report,
filters=self.filters,
identifiers=self.identifiers,
graph=self.ctx.graph,
)
self.report.queries_extractor = self.queries_extractor.report
@ -388,6 +435,8 @@ class SnowflakeQueriesSource(Source):
return cls(ctx, config)
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.report.window = self.config.window
# TODO: Disable auto status processor?
return self.queries_extractor.get_workunits_internal()
@ -399,7 +448,7 @@ class SnowflakeQueriesSource(Source):
_MAX_TABLES_PER_QUERY = 20
def _build_enriched_audit_log_query(
def _build_enriched_query_log_query(
start_time: datetime,
end_time: datetime,
bucket_duration: BucketDuration,

View File

@ -15,6 +15,9 @@ from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport
from datahub.utilities.perf_timer import PerfTimer
if TYPE_CHECKING:
from datahub.ingestion.source.snowflake.snowflake_queries import (
SnowflakeQueriesExtractorReport,
)
from datahub.ingestion.source.snowflake.snowflake_schema import (
SnowflakeDataDictionary,
)
@ -113,6 +116,8 @@ class SnowflakeV2Report(
data_dictionary_cache: Optional["SnowflakeDataDictionary"] = None
queries_extractor: Optional["SnowflakeQueriesExtractorReport"] = None
# These will be non-zero if snowflake information_schema queries fail with error -
# "Information schema query returned too much data. Please repeat query with more selective predicates.""
# This will result in overall increase in time complexity

View File

@ -185,8 +185,6 @@ class _SnowflakeTagCache:
class SnowflakeDataDictionary(SupportsAsObj):
def __init__(self, connection: SnowflakeConnection) -> None:
self.logger = logger
self.connection = connection
def as_obj(self) -> Dict[str, Dict[str, int]]:
@ -514,7 +512,7 @@ class SnowflakeDataDictionary(SupportsAsObj):
)
else:
# This should never happen.
self.logger.error(f"Encountered an unexpected domain: {domain}")
logger.error(f"Encountered an unexpected domain: {domain}")
continue
return tags

View File

@ -1,6 +1,6 @@
import itertools
import logging
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union
from datahub.configuration.pattern_utils import is_schema_allowed
from datahub.emitter.mce_builder import (
@ -26,8 +26,6 @@ from datahub.ingestion.source.snowflake.constants import (
SnowflakeObjectDomain,
)
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeFilterConfig,
SnowflakeIdentifierConfig,
SnowflakeV2Config,
TagOption,
)
@ -52,8 +50,9 @@ from datahub.ingestion.source.snowflake.snowflake_schema import (
)
from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeFilterMixin,
SnowflakeIdentifierMixin,
SnowflakeFilter,
SnowflakeIdentifierBuilder,
SnowflakeStructuredReportMixin,
SnowsightUrlBuilder,
)
from datahub.ingestion.source.sql.sql_utils import (
@ -142,13 +141,16 @@ SNOWFLAKE_FIELD_TYPE_MAPPINGS = {
}
class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
class SnowflakeSchemaGenerator(SnowflakeStructuredReportMixin):
platform = "snowflake"
def __init__(
self,
config: SnowflakeV2Config,
report: SnowflakeV2Report,
connection: SnowflakeConnection,
dataset_urn_builder: Callable[[str], str],
filters: SnowflakeFilter,
identifiers: SnowflakeIdentifierBuilder,
domain_registry: Optional[DomainRegistry],
profiler: Optional[SnowflakeProfiler],
aggregator: Optional[SqlParsingAggregator],
@ -157,7 +159,8 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = report
self.connection: SnowflakeConnection = connection
self.dataset_urn_builder = dataset_urn_builder
self.filters: SnowflakeFilter = filters
self.identifiers: SnowflakeIdentifierBuilder = identifiers
self.data_dictionary: SnowflakeDataDictionary = SnowflakeDataDictionary(
connection=self.connection
@ -185,19 +188,17 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
def structured_reporter(self) -> SourceReport:
return self.report
@property
def filter_config(self) -> SnowflakeFilterConfig:
return self.config
def gen_dataset_urn(self, dataset_identifier: str) -> str:
return self.identifiers.gen_dataset_urn(dataset_identifier)
@property
def identifier_config(self) -> SnowflakeIdentifierConfig:
return self.config
def snowflake_identifier(self, identifier: str) -> str:
return self.identifiers.snowflake_identifier(identifier)
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.databases = []
for database in self.get_databases() or []:
self.report.report_entity_scanned(database.name, "database")
if not self.filter_config.database_pattern.allowed(database.name):
if not self.filters.filter_config.database_pattern.allowed(database.name):
self.report.report_dropped(f"{database.name}.*")
else:
self.databases.append(database)
@ -211,7 +212,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
yield from self._process_database(snowflake_db)
except SnowflakePermissionError as e:
self.report_error(GENERIC_PERMISSION_ERROR_KEY, str(e))
self.structured_reporter.failure(
GENERIC_PERMISSION_ERROR_KEY,
exc=e,
)
return
def get_databases(self) -> Optional[List[SnowflakeDatabase]]:
@ -220,10 +224,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
# whose information_schema can be queried to start with.
databases = self.data_dictionary.show_databases()
except Exception as e:
logger.debug(f"Failed to list databases due to error {e}", exc_info=e)
self.report_error(
"list-databases",
f"Failed to list databases due to error {e}",
self.structured_reporter.failure(
"Failed to list databases",
exc=e,
)
return None
else:
@ -232,7 +235,7 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
] = self.get_databases_from_ischema(databases)
if len(ischema_databases) == 0:
self.report_error(
self.structured_reporter.failure(
GENERIC_PERMISSION_ERROR_KEY,
"No databases found. Please check permissions.",
)
@ -275,7 +278,7 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
# This may happen if REFERENCE_USAGE permissions are set
# We can not run show queries on database in such case.
# This need not be a failure case.
self.report_warning(
self.structured_reporter.warning(
"Insufficient privileges to operate on database, skipping. Please grant USAGE permissions on database to extract its metadata.",
db_name,
)
@ -284,9 +287,8 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
f"Failed to use database {db_name} due to error {e}",
exc_info=e,
)
self.report_warning(
"Failed to get schemas for database",
db_name,
self.structured_reporter.warning(
"Failed to get schemas for database", db_name, exc=e
)
return
@ -342,10 +344,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
for schema in self.data_dictionary.get_schemas_for_database(db_name):
self.report.report_entity_scanned(schema.name, "schema")
if not is_schema_allowed(
self.filter_config.schema_pattern,
self.filters.filter_config.schema_pattern,
schema.name,
db_name,
self.filter_config.match_fully_qualified_names,
self.filters.filter_config.match_fully_qualified_names,
):
self.report.report_dropped(f"{db_name}.{schema.name}.*")
else:
@ -356,17 +358,14 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
# Ideal implementation would use PEP 678 Enriching Exceptions with Notes
raise SnowflakePermissionError(error_msg) from e.__cause__
else:
logger.debug(
f"Failed to get schemas for database {db_name} due to error {e}",
exc_info=e,
)
self.report_warning(
self.structured_reporter.warning(
"Failed to get schemas for database",
db_name,
exc=e,
)
if not schemas:
self.report_warning(
self.structured_reporter.warning(
"No schemas found in database. If schemas exist, please grant USAGE permissions on them.",
db_name,
)
@ -421,12 +420,12 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
and self.config.parse_view_ddl
):
for view in views:
view_identifier = self.get_dataset_identifier(
view_identifier = self.identifiers.get_dataset_identifier(
view.name, schema_name, db_name
)
if view.view_definition:
self.aggregator.add_view_definition(
view_urn=self.dataset_urn_builder(view_identifier),
view_urn=self.identifiers.gen_dataset_urn(view_identifier),
view_definition=view.view_definition,
default_db=db_name,
default_schema=schema_name,
@ -441,9 +440,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
yield from self._process_tag(tag)
if not snowflake_schema.views and not snowflake_schema.tables:
self.report_warning(
"No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.",
f"{db_name}.{schema_name}",
self.structured_reporter.warning(
title="No tables/views found in schema",
message="If tables exist, please grant REFERENCES or SELECT permissions on them.",
context=f"{db_name}.{schema_name}",
)
def fetch_views_for_schema(
@ -452,11 +452,13 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
try:
views: List[SnowflakeView] = []
for view in self.get_views_for_schema(schema_name, db_name):
view_name = self.get_dataset_identifier(view.name, schema_name, db_name)
view_name = self.identifiers.get_dataset_identifier(
view.name, schema_name, db_name
)
self.report.report_entity_scanned(view_name, "view")
if not self.filter_config.view_pattern.allowed(view_name):
if not self.filters.filter_config.view_pattern.allowed(view_name):
self.report.report_dropped(view_name)
else:
views.append(view)
@ -469,13 +471,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
raise SnowflakePermissionError(error_msg) from e.__cause__
else:
logger.debug(
f"Failed to get views for schema {db_name}.{schema_name} due to error {e}",
exc_info=e,
)
self.report_warning(
self.structured_reporter.warning(
"Failed to get views for schema",
f"{db_name}.{schema_name}",
exc=e,
)
return []
@ -485,11 +484,13 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
try:
tables: List[SnowflakeTable] = []
for table in self.get_tables_for_schema(schema_name, db_name):
table_identifier = self.get_dataset_identifier(
table_identifier = self.identifiers.get_dataset_identifier(
table.name, schema_name, db_name
)
self.report.report_entity_scanned(table_identifier)
if not self.filter_config.table_pattern.allowed(table_identifier):
if not self.filters.filter_config.table_pattern.allowed(
table_identifier
):
self.report.report_dropped(table_identifier)
else:
tables.append(table)
@ -501,13 +502,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
error_msg = f"Failed to get tables for schema {db_name}.{schema_name}. Please check permissions."
raise SnowflakePermissionError(error_msg) from e.__cause__
else:
logger.debug(
f"Failed to get tables for schema {db_name}.{schema_name} due to error {e}",
exc_info=e,
)
self.report_warning(
self.structured_reporter.warning(
"Failed to get tables for schema",
f"{db_name}.{schema_name}",
exc=e,
)
return []
@ -526,7 +524,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
db_name: str,
) -> Iterable[MetadataWorkUnit]:
schema_name = snowflake_schema.name
table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name)
table_identifier = self.identifiers.get_dataset_identifier(
table.name, schema_name, db_name
)
try:
table.columns = self.get_columns_for_table(
@ -538,11 +538,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
table.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get columns for table {table_identifier} due to error {e}",
exc_info=e,
self.structured_reporter.warning(
"Failed to get columns for table", table_identifier, exc=e
)
self.report_warning("Failed to get columns for table", table_identifier)
if self.config.extract_tags != TagOption.skip:
table.tags = self.tag_extractor.get_tags_on_object(
@ -575,11 +573,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
table.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get foreign key for table {table_identifier} due to error {e}",
exc_info=e,
self.structured_reporter.warning(
"Failed to get foreign keys for table", table_identifier, exc=e
)
self.report_warning("Failed to get foreign key for table", table_identifier)
def fetch_pk_for_table(
self,
@ -593,11 +589,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
table.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get primary key for table {table_identifier} due to error {e}",
exc_info=e,
self.structured_reporter.warning(
"Failed to get primary key for table", table_identifier, exc=e
)
self.report_warning("Failed to get primary key for table", table_identifier)
def _process_view(
self,
@ -606,7 +600,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
db_name: str,
) -> Iterable[MetadataWorkUnit]:
schema_name = snowflake_schema.name
view_name = self.get_dataset_identifier(view.name, schema_name, db_name)
view_name = self.identifiers.get_dataset_identifier(
view.name, schema_name, db_name
)
try:
view.columns = self.get_columns_for_table(
@ -617,11 +613,9 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
view.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get columns for view {view_name} due to error {e}",
exc_info=e,
self.structured_reporter.warning(
"Failed to get columns for view", view_name, exc=e
)
self.report_warning("Failed to get columns for view", view_name)
if self.config.extract_tags != TagOption.skip:
view.tags = self.tag_extractor.get_tags_on_object(
@ -657,8 +651,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
for tag in table.column_tags[column_name]:
yield from self._process_tag(tag)
dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name)
dataset_urn = self.dataset_urn_builder(dataset_name)
dataset_name = self.identifiers.get_dataset_identifier(
table.name, schema_name, db_name
)
dataset_urn = self.identifiers.gen_dataset_urn(dataset_name)
status = Status(removed=False)
yield MetadataChangeProposalWrapper(
@ -799,8 +795,10 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
schema_name: str,
db_name: str,
) -> SchemaMetadata:
dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name)
dataset_urn = self.dataset_urn_builder(dataset_name)
dataset_name = self.identifiers.get_dataset_identifier(
table.name, schema_name, db_name
)
dataset_urn = self.identifiers.gen_dataset_urn(dataset_name)
foreign_keys: Optional[List[ForeignKeyConstraint]] = None
if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0:
@ -859,7 +857,7 @@ class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
for fk in table.foreign_keys:
foreign_dataset = make_dataset_urn_with_platform_instance(
platform=self.platform,
name=self.get_dataset_identifier(
name=self.identifiers.get_dataset_identifier(
fk.referred_table, fk.referred_schema, fk.referred_database
),
env=self.config.env,

View File

@ -1,5 +1,5 @@
import logging
from typing import Callable, Iterable, List
from typing import Iterable, List
from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance
from datahub.emitter.mcp import MetadataChangeProposalWrapper
@ -26,12 +26,9 @@ class SnowflakeSharesHandler(SnowflakeCommonMixin):
self,
config: SnowflakeV2Config,
report: SnowflakeV2Report,
dataset_urn_builder: Callable[[str], str],
) -> None:
self.config = config
self.report = report
self.logger = logger
self.dataset_urn_builder = dataset_urn_builder
def get_shares_workunits(
self, databases: List[SnowflakeDatabase]
@ -94,9 +91,10 @@ class SnowflakeSharesHandler(SnowflakeCommonMixin):
missing_dbs = [db for db in inbounds + outbounds if db not in db_names]
if missing_dbs and self.config.platform_instance:
self.report_warning(
"snowflake-shares",
f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.",
self.report.warning(
title="Extra Snowflake share configurations",
message="Some databases referenced by the share configs were not ingested. Siblings/lineage will not be set for these.",
context=f"{missing_dbs}",
)
elif missing_dbs:
logger.debug(
@ -113,15 +111,15 @@ class SnowflakeSharesHandler(SnowflakeCommonMixin):
) -> Iterable[MetadataWorkUnit]:
if not sibling_databases:
return
dataset_identifier = self.get_dataset_identifier(
dataset_identifier = self.identifiers.get_dataset_identifier(
table_name, schema_name, database_name
)
urn = self.dataset_urn_builder(dataset_identifier)
urn = self.identifiers.gen_dataset_urn(dataset_identifier)
sibling_urns = [
make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier(
self.identifiers.platform,
self.identifiers.get_dataset_identifier(
table_name, schema_name, sibling_db.database
),
sibling_db.platform_instance,
@ -141,14 +139,14 @@ class SnowflakeSharesHandler(SnowflakeCommonMixin):
table_name: str,
primary_sibling_db: DatabaseId,
) -> MetadataWorkUnit:
dataset_identifier = self.get_dataset_identifier(
dataset_identifier = self.identifiers.get_dataset_identifier(
table_name, schema_name, database_name
)
urn = self.dataset_urn_builder(dataset_identifier)
urn = self.identifiers.gen_dataset_urn(dataset_identifier)
upstream_urn = make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier(
self.identifiers.platform,
self.identifiers.get_dataset_identifier(
table_name, schema_name, primary_sibling_db.database
),
primary_sibling_db.platform_instance,

View File

@ -1,5 +1,4 @@
import dataclasses
import logging
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
@ -9,7 +8,10 @@ from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import SupportStatus, config_class, support_status
from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeFilterConfig
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeFilterConfig,
SnowflakeIdentifierConfig,
)
from datahub.ingestion.source.snowflake.snowflake_connection import (
SnowflakeConnectionConfig,
)
@ -17,6 +19,9 @@ from datahub.ingestion.source.snowflake.snowflake_schema import SnowflakeDatabas
from datahub.ingestion.source.snowflake.snowflake_schema_gen import (
SnowflakeSchemaGenerator,
)
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeIdentifierBuilder,
)
from datahub.ingestion.source_report.time_window import BaseTimeWindowReport
from datahub.utilities.lossy_collections import LossyList
@ -59,7 +64,6 @@ class SnowflakeSummarySource(Source):
super().__init__(ctx)
self.config: SnowflakeSummaryConfig = config
self.report: SnowflakeSummaryReport = SnowflakeSummaryReport()
self.logger = logging.getLogger(__name__)
self.connection = self.config.get_connection()
@ -69,7 +73,10 @@ class SnowflakeSummarySource(Source):
config=self.config, # type: ignore
report=self.report, # type: ignore
connection=self.connection,
dataset_urn_builder=lambda x: "",
identifiers=SnowflakeIdentifierBuilder(
identifier_config=SnowflakeIdentifierConfig(),
structured_reporter=self.report,
),
domain_registry=None,
profiler=None,
aggregator=None,

View File

@ -27,7 +27,6 @@ class SnowflakeTagExtractor(SnowflakeCommonMixin):
self.config = config
self.data_dictionary = data_dictionary
self.report = report
self.logger = logger
self.tag_cache: Dict[str, _SnowflakeTagCache] = {}
@ -69,16 +68,18 @@ class SnowflakeTagExtractor(SnowflakeCommonMixin):
) -> List[SnowflakeTag]:
identifier = ""
if domain == SnowflakeObjectDomain.DATABASE:
identifier = self.get_quoted_identifier_for_database(db_name)
identifier = self.identifiers.get_quoted_identifier_for_database(db_name)
elif domain == SnowflakeObjectDomain.SCHEMA:
assert schema_name is not None
identifier = self.get_quoted_identifier_for_schema(db_name, schema_name)
identifier = self.identifiers.get_quoted_identifier_for_schema(
db_name, schema_name
)
elif (
domain == SnowflakeObjectDomain.TABLE
): # Views belong to this domain as well.
assert schema_name is not None
assert table_name is not None
identifier = self.get_quoted_identifier_for_table(
identifier = self.identifiers.get_quoted_identifier_for_table(
db_name, schema_name, table_name
)
else:
@ -140,7 +141,7 @@ class SnowflakeTagExtractor(SnowflakeCommonMixin):
elif self.config.extract_tags == TagOption.with_lineage:
self.report.num_get_tags_on_columns_for_table_queries += 1
temp_column_tags = self.data_dictionary.get_tags_on_columns_for_table(
quoted_table_name=self.get_quoted_identifier_for_table(
quoted_table_name=self.identifiers.get_quoted_identifier_for_table(
db_name, schema_name, table_name
),
db_name=db_name,

View File

@ -2,7 +2,7 @@ import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple
import pydantic
@ -20,7 +20,11 @@ from datahub.ingestion.source.snowflake.snowflake_connection import (
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeFilter,
SnowflakeIdentifierBuilder,
)
from datahub.ingestion.source.state.redundant_run_skip_handler import (
RedundantUsageRunSkipHandler,
)
@ -112,13 +116,14 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
config: SnowflakeV2Config,
report: SnowflakeV2Report,
connection: SnowflakeConnection,
dataset_urn_builder: Callable[[str], str],
filter: SnowflakeFilter,
identifiers: SnowflakeIdentifierBuilder,
redundant_run_skip_handler: Optional[RedundantUsageRunSkipHandler],
) -> None:
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = report
self.dataset_urn_builder = dataset_urn_builder
self.logger = logger
self.filter = filter
self.identifiers = identifiers
self.connection = connection
self.redundant_run_skip_handler = redundant_run_skip_handler
@ -171,7 +176,7 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
bucket_duration=self.config.bucket_duration,
),
dataset_urns={
self.dataset_urn_builder(dataset_identifier)
self.identifiers.gen_dataset_urn(dataset_identifier)
for dataset_identifier in discovered_datasets
},
)
@ -232,7 +237,7 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
logger.debug(f"Processing usage row number {results.rownumber}")
logger.debug(self.report.usage_aggregation.as_string())
if not self.is_dataset_pattern_allowed(
if not self.filter.is_dataset_pattern_allowed(
row["OBJECT_NAME"],
row["OBJECT_DOMAIN"],
):
@ -242,7 +247,7 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
continue
dataset_identifier = (
self.get_dataset_identifier_from_qualified_name(
self.identifiers.get_dataset_identifier_from_qualified_name(
row["OBJECT_NAME"]
)
)
@ -279,7 +284,8 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
fieldCounts=self._map_field_counts(row["FIELD_COUNTS"]),
)
return MetadataChangeProposalWrapper(
entityUrn=self.dataset_urn_builder(dataset_identifier), aspect=stats
entityUrn=self.identifiers.gen_dataset_urn(dataset_identifier),
aspect=stats,
).as_workunit()
except Exception as e:
logger.debug(
@ -356,7 +362,9 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
return sorted(
[
DatasetFieldUsageCounts(
fieldPath=self.snowflake_identifier(field_count["col"]),
fieldPath=self.identifiers.snowflake_identifier(
field_count["col"]
),
count=field_count["total"],
)
for field_count in field_counts
@ -454,8 +462,10 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
for obj in event.objects_modified:
resource = obj.objectName
dataset_identifier = self.get_dataset_identifier_from_qualified_name(
resource
dataset_identifier = (
self.identifiers.get_dataset_identifier_from_qualified_name(
resource
)
)
if dataset_identifier not in discovered_datasets:
@ -476,7 +486,7 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
),
)
mcp = MetadataChangeProposalWrapper(
entityUrn=self.dataset_urn_builder(dataset_identifier),
entityUrn=self.identifiers.gen_dataset_urn(dataset_identifier),
aspect=operation_aspect,
)
wu = MetadataWorkUnit(
@ -561,7 +571,7 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
def _is_object_valid(self, obj: Dict[str, Any]) -> bool:
if self._is_unsupported_object_accessed(
obj
) or not self.is_dataset_pattern_allowed(
) or not self.filter.is_dataset_pattern_allowed(
obj.get("objectName"), obj.get("objectDomain")
):
return False

View File

@ -1,8 +1,7 @@
import abc
from functools import cached_property
from typing import ClassVar, Literal, Optional, Tuple
from typing_extensions import Protocol
from datahub.configuration.pattern_utils import is_schema_allowed
from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance
from datahub.ingestion.api.source import SourceReport
@ -25,42 +24,6 @@ class SnowflakeStructuredReportMixin(abc.ABC):
def structured_reporter(self) -> SourceReport:
...
# TODO: Eventually I want to deprecate these methods and use the structured_reporter directly.
def report_warning(self, key: str, reason: str) -> None:
self.structured_reporter.warning(key, reason)
def report_error(self, key: str, reason: str) -> None:
self.structured_reporter.failure(key, reason)
# Required only for mypy, since we are using mixin classes, and not inheritance.
# Reference - https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
class SnowflakeCommonProtocol(Protocol):
platform: str = "snowflake"
config: SnowflakeV2Config
report: SnowflakeV2Report
def get_dataset_identifier(
self, table_name: str, schema_name: str, db_name: str
) -> str:
...
def cleanup_qualified_name(self, qualified_name: str) -> str:
...
def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str:
...
def snowflake_identifier(self, identifier: str) -> str:
...
def report_warning(self, key: str, reason: str) -> None:
...
def report_error(self, key: str, reason: str) -> None:
...
class SnowsightUrlBuilder:
CLOUD_REGION_IDS_WITHOUT_CLOUD_SUFFIX: ClassVar = [
@ -140,17 +103,14 @@ class SnowsightUrlBuilder:
return f"{self.snowsight_base_url}#/data/databases/{db_name}/"
class SnowflakeFilterMixin(SnowflakeStructuredReportMixin):
@property
@abc.abstractmethod
def filter_config(self) -> SnowflakeFilterConfig:
...
class SnowflakeFilter:
def __init__(
self, filter_config: SnowflakeFilterConfig, structured_reporter: SourceReport
) -> None:
self.filter_config = filter_config
self.structured_reporter = structured_reporter
@staticmethod
def _combine_identifier_parts(
table_name: str, schema_name: str, db_name: str
) -> str:
return f"{db_name}.{schema_name}.{table_name}"
# TODO: Refactor remaining filtering logic into this class.
def is_dataset_pattern_allowed(
self,
@ -167,28 +127,35 @@ class SnowflakeFilterMixin(SnowflakeStructuredReportMixin):
SnowflakeObjectDomain.MATERIALIZED_VIEW,
):
return False
if len(dataset_params) != 3:
self.report_warning(
"invalid-dataset-pattern",
f"Found {dataset_params} of type {dataset_type}",
)
# NOTE: this case returned `True` earlier when extracting lineage
return False
if not self.filter_config.database_pattern.allowed(
dataset_params[0].strip('"')
) or not is_schema_allowed(
self.filter_config.schema_pattern,
dataset_params[1].strip('"'),
dataset_params[0].strip('"'),
self.filter_config.match_fully_qualified_names,
if len(dataset_params) != 3:
self.structured_reporter.info(
title="Unexpected dataset pattern",
message=f"Found a {dataset_type} with an unexpected number of parts. Database and schema filtering will not work as expected, but table filtering will still work.",
context=dataset_name,
)
# We fall-through here so table/view filtering still works.
if (
len(dataset_params) >= 1
and not self.filter_config.database_pattern.allowed(
dataset_params[0].strip('"')
)
) or (
len(dataset_params) >= 2
and not is_schema_allowed(
self.filter_config.schema_pattern,
dataset_params[1].strip('"'),
dataset_params[0].strip('"'),
self.filter_config.match_fully_qualified_names,
)
):
return False
if dataset_type.lower() in {
SnowflakeObjectDomain.TABLE
} and not self.filter_config.table_pattern.allowed(
self.cleanup_qualified_name(dataset_name)
_cleanup_qualified_name(dataset_name, self.structured_reporter)
):
return False
@ -196,41 +163,53 @@ class SnowflakeFilterMixin(SnowflakeStructuredReportMixin):
SnowflakeObjectDomain.VIEW,
SnowflakeObjectDomain.MATERIALIZED_VIEW,
} and not self.filter_config.view_pattern.allowed(
self.cleanup_qualified_name(dataset_name)
_cleanup_qualified_name(dataset_name, self.structured_reporter)
):
return False
return True
# Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers,
# For example "test-database"."test-schema".test_table
# whereas we generate urns without quotes even for quoted identifiers for backward compatibility
# and also unavailability of utility function to identify whether current table/schema/database
# name should be quoted in above method get_dataset_identifier
def cleanup_qualified_name(self, qualified_name: str) -> str:
name_parts = qualified_name.split(".")
if len(name_parts) != 3:
self.structured_reporter.report_warning(
title="Unexpected dataset pattern",
message="We failed to parse a Snowflake qualified name into its constituent parts. "
"DB/schema/table filtering may not work as expected on these entities.",
context=f"{qualified_name} has {len(name_parts)} parts",
)
return qualified_name.replace('"', "")
return SnowflakeFilterMixin._combine_identifier_parts(
table_name=name_parts[2].strip('"'),
schema_name=name_parts[1].strip('"'),
db_name=name_parts[0].strip('"'),
def _combine_identifier_parts(
*, table_name: str, schema_name: str, db_name: str
) -> str:
return f"{db_name}.{schema_name}.{table_name}"
# Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers,
# For example "test-database"."test-schema".test_table
# whereas we generate urns without quotes even for quoted identifiers for backward compatibility
# and also unavailability of utility function to identify whether current table/schema/database
# name should be quoted in above method get_dataset_identifier
def _cleanup_qualified_name(
qualified_name: str, structured_reporter: SourceReport
) -> str:
name_parts = qualified_name.split(".")
if len(name_parts) != 3:
structured_reporter.info(
title="Unexpected dataset pattern",
message="We failed to parse a Snowflake qualified name into its constituent parts. "
"DB/schema/table filtering may not work as expected on these entities.",
context=f"{qualified_name} has {len(name_parts)} parts",
)
return qualified_name.replace('"', "")
return _combine_identifier_parts(
db_name=name_parts[0].strip('"'),
schema_name=name_parts[1].strip('"'),
table_name=name_parts[2].strip('"'),
)
class SnowflakeIdentifierMixin(abc.ABC):
class SnowflakeIdentifierBuilder:
platform = "snowflake"
@property
@abc.abstractmethod
def identifier_config(self) -> SnowflakeIdentifierConfig:
...
def __init__(
self,
identifier_config: SnowflakeIdentifierConfig,
structured_reporter: SourceReport,
) -> None:
self.identifier_config = identifier_config
self.structured_reporter = structured_reporter
def snowflake_identifier(self, identifier: str) -> str:
# to be in in sync with older connector, convert name to lowercase
@ -242,7 +221,7 @@ class SnowflakeIdentifierMixin(abc.ABC):
self, table_name: str, schema_name: str, db_name: str
) -> str:
return self.snowflake_identifier(
SnowflakeCommonMixin._combine_identifier_parts(
_combine_identifier_parts(
table_name=table_name, schema_name=schema_name, db_name=db_name
)
)
@ -255,20 +234,10 @@ class SnowflakeIdentifierMixin(abc.ABC):
env=self.identifier_config.env,
)
# TODO: We're most of the way there on fully removing SnowflakeCommonProtocol.
class SnowflakeCommonMixin(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
@property
def structured_reporter(self: SnowflakeCommonProtocol) -> SourceReport:
return self.report
@property
def filter_config(self: SnowflakeCommonProtocol) -> SnowflakeFilterConfig:
return self.config
@property
def identifier_config(self: SnowflakeCommonProtocol) -> SnowflakeIdentifierConfig:
return self.config
def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str:
return self.snowflake_identifier(
_cleanup_qualified_name(qualified_name, self.structured_reporter)
)
@staticmethod
def get_quoted_identifier_for_database(db_name):
@ -278,40 +247,51 @@ class SnowflakeCommonMixin(SnowflakeFilterMixin, SnowflakeIdentifierMixin):
def get_quoted_identifier_for_schema(db_name, schema_name):
return f'"{db_name}"."{schema_name}"'
def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str:
return self.snowflake_identifier(self.cleanup_qualified_name(qualified_name))
@staticmethod
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
return f'"{db_name}"."{schema_name}"."{table_name}"'
class SnowflakeCommonMixin(SnowflakeStructuredReportMixin):
platform = "snowflake"
config: SnowflakeV2Config
report: SnowflakeV2Report
@property
def structured_reporter(self) -> SourceReport:
return self.report
@cached_property
def identifiers(self) -> SnowflakeIdentifierBuilder:
return SnowflakeIdentifierBuilder(self.config, self.report)
# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self: SnowflakeCommonProtocol,
self,
user_name: str,
user_email: Optional[str],
email_as_user_identifier: bool,
) -> str:
if user_email:
return self.snowflake_identifier(
return self.identifiers.snowflake_identifier(
user_email
if email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.snowflake_identifier(user_name)
return self.identifiers.snowflake_identifier(user_name)
# TODO: Revisit this after stateful ingestion can commit checkpoint
# for failures that do not affect the checkpoint
def warn_if_stateful_else_error(
self: SnowflakeCommonProtocol, key: str, reason: str
) -> None:
# TODO: Add additional parameters to match the signature of the .warning and .failure methods
def warn_if_stateful_else_error(self, key: str, reason: str) -> None:
if (
self.config.stateful_ingestion is not None
and self.config.stateful_ingestion.enabled
):
self.report_warning(key, reason)
self.structured_reporter.warning(key, reason)
else:
self.report_error(key, reason)
self.structured_reporter.failure(key, reason)

View File

@ -25,6 +25,7 @@ from datahub.ingestion.api.source import (
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.constants import (
GENERIC_PERMISSION_ERROR_KEY,
@ -42,6 +43,10 @@ from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import (
SnowflakeLineageExtractor,
)
from datahub.ingestion.source.snowflake.snowflake_profiler import SnowflakeProfiler
from datahub.ingestion.source.snowflake.snowflake_queries import (
SnowflakeQueriesExtractor,
SnowflakeQueriesExtractorConfig,
)
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_schema import (
SnowflakeDataDictionary,
@ -56,6 +61,8 @@ from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
)
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeFilter,
SnowflakeIdentifierBuilder,
SnowsightUrlBuilder,
)
from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler
@ -72,6 +79,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
from datahub.ingestion.source_report.ingestion_stage import (
LINEAGE_EXTRACTION,
METADATA_EXTRACTION,
QUERIES_EXTRACTION,
)
from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator
from datahub.utilities.registries.domain_registry import DomainRegistry
@ -127,9 +135,13 @@ class SnowflakeV2Source(
super().__init__(config, ctx)
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = SnowflakeV2Report()
self.logger = logger
self.connection = self.config.get_connection()
self.filters = SnowflakeFilter(
filter_config=self.config, structured_reporter=self.report
)
self.identifiers = SnowflakeIdentifierBuilder(
identifier_config=self.config, structured_reporter=self.report
)
self.domain_registry: Optional[DomainRegistry] = None
if self.config.domain:
@ -137,28 +149,29 @@ class SnowflakeV2Source(
cached_domains=[k for k in self.config.domain], graph=self.ctx.graph
)
self.connection = self.config.get_connection()
# For database, schema, tables, views, etc
self.data_dictionary = SnowflakeDataDictionary(connection=self.connection)
self.lineage_extractor: Optional[SnowflakeLineageExtractor] = None
self.aggregator: Optional[SqlParsingAggregator] = None
if self.config.include_table_lineage:
if self.config.use_queries_v2 or self.config.include_table_lineage:
self.aggregator = SqlParsingAggregator(
platform=self.platform,
platform=self.identifiers.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
graph=(
graph=self.ctx.graph,
eager_graph_load=(
# If we're ingestion schema metadata for tables/views, then we will populate
# schemas into the resolver as we go. We only need to do a bulk fetch
# if we're not ingesting schema metadata as part of ingestion.
self.ctx.graph
if not (
(
self.config.include_technical_schema
and self.config.include_tables
and self.config.include_views
)
and not self.config.lazy_schema_resolver
else None
),
generate_usage_statistics=False,
generate_operations=False,
@ -166,6 +179,8 @@ class SnowflakeV2Source(
)
self.report.sql_aggregator = self.aggregator.report
if self.config.include_table_lineage:
assert self.aggregator is not None
redundant_lineage_run_skip_handler: Optional[
RedundantLineageRunSkipHandler
] = None
@ -180,7 +195,8 @@ class SnowflakeV2Source(
config,
self.report,
connection=self.connection,
dataset_urn_builder=self.gen_dataset_urn,
filters=self.filters,
identifiers=self.identifiers,
redundant_run_skip_handler=redundant_lineage_run_skip_handler,
sql_aggregator=self.aggregator,
)
@ -201,7 +217,8 @@ class SnowflakeV2Source(
config,
self.report,
connection=self.connection,
dataset_urn_builder=self.gen_dataset_urn,
filter=self.filters,
identifiers=self.identifiers,
redundant_run_skip_handler=redundant_usage_run_skip_handler,
)
@ -445,7 +462,8 @@ class SnowflakeV2Source(
profiler=self.profiler,
aggregator=self.aggregator,
snowsight_url_builder=snowsight_url_builder,
dataset_urn_builder=self.gen_dataset_urn,
filters=self.filters,
identifiers=self.identifiers,
)
self.report.set_ingestion_stage("*", METADATA_EXTRACTION)
@ -453,30 +471,28 @@ class SnowflakeV2Source(
databases = schema_extractor.databases
self.connection.close()
# TODO: The checkpoint state for stale entity detection can be committed here.
if self.config.shares:
yield from SnowflakeSharesHandler(
self.config, self.report, self.gen_dataset_urn
self.config, self.report
).get_shares_workunits(databases)
discovered_tables: List[str] = [
self.get_dataset_identifier(table_name, schema.name, db.name)
self.identifiers.get_dataset_identifier(table_name, schema.name, db.name)
for db in databases
for schema in db.schemas
for table_name in schema.tables
]
discovered_views: List[str] = [
self.get_dataset_identifier(table_name, schema.name, db.name)
self.identifiers.get_dataset_identifier(table_name, schema.name, db.name)
for db in databases
for schema in db.schemas
for table_name in schema.views
]
if len(discovered_tables) == 0 and len(discovered_views) == 0:
self.report_error(
self.structured_reporter.failure(
GENERIC_PERMISSION_ERROR_KEY,
"No tables/views found. Please check permissions.",
)
@ -484,33 +500,66 @@ class SnowflakeV2Source(
discovered_datasets = discovered_tables + discovered_views
if self.config.include_table_lineage and self.lineage_extractor:
self.report.set_ingestion_stage("*", LINEAGE_EXTRACTION)
yield from self.lineage_extractor.get_workunits(
discovered_tables=discovered_tables,
discovered_views=discovered_views,
if self.config.use_queries_v2:
self.report.set_ingestion_stage("*", "View Parsing")
assert self.aggregator is not None
yield from auto_workunit(self.aggregator.gen_metadata())
self.report.set_ingestion_stage("*", QUERIES_EXTRACTION)
schema_resolver = self.aggregator._schema_resolver
queries_extractor = SnowflakeQueriesExtractor(
connection=self.connection,
config=SnowflakeQueriesExtractorConfig(
window=self.config,
temporary_tables_pattern=self.config.temporary_tables_pattern,
include_lineage=self.config.include_table_lineage,
include_usage_statistics=self.config.include_usage_stats,
include_operations=self.config.include_operational_stats,
),
structured_report=self.report,
filters=self.filters,
identifiers=self.identifiers,
schema_resolver=schema_resolver,
)
if (
self.config.include_usage_stats or self.config.include_operational_stats
) and self.usage_extractor:
yield from self.usage_extractor.get_usage_workunits(discovered_datasets)
# TODO: This is slightly suboptimal because we create two SqlParsingAggregator instances with different configs
# but a shared schema resolver. That's fine for now though - once we remove the old lineage/usage extractors,
# it should be pretty straightforward to refactor this and only initialize the aggregator once.
self.report.queries_extractor = queries_extractor.report
yield from queries_extractor.get_workunits_internal()
else:
if self.config.include_table_lineage and self.lineage_extractor:
self.report.set_ingestion_stage("*", LINEAGE_EXTRACTION)
yield from self.lineage_extractor.get_workunits(
discovered_tables=discovered_tables,
discovered_views=discovered_views,
)
if (
self.config.include_usage_stats or self.config.include_operational_stats
) and self.usage_extractor:
yield from self.usage_extractor.get_usage_workunits(discovered_datasets)
if self.config.include_assertion_results:
yield from SnowflakeAssertionsHandler(
self.config, self.report, self.connection
self.config, self.report, self.connection, self.identifiers
).get_assertion_workunits(discovered_datasets)
self.connection.close()
def report_warehouse_failure(self) -> None:
if self.config.warehouse is not None:
self.report_error(
self.structured_reporter.failure(
GENERIC_PERMISSION_ERROR_KEY,
f"Current role does not have permissions to use warehouse {self.config.warehouse}. Please update permissions.",
)
else:
self.report_error(
"no-active-warehouse",
"No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe.",
self.structured_reporter.failure(
"Could not use a Snowflake warehouse",
"No default warehouse set for user. Either set a default warehouse for the user or configure a warehouse in the recipe.",
)
def get_report(self) -> SourceReport:
@ -541,19 +590,28 @@ class SnowflakeV2Source(
for db_row in connection.query(SnowflakeQuery.current_version()):
self.report.saas_version = db_row["CURRENT_VERSION()"]
except Exception as e:
self.report_error("version", f"Error: {e}")
self.structured_reporter.failure(
"Could not determine the current Snowflake version",
exc=e,
)
try:
logger.info("Checking current role")
for db_row in connection.query(SnowflakeQuery.current_role()):
self.report.role = db_row["CURRENT_ROLE()"]
except Exception as e:
self.report_error("version", f"Error: {e}")
self.structured_reporter.failure(
"Could not determine the current Snowflake role",
exc=e,
)
try:
logger.info("Checking current warehouse")
for db_row in connection.query(SnowflakeQuery.current_warehouse()):
self.report.default_warehouse = db_row["CURRENT_WAREHOUSE()"]
except Exception as e:
self.report_error("current_warehouse", f"Error: {e}")
self.structured_reporter.failure(
"Could not determine the current Snowflake warehouse",
exc=e,
)
try:
logger.info("Checking current edition")

View File

@ -251,11 +251,6 @@ class AthenaConfig(SQLCommonConfig):
"queries executed by DataHub."
)
# overwrite default behavior of SQLAlchemyConfing
include_views: Optional[bool] = pydantic.Field(
default=True, description="Whether views should be ingested."
)
_s3_staging_dir_population = pydantic_renamed_field(
old_name="s3_staging_dir",
new_name="query_result_location",

View File

@ -83,10 +83,10 @@ class SQLCommonConfig(
description='Attach domains to databases, schemas or tables during ingestion using regex patterns. Domain key can be a guid like *urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba* or a string like "Marketing".) If you provide strings, then datahub will attempt to resolve this name to a guid, and will error out if this fails. There can be multiple domain keys specified.',
)
include_views: Optional[bool] = Field(
include_views: bool = Field(
default=True, description="Whether views should be ingested."
)
include_tables: Optional[bool] = Field(
include_tables: bool = Field(
default=True, description="Whether tables should be ingested."
)

View File

@ -14,6 +14,7 @@ LINEAGE_EXTRACTION = "Lineage Extraction"
USAGE_EXTRACTION_INGESTION = "Usage Extraction Ingestion"
USAGE_EXTRACTION_OPERATIONAL_STATS = "Usage Extraction Operational Stats"
USAGE_EXTRACTION_USAGE_AGGREGATION = "Usage Extraction Usage Aggregation"
QUERIES_EXTRACTION = "Queries Extraction"
PROFILING = "Profiling"

View File

@ -251,7 +251,9 @@ class SqlParsingAggregator(Closeable):
platform: str,
platform_instance: Optional[str] = None,
env: str = builder.DEFAULT_ENV,
schema_resolver: Optional[SchemaResolver] = None,
graph: Optional[DataHubGraph] = None,
eager_graph_load: bool = True,
generate_lineage: bool = True,
generate_queries: bool = True,
generate_query_subject_fields: bool = True,
@ -274,8 +276,12 @@ class SqlParsingAggregator(Closeable):
self.generate_usage_statistics = generate_usage_statistics
self.generate_query_usage_statistics = generate_query_usage_statistics
self.generate_operations = generate_operations
if self.generate_queries and not self.generate_lineage:
raise ValueError("Queries will only be generated if lineage is enabled")
if self.generate_queries and not (
self.generate_lineage or self.generate_query_usage_statistics
):
logger.warning(
"Queries will not be generated, as neither lineage nor query usage statistics are enabled"
)
self.usage_config = usage_config
if (
@ -297,17 +303,29 @@ class SqlParsingAggregator(Closeable):
# Set up the schema resolver.
self._schema_resolver: SchemaResolver
if graph is None:
if schema_resolver is not None:
# If explicitly provided, use it.
assert self.platform.platform_name == schema_resolver.platform
assert self.platform_instance == schema_resolver.platform_instance
assert self.env == schema_resolver.env
self._schema_resolver = schema_resolver
elif graph is not None and eager_graph_load and self._need_schemas:
# Bulk load schemas using the graph client.
self._schema_resolver = graph.initialize_schema_resolver_from_datahub(
platform=self.platform.urn(),
platform_instance=self.platform_instance,
env=self.env,
)
else:
# Otherwise, use a lazy-loading schema resolver.
self._schema_resolver = self._exit_stack.enter_context(
SchemaResolver(
platform=self.platform.platform_name,
platform_instance=self.platform_instance,
env=self.env,
graph=graph,
)
)
else:
self._schema_resolver = None # type: ignore
self._initialize_schema_resolver_from_graph(graph)
# Initialize internal data structures.
# This leans pretty heavily on the our query fingerprinting capabilities.
@ -373,6 +391,8 @@ class SqlParsingAggregator(Closeable):
# Usage aggregator. This will only be initialized if usage statistics are enabled.
# TODO: Replace with FileBackedDict.
# TODO: The BaseUsageConfig class is much too broad for our purposes, and has a number of
# configs that won't be respected here. Using it is misleading.
self._usage_aggregator: Optional[UsageAggregator[UrnStr]] = None
if self.generate_usage_statistics:
assert self.usage_config is not None
@ -392,7 +412,13 @@ class SqlParsingAggregator(Closeable):
@property
def _need_schemas(self) -> bool:
return self.generate_lineage or self.generate_usage_statistics
# Unless the aggregator is totally disabled, we will need schema information.
return (
self.generate_lineage
or self.generate_usage_statistics
or self.generate_queries
or self.generate_operations
)
def register_schema(
self, urn: Union[str, DatasetUrn], schema: models.SchemaMetadataClass
@ -414,35 +440,6 @@ class SqlParsingAggregator(Closeable):
yield wu
def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
# requires a graph instance
# if no schemas are currently registered in the schema resolver
# and we need the schema resolver (e.g. lineage or usage is enabled)
# then use the graph instance to fetch all schemas for the
# platform/instance/env combo
if not self._need_schemas:
return
if (
self._schema_resolver is not None
and self._schema_resolver.schema_count() > 0
):
# TODO: Have a mechanism to override this, e.g. when table ingestion is enabled but view ingestion is not.
logger.info(
"Not fetching any schemas from the graph, since "
f"there are {self._schema_resolver.schema_count()} schemas already registered."
)
return
# TODO: The initialize_schema_resolver_from_datahub method should take in a SchemaResolver
# that it can populate or add to, rather than creating a new one and dropping any schemas
# that were already loaded into the existing one.
self._schema_resolver = graph.initialize_schema_resolver_from_datahub(
platform=self.platform.urn(),
platform_instance=self.platform_instance,
env=self.env,
)
def _maybe_format_query(self, query: str) -> str:
if self.format_queries:
with self.report.sql_formatting_timer:

View File

@ -102,9 +102,7 @@ def test_snowflake_shares_workunit_no_shares(
config = SnowflakeV2Config(account_id="abc12345", platform_instance="instance1")
report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x)
)
shares_handler = SnowflakeSharesHandler(config, report)
wus = list(shares_handler.get_shares_workunits(snowflake_databases))
@ -204,9 +202,7 @@ def test_snowflake_shares_workunit_inbound_share(
)
report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x, "instance1")
)
shares_handler = SnowflakeSharesHandler(config, report)
wus = list(shares_handler.get_shares_workunits(snowflake_databases))
@ -262,9 +258,7 @@ def test_snowflake_shares_workunit_outbound_share(
)
report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x, "instance1")
)
shares_handler = SnowflakeSharesHandler(config, report)
wus = list(shares_handler.get_shares_workunits(snowflake_databases))
@ -313,9 +307,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share(
)
report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x, "instance1")
)
shares_handler = SnowflakeSharesHandler(config, report)
wus = list(shares_handler.get_shares_workunits(snowflake_databases))
@ -376,9 +368,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instan
)
report = SnowflakeV2Report()
shares_handler = SnowflakeSharesHandler(
config, report, lambda x: make_snowflake_urn(x)
)
shares_handler = SnowflakeSharesHandler(config, report)
assert sorted(config.outbounds().keys()) == ["db1", "db2_main"]
assert sorted(config.inbounds().keys()) == [