feat(ingest): utilities for query logs (#10036)

This commit is contained in:
Harshal Sheth 2024-03-12 23:20:46 -07:00 committed by GitHub
parent 4535f2adfd
commit b0163c4885
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 861 additions and 569 deletions

View File

@ -99,7 +99,7 @@ usage_common = {
sqlglot_lib = { sqlglot_lib = {
# Using an Acryl fork of sqlglot. # Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1 # https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==22.3.1.dev3", "acryl-sqlglot==22.4.1.dev4",
} }
classification_lib = { classification_lib = {

View File

@ -1,4 +1,7 @@
import dataclasses
import json
import logging import logging
import pathlib
import pprint import pprint
import shutil import shutil
import tempfile import tempfile
@ -17,6 +20,7 @@ from datahub.ingestion.sink.sink_registry import sink_registry
from datahub.ingestion.source.source_registry import source_registry from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry from datahub.ingestion.transformer.transform_registry import transform_registry
from datahub.telemetry import telemetry from datahub.telemetry import telemetry
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -339,3 +343,28 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}" f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}"
) )
raise e raise e
@check.command()
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
@click.option("--output", type=click.Path())
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""
from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
assert dataclasses.is_dataclass(LoggedQuery)
shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
query_log = FileBackedList[LoggedQuery](
shared_connection=shared_connection, tablename="stored_queries"
)
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
queries = [dataclasses.asdict(query) for query in query_log]
if output:
with open(output, "w") as f:
json.dump(queries, f, indent=2)
logger.info(f"Extracted {len(queries)} queries to {output}")
else:
click.echo(json.dumps(queries, indent=2))

View File

@ -1,8 +1,10 @@
import contextlib
import dataclasses import dataclasses
import enum import enum
import itertools import itertools
import json import json
import logging import logging
import os
import pathlib import pathlib
import tempfile import tempfile
import uuid import uuid
@ -15,6 +17,7 @@ import datahub.metadata.schema_classes as models
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import compute_upstream_fields from datahub.emitter.sql_parsing_builder import compute_upstream_fields
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.report import Report from datahub.ingestion.api.report import Report
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.client import DataHubGraph
@ -53,9 +56,6 @@ logger = logging.getLogger(__name__)
QueryId = str QueryId = str
UrnStr = str UrnStr = str
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
class QueryLogSetting(enum.Enum): class QueryLogSetting(enum.Enum):
DISABLED = "DISABLED" DISABLED = "DISABLED"
@ -63,6 +63,23 @@ class QueryLogSetting(enum.Enum):
STORE_FAILED = "STORE_FAILED" STORE_FAILED = "STORE_FAILED"
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
_DEFAULT_QUERY_LOG_SETTING = QueryLogSetting[
os.getenv("DATAHUB_SQL_AGG_QUERY_LOG") or QueryLogSetting.DISABLED.name
]
@dataclasses.dataclass
class LoggedQuery:
query: str
session_id: Optional[str]
timestamp: Optional[datetime]
user: Optional[UrnStr]
default_db: Optional[str]
default_schema: Optional[str]
@dataclasses.dataclass @dataclasses.dataclass
class ViewDefinition: class ViewDefinition:
view_definition: str view_definition: str
@ -170,7 +187,7 @@ class SqlAggregatorReport(Report):
return super().compute_stats() return super().compute_stats()
class SqlParsingAggregator: class SqlParsingAggregator(Closeable):
def __init__( def __init__(
self, self,
*, *,
@ -185,7 +202,7 @@ class SqlParsingAggregator:
usage_config: Optional[BaseUsageConfig] = None, usage_config: Optional[BaseUsageConfig] = None,
is_temp_table: Optional[Callable[[UrnStr], bool]] = None, is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
format_queries: bool = True, format_queries: bool = True,
query_log: QueryLogSetting = QueryLogSetting.DISABLED, query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
) -> None: ) -> None:
self.platform = DataPlatformUrn(platform) self.platform = DataPlatformUrn(platform)
self.platform_instance = platform_instance self.platform_instance = platform_instance
@ -210,13 +227,18 @@ class SqlParsingAggregator:
self.format_queries = format_queries self.format_queries = format_queries
self.query_log = query_log self.query_log = query_log
# The exit stack helps ensure that we close all the resources we open.
self._exit_stack = contextlib.ExitStack()
# Set up the schema resolver. # Set up the schema resolver.
self._schema_resolver: SchemaResolver self._schema_resolver: SchemaResolver
if graph is None: if graph is None:
self._schema_resolver = SchemaResolver( self._schema_resolver = self._exit_stack.enter_context(
platform=self.platform.platform_name, SchemaResolver(
platform_instance=self.platform_instance, platform=self.platform.platform_name,
env=self.env, platform_instance=self.platform_instance,
env=self.env,
)
) )
else: else:
self._schema_resolver = None # type: ignore self._schema_resolver = None # type: ignore
@ -235,27 +257,33 @@ class SqlParsingAggregator:
# By providing a filename explicitly here, we also ensure that the file # By providing a filename explicitly here, we also ensure that the file
# is not automatically deleted on exit. # is not automatically deleted on exit.
self._shared_connection = ConnectionWrapper(filename=query_log_path) self._shared_connection = self._exit_stack.enter_context(
ConnectionWrapper(filename=query_log_path)
)
# Stores the logged queries. # Stores the logged queries.
self._logged_queries = FileBackedList[str]( self._logged_queries = FileBackedList[LoggedQuery](
shared_connection=self._shared_connection, tablename="stored_queries" shared_connection=self._shared_connection, tablename="stored_queries"
) )
self._exit_stack.push(self._logged_queries)
# Map of query_id -> QueryMetadata # Map of query_id -> QueryMetadata
self._query_map = FileBackedDict[QueryMetadata]( self._query_map = FileBackedDict[QueryMetadata](
shared_connection=self._shared_connection, tablename="query_map" shared_connection=self._shared_connection, tablename="query_map"
) )
self._exit_stack.push(self._query_map)
# Map of downstream urn -> { query ids } # Map of downstream urn -> { query ids }
self._lineage_map = FileBackedDict[OrderedSet[QueryId]]( self._lineage_map = FileBackedDict[OrderedSet[QueryId]](
shared_connection=self._shared_connection, tablename="lineage_map" shared_connection=self._shared_connection, tablename="lineage_map"
) )
self._exit_stack.push(self._lineage_map)
# Map of view urn -> view definition # Map of view urn -> view definition
self._view_definitions = FileBackedDict[ViewDefinition]( self._view_definitions = FileBackedDict[ViewDefinition](
shared_connection=self._shared_connection, tablename="view_definitions" shared_connection=self._shared_connection, tablename="view_definitions"
) )
self._exit_stack.push(self._view_definitions)
# Map of session ID -> {temp table name -> query id} # Map of session ID -> {temp table name -> query id}
# Needs to use the query_map to find the info about the query. # Needs to use the query_map to find the info about the query.
@ -263,16 +291,20 @@ class SqlParsingAggregator:
self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]]( self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]](
shared_connection=self._shared_connection, tablename="temp_lineage_map" shared_connection=self._shared_connection, tablename="temp_lineage_map"
) )
self._exit_stack.push(self._temp_lineage_map)
# Map of query ID -> schema fields, only for query IDs that generate temp tables. # Map of query ID -> schema fields, only for query IDs that generate temp tables.
self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]]( self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]](
shared_connection=self._shared_connection, tablename="inferred_temp_schemas" shared_connection=self._shared_connection,
tablename="inferred_temp_schemas",
) )
self._exit_stack.push(self._inferred_temp_schemas)
# Map of table renames, from original UrnStr to new UrnStr. # Map of table renames, from original UrnStr to new UrnStr.
self._table_renames = FileBackedDict[UrnStr]( self._table_renames = FileBackedDict[UrnStr](
shared_connection=self._shared_connection, tablename="table_renames" shared_connection=self._shared_connection, tablename="table_renames"
) )
self._exit_stack.push(self._table_renames)
# Usage aggregator. This will only be initialized if usage statistics are enabled. # Usage aggregator. This will only be initialized if usage statistics are enabled.
# TODO: Replace with FileBackedDict. # TODO: Replace with FileBackedDict.
@ -281,6 +313,9 @@ class SqlParsingAggregator:
assert self.usage_config is not None assert self.usage_config is not None
self._usage_aggregator = UsageAggregator(config=self.usage_config) self._usage_aggregator = UsageAggregator(config=self.usage_config)
def close(self) -> None:
self._exit_stack.close()
@property @property
def _need_schemas(self) -> bool: def _need_schemas(self) -> bool:
return self.generate_lineage or self.generate_usage_statistics return self.generate_lineage or self.generate_usage_statistics
@ -499,6 +534,9 @@ class SqlParsingAggregator:
default_db=default_db, default_db=default_db,
default_schema=default_schema, default_schema=default_schema,
schema_resolver=schema_resolver, schema_resolver=schema_resolver,
session_id=session_id,
timestamp=query_timestamp,
user=user,
) )
if parsed.debug_info.error: if parsed.debug_info.error:
self.report.observed_query_parse_failures.append( self.report.observed_query_parse_failures.append(
@ -700,6 +738,9 @@ class SqlParsingAggregator:
default_db: Optional[str], default_db: Optional[str],
default_schema: Optional[str], default_schema: Optional[str],
schema_resolver: SchemaResolverInterface, schema_resolver: SchemaResolverInterface,
session_id: str = _MISSING_SESSION_ID,
timestamp: Optional[datetime] = None,
user: Optional[CorpUserUrn] = None,
) -> SqlParsingResult: ) -> SqlParsingResult:
parsed = sqlglot_lineage( parsed = sqlglot_lineage(
query, query,
@ -712,7 +753,15 @@ class SqlParsingAggregator:
if self.query_log == QueryLogSetting.STORE_ALL or ( if self.query_log == QueryLogSetting.STORE_ALL or (
self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error
): ):
self._logged_queries.append(query) query_log_entry = LoggedQuery(
query=query,
session_id=session_id if session_id != _MISSING_SESSION_ID else None,
timestamp=timestamp,
user=user.urn() if user else None,
default_db=default_db,
default_schema=default_schema,
)
self._logged_queries.append(query_log_entry)
# Also add some extra logging. # Also add some extra logging.
if parsed.debug_info.error: if parsed.debug_info.error:

View File

@ -62,9 +62,13 @@ def assert_metadata_files_equal(
# We have to "normalize" the golden file by reading and writing it back out. # We have to "normalize" the golden file by reading and writing it back out.
# This will clean up nulls, double serialization, and other formatting issues. # This will clean up nulls, double serialization, and other formatting issues.
with tempfile.NamedTemporaryFile() as temp: with tempfile.NamedTemporaryFile() as temp:
golden_metadata = read_metadata_file(pathlib.Path(golden_path)) try:
write_metadata_file(pathlib.Path(temp.name), golden_metadata) golden_metadata = read_metadata_file(pathlib.Path(golden_path))
golden = load_json_file(temp.name) write_metadata_file(pathlib.Path(temp.name), golden_metadata)
golden = load_json_file(temp.name)
except (ValueError, AssertionError) as e:
logger.info(f"Error reformatting golden file as MCP/MCEs: {e}")
golden = load_json_file(golden_path)
diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order) diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order)
if diff and update_golden: if diff and update_golden:
@ -107,7 +111,7 @@ def diff_metadata_json(
# if ignore_order is False, always use DeepDiff # if ignore_order is False, always use DeepDiff
except CannotCompareMCPs as e: except CannotCompareMCPs as e:
logger.info(f"{e}, falling back to MCE diff") logger.info(f"{e}, falling back to MCE diff")
except AssertionError as e: except (AssertionError, ValueError) as e:
logger.warning(f"Reverting to old diff method: {e}") logger.warning(f"Reverting to old diff method: {e}")
logger.debug("Error with new diff method", exc_info=True) logger.debug("Error with new diff method", exc_info=True)

View File

@ -126,6 +126,7 @@ class ConnectionWrapper:
def close(self) -> None: def close(self) -> None:
for obj in self._dependent_objects: for obj in self._dependent_objects:
obj.close() obj.close()
self._dependent_objects.clear()
with self.conn_lock: with self.conn_lock:
self.conn.close() self.conn.close()
if self._temp_directory: if self._temp_directory:
@ -440,7 +441,7 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]):
self.close() self.close()
class FileBackedList(Generic[_VT]): class FileBackedList(Generic[_VT], Closeable):
"""An append-only, list-like object that stores its contents in a SQLite database.""" """An append-only, list-like object that stores its contents in a SQLite database."""
_len: int = field(default=0) _len: int = field(default=0)
@ -456,7 +457,6 @@ class FileBackedList(Generic[_VT]):
cache_max_size: Optional[int] = None, cache_max_size: Optional[int] = None,
cache_eviction_batch_size: Optional[int] = None, cache_eviction_batch_size: Optional[int] = None,
) -> None: ) -> None:
self._len = 0
self._dict = FileBackedDict[_VT]( self._dict = FileBackedDict[_VT](
shared_connection=shared_connection, shared_connection=shared_connection,
tablename=tablename, tablename=tablename,
@ -468,6 +468,12 @@ class FileBackedList(Generic[_VT]):
or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE, or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE,
) )
if shared_connection:
shared_connection._dependent_objects.append(self)
# In case we're reusing an existing list, we need to run a query to get the length.
self._len = len(self._dict)
@property @property
def tablename(self) -> str: def tablename(self) -> str:
return self._dict.tablename return self._dict.tablename

