fix(ingest/bigquery): Increase batch size in metadata extraction if no partitioned table involved (#7252)

This commit is contained in:
Tamas Nemeth 2023-02-17 11:49:47 +01:00 committed by GitHub
parent 751289c9e3
commit aa388f04c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 31 deletions

View File

@ -199,6 +199,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
self.config: BigQueryV2Config = config
self.report: BigQueryV2Report = BigQueryV2Report()
self.platform: str = "bigquery"
BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = (
self.config.sharded_table_pattern
)
@ -725,6 +726,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
project_id=project_id,
dataset_name=dataset_name,
column_limit=self.config.column_limit,
run_optimized_column_query=self.config.run_optimized_column_query,
)
if self.config.include_tables:
@ -736,7 +738,10 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
table_columns = columns.get(table.name, []) if columns else []
yield from self._process_table(
conn, table, table_columns, project_id, dataset_name
table=table,
columns=table_columns,
project_id=project_id,
dataset_name=dataset_name,
)
if self.config.include_views:
@ -747,7 +752,10 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
for view in db_views[dataset_name]:
view_columns = columns.get(view.name, []) if columns else []
yield from self._process_view(
view, view_columns, project_id, dataset_name
view=view,
columns=view_columns,
project_id=project_id,
dataset_name=dataset_name,
)
# This method is used to generate the ignore list for datatypes the profiler doesn't support we have to do it here
@ -764,13 +772,12 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
def _process_table(
self,
conn: bigquery.Client,
table: BigqueryTable,
columns: List[BigqueryColumn],
project_id: str,
schema_name: str,
dataset_name: str,
) -> Iterable[MetadataWorkUnit]:
table_identifier = BigqueryTableIdentifier(project_id, schema_name, table.name)
table_identifier = BigqueryTableIdentifier(project_id, dataset_name, table.name)
self.report.report_entity_scanned(table_identifier.raw_table_name())
@ -805,7 +812,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
None,
)
yield from self.gen_table_dataset_workunits(
table, columns, project_id, schema_name
table, columns, project_id, dataset_name
)
def _process_view(
@ -1128,7 +1135,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
# In bigquery there is no way to query all tables in a Project id
with PerfTimer() as timer:
bigquery_tables = []
table_count: int = 0
partitioned_table_count_in_this_batch: int = 0
table_items: Dict[str, TableListItem] = {}
# Dict to store sharded table and the last seen max shard id
sharded_tables: Dict[str, TableListItem] = defaultdict()
@ -1161,35 +1168,42 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
# When table is only a shard we use dataset_name as table_name
sharded_tables[table_name] = table
continue
else:
stored_table_identifier = BigqueryTableIdentifier(
project_id=project_id,
dataset=dataset_name,
table=sharded_tables[table_name].table_id,
)
(
_,
stored_shard,
) = BigqueryTableIdentifier.get_table_and_shard(
stored_table_identifier.raw_table_name()
)
# When table is none, we use dataset_name as table_name
assert stored_shard
if stored_shard < shard:
sharded_tables[table_name] = table
continue
else:
table_count = table_count + 1
table_items[table.table_id] = table
if str(table_identifier).startswith(
stored_table_identifier = BigqueryTableIdentifier(
project_id=project_id,
dataset=dataset_name,
table=sharded_tables[table_name].table_id,
)
_, stored_shard = BigqueryTableIdentifier.get_table_and_shard(
stored_table_identifier.raw_table_name()
)
# When table is none, we use dataset_name as table_name
assert stored_shard
if stored_shard < shard:
sharded_tables[table_name] = table
continue
elif str(table_identifier).startswith(
self.config.temp_table_dataset_prefix
):
logger.debug(f"Dropping temporary table {table_identifier.table}")
self.report.report_dropped(table_identifier.raw_table_name())
continue
else:
if (
table.time_partitioning
or "range_partitioning" in table._properties
):
partitioned_table_count_in_this_batch += 1
if table_count % self.config.number_of_datasets_process_in_batch == 0:
table_items[table.table_id] = table
if (
len(table_items) % self.config.number_of_datasets_process_in_batch
== 0
) or (
partitioned_table_count_in_this_batch
== self.config.number_of_partitioned_datasets_process_in_batch
):
bigquery_tables.extend(
BigQueryDataDictionary.get_tables_for_dataset(
conn,
@ -1199,6 +1213,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
with_data_read_permission=self.config.profiling.enabled,
)
)
partitioned_table_count_in_this_batch = 0
table_items.clear()
# Sharded tables don't have partition keys, so it is safe to add to the list as

View File

@ -78,9 +78,17 @@ class BigQueryV2Config(
)
number_of_datasets_process_in_batch: int = Field(
default=80,
hidden_from_schema=True,
default=500,
description="Number of table queried in batch when getting metadata. This is a low level config property which should be touched with care. This restriction is needed because we query partitions system view which throws error if we try to touch too many tables.",
)
number_of_partitioned_datasets_process_in_batch: int = Field(
hidden_from_schema=True,
default=80,
description="Number of partitioned table queried in batch when getting metadata. This is a low level config property which should be touched with care. This restriction is needed because we query partitions system view which throws error if we try to touch too many tables.",
)
column_limit: int = Field(
default=300,
description="Maximum number of columns to process in a table. This is a low level config property which should be touched with care. This restriction is needed because excessively wide tables can result in failure to ingest the schema.",
@ -171,6 +179,12 @@ class BigQueryV2Config(
description="Useful for debugging lineage information. Set to True to see the raw lineage created internally.",
)
run_optimized_column_query: bool = Field(
hidden_from_schema=True,
default=False,
description="Run optimized column query to get column information. This is an experimental feature and may not work for all cases.",
)
def __init__(self, **data: Any):
super().__init__(**data)

View File

@ -280,6 +280,34 @@ from
ORDER BY
table_catalog, table_schema, table_name, ordinal_position ASC, data_type DESC"""
optimized_columns_for_dataset: str = """
select * from
(select
c.table_catalog as table_catalog,
c.table_schema as table_schema,
c.table_name as table_name,
c.column_name as column_name,
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
CASE WHEN CONTAINS_SUBSTR(field_path, ".") THEN NULL ELSE c.data_type END as data_type,
description as comment,
c.is_hidden as is_hidden,
c.is_partitioning_column as is_partitioning_column,
-- We count the columns to be able limit it later
row_number() over (partition by c.table_catalog, c.table_schema, c.table_name order by c.ordinal_position asc, c.data_type DESC) as column_num,
-- Getting the maximum shard for each table
row_number() over (partition by c.table_catalog, c.table_schema, ifnull(REGEXP_EXTRACT(c.table_name, r'(.*)_\\d{{8}}$'), c.table_name), cfp.field_path order by c.table_catalog, c.table_schema asc, c.table_name desc) as shard_num
from
`{project_id}`.`{dataset_name}`.INFORMATION_SCHEMA.COLUMNS c
join `{project_id}`.`{dataset_name}`.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS as cfp on cfp.table_name = c.table_name
and cfp.column_name = c.column_name
)
-- We filter column limit + 1 to make sure we warn about the limit being reached but not reading too much data
where column_num <= {column_limit} and shard_num = 1
ORDER BY
table_catalog, table_schema, table_name, ordinal_position, column_num ASC, table_name, data_type DESC"""
columns_for_table: str = """
select
c.table_catalog as table_catalog,
@ -456,7 +484,8 @@ class BigQueryDataDictionary:
conn: bigquery.Client,
project_id: str,
dataset_name: str,
column_limit: Optional[int] = None,
column_limit: int,
run_optimized_column_query: bool = False,
) -> Optional[Dict[str, List[BigqueryColumn]]]:
columns: Dict[str, List[BigqueryColumn]] = defaultdict(list)
try:
@ -464,6 +493,12 @@ class BigQueryDataDictionary:
conn,
BigqueryQuery.columns_for_dataset.format(
project_id=project_id, dataset_name=dataset_name
)
if not run_optimized_column_query
else BigqueryQuery.optimized_columns_for_dataset.format(
project_id=project_id,
dataset_name=dataset_name,
column_limit=column_limit,
),
)
except Exception as e: