feat(ingest): schema-aware SQL parsing for column-level lineage (#8334)

This commit is contained in:
Harshal Sheth 2023-07-07 16:24:35 -07:00 committed by GitHub
parent 1f84bf5b2b
commit 3e47b3d228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2334 additions and 10 deletions

View File

@ -133,6 +133,12 @@ sqllineage_lib = {
"sqlparse==0.4.3",
}
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==16.7.6.dev6",
}
aws_common = {
# AWS Python SDK
"boto3",
@ -269,6 +275,8 @@ plugins: Dict[str, Set[str]] = {
"gql[requests]>=3.3.0",
},
"great-expectations": sql_common | sqllineage_lib,
# Misc plugins.
"sql-parser": sqlglot_lib,
# Source plugins
# PyAthena is pinned with exact version because we use private method in PyAthena
"athena": sql_common | {"PyAthena[SQLAlchemy]==2.4.1"},
@ -276,7 +284,9 @@ plugins: Dict[str, Set[str]] = {
"bigquery": sql_common
| bigquery_common
| {
# TODO: I doubt we need all three sql parsing libraries.
*sqllineage_lib,
*sqlglot_lib,
"sql_metadata",
"sqlalchemy-bigquery>=1.4.1",
"google-cloud-datacatalog-lineage==0.2.2",
@ -285,6 +295,7 @@ plugins: Dict[str, Set[str]] = {
| bigquery_common
| {
*sqllineage_lib,
*sqlglot_lib,
"sql_metadata",
"sqlalchemy-bigquery>=1.4.1",
}, # deprecated, but keeping the extra for backwards compatibility

View File

@ -1,14 +1,21 @@
import logging
import shutil
import tempfile
from typing import Optional
import click
from datahub import __package_name__
from datahub.cli.json_file import check_mce_file
from datahub.emitter.mce_builder import DEFAULT_ENV
from datahub.ingestion.graph.client import get_default_graph
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.sink.sink_registry import sink_registry
from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry
from datahub.telemetry import telemetry
logger = logging.getLogger(__name__)
@click.group()
@ -28,6 +35,7 @@ def check() -> None:
@click.option(
"--unpack-mces", default=False, is_flag=True, help="Converts MCEs into MCPs"
)
@telemetry.with_telemetry()
def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None:
"""Check the schema of a metadata (MCE or MCP) JSON file."""
@ -70,6 +78,7 @@ def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None:
default=False,
help="Include extra information for each plugin.",
)
@telemetry.with_telemetry()
def plugins(verbose: bool) -> None:
"""List the enabled ingestion plugins."""
@ -87,3 +96,68 @@ def plugins(verbose: bool) -> None:
click.echo(
f"If a plugin is disabled, try running: pip install '{__package_name__}[<plugin>]'"
)
@check.command()
@click.option(
"--sql",
type=str,
required=True,
help="The SQL query to parse",
)
@click.option(
"--platform",
type=str,
required=True,
help="The SQL dialect e.g. bigquery or snowflake",
)
@click.option(
"--platform-instance",
type=str,
help="The specific platform_instance the SQL query was run in",
)
@click.option(
"--env",
type=str,
default=DEFAULT_ENV,
help=f"The environment the SQL query was run in, defaults to {DEFAULT_ENV}",
)
@click.option(
"--default-db",
type=str,
help="The default database to use for unqualified table names",
)
@click.option(
"--default-schema",
type=str,
help="The default schema to use for unqualified table names",
)
@telemetry.with_telemetry()
def sql_lineage(
sql: str,
platform: str,
default_db: Optional[str],
default_schema: Optional[str],
platform_instance: Optional[str],
env: str,
) -> None:
"""Parse the lineage of a SQL query.
This performs schema-aware parsing in order to generate column-level lineage.
If the relevant tables are not in DataHub, this will be less accurate.
"""
graph = get_default_graph()
lineage = graph.parse_sql_lineage(
sql,
platform=platform,
platform_instance=platform_instance,
env=env,
default_db=default_db,
default_schema=default_schema,
)
logger.debug("Sql parsing debug info: %s", lineage.debug_info)
click.echo(lineage.json(indent=4))

View File

@ -1,4 +1,5 @@
import enum
import functools
import json
import logging
import textwrap
@ -16,7 +17,7 @@ from datahub.cli.cli_utils import get_url_and_token
from datahub.configuration.common import ConfigModel, GraphError, OperationalError
from datahub.configuration.validate_field_removal import pydantic_removed_field
from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP
from datahub.emitter.mce_builder import Aspect, make_data_platform_urn
from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect, make_data_platform_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.emitter.serialization_helper import post_json_transform
@ -44,6 +45,7 @@ if TYPE_CHECKING:
from datahub.ingestion.source.state.entity_removal_state import (
GenericCheckpointState,
)
from datahub.utilities.sqlglot_lineage import SchemaResolver, SqlParsingResult
logger = logging.getLogger(__name__)
@ -955,6 +957,46 @@ class DataHubGraph(DatahubRestEmitter):
related_aspects = response.get("relatedAspects", [])
return reference_count, related_aspects
@functools.lru_cache()
def _make_schema_resolver(
self, platform: str, platform_instance: Optional[str], env: str
) -> "SchemaResolver":
from datahub.utilities.sqlglot_lineage import SchemaResolver
return SchemaResolver(
platform=platform,
platform_instance=platform_instance,
env=env,
graph=self,
)
def parse_sql_lineage(
self,
sql: str,
*,
platform: str,
platform_instance: Optional[str] = None,
env: str = DEFAULT_ENV,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> "SqlParsingResult":
from datahub.utilities.sqlglot_lineage import sqlglot_lineage
# Cache the schema resolver to make bulk parsing faster.
schema_resolver = self._make_schema_resolver(
platform=platform,
platform_instance=platform_instance,
env=env,
)
return sqlglot_lineage(
sql,
platform=platform,
schema_resolver=schema_resolver,
default_db=default_db,
default_schema=default_schema,
)
def get_default_graph() -> DataHubGraph:
(url, token) = get_url_and_token()

View File

@ -62,6 +62,8 @@ BigQueryAuditMetadata = Any
logger: logging.Logger = logging.getLogger(__name__)
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = "((.+)[_$])?(\\d{8})$"
@dataclass(frozen=True, order=True)
class BigqueryTableIdentifier:
@ -70,7 +72,12 @@ class BigqueryTableIdentifier:
table: str
invalid_chars: ClassVar[Set[str]] = {"$", "@"}
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{8})$"
# Note: this regex may get overwritten by the sharded_table_pattern config.
# The class-level constant, however, will not be overwritten.
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[
str
] = _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX
_BIGQUERY_WILDCARD_REGEX: ClassVar[str] = "((_(\\d+)?)\\*$)|\\*$"
_BQ_SHARDED_TABLE_SUFFIX: str = "_yyyymmdd"

View File

@ -0,0 +1,6 @@
"""Testing utilities for the datahub package.
These modules are included in the `datahub.testing` namespace, but aren't
part of the public API. They're placed here, rather than in the `tests`
directory, so they can be used by other packages that depend on `datahub`.
"""

View File

@ -0,0 +1,90 @@
import logging
import os
import pathlib
from typing import Any, Dict, Optional
import deepdiff
from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier
from datahub.utilities.sqlglot_lineage import (
SchemaInfo,
SchemaResolver,
SqlParsingResult,
sqlglot_lineage,
)
logger = logging.getLogger(__name__)
# TODO: Hook this into the standard --update-golden-files mechanism.
UPDATE_FILES = os.environ.get("UPDATE_SQLPARSER_FILES", "false").lower() == "true"
def assert_sql_result_with_resolver(
sql: str,
*,
dialect: str,
expected_file: pathlib.Path,
schema_resolver: SchemaResolver,
**kwargs: Any,
) -> None:
# HACK: Our BigQuery source overwrites this value and doesn't undo it.
# As such, we need to handle that here.
BigqueryTableIdentifier._BQ_SHARDED_TABLE_SUFFIX = "_yyyymmdd"
res = sqlglot_lineage(
sql,
platform=dialect,
schema_resolver=schema_resolver,
**kwargs,
)
if res.debug_info.column_error:
logger.warning(
f"SQL parser column error: {res.debug_info.column_error}",
exc_info=res.debug_info.column_error,
)
txt = res.json(indent=4)
if UPDATE_FILES:
expected_file.write_text(txt)
return
if not expected_file.exists():
expected_file.write_text(txt)
raise AssertionError(
f"Expected file {expected_file} does not exist. "
"Created it with the expected output. Please verify it."
)
expected = SqlParsingResult.parse_raw(expected_file.read_text())
full_diff = deepdiff.DeepDiff(
expected.dict(),
res.dict(),
exclude_regex_paths=[
r"root.column_lineage\[\d+\].logic",
],
)
assert not full_diff, full_diff
def assert_sql_result(
sql: str,
*,
dialect: str,
expected_file: pathlib.Path,
schemas: Optional[Dict[str, SchemaInfo]] = None,
**kwargs: Any,
) -> None:
schema_resolver = SchemaResolver(platform=dialect)
if schemas:
for urn, schema in schemas.items():
schema_resolver.add_raw_schema_info(urn, schema)
assert_sql_result_with_resolver(
sql,
dialect=dialect,
expected_file=expected_file,
schema_resolver=schema_resolver,
**kwargs,
)

View File

@ -55,18 +55,30 @@ class ConnectionWrapper:
conn: sqlite3.Connection
filename: pathlib.Path
_directory: Optional[tempfile.TemporaryDirectory]
_allow_table_name_reuse: bool
def __init__(self, filename: Optional[pathlib.Path] = None):
self._directory = None
# Warning: If filename is provided, the file will not be automatically cleaned up
if not filename:
# In the normal case, we do not use "IF NOT EXISTS" in our create table statements
# because creating the same table twice indicates a client usage error.
# However, if you're trying to persist a file-backed dict across multiple runs,
# which happens when filename is passed explicitly, then we need to allow table name reuse.
allow_table_name_reuse = False
# Warning: If filename is provided, the file will not be automatically cleaned up.
if filename:
allow_table_name_reuse = True
else:
self._directory = tempfile.TemporaryDirectory()
filename = pathlib.Path(self._directory.name) / _DEFAULT_FILE_NAME
self.conn = sqlite3.connect(filename, isolation_level=None)
self.conn.row_factory = sqlite3.Row
self.filename = filename
self._allow_table_name_reuse = allow_table_name_reuse
# These settings are optimized for performance.
# See https://www.sqlite.org/pragma.html for more information.
@ -80,13 +92,13 @@ class ConnectionWrapper:
def execute(
self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = ()
) -> sqlite3.Cursor:
logger.debug(f"Executing <{sql}> ({parameters})")
# logger.debug(f"Executing <{sql}> ({parameters})")
return self.conn.execute(sql, parameters)
def executemany(
self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = ()
) -> sqlite3.Cursor:
logger.debug(f"Executing many <{sql}> ({parameters})")
# logger.debug(f"Executing many <{sql}> ({parameters})")
return self.conn.executemany(sql, parameters)
def close(self) -> None:
@ -184,10 +196,10 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]):
# a poor-man's LRU cache.
self._active_object_cache = collections.OrderedDict()
# Create the table. We're not using "IF NOT EXISTS" because creating
# the same table twice indicates a client usage error.
# Create the table.
if_not_exists = "IF NOT EXISTS" if self._conn._allow_table_name_reuse else ""
self._conn.execute(
f"""CREATE TABLE {self.tablename} (
f"""CREATE TABLE {if_not_exists} {self.tablename} (
key TEXT PRIMARY KEY,
value BLOB
{''.join(f', {column_name} BLOB' for column_name in self.extra_columns.keys())}

View File

@ -0,0 +1,754 @@
import contextlib
import enum
import functools
import itertools
import logging
import pathlib
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
import pydantic
import pydantic.dataclasses
import sqlglot
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
from pydantic import BaseModel
from datahub.emitter.mce_builder import (
DEFAULT_ENV,
make_dataset_urn_with_platform_instance,
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier
from datahub.metadata.schema_classes import SchemaMetadataClass
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict
from datahub.utilities.urns.dataset_urn import DatasetUrn
logger = logging.getLogger(__name__)
Urn = str
# A lightweight table schema: column -> type mapping.
SchemaInfo = Dict[str, str]
class QueryType(enum.Enum):
CREATE = "CREATE"
SELECT = "SELECT"
INSERT = "INSERT"
UPDATE = "UPDATE"
DELETE = "DELETE"
MERGE = "MERGE"
UNKNOWN = "UNKNOWN"
def get_query_type_of_sql(expression: sqlglot.exp.Expression) -> QueryType:
mapping = {
sqlglot.exp.Create: QueryType.CREATE,
sqlglot.exp.Select: QueryType.SELECT,
sqlglot.exp.Insert: QueryType.INSERT,
sqlglot.exp.Update: QueryType.UPDATE,
sqlglot.exp.Delete: QueryType.DELETE,
sqlglot.exp.Merge: QueryType.MERGE,
}
for cls, query_type in mapping.items():
if isinstance(expression, cls):
return query_type
return QueryType.UNKNOWN
@functools.total_ordering
class _FrozenModel(BaseModel, frozen=True):
def __lt__(self, other: "_FrozenModel") -> bool:
for field in self.__fields__:
self_v = getattr(self, field)
other_v = getattr(other, field)
if self_v != other_v:
return self_v < other_v
return False
class _TableName(_FrozenModel):
database: Optional[str]
db_schema: Optional[str]
table: str
def as_sqlglot_table(self) -> sqlglot.exp.Table:
return sqlglot.exp.Table(
catalog=self.database, db=self.db_schema, this=self.table
)
def qualified(
self,
dialect: str,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> "_TableName":
database = self.database or default_db
db_schema = self.db_schema or default_schema
return _TableName(
database=database,
db_schema=db_schema,
table=self.table,
)
@classmethod
def from_sqlglot_table(
cls,
table: sqlglot.exp.Table,
dialect: str,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> "_TableName":
return cls(
database=table.catalog or default_db,
db_schema=table.db or default_schema,
table=table.this.name,
)
class _ColumnRef(_FrozenModel):
table: _TableName
column: str
class ColumnRef(BaseModel):
table: Urn
column: str
class _DownstreamColumnRef(BaseModel):
table: Optional[_TableName]
column: str
class DownstreamColumnRef(BaseModel):
table: Optional[Urn]
column: str
class _ColumnLineageInfo(BaseModel):
downstream: _DownstreamColumnRef
upstreams: List[_ColumnRef]
logic: Optional[str]
class ColumnLineageInfo(BaseModel):
downstream: DownstreamColumnRef
upstreams: List[ColumnRef]
# Logic for this column, as a SQL expression.
logic: Optional[str] = pydantic.Field(default=None, exclude=True)
class SqlParsingDebugInfo(BaseModel, arbitrary_types_allowed=True):
confidence: float
tables_discovered: int
table_schemas_resolved: int
column_error: Optional[Exception]
class SqlParsingResult(BaseModel):
query_type: QueryType = QueryType.UNKNOWN
in_tables: List[Urn]
out_tables: List[Urn]
column_lineage: Optional[List[ColumnLineageInfo]]
# TODO include formatted original sql logic
# TODO include list of referenced columns
debug_info: SqlParsingDebugInfo = pydantic.Field(
default_factory=lambda: SqlParsingDebugInfo(
confidence=0,
tables_discovered=0,
table_schemas_resolved=0,
column_error=None,
),
exclude=True,
)
def _parse_statement(sql: str, dialect: str) -> sqlglot.Expression:
statement = sqlglot.parse_one(
sql, read=dialect, error_level=sqlglot.ErrorLevel.RAISE
)
return statement
def _table_level_lineage(
statement: sqlglot.Expression,
dialect: str,
) -> Tuple[Set[_TableName], Set[_TableName]]:
def _raw_table_name(table: sqlglot.exp.Table) -> _TableName:
return _TableName.from_sqlglot_table(table, dialect=dialect)
# Generate table-level lineage.
modified = {
_raw_table_name(expr.this)
for expr in statement.find_all(
sqlglot.exp.Create,
sqlglot.exp.Insert,
sqlglot.exp.Update,
sqlglot.exp.Delete,
sqlglot.exp.Merge,
)
# In some cases like "MERGE ... then INSERT (col1, col2) VALUES (col1, col2)",
# the `this` on the INSERT part isn't a table.
if isinstance(expr.this, sqlglot.exp.Table)
}
tables = (
{_raw_table_name(table) for table in statement.find_all(sqlglot.exp.Table)}
# ignore references created in this query
- modified
# ignore CTEs created in this statement
- {
_TableName(database=None, schema=None, table=cte.alias_or_name)
for cte in statement.find_all(sqlglot.exp.CTE)
}
)
# TODO: If a CTAS has "LIMIT 0", it's not really lineage, just copying the schema.
return tables, modified
class SchemaResolver:
def __init__(
self,
*,
platform: str,
platform_instance: Optional[str] = None,
env: str = DEFAULT_ENV,
graph: Optional[DataHubGraph] = None,
_cache_filename: Optional[pathlib.Path] = None,
):
# TODO handle platforms when prefixed with urn:li:dataPlatform:
self.platform = platform
self.platform_instance = platform_instance
self.env = env
self.graph = graph
# Init cache, potentially restoring from a previous run.
shared_conn = None
if _cache_filename:
shared_conn = ConnectionWrapper(filename=_cache_filename)
self._schema_cache: FileBackedDict[Optional[SchemaInfo]] = FileBackedDict(
shared_connection=shared_conn,
)
def get_urn_for_table(self, table: _TableName, lower: bool = False) -> str:
# TODO: Validate that this is the correct 2/3 layer hierarchy for the platform.
table_name = ".".join(
filter(None, [table.database, table.db_schema, table.table])
)
if lower:
table_name = table_name.lower()
if self.platform == "bigquery":
# Normalize shard numbers and other BigQuery weirdness.
# TODO check that this is the right way to do it
with contextlib.suppress(IndexError):
table_name = BigqueryTableIdentifier.from_string_name(
table_name
).get_table_name()
urn = make_dataset_urn_with_platform_instance(
platform=self.platform,
platform_instance=self.platform_instance,
env=self.env,
name=table_name,
)
return urn
def resolve_table(self, table: _TableName) -> Tuple[str, Optional[SchemaInfo]]:
urn = self.get_urn_for_table(table)
schema_info = self._resolve_schema_info(urn)
if schema_info:
return urn, schema_info
urn_lower = self.get_urn_for_table(table, lower=True)
if urn_lower != urn:
schema_info = self._resolve_schema_info(urn_lower)
if schema_info:
return urn_lower, schema_info
return urn_lower, None
def _resolve_schema_info(self, urn: str) -> Optional[SchemaInfo]:
if urn in self._schema_cache:
return self._schema_cache[urn]
# TODO: For bigquery partitioned tables, add the pseudo-column _PARTITIONTIME
# or _PARTITIONDATE where appropriate.
if self.graph:
schema_info = self._fetch_schema_info(self.graph, urn)
if schema_info:
self._save_to_cache(urn, schema_info)
return schema_info
self._save_to_cache(urn, None)
return None
def add_schema_metadata(
self, urn: str, schema_metadata: SchemaMetadataClass
) -> None:
schema_info = self._convert_schema_aspect_to_info(schema_metadata)
self._save_to_cache(urn, schema_info)
def add_raw_schema_info(self, urn: str, schema_info: SchemaInfo) -> None:
self._save_to_cache(urn, schema_info)
def _save_to_cache(self, urn: str, schema_info: Optional[SchemaInfo]) -> None:
self._schema_cache[urn] = schema_info
def _fetch_schema_info(self, graph: DataHubGraph, urn: str) -> Optional[SchemaInfo]:
aspect = graph.get_aspect(urn, SchemaMetadataClass)
if not aspect:
return None
return self._convert_schema_aspect_to_info(aspect)
@classmethod
def _convert_schema_aspect_to_info(
cls, schema_metadata: SchemaMetadataClass
) -> SchemaInfo:
return {
DatasetUrn._get_simple_field_path_from_v2_field_path(col.fieldPath): (
# The actual types are more of a "nice to have".
col.nativeDataType
or "str"
)
for col in schema_metadata.fields
# TODO: We can't generate lineage to columns nested within structs yet.
if "."
not in DatasetUrn._get_simple_field_path_from_v2_field_path(col.fieldPath)
}
# TODO add a method to load all from graphql
# TODO: Once PEP 604 is supported (Python 3.10), we can unify these into a
# single type. See https://peps.python.org/pep-0604/#isinstance-and-issubclass.
_SupportedColumnLineageTypes = Union[
# Note that Select and Union inherit from Subqueryable.
sqlglot.exp.Subqueryable,
# For actual subqueries, the statement type might also be DerivedTable.
sqlglot.exp.DerivedTable,
]
_SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable)
class UnsupportedStatementTypeError(TypeError):
pass
class SqlUnderstandingError(Exception):
# Usually hit when we need schema info for a given statement but don't have it.
pass
# TODO: Break this up into smaller functions.
def _column_level_lineage( # noqa: C901
statement: sqlglot.exp.Expression,
dialect: str,
input_tables: Dict[_TableName, SchemaInfo],
output_table: Optional[_TableName],
default_db: Optional[str],
default_schema: Optional[str],
) -> List[_ColumnLineageInfo]:
if not isinstance(
statement,
_SupportedColumnLineageTypesTuple,
):
raise UnsupportedStatementTypeError(
f"Can only generate column-level lineage for select-like inner statements, not {type(statement)}"
)
use_case_insensitive_cols = dialect in {
# Column identifiers are case-insensitive in BigQuery, so we need to
# do a normalization step beforehand to make sure it's resolved correctly.
"bigquery",
# Our snowflake source lowercases column identifiers, so we are forced
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
"snowflake",
}
sqlglot_db_schema = sqlglot.MappingSchema(
dialect=dialect,
# We do our own normalization, so don't let sqlglot do it.
normalize=False,
)
table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict(
dict
)
for table, table_schema in input_tables.items():
normalized_table_schema: SchemaInfo = {}
for col, col_type in table_schema.items():
if use_case_insensitive_cols:
col_normalized = (
# This is required to match Sqlglot's behavior.
col.upper()
if dialect in {"snowflake"}
else col.lower()
)
else:
col_normalized = col
table_schema_normalized_mapping[table][col_normalized] = col
normalized_table_schema[col_normalized] = col_type
sqlglot_db_schema.add_table(
table.as_sqlglot_table(),
column_mapping=normalized_table_schema,
)
if use_case_insensitive_cols:
def _sqlglot_force_column_normalizer(
node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None
) -> sqlglot.exp.Expression:
if isinstance(node, sqlglot.exp.Column):
node.this.set("quoted", False)
return node
# logger.debug(
# "Prior to case normalization sql %s",
# statement.sql(pretty=True, dialect=dialect),
# )
statement = statement.transform(
_sqlglot_force_column_normalizer, dialect, copy=False
)
# logger.debug(
# "Sql after casing normalization %s",
# statement.sql(pretty=True, dialect=dialect),
# )
# Optimize the statement + qualify column references.
logger.debug(
"Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect)
)
try:
# Second time running qualify, this time with:
# - the select instead of the full outer statement
# - schema info
# - column qualification enabled
# logger.debug("Schema: %s", sqlglot_db_schema.mapping)
statement = sqlglot.optimizer.qualify.qualify(
statement,
dialect=dialect,
schema=sqlglot_db_schema,
validate_qualify_columns=False,
identify=True,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db,
db=default_schema,
)
except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError(
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
) from e
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
column_lineage = []
try:
assert isinstance(statement, _SupportedColumnLineageTypesTuple)
# List output columns.
output_columns = [
(select_col.alias_or_name, select_col) for select_col in statement.selects
]
logger.debug("output columns: %s", [col[0] for col in output_columns])
output_col: str
for output_col, original_col_expression in output_columns:
# print(f"output column: {output_col}")
if output_col == "*":
# If schema information is available, the * will be expanded to the actual columns.
# Otherwise, we can't process it.
continue
if dialect == "bigquery" and output_col.lower() in {
"_partitiontime",
"_partitiondate",
}:
# These are not real columns, just a way to filter by partition.
# TODO: We should add these columns to the schema info instead.
# Once that's done, we should actually generate lineage for these
# if they appear in the output.
continue
lineage_node = sqlglot.lineage.lineage(
output_col,
statement,
dialect=dialect,
schema=sqlglot_db_schema,
)
# pathlib.Path("sqlglot.html").write_text(
# str(lineage_node.to_html(dialect=dialect))
# )
# Generate SELECT lineage.
# Using a set here to deduplicate upstreams.
direct_col_upstreams: Set[_ColumnRef] = set()
for node in lineage_node.walk():
if node.downstream:
# We only want the leaf nodes.
pass
elif isinstance(node.expression, sqlglot.exp.Table):
table_ref = _TableName.from_sqlglot_table(
node.expression, dialect=dialect
)
# Parse the column name out of the node name.
# Sqlglot calls .sql(), so we have to do the inverse.
normalized_col = sqlglot.parse_one(node.name).this.name
if node.subfield:
normalized_col = f"{normalized_col}.{node.subfield}"
col = table_schema_normalized_mapping[table_ref].get(
normalized_col, normalized_col
)
direct_col_upstreams.add(_ColumnRef(table=table_ref, column=col))
else:
# This branch doesn't matter. For example, a count(*) column would go here, and
# we don't get any column-level lineage for that.
pass
# column_logic = lineage_node.source
if output_col.startswith("_col_"):
# This is the format sqlglot uses for unnamed columns e.g. 'count(id)' -> 'count(id) AS _col_0'
# This is a bit jank since we're relying on sqlglot internals, but it seems to be
# the best way to do it.
output_col = original_col_expression.this.sql(dialect=dialect)
if not direct_col_upstreams:
logger.debug(f' "{output_col}" has no upstreams')
column_lineage.append(
_ColumnLineageInfo(
downstream=_DownstreamColumnRef(
table=output_table, column=output_col
),
upstreams=sorted(direct_col_upstreams),
# logic=column_logic.sql(pretty=True, dialect=dialect),
)
)
# TODO: Also extract referenced columns (e.g. non-SELECT lineage)
except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError(
f"sqlglot failed to compute some lineage: {e}"
) from e
return column_lineage
def _extract_select_from_create(
statement: sqlglot.exp.Create,
) -> sqlglot.exp.Expression:
# TODO: Validate that this properly includes WITH clauses in all dialects.
inner = statement.expression
if inner:
return inner
else:
return statement
def _try_extract_select(
statement: sqlglot.exp.Expression,
) -> sqlglot.exp.Expression:
# Try to extract the core select logic from a more complex statement.
# If it fails, just return the original statement.
if isinstance(statement, sqlglot.exp.Merge):
# TODO Need to map column renames in the expressions part of the statement.
# Likely need to use the named_selects attr.
statement = statement.args["using"]
if isinstance(statement, sqlglot.exp.Table):
# If we're querying a table directly, wrap it in a SELECT.
statement = sqlglot.exp.Select().select("*").from_(statement)
elif isinstance(statement, sqlglot.exp.Insert):
# TODO Need to map column renames in the expressions part of the statement.
statement = statement.expression
elif isinstance(statement, sqlglot.exp.Create):
# TODO May need to map column renames.
# Assumption: the output table is already captured in the modified tables list.
statement = _extract_select_from_create(statement)
if isinstance(statement, sqlglot.exp.Subquery):
statement = statement.unnest()
return statement
def _translate_internal_column_lineage(
table_name_urn_mapping: Dict[_TableName, str],
raw_column_lineage: _ColumnLineageInfo,
) -> ColumnLineageInfo:
downstream_urn = None
if raw_column_lineage.downstream.table:
downstream_urn = table_name_urn_mapping[raw_column_lineage.downstream.table]
return ColumnLineageInfo(
downstream=DownstreamColumnRef(
table=downstream_urn,
column=raw_column_lineage.downstream.column,
),
upstreams=[
ColumnRef(
table=table_name_urn_mapping[upstream.table],
column=upstream.column,
)
for upstream in raw_column_lineage.upstreams
],
logic=raw_column_lineage.logic,
)
def sqlglot_lineage(
sql: str,
platform: str,
schema_resolver: SchemaResolver,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> SqlParsingResult:
# TODO: convert datahub platform names to sqlglot dialect
dialect = platform
if dialect == "snowflake":
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if default_db:
default_db = default_db.upper()
if default_schema:
default_schema = default_schema.upper()
logger.debug("Parsing lineage from sql statement: %s", sql)
statement = _parse_statement(sql, dialect=dialect)
original_statement = statement.copy()
# logger.debug(
# "Formatted sql statement: %s",
# original_statement.sql(pretty=True, dialect=dialect),
# )
# Make sure the tables are resolved with the default db / schema.
# This only works for Unionable statements. For other types of statements,
# we have to do it manually afterwards, but that's slightly lower accuracy
# because of CTEs.
statement = sqlglot.optimizer.qualify.qualify(
statement,
dialect=dialect,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db,
db=default_schema,
# At this stage we only want to qualify the table names. The columns will be dealt with later.
qualify_columns=False,
validate_qualify_columns=False,
# Only insert quotes where necessary.
identify=False,
)
# Generate table-level lineage.
tables, modified = _table_level_lineage(statement, dialect=dialect)
# Prep for generating column-level lineage.
downstream_table: Optional[_TableName] = None
if len(modified) == 1:
downstream_table = next(iter(modified))
# Fetch schema info for the relevant tables.
table_name_urn_mapping: Dict[_TableName, str] = {}
table_name_schema_mapping: Dict[_TableName, SchemaInfo] = {}
for table, is_input in itertools.chain(
[(table, True) for table in tables],
[(table, False) for table in modified],
):
# For select statements, qualification will be a no-op. For other statements, this
# is where the qualification actually happens.
qualified_table = table.qualified(
dialect=dialect, default_db=default_db, default_schema=default_schema
)
urn, schema_info = schema_resolver.resolve_table(qualified_table)
table_name_urn_mapping[qualified_table] = urn
if is_input and schema_info:
table_name_schema_mapping[qualified_table] = schema_info
# Also include the original, non-qualified table name in the urn mapping.
table_name_urn_mapping[table] = urn
debug_info = SqlParsingDebugInfo(
confidence=0.9 if len(tables) == len(table_name_schema_mapping)
# If we're missing any schema info, our confidence will be in the 0.2-0.5 range depending
# on how many tables we were able to resolve.
else 0.2 + 0.3 * len(table_name_schema_mapping) / len(tables),
tables_discovered=len(tables),
table_schemas_resolved=len(table_name_schema_mapping),
)
logger.debug(
f"Resolved {len(table_name_schema_mapping)} of {len(tables)} table schemas"
)
# Simplify the input statement for column-level lineage generation.
select_statement = _try_extract_select(statement)
# Generate column-level lineage.
column_lineage: Optional[List[_ColumnLineageInfo]] = None
try:
column_lineage = _column_level_lineage(
select_statement,
dialect=dialect,
input_tables=table_name_schema_mapping,
output_table=downstream_table,
default_db=default_db,
default_schema=default_schema,
)
except UnsupportedStatementTypeError as e:
# Inject details about the outer statement type too.
e.args = (f"{e.args[0]} (outer statement type: {type(statement)})",)
debug_info.column_error = e
logger.debug(debug_info.column_error)
except SqlUnderstandingError as e:
logger.debug(f"Failed to generate column-level lineage: {e}", exc_info=True)
debug_info.column_error = e
# TODO: Can we generate a common JOIN tables / keys section?
# TODO: Can we generate a common WHERE clauses section?
# Convert TableName to urns.
in_urns = sorted(set(table_name_urn_mapping[table] for table in tables))
out_urns = sorted(set(table_name_urn_mapping[table] for table in modified))
column_lineage_urns = None
if column_lineage:
column_lineage_urns = [
_translate_internal_column_lineage(
table_name_urn_mapping, internal_col_lineage
)
for internal_col_lineage in column_lineage
]
return SqlParsingResult(
query_type=get_query_type_of_sql(original_statement),
in_tables=in_urns,
out_tables=out_urns,
column_lineage=column_lineage_urns,
debug_info=debug_info,
)

View File

@ -0,0 +1,61 @@
{
"query_type": "CREATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col5"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)",
"column": "col5"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col1"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)",
"column": "col1"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col2"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)",
"column": "col2"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col3"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)",
"column": "col3"
}
]
}
]
}

View File

@ -0,0 +1,33 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "col1"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)",
"column": "col1"
}
]
},
{
"downstream": {
"table": null,
"column": "col2"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)",
"column": "col2"
}
]
}
]
}

