test(ingest/unity): Unity catalog data generation (#8949)

This commit is contained in:
Andrew Sikowitz 2023-12-05 12:33:00 -05:00 committed by GitHub
parent 3ee82590cd
commit 806f09ae23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 382 additions and 100 deletions

View File

@ -262,6 +262,7 @@ databricks = {
"databricks-sdk>=0.9.0",
"pyspark~=3.3.0",
"requests",
"databricks-sql-connector",
}
mysql = sql_common | {"pymysql>=1.0.2"}

View File

@ -2,7 +2,7 @@ import dataclasses
import random
import uuid
from collections import defaultdict
from typing import Dict, Iterable, List, cast
from typing import Dict, Iterable, List, Set
from typing_extensions import get_args
@ -15,7 +15,7 @@ from datahub.ingestion.source.bigquery_v2.bigquery_audit import (
)
from datahub.ingestion.source.bigquery_v2.bigquery_config import BigQueryV2Config
from datahub.ingestion.source.bigquery_v2.usage import OPERATION_STATEMENT_TYPES
from tests.performance.data_model import Query, StatementType, Table, View
from tests.performance.data_model import Query, StatementType, Table
# https://cloud.google.com/bigquery/docs/reference/auditlogs/rest/Shared.Types/BigQueryAuditMetadata.TableDataRead.Reason
READ_REASONS = [
@ -86,7 +86,7 @@ def generate_events(
ref_from_table(parent, table_to_project)
for field in query.fields_accessed
if field.table.is_view()
for parent in cast(View, field.table).parents
for parent in field.table.upstreams
)
),
referencedViews=referencedViews,
@ -96,7 +96,7 @@ def generate_events(
query_on_view=True if referencedViews else False,
)
)
table_accesses = defaultdict(set)
table_accesses: Dict[BigQueryTableRef, Set[str]] = defaultdict(set)
for field in query.fields_accessed:
if not field.table.is_view():
table_accesses[ref_from_table(field.table, table_to_project)].add(
@ -104,7 +104,7 @@ def generate_events(
)
else:
# assuming that same fields are accessed in parent tables
for parent in cast(View, field.table).parents:
for parent in field.table.upstreams:
table_accesses[ref_from_table(parent, table_to_project)].add(
field.column
)

View File

@ -8,16 +8,16 @@ We could also get more human data by using Faker.
This is a work in progress, built piecemeal as needed.
"""
import random
import uuid
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Iterable, List, TypeVar, Union, cast
from typing import Collection, Iterable, List, Optional, TypeVar, Union, cast
from faker import Faker
from tests.performance.data_model import (
Column,
ColumnMapping,
ColumnType,
Container,
FieldAccess,
@ -40,17 +40,46 @@ OPERATION_TYPES: List[StatementType] = [
"UNKNOWN",
]
ID_COLUMN = "id" # Use to allow joins between all tables
class Distribution(metaclass=ABCMeta):
@abstractmethod
def _sample(self) -> int:
raise NotImplementedError
def sample(
self, *, floor: Optional[int] = None, ceiling: Optional[int] = None
) -> int:
value = self._sample()
if floor is not None:
value = max(value, floor)
if ceiling is not None:
value = min(value, ceiling)
return value
@dataclass(frozen=True)
class NormalDistribution:
class NormalDistribution(Distribution):
mu: float
sigma: float
def sample(self) -> int:
def _sample(self) -> int:
return int(random.gauss(mu=self.mu, sigma=self.sigma))
def sample_with_floor(self, floor: int = 1) -> int:
return max(int(random.gauss(mu=self.mu, sigma=self.sigma)), floor)
@dataclass(frozen=True)
class LomaxDistribution(Distribution):
"""See https://en.wikipedia.org/wiki/Lomax_distribution.
Equivalent to pareto(scale, shape) - scale; scale * beta_prime(1, shape)
"""
scale: float
shape: float
def _sample(self) -> int:
return int(self.scale * (random.paretovariate(self.shape) - 1))
@dataclass
@ -72,9 +101,9 @@ def generate_data(
num_containers: Union[List[int], int],
num_tables: int,
num_views: int,
columns_per_table: NormalDistribution = NormalDistribution(5, 2),
parents_per_view: NormalDistribution = NormalDistribution(2, 1),
view_definition_length: NormalDistribution = NormalDistribution(150, 50),
columns_per_table: Distribution = NormalDistribution(5, 2),
parents_per_view: Distribution = NormalDistribution(2, 1),
view_definition_length: Distribution = NormalDistribution(150, 50),
time_range: timedelta = timedelta(days=14),
) -> SeedMetadata:
# Assemble containers
@ -85,43 +114,32 @@ def generate_data(
for i, num_in_layer in enumerate(num_containers):
layer = [
Container(
f"{i}-container-{j}",
f"{_container_type(i)}_{j}",
parent=random.choice(containers[-1]) if containers else None,
)
for j in range(num_in_layer)
]
containers.append(layer)
# Assemble tables
# Assemble tables and views, lineage, and definitions
tables = [
Table(
f"table-{i}",
container=random.choice(containers[-1]),
columns=[
f"column-{j}-{uuid.uuid4()}"
for j in range(columns_per_table.sample_with_floor())
],
column_mapping=None,
)
for i in range(num_tables)
_generate_table(i, containers[-1], columns_per_table) for i in range(num_tables)
]
views = [
View(
f"view-{i}",
container=random.choice(containers[-1]),
columns=[
f"column-{j}-{uuid.uuid4()}"
for j in range(columns_per_table.sample_with_floor())
],
column_mapping=None,
definition=f"{uuid.uuid4()}-{'*' * view_definition_length.sample_with_floor(10)}",
parents=random.sample(tables, parents_per_view.sample_with_floor()),
**{ # type: ignore
**_generate_table(i, containers[-1], columns_per_table).__dict__,
"name": f"view_{i}",
"definition": f"--{'*' * view_definition_length.sample(floor=0)}",
},
)
for i in range(num_views)
]
for table in tables + views:
_generate_column_mapping(table)
for view in views:
view.upstreams = random.sample(tables, k=parents_per_view.sample(floor=1))
generate_lineage(tables, views)
now = datetime.now(tz=timezone.utc)
return SeedMetadata(
@ -133,6 +151,33 @@ def generate_data(
)
def generate_lineage(
tables: Collection[Table],
views: Collection[Table],
# Percentiles: 75th=0, 80th=1, 95th=2, 99th=4, 99.99th=15
upstream_distribution: Distribution = LomaxDistribution(scale=3, shape=5),
) -> None:
num_upstreams = [upstream_distribution.sample(ceiling=100) for _ in tables]
# Prioritize tables with a lot of upstreams themselves
factor = 1 + len(tables) // 10
table_weights = [1 + (num_upstreams[i] * factor) for i in range(len(tables))]
view_weights = [1] * len(views)
# TODO: Python 3.9 use random.sample with counts
sample = []
for table, weight in zip(tables, table_weights):
for _ in range(weight):
sample.append(table)
for view, weight in zip(views, view_weights):
for _ in range(weight):
sample.append(view)
for i, table in enumerate(tables):
table.upstreams = random.sample( # type: ignore
sample,
k=num_upstreams[i],
)
def generate_queries(
seed_metadata: SeedMetadata,
num_selects: int,
@ -146,12 +191,12 @@ def generate_queries(
) -> Iterable[Query]:
faker = Faker()
query_texts = [
faker.paragraph(query_length.sample_with_floor(30) // 30)
faker.paragraph(query_length.sample(floor=30) // 30)
for _ in range(num_unique_queries)
]
all_tables = seed_metadata.tables + seed_metadata.views
users = [f"user-{i}@xyz.com" for i in range(num_users)]
users = [f"user_{i}@xyz.com" for i in range(num_users)]
for i in range(num_selects): # Pure SELECT statements
tables = _sample_list(all_tables, tables_per_select)
all_columns = [
@ -191,21 +236,43 @@ def generate_queries(
)
def _generate_column_mapping(table: Table) -> ColumnMapping:
d = {}
for column in table.columns:
d[column] = Column(
name=column,
def _container_type(i: int) -> str:
if i == 0:
return "database"
elif i == 1:
return "schema"
else:
return f"{i}container"
def _generate_table(
i: int, parents: List[Container], columns_per_table: Distribution
) -> Table:
num_columns = columns_per_table.sample(floor=1)
columns = OrderedDict({ID_COLUMN: Column(ID_COLUMN, ColumnType.INTEGER, False)})
for j in range(num_columns):
name = f"column_{j}"
columns[name] = Column(
name=name,
type=random.choice(list(ColumnType)),
nullable=random.random() < 0.1, # Fixed 10% chance for now
)
table.column_mapping = d
return d
return Table(
f"table_{i}",
container=random.choice(parents),
columns=columns,
upstreams=[],
)
def _sample_list(lst: List[T], dist: NormalDistribution, floor: int = 1) -> List[T]:
return random.sample(lst, min(dist.sample_with_floor(floor), len(lst)))
return random.sample(lst, min(dist.sample(floor=floor), len(lst)))
def _random_time_between(start: datetime, end: datetime) -> datetime:
return start + timedelta(seconds=(end - start).total_seconds() * random.random())
if __name__ == "__main__":
z = generate_data(10, 1000, 10)

View File

@ -1,7 +1,9 @@
from dataclasses import dataclass
import typing
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from typing_extensions import Literal
@ -37,29 +39,63 @@ class ColumnType(str, Enum):
@dataclass
class Column:
name: str
type: ColumnType
nullable: bool
type: ColumnType = ColumnType.STRING
nullable: bool = False
ColumnRef = str
ColumnMapping = Dict[ColumnRef, Column]
@dataclass
@dataclass(init=False)
class Table:
name: str
container: Container
columns: List[ColumnRef]
column_mapping: Optional[ColumnMapping]
columns: typing.OrderedDict[ColumnRef, Column] = field(repr=False)
upstreams: List["Table"] = field(repr=False)
def __init__(
self,
name: str,
container: Container,
columns: Union[List[str], Dict[str, Column]],
upstreams: List["Table"],
):
self.name = name
self.container = container
self.upstreams = upstreams
if isinstance(columns, list):
self.columns = OrderedDict((col, Column(col)) for col in columns)
elif isinstance(columns, dict):
self.columns = OrderedDict(columns)
@property
def name_components(self) -> List[str]:
lst = [self.name]
container: Optional[Container] = self.container
while container:
lst.append(container.name)
container = container.parent
return lst[::-1]
def is_view(self) -> bool:
return False
@dataclass
@dataclass(init=False)
class View(Table):
definition: str
parents: List[Table]
def __init__(
self,
name: str,
container: Container,
columns: Union[List[str], Dict[str, Column]],
upstreams: List["Table"],
definition: str,
):
super().__init__(name, container, columns, upstreams)
self.definition = definition
def is_view(self) -> bool:
return True

View File

@ -0,0 +1,177 @@
import logging
import random
import string
from concurrent.futures import ThreadPoolExecutor, wait
from datetime import datetime
from typing import Callable, List, TypeVar, Union
from urllib.parse import urlparse
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import DatabricksError
from databricks.sdk.service.catalog import ColumnTypeName
from performance.data_generation import Distribution, LomaxDistribution, SeedMetadata
from performance.data_model import ColumnType, Container, Table, View
from performance.databricks.unity_proxy_mock import _convert_column_type
from sqlalchemy import create_engine
from datahub.ingestion.source.sql.sql_config import make_sqlalchemy_uri
logger = logging.getLogger(__name__)
T = TypeVar("T")
MAX_WORKERS = 200
class DatabricksDataGenerator:
def __init__(self, host: str, token: str, warehouse_id: str):
self.client = WorkspaceClient(host=host, token=token)
self.warehouse_id = warehouse_id
url = make_sqlalchemy_uri(
scheme="databricks",
username="token",
password=token,
at=urlparse(host).netloc,
db=None,
uri_opts={"http_path": f"/sql/1.0/warehouses/{warehouse_id}"},
)
engine = create_engine(
url, connect_args={"timeout": 600}, pool_size=MAX_WORKERS
)
self.connection = engine.connect()
def clear_data(self, seed_metadata: SeedMetadata) -> None:
for container in seed_metadata.containers[0]:
try:
self.client.catalogs.delete(container.name, force=True)
except DatabricksError:
pass
def create_data(
self,
seed_metadata: SeedMetadata,
# Percentiles: 1st=0, 10th=7, 25th=21, 50th=58, 75th=152, 90th=364, 99th=2063, 99.99th=46316
num_rows_distribution: Distribution = LomaxDistribution(scale=100, shape=1.5),
) -> None:
"""Create data in Databricks based on SeedMetadata."""
for container in seed_metadata.containers[0]:
self._create_catalog(container)
for container in seed_metadata.containers[1]:
self._create_schema(container)
_thread_pool_execute("create tables", seed_metadata.tables, self._create_table)
_thread_pool_execute("create views", seed_metadata.views, self._create_view)
_thread_pool_execute(
"populate tables",
seed_metadata.tables,
lambda t: self._populate_table(
t, num_rows_distribution.sample(ceiling=1_000_000)
),
)
_thread_pool_execute(
"create table lineage", seed_metadata.tables, self._create_table_lineage
)
def _create_catalog(self, catalog: Container) -> None:
try:
self.client.catalogs.get(catalog.name)
except DatabricksError:
self.client.catalogs.create(catalog.name)
def _create_schema(self, schema: Container) -> None:
try:
self.client.schemas.get(f"{schema.parent.name}.{schema.name}")
except DatabricksError:
self.client.schemas.create(schema.name, schema.parent.name)
def _create_table(self, table: Table) -> None:
try:
self.client.tables.delete(".".join(table.name_components))
except DatabricksError:
pass
columns = ", ".join(
f"{name} {_convert_column_type(column.type).value}"
for name, column in table.columns.items()
)
self._execute_sql(f"CREATE TABLE {_quote_table(table)} ({columns})")
self._assert_table_exists(table)
def _create_view(self, view: View) -> None:
self._execute_sql(_generate_view_definition(view))
self._assert_table_exists(view)
def _assert_table_exists(self, table: Table) -> None:
self.client.tables.get(".".join(table.name_components))
def _populate_table(self, table: Table, num_rows: int) -> None:
values = [
", ".join(
str(_generate_value(column.type)) for column in table.columns.values()
)
for _ in range(num_rows)
]
values_str = ", ".join(f"({value})" for value in values)
self._execute_sql(f"INSERT INTO {_quote_table(table)} VALUES {values_str}")
def _create_table_lineage(self, table: Table) -> None:
for upstream in table.upstreams:
self._execute_sql(_generate_insert_lineage(table, upstream))
def _execute_sql(self, sql: str) -> None:
print(sql)
self.connection.execute(sql)
def _thread_pool_execute(desc: str, lst: List[T], fn: Callable[[T], None]) -> None:
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = [executor.submit(fn, item) for item in lst]
wait(futures)
for future in futures:
try:
future.result()
except Exception as e:
logger.error(f"Error executing '{desc}': {e}", exc_info=True)
def _generate_value(t: ColumnType) -> Union[int, float, str, bool]:
ctn = _convert_column_type(t)
if ctn == ColumnTypeName.INT:
return random.randint(-(2**31), 2**31 - 1)
elif ctn == ColumnTypeName.DOUBLE:
return random.uniform(-(2**31), 2**31 - 1)
elif ctn == ColumnTypeName.STRING:
return (
"'" + "".join(random.choice(string.ascii_letters) for _ in range(8)) + "'"
)
elif ctn == ColumnTypeName.BOOLEAN:
return random.choice([True, False])
elif ctn == ColumnTypeName.TIMESTAMP:
return random.randint(0, int(datetime.now().timestamp()))
else:
raise NotImplementedError(f"Unsupported type {ctn}")
def _generate_insert_lineage(table: Table, upstream: Table) -> str:
select = []
for column in table.columns.values():
matching_cols = [c for c in upstream.columns.values() if c.type == column.type]
if matching_cols:
upstream_col = random.choice(matching_cols)
select.append(f"{upstream_col.name} AS {column.name}")
else:
select.append(f"{_generate_value(column.type)} AS {column.name}")
return f"INSERT INTO {_quote_table(table)} SELECT {', '.join(select)} FROM {_quote_table(upstream)}"
def _generate_view_definition(view: View) -> str:
from_statement = f"FROM {_quote_table(view.upstreams[0])} t0"
join_statement = " ".join(
f"JOIN {_quote_table(upstream)} t{i+1} ON t0.id = t{i+1}.id"
for i, upstream in enumerate(view.upstreams[1:])
)
return f"CREATE VIEW {_quote_table(view)} AS SELECT * {from_statement} {join_statement} {view.definition}"
def _quote_table(table: Table) -> str:
return ".".join(f"`{component}`" for component in table.name_components)

View File

@ -88,22 +88,21 @@ class UnityCatalogApiProxyMock:
def tables(self, schema: Schema) -> Iterable[Table]:
for table in self._schema_to_table[schema.name]:
columns = []
if table.column_mapping:
for i, col_name in enumerate(table.columns):
column = table.column_mapping[col_name]
columns.append(
Column(
id=column.name,
name=column.name,
type_name=self._convert_column_type(column.type),
type_text=column.type.value,
nullable=column.nullable,
position=i,
comment=None,
type_precision=0,
type_scale=0,
)
for i, col_name in enumerate(table.columns):
column = table.columns[col_name]
columns.append(
Column(
id=column.name,
name=column.name,
type_name=_convert_column_type(column.type),
type_text=column.type.value,
nullable=column.nullable,
position=i,
comment=None,
type_precision=0,
type_scale=0,
)
)
yield Table(
id=f"{schema.id}.{table.name}",
@ -145,7 +144,7 @@ class UnityCatalogApiProxyMock:
yield Query(
query_id=str(i),
query_text=query.text,
statement_type=self._convert_statement_type(query.type),
statement_type=_convert_statement_type(query.type),
start_time=query.timestamp,
end_time=query.timestamp,
user_id=hash(query.actor),
@ -160,24 +159,24 @@ class UnityCatalogApiProxyMock:
def get_column_lineage(self, table: Table) -> None:
pass
@staticmethod
def _convert_column_type(t: ColumnType) -> ColumnTypeName:
if t == ColumnType.INTEGER:
return ColumnTypeName.INT
elif t == ColumnType.FLOAT:
return ColumnTypeName.DOUBLE
elif t == ColumnType.STRING:
return ColumnTypeName.STRING
elif t == ColumnType.BOOLEAN:
return ColumnTypeName.BOOLEAN
elif t == ColumnType.DATETIME:
return ColumnTypeName.TIMESTAMP
else:
raise ValueError(f"Unknown column type: {t}")
@staticmethod
def _convert_statement_type(t: StatementType) -> QueryStatementType:
if t == "CUSTOM" or t == "UNKNOWN":
return QueryStatementType.OTHER
else:
return QueryStatementType[t]
def _convert_column_type(t: ColumnType) -> ColumnTypeName:
if t == ColumnType.INTEGER:
return ColumnTypeName.INT
elif t == ColumnType.FLOAT:
return ColumnTypeName.DOUBLE
elif t == ColumnType.STRING:
return ColumnTypeName.STRING
elif t == ColumnType.BOOLEAN:
return ColumnTypeName.BOOLEAN
elif t == ColumnType.DATETIME:
return ColumnTypeName.TIMESTAMP
else:
raise ValueError(f"Unknown column type: {t}")
def _convert_statement_type(t: StatementType) -> QueryStatementType:
if t == "CUSTOM" or t == "UNKNOWN":
return QueryStatementType.OTHER
else:
return QueryStatementType[t]

View File

@ -324,7 +324,7 @@ def test_get_projects_list_failure(
{"project_id_pattern": {"deny": ["^test-project$"]}}
)
source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test"))
caplog.records.clear()
caplog.clear()
with caplog.at_level(logging.ERROR):
projects = source._get_projects()
assert len(caplog.records) == 1

View File

@ -1,7 +1,7 @@
import logging
import random
from datetime import datetime, timedelta, timezone
from typing import Iterable, cast
from typing import Iterable
from unittest.mock import MagicMock, patch
import pytest
@ -45,15 +45,16 @@ ACTOR_1, ACTOR_1_URN = "a@acryl.io", "urn:li:corpuser:a"
ACTOR_2, ACTOR_2_URN = "b@acryl.io", "urn:li:corpuser:b"
DATABASE_1 = Container("database_1")
DATABASE_2 = Container("database_2")
TABLE_1 = Table("table_1", DATABASE_1, ["id", "name", "age"], None)
TABLE_2 = Table("table_2", DATABASE_1, ["id", "table_1_id", "value"], None)
TABLE_1 = Table("table_1", DATABASE_1, columns=["id", "name", "age"], upstreams=[])
TABLE_2 = Table(
"table_2", DATABASE_1, columns=["id", "table_1_id", "value"], upstreams=[]
)
VIEW_1 = View(
name="view_1",
container=DATABASE_1,
columns=["id", "name", "total"],
definition="VIEW DEFINITION 1",
parents=[TABLE_1, TABLE_2],
column_mapping=None,
upstreams=[TABLE_1, TABLE_2],
)
ALL_TABLES = [TABLE_1, TABLE_2, VIEW_1]
@ -842,6 +843,7 @@ def test_usage_counts_no_columns(
)
),
]
caplog.clear()
with caplog.at_level(logging.WARNING):
workunits = usage_extractor._get_workunits_internal(
events, [TABLE_REFS[TABLE_1.name]]
@ -938,7 +940,7 @@ def test_operational_stats(
).to_urn("PROD")
for field in query.fields_accessed
if field.table.is_view()
for parent in cast(View, field.table).parents
for parent in field.table.upstreams
)
),
),