feat(ingestion/redshift): collapse lineage to permanent table (#9704)

Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
Co-authored-by: treff7es <treff7es@gmail.com>
This commit is contained in:
sid-acryl 2024-02-02 02:17:09 +05:30 committed by GitHub
parent eb97120469
commit 533130408a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1515 additions and 59 deletions

View File

@ -1,4 +1,5 @@
"""Convenience functions for creating MCEs"""
import hashlib
import json
import logging

View File

@ -64,7 +64,7 @@ class PipelineContext:
# TODO: Get rid of this function once lower-casing is the standard.
if self.graph:
server_config = self.graph.get_config()
if server_config and server_config.get("datasetUrnNameCasing"):
if server_config and server_config.get("datasetUrnNameCasing") is True:
set_dataset_urn_to_lower(True)
def register_checkpointer(self, committable: Committable) -> None:

View File

@ -94,10 +94,10 @@ class RedshiftConfig(
description="The default schema to use if the sql parser fails to parse the schema with `sql_based` lineage collector",
)
include_table_lineage: Optional[bool] = Field(
include_table_lineage: bool = Field(
default=True, description="Whether table lineage should be ingested."
)
include_copy_lineage: Optional[bool] = Field(
include_copy_lineage: bool = Field(
default=True,
description="Whether lineage should be collected from copy commands",
)
@ -107,17 +107,15 @@ class RedshiftConfig(
description="Generate usage statistic. email_domain config parameter needs to be set if enabled",
)
include_unload_lineage: Optional[bool] = Field(
include_unload_lineage: bool = Field(
default=True,
description="Whether lineage should be collected from unload commands",
)
capture_lineage_query_parser_failures: Optional[bool] = Field(
hide_from_schema=True,
include_table_rename_lineage: bool = Field(
default=False,
description="Whether to capture lineage query parser errors with dataset properties for debugging",
description="Whether we should follow `alter table ... rename to` statements when computing lineage. ",
)
table_lineage_mode: Optional[LineageMode] = Field(
default=LineageMode.STL_SCAN_BASED,
description="Which table lineage collector mode to use. Available modes are: [stl_scan_based, sql_based, mixed]",
@ -139,6 +137,11 @@ class RedshiftConfig(
description="When enabled, emits lineage as incremental to existing lineage already in DataHub. When disabled, re-states lineage on each run. This config works with rest-sink only.",
)
resolve_temp_table_in_lineage: bool = Field(
default=False,
description="Whether to resolve temp table appear in lineage to upstream permanent tables.",
)
@root_validator(pre=True)
def check_email_is_set_on_usage(cls, values):
if values.get("include_usage_statistics"):

View File

@ -4,11 +4,12 @@ from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union, cast
from urllib.parse import urlparse
import humanfriendly
import redshift_connector
import sqlglot
import datahub.emitter.mce_builder as builder
import datahub.utilities.sqlglot_lineage as sqlglot_l
@ -24,17 +25,24 @@ from datahub.ingestion.source.redshift.redshift_schema import (
RedshiftSchema,
RedshiftTable,
RedshiftView,
TempTableRow,
)
from datahub.ingestion.source.redshift.report import RedshiftReport
from datahub.ingestion.source.state.redundant_run_skip_handler import (
RedundantLineageRunSkipHandler,
)
from datahub.metadata._schema_classes import SchemaFieldDataTypeClass
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineage,
FineGrainedLineageDownstreamType,
FineGrainedLineageUpstreamType,
UpstreamLineage,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
OtherSchema,
SchemaField,
SchemaMetadata,
)
from datahub.metadata.schema_classes import (
DatasetLineageTypeClass,
UpstreamClass,
@ -111,6 +119,34 @@ class LineageItem:
self.cll = self.cll or None
def parse_alter_table_rename(default_schema: str, query: str) -> Tuple[str, str, str]:
"""
Parses an ALTER TABLE ... RENAME TO ... query and returns the schema, previous table name, and new table name.
"""
parsed_query = sqlglot.parse_one(query, dialect="redshift")
assert isinstance(parsed_query, sqlglot.exp.AlterTable)
prev_name = parsed_query.this.name
rename_clause = parsed_query.args["actions"][0]
assert isinstance(rename_clause, sqlglot.exp.RenameTable)
new_name = rename_clause.this.name
schema = parsed_query.this.db or default_schema
return schema, prev_name, new_name
def split_qualified_table_name(urn: str) -> Tuple[str, str, str]:
qualified_table_name = dataset_urn.DatasetUrn.create_from_string(
urn
).get_entity_id()[1]
# -3 because platform instance is optional and that can cause the split to have more than 3 elements
db, schema, table = qualified_table_name.split(".")[-3:]
return db, schema, table
class RedshiftLineageExtractor:
def __init__(
self,
@ -130,6 +166,95 @@ class RedshiftLineageExtractor:
self.report.lineage_end_time,
) = self.get_time_window()
self.temp_tables: Dict[str, TempTableRow] = {}
def _init_temp_table_schema(
self, database: str, temp_tables: List[TempTableRow]
) -> None:
if self.context.graph is None: # to silent lint
return
schema_resolver: sqlglot_l.SchemaResolver = (
self.context.graph._make_schema_resolver(
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
env=self.config.env,
)
)
dataset_vs_columns: Dict[str, List[SchemaField]] = {}
# prepare dataset_urn vs List of schema fields
for table in temp_tables:
logger.debug(
f"Processing temp table: {table.create_command} with query text {table.query_text}"
)
result = sqlglot_l.create_lineage_sql_parsed_result(
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
env=self.config.env,
default_db=database,
default_schema=self.config.default_schema,
query=table.query_text,
graph=self.context.graph,
)
if (
result is None
or result.column_lineage is None
or result.query_type != sqlglot_l.QueryType.CREATE
or not result.out_tables
):
logger.debug(f"Unsupported temp table query found: {table.query_text}")
continue
table.parsed_result = result
if result.column_lineage[0].downstream.table:
table.urn = result.column_lineage[0].downstream.table
self.temp_tables[result.out_tables[0]] = table
for table in self.temp_tables.values():
if (
table.parsed_result is None
or table.parsed_result.column_lineage is None
):
continue
for column_lineage in table.parsed_result.column_lineage:
if column_lineage.downstream.table not in dataset_vs_columns:
dataset_vs_columns[cast(str, column_lineage.downstream.table)] = []
# Initialise the temp table urn, we later need this to merge CLL
dataset_vs_columns[cast(str, column_lineage.downstream.table)].append(
SchemaField(
fieldPath=column_lineage.downstream.column,
type=cast(
SchemaFieldDataTypeClass,
column_lineage.downstream.column_type,
),
nativeDataType=cast(
str, column_lineage.downstream.native_column_type
),
)
)
# Add datasets, and it's respective fields in schema_resolver, so that later schema_resolver would be able
# correctly generates the upstreams for temporary tables
for urn in dataset_vs_columns:
db, schema, table_name = split_qualified_table_name(urn)
schema_resolver.add_schema_metadata(
urn=urn,
schema_metadata=SchemaMetadata(
schemaName=table_name,
platform=builder.make_data_platform_urn(
LineageDatasetPlatform.REDSHIFT.value
),
version=0,
hash="",
platformSchema=OtherSchema(rawSchema=""),
fields=dataset_vs_columns[urn],
),
)
def get_time_window(self) -> Tuple[datetime, datetime]:
if self.redundant_run_skip_handler:
self.report.stateful_lineage_ingestion_enabled = True
@ -157,25 +282,32 @@ class RedshiftLineageExtractor:
return path
def _get_sources_from_query(
self, db_name: str, query: str
self,
db_name: str,
query: str,
parsed_result: Optional[sqlglot_l.SqlParsingResult] = None,
) -> Tuple[List[LineageDataset], Optional[List[sqlglot_l.ColumnLineageInfo]]]:
sources: List[LineageDataset] = list()
parsed_result: Optional[
sqlglot_l.SqlParsingResult
] = sqlglot_l.create_lineage_sql_parsed_result(
query=query,
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
default_db=db_name,
default_schema=str(self.config.default_schema),
graph=self.context.graph,
env=self.config.env,
)
if parsed_result is None:
parsed_result = sqlglot_l.create_lineage_sql_parsed_result(
query=query,
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
default_db=db_name,
default_schema=str(self.config.default_schema),
graph=self.context.graph,
env=self.config.env,
)
if parsed_result is None:
logger.debug(f"native query parsing failed for {query}")
return sources, None
elif parsed_result.debug_info.table_error:
logger.debug(
f"native query parsing failed for {query} with error: {parsed_result.debug_info.table_error}"
)
return sources, None
logger.debug(f"parsed_result = {parsed_result}")
@ -277,7 +409,7 @@ class RedshiftLineageExtractor:
database: str,
lineage_type: LineageCollectorType,
connection: redshift_connector.Connection,
all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
all_tables_set: Dict[str, Dict[str, Set[str]]],
) -> None:
"""
This method generate table level lineage based with the given query.
@ -292,7 +424,10 @@ class RedshiftLineageExtractor:
return: The method does not return with anything as it directly modify the self._lineage_map property.
:rtype: None
"""
logger.info(f"Extracting {lineage_type.name} lineage for db {database}")
try:
logger.debug(f"Processing lineage query: {query}")
cll: Optional[List[sqlglot_l.ColumnLineageInfo]] = None
raw_db_name = database
alias_db_name = self.config.database
@ -301,11 +436,18 @@ class RedshiftLineageExtractor:
conn=connection, query=query
):
target = self._get_target_lineage(
alias_db_name, lineage_row, lineage_type
alias_db_name,
lineage_row,
lineage_type,
all_tables_set=all_tables_set,
)
if not target:
continue
logger.debug(
f"Processing {lineage_type.name} lineage row: {lineage_row}"
)
sources, cll = self._get_sources(
lineage_type,
alias_db_name,
@ -318,9 +460,12 @@ class RedshiftLineageExtractor:
target.upstreams.update(
self._get_upstream_lineages(
sources=sources,
all_tables=all_tables,
target_table=target.dataset.urn,
target_dataset_cll=cll,
all_tables_set=all_tables_set,
alias_db_name=alias_db_name,
raw_db_name=raw_db_name,
connection=connection,
)
)
target.cll = cll
@ -344,21 +489,50 @@ class RedshiftLineageExtractor:
)
self.report_status(f"extract-{lineage_type.name}", False)
def _update_lineage_map_for_table_renames(
self, table_renames: Dict[str, str]
) -> None:
if not table_renames:
return
logger.info(f"Updating lineage map for {len(table_renames)} table renames")
for new_table_urn, prev_table_urn in table_renames.items():
# This table was renamed from some other name, copy in the lineage
# for the previous name as well.
prev_table_lineage = self._lineage_map.get(prev_table_urn)
if prev_table_lineage:
logger.debug(
f"including lineage for {prev_table_urn} in {new_table_urn} due to table rename"
)
self._lineage_map[new_table_urn].merge_lineage(
upstreams=prev_table_lineage.upstreams,
cll=prev_table_lineage.cll,
)
def _get_target_lineage(
self,
alias_db_name: str,
lineage_row: LineageRow,
lineage_type: LineageCollectorType,
all_tables_set: Dict[str, Dict[str, Set[str]]],
) -> Optional[LineageItem]:
if (
lineage_type != LineageCollectorType.UNLOAD
and lineage_row.target_schema
and lineage_row.target_table
):
if not self.config.schema_pattern.allowed(
lineage_row.target_schema
) or not self.config.table_pattern.allowed(
f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}"
if (
not self.config.schema_pattern.allowed(lineage_row.target_schema)
or not self.config.table_pattern.allowed(
f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}"
)
) and not (
# We also check the all_tables_set, since this might be a renamed table
# that we don't want to drop lineage for.
alias_db_name in all_tables_set
and lineage_row.target_schema in all_tables_set[alias_db_name]
and lineage_row.target_table
in all_tables_set[alias_db_name][lineage_row.target_schema]
):
return None
# Target
@ -400,18 +574,19 @@ class RedshiftLineageExtractor:
def _get_upstream_lineages(
self,
sources: List[LineageDataset],
all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
target_table: str,
all_tables_set: Dict[str, Dict[str, Set[str]]],
alias_db_name: str,
raw_db_name: str,
connection: redshift_connector.Connection,
target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]],
) -> List[LineageDataset]:
targe_source = []
target_source = []
probable_temp_tables: List[str] = []
for source in sources:
if source.platform == LineageDatasetPlatform.REDSHIFT:
qualified_table_name = dataset_urn.DatasetUrn.create_from_string(
source.urn
).get_entity_id()[1]
# -3 because platform instance is optional and that can cause the split to have more than 3 elements
db, schema, table = qualified_table_name.split(".")[-3:]
db, schema, table = split_qualified_table_name(source.urn)
if db == raw_db_name:
db = alias_db_name
path = f"{db}.{schema}.{table}"
@ -427,19 +602,40 @@ class RedshiftLineageExtractor:
# Filtering out tables which does not exist in Redshift
# It was deleted in the meantime or query parser did not capture well the table name
# Or it might be a temp table
if (
db not in all_tables
or schema not in all_tables[db]
or not any(table == t.name for t in all_tables[db][schema])
db not in all_tables_set
or schema not in all_tables_set[db]
or table not in all_tables_set[db][schema]
):
logger.debug(
f"{source.urn} missing table, dropping from lineage.",
f"{source.urn} missing table. Adding it to temp table list for target table {target_table}.",
)
probable_temp_tables.append(f"{schema}.{table}")
self.report.num_lineage_tables_dropped += 1
continue
targe_source.append(source)
return targe_source
target_source.append(source)
if probable_temp_tables and self.config.resolve_temp_table_in_lineage:
self.report.num_lineage_processed_temp_tables += len(probable_temp_tables)
# Generate lineage dataset from temporary tables
number_of_permanent_dataset_found: int = (
self.update_table_and_column_lineage(
db_name=raw_db_name,
connection=connection,
temp_table_names=probable_temp_tables,
target_source_dataset=target_source,
target_dataset_cll=target_dataset_cll,
)
)
logger.debug(
f"Number of permanent datasets found for {target_table} = {number_of_permanent_dataset_found} in "
f"temp tables {probable_temp_tables}"
)
return target_source
def populate_lineage(
self,
@ -447,8 +643,27 @@ class RedshiftLineageExtractor:
connection: redshift_connector.Connection,
all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
) -> None:
if self.config.resolve_temp_table_in_lineage:
self._init_temp_table_schema(
database=database,
temp_tables=self.get_temp_tables(connection=connection),
)
populate_calls: List[Tuple[str, LineageCollectorType]] = []
all_tables_set: Dict[str, Dict[str, Set[str]]] = {
db: {schema: {t.name for t in tables} for schema, tables in schemas.items()}
for db, schemas in all_tables.items()
}
table_renames: Dict[str, str] = {}
if self.config.include_table_rename_lineage:
table_renames, all_tables_set = self._process_table_renames(
database=database,
connection=connection,
all_tables=all_tables_set,
)
if self.config.table_lineage_mode in {
LineageMode.STL_SCAN_BASED,
LineageMode.MIXED,
@ -504,9 +719,12 @@ class RedshiftLineageExtractor:
database=database,
lineage_type=lineage_type,
connection=connection,
all_tables=all_tables,
all_tables_set=all_tables_set,
)
# Handling for alter table statements.
self._update_lineage_map_for_table_renames(table_renames=table_renames)
self.report.lineage_mem_size[self.config.database] = humanfriendly.format_size(
memory_footprint.total_size(self._lineage_map)
)
@ -613,3 +831,271 @@ class RedshiftLineageExtractor:
def report_status(self, step: str, status: bool) -> None:
if self.redundant_run_skip_handler:
self.redundant_run_skip_handler.report_current_run_status(step, status)
def _process_table_renames(
self,
database: str,
connection: redshift_connector.Connection,
all_tables: Dict[str, Dict[str, Set[str]]],
) -> Tuple[Dict[str, str], Dict[str, Dict[str, Set[str]]]]:
logger.info(f"Processing table renames for db {database}")
# new urn -> prev urn
table_renames: Dict[str, str] = {}
query = RedshiftQuery.alter_table_rename_query(
db_name=database,
start_time=self.start_time,
end_time=self.end_time,
)
for rename_row in RedshiftDataDictionary.get_alter_table_commands(
connection, query
):
schema, prev_name, new_name = parse_alter_table_rename(
default_schema=self.config.default_schema,
query=rename_row.query_text,
)
prev_urn = make_dataset_urn_with_platform_instance(
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
name=f"{database}.{schema}.{prev_name}",
env=self.config.env,
)
new_urn = make_dataset_urn_with_platform_instance(
platform=LineageDatasetPlatform.REDSHIFT.value,
platform_instance=self.config.platform_instance,
name=f"{database}.{schema}.{new_name}",
env=self.config.env,
)
table_renames[new_urn] = prev_urn
# We want to generate lineage for the previous name too.
all_tables[database][schema].add(prev_name)
logger.info(f"Discovered {len(table_renames)} table renames")
return table_renames, all_tables
def get_temp_tables(
self, connection: redshift_connector.Connection
) -> List[TempTableRow]:
ddl_query: str = RedshiftQuery.temp_table_ddl_query(
start_time=self.config.start_time,
end_time=self.config.end_time,
)
logger.debug(f"Temporary table ddl query = {ddl_query}")
temp_table_rows: List[TempTableRow] = []
for row in RedshiftDataDictionary.get_temporary_rows(
conn=connection,
query=ddl_query,
):
temp_table_rows.append(row)
return temp_table_rows
def find_temp_tables(
self, temp_table_rows: List[TempTableRow], temp_table_names: List[str]
) -> List[TempTableRow]:
matched_temp_tables: List[TempTableRow] = []
for table_name in temp_table_names:
prefixes = RedshiftQuery.get_temp_table_clause(table_name)
prefixes.extend(
RedshiftQuery.get_temp_table_clause(table_name.split(".")[-1])
)
for row in temp_table_rows:
if any(
row.create_command.lower().startswith(prefix) for prefix in prefixes
):
matched_temp_tables.append(row)
return matched_temp_tables
def resolve_column_refs(
self, column_refs: List[sqlglot_l.ColumnRef], depth: int = 0
) -> List[sqlglot_l.ColumnRef]:
"""
This method resolves the column reference to the original column reference.
For example, if the column reference is to a temporary table, it will be resolved to the original column
reference.
"""
max_depth = 10
resolved_column_refs: List[sqlglot_l.ColumnRef] = []
if not column_refs:
return column_refs
if depth >= max_depth:
logger.warning(
f"Max depth reached for resolving temporary columns: {column_refs}"
)
self.report.num_unresolved_temp_columns += 1
return column_refs
for ref in column_refs:
resolved = False
if ref.table in self.temp_tables:
table = self.temp_tables[ref.table]
if table.parsed_result and table.parsed_result.column_lineage:
for column_lineage in table.parsed_result.column_lineage:
if (
column_lineage.downstream.table == ref.table
and column_lineage.downstream.column == ref.column
):
resolved_column_refs.extend(
self.resolve_column_refs(
column_lineage.upstreams, depth=depth + 1
)
)
resolved = True
break
# If we reach here, it means that we were not able to resolve the column reference.
if resolved is False:
logger.warning(
f"Unable to resolve column reference {ref} to a permanent table"
)
else:
logger.debug(
f"Resolved column reference {ref} is not resolved because referenced table {ref.table} is not a temp table or not found. Adding reference as non-temp table. This is normal."
)
resolved_column_refs.append(ref)
return resolved_column_refs
def _update_target_dataset_cll(
self,
temp_table_urn: str,
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo],
source_dataset_cll: List[sqlglot_l.ColumnLineageInfo],
) -> None:
for target_column_lineage in target_dataset_cll:
upstreams: List[sqlglot_l.ColumnRef] = []
# Look for temp_table_urn in upstream of column_lineage, if found then we need to replace it with
# column of permanent table
for target_column_ref in target_column_lineage.upstreams:
if target_column_ref.table == temp_table_urn:
# Look for column_ref.table and column_ref.column in downstream of source_dataset_cll.
# The source_dataset_cll contains CLL generated from create statement of temp table (temp_table_urn)
for source_column_lineage in source_dataset_cll:
if (
source_column_lineage.downstream.table
== target_column_ref.table
and source_column_lineage.downstream.column
== target_column_ref.column
):
resolved_columns = self.resolve_column_refs(
source_column_lineage.upstreams
)
# Add all upstream of above temporary column into upstream of target column
upstreams.extend(resolved_columns)
continue
upstreams.append(target_column_ref)
if upstreams:
# update the upstreams
target_column_lineage.upstreams = upstreams
def _add_permanent_datasets_recursively(
self,
db_name: str,
temp_table_rows: List[TempTableRow],
visited_tables: Set[str],
connection: redshift_connector.Connection,
permanent_lineage_datasets: List[LineageDataset],
target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]],
) -> None:
transitive_temp_tables: List[TempTableRow] = []
for temp_table in temp_table_rows:
logger.debug(
f"Processing temp table with transaction id: {temp_table.transaction_id} and query text {temp_table.query_text}"
)
intermediate_l_datasets, cll = self._get_sources_from_query(
db_name=db_name,
query=temp_table.query_text,
parsed_result=temp_table.parsed_result,
)
if (
temp_table.urn is not None
and target_dataset_cll is not None
and cll is not None
): # condition to silent the lint
self._update_target_dataset_cll(
temp_table_urn=temp_table.urn,
target_dataset_cll=target_dataset_cll,
source_dataset_cll=cll,
)
# make sure lineage dataset should not contain a temp table
# if such dataset is present then add it to transitive_temp_tables to resolve it to original permanent table
for lineage_dataset in intermediate_l_datasets:
db, schema, table = split_qualified_table_name(lineage_dataset.urn)
if table in visited_tables:
# The table is already processed
continue
# Check if table found is again a temp table
repeated_temp_table: List[TempTableRow] = self.find_temp_tables(
temp_table_rows=list(self.temp_tables.values()),
temp_table_names=[table],
)
if not repeated_temp_table:
logger.debug(f"Unable to find table {table} in temp tables.")
if repeated_temp_table:
transitive_temp_tables.extend(repeated_temp_table)
visited_tables.add(table)
continue
permanent_lineage_datasets.append(lineage_dataset)
if transitive_temp_tables:
# recursive call
self._add_permanent_datasets_recursively(
db_name=db_name,
temp_table_rows=transitive_temp_tables,
visited_tables=visited_tables,
connection=connection,
permanent_lineage_datasets=permanent_lineage_datasets,
target_dataset_cll=target_dataset_cll,
)
def update_table_and_column_lineage(
self,
db_name: str,
temp_table_names: List[str],
connection: redshift_connector.Connection,
target_source_dataset: List[LineageDataset],
target_dataset_cll: Optional[List[sqlglot_l.ColumnLineageInfo]],
) -> int:
permanent_lineage_datasets: List[LineageDataset] = []
temp_table_rows: List[TempTableRow] = self.find_temp_tables(
temp_table_rows=list(self.temp_tables.values()),
temp_table_names=temp_table_names,
)
visited_tables: Set[str] = set(temp_table_names)
self._add_permanent_datasets_recursively(
db_name=db_name,
temp_table_rows=temp_table_rows,
visited_tables=visited_tables,
connection=connection,
permanent_lineage_datasets=permanent_lineage_datasets,
target_dataset_cll=target_dataset_cll,
)
target_source_dataset.extend(permanent_lineage_datasets)
return len(permanent_lineage_datasets)