View File

@ -0,0 +1,33 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "col1"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)",
"column": "col1"
}
]
},
{
"downstream": {
"table": null,
"column": "col2"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)",
"column": "col2"
}
]
}
]
}

View File

@ -0,0 +1,33 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "col1"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)",
"column": "col1"
}
]
},
{
"downstream": {
"table": null,
"column": "col2"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)",
"column": "col2"
}
]
}
]
}

View File

@ -0,0 +1,90 @@
{
"query_type": "SELECT",
"in_tables": [
{
"database": "bq-proj",
"db_schema": "dataset",
"table": "table1"
},
{
"database": "bq-proj",
"db_schema": "dataset",
"table": "table2"
}
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "day"
},
"upstreams": [
{
"table": {
"database": "bq-proj",
"db_schema": "dataset",
"table": "table1"
},
"column": "reporting_day"
}
]
},
{
"downstream": {
"table": null,
"column": "product"
},
"upstreams": [
{
"table": {
"database": "bq-proj",
"db_schema": "dataset",
"table": "table1"
},
"column": "by_product.product_code"
},
{
"table": {
"database": "bq-proj",
"db_schema": "dataset",
"table": "table2"
},
"column": "other_field"
}
]
},
{
"downstream": {
"table": null,
"column": "other_field"
},
"upstreams": [
{
"table": {
"database": "bq-proj",
"db_schema": "dataset",
"table": "table2"
},
"column": "other_field"
}
]
},
{
"downstream": {
"table": null,
"column": "daily_active_users"
},
"upstreams": [
{
"table": {
"database": "bq-proj",
"db_schema": "dataset",
"table": "table1"
},
"column": "by_product.product_code_dau"
}
]
}
]
}

