diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index eafd0b5e33..4fd6f1f017 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -11,15 +11,19 @@ from typing import ( Iterator, List, Optional, + Type, TypeVar, Union, + cast, ) from pydantic import BaseModel +from datahub.configuration.common import ConfigModel from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit from datahub.ingestion.api.report import Report +from datahub.utilities.type_annotations import get_class_from_annotation class SourceCapability(Enum): @@ -136,12 +140,26 @@ class TestConnectionReport(Report): WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit) +ExtractorConfig = TypeVar("ExtractorConfig", bound=ConfigModel) -class Extractor(Generic[WorkUnitType], Closeable, metaclass=ABCMeta): - @abstractmethod - def configure(self, config_dict: dict, ctx: PipelineContext) -> None: - pass +class Extractor(Generic[WorkUnitType, ExtractorConfig], Closeable, metaclass=ABCMeta): + ctx: PipelineContext + config: ExtractorConfig + + @classmethod + def get_config_class(cls) -> Type[ExtractorConfig]: + config_class = get_class_from_annotation(cls, Extractor, ConfigModel) + assert config_class, "Extractor subclasses must define a config class" + return cast(Type[ExtractorConfig], config_class) + + def __init__(self, config_dict: dict, ctx: PipelineContext) -> None: + super().__init__() + + config_class = self.get_config_class() + + self.ctx = ctx + self.config = config_class.parse_obj(config_dict) @abstractmethod def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]: diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py b/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py index e65da022d5..9a9e15fcac 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py @@ -1,9 +1,9 @@ from typing import Iterable, Union +from datahub.configuration.common import ConfigModel from datahub.emitter.mce_builder import get_sys_time from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api import RecordEnvelope -from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.source import Extractor, WorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit, UsageStatsWorkUnit from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( @@ -19,14 +19,15 @@ except ImportError: black = None # type: ignore -class WorkUnitRecordExtractor(Extractor): +class WorkUnitRecordExtractorConfig(ConfigModel): + set_system_metadata = True + + +class WorkUnitRecordExtractor( + Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig] +): """An extractor that simply returns the data inside workunits back as records.""" - ctx: PipelineContext - - def configure(self, config_dict: dict, ctx: PipelineContext) -> None: - self.ctx = ctx - def get_records( self, workunit: WorkUnit ) -> Iterable[ @@ -48,9 +49,10 @@ class WorkUnitRecordExtractor(Extractor): MetadataChangeProposalWrapper, ), ): - workunit.metadata.systemMetadata = SystemMetadata( - lastObserved=get_sys_time(), runId=self.ctx.run_id - ) + if self.config.set_system_metadata: + workunit.metadata.systemMetadata = SystemMetadata( + lastObserved=get_sys_time(), runId=self.ctx.run_id + ) if ( isinstance(workunit.metadata, MetadataChangeEvent) and len(workunit.metadata.proposedSnapshot.aspects) == 0 diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index bc058086d8..9868e50d5d 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -110,6 +110,7 @@ class Pipeline: config: PipelineConfig ctx: PipelineContext source: Source + extractor: Extractor sink: Sink transformers: List[Transformer] @@ -185,6 +186,7 @@ class Pipeline: self.config.source.dict().get("config", {}), self.ctx ) logger.debug(f"Source type:{source_type},{source_class} configured") + logger.info("Source configured successfully.") except Exception as e: self._record_initialization_failure( e, f"Failed to configure source ({source_type})" @@ -192,7 +194,10 @@ class Pipeline: return try: - self.extractor_class = extractor_registry.get(self.config.source.extractor) + extractor_class = extractor_registry.get(self.config.source.extractor) + self.extractor = extractor_class( + self.config.source.extractor_config, self.ctx + ) except Exception as e: self._record_initialization_failure( e, f"Failed to configure extractor ({self.config.source.extractor})" @@ -330,20 +335,17 @@ class Pipeline: self.ctx, self.config.failure_log.log_config ) ) - extractor: Extractor = self.extractor_class() for wu in itertools.islice( self.source.get_workunits(), self.preview_workunits if self.preview_mode else None, ): if self._time_to_print(): self.pretty_print_summary(currently_running=True) - # TODO: change extractor interface - extractor.configure({}, self.ctx) if not self.dry_run: self.sink.handle_work_unit_start(wu) try: - record_envelopes = extractor.get_records(wu) + record_envelopes = self.extractor.get_records(wu) for record_envelope in self.transform(record_envelopes): if not self.dry_run: self.sink.write_record_async(record_envelope, callback) @@ -355,7 +357,7 @@ class Pipeline: except Exception as e: logger.error("Failed to process some records. Continuing.", e) - extractor.close() + self.extractor.close() if not self.dry_run: self.sink.handle_work_unit_end(wu) self.source.close() diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py index 58e523a104..8e9e474fbb 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py @@ -13,9 +13,13 @@ from datahub.ingestion.sink.file import FileSinkConfig logger = logging.getLogger(__name__) +# Sentinel value used to check if the run ID is the default value. +DEFAULT_RUN_ID = "__DEFAULT_RUN_ID" + class SourceConfig(DynamicTypedConfig): extractor: str = "generic" + extractor_config: dict = Field(default_factory=dict) class ReporterConfig(DynamicTypedConfig): @@ -42,7 +46,7 @@ class PipelineConfig(ConfigModel): sink: DynamicTypedConfig transformers: Optional[List[DynamicTypedConfig]] reporting: List[ReporterConfig] = [] - run_id: str = "__DEFAULT_RUN_ID" + run_id: str = DEFAULT_RUN_ID datahub_api: Optional[DatahubClientConfig] = None pipeline_name: Optional[str] = None failure_log: FailureLoggingConfig = FailureLoggingConfig() @@ -55,7 +59,7 @@ class PipelineConfig(ConfigModel): def run_id_should_be_semantic( cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any ) -> str: - if v == "__DEFAULT_RUN_ID": + if v == DEFAULT_RUN_ID: if "source" in values and hasattr(values["source"], "type"): source_type = values["source"].type current_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S") diff --git a/metadata-ingestion/src/datahub/utilities/type_annotations.py b/metadata-ingestion/src/datahub/utilities/type_annotations.py new file mode 100644 index 0000000000..b139a0ed23 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/type_annotations.py @@ -0,0 +1,26 @@ +import inspect +from typing import Optional, Type, TypeVar + +import typing_inspect + +TargetClass = TypeVar("TargetClass") + + +def get_class_from_annotation( + derived_cls: Type, super_class: Type, target_class: Type[TargetClass] +) -> Optional[Type[TargetClass]]: + """ + Attempts to find an instance of target_class in the type annotations of derived_class. + We assume that super_class inherits from typing.Generic and that derived_class inherits from super_class. + """ + + # Modified from https://stackoverflow.com/q/69085037/5004662. + for base in inspect.getmro(derived_cls): + for generic_base in getattr(base, "__orig_bases__", []): + generic_origin = typing_inspect.get_origin(generic_base) + if generic_origin and issubclass(generic_origin, super_class): + for arg in typing_inspect.get_args(generic_base): + if issubclass(arg, target_class): + return arg + + return None