feat(ingest/unity): Add usage extraction; add TableReference (#7910)

- Adds usage extraction to the unity catalog source and a TableReference object to handle references to tables
Also makes the following refactors:
- Creates UsageAggregator class to usage_common, as I've seen this same logic multiple times.
- Allows customizable user_urn_builder in usage_common as not all unity users are emails. We create emails with a default email_domain config in other connectors like redshift and snowflake, which seems unnecessary now?
- Creates TableReference for unity catalog and adds it to the Table dataclass, for managing string references to tables. Replaces logic, especially in lineage extraction, with these references
- Creates gen_dataset_urn and gen_user_urn on unity source to reduce duplicate code
Breaks up proxy.py into implementation and types
This commit is contained in:
Andrew Sikowitz 2023-05-01 14:30:09 -04:00 committed by GitHub
parent cd05f5b174
commit 5b290c9bc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 719 additions and 184 deletions

View File

@ -1,4 +1,9 @@
### Prerequisities
- Generate a Databrick Personal Access token following the guide here: https://docs.databricks.com/dev-tools/api/latest/authentication.html#generate-a-personal-access-token
- Get your workspace Id where Unity Catalog is following: https://docs.databricks.com/workspace/workspace-details.html#workspace-instance-names-urls-and-ids
- Check the starter recipe below and replace Token and Workspace Id with the ones above.
- Generate a Databrick Personal Access token following the guide here:
https://docs.databricks.com/dev-tools/api/latest/authentication.html#generate-a-personal-access-token
- Get your catalog's workspace id by following:
https://docs.databricks.com/workspace/workspace-details.html#workspace-instance-names-urls-and-ids
- To enable usage ingestion, ensure the account associated with your access token has
`CAN_MANAGE` permissions on any SQL Warehouses you want to ingest:
https://docs.databricks.com/security/auth-authz/access-control/sql-endpoint-acl.html
- Check the starter recipe below and replace Token and Workspace id with the ones above.

View File

@ -240,9 +240,7 @@ usage_common = {
"sqlparse",
}
databricks_cli = {
"databricks-cli==0.17.3",
}
databricks_cli = {"databricks-cli==0.17.3", "pyspark"}
# Note: for all of these, framework_common will be added.
plugins: Dict[str, Set[str]] = {

View File

@ -451,7 +451,9 @@ class BigQueryUsageExtractor:
user_freq=entry.user_freq,
column_freq=entry.column_freq,
bucket_duration=self.config.bucket_duration,
urn_builder=lambda resource: resource.to_urn(self.config.env),
resource_urn_builder=lambda resource: resource.to_urn(
self.config.env
),
top_n_queries=self.config.usage.top_n_queries,
format_sql_queries=self.config.usage.format_sql_queries,
)

View File

@ -1,3 +1,4 @@
from datetime import datetime, timezone
from typing import Dict, Optional
import pydantic
@ -12,9 +13,12 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfigBase,
)
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
class UnityCatalogSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin):
class UnityCatalogSourceConfig(
StatefulIngestionConfigBase, BaseUsageConfig, DatasetSourceConfigMixin
):
token: str = pydantic.Field(description="Databricks personal access token")
workspace_url: str = pydantic.Field(description="Databricks workspace url")
workspace_name: Optional[str] = pydantic.Field(
@ -65,6 +69,17 @@ class UnityCatalogSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigM
description="Option to enable/disable lineage generation. Currently we have to call a rest call per column to get column level lineage due to the Databrick api which can slow down ingestion. ",
)
include_usage_statistics: bool = Field(
default=True,
description="Generate usage statistics.",
)
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field(
default=None, description="Unity Catalog Stateful Ingestion Config."
)
@pydantic.validator("start_time")
def within_thirty_days(cls, v: datetime) -> datetime:
if (datetime.now(timezone.utc) - v).days > 30:
raise ValueError("Query history is only maintained for 30 days.")
return v

View File

