185 lines
6.0 KiB
Python

import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Dict, Iterable, List, Optional
from databricks.sdk.service.catalog import ColumnTypeName
from databricks.sdk.service.sql import QueryStatementType
from datahub.ingestion.source.unity.proxy_types import (
Catalog,
CatalogType,
Column,
Metastore,
Query,
Schema,
ServicePrincipal,
Table,
TableType,
)
from tests.performance import data_model
from tests.performance.data_generation import SeedMetadata
from tests.performance.data_model import ColumnType, StatementType
class UnityCatalogApiProxyMock:
"""Mimics UnityCatalogApiProxy for performance testing."""
def __init__(
self,
seed_metadata: SeedMetadata,
queries: Iterable[data_model.Query] = (),
num_service_principals: int = 0,
) -> None:
self.seed_metadata = seed_metadata
self.queries = queries
self.num_service_principals = num_service_principals
self.warehouse_id = "invalid-warehouse-id"
# Cache for performance
self._schema_to_table: Dict[str, List[data_model.Table]] = defaultdict(list)
for table in seed_metadata.all_tables:
self._schema_to_table[table.container.name].append(table)
def check_basic_connectivity(self) -> bool:
return True
def assigned_metastore(self) -> Metastore:
container = self.seed_metadata.containers[0][0]
return Metastore(
id=container.name,
name=container.name,
global_metastore_id=container.name,
metastore_id=container.name,
comment=None,
owner=None,
cloud=None,
region=None,
)
def catalogs(self, metastore: Optional[Metastore]) -> Iterable[Catalog]:
for container in self.seed_metadata.containers[1]:
if not container.parent or (
metastore and metastore.name != container.parent.name
):
continue
yield Catalog(
id=f"{metastore.id}.{container.name}" if metastore else container.name,
name=container.name,
metastore=metastore,
comment=None,
owner=None,
type=CatalogType.MANAGED_CATALOG,
)
def schemas(self, catalog: Catalog) -> Iterable[Schema]:
for container in self.seed_metadata.containers[2]:
# Assumes all catalog names are unique
if not container.parent or catalog.name != container.parent.name:
continue
yield Schema(
id=f"{catalog.id}.{container.name}",
name=container.name,
catalog=catalog,
comment=None,
owner=None,
)
def tables(self, schema: Schema) -> Iterable[Table]:
for table in self._schema_to_table[schema.name]:
columns = []
for i, col_name in enumerate(table.columns):
column = table.columns[col_name]
columns.append(
Column(
id=column.name,
name=column.name,
type_name=_convert_column_type(column.type),
type_text=column.type.value,
nullable=column.nullable,
position=i,
comment=None,
type_precision=0,
type_scale=0,
)
)
yield Table(
id=f"{schema.id}.{table.name}",
name=table.name,
schema=schema,
table_type=TableType.VIEW if table.is_view() else TableType.MANAGED,
columns=columns,
created_at=datetime.now(tz=timezone.utc),
comment=None,
owner=None,
storage_location=None,
data_source_format=None,
generation=None,
created_by="",
updated_at=None,
updated_by=None,
table_id="",
view_definition=(
table.definition if isinstance(table, data_model.View) else None
),
properties={},
)
def service_principals(self) -> Iterable[ServicePrincipal]:
for i in range(self.num_service_principals):
yield ServicePrincipal(
id=str(i),
application_id=str(uuid.uuid4()),
display_name=f"user-{i}",
active=True,
)
def query_history(
self,
start_time: datetime,
end_time: datetime,
) -> Iterable[Query]:
for i, query in enumerate(self.queries):
yield Query(
query_id=str(i),
query_text=query.text,
statement_type=_convert_statement_type(query.type),
start_time=query.timestamp,
end_time=query.timestamp,
user_id=hash(query.actor),
user_name=query.actor,
executed_as_user_id=hash(query.actor),
executed_as_user_name=None,
)
def table_lineage(self, table: Table, include_entity_lineage: bool) -> None:
pass
def get_column_lineage(self, table: Table) -> None:
pass
def _convert_column_type(t: ColumnType) -> ColumnTypeName:
if t == ColumnType.INTEGER:
return ColumnTypeName.INT
elif t == ColumnType.FLOAT:
return ColumnTypeName.DOUBLE
elif t == ColumnType.STRING:
return ColumnTypeName.STRING
elif t == ColumnType.BOOLEAN:
return ColumnTypeName.BOOLEAN
elif t == ColumnType.DATETIME:
return ColumnTypeName.TIMESTAMP
else:
raise ValueError(f"Unknown column type: {t}")
def _convert_statement_type(t: StatementType) -> QueryStatementType:
if t == "CUSTOM" or t == "UNKNOWN":
return QueryStatementType.OTHER
else:
return QueryStatementType[t]