View File

@ -1,9 +1,14 @@
from datetime import datetime
from typing import List
redshift_datetime_format = "%Y-%m-%d %H:%M:%S"
class RedshiftQuery:
CREATE_TEMP_TABLE_CLAUSE = "create temp table"
CREATE_TEMPORARY_TABLE_CLAUSE = "create temporary table"
CREATE_TABLE_CLAUSE = "create table"
list_databases: str = """SELECT datname FROM pg_database
WHERE (datname <> ('padb_harvest')::name)
AND (datname <> ('template0')::name)
@ -97,7 +102,7 @@ SELECT schemaname as schema_name,
NULL as table_description
FROM pg_catalog.svv_external_tables
ORDER BY "schema",
"relname";
"relname"
"""
list_columns: str = """
SELECT
@ -379,7 +384,8 @@ SELECT schemaname as schema_name,
target_schema,
target_table,
username,
querytxt as ddl
query as query_id,
LISTAGG(CASE WHEN LEN(RTRIM(querytxt)) = 0 THEN querytxt ELSE RTRIM(querytxt) END) WITHIN GROUP (ORDER BY sequence) as ddl
from
(
select
@ -388,7 +394,9 @@ SELECT schemaname as schema_name,
sti.table as target_table,
sti.database as cluster,
usename as username,
querytxt,
text as querytxt,
sq.query,
sequence,
si.starttime as starttime
from
stl_insert as si
@ -396,19 +404,20 @@ SELECT schemaname as schema_name,
sti.table_id = tbl
left join svl_user_info sui on
si.userid = sui.usesysid
left join stl_query sq on
left join STL_QUERYTEXT sq on
si.query = sq.query
left join stl_load_commits slc on
slc.query = si.query
where
sui.usename <> 'rdsdb'
and sq.aborted = 0
and slc.query IS NULL
and cluster = '{db_name}'
and si.starttime >= '{start_time}'
and si.starttime < '{end_time}'
and sequence < 320
) as target_tables
order by cluster, target_schema, target_table, starttime asc
group by cluster, query_id, target_schema, target_table, username, starttime
order by cluster, query_id, target_schema, target_table, starttime asc
""".format(
# We need the original database name for filtering
db_name=db_name,
@ -443,3 +452,118 @@ SELECT schemaname as schema_name,
start_time=start_time.strftime(redshift_datetime_format),
end_time=end_time.strftime(redshift_datetime_format),
)
@staticmethod
def get_temp_table_clause(table_name: str) -> List[str]:
return [
f"{RedshiftQuery.CREATE_TABLE_CLAUSE} {table_name}",
f"{RedshiftQuery.CREATE_TEMP_TABLE_CLAUSE} {table_name}",
f"{RedshiftQuery.CREATE_TEMPORARY_TABLE_CLAUSE} {table_name}",
]
@staticmethod
def temp_table_ddl_query(start_time: datetime, end_time: datetime) -> str:
start_time_str: str = start_time.strftime(redshift_datetime_format)
end_time_str: str = end_time.strftime(redshift_datetime_format)
return rf"""-- DataHub Redshift Source temp table DDL query
select
*
from
(
select
session_id,
transaction_id,
start_time,
userid,
REGEXP_REPLACE(REGEXP_SUBSTR(REGEXP_REPLACE(query_text,'\\\\n','\\n'), '(CREATE(?:[\\n\\s\\t]+(?:temp|temporary))?(?:[\\n\\s\\t]+)table(?:[\\n\\s\\t]+)[^\\n\\s\\t()-]+)', 0, 1, 'ipe'),'[\\n\\s\\t]+',' ',1,'p') as create_command,
query_text,
row_number() over (
partition by TRIM(query_text)
order by start_time desc
) rn
from
(
select
pid as session_id,
xid as transaction_id,
starttime as start_time,
type,
query_text,
userid
from
(
select
starttime,
pid,
xid,
type,
userid,
LISTAGG(case
when LEN(RTRIM(text)) = 0 then text
else RTRIM(text)
end,
'') within group (
order by sequence
) as query_text
from
SVL_STATEMENTTEXT
where
type in ('DDL', 'QUERY')
AND starttime >= '{start_time_str}'
AND starttime < '{end_time_str}'
-- See https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl-statementtext
AND sequence < 320
group by
starttime,
pid,
xid,
type,
userid
order by
starttime,
pid,
xid,
type,
userid
asc)
where
type in ('DDL', 'QUERY')
)
where
(create_command ilike 'create temp table %'
or create_command ilike 'create temporary table %'
-- we want to get all the create table statements and not just temp tables if non temp table is created and dropped in the same transaction
or create_command ilike 'create table %')
-- Redshift creates temp tables with the following names: volt_tt_%. We need to filter them out.
and query_text not ilike 'CREATE TEMP TABLE volt_tt_%'
and create_command not like 'CREATE TEMP TABLE volt_tt_'
-- We need to filter out our query and it was not possible earlier when we did not have any comment in the query
and query_text not ilike '%https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl-statementtext%'
)
where
rn = 1;
"""
@staticmethod
def alter_table_rename_query(
db_name: str, start_time: datetime, end_time: datetime
) -> str:
start_time_str: str = start_time.strftime(redshift_datetime_format)
end_time_str: str = end_time.strftime(redshift_datetime_format)
return f"""
SELECT transaction_id,
session_id,
start_time,
query_text
FROM sys_query_history SYS
WHERE SYS.status = 'success'
AND SYS.query_type = 'DDL'
AND SYS.database_name = '{db_name}'
AND SYS.start_time >= '{start_time_str}'
AND SYS.end_time < '{end_time_str}'
AND SYS.query_text ILIKE 'alter table % rename to %'
"""

View File

@ -9,6 +9,7 @@ from datahub.ingestion.source.redshift.query import RedshiftQuery
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
from datahub.utilities.hive_schema_to_avro import get_schema_fields_for_hive_column
from datahub.utilities.sqlglot_lineage import SqlParsingResult
logger: logging.Logger = logging.getLogger(__name__)
@ -80,6 +81,26 @@ class LineageRow:
filename: Optional[str]
@dataclass
class TempTableRow:
transaction_id: int
session_id: str
query_text: str
create_command: str
start_time: datetime
urn: Optional[str]
parsed_result: Optional[SqlParsingResult] = None
@dataclass
class AlterTableRow:
# TODO unify this type with TempTableRow
transaction_id: int
session_id: str
query_text: str
start_time: datetime
# this is a class to be a proxy to query Redshift
class RedshiftDataDictionary:
@staticmethod
@ -359,9 +380,62 @@ class RedshiftDataDictionary:
target_table=row[field_names.index("target_table")]
if "target_table" in field_names
else None,
ddl=row[field_names.index("ddl")] if "ddl" in field_names else None,
# See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html
# for why we need to remove the \\n.
ddl=row[field_names.index("ddl")].replace("\\n", "\n")
if "ddl" in field_names
else None,
filename=row[field_names.index("filename")]
if "filename" in field_names
else None,
)
rows = cursor.fetchmany()
@staticmethod
def get_temporary_rows(
conn: redshift_connector.Connection,
query: str,
) -> Iterable[TempTableRow]:
cursor = conn.cursor()
cursor.execute(query)
field_names = [i[0] for i in cursor.description]
rows = cursor.fetchmany()
while rows:
for row in rows:
yield TempTableRow(
transaction_id=row[field_names.index("transaction_id")],
session_id=row[field_names.index("session_id")],
# See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html
# for why we need to replace the \n with a newline.
query_text=row[field_names.index("query_text")].replace(
r"\n", "\n"
),
create_command=row[field_names.index("create_command")],
start_time=row[field_names.index("start_time")],
urn=None,
)
rows = cursor.fetchmany()
@staticmethod
def get_alter_table_commands(
conn: redshift_connector.Connection,
query: str,
) -> Iterable[AlterTableRow]:
# TODO: unify this with get_temporary_rows
cursor = RedshiftDataDictionary.get_query_result(conn, query)
field_names = [i[0] for i in cursor.description]
rows = cursor.fetchmany()
while rows:
for row in rows:
yield AlterTableRow(
transaction_id=row[field_names.index("transaction_id")],
session_id=row[field_names.index("session_id")],
query_text=row[field_names.index("query_text")],
start_time=row[field_names.index("start_time")],
)
rows = cursor.fetchmany()

