feat(ingest): utilities for query logs (#10036)

This commit is contained in:
Harshal Sheth 2024-03-12 23:20:46 -07:00 committed by GitHub
parent 4535f2adfd
commit b0163c4885
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 861 additions and 569 deletions

View File

@ -99,7 +99,7 @@ usage_common = {
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==22.3.1.dev3",
"acryl-sqlglot==22.4.1.dev4",
}
classification_lib = {

View File

@ -1,4 +1,7 @@
import dataclasses
import json
import logging
import pathlib
import pprint
import shutil
import tempfile
@ -17,6 +20,7 @@ from datahub.ingestion.sink.sink_registry import sink_registry
from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry
from datahub.telemetry import telemetry
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
logger = logging.getLogger(__name__)
@ -339,3 +343,28 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}"
)
raise e
@check.command()
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
@click.option("--output", type=click.Path())
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""
from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
assert dataclasses.is_dataclass(LoggedQuery)
shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
query_log = FileBackedList[LoggedQuery](
shared_connection=shared_connection, tablename="stored_queries"
)
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
queries = [dataclasses.asdict(query) for query in query_log]
if output:
with open(output, "w") as f:
json.dump(queries, f, indent=2)
logger.info(f"Extracted {len(queries)} queries to {output}")
else:
click.echo(json.dumps(queries, indent=2))

View File