@ -1,128 +1,32 @@
"""
Manage the communication with DataBricks Server and provide equivalent dataclasses for dependent modules
"""
import datetime
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional
from databricks_cli.sdk.api_client import ApiClient
from databricks_cli.unity_catalog.api import UnityCatalogApi
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.metadata.schema_classes import (
ArrayTypeClass,
BooleanTypeClass,
BytesTypeClass,
DateTypeClass,
MapTypeClass,
NullTypeClass,
NumberTypeClass,
RecordTypeClass,
SchemaFieldDataTypeClass,
StringTypeClass,
TimeTypeClass,
from datahub.ingestion.source.unity.proxy_types import (
ALLOWED_STATEMENT_TYPES,
DATA_TYPE_REGISTRY,
Catalog,
Column,
Metastore,
Query,
QueryStatus,
Schema,
ServicePrincipal,
StatementType,
Table,
TableReference,
)
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.metadata.schema_classes import SchemaFieldDataTypeClass
logger: logging.Logger = logging.getLogger(__name__)
# Supported types are available at
# https://api-docs.databricks.com/rest/latest/unity-catalog-api-specification-2-1.html?_ga=2.151019001.1795147704.1666247755-2119235717.1666247755
DATA_TYPE_REGISTRY: dict = {
"BOOLEAN": BooleanTypeClass,
"BYTE": BytesTypeClass,
"DATE": DateTypeClass,
"SHORT": NumberTypeClass,
"INT": NumberTypeClass,
"LONG": NumberTypeClass,
"FLOAT": NumberTypeClass,
"DOUBLE": NumberTypeClass,
"TIMESTAMP": TimeTypeClass,
"STRING": StringTypeClass,
"BINARY": BytesTypeClass,
"DECIMAL": NumberTypeClass,
"INTERVAL": TimeTypeClass,
"ARRAY": ArrayTypeClass,
"STRUCT": RecordTypeClass,
"MAP": MapTypeClass,
"CHAR": StringTypeClass,
"NULL": NullTypeClass,
}
@dataclass
class CommonProperty:
id: str
name: str
type: str
comment: Optional[str]
@dataclass
class Metastore(CommonProperty):
metastore_id: str
owner: Optional[str]
@dataclass
class Catalog(CommonProperty):
metastore: Metastore
owner: Optional[str]
@dataclass
class Schema(CommonProperty):
catalog: Catalog
owner: Optional[str]
@dataclass
class Column(CommonProperty):
type_text: str
type_name: SchemaFieldDataTypeClass
type_precision: int
type_scale: int
position: int
nullable: bool
comment: Optional[str]
@dataclass
class ColumnLineage:
source: str
destination: str
@dataclass
class ServicePrincipal:
id: str
application_id: str # uuid used to reference the service principal
display_name: str
active: Optional[bool]
@dataclass
class Table(CommonProperty):
schema: Schema
columns: List[Column]
storage_location: Optional[str]
data_source_format: Optional[str]
comment: Optional[str]
table_type: str
owner: Optional[str]
generation: int
created_at: datetime.datetime
created_by: str
updated_at: Optional[datetime.datetime]
updated_by: Optional[str]
table_id: str
view_definition: Optional[str]
properties: Dict[str, str]
upstreams: Dict[str, Dict[str, List[str]]] = field(default_factory=dict)
# lineage: Optional[Lineage]
class UnityCatalogApiProxy:
_unity_catalog_api: UnityCatalogApi
@ -197,7 +101,9 @@ class UnityCatalogApiProxy:
)
if response.get("tables") is None:
logger.info(f"Tables not found for schema {schema.name}")
logger.info(
f"Tables not found for schema {schema.catalog.name}.{schema.name}"
)
return []
for table in response["tables"]:
@ -217,6 +123,38 @@ class UnityCatalogApiProxy:
for principal in response["Resources"]:
yield self._create_service_principal(principal)
def query_history(
self,
start_time: datetime,
end_time: datetime,
) -> Iterable[Query]:
# This is a _complete_ hack. The source code of perform_query
# bundles data into query params if method == "GET", but we need it passed as the body.
# To get around this, we set method to lowercase "get".
# I still prefer this over duplicating the code in perform_query.
method = "get"
path = "/sql/history/queries"
data: Dict[str, Any] = {
"include_metrics": False,
"max_results": 1000, # Max batch size
}
filter_by = {
"query_start_time_range": {
"start_time_ms": start_time.timestamp() * 1000,
"end_time_ms": end_time.timestamp() * 1000,
},
"statuses": [QueryStatus.FINISHED.value],
"statement_types": list(ALLOWED_STATEMENT_TYPES),
}
response: dict = self._unity_catalog_api.client.client.perform_query(
method, path, {**data, "filter_by": filter_by}
)
yield from self._create_queries(response["res"])
while response["has_next_page"]:
response = self._unity_catalog_api.client.client.perform_query(
method, path, {**data, "next_page_token": response["next_page_token"]}
)
def list_lineages_by_table(self, table_name=None, headers=None):
"""
List table lineage by table name
@ -259,7 +197,12 @@ class UnityCatalogApiProxy:
table_name=f"{table.schema.catalog.name}.{table.schema.name}.{table.name}"
)
table.upstreams = {
f"{item['catalog_name']}.{item['schema_name']}.{item['name']}": {}
TableReference(
table.schema.catalog.metastore.id,
item["catalog_name"],
item["schema_name"],
item["name"],
): {}
for item in response.get("upstream_tables", [])
}
except Exception as e:
@ -277,17 +220,15 @@ class UnityCatalogApiProxy:
column_name=column.name,
)
for item in response.get("upstream_cols", []):
table_name = f"{item['catalog_name']}.{item['schema_name']}.{item['table_name']}"
col_name = item["name"]
if not table.upstreams.get(table_name):
table.upstreams[table_name] = {column.name: [col_name]}
else:
if column.name in table.upstreams[table_name]:
table.upstreams[table_name][column.name].append(
col_name
)
else:
table.upstreams[table_name][column.name] = [col_name]
table_ref = TableReference(
table.schema.catalog.metastore.id,
item["catalog_name"],
item["schema_name"],
item["table_name"],
)
table.upstreams.setdefault(table_ref, {}).setdefault(
column.name, []
).append(item["name"])
except Exception as e:
logger.error(f"Error getting lineage: {e}")
@ -366,9 +307,9 @@ class UnityCatalogApiProxy:
properties=obj.get("properties", {}),
owner=obj.get("owner"),
generation=obj["generation"],
created_at=datetime.datetime.utcfromtimestamp(obj["created_at"] / 1000),
created_at=datetime.utcfromtimestamp(obj["created_at"] / 1000),
created_by=obj["created_by"],
updated_at=datetime.datetime.utcfromtimestamp(obj["updated_at"] / 1000)
updated_at=datetime.utcfromtimestamp(obj["updated_at"] / 1000)
if "updated_at" in obj
else None,
updated_by=obj.get("updated_by", None),
@ -384,3 +325,25 @@ class UnityCatalogApiProxy:
application_id=obj["applicationId"],
active=obj.get("active"),
)
def _create_queries(self, lst: List[dict]) -> Iterable[Query]:
for obj in lst:
try:
yield self._create_query(obj)
except Exception as e:
logger.warning(f"Error parsing query: {e}")
self.report.report_warning("query-parse", str(e))
@staticmethod
def _create_query(obj: dict) -> Query:
return Query(
query_id=obj["query_id"],
query_text=obj["query_text"],
statement_type=StatementType(obj["statement_type"]),
start_time=datetime.utcfromtimestamp(obj["query_start_time_ms"] / 1000),
end_time=datetime.utcfromtimestamp(obj["query_end_time_ms"] / 1000),
user_id=obj["user_id"],
user_name=obj["user_name"],
executed_as_user_id=obj["executed_as_user_id"],
executed_as_user_name=obj["executed_as_user_name"],
)

View File

@ -0,0 +1,208 @@
# Supported types are available at
# https://api-docs.databricks.com/rest/latest/unity-catalog-api-specification-2-1.html?_ga=2.151019001.1795147704.1666247755-2119235717.1666247755
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from datahub.metadata.schema_classes import (
ArrayTypeClass,
BooleanTypeClass,
BytesTypeClass,
DateTypeClass,
MapTypeClass,
NullTypeClass,
NumberTypeClass,
OperationTypeClass,
RecordTypeClass,
SchemaFieldDataTypeClass,
StringTypeClass,
TimeTypeClass,
)
DATA_TYPE_REGISTRY: dict = {
"BOOLEAN": BooleanTypeClass,
"BYTE": BytesTypeClass,
"DATE": DateTypeClass,
"SHORT": NumberTypeClass,
"INT": NumberTypeClass,
"LONG": NumberTypeClass,
"FLOAT": NumberTypeClass,
"DOUBLE": NumberTypeClass,
"TIMESTAMP": TimeTypeClass,
"STRING": StringTypeClass,
"BINARY": BytesTypeClass,
"DECIMAL": NumberTypeClass,
"INTERVAL": TimeTypeClass,
"ARRAY": ArrayTypeClass,
"STRUCT": RecordTypeClass,
"MAP": MapTypeClass,
"CHAR": StringTypeClass,
"NULL": NullTypeClass,
}
class StatementType(str, Enum):
OTHER = "OTHER"
ALTER = "ALTER"
ANALYZE = "ANALYZE"
COPY = "COPY"
CREATE = "CREATE"
DELETE = "DELETE"
DESCRIBE = "DESCRIBE"
DROP = "DROP"
EXPLAIN = "EXPLAIN"
GRANT = "GRANT"
INSERT = "INSERT"
MERGE = "MERGE"
OPTIMIZE = "OPTIMIZE"
REFRESH = "REFRESH"
REPLACE = "REPLACE"
REVOKE = "REVOKE"
SELECT = "SELECT"
SET = "SET"
SHOW = "SHOW"
TRUNCATE = "TRUNCATE"
UPDATE = "UPDATE"
USE = "USE"
# Does not parse other statement types, besides SELECT
OPERATION_STATEMENT_TYPES = {
StatementType.INSERT: OperationTypeClass.INSERT,
StatementType.COPY: OperationTypeClass.INSERT,
StatementType.UPDATE: OperationTypeClass.UPDATE,
StatementType.MERGE: OperationTypeClass.UPDATE,
StatementType.DELETE: OperationTypeClass.DELETE,
StatementType.TRUNCATE: OperationTypeClass.DELETE,
StatementType.CREATE: OperationTypeClass.CREATE,
StatementType.REPLACE: OperationTypeClass.CREATE,
StatementType.ALTER: OperationTypeClass.ALTER,
StatementType.DROP: OperationTypeClass.DROP,
StatementType.OTHER: OperationTypeClass.UNKNOWN,
}
ALLOWED_STATEMENT_TYPES = {*OPERATION_STATEMENT_TYPES.keys(), StatementType.SELECT}
@dataclass
class CommonProperty:
id: str
name: str
type: str
comment: Optional[str]
@dataclass
class Metastore(CommonProperty):
metastore_id: str
owner: Optional[str]
@dataclass
class Catalog(CommonProperty):
metastore: Metastore
owner: Optional[str]
@dataclass
class Schema(CommonProperty):
catalog: Catalog
owner: Optional[str]
@dataclass
class Column(CommonProperty):
type_text: str
type_name: SchemaFieldDataTypeClass
type_precision: int
type_scale: int
position: int
nullable: bool
comment: Optional[str]
@dataclass
class ColumnLineage:
source: str
destination: str
@dataclass
class ServicePrincipal:
id: str
application_id: str # uuid used to reference the service principal
display_name: str
active: Optional[bool]
@dataclass(frozen=True, order=True)
class TableReference:
metastore_id: str
catalog: str
schema: str
table: str
@classmethod
def create(cls, table: "Table") -> "TableReference":
return cls(
table.schema.catalog.metastore.id,
table.schema.catalog.name,
table.schema.name,
table.name,
)
def __str__(self) -> str:
return f"{self.metastore_id}.{self.catalog}.{self.schema}.{self.table}"
@property
def qualified_table_name(self) -> str:
return f"{self.catalog}.{self.schema}.{self.table}"
@dataclass
class Table(CommonProperty):
schema: Schema
columns: List[Column]
storage_location: Optional[str]
data_source_format: Optional[str]
comment: Optional[str]
table_type: str
owner: Optional[str]
generation: int
created_at: datetime
created_by: str
updated_at: Optional[datetime]
updated_by: Optional[str]
table_id: str
view_definition: Optional[str]
properties: Dict[str, str]
upstreams: Dict[TableReference, Dict[str, List[str]]] = field(default_factory=dict)
ref: TableReference = field(init=False)
# lineage: Optional[Lineage]
def __post_init__(self):
self.ref = TableReference.create(self)
class QueryStatus(str, Enum):
FINISHED = "FINISHED"
RUNNING = "RUNNING"
QUEUED = "QUEUED"
FAILED = "FAILED"
CANCELED = "CANCELED"
@dataclass
class Query:
query_id: str
query_text: str
statement_type: StatementType
start_time: datetime
end_time: datetime
# User who ran the query
user_id: str
user_name: str # Email or username
# User whose credentials were used to run the query
executed_as_user_id: str
executed_as_user_name: str

View File

@ -12,3 +12,12 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport):
catalogs: EntityFilterReport = EntityFilterReport.field(type="catalog")
schemas: EntityFilterReport = EntityFilterReport.field(type="schema")
tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
num_queries: int = 0
num_queries_dropped_parse_failure: int = 0
num_queries_dropped_missing_table: int = 0 # Can be due to pattern filter
num_queries_dropped_duplicate_table: int = 0
num_queries_parsed_by_spark_plan: int = 0
num_operational_stats_workunits_emitted: int = 0
num_usage_workunits_emitted: int = 0

View File

@ -1,7 +1,7 @@
import logging
import re
import time
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Set
from datahub.emitter.mce_builder import (
make_data_platform_urn,
@ -53,7 +53,9 @@ from datahub.ingestion.source.unity.proxy import (
Schema,
ServicePrincipal,
)
from datahub.ingestion.source.unity.proxy_types import TableReference
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.ingestion.source.unity.usage import UnityCatalogUsageExtractor
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineage,
FineGrainedLineageUpstreamType,
@ -93,10 +95,11 @@ logger: logging.Logger = logging.getLogger(__name__)
@capability(SourceCapability.DESCRIPTIONS, "Enabled by default")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
@capability(SourceCapability.LINEAGE_FINE, "Enabled by default")
@capability(SourceCapability.USAGE_STATS, "Enabled by default")
@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default")
@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field")
@capability(SourceCapability.CONTAINERS, "Enabled by default")
@capability(SourceCapability.OWNERSHIP, "Supported via the `include_ownership` configs")
@capability(SourceCapability.OWNERSHIP, "Supported via the `include_ownership` config")
@capability(
SourceCapability.DELETION_DETECTION,
"Optionally enabled via `stateful_ingestion.remove_stale_metadata`",
@ -151,6 +154,8 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
# Global map of service principal application id -> ServicePrincipal
self.service_principals: Dict[str, ServicePrincipal] = {}
# Global set of table refs
self.table_refs: Set[TableReference] = set()
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
@ -190,6 +195,16 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.build_service_principal_map()
yield from self.process_metastores()
if self.config.include_usage_statistics:
usage_extractor = UnityCatalogUsageExtractor(
config=self.config,
report=self.report,
proxy=self.unity_catalog_api_proxy,
table_urn_builder=self.gen_dataset_urn,
user_urn_builder=self.gen_user_urn,
)
yield from usage_extractor.run(self.table_refs)
def build_service_principal_map(self) -> None:
try:
for sp in self.unity_catalog_api_proxy.service_principals():
@ -244,26 +259,18 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
def process_tables(self, schema: proxy.Schema) -> Iterable[MetadataWorkUnit]:
for table in self.unity_catalog_api_proxy.tables(schema=schema):
filter_table_name = (
f"{table.schema.catalog.name}.{table.schema.name}.{table.name}"
)
if not self.config.table_pattern.allowed(filter_table_name):
if not self.config.table_pattern.allowed(table.ref.qualified_table_name):
self.report.tables.dropped(table.id, type=table.type)
continue
self.table_refs.add(table.ref)
yield from self.process_table(table, schema)
self.report.tables.processed(table.id, type=table.type)
def process_table(
self, table: proxy.Table, schema: proxy.Schema
) -> Iterable[MetadataWorkUnit]:
dataset_urn: str = make_dataset_urn_with_platform_instance(
platform=self.platform,
platform_instance=self.platform_instance_name,
name=table.id,
)
dataset_urn = self.gen_dataset_urn(table.ref)
yield from self.add_table_to_dataset_container(dataset_urn, schema)
table_props = self._create_table_property_aspect(table)
@ -314,24 +321,23 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
) -> Optional[UpstreamLineageClass]:
upstreams: List[UpstreamClass] = []
finegrained_lineages: List[FineGrainedLineage] = []
for upstream in sorted(table.upstreams.keys()):
upstream_urn = make_dataset_urn_with_platform_instance(
self.platform,
f"{table.schema.catalog.metastore.id}.{upstream}",
self.platform_instance_name,
)
for upstream_ref, downstream_to_upstream_cols in sorted(
table.upstreams.items()
):
upstream_urn = self.gen_dataset_urn(upstream_ref)
for col in sorted(table.upstreams[upstream].keys()):
fl = FineGrainedLineage(
finegrained_lineages.extend(
FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=[
make_schema_field_urn(upstream_urn, upstream_col)
for upstream_col in sorted(table.upstreams[upstream][col])
for upstream_col in sorted(u_cols)
],
downstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreams=[make_schema_field_urn(dataset_urn, col)],
downstreams=[make_schema_field_urn(dataset_urn, d_col)],
)
finegrained_lineages.append(fl)
for d_col, u_cols in sorted(downstream_to_upstream_cols.items())
)
upstream_table = UpstreamClass(
upstream_urn,
@ -374,13 +380,23 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
return None
return DomainsClass(domains=[domain_urn])
def gen_user_urn(self, user: Optional[str]) -> Optional[str]:
def get_owner_urn(self, user: Optional[str]) -> Optional[str]:
if self.config.include_ownership and user is not None:
if user in self.service_principals:
user = self.service_principals[user].display_name
return make_user_urn(user)
return self.gen_user_urn(user)
return None
def gen_user_urn(self, user: str) -> str:
if user in self.service_principals:
user = self.service_principals[user].display_name
return make_user_urn(user)
def gen_dataset_urn(self, table_ref: TableReference) -> str:
return make_dataset_urn_with_platform_instance(
platform=self.platform,
platform_instance=self.platform_instance_name,
name=str(table_ref),
)
def gen_schema_containers(self, schema: Schema) -> Iterable[MetadataWorkUnit]:
domain_urn = self._gen_domain_urn(f"{schema.catalog.name}.{schema.name}")
@ -392,7 +408,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
parent_container_key=self.gen_catalog_key(catalog=schema.catalog),
domain_urn=domain_urn,
description=schema.comment,
owner_urn=self.gen_user_urn(schema.owner),
owner_urn=self.get_owner_urn(schema.owner),
)
def gen_metastore_containers(
@ -407,7 +423,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
sub_types=[DatasetContainerSubTypes.DATABRICKS_METASTORE],
domain_urn=domain_urn,
description=metastore.comment,
owner_urn=self.gen_user_urn(metastore.owner),
owner_urn=self.get_owner_urn(metastore.owner),
)
def gen_catalog_containers(self, catalog: Catalog) -> Iterable[MetadataWorkUnit]:
@ -422,7 +438,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
domain_urn=domain_urn,
parent_container_key=metastore_container_key,
description=catalog.comment,
owner_urn=self.gen_user_urn(catalog.owner),
owner_urn=self.get_owner_urn(catalog.owner),
)
def gen_schema_key(self, schema: Schema) -> PlatformKey:
@ -540,7 +556,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
def _create_table_ownership_aspect(
self, table: proxy.Table
) -> Optional[OwnershipClass]:
owner_urn = self.gen_user_urn(table.owner)
owner_urn = self.get_owner_urn(table.owner)
if owner_urn is not None:
return OwnershipClass(
owners=[

View File

@ -0,0 +1,217 @@
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Set, TypeVar
import pyspark
from sqllineage.runner import LineageRunner
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy
from datahub.ingestion.source.unity.proxy_types import (
OPERATION_STATEMENT_TYPES,
Query,
StatementType,
TableReference,
)
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.ingestion.source.usage.usage_common import UsageAggregator
from datahub.metadata.schema_classes import OperationClass
logger = logging.getLogger(__name__)
TableMap = Dict[str, List[TableReference]]
T = TypeVar("T")
@dataclass # Dataclass over NamedTuple to support generic type annotations
class GenericTableInfo(Generic[T]):
source_tables: List[T]
target_tables: List[T]
StringTableInfo = GenericTableInfo[str]
QueryTableInfo = GenericTableInfo[TableReference]
@dataclass(eq=False)
class UnityCatalogUsageExtractor:
config: UnityCatalogSourceConfig
report: UnityCatalogReport
proxy: UnityCatalogApiProxy
table_urn_builder: Callable[[TableReference], str]
user_urn_builder: Callable[[str], str]
def __post_init__(self):
self.usage_aggregator = UsageAggregator[TableReference](self.config)
self._spark_sql_parser: Optional[Any] = None
@property
def spark_sql_parser(self):
"""Lazily initializes the Spark SQL parser."""
if self._spark_sql_parser is None:
spark_context = pyspark.SparkContext.getOrCreate()
spark_session = pyspark.sql.SparkSession(spark_context)
self._spark_sql_parser = (
spark_session._jsparkSession.sessionState().sqlParser()
)
return self._spark_sql_parser
def run(self, table_refs: Set[TableReference]) -> Iterable[MetadataWorkUnit]:
try:
table_map = defaultdict(list)
for ref in table_refs:
table_map[ref.table].append(ref)
table_map[f"{ref.schema}.{ref.table}"].append(ref)
table_map[ref.qualified_table_name].append(ref)
yield from self._generate_workunits(table_map)
except Exception as e:
logger.error("Error processing usage", exc_info=True)
self.report.report_warning("usage-extraction", str(e))
def _generate_workunits(self, table_map: TableMap) -> Iterable[MetadataWorkUnit]:
for query in self._get_queries():
self.report.num_queries += 1
table_info = self._parse_query(query, table_map)
if table_info is not None:
if self.config.include_operational_stats:
yield from self._generate_operation_workunit(query, table_info)
for source_table in table_info.source_tables:
self.usage_aggregator.aggregate_event(
resource=source_table,
start_time=query.start_time,
query=query.query_text,
user=query.user_name,
fields=[],
)
for wu in self.usage_aggregator.generate_workunits(
resource_urn_builder=self.table_urn_builder,
user_urn_builder=self.user_urn_builder,
):
self.report.num_usage_workunits_emitted += 1
yield wu
def _generate_operation_workunit(
self, query: Query, table_info: QueryTableInfo
) -> Iterable[MetadataWorkUnit]:
if query.statement_type not in OPERATION_STATEMENT_TYPES:
return None
# Not sure about behavior when there are multiple target tables. This is a best attempt.
for target_table in table_info.target_tables:
operation_aspect = OperationClass(
timestampMillis=int(time.time() * 1000),
lastUpdatedTimestamp=int(query.end_time.timestamp() * 1000),
actor=self.user_urn_builder(query.user_name),
operationType=OPERATION_STATEMENT_TYPES[query.statement_type],
affectedDatasets=[
self.table_urn_builder(table) for table in table_info.source_tables
],
)
self.report.num_operational_stats_workunits_emitted += 1
yield MetadataChangeProposalWrapper(
entityUrn=self.table_urn_builder(target_table), aspect=operation_aspect
).as_workunit()
def _get_queries(self) -> Iterable[Query]:
try:
yield from self.proxy.query_history(
self.config.start_time, self.config.end_time
)
except Exception as e:
logger.warning("Error getting queries", exc_info=True)
self.report.report_warning("get-queries", str(e))
def _parse_query(
self, query: Query, table_map: TableMap
) -> Optional[QueryTableInfo]:
table_info = self._parse_query_via_lineage_runner(query.query_text)
if table_info is None and query.statement_type == StatementType.SELECT:
table_info = self._parse_query_via_spark_sql_plan(query.query_text)
if table_info is None:
self.report.num_queries_dropped_parse_failure += 1
return None
else:
return QueryTableInfo(
source_tables=self._resolve_tables(table_info.source_tables, table_map),
target_tables=self._resolve_tables(table_info.target_tables, table_map),
)
def _parse_query_via_lineage_runner(self, query: str) -> Optional[StringTableInfo]:
try:
runner = LineageRunner(query)
return GenericTableInfo(
source_tables=[
self._parse_sqllineage_table(table)
for table in runner.source_tables
],
target_tables=[
self._parse_sqllineage_table(table)
for table in runner.target_tables
],
)
except Exception:
logger.info(
f"Could not parse query via lineage runner, {query}", exc_info=True
)
return None
@staticmethod
def _parse_sqllineage_table(sqllineage_table: object) -> str:
full_table_name = str(sqllineage_table)
default_schema = "<default>."
if full_table_name.startswith(default_schema):
return full_table_name[len(default_schema) :]
else:
return full_table_name
def _parse_query_via_spark_sql_plan(self, query: str) -> Optional[StringTableInfo]:
"""Parse query source tables via Spark SQL plan. This is a fallback option."""
# Would be more effective if we upgrade pyspark
# Does not work with CTEs or non-SELECT statements
try:
plan = json.loads(self.spark_sql_parser.parsePlan(query).toJSON())
tables = [self._parse_plan_item(item) for item in plan]
self.report.num_queries_parsed_by_spark_plan += 1
return GenericTableInfo(
source_tables=[t for t in tables if t], target_tables=[]
)
except Exception:
logger.info(f"Could not parse query via spark plan, {query}", exc_info=True)
return None
@staticmethod
def _parse_plan_item(item: dict) -> Optional[str]:
if item["class"] == "org.apache.spark.sql.catalyst.analysis.UnresolvedRelation":
return ".".join(item["multipartIdentifier"].strip("[]").split(", "))
return None
def _resolve_tables(
self, tables: List[str], table_map: TableMap
) -> List[TableReference]:
"""Resolve tables to TableReferences, filtering out unrecognized or unresolvable table names."""
output = []
for table in tables:
table = str(table)
if table not in table_map:
logger.debug(f"Dropping query with unrecognized table: {table}")
self.report.num_queries_dropped_missing_table += 1
else:
refs = table_map[table]
if len(refs) == 1:
output.append(refs[0])
else:
logger.warning(
f"Could not resolve table ref for {table}: {len(refs)} duplicates."
)
self.report.num_queries_dropped_duplicate_table += 1
return output

View File

@ -1,8 +1,18 @@
import collections
import dataclasses
import logging
from collections import defaultdict
from datetime import datetime
from typing import Callable, Counter, Generic, List, Optional, Tuple, TypeVar
from typing import (
Callable,
Counter,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
)
import pydantic
from pydantic.fields import Field
@ -12,6 +22,7 @@ from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import (
BaseTimeWindowConfig,
BucketDuration,
get_time_bucket,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
@ -31,6 +42,10 @@ ResourceType = TypeVar("ResourceType")
TOTAL_BUDGET_FOR_QUERY_LIST = 24000
def default_user_urn_builder(email: str) -> str:
return builder.make_user_urn(email.split("@")[0])
def make_usage_workunit(
bucket_start_time: datetime,
resource: ResourceType,
@ -39,12 +54,16 @@ def make_usage_workunit(
user_freq: List[Tuple[str, int]],
column_freq: List[Tuple[str, int]],
bucket_duration: BucketDuration,
urn_builder: Callable[[ResourceType], str],
resource_urn_builder: Callable[[ResourceType], str],
top_n_queries: int,
format_sql_queries: bool,
user_urn_builder: Optional[Callable[[str], str]] = None,
total_budget_for_query_list: int = TOTAL_BUDGET_FOR_QUERY_LIST,
query_trimmer_string: str = " ...",
) -> MetadataWorkUnit:
if user_urn_builder is None:
user_urn_builder = default_user_urn_builder
top_sql_queries: Optional[List[str]] = None
if query_freq is not None:
budget_per_query: int = int(total_budget_for_query_list / top_n_queries)
@ -67,11 +86,11 @@ def make_usage_workunit(
topSqlQueries=top_sql_queries,
userCounts=[
DatasetUserUsageCountsClass(
user=builder.make_user_urn(user_email.split("@")[0]),
user=user_urn_builder(user),
count=count,
userEmail=user_email,
userEmail=user if "@" in user else None,
)
for user_email, count in user_freq
for user, count in user_freq
],
fieldCounts=[
DatasetFieldUsageCountsClass(
@ -83,7 +102,7 @@ def make_usage_workunit(
)
return MetadataChangeProposalWrapper(
entityUrn=urn_builder(resource),
entityUrn=resource_urn_builder(resource),
aspect=usageStats,
).as_workunit()
@ -96,9 +115,9 @@ class GenericAggregatedDataset(Generic[ResourceType]):
readCount: int = 0
queryCount: int = 0
queryFreq: Counter[str] = dataclasses.field(default_factory=collections.Counter)
userFreq: Counter[str] = dataclasses.field(default_factory=collections.Counter)
columnFreq: Counter[str] = dataclasses.field(default_factory=collections.Counter)
queryFreq: Counter[str] = dataclasses.field(default_factory=Counter)
userFreq: Counter[str] = dataclasses.field(default_factory=Counter)
columnFreq: Counter[str] = dataclasses.field(default_factory=Counter)
def add_read_entry(
self,
@ -122,10 +141,11 @@ class GenericAggregatedDataset(Generic[ResourceType]):
def make_usage_workunit(
self,
bucket_duration: BucketDuration,
urn_builder: Callable[[ResourceType], str],
resource_urn_builder: Callable[[ResourceType], str],
top_n_queries: int,
format_sql_queries: bool,
include_top_n_queries: bool,
user_urn_builder: Optional[Callable[[str], str]] = None,
total_budget_for_query_list: int = TOTAL_BUDGET_FOR_QUERY_LIST,
query_trimmer_string: str = " ...",
) -> MetadataWorkUnit:
@ -140,7 +160,8 @@ class GenericAggregatedDataset(Generic[ResourceType]):
user_freq=self.userFreq.most_common(),
column_freq=self.columnFreq.most_common(),
bucket_duration=bucket_duration,
urn_builder=urn_builder,
resource_urn_builder=resource_urn_builder,
user_urn_builder=user_urn_builder,
top_n_queries=top_n_queries,
format_sql_queries=format_sql_queries,
total_budget_for_query_list=total_budget_for_query_list,
@ -182,3 +203,51 @@ class BaseUsageConfig(BaseTimeWindowConfig):
f"top_n_queries is set to {v} but it can be maximum {max_queries}"
)
return v
class UsageAggregator(Generic[ResourceType]):
# TODO: Move over other connectors to use this class
def __init__(self, config: BaseUsageConfig):
self.config = config
self.aggregation: Dict[
datetime, Dict[ResourceType, GenericAggregatedDataset[ResourceType]]
] = defaultdict(dict)
def aggregate_event(
self,
*,
resource: ResourceType,
start_time: datetime,
query: Optional[str],
user: str,
fields: List[str],
) -> None:
floored_ts: datetime = get_time_bucket(start_time, self.config.bucket_duration)
self.aggregation[floored_ts].setdefault(
resource,
GenericAggregatedDataset[ResourceType](
bucket_start_time=floored_ts,
resource=resource,
),
).add_read_entry(
user,
query,
fields,
)
def generate_workunits(
self,
resource_urn_builder: Callable[[ResourceType], str],
user_urn_builder: Optional[Callable[[str], str]] = None,
) -> Iterable[MetadataWorkUnit]:
for time_bucket in self.aggregation.values():
for aggregate in time_bucket.values():
yield aggregate.make_usage_workunit(
bucket_duration=self.config.bucket_duration,
top_n_queries=self.config.top_n_queries,
format_sql_queries=self.config.format_sql_queries,
include_top_n_queries=self.config.include_top_n_queries,
resource_urn_builder=resource_urn_builder,
user_urn_builder=user_urn_builder,
)

View File

@ -204,7 +204,7 @@ def register_mock_data(unity_catalog_api_instance):
}
def mock_perform_query(method, path, **kwargs):
def mock_perform_query(method, path, *args, **kwargs):
if method == "GET" and path == "/account/scim/v2/ServicePrincipals":
return {
"Resources": [
@ -226,6 +226,8 @@ def mock_perform_query(method, path, **kwargs):
"itemsPerPage": 2,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
}
elif method == "get" and path == "/sql/history/queries":
return {"next_page_token": "next_token", "has_next_page": False, "res": []}
else:
return {}

View File

@ -0,0 +1,31 @@
from datetime import datetime, timedelta
import pytest
from freezegun import freeze_time
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
FROZEN_TIME = datetime.fromisoformat("2023-01-01 00:00:00+00:00")
@freeze_time(FROZEN_TIME)
def test_within_thirty_days():
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "token",
"workspace_url": "workspace_url",
"include_usage_statistics": True,
"start_time": FROZEN_TIME - timedelta(days=30),
}
)
assert config.start_time == FROZEN_TIME - timedelta(days=30)
with pytest.raises(ValueError):
UnityCatalogSourceConfig.parse_obj(
{
"token": "token",
"workspace_url": "workspace_url",
"include_usage_statistics": True,
"start_time": FROZEN_TIME - timedelta(days=31),
}
)

View File

@ -165,7 +165,7 @@ def test_make_usage_workunit():
)
wu: MetadataWorkUnit = ta.make_usage_workunit(
bucket_duration=BucketDuration.DAY,
urn_builder=_simple_urn_builder,
resource_urn_builder=_simple_urn_builder,
top_n_queries=10,
format_sql_queries=False,
include_top_n_queries=True,
@ -200,7 +200,7 @@ def test_query_formatting():
)
wu: MetadataWorkUnit = ta.make_usage_workunit(
bucket_duration=BucketDuration.DAY,
urn_builder=_simple_urn_builder,
resource_urn_builder=_simple_urn_builder,
top_n_queries=10,
format_sql_queries=True,
include_top_n_queries=True,
@ -233,7 +233,7 @@ def test_query_trimming():
)
wu: MetadataWorkUnit = ta.make_usage_workunit(
bucket_duration=BucketDuration.DAY,
urn_builder=_simple_urn_builder,
resource_urn_builder=_simple_urn_builder,
top_n_queries=top_n_queries,
format_sql_queries=False,
include_top_n_queries=True,
@ -276,7 +276,7 @@ def test_make_usage_workunit_include_top_n_queries():
)
wu: MetadataWorkUnit = ta.make_usage_workunit(
bucket_duration=BucketDuration.DAY,
urn_builder=_simple_urn_builder,
resource_urn_builder=_simple_urn_builder,
top_n_queries=10,
format_sql_queries=False,
include_top_n_queries=False,