View File

@ -0,0 +1,42 @@
{
"query_type": "CREATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)",
"column": "Department"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)",
"column": "deptno"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)",
"column": "Employees"
},
"upstreams": []
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)",
"column": "Salary"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)",
"column": "sal"
}
]
}
]
}

View File

@ -0,0 +1,129 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "TOTAL_AGG"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "totalprice"
}
]
},
{
"downstream": {
"table": null,
"column": "ORDERKEY"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderkey"
}
]
},
{
"downstream": {
"table": null,
"column": "CUSTKEY"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "custkey"
}
]
},
{
"downstream": {
"table": null,
"column": "ORDERSTATUS"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderstatus"
}
]
},
{
"downstream": {
"table": null,
"column": "TOTALPRICE"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "totalprice"
}
]
},
{
"downstream": {
"table": null,
"column": "ORDERDATE"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderdate"
}
]
},
{
"downstream": {
"table": null,
"column": "ORDERPRIORITY"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderpriority"
}
]
},
{
"downstream": {
"table": null,
"column": "CLERK"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "clerk"
}
]
},
{
"downstream": {
"table": null,
"column": "SHIPPRIORITY"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "shippriority"
}
]
},
{
"downstream": {
"table": null,
"column": "COMMENT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "comment"
}
]
}
]
}