@ -1,8 +1,10 @@
import contextlib
import dataclasses
import enum
import itertools
import json
import logging
import os
import pathlib
import tempfile
import uuid
@ -15,6 +17,7 @@ import datahub.metadata.schema_classes as models
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph
@ -53,9 +56,6 @@ logger = logging.getLogger(__name__)
QueryId = str
UrnStr = str
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
class QueryLogSetting(enum.Enum):
DISABLED = "DISABLED"
@ -63,6 +63,23 @@ class QueryLogSetting(enum.Enum):
STORE_FAILED = "STORE_FAILED"
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
_DEFAULT_QUERY_LOG_SETTING = QueryLogSetting[
os.getenv("DATAHUB_SQL_AGG_QUERY_LOG") or QueryLogSetting.DISABLED.name
]
@dataclasses.dataclass
class LoggedQuery:
query: str
session_id: Optional[str]
timestamp: Optional[datetime]
user: Optional[UrnStr]
default_db: Optional[str]
default_schema: Optional[str]
@dataclasses.dataclass
class ViewDefinition:
view_definition: str
@ -170,7 +187,7 @@ class SqlAggregatorReport(Report):
return super().compute_stats()
class SqlParsingAggregator:
class SqlParsingAggregator(Closeable):
def __init__(
self,
*,
@ -185,7 +202,7 @@ class SqlParsingAggregator:
usage_config: Optional[BaseUsageConfig] = None,
is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
format_queries: bool = True,
query_log: QueryLogSetting = QueryLogSetting.DISABLED,
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
) -> None:
self.platform = DataPlatformUrn(platform)
self.platform_instance = platform_instance
@ -210,13 +227,18 @@ class SqlParsingAggregator:
self.format_queries = format_queries
self.query_log = query_log
# The exit stack helps ensure that we close all the resources we open.
self._exit_stack = contextlib.ExitStack()
# Set up the schema resolver.
self._schema_resolver: SchemaResolver
if graph is None:
self._schema_resolver = SchemaResolver(
platform=self.platform.platform_name,
platform_instance=self.platform_instance,
env=self.env,
self._schema_resolver = self._exit_stack.enter_context(
SchemaResolver(
platform=self.platform.platform_name,
platform_instance=self.platform_instance,
env=self.env,
)
)
else:
self._schema_resolver = None # type: ignore
@ -235,27 +257,33 @@ class SqlParsingAggregator:
# By providing a filename explicitly here, we also ensure that the file
# is not automatically deleted on exit.
self._shared_connection = ConnectionWrapper(filename=query_log_path)
self._shared_connection = self._exit_stack.enter_context(
ConnectionWrapper(filename=query_log_path)
)
# Stores the logged queries.
self._logged_queries = FileBackedList[str](
self._logged_queries = FileBackedList[LoggedQuery](
shared_connection=self._shared_connection, tablename="stored_queries"
)
self._exit_stack.push(self._logged_queries)
# Map of query_id -> QueryMetadata
self._query_map = FileBackedDict[QueryMetadata](
shared_connection=self._shared_connection, tablename="query_map"
)
self._exit_stack.push(self._query_map)
# Map of downstream urn -> { query ids }
self._lineage_map = FileBackedDict[OrderedSet[QueryId]](
shared_connection=self._shared_connection, tablename="lineage_map"
)
self._exit_stack.push(self._lineage_map)
# Map of view urn -> view definition
self._view_definitions = FileBackedDict[ViewDefinition](
shared_connection=self._shared_connection, tablename="view_definitions"
)
self._exit_stack.push(self._view_definitions)
# Map of session ID -> {temp table name -> query id}
# Needs to use the query_map to find the info about the query.
@ -263,16 +291,20 @@ class SqlParsingAggregator:
self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]](
shared_connection=self._shared_connection, tablename="temp_lineage_map"
)
self._exit_stack.push(self._temp_lineage_map)
# Map of query ID -> schema fields, only for query IDs that generate temp tables.
self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]](
shared_connection=self._shared_connection, tablename="inferred_temp_schemas"
shared_connection=self._shared_connection,
tablename="inferred_temp_schemas",
)
self._exit_stack.push(self._inferred_temp_schemas)
# Map of table renames, from original UrnStr to new UrnStr.
self._table_renames = FileBackedDict[UrnStr](
shared_connection=self._shared_connection, tablename="table_renames"
)
self._exit_stack.push(self._table_renames)
# Usage aggregator. This will only be initialized if usage statistics are enabled.
# TODO: Replace with FileBackedDict.
@ -281,6 +313,9 @@ class SqlParsingAggregator:
assert self.usage_config is not None
self._usage_aggregator = UsageAggregator(config=self.usage_config)
def close(self) -> None:
self._exit_stack.close()
@property
def _need_schemas(self) -> bool:
return self.generate_lineage or self.generate_usage_statistics
@ -499,6 +534,9 @@ class SqlParsingAggregator:
default_db=default_db,
default_schema=default_schema,
schema_resolver=schema_resolver,
session_id=session_id,
timestamp=query_timestamp,
user=user,
)
if parsed.debug_info.error:
self.report.observed_query_parse_failures.append(
@ -700,6 +738,9 @@ class SqlParsingAggregator:
default_db: Optional[str],
default_schema: Optional[str],
schema_resolver: SchemaResolverInterface,
session_id: str = _MISSING_SESSION_ID,
timestamp: Optional[datetime] = None,
user: Optional[CorpUserUrn] = None,
) -> SqlParsingResult:
parsed = sqlglot_lineage(
query,
@ -712,7 +753,15 @@ class SqlParsingAggregator:
if self.query_log == QueryLogSetting.STORE_ALL or (
self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error
):
self._logged_queries.append(query)
query_log_entry = LoggedQuery(
query=query,
session_id=session_id if session_id != _MISSING_SESSION_ID else None,
timestamp=timestamp,
user=user.urn() if user else None,
default_db=default_db,
default_schema=default_schema,
)
self._logged_queries.append(query_log_entry)
# Also add some extra logging.
if parsed.debug_info.error:

View File

@ -62,9 +62,13 @@ def assert_metadata_files_equal(
# We have to "normalize" the golden file by reading and writing it back out.
# This will clean up nulls, double serialization, and other formatting issues.
with tempfile.NamedTemporaryFile() as temp:
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
golden = load_json_file(temp.name)
try:
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
golden = load_json_file(temp.name)
except (ValueError, AssertionError) as e:
logger.info(f"Error reformatting golden file as MCP/MCEs: {e}")
golden = load_json_file(golden_path)
diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order)
if diff and update_golden:
@ -107,7 +111,7 @@ def diff_metadata_json(
# if ignore_order is False, always use DeepDiff
except CannotCompareMCPs as e:
logger.info(f"{e}, falling back to MCE diff")
except AssertionError as e:
except (AssertionError, ValueError) as e:
logger.warning(f"Reverting to old diff method: {e}")
logger.debug("Error with new diff method", exc_info=True)

View File

@ -126,6 +126,7 @@ class ConnectionWrapper:
def close(self) -> None:
for obj in self._dependent_objects:
obj.close()
self._dependent_objects.clear()
with self.conn_lock:
self.conn.close()
if self._temp_directory:
@ -440,7 +441,7 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]):
self.close()
class FileBackedList(Generic[_VT]):
class FileBackedList(Generic[_VT], Closeable):
"""An append-only, list-like object that stores its contents in a SQLite database."""
_len: int = field(default=0)
@ -456,7 +457,6 @@ class FileBackedList(Generic[_VT]):
cache_max_size: Optional[int] = None,
cache_eviction_batch_size: Optional[int] = None,
) -> None:
self._len = 0
self._dict = FileBackedDict[_VT](
shared_connection=shared_connection,
tablename=tablename,
@ -468,6 +468,12 @@ class FileBackedList(Generic[_VT]):
or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE,
)
if shared_connection:
shared_connection._dependent_objects.append(self)
# In case we're reusing an existing list, we need to run a query to get the length.
self._len = len(self._dict)
@property
def tablename(self) -> str:
return self._dict.tablename

View File

@ -0,0 +1,10 @@
[
{
"query": "create table foo as select a, b from bar",
"session_id": null,
"timestamp": null,
"user": null,
"default_db": "dev",
"default_schema": "public"
}
]

View File

@ -83,7 +83,7 @@
"aspect": {
"json": {
"statement": {
"value": "create table #temp2 as select b, c from upstream2;\n\ncreate table #temp1 as select a, 2*b as b from upstream1;\n\ncreate temp table staging_foo as select up1.a, up1.b, up2.c from #temp1 up1 left join #temp2 up2 on up1.b = up2.b where up1.b > 0;\n\ninsert into table prod_foo\nselect * from staging_foo",
"value": "CREATE TABLE #temp2 AS\nSELECT\n b,\n c\nFROM upstream2;\n\nCREATE TABLE #temp1 AS\nSELECT\n a,\n 2 * b AS b\nFROM upstream1;\n\nCREATE TEMPORARY TABLE staging_foo AS\nSELECT\n up1.a,\n up1.b,\n up2.c\nFROM #temp1 AS up1\nLEFT JOIN #temp2 AS up2\n ON up1.b = up2.b\nWHERE\n up1.b > 0;\n\nINSERT INTO prod_foo\nSELECT\n *\nFROM staging_foo",
"language": "SQL"
},
"source": "SYSTEM",

View File

@ -13,6 +13,7 @@ from datahub.sql_parsing.sql_parsing_aggregator import (
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import ColumnLineageInfo, ColumnRef
from tests.test_helpers import mce_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
RESOURCE_DIR = pathlib.Path(__file__).parent / "aggregator_goldens"
FROZEN_TIME = "2024-02-06 01:23:45"
@ -23,12 +24,13 @@ def _ts(ts: int) -> datetime:
@freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config) -> None:
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
aggregator = SqlParsingAggregator(
platform="redshift",
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=False,
query_log=QueryLogSetting.STORE_ALL,
)
aggregator.add_observed_query(
@ -45,6 +47,23 @@ def test_basic_lineage(pytestconfig: pytest.Config) -> None:
golden_path=RESOURCE_DIR / "test_basic_lineage.json",
)
# This test also validates the query log storage functionality.
aggregator.close()
query_log_db = aggregator.report.query_log_path
query_log_json = tmp_path / "query_log.json"
run_datahub_cmd(
[
"check",
"extract-sql-agg-log",
str(query_log_db),
"--output",
str(query_log_json),
]
)
mce_helpers.check_golden_file(
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
)
@freeze_time(FROZEN_TIME)
def test_overlapping_inserts(pytestconfig: pytest.Config) -> None: