mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-03 22:23:37 +00:00
feat(ingest/snowflake): initialize schema resolver from datahub for l… (#8903)
This commit is contained in:
parent
a300b39f15
commit
e3780c2d75
@ -7,7 +7,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from avro.schema import RecordSchema
|
||||
from deprecated import deprecated
|
||||
@ -1010,14 +1010,13 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
|
||||
def initialize_schema_resolver_from_datahub(
|
||||
self, platform: str, platform_instance: Optional[str], env: str
|
||||
) -> Tuple["SchemaResolver", Set[str]]:
|
||||
) -> "SchemaResolver":
|
||||
logger.info("Initializing schema resolver")
|
||||
schema_resolver = self._make_schema_resolver(
|
||||
platform, platform_instance, env, include_graph=False
|
||||
)
|
||||
|
||||
logger.info(f"Fetching schemas for platform {platform}, env {env}")
|
||||
urns = []
|
||||
count = 0
|
||||
with PerfTimer() as timer:
|
||||
for urn, schema_info in self._bulk_fetch_schema_info_by_filter(
|
||||
@ -1026,7 +1025,6 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
env=env,
|
||||
):
|
||||
try:
|
||||
urns.append(urn)
|
||||
schema_resolver.add_graphql_schema_metadata(urn, schema_info)
|
||||
count += 1
|
||||
except Exception:
|
||||
@ -1041,7 +1039,7 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
)
|
||||
|
||||
logger.info("Finished initializing schema resolver")
|
||||
return schema_resolver, set(urns)
|
||||
return schema_resolver
|
||||
|
||||
def parse_sql_lineage(
|
||||
self,
|
||||
|
||||
@ -458,7 +458,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
|
||||
platform=self.platform,
|
||||
platform_instance=self.config.platform_instance,
|
||||
env=self.config.env,
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to load schema info from DataHub as DataHubGraph is missing.",
|
||||
|
||||
@ -101,8 +101,8 @@ class SnowflakeV2Config(
|
||||
)
|
||||
|
||||
include_view_column_lineage: bool = Field(
|
||||
default=False,
|
||||
description="Populates view->view and table->view column lineage.",
|
||||
default=True,
|
||||
description="Populates view->view and table->view column lineage using DataHub's sql parser.",
|
||||
)
|
||||
|
||||
_check_role_grants_removed = pydantic_removed_field("check_role_grants")
|
||||
|
||||
@ -301,14 +301,11 @@ class SnowflakeV2Source(
|
||||
# Caches tables for a single database. Consider moving to disk or S3 when possible.
|
||||
self.db_tables: Dict[str, List[SnowflakeTable]] = {}
|
||||
|
||||
self.sql_parser_schema_resolver = SchemaResolver(
|
||||
platform=self.platform,
|
||||
platform_instance=self.config.platform_instance,
|
||||
env=self.config.env,
|
||||
)
|
||||
self.view_definitions: FileBackedDict[str] = FileBackedDict()
|
||||
self.add_config_to_report()
|
||||
|
||||
self.sql_parser_schema_resolver = self._init_schema_resolver()
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source":
|
||||
config = SnowflakeV2Config.parse_obj(config_dict)
|
||||
@ -493,6 +490,24 @@ class SnowflakeV2Source(
|
||||
|
||||
return _report
|
||||
|
||||
def _init_schema_resolver(self) -> SchemaResolver:
|
||||
if not self.config.include_technical_schema and self.config.parse_view_ddl:
|
||||
if self.ctx.graph:
|
||||
return self.ctx.graph.initialize_schema_resolver_from_datahub(
|
||||
platform=self.platform,
|
||||
platform_instance=self.config.platform_instance,
|
||||
env=self.config.env,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to load schema info from DataHub as DataHubGraph is missing.",
|
||||
)
|
||||
return SchemaResolver(
|
||||
platform=self.platform,
|
||||
platform_instance=self.config.platform_instance,
|
||||
env=self.config.env,
|
||||
)
|
||||
|
||||
def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
|
||||
return [
|
||||
*super().get_workunit_processors(),
|
||||
@ -764,7 +779,7 @@ class SnowflakeV2Source(
|
||||
)
|
||||
self.db_tables[schema_name] = tables
|
||||
|
||||
if self.config.include_technical_schema or self.config.parse_view_ddl:
|
||||
if self.config.include_technical_schema:
|
||||
for table in tables:
|
||||
yield from self._process_table(table, schema_name, db_name)
|
||||
|
||||
@ -776,7 +791,7 @@ class SnowflakeV2Source(
|
||||
if view.view_definition:
|
||||
self.view_definitions[key] = view.view_definition
|
||||
|
||||
if self.config.include_technical_schema or self.config.parse_view_ddl:
|
||||
if self.config.include_technical_schema:
|
||||
for view in views:
|
||||
yield from self._process_view(view, schema_name, db_name)
|
||||
|
||||
@ -892,8 +907,6 @@ class SnowflakeV2Source(
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
yield from self.gen_dataset_workunits(table, schema_name, db_name)
|
||||
elif self.config.parse_view_ddl:
|
||||
self.gen_schema_metadata(table, schema_name, db_name)
|
||||
|
||||
def fetch_sample_data_for_classification(
|
||||
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
|
||||
@ -1004,8 +1017,6 @@ class SnowflakeV2Source(
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
yield from self.gen_dataset_workunits(view, schema_name, db_name)
|
||||
elif self.config.parse_view_ddl:
|
||||
self.gen_schema_metadata(view, schema_name, db_name)
|
||||
|
||||
def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
|
||||
tag_identifier = tag.identifier()
|
||||
|
||||
@ -103,13 +103,12 @@ class SqlQueriesSource(Source):
|
||||
self.builder = SqlParsingBuilder(usage_config=self.config.usage)
|
||||
|
||||
if self.config.use_schema_resolver:
|
||||
schema_resolver, urns = self.graph.initialize_schema_resolver_from_datahub(
|
||||
self.schema_resolver = self.graph.initialize_schema_resolver_from_datahub(
|
||||
platform=self.config.platform,
|
||||
platform_instance=self.config.platform_instance,
|
||||
env=self.config.env,
|
||||
)
|
||||
self.schema_resolver = schema_resolver
|
||||
self.urns = urns
|
||||
self.urns = self.schema_resolver.get_urns()
|
||||
else:
|
||||
self.schema_resolver = self.graph._make_schema_resolver(
|
||||
platform=self.config.platform,
|
||||
|
||||
@ -283,6 +283,9 @@ class SchemaResolver(Closeable):
|
||||
shared_connection=shared_conn,
|
||||
)
|
||||
|
||||
def get_urns(self) -> Set[str]:
|
||||
return set(self._schema_cache.keys())
|
||||
|
||||
def get_urn_for_table(self, table: _TableName, lower: bool = False) -> str:
|
||||
# TODO: Validate that this is the correct 2/3 layer hierarchy for the platform.
|
||||
|
||||
@ -397,8 +400,6 @@ class SchemaResolver(Closeable):
|
||||
)
|
||||
}
|
||||
|
||||
# TODO add a method to load all from graphql
|
||||
|
||||
def close(self) -> None:
|
||||
self._schema_cache.close()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user