fix(ingest): bigquery-beta - turning sql parsing off in lineage extraction (#6163)

Co-authored-by: Shirshanka Das <shirshanka@apache.org>
This commit is contained in:
Tamas Nemeth 2022-10-11 05:30:29 +02:00 committed by GitHub
parent d569734193
commit 128e3a8970
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 300 additions and 92 deletions

View File

@ -540,6 +540,9 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
)
return
self.report.num_project_datasets_to_scan[project_id] = len(
bigquery_project.datasets
)
for bigquery_dataset in bigquery_project.datasets:
if not self.config.dataset_pattern.allowed(bigquery_dataset.name):
@ -619,7 +622,9 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
self.report.report_dropped(table_identifier.raw_table_name())
return
table.columns = self.get_columns_for_table(conn, table_identifier)
table.columns = self.get_columns_for_table(
conn, table_identifier, self.config.column_limit
)
if not table.columns:
logger.warning(f"Unable to get columns for table: {table_identifier}")
@ -653,7 +658,9 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
self.report.report_dropped(table_identifier.raw_table_name())
return
view.columns = self.get_columns_for_table(conn, table_identifier)
view.columns = self.get_columns_for_table(
conn, table_identifier, column_limit=self.config.column_limit
)
lineage_info: Optional[Tuple[UpstreamLineage, Dict[str, str]]] = None
if self.config.include_table_lineage:
@ -877,8 +884,8 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
_COMPLEX_TYPE = re.compile("^(struct|array)")
last_id = -1
for col in columns:
if _COMPLEX_TYPE.match(col.data_type.lower()):
# if col.data_type is empty that means this column is part of a complex type
if col.data_type is None or _COMPLEX_TYPE.match(col.data_type.lower()):
# If the we have seen the ordinal position that most probably means we already processed this complex type
if last_id != col.ordinal_position:
schema_fields.extend(
@ -1099,7 +1106,10 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
return views.get(dataset_name, [])
def get_columns_for_table(
self, conn: bigquery.Client, table_identifier: BigqueryTableIdentifier
self,
conn: bigquery.Client,
table_identifier: BigqueryTableIdentifier,
column_limit: Optional[int] = None,
) -> List[BigqueryColumn]:
if (
@ -1110,6 +1120,7 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
conn,
project_id=table_identifier.project_id,
dataset_name=table_identifier.dataset,
column_limit=column_limit,
)
self.schema_columns[
(table_identifier.project_id, table_identifier.dataset)
@ -1125,7 +1136,9 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
logger.warning(
f"Couldn't get columns on the dataset level for {table_identifier}. Trying to get on table level..."
)
return BigQueryDataDictionary.get_columns_for_table(conn, table_identifier)
return BigQueryDataDictionary.get_columns_for_table(
conn, table_identifier, self.config.column_limit
)
# Access to table but none of its columns - is this possible ?
return columns.get(table_identifier.table, [])

View File

@ -121,7 +121,7 @@ class BigqueryTableIdentifier:
]
if invalid_chars_in_table_name:
raise ValueError(
f"Cannot handle {self} - poorly formatted table name, contains {invalid_chars_in_table_name}"
f"Cannot handle {self.raw_table_name()} - poorly formatted table name, contains {invalid_chars_in_table_name}"
)
return table_name
@ -207,6 +207,7 @@ class QueryEvent:
actor_email: str
query: str
statementType: str
project_id: str
job_name: Optional[str] = None
destinationTable: Optional[BigQueryTableRef] = None
@ -238,6 +239,15 @@ class QueryEvent:
return f"projects/{project}/jobs/{jobId}"
return None
@staticmethod
def _get_project_id_from_job_name(job_name: str) -> str:
project_id_pattern = r"projects\/(.*)\/jobs\/.*"
matches = re.match(project_id_pattern, job_name, re.MULTILINE)
if matches:
return matches.group(1)
else:
raise ValueError(f"Unable to get project_id from jobname: {job_name}")
@classmethod
def from_entry(
cls, entry: AuditLogEntry, debug_include_full_payloads: bool = False
@ -253,6 +263,7 @@ class QueryEvent:
job.get("jobName", {}).get("projectId"),
job.get("jobName", {}).get("jobId"),
),
project_id=job.get("jobName", {}).get("projectId"),
default_dataset=job_query_conf["defaultDataset"]
if job_query_conf["defaultDataset"]
else None,
@ -331,6 +342,7 @@ class QueryEvent:
actor_email=payload["authenticationInfo"]["principalEmail"],
query=query_config["query"],
job_name=job["jobName"],
project_id=QueryEvent._get_project_id_from_job_name(job["jobName"]),
default_dataset=query_config["defaultDataset"]
if query_config.get("defaultDataset")
else None,
@ -392,6 +404,7 @@ class QueryEvent:
# basic query_event
query_event = QueryEvent(
job_name=job["jobName"],
project_id=QueryEvent._get_project_id_from_job_name(job["jobName"]),
timestamp=row.timestamp,
actor_email=payload["authenticationInfo"]["principalEmail"],
query=query_config["query"],

View File

@ -32,6 +32,7 @@ class BigQueryV2Config(BigQueryConfig):
usage: BigQueryUsageConfig = Field(
default=BigQueryUsageConfig(), description="Usage related configs"
)
include_usage_statistics: bool = Field(
default=True,
description="Generate usage statistic",
@ -56,7 +57,10 @@ class BigQueryV2Config(BigQueryConfig):
default=50,
description="Number of table queried in batch when getting metadata. This is a low leve config propert which should be touched with care. This restriction needed because we query partitions system view which throws error if we try to touch too many tables.",
)
column_limit: int = Field(
default=1000,
description="Maximum number of columns to process in a table",
)
# The inheritance hierarchy is wonky here, but these options need modifications.
project_id: Optional[str] = Field(
default=None,
@ -64,6 +68,11 @@ class BigQueryV2Config(BigQueryConfig):
)
storage_project_id: None = Field(default=None, exclude=True)
lineage_use_sql_parser: bool = Field(
default=False,
description="Experimental. Use sql parser to resolve view/table lineage. If there is a view being referenced then bigquery sends both the view as well as underlying tablein the references. There is no distinction between direct/base objects accessed. So doing sql parsing to ensure we only use direct objects accessed for lineage.",
)
@root_validator(pre=False)
def profile_default_settings(cls, values: Dict) -> Dict:
# Extra default SQLAlchemy option for better connection pooling and threading.

View File

@ -8,27 +8,40 @@ import pydantic
from datahub.ingestion.source.sql.sql_common import SQLSourceReport
from datahub.utilities.lossy_collections import LossyDict, LossyList
from datahub.utilities.stats_collections import TopKDict
@dataclass
class BigQueryV2Report(SQLSourceReport):
num_total_lineage_entries: Optional[int] = None
num_skipped_lineage_entries_missing_data: Optional[int] = None
num_skipped_lineage_entries_not_allowed: Optional[int] = None
num_lineage_entries_sql_parser_failure: Optional[int] = None
num_skipped_lineage_entries_other: Optional[int] = None
num_total_log_entries: Optional[int] = None
num_parsed_log_entires: Optional[int] = None
num_total_audit_entries: Optional[int] = None
num_parsed_audit_entires: Optional[int] = None
num_total_lineage_entries: TopKDict[str, int] = field(default_factory=TopKDict)
num_skipped_lineage_entries_missing_data: TopKDict[str, int] = field(
default_factory=TopKDict
)
num_skipped_lineage_entries_not_allowed: TopKDict[str, int] = field(
default_factory=TopKDict
)
num_lineage_entries_sql_parser_failure: TopKDict[str, int] = field(
default_factory=TopKDict
)
num_lineage_entries_sql_parser_success: TopKDict[str, int] = field(
default_factory=TopKDict
)
num_skipped_lineage_entries_other: TopKDict[str, int] = field(
default_factory=TopKDict
)
num_total_log_entries: TopKDict[str, int] = field(default_factory=TopKDict)
num_parsed_log_entries: TopKDict[str, int] = field(default_factory=TopKDict)
num_total_audit_entries: TopKDict[str, int] = field(default_factory=TopKDict)
num_parsed_audit_entries: TopKDict[str, int] = field(default_factory=TopKDict)
bigquery_audit_metadata_datasets_missing: Optional[bool] = None
lineage_failed_extraction: LossyList[str] = field(default_factory=LossyList)
lineage_metadata_entries: Optional[int] = None
lineage_mem_size: Optional[str] = None
lineage_extraction_sec: Dict[str, float] = field(default_factory=dict)
usage_extraction_sec: Dict[str, float] = field(default_factory=dict)
lineage_metadata_entries: TopKDict[str, int] = field(default_factory=TopKDict)
lineage_mem_size: Dict[str, str] = field(default_factory=TopKDict)
lineage_extraction_sec: Dict[str, float] = field(default_factory=TopKDict)
usage_extraction_sec: Dict[str, float] = field(default_factory=TopKDict)
usage_failed_extraction: LossyList[str] = field(default_factory=LossyList)
metadata_extraction_sec: Dict[str, float] = field(default_factory=dict)
num_project_datasets_to_scan: Dict[str, int] = field(default_factory=TopKDict)
metadata_extraction_sec: Dict[str, float] = field(default_factory=TopKDict)
include_table_lineage: Optional[bool] = None
use_date_sharded_audit_log_tables: Optional[bool] = None
log_page_size: Optional[pydantic.PositiveInt] = None
@ -40,10 +53,10 @@ class BigQueryV2Report(SQLSourceReport):
audit_start_time: Optional[str] = None
audit_end_time: Optional[str] = None
upstream_lineage: LossyDict = field(default_factory=LossyDict)
partition_info: Dict[str, str] = field(default_factory=dict)
profile_table_selection_criteria: Dict[str, str] = field(default_factory=dict)
selected_profile_tables: Dict[str, List[str]] = field(default_factory=dict)
invalid_partition_ids: Dict[str, str] = field(default_factory=dict)
partition_info: Dict[str, str] = field(default_factory=TopKDict)
profile_table_selection_criteria: Dict[str, str] = field(default_factory=TopKDict)
selected_profile_tables: Dict[str, List[str]] = field(default_factory=TopKDict)
invalid_partition_ids: Dict[str, str] = field(default_factory=TopKDict)
allow_pattern: Optional[str] = None
deny_pattern: Optional[str] = None
num_usage_workunits_emitted: Optional[int] = None

View File

@ -227,7 +227,7 @@ select
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
CASE WHEN CONTAINS_SUBSTR(field_path, ".") THEN NULL ELSE c.data_type END as data_type,
description as comment,
c.is_hidden as is_hidden,
c.is_partitioning_column as is_partitioning_column
@ -236,7 +236,7 @@ from
join `{project_id}`.`{dataset_name}`.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS as cfp on cfp.table_name = c.table_name
and cfp.column_name = c.column_name
ORDER BY
ordinal_position"""
table_catalog, table_schema, table_name, ordinal_position ASC, data_type DESC"""
columns_for_table: str = """
select
@ -247,7 +247,7 @@ select
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
CASE WHEN CONTAINS_SUBSTR(field_path, ".") THEN NULL ELSE c.data_type END as data_type,
c.is_hidden as is_hidden,
c.is_partitioning_column as is_partitioning_column,
description as comment
@ -258,7 +258,7 @@ from
where
c.table_name = '{table_identifier.table}'
ORDER BY
ordinal_position"""
table_catalog, table_schema, table_name, ordinal_position ASC, data_type DESC"""
class BigQueryDataDictionary:
@ -419,7 +419,10 @@ class BigQueryDataDictionary:
@staticmethod
def get_columns_for_dataset(
conn: bigquery.Client, project_id: str, dataset_name: str
conn: bigquery.Client,
project_id: str,
dataset_name: str,
column_limit: Optional[int] = None,
) -> Optional[Dict[str, List[BigqueryColumn]]]:
columns: Dict[str, List[BigqueryColumn]] = defaultdict(list)
try:
@ -435,24 +438,38 @@ class BigQueryDataDictionary:
# Please repeat query with more selective predicates.
return None
last_seen_table: str = ""
for column in cur:
columns[column.table_name].append(
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
field_path=column.field_path,
is_nullable=column.is_nullable == "YES",
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
if (
column_limit
and column.table_name in columns
and len(columns[column.table_name]) >= column_limit
):
if last_seen_table != column.table_name:
logger.warning(
f"{project_id}.{dataset_name}.{column.table_name} contains more than {column_limit} columns, only processing {column_limit} columns"
)
last_seen_table = column.table_name
else:
columns[column.table_name].append(
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
field_path=column.field_path,
is_nullable=column.is_nullable == "YES",
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
)
)
)
return columns
@staticmethod
def get_columns_for_table(
conn: bigquery.Client, table_identifier: BigqueryTableIdentifier
conn: bigquery.Client,
table_identifier: BigqueryTableIdentifier,
column_limit: Optional[int],
) -> List[BigqueryColumn]:
cur = BigQueryDataDictionary.get_query_result(
@ -460,15 +477,31 @@ class BigQueryDataDictionary:
BigqueryQuery.columns_for_table.format(table_identifier=table_identifier),
)
return [
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
is_nullable=column.is_nullable == "YES",
field_path=column.field_path,
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
)
for column in cur
]
columns: List[BigqueryColumn] = []
last_seen_table: str = ""
for column in cur:
if (
column_limit
and column.table_name in columns
and len(columns[column.table_name]) >= column_limit
):
if last_seen_table != column.table_name:
logger.warning(
f"{table_identifier.project_id}.{table_identifier.dataset}.{column.table_name} contains more than {column_limit} columns, only processing {column_limit} columns"
)
last_seen_table = column.table_name
else:
columns.append(
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
is_nullable=column.is_nullable == "YES",
field_path=column.field_path,
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
)
)
last_seen_table = column.table_name
return columns

View File

@ -1,6 +1,5 @@
import collections
import logging
import sys
import textwrap
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
@ -30,6 +29,7 @@ from datahub.metadata.schema_classes import (
UpstreamClass,
UpstreamLineageClass,
)
from datahub.utilities import memory_footprint
from datahub.utilities.bigquery_sql_parser import BigQuerySQLParser
from datahub.utilities.perf_timer import PerfTimer
@ -145,7 +145,7 @@ timestamp < "{end_time}"
return textwrap.dedent(query)
def compute_bigquery_lineage_via_gcp_logging(
self, project_id: Optional[str]
self, project_id: str
) -> Dict[str, Set[str]]:
logger.info(f"Populating lineage info via GCP audit logs for {project_id}")
try:
@ -154,6 +154,7 @@ timestamp < "{end_time}"
log_entries: Iterable[AuditLogEntry] = self._get_bigquery_log_entries(
clients
)
logger.info("Log Entries loaded")
parsed_entries: Iterable[QueryEvent] = self._parse_bigquery_log_entries(
log_entries
)
@ -193,7 +194,7 @@ timestamp < "{end_time}"
def _get_bigquery_log_entries(
self, client: GCPLoggingClient, limit: Optional[int] = None
) -> Union[Iterable[AuditLogEntry], Iterable[BigQueryAuditMetadata]]:
self.report.num_total_log_entries = 0
self.report.num_total_log_entries[client.project] = 0
# Add a buffer to start and end time to account for delays in logging events.
start_time = (self.config.start_time - self.config.max_query_duration).strftime(
BQ_DATETIME_FORMAT
@ -225,12 +226,22 @@ timestamp < "{end_time}"
entries = client.list_entries(
filter_=filter, page_size=self.config.log_page_size, max_results=limit
)
logger.info(
f"Start iterating over log entries from BigQuery for {client.project}"
)
for entry in entries:
self.report.num_total_log_entries += 1
# for num in range(0, 100):
self.report.num_total_log_entries[client.project] += 1
if self.report.num_total_log_entries[client.project] % 1000 == 0:
logger.info(
f"{self.report.num_total_log_entries[client.project]} log entries loaded for project {client.project} so far..."
)
yield entry
logger.info(
f"Finished loading {self.report.num_total_log_entries} log entries from BigQuery project {client.project} so far"
f"Finished loading {self.report.num_total_log_entries[client.project]} log entries from BigQuery project {client.project} so far"
)
def _get_exported_bigquery_audit_metadata(
@ -294,7 +305,6 @@ timestamp < "{end_time}"
self,
entries: Union[Iterable[AuditLogEntry], Iterable[BigQueryAuditMetadata]],
) -> Iterable[QueryEvent]:
self.report.num_parsed_log_entires = 0
for entry in entries:
event: Optional[QueryEvent] = None
@ -318,21 +328,15 @@ timestamp < "{end_time}"
f"Unable to parse log missing {missing_entry}, missing v2 {missing_entry_v2} for {entry}",
)
else:
self.report.num_parsed_log_entires += 1
self.report.num_parsed_log_entries[event.project_id] = (
self.report.num_parsed_log_entries.get(event.project_id, 0) + 1
)
yield event
logger.info(
"Parsing BigQuery log entries: "
f"number of log entries successfully parsed={self.report.num_parsed_log_entires}"
)
def _parse_exported_bigquery_audit_metadata(
self, audit_metadata_rows: Iterable[BigQueryAuditMetadata]
) -> Iterable[QueryEvent]:
self.report.num_total_audit_entries = 0
self.report.num_parsed_audit_entires = 0
for audit_metadata in audit_metadata_rows:
self.report.num_total_audit_entries += 1
event: Optional[QueryEvent] = None
missing_exported_audit = (
@ -353,34 +357,62 @@ timestamp < "{end_time}"
f"Unable to parse audit metadata missing {missing_exported_audit} for {audit_metadata}",
)
else:
self.report.num_parsed_audit_entires += 1
self.report.num_parsed_audit_entries[event.project_id] = (
self.report.num_parsed_audit_entries.get(event.project_id, 0) + 1
)
self.report.num_total_audit_entries[event.project_id] = (
self.report.num_total_audit_entries.get(event.project_id, 0) + 1
)
yield event
def _create_lineage_map(self, entries: Iterable[QueryEvent]) -> Dict[str, Set[str]]:
logger.info("Entering create lineage map function")
lineage_map: Dict[str, Set[str]] = collections.defaultdict(set)
self.report.num_total_lineage_entries = 0
self.report.num_skipped_lineage_entries_missing_data = 0
self.report.num_skipped_lineage_entries_not_allowed = 0
self.report.num_skipped_lineage_entries_other = 0
self.report.num_lineage_entries_sql_parser_failure = 0
for e in entries:
self.report.num_total_lineage_entries += 1
self.report.num_total_lineage_entries[e.project_id] = (
self.report.num_total_lineage_entries.get(e.project_id, 0) + 1
)
if e.destinationTable is None or not (
e.referencedTables or e.referencedViews
):
self.report.num_skipped_lineage_entries_missing_data += 1
self.report.num_skipped_lineage_entries_missing_data[e.project_id] = (
self.report.num_skipped_lineage_entries_missing_data.get(
e.project_id, 0
)
+ 1
)
continue
# Skip if schema/table pattern don't allow the destination table
destination_table_str = str(e.destinationTable.get_sanitized_table_ref())
destination_table_str_parts = destination_table_str.split("/")
try:
destination_table = e.destinationTable.get_sanitized_table_ref()
except Exception:
self.report.num_skipped_lineage_entries_missing_data[e.project_id] = (
self.report.num_skipped_lineage_entries_missing_data.get(
e.project_id, 0
)
+ 1
)
continue
destination_table_str = destination_table.table_identifier.get_table_name()
if not self.config.dataset_pattern.allowed(
destination_table_str_parts[3]
) or not self.config.table_pattern.allowed(destination_table_str_parts[-1]):
self.report.num_skipped_lineage_entries_not_allowed += 1
destination_table.table_identifier.dataset
) or not self.config.table_pattern.allowed(
destination_table.table_identifier.get_table_name()
):
self.report.num_skipped_lineage_entries_not_allowed[e.project_id] = (
self.report.num_skipped_lineage_entries_not_allowed.get(
e.project_id, 0
)
+ 1
)
continue
has_table = False
for ref_table in e.referencedTables:
ref_table_str = str(ref_table.get_sanitized_table_ref())
ref_table_str = (
ref_table.get_sanitized_table_ref().table_identifier.get_table_name()
)
if ref_table_str != destination_table_str:
lineage_map[destination_table_str].add(ref_table_str)
has_table = True
@ -390,7 +422,7 @@ timestamp < "{end_time}"
if ref_view_str != destination_table_str:
lineage_map[destination_table_str].add(ref_view_str)
has_view = True
if has_table and has_view:
if self.config.lineage_use_sql_parser and has_table and has_view:
# If there is a view being referenced then bigquery sends both the view as well as underlying table
# in the references. There is no distinction between direct/base objects accessed. So doing sql parsing
# to ensure we only use direct objects accessed for lineage
@ -399,11 +431,22 @@ timestamp < "{end_time}"
referenced_objs = set(
map(lambda x: x.split(".")[-1], parser.get_tables())
)
self.report.num_lineage_entries_sql_parser_failure[e.project_id] = (
self.report.num_lineage_entries_sql_parser_failure.get(
e.project_id, 0
)
+ 1
)
except Exception as ex:
logger.debug(
f"Sql Parser failed on query: {e.query}. It won't cause any issue except table/view lineage can't be detected reliably. The error was {ex}."
)
self.report.num_lineage_entries_sql_parser_failure += 1
self.report.num_lineage_entries_sql_parser_failure[e.project_id] = (
self.report.num_lineage_entries_sql_parser_failure.get(
e.project_id, 0
)
+ 1
)
continue
curr_lineage_str = lineage_map[destination_table_str]
new_lineage_str = set()
@ -413,12 +456,15 @@ timestamp < "{end_time}"
new_lineage_str.add(lineage_str)
lineage_map[destination_table_str] = new_lineage_str
if not (has_table or has_view):
self.report.num_skipped_lineage_entries_other += 1
self.report.num_skipped_lineage_entries_other[e.project_id] = (
self.report.num_skipped_lineage_entries_other.get(e.project_id, 0)
+ 1
)
logger.info("Exiting create lineage map function")
return lineage_map
def _compute_bigquery_lineage(
self, project_id: Optional[str] = None
) -> Dict[str, Set[str]]:
def _compute_bigquery_lineage(self, project_id: str) -> Dict[str, Set[str]]:
lineage_extractor: BigqueryLineageExtractor = BigqueryLineageExtractor(
config=self.config, report=self.report
)
@ -448,10 +494,10 @@ timestamp < "{end_time}"
if lineage_metadata is None:
lineage_metadata = {}
self.report.lineage_mem_size = humanfriendly.format_size(
sys.getsizeof(lineage_metadata)
self.report.lineage_mem_size[project_id] = humanfriendly.format_size(
memory_footprint.total_size(lineage_metadata)
)
self.report.lineage_metadata_entries = len(lineage_metadata)
self.report.lineage_metadata_entries[project_id] = len(lineage_metadata)
logger.info(f"Built lineage map containing {len(lineage_metadata)} entries.")
logger.debug(f"lineage metadata is {lineage_metadata}")
return lineage_metadata

View File

@ -0,0 +1,45 @@
from collections import deque
from itertools import chain
from sys import getsizeof
from typing import Any, Dict
def total_size(o: Any, handlers: Any = {}) -> int:
"""Returns the approximate memory footprint an object and all of its contents.
Automatically finds the contents of the following builtin containers and
their subclasses: tuple, list, deque, dict, set and frozenset.
To search other containers, add handlers to iterate over their contents:
handlers = {SomeContainerClass: iter,
OtherContainerClass: OtherContainerClass.get_elements}
Based on https://github.com/ActiveState/recipe-577504-compute-mem-footprint/blob/master/recipe.py
"""
def dict_handler(d: Dict) -> chain[Any]:
return chain.from_iterable(d.items())
all_handlers = {
tuple: iter,
list: iter,
deque: iter,
dict: dict_handler,
set: iter,
frozenset: iter,
}
all_handlers.update(handlers) # user handlers take precedence
seen = set() # track which object id's have already been seen
default_size = getsizeof(0) # estimate sizeof object without __sizeof__
def sizeof(o: Any) -> int:
if id(o) in seen: # do not double count the same object
return 0
seen.add(id(o))
s = getsizeof(o, default_size)
for typ, handler in all_handlers.items():
if isinstance(o, typ):
s += sum(map(sizeof, handler(o))) # type: ignore
break
return s
return sizeof(o)

View File

@ -0,0 +1,36 @@
from typing import Any, Dict, TypeVar, Union
T = TypeVar("T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
class TopKDict(Dict[_KT, _VT]):
"""A structure that only prints the top K items from the dictionary. Not lossy."""
def __init__(self, top_k: int = 10) -> None:
super().__init__()
self.top_k = 10
def __repr__(self) -> str:
return repr(self.as_obj())
def __str__(self) -> str:
return self.__repr__()
@staticmethod
def _trim_dictionary(big_dict: Dict[str, Any]) -> Dict[str, Any]:
if big_dict is not None and len(big_dict) > 10:
dict_as_tuples = [(k, v) for k, v in big_dict.items()]
sorted_tuples = sorted(dict_as_tuples, key=lambda x: x[1], reverse=True)
dict_as_tuples = sorted_tuples[:10]
trimmed_dict = {k: v for k, v in dict_as_tuples}
trimmed_dict[f"... top(10) of total {len(big_dict)} entries"] = ""
print(f"Dropping entries {sorted_tuples[11:]}")
return trimmed_dict
return big_dict
def as_obj(self) -> Dict[Union[_KT, str], Union[_VT, str]]:
base_dict: Dict[Union[_KT, str], Union[_VT, str]] = super().copy() # type: ignore
return self._trim_dictionary(base_dict) # type: ignore