View File

@ -0,0 +1,10 @@
[
{
"query": "create table foo as select a, b from bar",
"session_id": null,
"timestamp": null,
"user": null,
"default_db": "dev",
"default_schema": "public"
}
]

View File

@ -83,7 +83,7 @@
"aspect": { "aspect": {
"json": { "json": {
"statement": { "statement": {
"value": "create table #temp2 as select b, c from upstream2;\n\ncreate table #temp1 as select a, 2*b as b from upstream1;\n\ncreate temp table staging_foo as select up1.a, up1.b, up2.c from #temp1 up1 left join #temp2 up2 on up1.b = up2.b where up1.b > 0;\n\ninsert into table prod_foo\nselect * from staging_foo", "value": "CREATE TABLE #temp2 AS\nSELECT\n b,\n c\nFROM upstream2;\n\nCREATE TABLE #temp1 AS\nSELECT\n a,\n 2 * b AS b\nFROM upstream1;\n\nCREATE TEMPORARY TABLE staging_foo AS\nSELECT\n up1.a,\n up1.b,\n up2.c\nFROM #temp1 AS up1\nLEFT JOIN #temp2 AS up2\n ON up1.b = up2.b\nWHERE\n up1.b > 0;\n\nINSERT INTO prod_foo\nSELECT\n *\nFROM staging_foo",
"language": "SQL" "language": "SQL"
}, },
"source": "SYSTEM", "source": "SYSTEM",

View File

@ -13,6 +13,7 @@ from datahub.sql_parsing.sql_parsing_aggregator import (
from datahub.sql_parsing.sql_parsing_common import QueryType from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import ColumnLineageInfo, ColumnRef from datahub.sql_parsing.sqlglot_lineage import ColumnLineageInfo, ColumnRef
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
RESOURCE_DIR = pathlib.Path(__file__).parent / "aggregator_goldens" RESOURCE_DIR = pathlib.Path(__file__).parent / "aggregator_goldens"
FROZEN_TIME = "2024-02-06 01:23:45" FROZEN_TIME = "2024-02-06 01:23:45"
@ -23,12 +24,13 @@ def _ts(ts: int) -> datetime:
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config) -> None: def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
aggregator = SqlParsingAggregator( aggregator = SqlParsingAggregator(
platform="redshift", platform="redshift",
generate_lineage=True, generate_lineage=True,
generate_usage_statistics=False, generate_usage_statistics=False,
generate_operations=False, generate_operations=False,
query_log=QueryLogSetting.STORE_ALL,
) )
aggregator.add_observed_query( aggregator.add_observed_query(
@ -45,6 +47,23 @@ def test_basic_lineage(pytestconfig: pytest.Config) -> None:
golden_path=RESOURCE_DIR / "test_basic_lineage.json", golden_path=RESOURCE_DIR / "test_basic_lineage.json",
) )
# This test also validates the query log storage functionality.
aggregator.close()
query_log_db = aggregator.report.query_log_path
query_log_json = tmp_path / "query_log.json"
run_datahub_cmd(
[
"check",
"extract-sql-agg-log",
str(query_log_db),
"--output",
str(query_log_json),
]
)
mce_helpers.check_golden_file(
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
def test_overlapping_inserts(pytestconfig: pytest.Config) -> None: def test_overlapping_inserts(pytestconfig: pytest.Config) -> None: