mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-31 04:25:29 +00:00
feat(ingest): utilities for query logs (#10036)
This commit is contained in:
parent
4535f2adfd
commit
b0163c4885
@ -99,7 +99,7 @@ usage_common = {
|
||||
sqlglot_lib = {
|
||||
# Using an Acryl fork of sqlglot.
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
|
||||
"acryl-sqlglot==22.3.1.dev3",
|
||||
"acryl-sqlglot==22.4.1.dev4",
|
||||
}
|
||||
|
||||
classification_lib = {
|
||||
|
@ -1,4 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import pprint
|
||||
import shutil
|
||||
import tempfile
|
||||
@ -17,6 +20,7 @@ from datahub.ingestion.sink.sink_registry import sink_registry
|
||||
from datahub.ingestion.source.source_registry import source_registry
|
||||
from datahub.ingestion.transformer.transform_registry import transform_registry
|
||||
from datahub.telemetry import telemetry
|
||||
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -339,3 +343,28 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
|
||||
f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@check.command()
|
||||
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
|
||||
@click.option("--output", type=click.Path())
|
||||
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
|
||||
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""
|
||||
|
||||
from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
|
||||
|
||||
assert dataclasses.is_dataclass(LoggedQuery)
|
||||
|
||||
shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
|
||||
query_log = FileBackedList[LoggedQuery](
|
||||
shared_connection=shared_connection, tablename="stored_queries"
|
||||
)
|
||||
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
|
||||
queries = [dataclasses.asdict(query) for query in query_log]
|
||||
|
||||
if output:
|
||||
with open(output, "w") as f:
|
||||
json.dump(queries, f, indent=2)
|
||||
logger.info(f"Extracted {len(queries)} queries to {output}")
|
||||
else:
|
||||
click.echo(json.dumps(queries, indent=2))
|
||||
|
@ -1,8 +1,10 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import uuid
|
||||
@ -15,6 +17,7 @@ import datahub.metadata.schema_classes as models
|
||||
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
|
||||
from datahub.ingestion.api.closeable import Closeable
|
||||
from datahub.ingestion.api.report import Report
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
@ -53,9 +56,6 @@ logger = logging.getLogger(__name__)
|
||||
QueryId = str
|
||||
UrnStr = str
|
||||
|
||||
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
|
||||
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
|
||||
|
||||
|
||||
class QueryLogSetting(enum.Enum):
|
||||
DISABLED = "DISABLED"
|
||||
@ -63,6 +63,23 @@ class QueryLogSetting(enum.Enum):
|
||||
STORE_FAILED = "STORE_FAILED"
|
||||
|
||||
|
||||
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
|
||||
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
|
||||
_DEFAULT_QUERY_LOG_SETTING = QueryLogSetting[
|
||||
os.getenv("DATAHUB_SQL_AGG_QUERY_LOG") or QueryLogSetting.DISABLED.name
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoggedQuery:
|
||||
query: str
|
||||
session_id: Optional[str]
|
||||
timestamp: Optional[datetime]
|
||||
user: Optional[UrnStr]
|
||||
default_db: Optional[str]
|
||||
default_schema: Optional[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ViewDefinition:
|
||||
view_definition: str
|
||||
@ -170,7 +187,7 @@ class SqlAggregatorReport(Report):
|
||||
return super().compute_stats()
|
||||
|
||||
|
||||
class SqlParsingAggregator:
|
||||
class SqlParsingAggregator(Closeable):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -185,7 +202,7 @@ class SqlParsingAggregator:
|
||||
usage_config: Optional[BaseUsageConfig] = None,
|
||||
is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
|
||||
format_queries: bool = True,
|
||||
query_log: QueryLogSetting = QueryLogSetting.DISABLED,
|
||||
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
|
||||
) -> None:
|
||||
self.platform = DataPlatformUrn(platform)
|
||||
self.platform_instance = platform_instance
|
||||
@ -210,13 +227,18 @@ class SqlParsingAggregator:
|
||||
self.format_queries = format_queries
|
||||
self.query_log = query_log
|
||||
|
||||
# The exit stack helps ensure that we close all the resources we open.
|
||||
self._exit_stack = contextlib.ExitStack()
|
||||
|
||||
# Set up the schema resolver.
|
||||
self._schema_resolver: SchemaResolver
|
||||
if graph is None:
|
||||
self._schema_resolver = SchemaResolver(
|
||||
platform=self.platform.platform_name,
|
||||
platform_instance=self.platform_instance,
|
||||
env=self.env,
|
||||
self._schema_resolver = self._exit_stack.enter_context(
|
||||
SchemaResolver(
|
||||
platform=self.platform.platform_name,
|
||||
platform_instance=self.platform_instance,
|
||||
env=self.env,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._schema_resolver = None # type: ignore
|
||||
@ -235,27 +257,33 @@ class SqlParsingAggregator:
|
||||
|
||||
# By providing a filename explicitly here, we also ensure that the file
|
||||
# is not automatically deleted on exit.
|
||||
self._shared_connection = ConnectionWrapper(filename=query_log_path)
|
||||
self._shared_connection = self._exit_stack.enter_context(
|
||||
ConnectionWrapper(filename=query_log_path)
|
||||
)
|
||||
|
||||
# Stores the logged queries.
|
||||
self._logged_queries = FileBackedList[str](
|
||||
self._logged_queries = FileBackedList[LoggedQuery](
|
||||
shared_connection=self._shared_connection, tablename="stored_queries"
|
||||
)
|
||||
self._exit_stack.push(self._logged_queries)
|
||||
|
||||
# Map of query_id -> QueryMetadata
|
||||
self._query_map = FileBackedDict[QueryMetadata](
|
||||
shared_connection=self._shared_connection, tablename="query_map"
|
||||
)
|
||||
self._exit_stack.push(self._query_map)
|
||||
|
||||
# Map of downstream urn -> { query ids }
|
||||
self._lineage_map = FileBackedDict[OrderedSet[QueryId]](
|
||||
shared_connection=self._shared_connection, tablename="lineage_map"
|
||||
)
|
||||
self._exit_stack.push(self._lineage_map)
|
||||
|
||||
# Map of view urn -> view definition
|
||||
self._view_definitions = FileBackedDict[ViewDefinition](
|
||||
shared_connection=self._shared_connection, tablename="view_definitions"
|
||||
)
|
||||
self._exit_stack.push(self._view_definitions)
|
||||
|
||||
# Map of session ID -> {temp table name -> query id}
|
||||
# Needs to use the query_map to find the info about the query.
|
||||
@ -263,16 +291,20 @@ class SqlParsingAggregator:
|
||||
self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]](
|
||||
shared_connection=self._shared_connection, tablename="temp_lineage_map"
|
||||
)
|
||||
self._exit_stack.push(self._temp_lineage_map)
|
||||
|
||||
# Map of query ID -> schema fields, only for query IDs that generate temp tables.
|
||||
self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]](
|
||||
shared_connection=self._shared_connection, tablename="inferred_temp_schemas"
|
||||
shared_connection=self._shared_connection,
|
||||
tablename="inferred_temp_schemas",
|
||||
)
|
||||
self._exit_stack.push(self._inferred_temp_schemas)
|
||||
|
||||
# Map of table renames, from original UrnStr to new UrnStr.
|
||||
self._table_renames = FileBackedDict[UrnStr](
|
||||
shared_connection=self._shared_connection, tablename="table_renames"
|
||||
)
|
||||
self._exit_stack.push(self._table_renames)
|
||||
|
||||
# Usage aggregator. This will only be initialized if usage statistics are enabled.
|
||||
# TODO: Replace with FileBackedDict.
|
||||
@ -281,6 +313,9 @@ class SqlParsingAggregator:
|
||||
assert self.usage_config is not None
|
||||
self._usage_aggregator = UsageAggregator(config=self.usage_config)
|
||||
|
||||
def close(self) -> None:
|
||||
self._exit_stack.close()
|
||||
|
||||
@property
|
||||
def _need_schemas(self) -> bool:
|
||||
return self.generate_lineage or self.generate_usage_statistics
|
||||
@ -499,6 +534,9 @@ class SqlParsingAggregator:
|
||||
default_db=default_db,
|
||||
default_schema=default_schema,
|
||||
schema_resolver=schema_resolver,
|
||||
session_id=session_id,
|
||||
timestamp=query_timestamp,
|
||||
user=user,
|
||||
)
|
||||
if parsed.debug_info.error:
|
||||
self.report.observed_query_parse_failures.append(
|
||||
@ -700,6 +738,9 @@ class SqlParsingAggregator:
|
||||
default_db: Optional[str],
|
||||
default_schema: Optional[str],
|
||||
schema_resolver: SchemaResolverInterface,
|
||||
session_id: str = _MISSING_SESSION_ID,
|
||||
timestamp: Optional[datetime] = None,
|
||||
user: Optional[CorpUserUrn] = None,
|
||||
) -> SqlParsingResult:
|
||||
parsed = sqlglot_lineage(
|
||||
query,
|
||||
@ -712,7 +753,15 @@ class SqlParsingAggregator:
|
||||
if self.query_log == QueryLogSetting.STORE_ALL or (
|
||||
self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error
|
||||
):
|
||||
self._logged_queries.append(query)
|
||||
query_log_entry = LoggedQuery(
|
||||
query=query,
|
||||
session_id=session_id if session_id != _MISSING_SESSION_ID else None,
|
||||
timestamp=timestamp,
|
||||
user=user.urn() if user else None,
|
||||
default_db=default_db,
|
||||
default_schema=default_schema,
|
||||
)
|
||||
self._logged_queries.append(query_log_entry)
|
||||
|
||||
# Also add some extra logging.
|
||||
if parsed.debug_info.error:
|
||||
|
@ -62,9 +62,13 @@ def assert_metadata_files_equal(
|
||||
# We have to "normalize" the golden file by reading and writing it back out.
|
||||
# This will clean up nulls, double serialization, and other formatting issues.
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
|
||||
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
|
||||
golden = load_json_file(temp.name)
|
||||
try:
|
||||
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
|
||||
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
|
||||
golden = load_json_file(temp.name)
|
||||
except (ValueError, AssertionError) as e:
|
||||
logger.info(f"Error reformatting golden file as MCP/MCEs: {e}")
|
||||
golden = load_json_file(golden_path)
|
||||
|
||||
diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order)
|
||||
if diff and update_golden:
|
||||
@ -107,7 +111,7 @@ def diff_metadata_json(
|
||||
# if ignore_order is False, always use DeepDiff
|
||||
except CannotCompareMCPs as e:
|
||||
logger.info(f"{e}, falling back to MCE diff")
|
||||
except AssertionError as e:
|
||||
except (AssertionError, ValueError) as e:
|
||||
logger.warning(f"Reverting to old diff method: {e}")
|
||||
logger.debug("Error with new diff method", exc_info=True)
|
||||
|
||||
|
@ -126,6 +126,7 @@ class ConnectionWrapper:
|
||||
def close(self) -> None:
|
||||
for obj in self._dependent_objects:
|
||||
obj.close()
|
||||
self._dependent_objects.clear()
|
||||
with self.conn_lock:
|
||||
self.conn.close()
|
||||
if self._temp_directory:
|
||||
@ -440,7 +441,7 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]):
|
||||
self.close()
|
||||
|
||||
|
||||
class FileBackedList(Generic[_VT]):
|
||||
class FileBackedList(Generic[_VT], Closeable):
|
||||
"""An append-only, list-like object that stores its contents in a SQLite database."""
|
||||
|
||||
_len: int = field(default=0)
|
||||
@ -456,7 +457,6 @@ class FileBackedList(Generic[_VT]):
|
||||
cache_max_size: Optional[int] = None,
|
||||
cache_eviction_batch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self._len = 0
|
||||
self._dict = FileBackedDict[_VT](
|
||||
shared_connection=shared_connection,
|
||||
tablename=tablename,
|
||||
@ -468,6 +468,12 @@ class FileBackedList(Generic[_VT]):
|
||||
or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE,
|
||||
)
|
||||
|
||||
if shared_connection:
|
||||
shared_connection._dependent_objects.append(self)
|
||||
|
||||
# In case we're reusing an existing list, we need to run a query to get the length.
|
||||
self._len = len(self._dict)
|
||||
|
||||
@property
|
||||
def tablename(self) -> str:
|
||||
return self._dict.tablename
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,10 @@
|
||||
[
|
||||
{
|
||||
"query": "create table foo as select a, b from bar",
|
||||
"session_id": null,
|
||||
"timestamp": null,
|
||||
"user": null,
|
||||
"default_db": "dev",
|
||||
"default_schema": "public"
|
||||
}
|
||||
]
|
@ -83,7 +83,7 @@
|
||||
"aspect": {
|
||||
"json": {
|
||||
"statement": {
|
||||
"value": "create table #temp2 as select b, c from upstream2;\n\ncreate table #temp1 as select a, 2*b as b from upstream1;\n\ncreate temp table staging_foo as select up1.a, up1.b, up2.c from #temp1 up1 left join #temp2 up2 on up1.b = up2.b where up1.b > 0;\n\ninsert into table prod_foo\nselect * from staging_foo",
|
||||
"value": "CREATE TABLE #temp2 AS\nSELECT\n b,\n c\nFROM upstream2;\n\nCREATE TABLE #temp1 AS\nSELECT\n a,\n 2 * b AS b\nFROM upstream1;\n\nCREATE TEMPORARY TABLE staging_foo AS\nSELECT\n up1.a,\n up1.b,\n up2.c\nFROM #temp1 AS up1\nLEFT JOIN #temp2 AS up2\n ON up1.b = up2.b\nWHERE\n up1.b > 0;\n\nINSERT INTO prod_foo\nSELECT\n *\nFROM staging_foo",
|
||||
"language": "SQL"
|
||||
},
|
||||
"source": "SYSTEM",
|
||||
|
@ -13,6 +13,7 @@ from datahub.sql_parsing.sql_parsing_aggregator import (
|
||||
from datahub.sql_parsing.sql_parsing_common import QueryType
|
||||
from datahub.sql_parsing.sqlglot_lineage import ColumnLineageInfo, ColumnRef
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.click_helpers import run_datahub_cmd
|
||||
|
||||
RESOURCE_DIR = pathlib.Path(__file__).parent / "aggregator_goldens"
|
||||
FROZEN_TIME = "2024-02-06 01:23:45"
|
||||
@ -23,12 +24,13 @@ def _ts(ts: int) -> datetime:
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_basic_lineage(pytestconfig: pytest.Config) -> None:
|
||||
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
|
||||
aggregator = SqlParsingAggregator(
|
||||
platform="redshift",
|
||||
generate_lineage=True,
|
||||
generate_usage_statistics=False,
|
||||
generate_operations=False,
|
||||
query_log=QueryLogSetting.STORE_ALL,
|
||||
)
|
||||
|
||||
aggregator.add_observed_query(
|
||||
@ -45,6 +47,23 @@ def test_basic_lineage(pytestconfig: pytest.Config) -> None:
|
||||
golden_path=RESOURCE_DIR / "test_basic_lineage.json",
|
||||
)
|
||||
|
||||
# This test also validates the query log storage functionality.
|
||||
aggregator.close()
|
||||
query_log_db = aggregator.report.query_log_path
|
||||
query_log_json = tmp_path / "query_log.json"
|
||||
run_datahub_cmd(
|
||||
[
|
||||
"check",
|
||||
"extract-sql-agg-log",
|
||||
str(query_log_db),
|
||||
"--output",
|
||||
str(query_log_json),
|
||||
]
|
||||
)
|
||||
mce_helpers.check_golden_file(
|
||||
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_overlapping_inserts(pytestconfig: pytest.Config) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user