mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-05 21:29:48 +00:00
989 lines
38 KiB
Python
989 lines
38 KiB
Python
"""
|
|
Manage the communication with DataBricks Server and provide equivalent dataclasses for dependent modules
|
|
"""
|
|
|
|
import dataclasses
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, cast
|
|
from unittest.mock import patch
|
|
|
|
import cachetools
|
|
from cachetools import cached
|
|
from databricks.sdk import WorkspaceClient
|
|
from databricks.sdk.service.catalog import (
|
|
CatalogInfo,
|
|
ColumnInfo,
|
|
GetMetastoreSummaryResponse,
|
|
MetastoreInfo,
|
|
SchemaInfo,
|
|
TableInfo,
|
|
)
|
|
from databricks.sdk.service.iam import ServicePrincipal as DatabricksServicePrincipal
|
|
from databricks.sdk.service.sql import (
|
|
QueryFilter,
|
|
QueryInfo,
|
|
QueryStatementType,
|
|
QueryStatus,
|
|
)
|
|
from databricks.sdk.service.workspace import ObjectType
|
|
from databricks.sql import connect
|
|
from databricks.sql.types import Row
|
|
from typing_extensions import assert_never
|
|
|
|
from datahub._version import nice_version_name
|
|
from datahub.api.entities.external.unity_catalog_external_entites import UnityCatalogTag
|
|
from datahub.emitter.mce_builder import parse_ts_millis
|
|
from datahub.ingestion.source.unity.config import (
|
|
LineageDataSource,
|
|
)
|
|
from datahub.ingestion.source.unity.hive_metastore_proxy import HiveMetastoreProxy
|
|
from datahub.ingestion.source.unity.proxy_profiling import (
|
|
UnityCatalogProxyProfilingMixin,
|
|
)
|
|
from datahub.ingestion.source.unity.proxy_types import (
|
|
ALLOWED_STATEMENT_TYPES,
|
|
Catalog,
|
|
Column,
|
|
CustomCatalogType,
|
|
ExternalTableReference,
|
|
Metastore,
|
|
Notebook,
|
|
NotebookReference,
|
|
Query,
|
|
Schema,
|
|
ServicePrincipal,
|
|
Table,
|
|
TableReference,
|
|
)
|
|
from datahub.ingestion.source.unity.report import UnityCatalogReport
|
|
from datahub.utilities.file_backed_collections import FileBackedDict
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
# It is enough to keep the cache size to 1, since we only process one catalog at a time
|
|
# We need to change this if we want to support parallel processing of multiple catalogs
|
|
_MAX_CONCURRENT_CATALOGS = 1
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TableInfoWithGeneration(TableInfo):
|
|
generation: Optional[int] = None
|
|
|
|
def as_dict(self) -> dict:
|
|
return {**super().as_dict(), "generation": self.generation}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: Dict[str, Any]) -> "TableInfoWithGeneration":
|
|
table_info: TableInfoWithGeneration = cast(
|
|
TableInfoWithGeneration,
|
|
super().from_dict(d),
|
|
)
|
|
table_info.generation = d.get("generation")
|
|
return table_info
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class QueryFilterWithStatementTypes(QueryFilter):
|
|
statement_types: List[QueryStatementType] = dataclasses.field(default_factory=list)
|
|
|
|
def as_dict(self) -> dict:
|
|
return {**super().as_dict(), "statement_types": self.statement_types}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: Dict[str, Any]) -> "QueryFilterWithStatementTypes":
|
|
v: QueryFilterWithStatementTypes = cast(
|
|
QueryFilterWithStatementTypes,
|
|
super().from_dict(d),
|
|
)
|
|
v.statement_types = d["statement_types"]
|
|
return v
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TableUpstream:
|
|
table_name: str
|
|
source_type: str
|
|
last_updated: Optional[datetime] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExternalUpstream:
|
|
path: str
|
|
source_type: str
|
|
last_updated: Optional[datetime] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TableLineageInfo:
|
|
upstreams: List[TableUpstream] = dataclasses.field(default_factory=list)
|
|
external_upstreams: List[ExternalUpstream] = dataclasses.field(default_factory=list)
|
|
upstream_notebooks: List[NotebookReference] = dataclasses.field(
|
|
default_factory=list
|
|
)
|
|
downstream_notebooks: List[NotebookReference] = dataclasses.field(
|
|
default_factory=list
|
|
)
|
|
|
|
|
|
class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
|
|
_workspace_client: WorkspaceClient
|
|
_workspace_url: str
|
|
report: UnityCatalogReport
|
|
warehouse_id: str
|
|
|
|
def __init__(
|
|
self,
|
|
workspace_url: str,
|
|
personal_access_token: str,
|
|
warehouse_id: Optional[str],
|
|
report: UnityCatalogReport,
|
|
hive_metastore_proxy: Optional[HiveMetastoreProxy] = None,
|
|
lineage_data_source: LineageDataSource = LineageDataSource.AUTO,
|
|
):
|
|
self._workspace_client = WorkspaceClient(
|
|
host=workspace_url,
|
|
token=personal_access_token,
|
|
product="datahub",
|
|
product_version=nice_version_name(),
|
|
)
|
|
self.warehouse_id = warehouse_id or ""
|
|
self.report = report
|
|
self.hive_metastore_proxy = hive_metastore_proxy
|
|
self.lineage_data_source = lineage_data_source
|
|
self._sql_connection_params = {
|
|
"server_hostname": self._workspace_client.config.host.replace(
|
|
"https://", ""
|
|
),
|
|
"http_path": f"/sql/1.0/warehouses/{self.warehouse_id}",
|
|
"access_token": self._workspace_client.config.token,
|
|
}
|
|
|
|
def check_basic_connectivity(self) -> bool:
|
|
return bool(self._workspace_client.catalogs.list(include_browse=True))
|
|
|
|
def assigned_metastore(self) -> Optional[Metastore]:
|
|
response = self._workspace_client.metastores.summary()
|
|
return self._create_metastore(response)
|
|
|
|
def catalogs(self, metastore: Optional[Metastore]) -> Iterable[Catalog]:
|
|
if self.hive_metastore_proxy:
|
|
yield self.hive_metastore_proxy.hive_metastore_catalog(metastore)
|
|
|
|
response = self._workspace_client.catalogs.list(include_browse=True)
|
|
if not response:
|
|
logger.info("Catalogs not found")
|
|
return
|
|
for catalog in response:
|
|
optional_catalog = self._create_catalog(metastore, catalog)
|
|
if optional_catalog:
|
|
yield optional_catalog
|
|
|
|
def catalog(
|
|
self, catalog_name: str, metastore: Optional[Metastore]
|
|
) -> Optional[Catalog]:
|
|
response = self._workspace_client.catalogs.get(
|
|
catalog_name, include_browse=True
|
|
)
|
|
if not response:
|
|
logger.info(f"Catalog {catalog_name} not found")
|
|
return None
|
|
optional_catalog = self._create_catalog(metastore, response)
|
|
if optional_catalog:
|
|
return optional_catalog
|
|
|
|
return None
|
|
|
|
def schemas(self, catalog: Catalog) -> Iterable[Schema]:
|
|
if (
|
|
self.hive_metastore_proxy
|
|
and catalog.type == CustomCatalogType.HIVE_METASTORE_CATALOG
|
|
):
|
|
yield from self.hive_metastore_proxy.hive_metastore_schemas(catalog)
|
|
return
|
|
response = self._workspace_client.schemas.list(
|
|
catalog_name=catalog.name, include_browse=True
|
|
)
|
|
if not response:
|
|
logger.info(f"Schemas not found for catalog {catalog.id}")
|
|
return
|
|
for schema in response:
|
|
optional_schema = self._create_schema(catalog, schema)
|
|
if optional_schema:
|
|
yield optional_schema
|
|
|
|
def tables(self, schema: Schema) -> Iterable[Table]:
|
|
if (
|
|
self.hive_metastore_proxy
|
|
and schema.catalog.type == CustomCatalogType.HIVE_METASTORE_CATALOG
|
|
):
|
|
yield from self.hive_metastore_proxy.hive_metastore_tables(schema)
|
|
return
|
|
with patch("databricks.sdk.service.catalog.TableInfo", TableInfoWithGeneration):
|
|
response = self._workspace_client.tables.list(
|
|
catalog_name=schema.catalog.name,
|
|
schema_name=schema.name,
|
|
include_browse=True,
|
|
)
|
|
if not response:
|
|
logger.info(f"Tables not found for schema {schema.id}")
|
|
return
|
|
for table in response:
|
|
try:
|
|
optional_table = self._create_table(
|
|
schema, cast(TableInfoWithGeneration, table)
|
|
)
|
|
if optional_table:
|
|
yield optional_table
|
|
except Exception as e:
|
|
logger.warning(f"Error parsing table: {e}")
|
|
self.report.report_warning("table-parse", str(e))
|
|
|
|
def service_principals(self) -> Iterable[ServicePrincipal]:
|
|
for principal in self._workspace_client.service_principals.list():
|
|
optional_sp = self._create_service_principal(principal)
|
|
if optional_sp:
|
|
yield optional_sp
|
|
|
|
def groups(self):
|
|
"""
|
|
fetch the list of the groups belongs to the workspace, using the workspace client
|
|
create the list of group's display name, iterating through the list of groups fetched by the workspace client
|
|
"""
|
|
group_list: List[Optional[str]] = []
|
|
for group in self._workspace_client.groups.list():
|
|
group_list.append(group.display_name)
|
|
return group_list
|
|
|
|
def workspace_notebooks(self) -> Iterable[Notebook]:
|
|
for obj in self._workspace_client.workspace.list("/", recursive=True):
|
|
if obj.object_type == ObjectType.NOTEBOOK and obj.object_id and obj.path:
|
|
yield Notebook(
|
|
id=obj.object_id,
|
|
path=obj.path,
|
|
language=obj.language,
|
|
created_at=parse_ts_millis(obj.created_at),
|
|
modified_at=parse_ts_millis(obj.modified_at),
|
|
)
|
|
|
|
def query_history(
|
|
self,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
) -> Iterable[Query]:
|
|
"""Returns all queries that were run between start_time and end_time with relevant statement_type.
|
|
|
|
Raises:
|
|
DatabricksError: If the query history API returns an error.
|
|
"""
|
|
filter_by = QueryFilterWithStatementTypes.from_dict(
|
|
{
|
|
"query_start_time_range": {
|
|
"start_time_ms": start_time.timestamp() * 1000,
|
|
"end_time_ms": end_time.timestamp() * 1000,
|
|
},
|
|
"statuses": [QueryStatus.FINISHED],
|
|
"statement_types": [typ.value for typ in ALLOWED_STATEMENT_TYPES],
|
|
}
|
|
)
|
|
for query_info in self._query_history(filter_by=filter_by):
|
|
try:
|
|
optional_query = self._create_query(query_info)
|
|
if optional_query:
|
|
yield optional_query
|
|
except Exception as e:
|
|
logger.warning(f"Error parsing query: {e}")
|
|
self.report.report_warning("query-parse", str(e))
|
|
|
|
def _query_history(
|
|
self,
|
|
filter_by: QueryFilterWithStatementTypes,
|
|
max_results: int = 1000,
|
|
include_metrics: bool = False,
|
|
) -> Iterable[QueryInfo]:
|
|
"""Manual implementation of the query_history.list() endpoint.
|
|
|
|
Needed because:
|
|
- WorkspaceClient incorrectly passes params as query params, not body
|
|
- It does not paginate correctly -- needs to remove filter_by argument
|
|
Remove if these issues are fixed.
|
|
"""
|
|
method = "GET"
|
|
path = "/api/2.0/sql/history/queries"
|
|
body: Dict[str, Any] = {
|
|
"include_metrics": include_metrics,
|
|
"max_results": max_results, # Max batch size
|
|
}
|
|
|
|
response: dict = self._workspace_client.api_client.do( # type: ignore
|
|
method, path, body={**body, "filter_by": filter_by.as_dict()}
|
|
)
|
|
# we use default raw=False(default) in above request, therefore will always get dict
|
|
while True:
|
|
if "res" not in response or not response["res"]:
|
|
return
|
|
for v in response["res"]:
|
|
yield QueryInfo.from_dict(v)
|
|
if not response.get("next_page_token"): # last page
|
|
return
|
|
response = self._workspace_client.api_client.do( # type: ignore
|
|
method, path, body={**body, "page_token": response["next_page_token"]}
|
|
)
|
|
|
|
def _build_datetime_where_conditions(
|
|
self, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None
|
|
) -> str:
|
|
"""Build datetime filtering conditions for lineage queries."""
|
|
conditions = []
|
|
if start_time:
|
|
conditions.append(f"event_time >= '{start_time.isoformat()}'")
|
|
if end_time:
|
|
conditions.append(f"event_time <= '{end_time.isoformat()}'")
|
|
return " AND " + " AND ".join(conditions) if conditions else ""
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_catalog_table_lineage_via_system_tables(
|
|
self,
|
|
catalog: str,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
) -> FileBackedDict[TableLineageInfo]:
|
|
"""Get table lineage for all tables in a catalog using system tables."""
|
|
logger.info(f"Fetching table lineage for catalog: {catalog}")
|
|
try:
|
|
additional_where = self._build_datetime_where_conditions(
|
|
start_time, end_time
|
|
)
|
|
|
|
query = f"""
|
|
SELECT
|
|
entity_type, entity_id,
|
|
source_table_full_name, source_type,
|
|
target_table_full_name, target_type,
|
|
max(event_time) as last_updated
|
|
FROM system.access.table_lineage
|
|
WHERE
|
|
(target_table_catalog = %s or source_table_catalog = %s)
|
|
{additional_where}
|
|
GROUP BY
|
|
entity_type, entity_id,
|
|
source_table_full_name, source_type,
|
|
target_table_full_name, target_type
|
|
"""
|
|
rows = self._execute_sql_query(query, [catalog, catalog])
|
|
|
|
result_dict: FileBackedDict[TableLineageInfo] = FileBackedDict()
|
|
for row in rows:
|
|
entity_type = row["entity_type"]
|
|
entity_id = row["entity_id"]
|
|
source_full_name = row["source_table_full_name"]
|
|
target_full_name = row["target_table_full_name"]
|
|
source_type = row["source_type"]
|
|
last_updated = row["last_updated"]
|
|
|
|
# Initialize TableLineageInfo for both source and target tables if they're in our catalog
|
|
for table_name in [source_full_name, target_full_name]:
|
|
if (
|
|
table_name
|
|
and table_name.startswith(f"{catalog}.")
|
|
and table_name not in result_dict
|
|
):
|
|
result_dict[table_name] = TableLineageInfo()
|
|
|
|
# Process upstream relationships (target table gets upstreams)
|
|
if target_full_name and target_full_name.startswith(f"{catalog}."):
|
|
# Handle table upstreams
|
|
if (
|
|
source_type in ["TABLE", "VIEW"]
|
|
and source_full_name != target_full_name
|
|
):
|
|
upstream = TableUpstream(
|
|
table_name=source_full_name,
|
|
source_type=source_type,
|
|
last_updated=last_updated,
|
|
)
|
|
result_dict[target_full_name].upstreams.append(upstream)
|
|
|
|
# Handle external upstreams (PATH type)
|
|
elif source_type == "PATH":
|
|
external_upstream = ExternalUpstream(
|
|
path=source_full_name,
|
|
source_type=source_type,
|
|
last_updated=last_updated,
|
|
)
|
|
result_dict[target_full_name].external_upstreams.append(
|
|
external_upstream
|
|
)
|
|
|
|
# Handle upstream notebooks (notebook -> table)
|
|
elif entity_type == "NOTEBOOK":
|
|
notebook_ref = NotebookReference(
|
|
id=entity_id,
|
|
last_updated=last_updated,
|
|
)
|
|
result_dict[target_full_name].upstream_notebooks.append(
|
|
notebook_ref
|
|
)
|
|
|
|
# Process downstream relationships (source table gets downstream notebooks)
|
|
if (
|
|
entity_type == "NOTEBOOK"
|
|
and source_full_name
|
|
and source_full_name.startswith(f"{catalog}.")
|
|
):
|
|
notebook_ref = NotebookReference(
|
|
id=entity_id,
|
|
last_updated=last_updated,
|
|
)
|
|
result_dict[source_full_name].downstream_notebooks.append(
|
|
notebook_ref
|
|
)
|
|
|
|
return result_dict
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error getting table lineage for catalog {catalog}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return FileBackedDict()
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_catalog_column_lineage_via_system_tables(
|
|
self,
|
|
catalog: str,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
) -> FileBackedDict[Dict[str, dict]]:
|
|
"""Get column lineage for all tables in a catalog using system tables."""
|
|
logger.info(f"Fetching column lineage for catalog: {catalog}")
|
|
try:
|
|
additional_where = self._build_datetime_where_conditions(
|
|
start_time, end_time
|
|
)
|
|
|
|
query = f"""
|
|
SELECT
|
|
source_table_catalog, source_table_schema, source_table_name, source_column_name, source_type,
|
|
target_table_schema, target_table_name, target_column_name,
|
|
max(event_time) as last_updated
|
|
FROM system.access.column_lineage
|
|
WHERE
|
|
target_table_catalog = %s
|
|
AND target_table_schema IS NOT NULL
|
|
AND target_table_name IS NOT NULL
|
|
AND target_column_name IS NOT NULL
|
|
AND source_table_catalog IS NOT NULL
|
|
AND source_table_schema IS NOT NULL
|
|
AND source_table_name IS NOT NULL
|
|
AND source_column_name IS NOT NULL
|
|
{additional_where}
|
|
GROUP BY
|
|
source_table_catalog, source_table_schema, source_table_name, source_column_name, source_type,
|
|
target_table_schema, target_table_name, target_column_name
|
|
"""
|
|
rows = self._execute_sql_query(query, [catalog])
|
|
|
|
result_dict: FileBackedDict[Dict[str, dict]] = FileBackedDict()
|
|
for row in rows:
|
|
result_dict.setdefault(row["target_table_schema"], {}).setdefault(
|
|
row["target_table_name"], {}
|
|
).setdefault(row["target_column_name"], []).append(
|
|
# make fields look like the response from the older HTTP API
|
|
{
|
|
"catalog_name": row["source_table_catalog"],
|
|
"schema_name": row["source_table_schema"],
|
|
"table_name": row["source_table_name"],
|
|
"name": row["source_column_name"],
|
|
"last_updated": row["last_updated"],
|
|
}
|
|
)
|
|
|
|
return result_dict
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error getting column lineage for catalog {catalog}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return FileBackedDict()
|
|
|
|
def list_lineages_by_table_via_http_api(
|
|
self, table_name: str, include_entity_lineage: bool
|
|
) -> dict:
|
|
"""List table lineage by table name."""
|
|
logger.debug(f"Getting table lineage for {table_name}")
|
|
return self._workspace_client.api_client.do( # type: ignore
|
|
method="GET",
|
|
path="/api/2.0/lineage-tracking/table-lineage",
|
|
body={
|
|
"table_name": table_name,
|
|
"include_entity_lineage": include_entity_lineage,
|
|
},
|
|
)
|
|
|
|
def list_lineages_by_column_via_http_api(
|
|
self, table_name: str, column_name: str
|
|
) -> list:
|
|
"""List column lineage by table name and column name."""
|
|
logger.debug(f"Getting column lineage for {table_name}.{column_name}")
|
|
try:
|
|
return (
|
|
self._workspace_client.api_client.do( # type: ignore
|
|
"GET",
|
|
"/api/2.0/lineage-tracking/column-lineage",
|
|
body={"table_name": table_name, "column_name": column_name},
|
|
).get("upstream_cols")
|
|
or []
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error getting column lineage on table {table_name}, column {column_name}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return []
|
|
|
|
def table_lineage(
|
|
self,
|
|
table: Table,
|
|
include_entity_lineage: bool,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
) -> None:
|
|
if table.schema.catalog.type == CustomCatalogType.HIVE_METASTORE_CATALOG:
|
|
# Lineage is not available for Hive Metastore Tables.
|
|
return None
|
|
|
|
try:
|
|
# Determine lineage data source based on config
|
|
use_system_tables = False
|
|
if self.lineage_data_source == LineageDataSource.SYSTEM_TABLES:
|
|
use_system_tables = True
|
|
elif self.lineage_data_source == LineageDataSource.API:
|
|
use_system_tables = False
|
|
elif self.lineage_data_source == LineageDataSource.AUTO:
|
|
# Use the newer system tables if we have a SQL warehouse, otherwise fall back
|
|
# to the older (and slower) HTTP API.
|
|
use_system_tables = bool(self.warehouse_id)
|
|
else:
|
|
assert_never(self.lineage_data_source)
|
|
|
|
if use_system_tables:
|
|
self._process_system_table_lineage(table, start_time, end_time)
|
|
else:
|
|
self._process_table_lineage_via_http_api(table, include_entity_lineage)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error getting lineage on table {table.ref}: {e}", exc_info=True
|
|
)
|
|
|
|
def _process_system_table_lineage(
|
|
self,
|
|
table: Table,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
) -> None:
|
|
"""Process table lineage using system.access.table_lineage table."""
|
|
catalog_lineage = self.get_catalog_table_lineage_via_system_tables(
|
|
table.ref.catalog, start_time, end_time
|
|
)
|
|
table_full_name = table.ref.qualified_table_name
|
|
|
|
lineage_info = catalog_lineage.get(table_full_name, TableLineageInfo())
|
|
|
|
# Process table upstreams
|
|
for upstream in lineage_info.upstreams:
|
|
upstream_table_name = upstream.table_name
|
|
# Parse catalog.schema.table format
|
|
parts = upstream_table_name.split(".")
|
|
if len(parts) == 3:
|
|
catalog_name, schema_name, table_name = parts[0], parts[1], parts[2]
|
|
table_ref = TableReference(
|
|
metastore=table.schema.catalog.metastore.id
|
|
if table.schema.catalog.metastore
|
|
else None,
|
|
catalog=catalog_name,
|
|
schema=schema_name,
|
|
table=table_name,
|
|
last_updated=upstream.last_updated,
|
|
)
|
|
table.upstreams[table_ref] = {}
|
|
else:
|
|
logger.warning(
|
|
f"Unexpected upstream table format: {upstream_table_name} for table {table_full_name}"
|
|
)
|
|
continue
|
|
|
|
# Process external upstreams
|
|
for external_upstream in lineage_info.external_upstreams:
|
|
external_ref = ExternalTableReference(
|
|
path=external_upstream.path,
|
|
has_permission=True,
|
|
name=None,
|
|
type=None,
|
|
storage_location=external_upstream.path,
|
|
last_updated=external_upstream.last_updated,
|
|
)
|
|
table.external_upstreams.add(external_ref)
|
|
|
|
# Process upstream notebook lineage
|
|
for notebook_ref in lineage_info.upstream_notebooks:
|
|
existing_ref = table.upstream_notebooks.get(notebook_ref.id)
|
|
if existing_ref is None or (
|
|
notebook_ref.last_updated
|
|
and existing_ref.last_updated
|
|
and notebook_ref.last_updated > existing_ref.last_updated
|
|
):
|
|
table.upstream_notebooks[notebook_ref.id] = notebook_ref
|
|
|
|
# Process downstream notebook lineage
|
|
for notebook_ref in lineage_info.downstream_notebooks:
|
|
existing_ref = table.downstream_notebooks.get(notebook_ref.id)
|
|
if existing_ref is None or (
|
|
notebook_ref.last_updated
|
|
and existing_ref.last_updated
|
|
and notebook_ref.last_updated > existing_ref.last_updated
|
|
):
|
|
table.downstream_notebooks[notebook_ref.id] = notebook_ref
|
|
|
|
def _process_table_lineage_via_http_api(
|
|
self, table: Table, include_entity_lineage: bool
|
|
) -> None:
|
|
"""Process table lineage using the HTTP API (legacy fallback)."""
|
|
response: dict = self.list_lineages_by_table_via_http_api(
|
|
table_name=table.ref.qualified_table_name,
|
|
include_entity_lineage=include_entity_lineage,
|
|
)
|
|
|
|
for item in response.get("upstreams") or []:
|
|
if "tableInfo" in item:
|
|
table_ref = TableReference.create_from_lineage(
|
|
item["tableInfo"], table.schema.catalog.metastore
|
|
)
|
|
if table_ref:
|
|
table.upstreams[table_ref] = {}
|
|
elif "fileInfo" in item:
|
|
external_ref = ExternalTableReference.create_from_lineage(
|
|
item["fileInfo"]
|
|
)
|
|
if external_ref:
|
|
table.external_upstreams.add(external_ref)
|
|
|
|
for notebook in item.get("notebookInfos") or []:
|
|
notebook_ref = NotebookReference(
|
|
id=notebook["notebook_id"],
|
|
)
|
|
table.upstream_notebooks[notebook_ref.id] = notebook_ref
|
|
|
|
for item in response.get("downstreams") or []:
|
|
for notebook in item.get("notebookInfos") or []:
|
|
notebook_ref = NotebookReference(
|
|
id=notebook["notebook_id"],
|
|
)
|
|
table.downstream_notebooks[notebook_ref.id] = notebook_ref
|
|
|
|
def get_column_lineage(
|
|
self,
|
|
table: Table,
|
|
column_names: List[str],
|
|
*,
|
|
max_workers: Optional[int] = None,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
) -> None:
|
|
try:
|
|
# Determine lineage data source based on config
|
|
use_system_tables = False
|
|
if self.lineage_data_source == LineageDataSource.SYSTEM_TABLES:
|
|
use_system_tables = True
|
|
elif self.lineage_data_source == LineageDataSource.API:
|
|
use_system_tables = False
|
|
elif self.lineage_data_source == LineageDataSource.AUTO:
|
|
# Use the newer system tables if we have a SQL warehouse, otherwise fall back
|
|
# to the older (and slower) HTTP API.
|
|
use_system_tables = bool(self.warehouse_id)
|
|
else:
|
|
assert_never(self.lineage_data_source)
|
|
|
|
if use_system_tables:
|
|
lineage = (
|
|
self.get_catalog_column_lineage_via_system_tables(
|
|
table.ref.catalog, start_time, end_time
|
|
)
|
|
.get(table.ref.schema, {})
|
|
.get(table.ref.table, {})
|
|
)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
futures = [
|
|
executor.submit(
|
|
self.list_lineages_by_column_via_http_api,
|
|
table.ref.qualified_table_name,
|
|
column_name,
|
|
)
|
|
for column_name in column_names
|
|
]
|
|
lineage = {
|
|
column_name: future.result()
|
|
for column_name, future in zip(column_names, futures)
|
|
}
|
|
|
|
for column_name in column_names:
|
|
for item in lineage.get(column_name) or []:
|
|
table_ref = TableReference.create_from_lineage(
|
|
item,
|
|
table.schema.catalog.metastore,
|
|
)
|
|
if table_ref:
|
|
table.upstreams.setdefault(table_ref, {}).setdefault(
|
|
column_name, []
|
|
).append(item["name"])
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error getting column lineage on table {table.ref}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
@staticmethod
|
|
def _escape_sequence(value: str) -> str:
|
|
return value.replace(" ", "_")
|
|
|
|
@staticmethod
|
|
def _create_metastore(
|
|
obj: Union[GetMetastoreSummaryResponse, MetastoreInfo],
|
|
) -> Optional[Metastore]:
|
|
if not obj.name:
|
|
return None
|
|
return Metastore(
|
|
name=obj.name,
|
|
id=UnityCatalogApiProxy._escape_sequence(obj.name),
|
|
global_metastore_id=obj.global_metastore_id,
|
|
metastore_id=obj.metastore_id,
|
|
owner=obj.owner,
|
|
region=obj.region,
|
|
cloud=obj.cloud,
|
|
comment=None,
|
|
)
|
|
|
|
def _create_catalog(
|
|
self, metastore: Optional[Metastore], obj: CatalogInfo
|
|
) -> Optional[Catalog]:
|
|
if not obj.name:
|
|
self.report.num_catalogs_missing_name += 1
|
|
return None
|
|
catalog_name = self._escape_sequence(obj.name)
|
|
return Catalog(
|
|
name=obj.name,
|
|
id=f"{metastore.id}.{catalog_name}" if metastore else catalog_name,
|
|
metastore=metastore,
|
|
comment=obj.comment,
|
|
owner=obj.owner,
|
|
type=obj.catalog_type,
|
|
)
|
|
|
|
def _create_schema(self, catalog: Catalog, obj: SchemaInfo) -> Optional[Schema]:
|
|
if not obj.name:
|
|
self.report.num_schemas_missing_name += 1
|
|
return None
|
|
return Schema(
|
|
name=obj.name,
|
|
id=f"{catalog.id}.{self._escape_sequence(obj.name)}",
|
|
catalog=catalog,
|
|
comment=obj.comment,
|
|
owner=obj.owner,
|
|
)
|
|
|
|
def _create_column(self, table_id: str, obj: ColumnInfo) -> Optional[Column]:
|
|
if not obj.name:
|
|
self.report.num_columns_missing_name += 1
|
|
return None
|
|
return Column(
|
|
name=obj.name,
|
|
id=f"{table_id}.{self._escape_sequence(obj.name)}",
|
|
type_text=obj.type_text or "",
|
|
type_name=obj.type_name,
|
|
type_scale=obj.type_scale,
|
|
type_precision=obj.type_precision,
|
|
position=obj.position,
|
|
nullable=obj.nullable,
|
|
comment=obj.comment,
|
|
)
|
|
|
|
def _create_table(
|
|
self, schema: Schema, obj: TableInfoWithGeneration
|
|
) -> Optional[Table]:
|
|
if not obj.name:
|
|
self.report.num_tables_missing_name += 1
|
|
return None
|
|
table_id = f"{schema.id}.{self._escape_sequence(obj.name)}"
|
|
return Table(
|
|
name=obj.name,
|
|
id=table_id,
|
|
table_type=obj.table_type,
|
|
schema=schema,
|
|
storage_location=obj.storage_location,
|
|
data_source_format=obj.data_source_format,
|
|
columns=(
|
|
list(self._extract_columns(obj.columns, table_id))
|
|
if obj.columns
|
|
else []
|
|
),
|
|
view_definition=obj.view_definition or None,
|
|
properties=obj.properties or {},
|
|
owner=obj.owner,
|
|
generation=obj.generation,
|
|
created_at=(parse_ts_millis(obj.created_at) if obj.created_at else None),
|
|
created_by=obj.created_by,
|
|
updated_at=(parse_ts_millis(obj.updated_at) if obj.updated_at else None),
|
|
updated_by=obj.updated_by,
|
|
table_id=obj.table_id,
|
|
comment=obj.comment,
|
|
)
|
|
|
|
def _extract_columns(
|
|
self, columns: List[ColumnInfo], table_id: str
|
|
) -> Iterable[Column]:
|
|
for column in columns:
|
|
optional_column = self._create_column(table_id, column)
|
|
if optional_column:
|
|
yield optional_column
|
|
|
|
def _create_service_principal(
|
|
self, obj: DatabricksServicePrincipal
|
|
) -> Optional[ServicePrincipal]:
|
|
if not obj.display_name or not obj.application_id:
|
|
return None
|
|
return ServicePrincipal(
|
|
id=f"{obj.id}.{self._escape_sequence(obj.display_name)}",
|
|
display_name=obj.display_name,
|
|
application_id=obj.application_id,
|
|
active=obj.active,
|
|
)
|
|
|
|
def _create_query(self, info: QueryInfo) -> Optional[Query]:
|
|
if (
|
|
not info.query_text
|
|
or not info.query_start_time_ms
|
|
or not info.query_end_time_ms
|
|
):
|
|
self.report.num_queries_missing_info += 1
|
|
return None
|
|
return Query(
|
|
query_id=info.query_id,
|
|
query_text=info.query_text,
|
|
statement_type=info.statement_type,
|
|
start_time=parse_ts_millis(info.query_start_time_ms),
|
|
end_time=parse_ts_millis(info.query_end_time_ms),
|
|
user_id=info.user_id,
|
|
user_name=info.user_name,
|
|
executed_as_user_id=info.executed_as_user_id,
|
|
executed_as_user_name=info.executed_as_user_name,
|
|
)
|
|
|
|
def _execute_sql_query(self, query: str, params: Sequence[Any] = ()) -> List[Row]:
|
|
"""Execute SQL query using databricks-sql connector for better performance"""
|
|
try:
|
|
with (
|
|
connect(**self._sql_connection_params) as connection,
|
|
connection.cursor() as cursor,
|
|
):
|
|
cursor.execute(query, list(params))
|
|
return cursor.fetchall()
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to execute SQL query: {e}")
|
|
return []
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_schema_tags(self, catalog: str) -> Dict[str, List[UnityCatalogTag]]:
|
|
"""Optimized version using databricks-sql"""
|
|
logger.info(f"Fetching schema tags for catalog: `{catalog}`")
|
|
|
|
query = f"SELECT * FROM `{catalog}`.information_schema.schema_tags"
|
|
rows = self._execute_sql_query(query)
|
|
|
|
result_dict: Dict[str, List[UnityCatalogTag]] = {}
|
|
|
|
for row in rows:
|
|
catalog_name, schema_name, tag_name, tag_value = row
|
|
schema_key = f"{catalog_name}.{schema_name}"
|
|
|
|
if schema_key not in result_dict:
|
|
result_dict[schema_key] = []
|
|
|
|
result_dict[schema_key].append(
|
|
UnityCatalogTag(key=tag_name, value=tag_value)
|
|
)
|
|
|
|
return result_dict
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_catalog_tags(self, catalog: str) -> Dict[str, List[UnityCatalogTag]]:
|
|
"""Optimized version using databricks-sql"""
|
|
logger.info(f"Fetching table tags for catalog: `{catalog}`")
|
|
|
|
query = f"SELECT * FROM `{catalog}`.information_schema.catalog_tags"
|
|
rows = self._execute_sql_query(query)
|
|
|
|
result_dict: Dict[str, List[UnityCatalogTag]] = {}
|
|
|
|
for row in rows:
|
|
catalog_name, tag_name, tag_value = row
|
|
|
|
if catalog_name not in result_dict:
|
|
result_dict[catalog_name] = []
|
|
|
|
result_dict[catalog_name].append(
|
|
UnityCatalogTag(key=tag_name, value=tag_value)
|
|
)
|
|
|
|
return result_dict
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_table_tags(self, catalog: str) -> Dict[str, List[UnityCatalogTag]]:
|
|
"""Optimized version using databricks-sql"""
|
|
logger.info(f"Fetching table tags for catalog: `{catalog}`")
|
|
|
|
query = f"SELECT * FROM `{catalog}`.information_schema.table_tags"
|
|
rows = self._execute_sql_query(query)
|
|
|
|
result_dict: Dict[str, List[UnityCatalogTag]] = {}
|
|
|
|
for row in rows:
|
|
catalog_name, schema_name, table_name, tag_name, tag_value = row
|
|
table_key = f"{catalog_name}.{schema_name}.{table_name}"
|
|
|
|
if table_key not in result_dict:
|
|
result_dict[table_key] = []
|
|
|
|
result_dict[table_key].append(
|
|
UnityCatalogTag(key=tag_name, value=tag_value if tag_value else None)
|
|
)
|
|
|
|
return result_dict
|
|
|
|
@cached(cachetools.FIFOCache(maxsize=_MAX_CONCURRENT_CATALOGS))
|
|
def get_column_tags(self, catalog: str) -> Dict[str, List[UnityCatalogTag]]:
|
|
"""Optimized version using databricks-sql"""
|
|
logger.info(f"Fetching column tags for catalog: `{catalog}`")
|
|
|
|
query = f"SELECT * FROM `{catalog}`.information_schema.column_tags"
|
|
rows = self._execute_sql_query(query)
|
|
|
|
result_dict: Dict[str, List[UnityCatalogTag]] = {}
|
|
|
|
for row in rows:
|
|
catalog_name, schema_name, table_name, column_name, tag_name, tag_value = (
|
|
row
|
|
)
|
|
column_key = f"{catalog_name}.{schema_name}.{table_name}.{column_name}"
|
|
|
|
if column_key not in result_dict:
|
|
result_dict[column_key] = []
|
|
|
|
result_dict[column_key].append(
|
|
UnityCatalogTag(key=tag_name, value=tag_value if tag_value else None)
|
|
)
|
|
|
|
return result_dict
|