feat(ingest): add config to extractor interface (#5761)

This commit is contained in:
Harshal Sheth 2022-08-30 02:16:17 +00:00 committed by GitHub
parent 80edf7c8ed
commit 24e4ee1746
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 22 deletions

View File

@ -11,15 +11,19 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
Type,
TypeVar, TypeVar,
Union, Union,
cast,
) )
from pydantic import BaseModel from pydantic import BaseModel
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.report import Report from datahub.ingestion.api.report import Report
from datahub.utilities.type_annotations import get_class_from_annotation
class SourceCapability(Enum): class SourceCapability(Enum):
@ -136,12 +140,26 @@ class TestConnectionReport(Report):
WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit) WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit)
ExtractorConfig = TypeVar("ExtractorConfig", bound=ConfigModel)
class Extractor(Generic[WorkUnitType], Closeable, metaclass=ABCMeta): class Extractor(Generic[WorkUnitType, ExtractorConfig], Closeable, metaclass=ABCMeta):
@abstractmethod ctx: PipelineContext
def configure(self, config_dict: dict, ctx: PipelineContext) -> None: config: ExtractorConfig
pass
@classmethod
def get_config_class(cls) -> Type[ExtractorConfig]:
config_class = get_class_from_annotation(cls, Extractor, ConfigModel)
assert config_class, "Extractor subclasses must define a config class"
return cast(Type[ExtractorConfig], config_class)
def __init__(self, config_dict: dict, ctx: PipelineContext) -> None:
super().__init__()
config_class = self.get_config_class()
self.ctx = ctx
self.config = config_class.parse_obj(config_dict)
@abstractmethod @abstractmethod
def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]: def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]:

View File

@ -1,9 +1,9 @@
from typing import Iterable, Union from typing import Iterable, Union
from datahub.configuration.common import ConfigModel
from datahub.emitter.mce_builder import get_sys_time from datahub.emitter.mce_builder import get_sys_time
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api import RecordEnvelope from datahub.ingestion.api import RecordEnvelope
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.source import Extractor, WorkUnit from datahub.ingestion.api.source import Extractor, WorkUnit
from datahub.ingestion.api.workunit import MetadataWorkUnit, UsageStatsWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit, UsageStatsWorkUnit
from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
@ -19,14 +19,15 @@ except ImportError:
black = None # type: ignore black = None # type: ignore
class WorkUnitRecordExtractor(Extractor): class WorkUnitRecordExtractorConfig(ConfigModel):
set_system_metadata = True
class WorkUnitRecordExtractor(
Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig]
):
"""An extractor that simply returns the data inside workunits back as records.""" """An extractor that simply returns the data inside workunits back as records."""
ctx: PipelineContext
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
self.ctx = ctx
def get_records( def get_records(
self, workunit: WorkUnit self, workunit: WorkUnit
) -> Iterable[ ) -> Iterable[
@ -48,6 +49,7 @@ class WorkUnitRecordExtractor(Extractor):
MetadataChangeProposalWrapper, MetadataChangeProposalWrapper,
), ),
): ):
if self.config.set_system_metadata:
workunit.metadata.systemMetadata = SystemMetadata( workunit.metadata.systemMetadata = SystemMetadata(
lastObserved=get_sys_time(), runId=self.ctx.run_id lastObserved=get_sys_time(), runId=self.ctx.run_id
) )

View File