View File

@ -0,0 +1,76 @@
{
"query_type": "INSERT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:hive,catalog_returns,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,catalog_sales,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,customer_demographics,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,date_dim,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,household_demographics,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,inventory,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,item,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:hive,warehouse,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "i_item_desc"
},
"upstreams": []
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "w_warehouse_name"
},
"upstreams": []
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "d_week_seq"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,date_dim,PROD)",
"column": "d_week_seq"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "no_promo"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)",
"column": "p_promo_sk"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "promo"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)",
"column": "p_promo_sk"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)",
"column": "total_cnt"
},
"upstreams": []
}
]
}

View File

@ -0,0 +1,11 @@
{
"query_type": "MERGE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.prep_from_ios,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.prep_from_web,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.base_union,PROD)"
],
"column_lineage": null
}

View File

@ -0,0 +1,21 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:mysql,something_prd.fact_complaint_snapshot,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "COUNT(`fact_complaint_snapshot`.`etl_data_dt_id`)"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:mysql,something_prd.fact_complaint_snapshot,PROD)",
"column": "etl_data_dt_id"
}
]
}
]
}

View File

@ -0,0 +1,49 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "post_id"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)",
"column": "post_id"
}
]
},
{
"downstream": {
"table": null,
"column": "id"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)",
"column": "widget.asset.id"
}
]
},
{
"downstream": {
"table": null,
"column": "min_metric"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)",
"column": "widget.metric.metricA"
},
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)",
"column": "widget.metric.metric_b"
}
]
}
]
}

