feat(ingest/snowflake): initialize schema resolver from datahub for l… (#8903)

This commit is contained in:
Mayuri Nehate 2023-10-04 16:23:31 +05:30 committed by GitHub
parent a300b39f15
commit e3780c2d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 24 deletions

View File

@ -7,7 +7,7 @@ import time
from dataclasses import dataclass
from datetime import datetime
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
from avro.schema import RecordSchema
from deprecated import deprecated
@ -1010,14 +1010,13 @@ class DataHubGraph(DatahubRestEmitter):
def initialize_schema_resolver_from_datahub(
self, platform: str, platform_instance: Optional[str], env: str
) -> Tuple["SchemaResolver", Set[str]]:
) -> "SchemaResolver":
logger.info("Initializing schema resolver")
schema_resolver = self._make_schema_resolver(
platform, platform_instance, env, include_graph=False
)
logger.info(f"Fetching schemas for platform {platform}, env {env}")
urns = []
count = 0
with PerfTimer() as timer:
for urn, schema_info in self._bulk_fetch_schema_info_by_filter(
@ -1026,7 +1025,6 @@ class DataHubGraph(DatahubRestEmitter):
env=env,
):
try:
urns.append(urn)
schema_resolver.add_graphql_schema_metadata(urn, schema_info)
count += 1
except Exception:
@ -1041,7 +1039,7 @@ class DataHubGraph(DatahubRestEmitter):
)
logger.info("Finished initializing schema resolver")
return schema_resolver, set(urns)
return schema_resolver
def parse_sql_lineage(
self,

View File

@ -458,7 +458,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)[0]
)
else:
logger.warning(
"Failed to load schema info from DataHub as DataHubGraph is missing.",

View File

@ -101,8 +101,8 @@ class SnowflakeV2Config(
)
include_view_column_lineage: bool = Field(
default=False,
description="Populates view->view and table->view column lineage.",
default=True,
description="Populates view->view and table->view column lineage using DataHub's sql parser.",
)
_check_role_grants_removed = pydantic_removed_field("check_role_grants")

View File

@ -301,14 +301,11 @@ class SnowflakeV2Source(
# Caches tables for a single database. Consider moving to disk or S3 when possible.
self.db_tables: Dict[str, List[SnowflakeTable]] = {}
self.sql_parser_schema_resolver = SchemaResolver(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
self.view_definitions: FileBackedDict[str] = FileBackedDict()
self.add_config_to_report()
self.sql_parser_schema_resolver = self._init_schema_resolver()
@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source":
config = SnowflakeV2Config.parse_obj(config_dict)
@ -493,6 +490,24 @@ class SnowflakeV2Source(
return _report
def _init_schema_resolver(self) -> SchemaResolver:
if not self.config.include_technical_schema and self.config.parse_view_ddl:
if self.ctx.graph:
return self.ctx.graph.initialize_schema_resolver_from_datahub(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
else:
logger.warning(
"Failed to load schema info from DataHub as DataHubGraph is missing.",
)
return SchemaResolver(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
@ -764,7 +779,7 @@ class SnowflakeV2Source(
)
self.db_tables[schema_name] = tables
if self.config.include_technical_schema or self.config.parse_view_ddl:
if self.config.include_technical_schema:
for table in tables:
yield from self._process_table(table, schema_name, db_name)
@ -776,7 +791,7 @@ class SnowflakeV2Source(
if view.view_definition:
self.view_definitions[key] = view.view_definition
if self.config.include_technical_schema or self.config.parse_view_ddl:
if self.config.include_technical_schema:
for view in views:
yield from self._process_view(view, schema_name, db_name)
@ -892,8 +907,6 @@ class SnowflakeV2Source(
yield from self._process_tag(tag)
yield from self.gen_dataset_workunits(table, schema_name, db_name)
elif self.config.parse_view_ddl:
self.gen_schema_metadata(table, schema_name, db_name)
def fetch_sample_data_for_classification(
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
@ -1004,8 +1017,6 @@ class SnowflakeV2Source(
yield from self._process_tag(tag)
yield from self.gen_dataset_workunits(view, schema_name, db_name)
elif self.config.parse_view_ddl:
self.gen_schema_metadata(view, schema_name, db_name)
def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
tag_identifier = tag.identifier()

View File

@ -103,13 +103,12 @@ class SqlQueriesSource(Source):
self.builder = SqlParsingBuilder(usage_config=self.config.usage)
if self.config.use_schema_resolver:
schema_resolver, urns = self.graph.initialize_schema_resolver_from_datahub(
self.schema_resolver = self.graph.initialize_schema_resolver_from_datahub(
platform=self.config.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
self.schema_resolver = schema_resolver
self.urns = urns
self.urns = self.schema_resolver.get_urns()
else:
self.schema_resolver = self.graph._make_schema_resolver(
platform=self.config.platform,

View File

@ -283,6 +283,9 @@ class SchemaResolver(Closeable):
shared_connection=shared_conn,
)
def get_urns(self) -> Set[str]:
return set(self._schema_cache.keys())
def get_urn_for_table(self, table: _TableName, lower: bool = False) -> str:
# TODO: Validate that this is the correct 2/3 layer hierarchy for the platform.
@ -397,8 +400,6 @@ class SchemaResolver(Closeable):
)
}
# TODO add a method to load all from graphql
def close(self) -> None:
self._schema_cache.close()