feat(ingest): add config to extractor interface (#5761)

This commit is contained in:
Harshal Sheth 2022-08-30 02:16:17 +00:00 committed by GitHub
parent 80edf7c8ed
commit 24e4ee1746
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 22 deletions

View File

@ -11,15 +11,19 @@ from typing import (
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from pydantic import BaseModel
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.report import Report
from datahub.utilities.type_annotations import get_class_from_annotation
class SourceCapability(Enum):
@ -136,12 +140,26 @@ class TestConnectionReport(Report):
WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit)
ExtractorConfig = TypeVar("ExtractorConfig", bound=ConfigModel)
class Extractor(Generic[WorkUnitType], Closeable, metaclass=ABCMeta):
@abstractmethod
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
pass
class Extractor(Generic[WorkUnitType, ExtractorConfig], Closeable, metaclass=ABCMeta):
ctx: PipelineContext
config: ExtractorConfig
@classmethod
def get_config_class(cls) -> Type[ExtractorConfig]:
config_class = get_class_from_annotation(cls, Extractor, ConfigModel)
assert config_class, "Extractor subclasses must define a config class"
return cast(Type[ExtractorConfig], config_class)
def __init__(self, config_dict: dict, ctx: PipelineContext) -> None:
super().__init__()
config_class = self.get_config_class()
self.ctx = ctx
self.config = config_class.parse_obj(config_dict)
@abstractmethod
def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]:

View File

@ -1,9 +1,9 @@
from typing import Iterable, Union
from datahub.configuration.common import ConfigModel
from datahub.emitter.mce_builder import get_sys_time
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api import RecordEnvelope
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.source import Extractor, WorkUnit
from datahub.ingestion.api.workunit import MetadataWorkUnit, UsageStatsWorkUnit
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
@ -19,14 +19,15 @@ except ImportError:
black = None # type: ignore
class WorkUnitRecordExtractor(Extractor):
class WorkUnitRecordExtractorConfig(ConfigModel):
set_system_metadata = True
class WorkUnitRecordExtractor(
Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig]
):
"""An extractor that simply returns the data inside workunits back as records."""
ctx: PipelineContext
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
self.ctx = ctx
def get_records(
self, workunit: WorkUnit
) -> Iterable[
@ -48,9 +49,10 @@ class WorkUnitRecordExtractor(Extractor):
MetadataChangeProposalWrapper,
),
):
workunit.metadata.systemMetadata = SystemMetadata(
lastObserved=get_sys_time(), runId=self.ctx.run_id
)
if self.config.set_system_metadata:
workunit.metadata.systemMetadata = SystemMetadata(
lastObserved=get_sys_time(), runId=self.ctx.run_id
)
if (
isinstance(workunit.metadata, MetadataChangeEvent)
and len(workunit.metadata.proposedSnapshot.aspects) == 0

View File

@ -110,6 +110,7 @@ class Pipeline:
config: PipelineConfig
ctx: PipelineContext
source: Source
extractor: Extractor
sink: Sink
transformers: List[Transformer]
@ -185,6 +186,7 @@ class Pipeline:
self.config.source.dict().get("config", {}), self.ctx
)
logger.debug(f"Source type:{source_type},{source_class} configured")
logger.info("Source configured successfully.")
except Exception as e:
self._record_initialization_failure(
e, f"Failed to configure source ({source_type})"
@ -192,7 +194,10 @@ class Pipeline:
return
try:
self.extractor_class = extractor_registry.get(self.config.source.extractor)
extractor_class = extractor_registry.get(self.config.source.extractor)
self.extractor = extractor_class(
self.config.source.extractor_config, self.ctx
)
except Exception as e:
self._record_initialization_failure(
e, f"Failed to configure extractor ({self.config.source.extractor})"
@ -330,20 +335,17 @@ class Pipeline:
self.ctx, self.config.failure_log.log_config
)
)
extractor: Extractor = self.extractor_class()
for wu in itertools.islice(
self.source.get_workunits(),
self.preview_workunits if self.preview_mode else None,
):
if self._time_to_print():
self.pretty_print_summary(currently_running=True)
# TODO: change extractor interface
extractor.configure({}, self.ctx)
if not self.dry_run:
self.sink.handle_work_unit_start(wu)
try:
record_envelopes = extractor.get_records(wu)
record_envelopes = self.extractor.get_records(wu)
for record_envelope in self.transform(record_envelopes):
if not self.dry_run:
self.sink.write_record_async(record_envelope, callback)
@ -355,7 +357,7 @@ class Pipeline:
except Exception as e:
logger.error("Failed to process some records. Continuing.", e)
extractor.close()
self.extractor.close()
if not self.dry_run:
self.sink.handle_work_unit_end(wu)
self.source.close()

View File

@ -13,9 +13,13 @@ from datahub.ingestion.sink.file import FileSinkConfig
logger = logging.getLogger(__name__)
# Sentinel value used to check if the run ID is the default value.
DEFAULT_RUN_ID = "__DEFAULT_RUN_ID"
class SourceConfig(DynamicTypedConfig):
extractor: str = "generic"
extractor_config: dict = Field(default_factory=dict)
class ReporterConfig(DynamicTypedConfig):
@ -42,7 +46,7 @@ class PipelineConfig(ConfigModel):
sink: DynamicTypedConfig
transformers: Optional[List[DynamicTypedConfig]]
reporting: List[ReporterConfig] = []
run_id: str = "__DEFAULT_RUN_ID"
run_id: str = DEFAULT_RUN_ID
datahub_api: Optional[DatahubClientConfig] = None
pipeline_name: Optional[str] = None
failure_log: FailureLoggingConfig = FailureLoggingConfig()
@ -55,7 +59,7 @@ class PipelineConfig(ConfigModel):
def run_id_should_be_semantic(
cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any
) -> str:
if v == "__DEFAULT_RUN_ID":
if v == DEFAULT_RUN_ID:
if "source" in values and hasattr(values["source"], "type"):
source_type = values["source"].type
current_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")

View File

@ -0,0 +1,26 @@
import inspect
from typing import Optional, Type, TypeVar
import typing_inspect
TargetClass = TypeVar("TargetClass")
def get_class_from_annotation(
derived_cls: Type, super_class: Type, target_class: Type[TargetClass]
) -> Optional[Type[TargetClass]]:
"""
Attempts to find an instance of target_class in the type annotations of derived_class.
We assume that super_class inherits from typing.Generic and that derived_class inherits from super_class.
"""
# Modified from https://stackoverflow.com/q/69085037/5004662.
for base in inspect.getmro(derived_cls):
for generic_base in getattr(base, "__orig_bases__", []):
generic_origin = typing_inspect.get_origin(generic_base)
if generic_origin and issubclass(generic_origin, super_class):
for arg in typing_inspect.get_args(generic_base):
if issubclass(arg, target_class):
return arg
return None