refactor(ingest): defer ctx.graph initialization (#10504)

This commit is contained in:
Harshal Sheth 2024-05-21 17:01:35 -07:00 committed by GitHub
parent f2b5875632
commit b8023a93a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 121 additions and 113 deletions

View File

@ -131,10 +131,10 @@ def run(
pipeline.run()
except Exception as e:
logger.info(
f"Source ({pipeline.config.source.type}) report:\n{pipeline.source.get_report().as_string()}"
f"Source ({pipeline.source_type}) report:\n{pipeline.source.get_report().as_string()}"
)
logger.info(
f"Sink ({pipeline.config.sink.type}) report:\n{pipeline.sink.get_report().as_string()}"
f"Sink ({pipeline.sink_type}) report:\n{pipeline.sink.get_report().as_string()}"
)
raise e
else:

View File

@ -152,6 +152,10 @@ class TransformerSemanticsConfigModel(ConfigModel):
class DynamicTypedConfig(ConfigModel):
# Once support for discriminated unions gets merged into Pydantic, we can
# simplify this configuration and validation.
# See https://github.com/samuelcolvin/pydantic/pull/2336.
type: str = Field(
description="The type of the dynamic object",
)

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.sink import Sink
class PipelineRunListener(ABC):
@ -21,6 +22,11 @@ class PipelineRunListener(ABC):
@classmethod
@abstractmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "PipelineRunListener":
def create(
cls,
config_dict: Dict[str, Any],
ctx: PipelineContext,
sink: Sink,
) -> "PipelineRunListener":
# Creation and initialization.
pass

View File

@ -184,6 +184,12 @@ class PluginRegistry(Generic[T]):
# If it's not an exception, then it's a registered type.
return tp
def get_optional(self, key: str) -> Optional[Type[T]]:
try:
return self.get(key)
except Exception:
return None
def summary(
self, verbose: bool = True, col_width: int = 15, verbose_col_width: int = 20
) -> str:

View File

@ -63,7 +63,10 @@ from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.urns.urn import Urn, guess_entity_type
if TYPE_CHECKING:
from datahub.ingestion.sink.datahub_rest import DatahubRestSink
from datahub.ingestion.sink.datahub_rest import (
DatahubRestSink,
DatahubRestSinkConfig,
)
from datahub.ingestion.source.state.entity_removal_state import (
GenericCheckpointState,
)
@ -202,13 +205,8 @@ class DataHubGraph(DatahubRestEmitter):
def _post_generic(self, url: str, payload_dict: Dict) -> Dict:
return self._send_restli_request("POST", url, json=payload_dict)
@contextlib.contextmanager
def make_rest_sink(
self, run_id: str = _GRAPH_DUMMY_RUN_ID
) -> Iterator["DatahubRestSink"]:
from datahub.ingestion.api.common import PipelineContext
def _make_rest_sink_config(self) -> "DatahubRestSinkConfig":
from datahub.ingestion.sink.datahub_rest import (
DatahubRestSink,
DatahubRestSinkConfig,
SyncOrAsync,
)
@ -218,10 +216,16 @@ class DataHubGraph(DatahubRestEmitter):
# TODO: We should refactor out the multithreading functionality of the sink
# into a separate class that can be used by both the sink and the graph client
# e.g. a DatahubBulkRestEmitter that both the sink and the graph client use.
sink_config = DatahubRestSinkConfig(
**self.config.dict(), mode=SyncOrAsync.ASYNC
)
return DatahubRestSinkConfig(**self.config.dict(), mode=SyncOrAsync.ASYNC)
@contextlib.contextmanager
def make_rest_sink(
self, run_id: str = _GRAPH_DUMMY_RUN_ID
) -> Iterator["DatahubRestSink"]:
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.sink.datahub_rest import DatahubRestSink
sink_config = self._make_rest_sink_config()
with DatahubRestSink(PipelineContext(run_id=run_id), sink_config) as sink:
yield sink
if sink.report.failures:

View File

@ -79,36 +79,33 @@ class DatahubIngestionRunSummaryProvider(PipelineRunListener):
cls,
config_dict: Dict[str, Any],
ctx: PipelineContext,
sink: Sink,
) -> PipelineRunListener:
sink_config_holder: Optional[DynamicTypedConfig] = None
reporter_config = DatahubIngestionRunSummaryProviderConfig.parse_obj(
config_dict or {}
)
if reporter_config.sink:
sink_config_holder = reporter_config.sink
if sink_config_holder is None:
# Populate sink from global recipe
assert ctx.pipeline_config
sink_config_holder = ctx.pipeline_config.sink
# Global instances are safe to use only if the types are datahub-rest and datahub-kafka
# Re-using a shared file sink will result in clobbering the events
if sink_config_holder.type not in ["datahub-rest", "datahub-kafka"]:
sink_class = sink_registry.get(reporter_config.sink.type)
sink_config = reporter_config.sink.config or {}
sink = sink_class.create(sink_config, ctx)
else:
if not isinstance(
sink,
tuple(
[
kls
for kls in [
sink_registry.get_optional("datahub-rest"),
sink_registry.get_optional("datahub-kafka"),
]
if kls
]
),
):
raise IgnorableError(
f"Datahub ingestion reporter will be disabled because sink type {sink_config_holder.type} is not supported"
f"Datahub ingestion reporter will be disabled because sink type {type(sink)} is not supported"
)
sink_type = sink_config_holder.type
sink_class = sink_registry.get(sink_type)
sink_config = sink_config_holder.dict().get("config") or {}
if sink_type == "datahub-rest":
# for the rest emitter we want to use sync mode to emit
# regardless of the default sink config since that makes it
# immune to process shutdown related failures
sink_config["mode"] = "SYNC"
sink: Sink = sink_class.create(sink_config, ctx)
return cls(sink, reporter_config.report_recipe, ctx)
def __init__(self, sink: Sink, report_recipe: bool, ctx: PipelineContext) -> None:

View File

@ -7,6 +7,7 @@ from pydantic import validator
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.pipeline_run_listener import PipelineRunListener
from datahub.ingestion.api.sink import Sink
logger = logging.getLogger(__name__)
@ -30,6 +31,7 @@ class FileReporter(PipelineRunListener):
cls,
config_dict: Dict[str, Any],
ctx: PipelineContext,
sink: Sink,
) -> PipelineRunListener:
reporter_config = FileReporterConfig.parse_obj(config_dict)
return cls(reporter_config)

