Mask SQL Queries in Usage & Lineage Workflow (#18565)

This commit is contained in:
Mayur Singal 2024-11-11 11:44:47 +05:30 committed by GitHub
parent 2437d0124e
commit efed932d97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 267 additions and 29 deletions

View File

@ -285,7 +285,7 @@ class MetadataUsageBulkSink(BulkSink):
# TODO: Clean up how we are passing dates from query parsing to here to use timestamps instead of strings # TODO: Clean up how we are passing dates from query parsing to here to use timestamps instead of strings
start_date = datetime.fromtimestamp(int(table_usage.date) / 1000) start_date = datetime.fromtimestamp(int(table_usage.date) / 1000)
table_joins: TableJoins = TableJoins( table_joins: TableJoins = TableJoins(
columnJoins=[], directTableJoins=[], startDate=start_date columnJoins=[], directTableJoins=[], startDate=start_date.date()
) )
column_joins_dict = {} column_joins_dict = {}
for column_join in table_usage.joins: for column_join in table_usage.joins:

View File

@ -0,0 +1,131 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Query masking utilities
"""
import traceback
import sqlparse
from sqlfluff.core import Linter
from sqlparse.sql import Comparison
from sqlparse.tokens import Literal, Number, String
from metadata.ingestion.lineage.models import Dialect
MASK_TOKEN = "?"
def get_logger():
# pylint: disable=import-outside-toplevel
from metadata.utils.logger import utils_logger
return utils_logger()
def mask_literals_with_sqlparse(query: str):
"""
Mask literals in a query using sqlparse.
"""
logger = get_logger()
try:
parsed = sqlparse.parse(query) # Parse the query
if not parsed:
return query
parsed = parsed[0]
def mask_token(token):
# Mask all literals: strings, numbers, or other literal values
if token.ttype in (
String,
Number,
Literal.String.Single,
Literal.Number.Integer,
Literal.Number.Float,
Literal.String.Single,
Literal.String.Symbol,
):
token.value = MASK_TOKEN
elif token.is_group:
# Recursively process grouped tokens
for t in token.tokens:
mask_token(t)
# Process all tokens
for token in parsed.tokens:
if isinstance(token, Comparison):
# In comparisons, mask both sides if literals
for t in token.tokens:
mask_token(t)
else:
mask_token(token)
# Return the formatted masked query
return str(parsed)
except Exception as exc:
logger.debug(f"Failed to mask query with sqlparse: {exc}")
logger.debug(traceback.format_exc())
return query
def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -> str:
"""
Mask literals in a query using SQLFluff.
"""
logger = get_logger()
try:
# Initialize SQLFluff linter
linter = Linter(dialect=dialect)
# Parse the query
parsed = linter.parse_string(query)
def replace_literals(segment):
"""Recursively replace literals with placeholders."""
if segment.is_type("literal", "quoted_literal", "numeric_literal"):
return MASK_TOKEN
if segment.segments:
# Recursively process sub-segments
return "".join(
replace_literals(sub_seg) for sub_seg in segment.segments
)
return segment.raw
# Reconstruct the query with masked literals
masked_query = "".join(
replace_literals(segment) for segment in parsed.tree.segments
)
return masked_query
except Exception as exc:
logger.debug(f"Failed to mask query with sqlfluff: {exc}")
logger.debug(traceback.format_exc())
return query
def mask_query(query: str, dialect: str = Dialect.ANSI.value) -> str:
logger = get_logger()
try:
sqlfluff_masked_query = mask_literals_with_sqlfluff(query, dialect)
sqlparse_masked_query = mask_literals_with_sqlparse(query)
# compare both masked queries and return the one with more masked tokens
if sqlfluff_masked_query.count(MASK_TOKEN) >= sqlparse_masked_query.count(
MASK_TOKEN
):
return sqlfluff_masked_query
return sqlparse_masked_query
except Exception as exc:
logger.debug(f"Failed to mask query with sqlfluff: {exc}")
logger.debug(traceback.format_exc())
return query

View File

@ -26,6 +26,7 @@ from collate_sqllineage.runner import LineageRunner
from sqlparse.sql import Comparison, Identifier, Parenthesis, Statement from sqlparse.sql import Comparison, Identifier, Parenthesis, Statement
from metadata.generated.schema.type.tableUsageCount import TableColumn, TableColumnJoin from metadata.generated.schema.type.tableUsageCount import TableColumn, TableColumnJoin
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import Dialect from metadata.ingestion.lineage.models import Dialect
from metadata.utils.helpers import ( from metadata.utils.helpers import (
find_in_iter, find_in_iter,
@ -69,7 +70,10 @@ class LineageParser:
self.query = query self.query = query
self.query_parsing_success = True self.query_parsing_success = True
self.query_parsing_failure_reason = None self.query_parsing_failure_reason = None
self.dialect = dialect
self._masked_query = mask_query(self.query, dialect.value)
self._clean_query = self.clean_raw_query(query) self._clean_query = self.clean_raw_query(query)
self._masked_clean_query = mask_query(self._clean_query, dialect.value)
self.parser = self._evaluate_best_parser( self.parser = self._evaluate_best_parser(
self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds
) )
@ -91,7 +95,7 @@ class LineageParser:
except SQLLineageException as exc: except SQLLineageException as exc:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.warning( logger.warning(
f"Cannot extract source table information from query [{self.query}]: {exc}" f"Cannot extract source table information from query [{self._masked_query}]: {exc}"
) )
return None return None
@ -333,7 +337,9 @@ class LineageParser:
logger.warning( logger.warning(
f"Can't extract table names when parsing JOIN information from {comparison}" f"Can't extract table names when parsing JOIN information from {comparison}"
) )
logger.debug(f"Query: {sql_statement}") logger.debug(
f"Query: {mask_query(sql_statement, self.dialect.value)}"
)
continue continue
left_table_column = TableColumn(table=table_left, column=column_left) left_table_column = TableColumn(table=table_left, column=column_left)
@ -430,14 +436,18 @@ class LineageParser:
f"Lineage with SqlFluff failed for the [{dialect.value}]. " f"Lineage with SqlFluff failed for the [{dialect.value}]. "
f"Parser has been running for more than {timeout_seconds} seconds." f"Parser has been running for more than {timeout_seconds} seconds."
) )
logger.debug(f"{self.query_parsing_failure_reason}] query: [{query}]") logger.debug(
f"{self.query_parsing_failure_reason}] query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None lr_sqlfluff = None
except Exception: except Exception:
self.query_parsing_success = False self.query_parsing_success = False
self.query_parsing_failure_reason = ( self.query_parsing_failure_reason = (
f"Lineage with SqlFluff failed for the [{dialect.value}]" f"Lineage with SqlFluff failed for the [{dialect.value}]"
) )
logger.debug(f"{self.query_parsing_failure_reason} query: [{query}]") logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None lr_sqlfluff = None
lr_sqlparser = LineageRunner(query) lr_sqlparser = LineageRunner(query)
@ -461,7 +471,9 @@ class LineageParser:
"Lineage computed with SqlFluff did not perform as expected " "Lineage computed with SqlFluff did not perform as expected "
f"for the [{dialect.value}]" f"for the [{dialect.value}]"
) )
logger.debug(f"{self.query_parsing_failure_reason} query: [{query}]") logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
return lr_sqlparser return lr_sqlparser
return lr_sqlfluff return lr_sqlfluff
return lr_sqlparser return lr_sqlparser

View File

@ -28,6 +28,7 @@ from metadata.generated.schema.type.entityLineage import (
from metadata.generated.schema.type.entityLineage import Source as LineageSource from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import ( from metadata.ingestion.lineage.models import (
Dialect, Dialect,
QueryParsingError, QueryParsingError,
@ -248,7 +249,7 @@ def _build_table_lineage(
to_entity: Table, to_entity: Table,
from_table_raw_name: str, from_table_raw_name: str,
to_table_raw_name: str, to_table_raw_name: str,
query: str, masked_query: str,
column_lineage_map: dict, column_lineage_map: dict,
lineage_source: LineageSource = LineageSource.QueryLineage, lineage_source: LineageSource = LineageSource.QueryLineage,
) -> Either[AddLineageRequest]: ) -> Either[AddLineageRequest]:
@ -275,7 +276,7 @@ def _build_table_lineage(
from_table_raw_name=str(from_table_raw_name), from_table_raw_name=str(from_table_raw_name),
column_lineage_map=column_lineage_map, column_lineage_map=column_lineage_map,
) )
lineage_details = LineageDetails(sqlQuery=query, source=lineage_source) lineage_details = LineageDetails(sqlQuery=masked_query, source=lineage_source)
if col_lineage: if col_lineage:
lineage_details.columnsLineage = col_lineage lineage_details.columnsLineage = col_lineage
lineage = AddLineageRequest( lineage = AddLineageRequest(
@ -311,7 +312,7 @@ def _create_lineage_by_table_name(
service_name: str, service_name: str,
database_name: Optional[str], database_name: Optional[str],
schema_name: Optional[str], schema_name: Optional[str],
query: str, masked_query: str,
column_lineage_map: dict, column_lineage_map: dict,
lineage_source: LineageSource = LineageSource.QueryLineage, lineage_source: LineageSource = LineageSource.QueryLineage,
) -> Iterable[Either[AddLineageRequest]]: ) -> Iterable[Either[AddLineageRequest]]:
@ -354,7 +355,7 @@ def _create_lineage_by_table_name(
from_entity=from_entity, from_entity=from_entity,
to_table_raw_name=to_table, to_table_raw_name=to_table,
from_table_raw_name=from_table, from_table_raw_name=from_table,
query=query, masked_query=masked_query,
column_lineage_map=column_lineage_map, column_lineage_map=column_lineage_map,
lineage_source=lineage_source, lineage_source=lineage_source,
) )
@ -417,9 +418,10 @@ def get_lineage_by_query(
""" """
column_lineage = {} column_lineage = {}
query_parsing_failures = QueryParsingFailures() query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)
try: try:
logger.debug(f"Running lineage with query: {query}") logger.debug(f"Running lineage with query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds) lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
raw_column_lineage = lineage_parser.column_lineage raw_column_lineage = lineage_parser.column_lineage
@ -434,7 +436,7 @@ def get_lineage_by_query(
service_name=service_name, service_name=service_name,
database_name=database_name, database_name=database_name,
schema_name=schema_name, schema_name=schema_name,
query=query, masked_query=masked_query,
column_lineage_map=column_lineage, column_lineage_map=column_lineage,
lineage_source=lineage_source, lineage_source=lineage_source,
) )
@ -446,7 +448,7 @@ def get_lineage_by_query(
service_name=service_name, service_name=service_name,
database_name=database_name, database_name=database_name,
schema_name=schema_name, schema_name=schema_name,
query=query, masked_query=masked_query,
column_lineage_map=column_lineage, column_lineage_map=column_lineage,
lineage_source=lineage_source, lineage_source=lineage_source,
) )
@ -460,14 +462,15 @@ def get_lineage_by_query(
service_name=service_name, service_name=service_name,
database_name=database_name, database_name=database_name,
schema_name=schema_name, schema_name=schema_name,
query=query, masked_query=masked_query,
column_lineage_map=column_lineage, column_lineage_map=column_lineage,
lineage_source=lineage_source, lineage_source=lineage_source,
) )
if not lineage_parser.query_parsing_success: if not lineage_parser.query_parsing_success:
query_parsing_failures.add( query_parsing_failures.add(
QueryParsingError( QueryParsingError(
query=query, error=lineage_parser.query_parsing_failure_reason query=masked_query,
error=lineage_parser.query_parsing_failure_reason,
) )
) )
except Exception as exc: except Exception as exc:
@ -494,9 +497,10 @@ def get_lineage_via_table_entity(
"""Get lineage from table entity""" """Get lineage from table entity"""
column_lineage = {} column_lineage = {}
query_parsing_failures = QueryParsingFailures() query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)
try: try:
logger.debug(f"Getting lineage via table entity using query: {query}") logger.debug(f"Getting lineage via table entity using query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds) lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
to_table_name = table_entity.name.root to_table_name = table_entity.name.root
@ -508,14 +512,15 @@ def get_lineage_via_table_entity(
service_name=service_name, service_name=service_name,
database_name=database_name, database_name=database_name,
schema_name=schema_name, schema_name=schema_name,
query=query, masked_query=masked_query,
column_lineage_map=column_lineage, column_lineage_map=column_lineage,
lineage_source=lineage_source, lineage_source=lineage_source,
) or [] ) or []
if not lineage_parser.query_parsing_success: if not lineage_parser.query_parsing_success:
query_parsing_failures.add( query_parsing_failures.add(
QueryParsingError( QueryParsingError(
query=query, error=lineage_parser.query_parsing_failure_reason query=masked_query,
error=lineage_parser.query_parsing_failure_reason,
) )
) )
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except

View File

@ -23,6 +23,7 @@ from metadata.generated.schema.entity.data.query import Query
from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.type.basic import Uuid from metadata.generated.schema.type.basic import Uuid
from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.ometa.client import REST from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.utils import model_str from metadata.ingestion.ometa.utils import model_str
@ -60,6 +61,9 @@ class OMetaQueryMixin:
""" """
for create_query in queries: for create_query in queries:
if not create_query.exclude_usage: if not create_query.exclude_usage:
create_query.query.root = mask_query(
create_query.query.root, create_query.dialect
)
query = self._get_or_create_query(create_query) query = self._get_or_create_query(create_query)
if query: if query:
# Add Query Usage # Add Query Usage

View File

@ -61,6 +61,7 @@ def parse_sql_statement(record: TableQuery, dialect: Dialect) -> Optional[Parsed
sql=record.query, sql=record.query,
query_type=record.query_type, query_type=record.query_type,
exclude_usage=record.exclude_usage, exclude_usage=record.exclude_usage,
dialect=dialect.value,
userName=record.userName, userName=record.userName,
date=str(start_time), date=str(start_time),
serviceName=record.serviceName, serviceName=record.serviceName,

View File

@ -41,6 +41,7 @@ class AthenaLineageSource(AthenaQueryParserSource, LineageSource):
and query.Status.State.upper() == QUERY_SUCCESS_STATUS and query.Status.State.upper() == QUERY_SUCCESS_STATUS
): ):
yield TableQuery( yield TableQuery(
dialect=self.dialect.value,
query=query.Query, query=query.Query,
serviceName=self.config.serviceName, serviceName=self.config.serviceName,
) )

View File

@ -40,6 +40,7 @@ class AthenaUsageSource(AthenaQueryParserSource, UsageSource):
for query_list in self.get_queries() or []: for query_list in self.get_queries() or []:
queries = [ queries = [
TableQuery( TableQuery(
dialect=self.dialect.value,
query=query.Query, query=query.Query,
startTime=query.Status.SubmissionDateTime.isoformat( startTime=query.Status.SubmissionDateTime.isoformat(
DATETIME_SEPARATOR, DATETIME_TIME_SPEC DATETIME_SEPARATOR, DATETIME_TIME_SPEC

View File

@ -40,6 +40,7 @@ class DatabricksLineageSource(DatabricksQueryParserSource, LineageSource):
try: try:
if self.client.is_query_valid(row): if self.client.is_query_valid(row):
yield TableQuery( yield TableQuery(
dialect=self.dialect.value,
query=row.get("query_text"), query=row.get("query_text"),
userName=row.get("user_name"), userName=row.get("user_name"),
startTime=str(row.get("query_start_time_ms")), startTime=str(row.get("query_start_time_ms")),

View File

@ -45,6 +45,7 @@ class DatabricksUsageSource(DatabricksQueryParserSource, UsageSource):
if self.client.is_query_valid(row): if self.client.is_query_valid(row):
queries.append( queries.append(
TableQuery( TableQuery(
dialect=self.dialect.value,
query=row.get("query_text"), query=row.get("query_text"),
userName=row.get("user_name"), userName=row.get("user_name"),
startTime=str(row.get("query_start_time_ms")), startTime=str(row.get("query_start_time_ms")),

View File

@ -135,6 +135,7 @@ class LineageSource(QueryParserSource, ABC):
query_dict = dict(row) query_dict = dict(row)
try: try:
yield TableQuery( yield TableQuery(
dialect=self.dialect.value,
query=query_dict["query_text"], query=query_dict["query_text"],
databaseName=self.get_database_name(query_dict), databaseName=self.get_database_name(query_dict),
serviceName=self.config.serviceName, serviceName=self.config.serviceName,

View File

@ -82,6 +82,7 @@ class PostgresLineageSource(PostgresQueryParserSource, LineageSource):
row = dict(row) row = dict(row)
try: try:
yield TableQuery( yield TableQuery(
dialect=self.dialect.value,
query=row["query_text"], query=row["query_text"],
userName=row["usename"], userName=row["usename"],
analysisDate=DateTime(datetime.now()), analysisDate=DateTime(datetime.now()),

View File

@ -51,6 +51,7 @@ class PostgresUsageSource(PostgresQueryParserSource, UsageSource):
try: try:
queries.append( queries.append(
TableQuery( TableQuery(
dialect=self.dialect.value,
query=row["query_text"], query=row["query_text"],
userName=row["usename"], userName=row["usename"],
analysisDate=DateTime(datetime.now()), analysisDate=DateTime(datetime.now()),

View File

@ -23,6 +23,7 @@ from metadata.ingestion.api.steps import Source
from metadata.ingestion.connections.test_connections import ( from metadata.ingestion.connections.test_connections import (
raise_test_connection_exception, raise_test_connection_exception,
) )
from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_test_connection_fn from metadata.ingestion.source.connections import get_test_connection_fn
from metadata.utils.helpers import get_start_and_end from metadata.utils.helpers import get_start_and_end
@ -43,6 +44,7 @@ class QueryParserSource(Source, ABC):
""" """
sql_stmt: str sql_stmt: str
dialect: str
filters: str filters: str
database_field: str database_field: str
schema_field: str schema_field: str
@ -58,6 +60,8 @@ class QueryParserSource(Source, ABC):
self.metadata = metadata self.metadata = metadata
self.service_name = self.config.serviceName self.service_name = self.config.serviceName
self.service_connection = self.config.serviceConnection.root.config self.service_connection = self.config.serviceConnection.root.config
connection_type = self.service_connection.type.value
self.dialect = ConnectionTypeDialectMapper.dialect_of(connection_type)
self.source_config = self.config.sourceConfig.config self.source_config = self.config.sourceConfig.config
self.start, self.end = get_start_and_end(self.source_config.queryLogDuration) self.start, self.end = get_start_and_end(self.source_config.queryLogDuration)
self.engine = ( self.engine = (

View File

@ -83,6 +83,7 @@ class RedshiftLineageSource(
query_dict = dict(row) query_dict = dict(row)
try: try:
yield TableQuery( yield TableQuery(
dialect=self.dialect.value,
query=query_dict["query_text"] query=query_dict["query_text"]
.replace("\\n", "\n") .replace("\\n", "\n")
.replace("\\r", ""), .replace("\\r", ""),

View File

@ -17,14 +17,14 @@ import textwrap
SNOWFLAKE_SQL_STATEMENT = textwrap.dedent( SNOWFLAKE_SQL_STATEMENT = textwrap.dedent(
""" """
SELECT SELECT
query_type, query_type "query_type",
query_text, query_text "query_text",
user_name, user_name "user_name",
database_name, database_name "database_name",
schema_name, schema_name "schema_name",
start_time, start_time "start_time",
end_time, end_time "end_time",
total_elapsed_time duration total_elapsed_time "duration"
from snowflake.account_usage.query_history from snowflake.account_usage.query_history
WHERE query_text NOT LIKE '/* {{"app": "OpenMetadata", %%}} */%%' WHERE query_text NOT LIKE '/* {{"app": "OpenMetadata", %%}} */%%'
AND query_text NOT LIKE '/* {{"app": "dbt", %%}} */%%' AND query_text NOT LIKE '/* {{"app": "dbt", %%}} */%%'

View File

@ -117,8 +117,6 @@ class StoredProcedureLineageMixin(ABC):
for row in results: for row in results:
try: try:
print("*** " * 100)
print(dict(row))
query_by_procedure = QueryByProcedure.model_validate(dict(row)) query_by_procedure = QueryByProcedure.model_validate(dict(row))
procedure_name = ( procedure_name = (
query_by_procedure.procedure_name query_by_procedure.procedure_name

View File

@ -21,6 +21,7 @@ from typing import Iterable
from metadata.generated.schema.type.basic import DateTime from metadata.generated.schema.type.basic import DateTime
from metadata.generated.schema.type.tableQuery import TableQueries, TableQuery from metadata.generated.schema.type.tableQuery import TableQueries, TableQuery
from metadata.ingestion.api.models import Either from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.source.database.query_parser_source import QueryParserSource from metadata.ingestion.source.database.query_parser_source import QueryParserSource
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
@ -65,6 +66,7 @@ class UsageSource(QueryParserSource, ABC):
) )
query_list.append( query_list.append(
TableQuery( TableQuery(
dialect=self.dialect.value,
query=query_dict["query_text"], query=query_dict["query_text"],
userName=query_dict.get("user_name", ""), userName=query_dict.get("user_name", ""),
startTime=query_dict.get("start_time", ""), startTime=query_dict.get("start_time", ""),
@ -119,6 +121,7 @@ class UsageSource(QueryParserSource, ABC):
for row in rows: for row in rows:
row = dict(row) row = dict(row)
try: try:
logger.debug(f"Processing row: {query}")
query_type = row.get("query_type") query_type = row.get("query_type")
query = self.format_query(row["query_text"]) query = self.format_query(row["query_text"])
queries.append( queries.append(
@ -128,6 +131,7 @@ class UsageSource(QueryParserSource, ABC):
exclude_usage=self.check_life_cycle_query( exclude_usage=self.check_life_cycle_query(
query_type=query_type, query_text=query query_type=query_type, query_text=query
), ),
dialect=self.dialect.value,
userName=row["user_name"], userName=row["user_name"],
startTime=str(row["start_time"]), startTime=str(row["start_time"]),
endTime=str(row["end_time"]), endTime=str(row["end_time"]),
@ -148,7 +152,7 @@ class UsageSource(QueryParserSource, ABC):
except Exception as exc: except Exception as exc:
if query: if query:
logger.debug( logger.debug(
f"###### USAGE QUERY #######\n{query}\n##########################" f"###### USAGE QUERY #######\n{mask_query(query, self.dialect.value)}\n##########################"
) )
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(f"Source usage processing error: {exc}") logger.error(f"Source usage processing error: {exc}")

View File

@ -114,6 +114,7 @@ class TableUsageStage(Stage):
exclude_usage=record.exclude_usage, exclude_usage=record.exclude_usage,
users=users, users=users,
queryDate=record.date, queryDate=record.date,
dialect=record.dialect,
usedBy=used_by, usedBy=used_by,
duration=record.duration, duration=record.duration,
service=record.serviceName, service=record.serviceName,
@ -128,6 +129,7 @@ class TableUsageStage(Stage):
users=users, users=users,
queryDate=record.date, queryDate=record.date,
usedBy=used_by, usedBy=used_by,
dialect=record.dialect,
duration=record.duration, duration=record.duration,
service=record.serviceName, service=record.serviceName,
) )

View File

@ -25,7 +25,10 @@ from metadata.data_quality.api.models import (
TestCaseResults, TestCaseResults,
) )
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.type.queryParserData import QueryParserData
from metadata.generated.schema.type.tableQuery import TableQueries
from metadata.ingestion.api.models import Entity from metadata.ingestion.api.models import Entity
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.models.delete_entity import DeleteEntity from metadata.ingestion.models.delete_entity import DeleteEntity
from metadata.ingestion.models.life_cycle import OMetaLifeCycleData from metadata.ingestion.models.life_cycle import OMetaLifeCycleData
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
@ -269,6 +272,24 @@ def _(record: PatchRequest) -> str:
return get_log_name(record.new_entity) return get_log_name(record.new_entity)
@get_log_name.register
def _(record: TableQueries) -> str:
"""Get the log of the TableQuery"""
queries = "\n------\n".join(
mask_query(query.query, query.dialect) for query in record.queries
)
return f"Table Queries [{queries}]"
@get_log_name.register
def _(record: QueryParserData) -> str:
"""Get the log of the ParsedData"""
queries = "\n------\n".join(
mask_query(query.sql, query.dialect) for query in record.parsedData
)
return f"Usage ParsedData [{queries}]"
def redacted_config(config: Dict[str, Union[str, dict]]) -> Dict[str, Union[str, dict]]: def redacted_config(config: Dict[str, Union[str, dict]]) -> Dict[str, Union[str, dict]]:
config_copy = deepcopy(config) config_copy = deepcopy(config)

View File

@ -34,6 +34,7 @@ with open(mock_file_path, encoding="utf-8") as file:
EXPECTED_DATABRICKS_DETAILS = [ EXPECTED_DATABRICKS_DETAILS = [
TableQuery( TableQuery(
dialect="databricks",
query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `test`', query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `test`',
userName="vijay@getcollate.io", userName="vijay@getcollate.io",
startTime="1665566128192", startTime="1665566128192",
@ -44,6 +45,7 @@ EXPECTED_DATABRICKS_DETAILS = [
databaseSchema=None, databaseSchema=None,
), ),
TableQuery( TableQuery(
dialect="databricks",
query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `test`', query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `test`',
userName="vijay@getcollate.io", userName="vijay@getcollate.io",
startTime="1665566127416", startTime="1665566127416",
@ -54,6 +56,7 @@ EXPECTED_DATABRICKS_DETAILS = [
databaseSchema=None, databaseSchema=None,
), ),
TableQuery( TableQuery(
dialect="databricks",
query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `default`', query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nSHOW TABLES IN `default`',
userName="vijay@getcollate.io", userName="vijay@getcollate.io",
startTime="1665566125414", startTime="1665566125414",
@ -64,6 +67,7 @@ EXPECTED_DATABRICKS_DETAILS = [
databaseSchema=None, databaseSchema=None,
), ),
TableQuery( TableQuery(
dialect="databricks",
query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nDESCRIBE default.view3', query=' /* {"app": "OpenMetadata", "version": "0.13.0.dev0"} */\nDESCRIBE default.view3',
userName="vijay@getcollate.io", userName="vijay@getcollate.io",
startTime="1665566124428", startTime="1665566124428",

View File

@ -18,6 +18,7 @@ from unittest import TestCase
import pytest import pytest
from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.table import Table
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import Dialect from metadata.ingestion.lineage.models import Dialect
from metadata.ingestion.lineage.parser import LineageParser from metadata.ingestion.lineage.parser import LineageParser
from metadata.ingestion.lineage.sql_lineage import ( from metadata.ingestion.lineage.sql_lineage import (
@ -225,3 +226,29 @@ class SqlLineageTest(TestCase):
self.assertEqual( self.assertEqual(
get_table_fqn_from_query_name(raw_query_name), (None, None, "tab") get_table_fqn_from_query_name(raw_query_name), (None, None, "tab")
) )
def test_query_masker(self):
query_list = [
"""SELECT * FROM user WHERE id=1234 AND name='Alice' AND birthdate=DATE '2023-01-01';""",
"""insert into user values ('mayur',123,'my random address 1'), ('mayur',123,'my random address 1');""",
"""SELECT * FROM user WHERE address = '5th street' and name = 'john';""",
"""INSERT INTO user VALUE ('John', '19', '5TH Street');""",
"""SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user;""",
"""with test as (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user) select * from test;""",
"""select * from (select * from (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user));""",
"""select * from users where id > 2 and name <> 'pere';""",
]
expected_query_list = [
"""SELECT * FROM user WHERE id=? AND name=? AND birthdate=DATE ?;""",
"""insert into user values (?,?,?), (?,?,?);""",
"""SELECT * FROM user WHERE address = ? and name = ?;""",
"""INSERT INTO user VALUE (?, ?, ?);""",
"""SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user;""",
"""with test as (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user) select * from test;""",
"""select * from (select * from (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user));""",
"""select * from users where id > ? and name <> ?;""",
]
for i, query in enumerate(query_list):
self.assertEqual(mask_query(query), expected_query_list[i])

View File

@ -77,6 +77,7 @@ def custom_query_compare(self, other):
EXPECTED_QUERIES = [ EXPECTED_QUERIES = [
TableQuery( TableQuery(
dialect="ansi",
query="select * from sales", query="select * from sales",
userName="", userName="",
startTime="", startTime="",
@ -88,6 +89,7 @@ EXPECTED_QUERIES = [
duration=None, duration=None,
), ),
TableQuery( TableQuery(
dialect="ansi",
query="select * from marketing", query="select * from marketing",
userName="", userName="",
startTime="", startTime="",
@ -99,6 +101,7 @@ EXPECTED_QUERIES = [
duration=None, duration=None,
), ),
TableQuery( TableQuery(
dialect="ansi",
query="insert into marketing select * from sales", query="insert into marketing select * from sales",
userName="", userName="",
startTime="", startTime="",
@ -112,6 +115,7 @@ EXPECTED_QUERIES = [
] ]
EXPECTED_QUERIES_FILE_2 = [ EXPECTED_QUERIES_FILE_2 = [
TableQuery( TableQuery(
dialect="ansi",
query="select * from product_data", query="select * from product_data",
userName="", userName="",
startTime="", startTime="",
@ -123,6 +127,7 @@ EXPECTED_QUERIES_FILE_2 = [
duration=None, duration=None,
), ),
TableQuery( TableQuery(
dialect="ansi",
query="select * from students where marks>=80", query="select * from students where marks>=80",
userName="", userName="",
startTime="", startTime="",

View File

@ -65,6 +65,10 @@
}, },
"uniqueItems": true "uniqueItems": true
}, },
"dialect": {
"description": "SQL dialect.",
"type": "string"
},
"queryDate": { "queryDate": {
"description": "Date on which the query ran.", "description": "Date on which the query ran.",
"$ref": "../../type/basic.json#/definitions/timestamp" "$ref": "../../type/basic.json#/definitions/timestamp"

View File

@ -34,6 +34,10 @@
"description": "SQL query", "description": "SQL query",
"type": "string" "type": "string"
}, },
"dialect": {
"description": "SQL dialect",
"type": "string"
},
"query_type": { "query_type": {
"description": "SQL query type", "description": "SQL query type",
"type": "string" "type": "string"

View File

@ -19,6 +19,10 @@
"description": "Flag to check if query is to be excluded while processing usage", "description": "Flag to check if query is to be excluded while processing usage",
"type": "boolean" "type": "boolean"
}, },
"dialect": {
"description": "SQL dialect",
"type": "string"
},
"userName": { "userName": {
"description": "Name of the user that executed the SQL query", "description": "Name of the user that executed the SQL query",
"type": "string" "type": "string"