diff --git a/metadata-ingestion/src/datahub/cli/ingest_cli.py b/metadata-ingestion/src/datahub/cli/ingest_cli.py index 33c6254832..c76434acdc 100644 --- a/metadata-ingestion/src/datahub/cli/ingest_cli.py +++ b/metadata-ingestion/src/datahub/cli/ingest_cli.py @@ -20,6 +20,7 @@ from datahub.configuration import SensitiveError from datahub.configuration.config_loader import load_config_file from datahub.ingestion.run.pipeline import Pipeline from datahub.telemetry import telemetry +from datahub.utilities import memory_leak_detector logger = logging.getLogger(__name__) @@ -63,8 +64,12 @@ def ingest() -> None: default=False, help="If enabled, ingestion runs with warnings will yield a non-zero error code", ) +@click.pass_context @telemetry.with_telemetry -def run(config: str, dry_run: bool, preview: bool, strict_warnings: bool) -> None: +@memory_leak_detector.with_leak_detection +def run( + ctx: click.Context, config: str, dry_run: bool, preview: bool, strict_warnings: bool +) -> None: """Ingest metadata into DataHub.""" logger.info("DataHub CLI version: %s", datahub_package.nice_version_name()) diff --git a/metadata-ingestion/src/datahub/entrypoints.py b/metadata-ingestion/src/datahub/entrypoints.py index b8360c203e..f7b60622d3 100644 --- a/metadata-ingestion/src/datahub/entrypoints.py +++ b/metadata-ingestion/src/datahub/entrypoints.py @@ -48,7 +48,16 @@ MAX_CONTENT_WIDTH = 120 version=datahub_package.nice_version_name(), prog_name=datahub_package.__package_name__, ) -def datahub(debug: bool) -> None: +@click.option( + "-dl", + "--detect-memory-leaks", + type=bool, + is_flag=True, + default=False, + help="Run memory leak detection.", +) +@click.pass_context +def datahub(ctx: click.Context, debug: bool, detect_memory_leaks: bool) -> None: # Insulate 'datahub' and all child loggers from inadvertent changes to the # root logger by the external site packages that we import. # (Eg: https://github.com/reata/sqllineage/commit/2df027c77ea0a8ea4909e471dcd1ecbf4b8aeb2f#diff-30685ea717322cd1e79c33ed8d37903eea388e1750aa00833c33c0c5b89448b3R11 @@ -74,6 +83,9 @@ def datahub(debug: bool) -> None: datahub_logger.setLevel(logging.INFO) # loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] # print(loggers) + # Setup the context for the memory_leak_detector decorator. + ctx.ensure_object(dict) + ctx.obj["detect_memory_leaks"] = detect_memory_leaks @datahub.command() diff --git a/metadata-ingestion/src/datahub/utilities/memory_leak_detector.py b/metadata-ingestion/src/datahub/utilities/memory_leak_detector.py new file mode 100644 index 0000000000..b5fa3c3a72 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/memory_leak_detector.py @@ -0,0 +1,111 @@ +import fnmatch +import gc +import logging +import sys +import tracemalloc +from collections import defaultdict +from functools import wraps +from typing import Any, Callable, Dict, List, TypeVar, Union, cast + +logger = logging.getLogger(__name__) +T = TypeVar("T") + + +def _trace_has_file(trace: tracemalloc.Traceback, file_pattern: str) -> bool: + for frame_index in range(0, len(trace)): + cur_frame = trace[frame_index] + if fnmatch.fnmatch(cur_frame.filename, file_pattern): + return True + return False + + +def _init_leak_detection() -> None: + # Initialize trace malloc to track up to 25 stack frames. + tracemalloc.start(25) + if sys.version_info >= (3, 9): + # Nice to reset peak to 0. Available for versions >= 3.9. + tracemalloc.reset_peak() + # Enable leak debugging in the garbage collector. + gc.set_debug(gc.DEBUG_LEAK) + + +def _perform_leak_detection() -> None: + # Log potentially useful memory usage metrics + logger.info(f"GC count before collect {gc.get_count()}") + traced_memory_size, traced_memory_peak = tracemalloc.get_traced_memory() + logger.info(f"Traced Memory: size={traced_memory_size}, peak={traced_memory_peak}") + num_unreacheable_objects = gc.collect() + logger.info(f"Number of unreachable objects = {num_unreacheable_objects}") + logger.info(f"GC count after collect {gc.get_count()}") + + # Collect unique traces of all live objects in the garbage - these have potential leaks. + unique_traces_to_objects: Dict[ + Union[tracemalloc.Traceback, int], List[object] + ] = defaultdict(list) + for obj in gc.garbage: + obj_trace = tracemalloc.get_object_traceback(obj) + if obj_trace is not None: + if _trace_has_file(obj_trace, "*datahub/*.py"): + # Leaking object + unique_traces_to_objects[obj_trace].append(obj) + else: + unique_traces_to_objects[id(obj)].append(obj) + logger.info("Potentially leaking objects start") + for key, obj_list in sorted( + unique_traces_to_objects.items(), + key=lambda item: sum([sys.getsizeof(o) for o in item[1]]), + reverse=True, + ): + if isinstance(key, tracemalloc.Traceback): + obj_traceback: tracemalloc.Traceback = cast(tracemalloc.Traceback, key) + logger.info( + f"#Objects:{len(obj_list)}; Total memory:{sum([sys.getsizeof(obj) for obj in obj_list])};" + + " Allocation Trace:\n\t" + + "\n\t".join(obj_traceback.format(limit=25)) + ) + else: + logger.info( + f"#Objects:{len(obj_list)}; Total memory:{sum([sys.getsizeof(obj) for obj in obj_list])};" + + " No Allocation Trace available!" + ) + # Print details about the live referrers of each object in the obj_list (same trace). + for obj in obj_list: + referrers = [r for r in gc.get_referrers(obj) if r in gc.garbage] + logger.info( + f"Referrers[{len(referrers)}] for object(addr={hex(id(obj))}):'{obj}':" + ) + for ref_index, referrer in enumerate(referrers): + ref_trace = tracemalloc.get_object_traceback(referrer) + logger.info( + f"Referrer[{ref_index}] referrer_obj(addr={hex(id(referrer))}):{referrer}, RefTrace:\n\t\t" + + "\n\t\t".join(ref_trace.format(limit=5) if ref_trace else []) + + "\n" + ) + logger.info("Potentially leaking objects end") + + tracemalloc.stop() + + +# TODO: Transition to ParamSpec with the first arg being click.Context (using typing_extensions.Concatenate) +# once fully supported by mypy. +def with_leak_detection(func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + detect_leaks: bool = args[0].obj.get("detect_memory_leaks", False) + if detect_leaks: + logger.info( + f"Initializing memory leak detection on command: {func.__module__}.{func.__name__}" + ) + _init_leak_detection() + + try: + res = func(*args, **kwargs) + return res + finally: + if detect_leaks: + _perform_leak_detection() + logger.info( + f"Finished memory leak detection on command: {func.__module__}.{func.__name__}" + ) + + return wrapper