diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 688a9f7a43..87cd2bbbd1 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -133,6 +133,12 @@ sqllineage_lib = { "sqlparse==0.4.3", } +sqlglot_lib = { + # Using an Acryl fork of sqlglot. + # https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1 + "acryl-sqlglot==16.7.6.dev6", +} + aws_common = { # AWS Python SDK "boto3", @@ -269,6 +275,8 @@ plugins: Dict[str, Set[str]] = { "gql[requests]>=3.3.0", }, "great-expectations": sql_common | sqllineage_lib, + # Misc plugins. + "sql-parser": sqlglot_lib, # Source plugins # PyAthena is pinned with exact version because we use private method in PyAthena "athena": sql_common | {"PyAthena[SQLAlchemy]==2.4.1"}, @@ -276,7 +284,9 @@ plugins: Dict[str, Set[str]] = { "bigquery": sql_common | bigquery_common | { + # TODO: I doubt we need all three sql parsing libraries. *sqllineage_lib, + *sqlglot_lib, "sql_metadata", "sqlalchemy-bigquery>=1.4.1", "google-cloud-datacatalog-lineage==0.2.2", @@ -285,6 +295,7 @@ plugins: Dict[str, Set[str]] = { | bigquery_common | { *sqllineage_lib, + *sqlglot_lib, "sql_metadata", "sqlalchemy-bigquery>=1.4.1", }, # deprecated, but keeping the extra for backwards compatibility diff --git a/metadata-ingestion/src/datahub/cli/check_cli.py b/metadata-ingestion/src/datahub/cli/check_cli.py index 3293c4ac96..5d34967d49 100644 --- a/metadata-ingestion/src/datahub/cli/check_cli.py +++ b/metadata-ingestion/src/datahub/cli/check_cli.py @@ -1,14 +1,21 @@ +import logging import shutil import tempfile +from typing import Optional import click from datahub import __package_name__ from datahub.cli.json_file import check_mce_file +from datahub.emitter.mce_builder import DEFAULT_ENV +from datahub.ingestion.graph.client import get_default_graph from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.sink.sink_registry import sink_registry from datahub.ingestion.source.source_registry import source_registry from datahub.ingestion.transformer.transform_registry import transform_registry +from datahub.telemetry import telemetry + +logger = logging.getLogger(__name__) @click.group() @@ -28,6 +35,7 @@ def check() -> None: @click.option( "--unpack-mces", default=False, is_flag=True, help="Converts MCEs into MCPs" ) +@telemetry.with_telemetry() def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None: """Check the schema of a metadata (MCE or MCP) JSON file.""" @@ -70,6 +78,7 @@ def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None: default=False, help="Include extra information for each plugin.", ) +@telemetry.with_telemetry() def plugins(verbose: bool) -> None: """List the enabled ingestion plugins.""" @@ -87,3 +96,68 @@ def plugins(verbose: bool) -> None: click.echo( f"If a plugin is disabled, try running: pip install '{__package_name__}[]'" ) + + +@check.command() +@click.option( + "--sql", + type=str, + required=True, + help="The SQL query to parse", +) +@click.option( + "--platform", + type=str, + required=True, + help="The SQL dialect e.g. bigquery or snowflake", +) +@click.option( + "--platform-instance", + type=str, + help="The specific platform_instance the SQL query was run in", +) +@click.option( + "--env", + type=str, + default=DEFAULT_ENV, + help=f"The environment the SQL query was run in, defaults to {DEFAULT_ENV}", +) +@click.option( + "--default-db", + type=str, + help="The default database to use for unqualified table names", +) +@click.option( + "--default-schema", + type=str, + help="The default schema to use for unqualified table names", +) +@telemetry.with_telemetry() +def sql_lineage( + sql: str, + platform: str, + default_db: Optional[str], + default_schema: Optional[str], + platform_instance: Optional[str], + env: str, +) -> None: + """Parse the lineage of a SQL query. + + This performs schema-aware parsing in order to generate column-level lineage. + If the relevant tables are not in DataHub, this will be less accurate. + """ + + graph = get_default_graph() + + lineage = graph.parse_sql_lineage( + sql, + platform=platform, + platform_instance=platform_instance, + env=env, + default_db=default_db, + default_schema=default_schema, + ) + + logger.debug("Sql parsing debug info: %s", lineage.debug_info) + + click.echo(lineage.json(indent=4)) diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index c3fb5dd4e3..ba07ea70c9 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -1,4 +1,5 @@ import enum +import functools import json import logging import textwrap @@ -16,7 +17,7 @@ from datahub.cli.cli_utils import get_url_and_token from datahub.configuration.common import ConfigModel, GraphError, OperationalError from datahub.configuration.validate_field_removal import pydantic_removed_field from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP -from datahub.emitter.mce_builder import Aspect, make_data_platform_urn +from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect, make_data_platform_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.rest_emitter import DatahubRestEmitter from datahub.emitter.serialization_helper import post_json_transform @@ -44,6 +45,7 @@ if TYPE_CHECKING: from datahub.ingestion.source.state.entity_removal_state import ( GenericCheckpointState, ) + from datahub.utilities.sqlglot_lineage import SchemaResolver, SqlParsingResult logger = logging.getLogger(__name__) @@ -955,6 +957,46 @@ class DataHubGraph(DatahubRestEmitter): related_aspects = response.get("relatedAspects", []) return reference_count, related_aspects + @functools.lru_cache() + def _make_schema_resolver( + self, platform: str, platform_instance: Optional[str], env: str + ) -> "SchemaResolver": + from datahub.utilities.sqlglot_lineage import SchemaResolver + + return SchemaResolver( + platform=platform, + platform_instance=platform_instance, + env=env, + graph=self, + ) + + def parse_sql_lineage( + self, + sql: str, + *, + platform: str, + platform_instance: Optional[str] = None, + env: str = DEFAULT_ENV, + default_db: Optional[str] = None, + default_schema: Optional[str] = None, + ) -> "SqlParsingResult": + from datahub.utilities.sqlglot_lineage import sqlglot_lineage + + # Cache the schema resolver to make bulk parsing faster. + schema_resolver = self._make_schema_resolver( + platform=platform, + platform_instance=platform_instance, + env=env, + ) + + return sqlglot_lineage( + sql, + platform=platform, + schema_resolver=schema_resolver, + default_db=default_db, + default_schema=default_schema, + ) + def get_default_graph() -> DataHubGraph: (url, token) = get_url_and_token() diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py index 379a773e24..0f9b37c93f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py @@ -62,6 +62,8 @@ BigQueryAuditMetadata = Any logger: logging.Logger = logging.getLogger(__name__) +_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = "((.+)[_$])?(\\d{8})$" + @dataclass(frozen=True, order=True) class BigqueryTableIdentifier: @@ -70,7 +72,12 @@ class BigqueryTableIdentifier: table: str invalid_chars: ClassVar[Set[str]] = {"$", "@"} - _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{8})$" + + # Note: this regex may get overwritten by the sharded_table_pattern config. + # The class-level constant, however, will not be overwritten. + _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[ + str + ] = _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX _BIGQUERY_WILDCARD_REGEX: ClassVar[str] = "((_(\\d+)?)\\*$)|\\*$" _BQ_SHARDED_TABLE_SUFFIX: str = "_yyyymmdd" diff --git a/metadata-ingestion/src/datahub/testing/__init__.py b/metadata-ingestion/src/datahub/testing/__init__.py new file mode 100644 index 0000000000..88eef33282 --- /dev/null +++ b/metadata-ingestion/src/datahub/testing/__init__.py @@ -0,0 +1,6 @@ +"""Testing utilities for the datahub package. + +These modules are included in the `datahub.testing` namespace, but aren't +part of the public API. They're placed here, rather than in the `tests` +directory, so they can be used by other packages that depend on `datahub`. +""" diff --git a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py new file mode 100644 index 0000000000..b73af0478a --- /dev/null +++ b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py @@ -0,0 +1,90 @@ +import logging +import os +import pathlib +from typing import Any, Dict, Optional + +import deepdiff + +from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier +from datahub.utilities.sqlglot_lineage import ( + SchemaInfo, + SchemaResolver, + SqlParsingResult, + sqlglot_lineage, +) + +logger = logging.getLogger(__name__) + +# TODO: Hook this into the standard --update-golden-files mechanism. +UPDATE_FILES = os.environ.get("UPDATE_SQLPARSER_FILES", "false").lower() == "true" + + +def assert_sql_result_with_resolver( + sql: str, + *, + dialect: str, + expected_file: pathlib.Path, + schema_resolver: SchemaResolver, + **kwargs: Any, +) -> None: + # HACK: Our BigQuery source overwrites this value and doesn't undo it. + # As such, we need to handle that here. + BigqueryTableIdentifier._BQ_SHARDED_TABLE_SUFFIX = "_yyyymmdd" + + res = sqlglot_lineage( + sql, + platform=dialect, + schema_resolver=schema_resolver, + **kwargs, + ) + + if res.debug_info.column_error: + logger.warning( + f"SQL parser column error: {res.debug_info.column_error}", + exc_info=res.debug_info.column_error, + ) + + txt = res.json(indent=4) + if UPDATE_FILES: + expected_file.write_text(txt) + return + + if not expected_file.exists(): + expected_file.write_text(txt) + raise AssertionError( + f"Expected file {expected_file} does not exist. " + "Created it with the expected output. Please verify it." + ) + + expected = SqlParsingResult.parse_raw(expected_file.read_text()) + + full_diff = deepdiff.DeepDiff( + expected.dict(), + res.dict(), + exclude_regex_paths=[ + r"root.column_lineage\[\d+\].logic", + ], + ) + assert not full_diff, full_diff + + +def assert_sql_result( + sql: str, + *, + dialect: str, + expected_file: pathlib.Path, + schemas: Optional[Dict[str, SchemaInfo]] = None, + **kwargs: Any, +) -> None: + schema_resolver = SchemaResolver(platform=dialect) + if schemas: + for urn, schema in schemas.items(): + schema_resolver.add_raw_schema_info(urn, schema) + + assert_sql_result_with_resolver( + sql, + dialect=dialect, + expected_file=expected_file, + schema_resolver=schema_resolver, + **kwargs, + ) diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index 58e693b315..74e97e5104 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -55,18 +55,30 @@ class ConnectionWrapper: conn: sqlite3.Connection filename: pathlib.Path + _directory: Optional[tempfile.TemporaryDirectory] + _allow_table_name_reuse: bool def __init__(self, filename: Optional[pathlib.Path] = None): self._directory = None - # Warning: If filename is provided, the file will not be automatically cleaned up - if not filename: + + # In the normal case, we do not use "IF NOT EXISTS" in our create table statements + # because creating the same table twice indicates a client usage error. + # However, if you're trying to persist a file-backed dict across multiple runs, + # which happens when filename is passed explicitly, then we need to allow table name reuse. + allow_table_name_reuse = False + + # Warning: If filename is provided, the file will not be automatically cleaned up. + if filename: + allow_table_name_reuse = True + else: self._directory = tempfile.TemporaryDirectory() filename = pathlib.Path(self._directory.name) / _DEFAULT_FILE_NAME self.conn = sqlite3.connect(filename, isolation_level=None) self.conn.row_factory = sqlite3.Row self.filename = filename + self._allow_table_name_reuse = allow_table_name_reuse # These settings are optimized for performance. # See https://www.sqlite.org/pragma.html for more information. @@ -80,13 +92,13 @@ class ConnectionWrapper: def execute( self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = () ) -> sqlite3.Cursor: - logger.debug(f"Executing <{sql}> ({parameters})") + # logger.debug(f"Executing <{sql}> ({parameters})") return self.conn.execute(sql, parameters) def executemany( self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = () ) -> sqlite3.Cursor: - logger.debug(f"Executing many <{sql}> ({parameters})") + # logger.debug(f"Executing many <{sql}> ({parameters})") return self.conn.executemany(sql, parameters) def close(self) -> None: @@ -184,10 +196,10 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]): # a poor-man's LRU cache. self._active_object_cache = collections.OrderedDict() - # Create the table. We're not using "IF NOT EXISTS" because creating - # the same table twice indicates a client usage error. + # Create the table. + if_not_exists = "IF NOT EXISTS" if self._conn._allow_table_name_reuse else "" self._conn.execute( - f"""CREATE TABLE {self.tablename} ( + f"""CREATE TABLE {if_not_exists} {self.tablename} ( key TEXT PRIMARY KEY, value BLOB {''.join(f', {column_name} BLOB' for column_name in self.extra_columns.keys())} diff --git a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py new file mode 100644 index 0000000000..189c5bb269 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py @@ -0,0 +1,754 @@ +import contextlib +import enum +import functools +import itertools +import logging +import pathlib +from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple, Union + +import pydantic +import pydantic.dataclasses +import sqlglot +import sqlglot.errors +import sqlglot.lineage +import sqlglot.optimizer.qualify +import sqlglot.optimizer.qualify_columns +from pydantic import BaseModel + +from datahub.emitter.mce_builder import ( + DEFAULT_ENV, + make_dataset_urn_with_platform_instance, +) +from datahub.ingestion.graph.client import DataHubGraph +from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier +from datahub.metadata.schema_classes import SchemaMetadataClass +from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict +from datahub.utilities.urns.dataset_urn import DatasetUrn + +logger = logging.getLogger(__name__) + +Urn = str + +# A lightweight table schema: column -> type mapping. +SchemaInfo = Dict[str, str] + + +class QueryType(enum.Enum): + CREATE = "CREATE" + SELECT = "SELECT" + INSERT = "INSERT" + UPDATE = "UPDATE" + DELETE = "DELETE" + MERGE = "MERGE" + + UNKNOWN = "UNKNOWN" + + +def get_query_type_of_sql(expression: sqlglot.exp.Expression) -> QueryType: + mapping = { + sqlglot.exp.Create: QueryType.CREATE, + sqlglot.exp.Select: QueryType.SELECT, + sqlglot.exp.Insert: QueryType.INSERT, + sqlglot.exp.Update: QueryType.UPDATE, + sqlglot.exp.Delete: QueryType.DELETE, + sqlglot.exp.Merge: QueryType.MERGE, + } + + for cls, query_type in mapping.items(): + if isinstance(expression, cls): + return query_type + return QueryType.UNKNOWN + + +@functools.total_ordering +class _FrozenModel(BaseModel, frozen=True): + def __lt__(self, other: "_FrozenModel") -> bool: + for field in self.__fields__: + self_v = getattr(self, field) + other_v = getattr(other, field) + if self_v != other_v: + return self_v < other_v + + return False + + +class _TableName(_FrozenModel): + database: Optional[str] + db_schema: Optional[str] + table: str + + def as_sqlglot_table(self) -> sqlglot.exp.Table: + return sqlglot.exp.Table( + catalog=self.database, db=self.db_schema, this=self.table + ) + + def qualified( + self, + dialect: str, + default_db: Optional[str] = None, + default_schema: Optional[str] = None, + ) -> "_TableName": + database = self.database or default_db + db_schema = self.db_schema or default_schema + + return _TableName( + database=database, + db_schema=db_schema, + table=self.table, + ) + + @classmethod + def from_sqlglot_table( + cls, + table: sqlglot.exp.Table, + dialect: str, + default_db: Optional[str] = None, + default_schema: Optional[str] = None, + ) -> "_TableName": + return cls( + database=table.catalog or default_db, + db_schema=table.db or default_schema, + table=table.this.name, + ) + + +class _ColumnRef(_FrozenModel): + table: _TableName + column: str + + +class ColumnRef(BaseModel): + table: Urn + column: str + + +class _DownstreamColumnRef(BaseModel): + table: Optional[_TableName] + column: str + + +class DownstreamColumnRef(BaseModel): + table: Optional[Urn] + column: str + + +class _ColumnLineageInfo(BaseModel): + downstream: _DownstreamColumnRef + upstreams: List[_ColumnRef] + + logic: Optional[str] + + +class ColumnLineageInfo(BaseModel): + downstream: DownstreamColumnRef + upstreams: List[ColumnRef] + + # Logic for this column, as a SQL expression. + logic: Optional[str] = pydantic.Field(default=None, exclude=True) + + +class SqlParsingDebugInfo(BaseModel, arbitrary_types_allowed=True): + confidence: float + + tables_discovered: int + table_schemas_resolved: int + + column_error: Optional[Exception] + + +class SqlParsingResult(BaseModel): + query_type: QueryType = QueryType.UNKNOWN + + in_tables: List[Urn] + out_tables: List[Urn] + + column_lineage: Optional[List[ColumnLineageInfo]] + + # TODO include formatted original sql logic + # TODO include list of referenced columns + + debug_info: SqlParsingDebugInfo = pydantic.Field( + default_factory=lambda: SqlParsingDebugInfo( + confidence=0, + tables_discovered=0, + table_schemas_resolved=0, + column_error=None, + ), + exclude=True, + ) + + +def _parse_statement(sql: str, dialect: str) -> sqlglot.Expression: + statement = sqlglot.parse_one( + sql, read=dialect, error_level=sqlglot.ErrorLevel.RAISE + ) + return statement + + +def _table_level_lineage( + statement: sqlglot.Expression, + dialect: str, +) -> Tuple[Set[_TableName], Set[_TableName]]: + def _raw_table_name(table: sqlglot.exp.Table) -> _TableName: + return _TableName.from_sqlglot_table(table, dialect=dialect) + + # Generate table-level lineage. + modified = { + _raw_table_name(expr.this) + for expr in statement.find_all( + sqlglot.exp.Create, + sqlglot.exp.Insert, + sqlglot.exp.Update, + sqlglot.exp.Delete, + sqlglot.exp.Merge, + ) + # In some cases like "MERGE ... then INSERT (col1, col2) VALUES (col1, col2)", + # the `this` on the INSERT part isn't a table. + if isinstance(expr.this, sqlglot.exp.Table) + } + + tables = ( + {_raw_table_name(table) for table in statement.find_all(sqlglot.exp.Table)} + # ignore references created in this query + - modified + # ignore CTEs created in this statement + - { + _TableName(database=None, schema=None, table=cte.alias_or_name) + for cte in statement.find_all(sqlglot.exp.CTE) + } + ) + # TODO: If a CTAS has "LIMIT 0", it's not really lineage, just copying the schema. + + return tables, modified + + +class SchemaResolver: + def __init__( + self, + *, + platform: str, + platform_instance: Optional[str] = None, + env: str = DEFAULT_ENV, + graph: Optional[DataHubGraph] = None, + _cache_filename: Optional[pathlib.Path] = None, + ): + # TODO handle platforms when prefixed with urn:li:dataPlatform: + self.platform = platform + self.platform_instance = platform_instance + self.env = env + + self.graph = graph + + # Init cache, potentially restoring from a previous run. + shared_conn = None + if _cache_filename: + shared_conn = ConnectionWrapper(filename=_cache_filename) + self._schema_cache: FileBackedDict[Optional[SchemaInfo]] = FileBackedDict( + shared_connection=shared_conn, + ) + + def get_urn_for_table(self, table: _TableName, lower: bool = False) -> str: + # TODO: Validate that this is the correct 2/3 layer hierarchy for the platform. + + table_name = ".".join( + filter(None, [table.database, table.db_schema, table.table]) + ) + if lower: + table_name = table_name.lower() + + if self.platform == "bigquery": + # Normalize shard numbers and other BigQuery weirdness. + # TODO check that this is the right way to do it + with contextlib.suppress(IndexError): + table_name = BigqueryTableIdentifier.from_string_name( + table_name + ).get_table_name() + + urn = make_dataset_urn_with_platform_instance( + platform=self.platform, + platform_instance=self.platform_instance, + env=self.env, + name=table_name, + ) + return urn + + def resolve_table(self, table: _TableName) -> Tuple[str, Optional[SchemaInfo]]: + urn = self.get_urn_for_table(table) + + schema_info = self._resolve_schema_info(urn) + if schema_info: + return urn, schema_info + + urn_lower = self.get_urn_for_table(table, lower=True) + if urn_lower != urn: + schema_info = self._resolve_schema_info(urn_lower) + if schema_info: + return urn_lower, schema_info + + return urn_lower, None + + def _resolve_schema_info(self, urn: str) -> Optional[SchemaInfo]: + if urn in self._schema_cache: + return self._schema_cache[urn] + + # TODO: For bigquery partitioned tables, add the pseudo-column _PARTITIONTIME + # or _PARTITIONDATE where appropriate. + + if self.graph: + schema_info = self._fetch_schema_info(self.graph, urn) + if schema_info: + self._save_to_cache(urn, schema_info) + return schema_info + + self._save_to_cache(urn, None) + return None + + def add_schema_metadata( + self, urn: str, schema_metadata: SchemaMetadataClass + ) -> None: + schema_info = self._convert_schema_aspect_to_info(schema_metadata) + self._save_to_cache(urn, schema_info) + + def add_raw_schema_info(self, urn: str, schema_info: SchemaInfo) -> None: + self._save_to_cache(urn, schema_info) + + def _save_to_cache(self, urn: str, schema_info: Optional[SchemaInfo]) -> None: + self._schema_cache[urn] = schema_info + + def _fetch_schema_info(self, graph: DataHubGraph, urn: str) -> Optional[SchemaInfo]: + aspect = graph.get_aspect(urn, SchemaMetadataClass) + if not aspect: + return None + + return self._convert_schema_aspect_to_info(aspect) + + @classmethod + def _convert_schema_aspect_to_info( + cls, schema_metadata: SchemaMetadataClass + ) -> SchemaInfo: + return { + DatasetUrn._get_simple_field_path_from_v2_field_path(col.fieldPath): ( + # The actual types are more of a "nice to have". + col.nativeDataType + or "str" + ) + for col in schema_metadata.fields + # TODO: We can't generate lineage to columns nested within structs yet. + if "." + not in DatasetUrn._get_simple_field_path_from_v2_field_path(col.fieldPath) + } + + # TODO add a method to load all from graphql + + +# TODO: Once PEP 604 is supported (Python 3.10), we can unify these into a +# single type. See https://peps.python.org/pep-0604/#isinstance-and-issubclass. +_SupportedColumnLineageTypes = Union[ + # Note that Select and Union inherit from Subqueryable. + sqlglot.exp.Subqueryable, + # For actual subqueries, the statement type might also be DerivedTable. + sqlglot.exp.DerivedTable, +] +_SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable) + + +class UnsupportedStatementTypeError(TypeError): + pass + + +class SqlUnderstandingError(Exception): + # Usually hit when we need schema info for a given statement but don't have it. + pass + + +# TODO: Break this up into smaller functions. +def _column_level_lineage( # noqa: C901 + statement: sqlglot.exp.Expression, + dialect: str, + input_tables: Dict[_TableName, SchemaInfo], + output_table: Optional[_TableName], + default_db: Optional[str], + default_schema: Optional[str], +) -> List[_ColumnLineageInfo]: + if not isinstance( + statement, + _SupportedColumnLineageTypesTuple, + ): + raise UnsupportedStatementTypeError( + f"Can only generate column-level lineage for select-like inner statements, not {type(statement)}" + ) + + use_case_insensitive_cols = dialect in { + # Column identifiers are case-insensitive in BigQuery, so we need to + # do a normalization step beforehand to make sure it's resolved correctly. + "bigquery", + # Our snowflake source lowercases column identifiers, so we are forced + # to do fuzzy (case-insensitive) resolution instead of exact resolution. + "snowflake", + } + + sqlglot_db_schema = sqlglot.MappingSchema( + dialect=dialect, + # We do our own normalization, so don't let sqlglot do it. + normalize=False, + ) + table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict( + dict + ) + for table, table_schema in input_tables.items(): + normalized_table_schema: SchemaInfo = {} + for col, col_type in table_schema.items(): + if use_case_insensitive_cols: + col_normalized = ( + # This is required to match Sqlglot's behavior. + col.upper() + if dialect in {"snowflake"} + else col.lower() + ) + else: + col_normalized = col + + table_schema_normalized_mapping[table][col_normalized] = col + normalized_table_schema[col_normalized] = col_type + + sqlglot_db_schema.add_table( + table.as_sqlglot_table(), + column_mapping=normalized_table_schema, + ) + + if use_case_insensitive_cols: + + def _sqlglot_force_column_normalizer( + node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None + ) -> sqlglot.exp.Expression: + if isinstance(node, sqlglot.exp.Column): + node.this.set("quoted", False) + + return node + + # logger.debug( + # "Prior to case normalization sql %s", + # statement.sql(pretty=True, dialect=dialect), + # ) + statement = statement.transform( + _sqlglot_force_column_normalizer, dialect, copy=False + ) + # logger.debug( + # "Sql after casing normalization %s", + # statement.sql(pretty=True, dialect=dialect), + # ) + + # Optimize the statement + qualify column references. + logger.debug( + "Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect) + ) + try: + # Second time running qualify, this time with: + # - the select instead of the full outer statement + # - schema info + # - column qualification enabled + + # logger.debug("Schema: %s", sqlglot_db_schema.mapping) + statement = sqlglot.optimizer.qualify.qualify( + statement, + dialect=dialect, + schema=sqlglot_db_schema, + validate_qualify_columns=False, + identify=True, + # sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table". + catalog=default_db, + db=default_schema, + ) + except (sqlglot.errors.OptimizeError, ValueError) as e: + raise SqlUnderstandingError( + f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}" + ) from e + logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect)) + + column_lineage = [] + + try: + assert isinstance(statement, _SupportedColumnLineageTypesTuple) + + # List output columns. + output_columns = [ + (select_col.alias_or_name, select_col) for select_col in statement.selects + ] + logger.debug("output columns: %s", [col[0] for col in output_columns]) + output_col: str + for output_col, original_col_expression in output_columns: + # print(f"output column: {output_col}") + if output_col == "*": + # If schema information is available, the * will be expanded to the actual columns. + # Otherwise, we can't process it. + continue + + if dialect == "bigquery" and output_col.lower() in { + "_partitiontime", + "_partitiondate", + }: + # These are not real columns, just a way to filter by partition. + # TODO: We should add these columns to the schema info instead. + # Once that's done, we should actually generate lineage for these + # if they appear in the output. + continue + + lineage_node = sqlglot.lineage.lineage( + output_col, + statement, + dialect=dialect, + schema=sqlglot_db_schema, + ) + # pathlib.Path("sqlglot.html").write_text( + # str(lineage_node.to_html(dialect=dialect)) + # ) + + # Generate SELECT lineage. + # Using a set here to deduplicate upstreams. + direct_col_upstreams: Set[_ColumnRef] = set() + for node in lineage_node.walk(): + if node.downstream: + # We only want the leaf nodes. + pass + + elif isinstance(node.expression, sqlglot.exp.Table): + table_ref = _TableName.from_sqlglot_table( + node.expression, dialect=dialect + ) + + # Parse the column name out of the node name. + # Sqlglot calls .sql(), so we have to do the inverse. + normalized_col = sqlglot.parse_one(node.name).this.name + if node.subfield: + normalized_col = f"{normalized_col}.{node.subfield}" + col = table_schema_normalized_mapping[table_ref].get( + normalized_col, normalized_col + ) + + direct_col_upstreams.add(_ColumnRef(table=table_ref, column=col)) + else: + # This branch doesn't matter. For example, a count(*) column would go here, and + # we don't get any column-level lineage for that. + pass + + # column_logic = lineage_node.source + + if output_col.startswith("_col_"): + # This is the format sqlglot uses for unnamed columns e.g. 'count(id)' -> 'count(id) AS _col_0' + # This is a bit jank since we're relying on sqlglot internals, but it seems to be + # the best way to do it. + output_col = original_col_expression.this.sql(dialect=dialect) + if not direct_col_upstreams: + logger.debug(f' "{output_col}" has no upstreams') + column_lineage.append( + _ColumnLineageInfo( + downstream=_DownstreamColumnRef( + table=output_table, column=output_col + ), + upstreams=sorted(direct_col_upstreams), + # logic=column_logic.sql(pretty=True, dialect=dialect), + ) + ) + + # TODO: Also extract referenced columns (e.g. non-SELECT lineage) + except (sqlglot.errors.OptimizeError, ValueError) as e: + raise SqlUnderstandingError( + f"sqlglot failed to compute some lineage: {e}" + ) from e + + return column_lineage + + +def _extract_select_from_create( + statement: sqlglot.exp.Create, +) -> sqlglot.exp.Expression: + # TODO: Validate that this properly includes WITH clauses in all dialects. + inner = statement.expression + + if inner: + return inner + else: + return statement + + +def _try_extract_select( + statement: sqlglot.exp.Expression, +) -> sqlglot.exp.Expression: + # Try to extract the core select logic from a more complex statement. + # If it fails, just return the original statement. + + if isinstance(statement, sqlglot.exp.Merge): + # TODO Need to map column renames in the expressions part of the statement. + # Likely need to use the named_selects attr. + statement = statement.args["using"] + if isinstance(statement, sqlglot.exp.Table): + # If we're querying a table directly, wrap it in a SELECT. + statement = sqlglot.exp.Select().select("*").from_(statement) + elif isinstance(statement, sqlglot.exp.Insert): + # TODO Need to map column renames in the expressions part of the statement. + statement = statement.expression + elif isinstance(statement, sqlglot.exp.Create): + # TODO May need to map column renames. + # Assumption: the output table is already captured in the modified tables list. + statement = _extract_select_from_create(statement) + + if isinstance(statement, sqlglot.exp.Subquery): + statement = statement.unnest() + + return statement + + +def _translate_internal_column_lineage( + table_name_urn_mapping: Dict[_TableName, str], + raw_column_lineage: _ColumnLineageInfo, +) -> ColumnLineageInfo: + downstream_urn = None + if raw_column_lineage.downstream.table: + downstream_urn = table_name_urn_mapping[raw_column_lineage.downstream.table] + return ColumnLineageInfo( + downstream=DownstreamColumnRef( + table=downstream_urn, + column=raw_column_lineage.downstream.column, + ), + upstreams=[ + ColumnRef( + table=table_name_urn_mapping[upstream.table], + column=upstream.column, + ) + for upstream in raw_column_lineage.upstreams + ], + logic=raw_column_lineage.logic, + ) + + +def sqlglot_lineage( + sql: str, + platform: str, + schema_resolver: SchemaResolver, + default_db: Optional[str] = None, + default_schema: Optional[str] = None, +) -> SqlParsingResult: + # TODO: convert datahub platform names to sqlglot dialect + dialect = platform + + if dialect == "snowflake": + # in snowflake, table identifiers must be uppercased to match sqlglot's behavior. + if default_db: + default_db = default_db.upper() + if default_schema: + default_schema = default_schema.upper() + + logger.debug("Parsing lineage from sql statement: %s", sql) + statement = _parse_statement(sql, dialect=dialect) + + original_statement = statement.copy() + # logger.debug( + # "Formatted sql statement: %s", + # original_statement.sql(pretty=True, dialect=dialect), + # ) + + # Make sure the tables are resolved with the default db / schema. + # This only works for Unionable statements. For other types of statements, + # we have to do it manually afterwards, but that's slightly lower accuracy + # because of CTEs. + statement = sqlglot.optimizer.qualify.qualify( + statement, + dialect=dialect, + # sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table". + catalog=default_db, + db=default_schema, + # At this stage we only want to qualify the table names. The columns will be dealt with later. + qualify_columns=False, + validate_qualify_columns=False, + # Only insert quotes where necessary. + identify=False, + ) + + # Generate table-level lineage. + tables, modified = _table_level_lineage(statement, dialect=dialect) + + # Prep for generating column-level lineage. + downstream_table: Optional[_TableName] = None + if len(modified) == 1: + downstream_table = next(iter(modified)) + + # Fetch schema info for the relevant tables. + table_name_urn_mapping: Dict[_TableName, str] = {} + table_name_schema_mapping: Dict[_TableName, SchemaInfo] = {} + for table, is_input in itertools.chain( + [(table, True) for table in tables], + [(table, False) for table in modified], + ): + # For select statements, qualification will be a no-op. For other statements, this + # is where the qualification actually happens. + qualified_table = table.qualified( + dialect=dialect, default_db=default_db, default_schema=default_schema + ) + + urn, schema_info = schema_resolver.resolve_table(qualified_table) + + table_name_urn_mapping[qualified_table] = urn + if is_input and schema_info: + table_name_schema_mapping[qualified_table] = schema_info + + # Also include the original, non-qualified table name in the urn mapping. + table_name_urn_mapping[table] = urn + + debug_info = SqlParsingDebugInfo( + confidence=0.9 if len(tables) == len(table_name_schema_mapping) + # If we're missing any schema info, our confidence will be in the 0.2-0.5 range depending + # on how many tables we were able to resolve. + else 0.2 + 0.3 * len(table_name_schema_mapping) / len(tables), + tables_discovered=len(tables), + table_schemas_resolved=len(table_name_schema_mapping), + ) + logger.debug( + f"Resolved {len(table_name_schema_mapping)} of {len(tables)} table schemas" + ) + + # Simplify the input statement for column-level lineage generation. + select_statement = _try_extract_select(statement) + + # Generate column-level lineage. + column_lineage: Optional[List[_ColumnLineageInfo]] = None + try: + column_lineage = _column_level_lineage( + select_statement, + dialect=dialect, + input_tables=table_name_schema_mapping, + output_table=downstream_table, + default_db=default_db, + default_schema=default_schema, + ) + except UnsupportedStatementTypeError as e: + # Inject details about the outer statement type too. + e.args = (f"{e.args[0]} (outer statement type: {type(statement)})",) + debug_info.column_error = e + logger.debug(debug_info.column_error) + except SqlUnderstandingError as e: + logger.debug(f"Failed to generate column-level lineage: {e}", exc_info=True) + debug_info.column_error = e + + # TODO: Can we generate a common JOIN tables / keys section? + # TODO: Can we generate a common WHERE clauses section? + + # Convert TableName to urns. + in_urns = sorted(set(table_name_urn_mapping[table] for table in tables)) + out_urns = sorted(set(table_name_urn_mapping[table] for table in modified)) + column_lineage_urns = None + if column_lineage: + column_lineage_urns = [ + _translate_internal_column_lineage( + table_name_urn_mapping, internal_col_lineage + ) + for internal_col_lineage in column_lineage + ] + + return SqlParsingResult( + query_type=get_query_type_of_sql(original_statement), + in_tables=in_urns, + out_tables=out_urns, + column_lineage=column_lineage_urns, + debug_info=debug_info, + ) diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_create_view_with_cte.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_create_view_with_cte.json new file mode 100644 index 0000000000..e50d944ce7 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_create_view_with_cte.json @@ -0,0 +1,61 @@ +{ + "query_type": "CREATE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)", + "column": "col5" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)", + "column": "col5" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)", + "column": "col1" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)", + "column": "col1" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)", + "column": "col2" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)", + "column": "col2" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)", + "column": "col3" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)", + "column": "col3" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_from_sharded_table_wildcard.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_from_sharded_table_wildcard.json new file mode 100644 index 0000000000..78591286fe --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_from_sharded_table_wildcard.json @@ -0,0 +1,33 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "col1" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)", + "column": "col1" + } + ] + }, + { + "downstream": { + "table": null, + "column": "col2" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)", + "column": "col2" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_nested_subqueries.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_nested_subqueries.json new file mode 100644 index 0000000000..0e93d31fbb --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_nested_subqueries.json @@ -0,0 +1,33 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "col1" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)", + "column": "col1" + } + ] + }, + { + "downstream": { + "table": null, + "column": "col2" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)", + "column": "col2" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_sharded_table_normalization.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_sharded_table_normalization.json new file mode 100644 index 0000000000..78591286fe --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_sharded_table_normalization.json @@ -0,0 +1,33 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "col1" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)", + "column": "col1" + } + ] + }, + { + "downstream": { + "table": null, + "column": "col2" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)", + "column": "col2" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_unnest_columns.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_unnest_columns.json new file mode 100644 index 0000000000..69eb6f4ea6 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_bigquery_unnest_columns.json @@ -0,0 +1,90 @@ +{ + "query_type": "SELECT", + "in_tables": [ + { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table1" + }, + { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table2" + } + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "day" + }, + "upstreams": [ + { + "table": { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table1" + }, + "column": "reporting_day" + } + ] + }, + { + "downstream": { + "table": null, + "column": "product" + }, + "upstreams": [ + { + "table": { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table1" + }, + "column": "by_product.product_code" + }, + { + "table": { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table2" + }, + "column": "other_field" + } + ] + }, + { + "downstream": { + "table": null, + "column": "other_field" + }, + "upstreams": [ + { + "table": { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table2" + }, + "column": "other_field" + } + ] + }, + { + "downstream": { + "table": null, + "column": "daily_active_users" + }, + "upstreams": [ + { + "table": { + "database": "bq-proj", + "db_schema": "dataset", + "table": "table1" + }, + "column": "by_product.product_code_dau" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_create_view_as_select.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_create_view_as_select.json new file mode 100644 index 0000000000..22bb78dc86 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_create_view_as_select.json @@ -0,0 +1,42 @@ +{ + "query_type": "CREATE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)", + "column": "Department" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)", + "column": "deptno" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)", + "column": "Employees" + }, + "upstreams": [] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,vsal,PROD)", + "column": "Salary" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,scott.emp,PROD)", + "column": "sal" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_expand_select_star_basic.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_expand_select_star_basic.json new file mode 100644 index 0000000000..e456e4450c --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_expand_select_star_basic.json @@ -0,0 +1,129 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "TOTAL_AGG" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "totalprice" + } + ] + }, + { + "downstream": { + "table": null, + "column": "ORDERKEY" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "orderkey" + } + ] + }, + { + "downstream": { + "table": null, + "column": "CUSTKEY" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "custkey" + } + ] + }, + { + "downstream": { + "table": null, + "column": "ORDERSTATUS" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "orderstatus" + } + ] + }, + { + "downstream": { + "table": null, + "column": "TOTALPRICE" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "totalprice" + } + ] + }, + { + "downstream": { + "table": null, + "column": "ORDERDATE" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "orderdate" + } + ] + }, + { + "downstream": { + "table": null, + "column": "ORDERPRIORITY" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "orderpriority" + } + ] + }, + { + "downstream": { + "table": null, + "column": "CLERK" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "clerk" + } + ] + }, + { + "downstream": { + "table": null, + "column": "SHIPPRIORITY" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "shippriority" + } + ] + }, + { + "downstream": { + "table": null, + "column": "COMMENT" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "comment" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_as_select.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_as_select.json new file mode 100644 index 0000000000..d7264fd2db --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_as_select.json @@ -0,0 +1,76 @@ +{ + "query_type": "INSERT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:hive,catalog_returns,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,catalog_sales,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,customer_demographics,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,date_dim,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,household_demographics,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,inventory,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,item,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:hive,warehouse,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "i_item_desc" + }, + "upstreams": [] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "w_warehouse_name" + }, + "upstreams": [] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "d_week_seq" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,date_dim,PROD)", + "column": "d_week_seq" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "no_promo" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)", + "column": "p_promo_sk" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "promo" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,promotion,PROD)", + "column": "p_promo_sk" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:hive,query72,PROD)", + "column": "total_cnt" + }, + "upstreams": [] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_merge_from_union.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_merge_from_union.json new file mode 100644 index 0000000000..ec8599353f --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_merge_from_union.json @@ -0,0 +1,11 @@ +{ + "query_type": "MERGE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.prep_from_ios,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.prep_from_web,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-pipelines-stg.referrer.base_union,PROD)" + ], + "column_lineage": null +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_count.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_count.json new file mode 100644 index 0000000000..9f6eeae46c --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_count.json @@ -0,0 +1,21 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:mysql,something_prd.fact_complaint_snapshot,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "COUNT(`fact_complaint_snapshot`.`etl_data_dt_id`)" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:mysql,something_prd.fact_complaint_snapshot,PROD)", + "column": "etl_data_dt_id" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_struct_subfields.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_struct_subfields.json new file mode 100644 index 0000000000..109de96180 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_struct_subfields.json @@ -0,0 +1,49 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "post_id" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)", + "column": "post_id" + } + ] + }, + { + "downstream": { + "table": null, + "column": "id" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)", + "column": "widget.asset.id" + } + ] + }, + { + "downstream": { + "table": null, + "column": "min_metric" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)", + "column": "widget.metric.metricA" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)", + "column": "widget.metric.metric_b" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_union.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_union.json new file mode 100644 index 0000000000..8e1fd453ce --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_from_union.json @@ -0,0 +1,33 @@ +{ + "query_type": "UNKNOWN", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "LABEL" + }, + "upstreams": [] + }, + { + "downstream": { + "table": null, + "column": "TOTAL_AGG" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)", + "column": "totalprice" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)", + "column": "totalprice" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_max.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_max.json new file mode 100644 index 0000000000..326c07d332 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_max.json @@ -0,0 +1,25 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "max_col" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)", + "column": "col1" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)", + "column": "col2" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_ctes.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_ctes.json new file mode 100644 index 0000000000..4647b27934 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_ctes.json @@ -0,0 +1,34 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:oracle,table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:oracle,table2,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "col1" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,table1,PROD)", + "column": "col1" + } + ] + }, + { + "downstream": { + "table": null, + "column": "col3" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:oracle,table2,PROD)", + "column": "col3" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_full_col_name.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_full_col_name.json new file mode 100644 index 0000000000..c12ad23b2f --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_select_with_full_col_name.json @@ -0,0 +1,33 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "post_id" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)", + "column": "post_id" + } + ] + }, + { + "downstream": { + "table": null, + "column": "id" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)", + "column": "widget.asset.id" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_column_normalization.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_column_normalization.json new file mode 100644 index 0000000000..694bec3800 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_column_normalization.json @@ -0,0 +1,57 @@ +{ + "query_type": "SELECT", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)" + ], + "out_tables": [], + "column_lineage": [ + { + "downstream": { + "table": null, + "column": "TOTAL_AGG" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "TotalPrice" + } + ] + }, + { + "downstream": { + "table": null, + "column": "TOTAL_AVG" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "TotalPrice" + } + ] + }, + { + "downstream": { + "table": null, + "column": "TOTAL_MIN" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "TotalPrice" + } + ] + }, + { + "downstream": { + "table": null, + "column": "TOTAL_MAX" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "TotalPrice" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_default_normalization.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_default_normalization.json new file mode 100644 index 0000000000..1577458541 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_default_normalization.json @@ -0,0 +1,88 @@ +{ + "query_type": "CREATE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "USER_FK" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)", + "column": "USER_FK" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "EMAIL" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)", + "column": "EMAIL" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "LAST_PURCHASE_DATE" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)", + "column": "LAST_PURCHASE_DATE" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "LIFETIME_PURCHASE_AMOUNT" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)", + "column": "purchase_amount" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "LIFETIME_PURCHASE_COUNT" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)", + "column": "pk" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.active_customer_ltv,PROD)", + "column": "AVERAGE_PURCHASE_AMOUNT" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)", + "column": "pk" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)", + "column": "purchase_amount" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py new file mode 100644 index 0000000000..b7ef62ba82 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py @@ -0,0 +1,479 @@ +import pathlib + +import pytest + +from datahub.testing.check_sql_parser_result import assert_sql_result + +RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens" + + +def test_select_max(): + # The COL2 should get normalized to col2. + assert_sql_result( + """ +SELECT max(col1, COL2) as max_col +FROM mytable +""", + dialect="mysql", + expected_file=RESOURCE_DIR / "test_select_max.json", + ) + + +def test_select_max_with_schema(): + # Note that `this_will_not_resolve` will be dropped from the result because it's not in the schema. + assert_sql_result( + """ +SELECT max(`col1`, COL2, `this_will_not_resolve`) as max_col +FROM mytable +""", + dialect="mysql", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:mysql,mytable,PROD)": { + "col1": "NUMBER", + "col2": "NUMBER", + }, + }, + # Shared with the test above. + expected_file=RESOURCE_DIR / "test_select_max.json", + ) + + +def test_select_count(): + assert_sql_result( + """ +SELECT + COUNT(etl_data_dt_id) +FROM something_prd.fact_complaint_snapshot +WHERE + etl_data_dt_id = 20230317 +""", + dialect="mysql", + expected_file=RESOURCE_DIR / "test_select_count.json", + ) + + +def test_select_with_ctes(): + assert_sql_result( + """ +WITH cte1 AS ( + SELECT col1, col2 + FROM table1 + WHERE col1 = 'value1' +), cte2 AS ( + SELECT col3, col4 + FROM table2 + WHERE col2 = 'value2' +) +SELECT cte1.col1, cte2.col3 +FROM cte1 +JOIN cte2 ON cte1.col2 = cte2.col4 +""", + dialect="oracle", + expected_file=RESOURCE_DIR / "test_select_with_ctes.json", + ) + + +def test_create_view_as_select(): + assert_sql_result( + """ +CREATE VIEW vsal +AS + SELECT a.deptno "Department", + a.num_emp / b.total_count "Employees", + a.sal_sum / b.total_sal "Salary" + FROM (SELECT deptno, + Count() num_emp, + SUM(sal) sal_sum + FROM scott.emp + WHERE city = 'NYC' + GROUP BY deptno) a, + (SELECT Count() total_count, + SUM(sal) total_sal + FROM scott.emp + WHERE city = 'NYC') b +; +""", + dialect="oracle", + expected_file=RESOURCE_DIR / "test_create_view_as_select.json", + ) + + +def test_insert_as_select(): + # Note: this also tests lineage with case statements. + + assert_sql_result( + """ +insert into query72 +select i_item_desc + , w_warehouse_name + , d1.d_week_seq + , sum(case when promotion.p_promo_sk is null then 1 else 0 end) no_promo + , sum(case when promotion.p_promo_sk is not null then 1 else 0 end) promo + , count(*) total_cnt +from catalog_sales + join inventory on (cs_item_sk = inv_item_sk) + join warehouse on (w_warehouse_sk = inv_warehouse_sk) + join item on (i_item_sk = cs_item_sk) + join customer_demographics on (cs_bill_cdemo_sk = cd_demo_sk) + join household_demographics on (cs_bill_hdemo_sk = hd_demo_sk) + join date_dim d1 on (cs_sold_date_sk = d1.d_date_sk) + join date_dim d2 on (inv_date_sk = d2.d_date_sk) + join date_dim d3 on (cs_ship_date_sk = d3.d_date_sk) + left outer join promotion on (cs_promo_sk = p_promo_sk) + left outer join catalog_returns on (cr_item_sk = cs_item_sk and cr_order_number = cs_order_number) +where d1.d_week_seq = d2.d_week_seq + and inv_quantity_on_hand < cs_quantity + and hd_buy_potential = '>10000' + and cd_marital_status = 'D' +group by i_item_desc, w_warehouse_name, d1.d_week_seq +order by total_cnt desc, i_item_desc, w_warehouse_name, d_week_seq +limit 100; +""", + dialect="hive", + expected_file=RESOURCE_DIR / "test_insert_as_select.json", + ) + + +def test_select_with_full_col_name(): + # In this case, `widget` is a struct column. + # This also tests the `default_db` functionality. + assert_sql_result( + """ +SELECT distinct post_id , widget.asset.id +FROM data_reporting.abcde_transformed +WHERE post_id LIKE '%268662%' +""", + dialect="bigquery", + default_db="my-bq-ProjectName", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-ProjectName.data_reporting.abcde_transformed,PROD)": { + "post_id": "NUMBER", + "widget": "struct", + }, + }, + expected_file=RESOURCE_DIR / "test_select_with_full_col_name.json", + ) + + +def test_select_from_struct_subfields(): + # In this case, `widget` is a column name. + assert_sql_result( + """ +SELECT distinct post_id , + widget.asset.id, + min(widget.metric.metricA, widget.metric.metric_b) as min_metric +FROM data_reporting.abcde_transformed +WHERE post_id LIKE '%12345%' +""", + dialect="bigquery", + default_db="my-bq-proj", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-bq-proj.data_reporting.abcde_transformed,PROD)": { + "post_id": "NUMBER", + "widget": "struct", + "widget.asset.id": "int", + "widget.metric.metricA": "int", + "widget.metric.metric_b": "int", + }, + }, + expected_file=RESOURCE_DIR / "test_select_from_struct_subfields.json", + ) + + +def test_select_from_union(): + assert_sql_result( + """ +SELECT + 'orders_10' as label, + SUM(totalprice) as total_agg +FROM snowflake_sample_data.tpch_sf10.orders +UNION ALL +SELECT + 'orders_100' as label, + SUM(totalprice) as total_agg, +FROM snowflake_sample_data.tpch_sf100.orders +""", + dialect="snowflake", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf10.orders,PROD)": { + "orderkey": "NUMBER", + "totalprice": "FLOAT", + }, + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf100.orders,PROD)": { + "orderkey": "NUMBER", + "totalprice": "FLOAT", + }, + }, + expected_file=RESOURCE_DIR / "test_select_from_union.json", + ) + + +def test_merge_from_union(): + # TODO: We don't support merge statements yet, but the union should still get handled. + + assert_sql_result( + """ + merge into `demo-pipelines-stg`.`referrer`.`base_union` as DBT_INTERNAL_DEST + using ( +SELECT * FROM `demo-pipelines-stg`.`referrer`.`prep_from_ios` WHERE partition_time = "2018-03-03" +UNION ALL +SELECT * FROM `demo-pipelines-stg`.`referrer`.`prep_from_web` WHERE partition_time = "2018-03-03" + ) as DBT_INTERNAL_SOURCE + on FALSE + + when not matched by source + and timestamp_trunc(DBT_INTERNAL_DEST.partition_time, day) in ( + timestamp('2018-03-03') + ) + then delete + + when not matched then insert + (`platform`, `pageview_id`, `query`, `referrer`, `partition_time`) + values + (`platform`, `pageview_id`, `query`, `referrer`, `partition_time`) +""", + dialect="bigquery", + expected_file=RESOURCE_DIR / "test_merge_from_union.json", + ) + + +def test_expand_select_star_basic(): + assert_sql_result( + """ +SELECT + SUM(totalprice) as total_agg, + * +FROM snowflake_sample_data.tpch_sf1.orders +WHERE orderdate = '1992-01-01' +""", + dialect="snowflake", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": { + "orderkey": "NUMBER", + "custkey": "NUMBER", + "orderstatus": "TEXT", + "totalprice": "FLOAT", + "orderdate": "DATE", + "orderpriority": "TEXT", + "clerk": "TEXT", + "shippriority": "NUMBER", + "comment": "TEXT", + }, + }, + expected_file=RESOURCE_DIR / "test_expand_select_star_basic.json", + ) + + +def test_snowflake_column_normalization(): + # Technically speaking this is incorrect since the column names are different and both quoted. + + assert_sql_result( + """ +SELECT + SUM(o."totalprice") as total_agg, + AVG("TotalPrice") as total_avg, + MIN("TOTALPRICE") as total_min, + MAX(TotalPrice) as total_max +FROM snowflake_sample_data.tpch_sf1.orders o +""", + dialect="snowflake", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": { + "orderkey": "NUMBER", + "TotalPrice": "FLOAT", + }, + }, + expected_file=RESOURCE_DIR / "test_snowflake_column_normalization.json", + ) + + +@pytest.mark.skip(reason="We don't handle the unnest lineage correctly") +def test_bigquery_unnest_columns(): + assert_sql_result( + """ +SELECT + DATE(reporting_day) AS day, + CASE + WHEN p.product_code IN ('A', 'B', 'C') + THEN pr.other_field + ELSE 'Other' + END AS product, + pr.other_field AS other_field, + SUM(p.product_code_dau) AS daily_active_users +FROM `bq-proj`.dataset.table1 +LEFT JOIN UNNEST(by_product) AS p +LEFT JOIN ( + SELECT DISTINCT + product_code, + other_field + FROM `bq-proj`.dataset.table2 +) AS pr +-- USING (product_code) +ON p.product_code = pr.product_code +""", + dialect="bigquery", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)": { + "reporting_day": "DATE", + "by_product": "ARRAY>", + }, + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table2,PROD)": { + "product_code": "STRING", + "other_field": "STRING", + }, + }, + expected_file=RESOURCE_DIR / "test_bigquery_unnest_columns.json", + ) + + +def test_bigquery_create_view_with_cte(): + assert_sql_result( + """ +CREATE VIEW `my-proj-2`.dataset.my_view AS +WITH cte1 AS ( + SELECT * + FROM dataset.table1 + WHERE col1 = 'value1' +), cte2 AS ( + SELECT col3, col4 as join_key + FROM dataset.table2 + WHERE col3 = 'value2' +) +SELECT col5, cte1.*, col3 +FROM dataset.table3 +JOIN cte1 ON table3.col5 = cte1.col2 +JOIN cte2 USING (join_key) +""", + dialect="bigquery", + default_db="my-proj-1", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table1,PROD)": { + "col1": "STRING", + "col2": "STRING", + }, + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table2,PROD)": { + "col3": "STRING", + "col4": "STRING", + }, + "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-1.dataset.table3,PROD)": { + "col5": "STRING", + "join_key": "STRING", + }, + }, + expected_file=RESOURCE_DIR / "test_bigquery_create_view_with_cte.json", + ) + + +def test_bigquery_nested_subqueries(): + assert_sql_result( + """ +SELECT * +FROM ( + SELECT * + FROM ( + SELECT * + FROM `bq-proj`.dataset.table1 + ) +) +""", + dialect="bigquery", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table1,PROD)": { + "col1": "STRING", + "col2": "STRING", + }, + }, + expected_file=RESOURCE_DIR / "test_bigquery_nested_subqueries.json", + ) + + +def test_bigquery_sharded_table_normalization(): + assert_sql_result( + """ +SELECT * +FROM `bq-proj.dataset.table_20230101` +""", + dialect="bigquery", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)": { + "col1": "STRING", + "col2": "STRING", + }, + }, + expected_file=RESOURCE_DIR / "test_bigquery_sharded_table_normalization.json", + ) + + +def test_bigquery_from_sharded_table_wildcard(): + assert_sql_result( + """ +SELECT * +FROM `bq-proj.dataset.table_2023*` +""", + dialect="bigquery", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:bigquery,bq-proj.dataset.table_yyyymmdd,PROD)": { + "col1": "STRING", + "col2": "STRING", + }, + }, + expected_file=RESOURCE_DIR / "test_bigquery_from_sharded_table_wildcard.json", + ) + + +def test_snowflake_default_normalization(): + assert_sql_result( + """ +create table active_customer_ltv as ( + +with active_customers as ( + select * from customer_last_purchase_date + where + last_purchase_date >= current_date - interval '90 days' +) + +, purchases as ( + select * from ecommerce.purchases +) + +select + active_customers.user_fk + , active_customers.email + , active_customers.last_purchase_date + , sum(purchases.purchase_amount) as lifetime_purchase_amount + , count(distinct(purchases.pk)) as lifetime_purchase_count + , sum(purchases.purchase_amount) / count(distinct(purchases.pk)) as average_purchase_amount +from + active_customers +join + purchases + on active_customers.user_fk = purchases.user_fk +group by 1,2,3 + +) +""", + dialect="snowflake", + default_db="long_tail_companions", + default_schema="analytics", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.ecommerce.purchases,PROD)": { + "pk": "NUMBER(38,0)", + "USER_FK": "NUMBER(38,0)", + "status": "VARCHAR(16777216)", + "purchase_amount": "NUMBER(10,2)", + "tax_AMOUNT": "NUMBER(10,2)", + "TOTAL_AMOUNT": "NUMBER(10,2)", + "CREATED_AT": "TIMESTAMP_NTZ", + "UPDATED_AT": "TIMESTAMP_NTZ", + }, + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.analytics.customer_last_purchase_date,PROD)": { + "USER_FK": "NUMBER(38,0)", + "EMAIL": "VARCHAR(16777216)", + "LAST_PURCHASE_DATE": "DATE", + }, + }, + expected_file=RESOURCE_DIR / "test_snowflake_default_normalization.json", + ) + + +# TODO: Add a test for setting platform_instance or env diff --git a/metadata-ingestion/tests/unit/test_bigquery_source.py b/metadata-ingestion/tests/unit/test_bigquery_source.py index 49dc66b232..3efca8d088 100644 --- a/metadata-ingestion/tests/unit/test_bigquery_source.py +++ b/metadata-ingestion/tests/unit/test_bigquery_source.py @@ -13,6 +13,7 @@ from google.cloud.bigquery.table import Row, TableListItem from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.bigquery_v2.bigquery import BigqueryV2Source from datahub.ingestion.source.bigquery_v2.bigquery_audit import ( + _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX, BigqueryTableIdentifier, BigQueryTableRef, ) @@ -652,7 +653,7 @@ def test_get_table_and_shard_default( ) -> None: with patch( "datahub.ingestion.source.bigquery_v2.bigquery_audit.BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX", - "((.+)[_$])?(\\d{8})$", + _BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX, ): assert BigqueryTableIdentifier.get_table_and_shard(table_name) == ( expected_table_prefix,