perf(ingest/bigquery): Improve bigquery usage disk usage and speed (#7825)

This commit is contained in:
Andrew Sikowitz 2023-04-14 21:09:43 -04:00 committed by GitHub
parent e839ac4c40
commit 1ac1ccf26e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 186 additions and 68 deletions

View File

@ -80,11 +80,9 @@ markers =
integration: marks tests to only run in integration (deselect with '-m "not integration"')
integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1')
slow_integration: marks tests that are too slow to even run in integration (deselect with '-m "not slow_integration"')
performance: marks tests that are sparingly run to measure performance (deselect with '-m "not performance"')
testpaths =
testpaths =
tests/unit
tests/integration
tests/performance
[coverage:run]
# Because of some quirks in the way setup.cfg, coverage.py, pytest-cov,

View File

@ -408,6 +408,7 @@ base_dev_requirements = {
# We should make an effort to keep it up to date.
"black==22.12.0",
"coverage>=5.1",
"faker>=18.4.0",
"flake8>=3.8.3", # DEPRECATION: Once we drop Python 3.7, we can pin to 6.x.
"flake8-tidy-imports>=4.3.0",
"flake8-bugbear==23.3.12",

View File

@ -516,7 +516,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
self.report.report_dropped(project_id.id)
continue
logger.info(f"Processing project: {project_id.id}")
self.report.set_project_state(project_id.id, "Metadata Extraction")
self.report.set_ingestion_stage(project_id.id, "Metadata Extraction")
yield from self._process_project(conn, project_id)
if self._should_ingest_usage():
@ -526,7 +526,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
if self._should_ingest_lineage():
for project in projects:
self.report.set_project_state(project.id, "Lineage Extraction")
self.report.set_ingestion_stage(project.id, "Lineage Extraction")
yield from self.generate_lineage(project.id)
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
@ -671,7 +671,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
if self.config.profiling.enabled:
logger.info(f"Starting profiling project {project_id}")
self.report.set_project_state(project_id, "Profiling")
self.report.set_ingestion_stage(project_id, "Profiling")
yield from self.profiler.get_workunits(
project_id=project_id,
tables=db_tables,

View File

@ -193,7 +193,7 @@ class BigQueryV2Config(
file_backed_cache_size: int = Field(
hidden_from_docs=True,
default=200,
default=2000,
description="Maximum number of entries for the in-memory caches of FileBacked data structures.",
)

View File

@ -2,7 +2,7 @@ import collections
import dataclasses
import logging
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timezone
from typing import Counter, Dict, List, Optional
import pydantic
@ -70,6 +70,7 @@ class BigQueryV2Report(ProfilingSqlReport):
num_query_events: int = 0
num_filtered_read_events: int = 0
num_filtered_query_events: int = 0
num_usage_query_hash_collisions: int = 0
num_operational_stats_workunits_emitted: int = 0
read_reasons_stat: Counter[str] = dataclasses.field(
default_factory=collections.Counter
@ -77,20 +78,24 @@ class BigQueryV2Report(ProfilingSqlReport):
operation_types_stat: Counter[str] = dataclasses.field(
default_factory=collections.Counter
)
current_project_status: Optional[str] = None
usage_state_size: Optional[str] = None
ingestion_stage: Optional[str] = None
ingestion_stage_durations: Dict[str, str] = field(default_factory=TopKDict)
timer: Optional[PerfTimer] = field(
_timer: Optional[PerfTimer] = field(
default=None, init=False, repr=False, compare=False
)
def set_project_state(self, project: str, stage: str) -> None:
if self.timer:
def set_ingestion_stage(self, project: str, stage: str) -> None:
if self._timer:
elapsed = f"{self._timer.elapsed_seconds():.2f}"
logger.info(
f"Time spent in stage <{self.current_project_status}>: "
f"{self.timer.elapsed_seconds():.2f} seconds"
f"Time spent in stage <{self.ingestion_stage}>: {elapsed} seconds"
)
if self.ingestion_stage:
self.ingestion_stage_durations[self.ingestion_stage] = elapsed
else:
self.timer = PerfTimer()
self._timer = PerfTimer()
self.current_project_status = f"{project}: {stage} at {datetime.now()}"
self.timer.start()
self.ingestion_stage = f"{project}: {stage} at {datetime.now(timezone.utc)}"
self._timer.start()

View File

@ -1,5 +1,7 @@
import hashlib
import json
import logging
import os
import textwrap
import time
import uuid
@ -18,6 +20,7 @@ from typing import (
Union,
)
import humanfriendly
from google.cloud.bigquery import Client as BigQueryClient
from google.cloud.logging_v2.client import Client as GCPLoggingClient
from ratelimiter import RateLimiter
@ -44,7 +47,10 @@ from datahub.ingestion.source.bigquery_v2.common import (
_make_gcp_logging_client,
get_bigquery_client,
)
from datahub.ingestion.source.usage.usage_common import make_usage_workunit
from datahub.ingestion.source.usage.usage_common import (
TOTAL_BUDGET_FOR_QUERY_LIST,
make_usage_workunit,
)
from datahub.metadata.schema_classes import OperationClass, OperationTypeClass
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict
from datahub.utilities.perf_timer import PerfTimer
@ -79,6 +85,8 @@ OPERATION_STATEMENT_TYPES = {
}
READ_STATEMENT_TYPES: List[str] = ["SELECT"]
STRING_ENCODING = "utf-8"
MAX_QUERY_LENGTH = TOTAL_BUDGET_FOR_QUERY_LIST
@dataclass(frozen=True, order=True)
@ -157,6 +165,7 @@ class BigQueryUsageState(Closeable):
read_events: FileBackedDict[ReadEvent]
query_events: FileBackedDict[QueryEvent]
column_accesses: FileBackedDict[Tuple[str, str]]
queries: FileBackedDict[str]
def __init__(self, config: BigQueryV2Config):
self.conn = ConnectionWrapper()
@ -172,6 +181,10 @@ class BigQueryUsageState(Closeable):
"user": lambda e: e.actor_email,
},
cache_max_size=config.file_backed_cache_size,
# Evict entire cache to reduce db calls.
cache_eviction_batch_size=max(int(config.file_backed_cache_size * 0.9), 1),
delay_index_creation=True,
should_compress_value=True,
)
# Keyed by job_name
self.query_events = FileBackedDict[QueryEvent](
@ -182,6 +195,9 @@ class BigQueryUsageState(Closeable):
"is_read": lambda e: int(e.statementType in READ_STATEMENT_TYPES),
},
cache_max_size=config.file_backed_cache_size,
cache_eviction_batch_size=max(int(config.file_backed_cache_size * 0.9), 1),
delay_index_creation=True,
should_compress_value=True,
)
# Created just to store column accesses in sqlite for JOIN
self.column_accesses = FileBackedDict[Tuple[str, str]](
@ -189,7 +205,10 @@ class BigQueryUsageState(Closeable):
tablename="column_accesses",
extra_columns={"read_event": lambda p: p[0], "field": lambda p: p[1]},
cache_max_size=config.file_backed_cache_size,
cache_eviction_batch_size=max(int(config.file_backed_cache_size * 0.9), 1),
delay_index_creation=True,
)
self.queries = FileBackedDict[str](cache_max_size=config.file_backed_cache_size)
def close(self) -> None:
self.read_events.close()
@ -197,12 +216,23 @@ class BigQueryUsageState(Closeable):
self.column_accesses.close()
self.conn.close()
self.queries.close()
def create_indexes(self) -> None:
self.read_events.create_indexes()
self.query_events.create_indexes()
self.column_accesses.create_indexes()
def standalone_events(self) -> Iterable[AuditEvent]:
for read_event in self.read_events.values():
query = """
SELECT r.value, q.value
FROM read_events r
LEFT JOIN query_events q ON r.name = q.key
"""
for read_value, query_value in self.read_events.sql_query_iterator(query):
read_event = self.read_events.deserializer(read_value)
query_event = (
self.query_events.get(read_event.jobName)
if read_event.jobName
else None
self.query_events.deserializer(query_value) if query_value else None
)
yield AuditEvent(read_event=read_event, query_event=query_event)
for _, query_event in self.query_events.items_snapshot("NOT is_read"):
@ -293,6 +323,16 @@ class BigQueryUsageState(Closeable):
column_freq=json.loads(row["column_freq"] or "[]"),
)
def report_disk_usage(self, report: BigQueryV2Report) -> None:
report.usage_state_size = str(
{
"main": humanfriendly.format_size(os.path.getsize(self.conn.filename)),
"queries": humanfriendly.format_size(
os.path.getsize(self.queries._conn.filename)
),
}
)
class BigQueryUsageExtractor:
"""
@ -308,6 +348,8 @@ class BigQueryUsageExtractor:
def __init__(self, config: BigQueryV2Config, report: BigQueryV2Report):
self.config: BigQueryV2Config = config
self.report: BigQueryV2Report = report
# Replace hash of query with uuid if there are hash conflicts
self.uuid_to_query: Dict[str, str] = {}
def _is_table_allowed(self, table_ref: Optional[BigQueryTableRef]) -> bool:
return (
@ -328,6 +370,8 @@ class BigQueryUsageExtractor:
try:
with BigQueryUsageState(self.config) as usage_state:
self._ingest_events(events, table_refs, usage_state)
usage_state.create_indexes()
usage_state.report_disk_usage(self.report)
if self.config.usage.include_operational_stats:
yield from self._generate_operational_workunits(
@ -335,6 +379,7 @@ class BigQueryUsageExtractor:
)
yield from self._generate_usage_workunits(usage_state)
usage_state.report_disk_usage(self.report)
except Exception as e:
logger.error("Error processing usage", exc_info=True)
self.report.report_warning("usage-ingestion", str(e))
@ -362,7 +407,7 @@ class BigQueryUsageExtractor:
def _generate_operational_workunits(
self, usage_state: BigQueryUsageState, table_refs: Collection[str]
) -> Iterable[MetadataWorkUnit]:
self.report.set_project_state("All", "Usage Extraction Operational Stats")
self.report.set_ingestion_stage("*", "Usage Extraction Operational Stats")
for audit_event in usage_state.standalone_events():
try:
operational_wu = self._create_operation_workunit(
@ -381,7 +426,7 @@ class BigQueryUsageExtractor:
def _generate_usage_workunits(
self, usage_state: BigQueryUsageState
) -> Iterable[MetadataWorkUnit]:
self.report.set_project_state("All", "Usage Extraction Usage Aggregation")
self.report.set_ingestion_stage("*", "Usage Extraction Usage Aggregation")
top_n = (
self.config.usage.top_n_queries
if self.config.usage.include_top_n_queries
@ -389,11 +434,20 @@ class BigQueryUsageExtractor:
)
for entry in usage_state.usage_statistics(top_n=top_n):
try:
query_freq = [
(
self.uuid_to_query.get(
query_hash, usage_state.queries[query_hash]
),
count,
)
for query_hash, count in entry.query_freq
]
yield make_usage_workunit(
bucket_start_time=datetime.fromisoformat(entry.timestamp),
resource=BigQueryTableRef.from_string_name(entry.resource),
query_count=entry.query_count,
query_freq=entry.query_freq,
query_freq=query_freq,
user_freq=entry.user_freq,
column_freq=entry.column_freq,
bucket_duration=self.config.bucket_duration,
@ -416,7 +470,7 @@ class BigQueryUsageExtractor:
for project_id in projects:
with PerfTimer() as timer:
try:
self.report.set_project_state(
self.report.set_ingestion_stage(
project_id, "Usage Extraction Ingestion"
)
yield from self._get_parsed_bigquery_log_events(project_id)
@ -460,6 +514,16 @@ class BigQueryUsageExtractor:
usage_state.column_accesses[str(uuid.uuid4())] = key, field_read
return True
elif event.query_event and event.query_event.job_name:
query = event.query_event.query[:MAX_QUERY_LENGTH]
query_hash = hashlib.md5(query.encode(STRING_ENCODING)).hexdigest()
if usage_state.queries.get(query_hash, query) != query:
key = str(uuid.uuid4())
self.uuid_to_query[key] = query
event.query_event.query = key
self.report.num_usage_query_hash_collisions += 1
else:
usage_state.queries[query_hash] = query
event.query_event.query = query_hash
usage_state.query_events[event.query_event.job_name] = event.query_event
return True
return False

View File

@ -1,4 +1,5 @@
import collections
import gzip
import logging
import pathlib
import pickle
@ -153,8 +154,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
cache_max_size: int = _DEFAULT_MEMORY_CACHE_MAX_SIZE
cache_eviction_batch_size: int = _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE
delay_index_creation: bool = False
should_compress_value: bool = False
_conn: ConnectionWrapper = field(init=False, repr=False)
indexes_created: bool = field(init=False, default=False)
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
# Maintains a dirty bit marking whether the value has been modified since it was persisted.
@ -190,12 +194,24 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
)"""
)
# The key column will automatically be indexed, but we need indexes
# for the extra columns.
if not self.delay_index_creation:
self.create_indexes()
if self.should_compress_value:
serializer = self.serializer
self.serializer = lambda value: gzip.compress(serializer(value)) # type: ignore
deserializer = self.deserializer
self.deserializer = lambda value: deserializer(gzip.decompress(value))
def create_indexes(self) -> None:
if self.indexes_created:
return
# The key column will automatically be indexed, but we need indexes for the extra columns.
for column_name in self.extra_columns.keys():
self._conn.execute(
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
)
self.indexes_created = True
def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None:
self._active_object_cache[key] = value, dirty
@ -377,7 +393,7 @@ class FileBackedList(Generic[_VT]):
cache_eviction_batch_size: Optional[int] = None,
) -> None:
self._len = 0
self._dict = FileBackedDict(
self._dict = FileBackedDict[_VT](
shared_connection=connection,
serializer=serializer,
deserializer=deserializer,

View File

@ -1,7 +1,6 @@
# Performance Testing
This module provides a framework for performance testing our ingestion sources.
When running a performance test, make sure to output print statements and live logs:
```bash
pytest -s --log-cli-level=INFO -m performance tests/performance/<test_name>.py
python -m tests.performance.<test_name>
```

View File

@ -13,6 +13,8 @@ from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Iterable, List, TypeVar
from faker import Faker
from tests.performance.data_model import (
Container,
FieldAccess,
@ -106,12 +108,19 @@ def generate_queries(
seed_metadata: SeedMetadata,
num_selects: int,
num_operations: int,
num_unique_queries: int,
num_users: int,
tables_per_select: NormalDistribution = NormalDistribution(3, 5),
columns_per_select: NormalDistribution = NormalDistribution(10, 5),
upstream_tables_per_operation: NormalDistribution = NormalDistribution(2, 2),
query_length: NormalDistribution = NormalDistribution(100, 50),
) -> Iterable[Query]:
faker = Faker()
query_texts = [
faker.paragraph(query_length.sample_with_floor(30) // 30)
for _ in range(num_unique_queries)
]
all_tables = seed_metadata.tables + seed_metadata.views
users = [f"user-{i}@xyz.com" for i in range(num_users)]
for i in range(num_selects): # Pure SELECT statements
@ -120,7 +129,7 @@ def generate_queries(
FieldAccess(column, table) for table in tables for column in table.columns
]
yield Query(
text=f"{uuid.uuid4()}-{'*' * query_length.sample_with_floor(10)}",
text=random.choice(query_texts),
type="SELECT",
actor=random.choice(users),
timestamp=_random_time_between(
@ -141,7 +150,7 @@ def generate_queries(
for column in table.columns
]
yield Query(
text=f"{uuid.uuid4()}-{'*' * query_length.sample_with_floor(10)}",
text=random.choice(query_texts),
type=random.choice(OPERATION_TYPES),
actor=random.choice(users),
timestamp=_random_time_between(

View File

@ -2,41 +2,35 @@ import logging
import os
import random
from datetime import timedelta
from typing import Iterable, Tuple
import humanfriendly
import psutil
import pytest
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.bigquery_v2.bigquery_config import (
BigQueryUsageConfig,
BigQueryV2Config,
)
from datahub.ingestion.source.bigquery_v2.bigquery_report import (
BigQueryV2Report,
logger as report_logger,
)
from datahub.ingestion.source.bigquery_v2.bigquery_report import BigQueryV2Report
from datahub.ingestion.source.bigquery_v2.usage import BigQueryUsageExtractor
from datahub.utilities.perf_timer import PerfTimer
from tests.performance.bigquery import generate_events, ref_from_table
from tests.performance.data_generation import generate_data, generate_queries
pytestmark = pytest.mark.performance
from tests.performance.data_generation import (
NormalDistribution,
generate_data,
generate_queries,
)
@pytest.fixture(autouse=True)
def report_log_level_info(caplog):
with caplog.at_level(logging.INFO, logger=report_logger.name):
yield
def test_bigquery_usage(report_log_level_info):
def run_test():
report = BigQueryV2Report()
report.set_project_state("All", "Seed Data Generation")
report.set_ingestion_stage("All", "Seed Data Generation")
seed_metadata = generate_data(
num_containers=100,
num_tables=2500,
num_views=100,
time_range=timedelta(days=1),
num_containers=2000,
num_tables=20000,
num_views=2000,
time_range=timedelta(days=7),
)
all_tables = seed_metadata.tables + seed_metadata.views
@ -44,33 +38,64 @@ def test_bigquery_usage(report_log_level_info):
start_time=seed_metadata.start_time,
end_time=seed_metadata.end_time,
usage=BigQueryUsageConfig(include_top_n_queries=True, top_n_queries=10),
file_backed_cache_size=1000,
)
usage_extractor = BigQueryUsageExtractor(config, report)
report.set_project_state("All", "Event Generation")
report.set_ingestion_stage("All", "Event Generation")
num_projects = 5
num_projects = 100
projects = [f"project-{i}" for i in range(num_projects)]
table_to_project = {table.name: random.choice(projects) for table in all_tables}
table_refs = {str(ref_from_table(table, table_to_project)) for table in all_tables}
queries = generate_queries(
seed_metadata,
num_selects=30000,
num_operations=20000,
num_users=10,
queries = list(
generate_queries(
seed_metadata,
num_selects=240_000,
num_operations=800_000,
num_unique_queries=50_000,
num_users=2000,
query_length=NormalDistribution(2000, 500),
)
)
events = generate_events(queries, projects, table_to_project, config=config)
events = list(events)
queries.sort(key=lambda q: q.timestamp)
events = list(generate_events(queries, projects, table_to_project, config=config))
print(f"Events generated: {len(events)}")
pre_mem_usage = psutil.Process(os.getpid()).memory_info().rss
print(f"Test data size: {humanfriendly.format_size(pre_mem_usage)}")
report.set_project_state("All", "Event Ingestion")
report.set_ingestion_stage("All", "Event Ingestion")
with PerfTimer() as timer:
workunits = usage_extractor._run(events, table_refs)
num_workunits = sum(1 for _ in workunits)
report.set_project_state("All", "Done")
num_workunits, peak_memory_usage = workunit_sink(workunits)
report.set_ingestion_stage("All", "Done")
print(f"Workunits Generated: {num_workunits}")
print(f"Seconds Elapsed: {timer.elapsed_seconds():.2f} seconds")
print(
f"Memory Used: {humanfriendly.format_size(psutil.Process(os.getpid()).memory_info().rss)}"
f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}"
)
print(f"Disk Used: {report.usage_state_size}")
print(f"Hash collisions: {report.num_usage_query_hash_collisions}")
def workunit_sink(workunits: Iterable[MetadataWorkUnit]) -> Tuple[int, int]:
peak_memory_usage = psutil.Process(os.getpid()).memory_info().rss
i: int = 0
for i, wu in enumerate(workunits):
if i % 10_000 == 0:
peak_memory_usage = max(
peak_memory_usage, psutil.Process(os.getpid()).memory_info().rss
)
peak_memory_usage = max(
peak_memory_usage, psutil.Process(os.getpid()).memory_info().rss
)
return i, peak_memory_usage
if __name__ == "__main__":
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(logging.StreamHandler())
run_test()

View File

@ -396,8 +396,8 @@ def test_usage_counts_multiple_buckets_and_resources(
totalSqlQueries=4,
topSqlQueries=[
query_table_1_a().text,
query_table_1_b().text,
query_tables_1_and_2().text,
query_table_1_b().text,
],
uniqueUserCount=2,
userCounts=[
@ -471,7 +471,7 @@ def test_usage_counts_multiple_buckets_and_resources(
unit=BucketDuration.DAY, multiple=1
),
totalSqlQueries=2,
topSqlQueries=[query_table_2().text, query_tables_1_and_2().text],
topSqlQueries=[query_tables_1_and_2().text, query_table_2().text],
uniqueUserCount=1,
userCounts=[
DatasetUserUsageCountsClass(
@ -614,6 +614,7 @@ def test_operational_stats(
seed_metadata,
num_selects=10,
num_operations=20,
num_unique_queries=10,
num_users=3,
)
)