feat(ingest): Add memory leak detection capability to the datahub cli command. (#4363)

This commit is contained in:
Ravindra Lanka 2022-03-09 17:08:44 -08:00 committed by GitHub
parent 6da3e28c33
commit dc62feb1e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 2 deletions

View File

@ -20,6 +20,7 @@ from datahub.configuration import SensitiveError
from datahub.configuration.config_loader import load_config_file from datahub.configuration.config_loader import load_config_file
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.telemetry import telemetry from datahub.telemetry import telemetry
from datahub.utilities import memory_leak_detector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,8 +64,12 @@ def ingest() -> None:
default=False, default=False,
help="If enabled, ingestion runs with warnings will yield a non-zero error code", help="If enabled, ingestion runs with warnings will yield a non-zero error code",
) )
@click.pass_context
@telemetry.with_telemetry @telemetry.with_telemetry
def run(config: str, dry_run: bool, preview: bool, strict_warnings: bool) -> None: @memory_leak_detector.with_leak_detection
def run(
ctx: click.Context, config: str, dry_run: bool, preview: bool, strict_warnings: bool
) -> None:
"""Ingest metadata into DataHub.""" """Ingest metadata into DataHub."""
logger.info("DataHub CLI version: %s", datahub_package.nice_version_name()) logger.info("DataHub CLI version: %s", datahub_package.nice_version_name())

View File

@ -48,7 +48,16 @@ MAX_CONTENT_WIDTH = 120
version=datahub_package.nice_version_name(), version=datahub_package.nice_version_name(),
prog_name=datahub_package.__package_name__, prog_name=datahub_package.__package_name__,
) )
def datahub(debug: bool) -> None: @click.option(
"-dl",
"--detect-memory-leaks",
type=bool,
is_flag=True,
default=False,
help="Run memory leak detection.",
)
@click.pass_context
def datahub(ctx: click.Context, debug: bool, detect_memory_leaks: bool) -> None:
# Insulate 'datahub' and all child loggers from inadvertent changes to the # Insulate 'datahub' and all child loggers from inadvertent changes to the
# root logger by the external site packages that we import. # root logger by the external site packages that we import.
# (Eg: https://github.com/reata/sqllineage/commit/2df027c77ea0a8ea4909e471dcd1ecbf4b8aeb2f#diff-30685ea717322cd1e79c33ed8d37903eea388e1750aa00833c33c0c5b89448b3R11 # (Eg: https://github.com/reata/sqllineage/commit/2df027c77ea0a8ea4909e471dcd1ecbf4b8aeb2f#diff-30685ea717322cd1e79c33ed8d37903eea388e1750aa00833c33c0c5b89448b3R11
@ -74,6 +83,9 @@ def datahub(debug: bool) -> None:
datahub_logger.setLevel(logging.INFO) datahub_logger.setLevel(logging.INFO)
# loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] # loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# print(loggers) # print(loggers)
# Setup the context for the memory_leak_detector decorator.
ctx.ensure_object(dict)
ctx.obj["detect_memory_leaks"] = detect_memory_leaks
@datahub.command() @datahub.command()

View File

@ -0,0 +1,111 @@
import fnmatch
import gc
import logging
import sys
import tracemalloc
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, List, TypeVar, Union, cast
logger = logging.getLogger(__name__)
T = TypeVar("T")
def _trace_has_file(trace: tracemalloc.Traceback, file_pattern: str) -> bool:
for frame_index in range(0, len(trace)):
cur_frame = trace[frame_index]
if fnmatch.fnmatch(cur_frame.filename, file_pattern):
return True
return False
def _init_leak_detection() -> None:
# Initialize trace malloc to track up to 25 stack frames.
tracemalloc.start(25)
if sys.version_info >= (3, 9):
# Nice to reset peak to 0. Available for versions >= 3.9.
tracemalloc.reset_peak()
# Enable leak debugging in the garbage collector.
gc.set_debug(gc.DEBUG_LEAK)
def _perform_leak_detection() -> None:
# Log potentially useful memory usage metrics
logger.info(f"GC count before collect {gc.get_count()}")
traced_memory_size, traced_memory_peak = tracemalloc.get_traced_memory()
logger.info(f"Traced Memory: size={traced_memory_size}, peak={traced_memory_peak}")
num_unreacheable_objects = gc.collect()
logger.info(f"Number of unreachable objects = {num_unreacheable_objects}")
logger.info(f"GC count after collect {gc.get_count()}")
# Collect unique traces of all live objects in the garbage - these have potential leaks.
unique_traces_to_objects: Dict[
Union[tracemalloc.Traceback, int], List[object]
] = defaultdict(list)
for obj in gc.garbage:
obj_trace = tracemalloc.get_object_traceback(obj)
if obj_trace is not None:
if _trace_has_file(obj_trace, "*datahub/*.py"):
# Leaking object
unique_traces_to_objects[obj_trace].append(obj)
else:
unique_traces_to_objects[id(obj)].append(obj)
logger.info("Potentially leaking objects start")
for key, obj_list in sorted(
unique_traces_to_objects.items(),
key=lambda item: sum([sys.getsizeof(o) for o in item[1]]),
reverse=True,
):
if isinstance(key, tracemalloc.Traceback):
obj_traceback: tracemalloc.Traceback = cast(tracemalloc.Traceback, key)
logger.info(
f"#Objects:{len(obj_list)}; Total memory:{sum([sys.getsizeof(obj) for obj in obj_list])};"
+ " Allocation Trace:\n\t"
+ "\n\t".join(obj_traceback.format(limit=25))
)
else:
logger.info(
f"#Objects:{len(obj_list)}; Total memory:{sum([sys.getsizeof(obj) for obj in obj_list])};"
+ " No Allocation Trace available!"
)
# Print details about the live referrers of each object in the obj_list (same trace).
for obj in obj_list:
referrers = [r for r in gc.get_referrers(obj) if r in gc.garbage]
logger.info(
f"Referrers[{len(referrers)}] for object(addr={hex(id(obj))}):'{obj}':"
)
for ref_index, referrer in enumerate(referrers):
ref_trace = tracemalloc.get_object_traceback(referrer)
logger.info(
f"Referrer[{ref_index}] referrer_obj(addr={hex(id(referrer))}):{referrer}, RefTrace:\n\t\t"
+ "\n\t\t".join(ref_trace.format(limit=5) if ref_trace else [])
+ "\n"
)
logger.info("Potentially leaking objects end")
tracemalloc.stop()
# TODO: Transition to ParamSpec with the first arg being click.Context (using typing_extensions.Concatenate)
# once fully supported by mypy.
def with_leak_detection(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
detect_leaks: bool = args[0].obj.get("detect_memory_leaks", False)
if detect_leaks:
logger.info(
f"Initializing memory leak detection on command: {func.__module__}.{func.__name__}"
)
_init_leak_detection()
try:
res = func(*args, **kwargs)
return res
finally:
if detect_leaks:
_perform_leak_detection()
logger.info(
f"Finished memory leak detection on command: {func.__module__}.{func.__name__}"
)
return wrapper