View File

@ -28,11 +28,12 @@ from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.ingestion.api.source import Extractor, Source
from datahub.ingestion.api.transform import Transformer
from datahub.ingestion.extractor.extractor_registry import extractor_registry
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.ingestion.reporting.reporting_provider_registry import (
reporting_provider_registry,
)
from datahub.ingestion.run.pipeline_config import PipelineConfig, ReporterConfig
from datahub.ingestion.sink.datahub_rest import DatahubRestSink
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.ingestion.sink.sink_registry import sink_registry
from datahub.ingestion.source.source_registry import source_registry
@ -182,11 +183,19 @@ class CliReport(Report):
return super().compute_stats()
def _make_default_rest_sink(ctx: PipelineContext) -> DatahubRestSink:
graph = get_default_graph()
sink_config = graph._make_rest_sink_config()
return DatahubRestSink(ctx, sink_config)
class Pipeline:
config: PipelineConfig
ctx: PipelineContext
source: Source
extractor: Extractor
sink_type: str
sink: Sink[ConfigModel, SinkReport]
transformers: List[Transformer]
@ -217,9 +226,6 @@ class Pipeline:
self.graph = DataHubGraph(self.config.datahub_api)
self.graph.test_connection()
telemetry.telemetry_instance.update_capture_exception_context(
server=self.graph
)
with _add_init_error_context("set up framework context"):
self.ctx = PipelineContext(
run_id=self.config.run_id,
@ -230,31 +236,43 @@ class Pipeline:
pipeline_config=self.config,
)
sink_type = self.config.sink.type
with _add_init_error_context(f"find a registered sink for type {sink_type}"):
sink_class = sink_registry.get(sink_type)
if self.config.sink is None:
with _add_init_error_context("configure the default rest sink"):
self.sink_type = "datahub-rest"
self.sink = _make_default_rest_sink(self.ctx)
else:
self.sink_type = self.config.sink.type
with _add_init_error_context(
f"find a registered sink for type {self.sink_type}"
):
sink_class = sink_registry.get(self.sink_type)
with _add_init_error_context(f"configure the sink ({sink_type})"):
sink_config = self.config.sink.dict().get("config") or {}
self.sink = sink_class.create(sink_config, self.ctx)
logger.debug(f"Sink type {self.config.sink.type} ({sink_class}) configured")
logger.info(f"Sink configured successfully. {self.sink.configured()}")
with _add_init_error_context(f"configure the sink ({self.sink_type})"):
sink_config = self.config.sink.dict().get("config") or {}
self.sink = sink_class.create(sink_config, self.ctx)
logger.debug(f"Sink type {self.sink_type} ({sink_class}) configured")
logger.info(f"Sink configured successfully. {self.sink.configured()}")
if self.graph is None and isinstance(self.sink, DatahubRestSink):
with _add_init_error_context("setup default datahub client"):
self.graph = self.sink.emitter.to_graph()
self.ctx.graph = self.graph
telemetry.telemetry_instance.update_capture_exception_context(server=self.graph)
# once a sink is configured, we can configure reporting immediately to get observability
with _add_init_error_context("configure reporters"):
self._configure_reporting(report_to, no_default_report)
source_type = self.config.source.type
with _add_init_error_context(
f"find a registered source for type {source_type}"
f"find a registered source for type {self.source_type}"
):
source_class = source_registry.get(source_type)
source_class = source_registry.get(self.source_type)
with _add_init_error_context(f"configure the source ({source_type})"):
with _add_init_error_context(f"configure the source ({self.source_type})"):
self.source = source_class.create(
self.config.source.dict().get("config", {}), self.ctx
)
logger.debug(f"Source type {source_type} ({source_class}) configured")
logger.debug(f"Source type {self.source_type} ({source_class}) configured")
logger.info("Source configured successfully.")
extractor_type = self.config.source.extractor
@ -267,6 +285,10 @@ class Pipeline:
with _add_init_error_context("configure transformers"):
self._configure_transforms()
@property
def source_type(self) -> str:
return self.config.source.type
def _configure_transforms(self) -> None:
self.transformers = []
if self.config.transformers is not None:
@ -310,6 +332,7 @@ class Pipeline:
reporter_class.create(
config_dict=reporter_config_dict,
ctx=self.ctx,
sink=self.sink,
)
)
logger.debug(
@ -552,8 +575,8 @@ class Pipeline:
telemetry.telemetry_instance.ping(
"ingest_stats",
{
"source_type": self.config.source.type,
"sink_type": self.config.sink.type,
"source_type": self.source_type,
"sink_type": self.sink_type,
"transformer_types": [
transformer.type for transformer in self.config.transformers or []
],
@ -602,17 +625,22 @@ class Pipeline:
click.echo()
click.secho("Cli report:", bold=True)
click.secho(self.cli_report.as_string())
click.secho(f"Source ({self.config.source.type}) report:", bold=True)
click.secho(f"Source ({self.source_type}) report:", bold=True)
click.echo(self.source.get_report().as_string())
click.secho(f"Sink ({self.config.sink.type}) report:", bold=True)
click.secho(f"Sink ({self.sink_type}) report:", bold=True)
click.echo(self.sink.get_report().as_string())
global_warnings = get_global_warnings()
if len(global_warnings) > 0:
click.secho("Global Warnings:", bold=True)
click.echo(global_warnings)
click.echo()
workunits_produced = self.source.get_report().events_produced
workunits_produced = self.sink.get_report().total_records_written
duration_message = f"in {humanfriendly.format_timespan(self.source.get_report().running_time)}."
if currently_running:
message_template = f"⏳ Pipeline running {{status}} so far; produced {workunits_produced} events {duration_message}"
else:
message_template = f"Pipeline finished {{status}}; produced {workunits_produced} events {duration_message}"
if self.source.get_report().failures or self.sink.get_report().failures:
num_failures_source = self._approx_all_vals(
@ -620,11 +648,11 @@ class Pipeline:
)
num_failures_sink = len(self.sink.get_report().failures)
click.secho(
f"{'' if currently_running else ''} Pipeline {'running' if currently_running else 'finished'} with at least {num_failures_source+num_failures_sink} failures{' so far' if currently_running else ''}; produced {workunits_produced} events {duration_message}",
message_template.format(
status=f"with at least {num_failures_source+num_failures_sink} failures"
),
fg=self._get_text_color(
running=currently_running,
failures=True,
warnings=False,
running=currently_running, failures=True, warnings=False
),
bold=True,
)
@ -638,7 +666,9 @@ class Pipeline:
num_warn_sink = len(self.sink.get_report().warnings)
num_warn_global = len(global_warnings)
click.secho(
f"{'' if currently_running else ''} Pipeline {'running' if currently_running else 'finished'} with at least {num_warn_source+num_warn_sink+num_warn_global} warnings{' so far' if currently_running else ''}; produced {workunits_produced} events {duration_message}",
message_template.format(
status=f"with at least {num_warn_source+num_warn_sink+num_warn_global} warnings"
),
fg=self._get_text_color(
running=currently_running, failures=False, warnings=True
),
@ -647,7 +677,7 @@ class Pipeline:
return 1 if warnings_as_failure else 0
else:
click.secho(
f"{'' if currently_running else ''} Pipeline {'running' if currently_running else 'finished'} successfully{' so far' if currently_running else ''}; produced {workunits_produced} events {duration_message}",
message_template.format(status="successfully"),
fg=self._get_text_color(
running=currently_running, failures=False, warnings=False
),
@ -659,11 +689,11 @@ class Pipeline:
return {
"cli": self.cli_report.as_obj(),
"source": {
"type": self.config.source.type,
"type": self.source_type,
"report": self.source.get_report().as_obj(),
},
"sink": {
"type": self.config.sink.type,
"type": self.sink_type,
"report": self.sink.get_report().as_obj(),
},
}

View File

@ -1,13 +1,10 @@
import datetime
import logging
import os
import uuid
from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator, validator
from pydantic import Field, validator
from datahub.cli.cli_utils import get_url_and_token
from datahub.configuration import config_loader
from datahub.configuration.common import ConfigModel, DynamicTypedConfig
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.sink.file import FileSinkConfig
@ -67,12 +64,8 @@ class FlagsConfig(ConfigModel):
class PipelineConfig(ConfigModel):
# Once support for discriminated unions gets merged into Pydantic, we can
# simplify this configuration and validation.
# See https://github.com/samuelcolvin/pydantic/pull/2336.
source: SourceConfig
sink: DynamicTypedConfig
sink: Optional[DynamicTypedConfig] = None
transformers: Optional[List[DynamicTypedConfig]] = None
flags: FlagsConfig = Field(default=FlagsConfig(), hidden_from_docs=True)
reporting: List[ReporterConfig] = []
@ -100,36 +93,6 @@ class PipelineConfig(ConfigModel):
assert v is not None
return v
@root_validator(pre=True)
def default_sink_is_datahub_rest(cls, values: Dict[str, Any]) -> Any:
if "sink" not in values:
gms_host, gms_token = get_url_and_token()
default_sink_config = {
"type": "datahub-rest",
"config": {
"server": gms_host,
"token": gms_token,
},
}
# resolve env variables if present
default_sink_config = config_loader.resolve_env_variables(
default_sink_config, environ=os.environ
)
values["sink"] = default_sink_config
return values
@validator("datahub_api", always=True)
def datahub_api_should_use_rest_sink_as_default(
cls, v: Optional[DatahubClientConfig], values: Dict[str, Any], **kwargs: Any
) -> Optional[DatahubClientConfig]:
if v is None and "sink" in values and hasattr(values["sink"], "type"):
sink_type = values["sink"].type
if sink_type == "datahub-rest":
sink_config = values["sink"].config
v = DatahubClientConfig.parse_obj_allow_extras(sink_config)
return v
@classmethod
def from_dict(
cls, resolved_dict: dict, raw_dict: Optional[dict] = None

View File

@ -57,10 +57,7 @@ class DataHubGcSource(Source):
self.ctx = ctx
self.config = config
self.report = DataHubGcSourceReport()
self.graph = ctx.graph
assert (
self.graph is not None
), "DataHubGc source requires a graph. Please either use datahub-rest sink or set datahub_api"
self.graph = ctx.require_graph("The DataHubGc source")
@classmethod
def create(cls, config_dict, ctx):

View File

@ -12,6 +12,7 @@ from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.api.transform import Transformer
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.sink.datahub_rest import DatahubRestSink
from datahub.metadata.com.linkedin.pegasus2avro.mxe import SystemMetadata
from datahub.metadata.schema_classes import (
DatasetPropertiesClass,
@ -67,11 +68,9 @@ class TestPipeline:
},
}
)
# assert that the default sink config is for a DatahubRestSink
assert isinstance(pipeline.config.sink, DynamicTypedConfig)
assert pipeline.config.sink.type == "datahub-rest"
assert isinstance(pipeline.config.sink.config, dict)
assert pipeline.config.sink.config["server"] == "http://localhost:8080"
# assert that the default sink is a DatahubRestSink
assert isinstance(pipeline.sink, DatahubRestSink)
assert pipeline.sink.config.server == "http://localhost:8080"
# token value is read from ~/.datahubenv which may be None or not
@freeze_time(FROZEN_TIME)