mirror of
https://github.com/datahub-project/datahub.git
synced 2025-09-03 06:13:14 +00:00
feat(ingest): add config to extractor interface (#5761)
This commit is contained in:
parent
80edf7c8ed
commit
24e4ee1746
@ -11,15 +11,19 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from datahub.configuration.common import ConfigModel
|
||||
from datahub.ingestion.api.closeable import Closeable
|
||||
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
|
||||
from datahub.ingestion.api.report import Report
|
||||
from datahub.utilities.type_annotations import get_class_from_annotation
|
||||
|
||||
|
||||
class SourceCapability(Enum):
|
||||
@ -136,12 +140,26 @@ class TestConnectionReport(Report):
|
||||
|
||||
|
||||
WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit)
|
||||
ExtractorConfig = TypeVar("ExtractorConfig", bound=ConfigModel)
|
||||
|
||||
|
||||
class Extractor(Generic[WorkUnitType], Closeable, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
class Extractor(Generic[WorkUnitType, ExtractorConfig], Closeable, metaclass=ABCMeta):
|
||||
ctx: PipelineContext
|
||||
config: ExtractorConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[ExtractorConfig]:
|
||||
config_class = get_class_from_annotation(cls, Extractor, ConfigModel)
|
||||
assert config_class, "Extractor subclasses must define a config class"
|
||||
return cast(Type[ExtractorConfig], config_class)
|
||||
|
||||
def __init__(self, config_dict: dict, ctx: PipelineContext) -> None:
|
||||
super().__init__()
|
||||
|
||||
config_class = self.get_config_class()
|
||||
|
||||
self.ctx = ctx
|
||||
self.config = config_class.parse_obj(config_dict)
|
||||
|
||||
@abstractmethod
|
||||
def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]:
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Iterable, Union
|
||||
|
||||
from datahub.configuration.common import ConfigModel
|
||||
from datahub.emitter.mce_builder import get_sys_time
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.ingestion.api import RecordEnvelope
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.api.source import Extractor, WorkUnit
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit, UsageStatsWorkUnit
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
|
||||
@ -19,14 +19,15 @@ except ImportError:
|
||||
black = None # type: ignore
|
||||
|
||||
|
||||
class WorkUnitRecordExtractor(Extractor):
|
||||
class WorkUnitRecordExtractorConfig(ConfigModel):
|
||||
set_system_metadata = True
|
||||
|
||||
|
||||
class WorkUnitRecordExtractor(
|
||||
Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig]
|
||||
):
|
||||
"""An extractor that simply returns the data inside workunits back as records."""
|
||||
|
||||
ctx: PipelineContext
|
||||
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
def get_records(
|
||||
self, workunit: WorkUnit
|
||||
) -> Iterable[
|
||||
@ -48,9 +49,10 @@ class WorkUnitRecordExtractor(Extractor):
|
||||
MetadataChangeProposalWrapper,
|
||||
),
|
||||
):
|
||||
workunit.metadata.systemMetadata = SystemMetadata(
|
||||
lastObserved=get_sys_time(), runId=self.ctx.run_id
|
||||
)
|
||||
if self.config.set_system_metadata:
|
||||
workunit.metadata.systemMetadata = SystemMetadata(
|
||||
lastObserved=get_sys_time(), runId=self.ctx.run_id
|
||||
)
|
||||
if (
|
||||
isinstance(workunit.metadata, MetadataChangeEvent)
|
||||
and len(workunit.metadata.proposedSnapshot.aspects) == 0
|
||||
|
@ -110,6 +110,7 @@ class Pipeline:
|
||||
config: PipelineConfig
|
||||
ctx: PipelineContext
|
||||
source: Source
|
||||
extractor: Extractor
|
||||
sink: Sink
|
||||
transformers: List[Transformer]
|
||||
|
||||
@ -185,6 +186,7 @@ class Pipeline:
|
||||
self.config.source.dict().get("config", {}), self.ctx
|
||||
)
|
||||
logger.debug(f"Source type:{source_type},{source_class} configured")
|
||||
logger.info("Source configured successfully.")
|
||||
except Exception as e:
|
||||
self._record_initialization_failure(
|
||||
e, f"Failed to configure source ({source_type})"
|
||||
@ -192,7 +194,10 @@ class Pipeline:
|
||||
return
|
||||
|
||||
try:
|
||||
self.extractor_class = extractor_registry.get(self.config.source.extractor)
|
||||
extractor_class = extractor_registry.get(self.config.source.extractor)
|
||||
self.extractor = extractor_class(
|
||||
self.config.source.extractor_config, self.ctx
|
||||
)
|
||||
except Exception as e:
|
||||
self._record_initialization_failure(
|
||||
e, f"Failed to configure extractor ({self.config.source.extractor})"
|
||||
@ -330,20 +335,17 @@ class Pipeline:
|
||||
self.ctx, self.config.failure_log.log_config
|
||||
)
|
||||
)
|
||||
extractor: Extractor = self.extractor_class()
|
||||
for wu in itertools.islice(
|
||||
self.source.get_workunits(),
|
||||
self.preview_workunits if self.preview_mode else None,
|
||||
):
|
||||
if self._time_to_print():
|
||||
self.pretty_print_summary(currently_running=True)
|
||||
# TODO: change extractor interface
|
||||
extractor.configure({}, self.ctx)
|
||||
|
||||
if not self.dry_run:
|
||||
self.sink.handle_work_unit_start(wu)
|
||||
try:
|
||||
record_envelopes = extractor.get_records(wu)
|
||||
record_envelopes = self.extractor.get_records(wu)
|
||||
for record_envelope in self.transform(record_envelopes):
|
||||
if not self.dry_run:
|
||||
self.sink.write_record_async(record_envelope, callback)
|
||||
@ -355,7 +357,7 @@ class Pipeline:
|
||||
except Exception as e:
|
||||
logger.error("Failed to process some records. Continuing.", e)
|
||||
|
||||
extractor.close()
|
||||
self.extractor.close()
|
||||
if not self.dry_run:
|
||||
self.sink.handle_work_unit_end(wu)
|
||||
self.source.close()
|
||||
|
@ -13,9 +13,13 @@ from datahub.ingestion.sink.file import FileSinkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel value used to check if the run ID is the default value.
|
||||
DEFAULT_RUN_ID = "__DEFAULT_RUN_ID"
|
||||
|
||||
|
||||
class SourceConfig(DynamicTypedConfig):
|
||||
extractor: str = "generic"
|
||||
extractor_config: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ReporterConfig(DynamicTypedConfig):
|
||||
@ -42,7 +46,7 @@ class PipelineConfig(ConfigModel):
|
||||
sink: DynamicTypedConfig
|
||||
transformers: Optional[List[DynamicTypedConfig]]
|
||||
reporting: List[ReporterConfig] = []
|
||||
run_id: str = "__DEFAULT_RUN_ID"
|
||||
run_id: str = DEFAULT_RUN_ID
|
||||
datahub_api: Optional[DatahubClientConfig] = None
|
||||
pipeline_name: Optional[str] = None
|
||||
failure_log: FailureLoggingConfig = FailureLoggingConfig()
|
||||
@ -55,7 +59,7 @@ class PipelineConfig(ConfigModel):
|
||||
def run_id_should_be_semantic(
|
||||
cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any
|
||||
) -> str:
|
||||
if v == "__DEFAULT_RUN_ID":
|
||||
if v == DEFAULT_RUN_ID:
|
||||
if "source" in values and hasattr(values["source"], "type"):
|
||||
source_type = values["source"].type
|
||||
current_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
|
||||
|
26
metadata-ingestion/src/datahub/utilities/type_annotations.py
Normal file
26
metadata-ingestion/src/datahub/utilities/type_annotations.py
Normal file
@ -0,0 +1,26 @@
|
||||
import inspect
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import typing_inspect
|
||||
|
||||
TargetClass = TypeVar("TargetClass")
|
||||
|
||||
|
||||
def get_class_from_annotation(
|
||||
derived_cls: Type, super_class: Type, target_class: Type[TargetClass]
|
||||
) -> Optional[Type[TargetClass]]:
|
||||
"""
|
||||
Attempts to find an instance of target_class in the type annotations of derived_class.
|
||||
We assume that super_class inherits from typing.Generic and that derived_class inherits from super_class.
|
||||
"""
|
||||
|
||||
# Modified from https://stackoverflow.com/q/69085037/5004662.
|
||||
for base in inspect.getmro(derived_cls):
|
||||
for generic_base in getattr(base, "__orig_bases__", []):
|
||||
generic_origin = typing_inspect.get_origin(generic_base)
|
||||
if generic_origin and issubclass(generic_origin, super_class):
|
||||
for arg in typing_inspect.get_args(generic_base):
|
||||
if issubclass(arg, target_class):
|
||||
return arg
|
||||
|
||||
return None
|
Loading…
x
Reference in New Issue
Block a user