@ -110,6 +110,7 @@ class Pipeline:
config: PipelineConfig config: PipelineConfig
ctx: PipelineContext ctx: PipelineContext
source: Source source: Source
extractor: Extractor
sink: Sink sink: Sink
transformers: List[Transformer] transformers: List[Transformer]
@ -185,6 +186,7 @@ class Pipeline:
self.config.source.dict().get("config", {}), self.ctx self.config.source.dict().get("config", {}), self.ctx
) )
logger.debug(f"Source type:{source_type},{source_class} configured") logger.debug(f"Source type:{source_type},{source_class} configured")
logger.info("Source configured successfully.")
except Exception as e: except Exception as e:
self._record_initialization_failure( self._record_initialization_failure(
e, f"Failed to configure source ({source_type})" e, f"Failed to configure source ({source_type})"
@ -192,7 +194,10 @@ class Pipeline:
return return
try: try:
self.extractor_class = extractor_registry.get(self.config.source.extractor) extractor_class = extractor_registry.get(self.config.source.extractor)
self.extractor = extractor_class(
self.config.source.extractor_config, self.ctx
)
except Exception as e: except Exception as e:
self._record_initialization_failure( self._record_initialization_failure(
e, f"Failed to configure extractor ({self.config.source.extractor})" e, f"Failed to configure extractor ({self.config.source.extractor})"
@ -330,20 +335,17 @@ class Pipeline:
self.ctx, self.config.failure_log.log_config self.ctx, self.config.failure_log.log_config
) )
) )
extractor: Extractor = self.extractor_class()
for wu in itertools.islice( for wu in itertools.islice(
self.source.get_workunits(), self.source.get_workunits(),
self.preview_workunits if self.preview_mode else None, self.preview_workunits if self.preview_mode else None,
): ):
if self._time_to_print(): if self._time_to_print():
self.pretty_print_summary(currently_running=True) self.pretty_print_summary(currently_running=True)
# TODO: change extractor interface
extractor.configure({}, self.ctx)
if not self.dry_run: if not self.dry_run:
self.sink.handle_work_unit_start(wu) self.sink.handle_work_unit_start(wu)
try: try:
record_envelopes = extractor.get_records(wu) record_envelopes = self.extractor.get_records(wu)
for record_envelope in self.transform(record_envelopes): for record_envelope in self.transform(record_envelopes):
if not self.dry_run: if not self.dry_run:
self.sink.write_record_async(record_envelope, callback) self.sink.write_record_async(record_envelope, callback)
@ -355,7 +357,7 @@ class Pipeline:
except Exception as e: except Exception as e:
logger.error("Failed to process some records. Continuing.", e) logger.error("Failed to process some records. Continuing.", e)
extractor.close() self.extractor.close()
if not self.dry_run: if not self.dry_run:
self.sink.handle_work_unit_end(wu) self.sink.handle_work_unit_end(wu)
self.source.close() self.source.close()

View File

@ -13,9 +13,13 @@ from datahub.ingestion.sink.file import FileSinkConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Sentinel value used to check if the run ID is the default value.
DEFAULT_RUN_ID = "__DEFAULT_RUN_ID"
class SourceConfig(DynamicTypedConfig): class SourceConfig(DynamicTypedConfig):
extractor: str = "generic" extractor: str = "generic"
extractor_config: dict = Field(default_factory=dict)
class ReporterConfig(DynamicTypedConfig): class ReporterConfig(DynamicTypedConfig):
@ -42,7 +46,7 @@ class PipelineConfig(ConfigModel):
sink: DynamicTypedConfig sink: DynamicTypedConfig
transformers: Optional[List[DynamicTypedConfig]] transformers: Optional[List[DynamicTypedConfig]]
reporting: List[ReporterConfig] = [] reporting: List[ReporterConfig] = []
run_id: str = "__DEFAULT_RUN_ID" run_id: str = DEFAULT_RUN_ID
datahub_api: Optional[DatahubClientConfig] = None datahub_api: Optional[DatahubClientConfig] = None
pipeline_name: Optional[str] = None pipeline_name: Optional[str] = None
failure_log: FailureLoggingConfig = FailureLoggingConfig() failure_log: FailureLoggingConfig = FailureLoggingConfig()
@ -55,7 +59,7 @@ class PipelineConfig(ConfigModel):
def run_id_should_be_semantic( def run_id_should_be_semantic(
cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any
) -> str: ) -> str:
if v == "__DEFAULT_RUN_ID": if v == DEFAULT_RUN_ID:
if "source" in values and hasattr(values["source"], "type"): if "source" in values and hasattr(values["source"], "type"):
source_type = values["source"].type source_type = values["source"].type
current_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S") current_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")

View File

@ -0,0 +1,26 @@
import inspect
from typing import Optional, Type, TypeVar
import typing_inspect
TargetClass = TypeVar("TargetClass")
def get_class_from_annotation(
derived_cls: Type, super_class: Type, target_class: Type[TargetClass]
) -> Optional[Type[TargetClass]]:
"""
Attempts to find an instance of target_class in the type annotations of derived_class.
We assume that super_class inherits from typing.Generic and that derived_class inherits from super_class.
"""
# Modified from https://stackoverflow.com/q/69085037/5004662.
for base in inspect.getmro(derived_cls):
for generic_base in getattr(base, "__orig_bases__", []):
generic_origin = typing_inspect.get_origin(generic_base)
if generic_origin and issubclass(generic_origin, super_class):
for arg in typing_inspect.get_args(generic_base):
if issubclass(arg, target_class):
return arg
return None