View File

@ -0,0 +1,33 @@
{
"query_type": "UNKNOWN",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "LABEL"
},
"upstreams": []
},
{
"downstream": {
"table": null,
"column": "TOTAL_AGG"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)",
"column": "totalprice"
},
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)",
"column": "totalprice"
}
]
}
]
}

View File

@ -0,0 +1,25 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "max_col"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)",
"column": "col1"
},
{
"table": "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)",
"column": "col2"
}
]
}
]
}

View File

@ -0,0 +1,34 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:oracle,table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:oracle,table2,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "col1"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,table1,PROD)",
"column": "col1"
}
]
},
{
"downstream": {
"table": null,
"column": "col3"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:oracle,table2,PROD)",
"column": "col3"
}
]
}
]
}

View File

@ -0,0 +1,33 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "post_id"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)",
"column": "post_id"
}
]
},
{
"downstream": {
"table": null,
"column": "id"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)",
"column": "widget.asset.id"
}
]
}
]
}

View File

@ -0,0 +1,57 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "TOTAL_AGG"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "TotalPrice"
}
]
},
{
"downstream": {
"table": null,
"column": "TOTAL_AVG"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "TotalPrice"
}
]
},
{
"downstream": {
"table": null,
"column": "TOTAL_MIN"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "TotalPrice"
}
]
},
{
"downstream": {
"table": null,
"column": "TOTAL_MAX"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "TotalPrice"
}
]
}
]
}