View File

@ -35,6 +35,7 @@ class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowRep
num_lineage_tables_dropped: int = 0
num_lineage_dropped_query_parser: int = 0
num_lineage_dropped_not_support_copy_path: int = 0
num_lineage_processed_temp_tables = 0
lineage_start_time: Optional[datetime] = None
lineage_end_time: Optional[datetime] = None
@ -43,6 +44,7 @@ class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowRep
usage_start_time: Optional[datetime] = None
usage_end_time: Optional[datetime] = None
stateful_usage_ingestion_enabled: bool = False
num_unresolved_temp_columns: int = 0
def report_dropped(self, key: str) -> None:
self.filtered.append(key)

View File

@ -140,7 +140,9 @@ class SnowflakeV2Config(
# This is required since access_history table does not capture whether the table was temporary table.
temporary_tables_pattern: List[str] = Field(
default=DEFAULT_TABLES_DENY_LIST,
description="[Advanced] Regex patterns for temporary tables to filter in lineage ingestion. Specify regex to match the entire table name in database.schema.table format. Defaults are to set in such a way to ignore the temporary staging tables created by known ETL tools.",
description="[Advanced] Regex patterns for temporary tables to filter in lineage ingestion. Specify regex to "
"match the entire table name in database.schema.table format. Defaults are to set in such a way "
"to ignore the temporary staging tables created by known ETL tools.",
)
rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field(
@ -150,13 +152,16 @@ class SnowflakeV2Config(
shares: Optional[Dict[str, SnowflakeShareConfig]] = Field(
default=None,
description="Required if current account owns or consumes snowflake share."
" If specified, connector creates lineage and siblings relationship between current account's database tables and consumer/producer account's database tables."
"If specified, connector creates lineage and siblings relationship between current account's database tables "
"and consumer/producer account's database tables."
" Map of share name -> details of share.",
)
email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is provided, generates email addresses for snowflake users with unset emails, based on their username.",
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)
@validator("convert_urns_to_lowercase")

View File

@ -1037,6 +1037,14 @@ def _sqlglot_lineage_inner(
default_db = default_db.upper()
if default_schema:
default_schema = default_schema.upper()
if _is_dialect_instance(dialect, "redshift") and not default_schema:
# On Redshift, there's no "USE SCHEMA <schema>" command. The default schema
# is public, and "current schema" is the one at the front of the search path.
# See https://docs.aws.amazon.com/redshift/latest/dg/r_search_path.html
# and https://stackoverflow.com/questions/9067335/how-does-the-search-path-influence-identifier-resolution-and-the-current-schema?noredirect=1&lq=1
# default_schema = "public"
# TODO: Re-enable this.
pass
logger.debug("Parsing lineage from sql statement: %s", sql)
statement = _parse_statement(sql, dialect=dialect)

View File

@ -0,0 +1,104 @@
from datetime import datetime
from unittest.mock import MagicMock
def mock_temp_table_cursor(cursor: MagicMock) -> None:
cursor.description = [
["transaction_id"],
["session_id"],
["query_text"],
["create_command"],
["start_time"],
]
cursor.fetchmany.side_effect = [
[
(
126,
"abc",
"CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price) AS "
"price_usd from player_activity group by player_id",
"CREATE TABLE #player_price",
datetime.now(),
)
],
[
# Empty result to stop the while loop
],
]
def mock_stl_insert_table_cursor(cursor: MagicMock) -> None:
cursor.description = [
["source_schema"],
["source_table"],
["target_schema"],
["target_table"],
["ddl"],
]
cursor.fetchmany.side_effect = [
[
(
"public",
"#player_price",
"public",
"player_price_with_hike_v6",
"INSERT INTO player_price_with_hike_v6 SELECT (price_usd + 0.2 * price_usd) as price, '20%' FROM "
"#player_price",
)
],
[
# Empty result to stop the while loop
],
]
query_vs_cursor_mocker = {
(
"-- DataHub Redshift Source temp table DDL query\n select\n *\n "
"from\n (\n select\n session_id,\n "
" transaction_id,\n start_time,\n userid,\n "
" REGEXP_REPLACE(REGEXP_SUBSTR(REGEXP_REPLACE(query_text,'\\\\\\\\n','\\\\n'), '(CREATE(?:["
"\\\\n\\\\s\\\\t]+(?:temp|temporary))?(?:[\\\\n\\\\s\\\\t]+)table(?:[\\\\n\\\\s\\\\t]+)["
"^\\\\n\\\\s\\\\t()-]+)', 0, 1, 'ipe'),'[\\\\n\\\\s\\\\t]+',' ',1,'p') as create_command,\n "
" query_text,\n row_number() over (\n partition "
"by TRIM(query_text)\n order by start_time desc\n ) rn\n "
" from\n (\n select\n pid "
"as session_id,\n xid as transaction_id,\n starttime "
"as start_time,\n type,\n query_text,\n "
" userid\n from\n (\n "
"select\n starttime,\n pid,\n "
" xid,\n type,\n userid,\n "
" LISTAGG(case\n when LEN(RTRIM(text)) = 0 then text\n "
" else RTRIM(text)\n end,\n "
" '') within group (\n order by sequence\n "
" ) as query_text\n from\n "
"SVL_STATEMENTTEXT\n where\n type in ('DDL', "
"'QUERY')\n AND starttime >= '2024-01-01 12:00:00'\n "
" AND starttime < '2024-01-10 12:00:00'\n -- See "
"https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl"
"-statementtext\n AND sequence < 320\n group by\n "
" starttime,\n pid,\n "
"xid,\n type,\n userid\n "
" order by\n starttime,\n pid,\n "
" xid,\n type,\n userid\n "
" asc)\n where\n type in ('DDL', "
"'QUERY')\n )\n where\n (create_command ilike "
"'create temp table %'\n or create_command ilike 'create temporary table %'\n "
" -- we want to get all the create table statements and not just temp tables "
"if non temp table is created and dropped in the same transaction\n or "
"create_command ilike 'create table %')\n -- Redshift creates temp tables with "
"the following names: volt_tt_%. We need to filter them out.\n and query_text not "
"ilike 'CREATE TEMP TABLE volt_tt_%'\n and create_command not like 'CREATE TEMP "
"TABLE volt_tt_'\n -- We need to filter out our query and it was not possible "
"earlier when we did not have any comment in the query\n and query_text not ilike "
"'%https://stackoverflow.com/questions/72770890/redshift-result-size-exceeds-listagg-limit-on-svl"
"-statementtext%'\n\n )\n where\n rn = 1;\n "
): mock_temp_table_cursor,
"select * from test_collapse_temp_lineage": mock_stl_insert_table_cursor,
}
def mock_cursor(cursor: MagicMock, query: str) -> None:
query_vs_cursor_mocker[query](cursor=cursor)

View File

@ -1,8 +1,31 @@
from datetime import datetime
from functools import partial
from typing import List
from unittest.mock import MagicMock
import datahub.utilities.sqlglot_lineage as sqlglot_l
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.redshift.config import RedshiftConfig
from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor
from datahub.ingestion.source.redshift.lineage import (
LineageCollectorType,
LineageDataset,
LineageDatasetPlatform,
LineageItem,
RedshiftLineageExtractor,
parse_alter_table_rename,
)
from datahub.ingestion.source.redshift.redshift_schema import TempTableRow
from datahub.ingestion.source.redshift.report import RedshiftReport
from datahub.utilities.sqlglot_lineage import ColumnLineageInfo, DownstreamColumnRef
from datahub.metadata._schema_classes import NumberTypeClass, SchemaFieldDataTypeClass
from datahub.utilities.sqlglot_lineage import (
ColumnLineageInfo,
DownstreamColumnRef,
QueryType,
SqlParsingDebugInfo,
SqlParsingResult,
)
from tests.unit.redshift_query_mocker import mock_cursor
def test_get_sources_from_query():
@ -120,16 +143,45 @@ def test_get_sources_from_query_with_only_table():
)
def test_cll():
config = RedshiftConfig(host_port="localhost:5439", database="test")
def test_parse_alter_table_rename():
assert parse_alter_table_rename("public", "alter table foo rename to bar") == (
"public",
"foo",
"bar",
)
assert parse_alter_table_rename(
"public", "alter table second_schema.storage_v2_stg rename to storage_v2; "
) == (
"second_schema",
"storage_v2_stg",
"storage_v2",
)
def get_lineage_extractor() -> RedshiftLineageExtractor:
config = RedshiftConfig(
host_port="localhost:5439",
database="test",
resolve_temp_table_in_lineage=True,
start_time=datetime(2024, 1, 1, 12, 0, 0).isoformat() + "Z",
end_time=datetime(2024, 1, 10, 12, 0, 0).isoformat() + "Z",
)
report = RedshiftReport()
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo", graph=mock_graph())
)
return lineage_extractor
def test_cll():
test_query = """
select a,b,c from db.public.customer inner join db.public.order on db.public.customer.id = db.public.order.customer_id
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_extractor = get_lineage_extractor()
_, cll = lineage_extractor._get_sources_from_query(db_name="db", query=test_query)
assert cll == [
@ -149,3 +201,600 @@ def test_cll():
logic=None,
),
]
def cursor_execute_side_effect(cursor: MagicMock, query: str) -> None:
mock_cursor(cursor=cursor, query=query)
def mock_redshift_connection() -> MagicMock:
connection = MagicMock()
cursor = MagicMock()
connection.cursor.return_value = cursor
cursor.execute.side_effect = partial(cursor_execute_side_effect, cursor)
return connection
def mock_graph() -> DataHubGraph:
graph = MagicMock()
graph._make_schema_resolver.return_value = sqlglot_l.SchemaResolver(
platform="redshift",
env="PROD",
platform_instance=None,
graph=None,
)
return graph
def test_collapse_temp_lineage():
lineage_extractor = get_lineage_extractor()
connection: MagicMock = mock_redshift_connection()
lineage_extractor._init_temp_table_schema(
database=lineage_extractor.config.database,
temp_tables=lineage_extractor.get_temp_tables(connection=connection),
)
lineage_extractor._populate_lineage_map(
query="select * from test_collapse_temp_lineage",
database=lineage_extractor.config.database,
all_tables_set={
lineage_extractor.config.database: {"public": {"player_price_with_hike_v6"}}
},
connection=connection,
lineage_type=LineageCollectorType.QUERY_SQL_PARSER,
)
print(lineage_extractor._lineage_map)
target_urn: str = "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.player_price_with_hike_v6,PROD)"
assert lineage_extractor._lineage_map.get(target_urn) is not None
lineage_item: LineageItem = lineage_extractor._lineage_map[target_urn]
assert list(lineage_item.upstreams)[0].urn == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"test.public.player_activity,PROD)"
)
assert lineage_item.cll is not None
assert lineage_item.cll[0].downstream.table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"test.public.player_price_with_hike_v6,PROD)"
)
assert lineage_item.cll[0].downstream.column == "price"
assert lineage_item.cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"test.public.player_activity,PROD)"
)
assert lineage_item.cll[0].upstreams[0].column == "price"
def test_collapse_temp_recursive_cll_lineage():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd "
"from player_activity",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="price",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
)
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
assert (
datasets[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].column == "price"
def test_collapse_temp_recursive_with_compex_column_cll_lineage():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price+tax) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price",
),
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="tax",
),
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, price, tax "
"from player_activity",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="price",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="tax",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="tax",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
)
],
logic=None,
),
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
assert (
datasets[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].column == "price"
assert target_dataset_cll[0].upstreams[1].column == "tax"
assert target_dataset_cll[1].upstreams[0].column == "player_id"
def test_collapse_temp_recursive_cll_lineage_with_circular_reference():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd "
"from #player_price",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
)
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
# Here we only interested if it fails or not