View File

@ -0,0 +1,88 @@
{
"query_type": "CREATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "USER_FK"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)",
"column": "USER_FK"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "EMAIL"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)",
"column": "EMAIL"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "LAST_PURCHASE_DATE"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)",
"column": "LAST_PURCHASE_DATE"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "LIFETIME_PURCHASE_AMOUNT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)",
"column": "purchase_amount"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "LIFETIME_PURCHASE_COUNT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)",
"column": "pk"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)",
"column": "AVERAGE_PURCHASE_AMOUNT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)",
"column": "pk"
},
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)",
"column": "purchase_amount"
}
]
}
]
}

View File

@ -0,0 +1,479 @@
import pathlib
import pytest
from datahub.testing.check_sql_parser_result import assert_sql_result
RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens"
def test_select_max():
# The COL2 should get normalized to col2.
assert_sql_result(
"""
SELECT max(col1, COL2) as max_col
FROM mytable
""",
dialect="mysql",
expected_file=RESOURCE_DIR / "test_select_max.json",
)
def test_select_max_with_schema():
# Note that `this_will_not_resolve` will be dropped from the result because it's not in the schema.
assert_sql_result(
"""
SELECT max(`col1`, COL2, `this_will_not_resolve`) as max_col
FROM mytable
""",
dialect="mysql",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)": {
"col1": "NUMBER",
"col2": "NUMBER",
},
},
# Shared with the test above.
expected_file=RESOURCE_DIR / "test_select_max.json",
)
def test_select_count():
assert_sql_result(
"""
SELECT
COUNT(etl_data_dt_id)
FROM something_prd.fact_complaint_snapshot
WHERE
etl_data_dt_id = 20230317
""",
dialect="mysql",
expected_file=RESOURCE_DIR / "test_select_count.json",
)
def test_select_with_ctes():
assert_sql_result(
"""
WITH cte1 AS (
SELECT col1, col2
FROM table1
WHERE col1 = 'value1'
), cte2 AS (
SELECT col3, col4
FROM table2
WHERE col2 = 'value2'
)
SELECT cte1.col1, cte2.col3
FROM cte1
JOIN cte2 ON cte1.col2 = cte2.col4
""",
dialect="oracle",
expected_file=RESOURCE_DIR / "test_select_with_ctes.json",
)
def test_create_view_as_select():
assert_sql_result(
"""
CREATE VIEW vsal
AS
SELECT a.deptno "Department",
a.num_emp / b.total_count "Employees",
a.sal_sum / b.total_sal "Salary"
FROM (SELECT deptno,
Count() num_emp,
SUM(sal) sal_sum
FROM scott.emp
WHERE city = 'NYC'
GROUP BY deptno) a,
(SELECT Count() total_count,
SUM(sal) total_sal
FROM scott.emp
WHERE city = 'NYC') b
;
""",
dialect="oracle",
expected_file=RESOURCE_DIR / "test_create_view_as_select.json",
)
def test_insert_as_select():
# Note: this also tests lineage with case statements.
assert_sql_result(
"""
insert into query72
select i_item_desc
, w_warehouse_name
, d1.d_week_seq
, sum(case when promotion.p_promo_sk is null then 1 else 0 end) no_promo
, sum(case when promotion.p_promo_sk is not null then 1 else 0 end) promo
, count(*) total_cnt
from catalog_sales
join inventory on (cs_item_sk = inv_item_sk)
join warehouse on (w_warehouse_sk = inv_warehouse_sk)
join item on (i_item_sk = cs_item_sk)
join customer_demographics on (cs_bill_cdemo_sk = cd_demo_sk)
join household_demographics on (cs_bill_hdemo_sk = hd_demo_sk)
join date_dim d1 on (cs_sold_date_sk = d1.d_date_sk)
join date_dim d2 on (inv_date_sk = d2.d_date_sk)
join date_dim d3 on (cs_ship_date_sk = d3.d_date_sk)
left outer join promotion on (cs_promo_sk = p_promo_sk)
left outer join catalog_returns on (cr_item_sk = cs_item_sk and cr_order_number = cs_order_number)
where d1.d_week_seq = d2.d_week_seq
and inv_quantity_on_hand < cs_quantity
and hd_buy_potential = '>10000'
and cd_marital_status = 'D'
group by i_item_desc, w_warehouse_name, d1.d_week_seq
order by total_cnt desc, i_item_desc, w_warehouse_name, d_week_seq
limit 100;
""",
dialect="hive",
expected_file=RESOURCE_DIR / "test_insert_as_select.json",
)
def test_select_with_full_col_name():
# In this case, `widget` is a struct column.
# This also tests the `default_db` functionality.
assert_sql_result(
"""
SELECT distinct post_id , widget.asset.id
FROM data_reporting.abcde_transformed
WHERE post_id LIKE '%268662%'
""",
dialect="bigquery",
default_db="my-bq-ProjectName",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)": {
"post_id": "NUMBER",
"widget": "struct",
},
},
expected_file=RESOURCE_DIR / "test_select_with_full_col_name.json",
)
def test_select_from_struct_subfields():
# In this case, `widget` is a column name.
assert_sql_result(
"""
SELECT distinct post_id ,
widget.asset.id,
min(widget.metric.metricA, widget.metric.metric_b) as min_metric
FROM data_reporting.abcde_transformed
WHERE post_id LIKE '%12345%'
""",
dialect="bigquery",
default_db="my-bq-proj",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)": {
"post_id": "NUMBER",
"widget": "struct",
"widget.asset.id": "int",
"widget.metric.metricA": "int",
"widget.metric.metric_b": "int",
},
},
expected_file=RESOURCE_DIR / "test_select_from_struct_subfields.json",
)
def test_select_from_union():
assert_sql_result(
"""
SELECT
'orders_10' as label,
SUM(totalprice) as total_agg
FROM snowflake_sample_data.tpch_sf10.orders
UNION ALL
SELECT
'orders_100' as label,
SUM(totalprice) as total_agg,
FROM snowflake_sample_data.tpch_sf100.orders
""",
dialect="snowflake",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)": {
"orderkey": "NUMBER",
"totalprice": "FLOAT",
},
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)": {
"orderkey": "NUMBER",
"totalprice": "FLOAT",
},
},
expected_file=RESOURCE_DIR / "test_select_from_union.json",
)
def test_merge_from_union():
# TODO: We don't support merge statements yet, but the union should still get handled.
assert_sql_result(
"""
merge into `demo-pipelines-stg`.`referrer`.`base_union` as DBT_INTERNAL_DEST
using (
SELECT * FROM `demo-pipelines-stg`.`referrer`.`prep_from_ios` WHERE partition_time = "2018-03-03"
UNION ALL
SELECT * FROM `demo-pipelines-stg`.`referrer`.`prep_from_web` WHERE partition_time = "2018-03-03"
) as DBT_INTERNAL_SOURCE
on FALSE
when not matched by source
and timestamp_trunc(DBT_INTERNAL_DEST.partition_time, day) in (
timestamp('2018-03-03')
)
then delete
when not matched then insert
(`platform`, `pageview_id`, `query`, `referrer`, `partition_time`)
values
(`platform`, `pageview_id`, `query`, `referrer`, `partition_time`)
""",
dialect="bigquery",
expected_file=RESOURCE_DIR / "test_merge_from_union.json",
)
def test_expand_select_star_basic():
assert_sql_result(
"""
SELECT
SUM(totalprice) as total_agg,
*
FROM snowflake_sample_data.tpch_sf1.orders
WHERE orderdate = '1992-01-01'
""",
dialect="snowflake",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": {
"orderkey": "NUMBER",
"custkey": "NUMBER",
"orderstatus": "TEXT",
"totalprice": "FLOAT",
"orderdate": "DATE",
"orderpriority": "TEXT",
"clerk": "TEXT",
"shippriority": "NUMBER",
"comment": "TEXT",
},
},
expected_file=RESOURCE_DIR / "test_expand_select_star_basic.json",
)
def test_snowflake_column_normalization():
# Technically speaking this is incorrect since the column names are different and both quoted.
assert_sql_result(
"""
SELECT
SUM(o."totalprice") as total_agg,
AVG("TotalPrice") as total_avg,
MIN("TOTALPRICE") as total_min,
MAX(TotalPrice) as total_max
FROM snowflake_sample_data.tpch_sf1.orders o
""",
dialect="snowflake",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": {
"orderkey": "NUMBER",
"TotalPrice": "FLOAT",
},
},
expected_file=RESOURCE_DIR / "test_snowflake_column_normalization.json",
)
@pytest.mark.skip(reason="We don't handle the unnest lineage correctly")
def test_bigquery_unnest_columns():
assert_sql_result(
"""
SELECT
DATE(reporting_day) AS day,
CASE
WHEN p.product_code IN ('A', 'B', 'C')
THEN pr.other_field
ELSE 'Other'
END AS product,
pr.other_field AS other_field,
SUM(p.product_code_dau) AS daily_active_users
FROM `bq-proj`.dataset.table1
LEFT JOIN UNNEST(by_product) AS p
LEFT JOIN (
SELECT DISTINCT
product_code,
other_field
FROM `bq-proj`.dataset.table2
) AS pr
-- USING (product_code)
ON p.product_code = pr.product_code
""",
dialect="bigquery",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)": {
"reporting_day": "DATE",
"by_product": "ARRAY<STRUCT<product_code STRING, product_code_dau INT64>>",
},
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table2,PROD)": {
"product_code": "STRING",
"other_field": "STRING",
},
},
expected_file=RESOURCE_DIR / "test_bigquery_unnest_columns.json",
)
def test_bigquery_create_view_with_cte():
assert_sql_result(
"""
CREATE VIEW `my-proj-2`.dataset.my_view AS
WITH cte1 AS (
SELECT *
FROM dataset.table1
WHERE col1 = 'value1'
), cte2 AS (
SELECT col3, col4 as join_key
FROM dataset.table2
WHERE col3 = 'value2'
)
SELECT col5, cte1.*, col3
FROM dataset.table3
JOIN cte1 ON table3.col5 = cte1.col2
JOIN cte2 USING (join_key)
""",
dialect="bigquery",
default_db="my-proj-1",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)": {
"col1": "STRING",
"col2": "STRING",
},
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)": {
"col3": "STRING",
"col4": "STRING",
},
"urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)": {
"col5": "STRING",
"join_key": "STRING",
},
},
expected_file=RESOURCE_DIR / "test_bigquery_create_view_with_cte.json",
)
def test_bigquery_nested_subqueries():
assert_sql_result(
"""
SELECT *
FROM (
SELECT *
FROM (
SELECT *
FROM `bq-proj`.dataset.table1
)
)
""",
dialect="bigquery",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)": {
"col1": "STRING",
"col2": "STRING",
},
},
expected_file=RESOURCE_DIR / "test_bigquery_nested_subqueries.json",
)
def test_bigquery_sharded_table_normalization():
assert_sql_result(
"""
SELECT *
FROM `bq-proj.dataset.table_20230101`
""",
dialect="bigquery",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)": {
"col1": "STRING",
"col2": "STRING",
},
},
expected_file=RESOURCE_DIR / "test_bigquery_sharded_table_normalization.json",
)
def test_bigquery_from_sharded_table_wildcard():
assert_sql_result(
"""
SELECT *
FROM `bq-proj.dataset.table_2023*`
""",
dialect="bigquery",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)": {
"col1": "STRING",
"col2": "STRING",
},
},
expected_file=RESOURCE_DIR / "test_bigquery_from_sharded_table_wildcard.json",
)
def test_snowflake_default_normalization():
assert_sql_result(
"""
create table active_customer_ltv as (
with active_customers as (
select * from customer_last_purchase_date
where
last_purchase_date >= current_date - interval '90 days'
)
, purchases as (
select * from ecommerce.purchases
)
select
active_customers.user_fk
, active_customers.email
, active_customers.last_purchase_date
, sum(purchases.purchase_amount) as lifetime_purchase_amount
, count(distinct(purchases.pk)) as lifetime_purchase_count
, sum(purchases.purchase_amount) / count(distinct(purchases.pk)) as average_purchase_amount
from
active_customers
join
purchases
on active_customers.user_fk = purchases.user_fk
group by 1,2,3
)
""",
dialect="snowflake",
default_db="long_tail_companions",
default_schema="analytics",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)": {
"pk": "NUMBER(38,0)",
"USER_FK": "NUMBER(38,0)",
"status": "VARCHAR(16777216)",
"purchase_amount": "NUMBER(10,2)",
"tax_AMOUNT": "NUMBER(10,2)",
"TOTAL_AMOUNT": "NUMBER(10,2)",
"CREATED_AT": "TIMESTAMP_NTZ",
"UPDATED_AT": "TIMESTAMP_NTZ",
},
"urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)": {
"USER_FK": "NUMBER(38,0)",
"EMAIL": "VARCHAR(16777216)",
"LAST_PURCHASE_DATE": "DATE",
},
},
expected_file=RESOURCE_DIR / "test_snowflake_default_normalization.json",
)
# TODO: Add a test for setting platform_instance or env

View File

@ -13,6 +13,7 @@ from google.cloud.bigquery.table import Row, TableListItem
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.bigquery_v2.bigquery import BigqueryV2Source
from datahub.ingestion.source.bigquery_v2.bigquery_audit import (
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX,
BigqueryTableIdentifier,
BigQueryTableRef,
)
@ -652,7 +653,7 @@ def test_get_table_and_shard_default(
) -> None:
with patch(
"datahub.ingestion.source.bigquery_v2.bigquery_audit.BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX",
"((.+)[_$])?(\\d{8})$",
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX,
):
assert BigqueryTableIdentifier.get_table_and_shard(table_name) == (
expected_table_prefix,