feat(ingest): framework - client side changes for monitoring and reporting (#3807)

This commit is contained in:
Ravindra Lanka 2022-02-02 13:19:15 -08:00 committed by GitHub
parent 78d35f95cf
commit f20382f956
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 2100 additions and 559 deletions

View File

@ -312,8 +312,11 @@ entry_points = {
"datahub-kafka = datahub.ingestion.sink.datahub_kafka:DatahubKafkaSink",
"datahub-rest = datahub.ingestion.sink.datahub_rest:DatahubRestSink",
],
"datahub.ingestion.state_provider.plugins": [
"datahub = datahub.ingestion.source.state_provider.datahub_ingestion_state_provider:DatahubIngestionStateProvider",
"datahub.ingestion.checkpointing_provider.plugins": [
"datahub = datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider:DatahubIngestionCheckpointingProvider",
],
"datahub.ingestion.reporting_provider.plugins": [
"datahub = datahub.ingestion.reporting.datahub_ingestion_reporting_provider:DatahubIngestionReportingProvider",
],
"apache_airflow_provider": ["provider_info=datahub_provider:get_provider_info"],
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 615 KiB

View File

@ -0,0 +1,95 @@
# Datahub's Reporting Framework for Ingestion Job Telemetry
The Datahub's reporting framework allows for configuring reporting providers with the ingestion pipelines to send
telemetry about the ingestion job runs to external systems for monitoring purposes. It is powered by the Datahub's
stateful ingestion framework. The `datahub` reporting provider comes with the standard client installation,
and allows for reporting ingestion job telemetry to the datahub backend as the destination.
**_NOTE_**: This feature requires the server to be `statefulIngestion` capable.
This is a feature of metadata service with version >= `0.8.20`.
To check if you are running a stateful ingestion capable server:
```console
curl http://<datahub-gms-endpoint>/config
{
models: { },
statefulIngestionCapable: true, # <-- this should be present and true
retention: "true",
noCode: "true"
}
```
## Config details
The ingestion reporting providers are a list of reporting provider configurations under the `reporting` config
param of the pipeline, each reporting provider configuration begin a type and config pair object. The telemetry data will
be sent to all the reporting providers in this list.
Note that a `.` is used to denote nested fields, and `[idx]` is used to denote an element of an array of objects in the YAML recipe.
| Field | Required | Default | Description |
|-------------------------| -------- |------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|
| `reporting[idx].type` | ✅ | `datahub` | The type of the ingestion reporting provider registered with datahub. |
| `reporting[idx].config` | | The `datahub_api` config if set at pipeline level. Otherwise, the default `DatahubClientConfig`. See the [defaults](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/graph/client.py#L19) here. | The configuration required for initializing the datahub reporting provider. |
| `pipeline_name` | ✅ | | The name of the ingestion pipeline. This is used as a part of the identifying key for the telemetry data reported by each job in the ingestion pipeline. |
#### Supported sources
* All sql based sources.
* snowflake_usage.
#### Sample configuration
```yaml
source:
type: "snowflake"
config:
username: <user_name>
password: <password>
role: <role>
host_port: <host_port>
warehouse: <ware_house>
# Rest of the source specific params ...
# This is mandatory. Changing it will cause old telemetry correlation to be lost.
pipeline_name: "my_snowflake_pipeline_1"
# Pipeline-level datahub_api configuration.
datahub_api: # Optional. But if provided, this config will be used by the "datahub" ingestion state provider.
server: "http://localhost:8080"
sink:
type: "datahub-rest"
config:
server: 'http://localhost:8080'
reporting:
- type: "datahub" # Required
config: # Optional.
datahub_api: # default value
server: "http://localhost:8080"
```
## Reporting Ingestion State Provider (Developer Guide)
An ingestion reporting state provider is responsible for saving and retrieving the ingestion telemetry
associated with the ingestion runs of various jobs inside the source connector of the ingestion pipeline.
The data model used for capturing the telemetry is [DatahubIngestionRunSummary](https://github.com/linkedin/datahub/blob/master/metadata-models/src/main/pegasus/com/linkedin/datajob/datahub/DatahubIngestionRunSummary.pdl).
A reporting ingestion state provider needs to implement the [IngestionReportingProviderBase](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_reporting_provider_base.py)
interface and register itself with datahub by adding an entry under `datahub.ingestion.checkpointing_provider.plugins`
key of the entry_points section in [setup.py](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/setup.py)
with its type and implementation class as shown below.
```python
entry_points = {
# <snip other keys>"
"datahub.ingestion.checkpointing_provider.plugins": [
"datahub = datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider:DatahubIngestionCheckpointingProvider",
],
}
```
### Datahub Reporting Ingestion State Provider
This is the reporting state provider implementation that is available out of the box in datahub. Its type is `datahub` and it is implemented on top
of the `datahub_api` client and the timeseries aspect capabilities of the datahub-backend.
#### Config details
Note that a `.` is used to denote nested fields in the YAML recipe.
| Field | Required | Default | Description |
|----------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|
| `type` | ✅ | `datahub` | The type of the ingestion reporting provider registered with datahub. |
| `config` | | The `datahub_api` config if set at pipeline level. Otherwise, the default `DatahubClientConfig`. See the [defaults](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/graph/client.py#L19) here. | The configuration required for initializing the datahub reporting provider. |

View File

@ -35,8 +35,11 @@ NOTE: If either `dry-run` or `preview` mode are set, stateful ingestion will be
## Use-cases powered by stateful ingestion.
Following is the list of current use-cases powered by stateful ingestion in datahub.
### Removal of stale tables and views.
Stateful ingestion can be used to automatically soft delete the tables and views that are seen in a previous run
Stateful ingestion can be used to automatically soft-delete the tables and views that are seen in a previous run
but absent in the current run (they are either deleted or no longer desired).
![Stale Metadata Deletion](./images/stale_metadata_deletion.png)
#### Supported sources
* All sql based sources.
#### Additional config details
@ -124,22 +127,22 @@ sink:
server: 'http://localhost:8080'
```
## The Ingestion State Provider
The ingestion state provider is responsible for saving and retrieving the ingestion state associated with the ingestion runs
of various jobs inside the source connector of the ingestion pipeline. An ingestion state provider needs to implement the
[IngestionStateProvider](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/api/ingestion_state_provider.py) interface and
register itself with datahub by adding an entry under `datahub.ingestion.state_provider.plugins` key of the entry_points section in [setup.py](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/setup.py) with its type and implementation class as shown below.
## The Checkpointing Ingestion State Provider (Developer Guide)
The ingestion checkpointing state provider is responsible for saving and retrieving the ingestion checkpoint state associated with the ingestion runs
of various jobs inside the source connector of the ingestion pipeline. The checkpointing data model is [DatahubIngestionCheckpoint](https://github.com/linkedin/datahub/blob/master/metadata-models/src/main/pegasus/com/linkedin/datajob/datahub/DatahubIngestionCheckpoint.pdl) and it supports any custom state to be stored using the [IngestionCheckpointState](https://github.com/linkedin/datahub/blob/master/metadata-models/src/main/pegasus/com/linkedin/datajob/datahub/IngestionCheckpointState.pdl#L9). A checkpointing ingestion state provider needs to implement the
[IngestionCheckpointingProviderBase](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_checkpointing_provider_base.py) interface and
register itself with datahub by adding an entry under `datahub.ingestion.checkpointing_provider.plugins` key of the entry_points section in [setup.py](https://github.com/linkedin/datahub/blob/master/metadata-ingestion/setup.py) with its type and implementation class as shown below.
```python
entry_points = {
# <snip other keys>"
"datahub.ingestion.state_provider.plugins": [
"datahub = datahub.ingestion.source.state_provider.datahub_ingestion_state_provider:DatahubIngestionStateProvider",
]
"datahub.ingestion.checkpointing_provider.plugins": [
"datahub = datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider:DatahubIngestionCheckpointingProvider",
],
}
```
### Datahub Ingestion State Provider
This is the state provider implementation that is avialble out of the box. It's type is `datahub` and it is implemented on top
### Datahub Checkpointing Ingestion State Provider
This is the state provider implementation that is available out of the box. Its type is `datahub` and it is implemented on top
of the `datahub_api` client and the timeseries aspect capabilities of the datahub-backend.
#### Config details

View File

@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Generic, List, Optional, TypeVar
class CommitPolicy(Enum):
ALWAYS = auto
ON_NO_ERRORS = auto
ON_NO_ERRORS_AND_NO_WARNINGS = auto
@dataclass
class _CommittableConcrete:
name: str
commit_policy: CommitPolicy
committed: bool
# The concrete portion Committable is separated from the abstract portion due to
# https://github.com/python/mypy/issues/5374#issuecomment-568335302.
class Committable(_CommittableConcrete, ABC):
def __init__(self, name: str, commit_policy: CommitPolicy):
super(Committable, self).__init__(name, commit_policy, committed=False)
@abstractmethod
def commit(self) -> None:
pass
StateKeyType = TypeVar("StateKeyType")
StateType = TypeVar("StateType")
# TODO: Add a better alternative to a string for the filter.
FilterType = TypeVar("FilterType")
@dataclass
class _StatefulCommittableConcrete(Generic[StateType]):
state_to_commit: StateType
class StatefulCommittable(
Committable,
_StatefulCommittableConcrete[StateType],
Generic[StateKeyType, StateType, FilterType],
):
def __init__(
self, name: str, commit_policy: CommitPolicy, state_to_commit: StateType
):
# _ConcreteCommittable will be the first from this class.
super(StatefulCommittable, self).__init__(
name=name, commit_policy=commit_policy
)
# _StatefulCommittableConcrete will be after _CommittableConcrete in the __mro__.
super(_CommittableConcrete, self).__init__(state_to_commit=state_to_commit)
def has_successfully_committed(self) -> bool:
return True if not self.state_to_commit or self.committed else False
@abstractmethod
def get_previous_states(
self,
state_key: StateKeyType,
last_only: bool = True,
filter_opt: Optional[FilterType] = None,
) -> List[StateType]:
pass

View File

@ -1,7 +1,8 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar
from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar
from datahub.ingestion.api.committable import Committable
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
T = TypeVar("T")
@ -41,3 +42,29 @@ class PipelineContext:
self.pipeline_name = pipeline_name
self.dry_run_mode = dry_run
self.preview_mode = preview_mode
self.reporters: Dict[str, Committable] = dict()
self.checkpointers: Dict[str, Committable] = dict()
def register_checkpointer(self, committable: Committable) -> None:
if committable.name in self.checkpointers:
raise IndexError(
f"Checkpointing provider {committable.name} already registered."
)
self.checkpointers[committable.name] = committable
def register_reporter(self, committable: Committable) -> None:
if committable.name in self.reporters:
raise IndexError(
f"Reporting provider {committable.name} already registered."
)
self.reporters[committable.name] = committable
def get_reporters(self) -> Iterable[Committable]:
for committable in self.reporters.values():
yield committable
def get_committables(self) -> Iterable[Tuple[str, Committable]]:
for reporting_item_commitable in self.reporters.items():
yield reporting_item_commitable
for checkpointing_item_commitable in self.checkpointers.items():
yield checkpointing_item_commitable

View File

@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from datahub.ingestion.api.committable import CommitPolicy
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_state_provider import (
IngestionJobStateProvider,
IngestionJobStateProviderConfig,
JobId,
JobStateFilterType,
JobStateKey,
JobStatesMap,
)
from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
#
# Common type exports
#
JobId = JobId
JobStateKey = JobStateKey
JobStateFilterType = JobStateFilterType
#
# Checkpoint state specific types
#
CheckpointJobStateType = DatahubIngestionCheckpointClass
CheckpointJobStatesMap = JobStatesMap[CheckpointJobStateType]
class IngestionCheckpointingProviderConfig(IngestionJobStateProviderConfig):
pass
@dataclass()
class IngestionCheckpointingProviderBase(
IngestionJobStateProvider[CheckpointJobStateType]
):
"""
The base class(non-abstract) for all checkpointing state provider implementations.
This class is implemented this way as a concrete class is needed to work with the registry,
but we don't want to implement any of the functionality yet.
"""
def __init__(
self, name: str, commit_policy: CommitPolicy = CommitPolicy.ON_NO_ERRORS
):
super(IngestionCheckpointingProviderBase, self).__init__(name, commit_policy)
@classmethod
def create(
cls, config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> "IngestionJobStateProvider":
raise NotImplementedError("Sub-classes must override this method.")
def get_previous_states(
self,
state_key: JobStateKey,
last_only: bool = True,
filter_opt: Optional[JobStateFilterType] = None,
) -> List[CheckpointJobStatesMap]:
raise NotImplementedError("Sub-classes must override this method.")
def commit(self) -> None:
raise NotImplementedError("Sub-classes must override this method.")

View File

@ -0,0 +1,60 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from datahub.ingestion.api.committable import CommitPolicy
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_state_provider import (
IngestionJobStateProvider,
IngestionJobStateProviderConfig,
JobId,
JobStateFilterType,
JobStateKey,
JobStatesMap,
)
from datahub.metadata.schema_classes import DatahubIngestionRunSummaryClass
#
# Common type exports
#
JobId = JobId
JobStateKey = JobStateKey
JobStateFilterType = JobStateFilterType
#
# Reporting state specific types
#
ReportingJobStateType = DatahubIngestionRunSummaryClass
ReportingJobStatesMap = JobStatesMap[ReportingJobStateType]
class IngestionReportingProviderConfig(IngestionJobStateProviderConfig):
pass
@dataclass()
class IngestionReportingProviderBase(IngestionJobStateProvider[ReportingJobStateType]):
"""
The base class(non-abstract) for all reporting state provider implementations.
This class is implemented this way as a concrete class is needed to work with the registry,
but we don't want to implement any of the functionality yet.
"""
def __init__(self, name: str, commit_policy: CommitPolicy = CommitPolicy.ALWAYS):
super(IngestionReportingProviderBase, self).__init__(name, commit_policy)
@classmethod
def create(
cls, config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> "IngestionJobStateProvider":
raise NotImplementedError("Sub-classes must override this method.")
def get_previous_states(
self,
state_key: JobStateKey,
last_only: bool = True,
filter_opt: Optional[JobStateFilterType] = None,
) -> List[ReportingJobStatesMap]:
raise NotImplementedError("Sub-classes must override this method.")
def commit(self) -> None:
raise NotImplementedError("Sub-classes must override this method.")

View File

@ -0,0 +1,68 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, NewType, Optional, TypeVar
import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.committable import CommitPolicy, StatefulCommittable
from datahub.ingestion.api.common import PipelineContext
JobId = NewType("JobId", str)
JobState = TypeVar("JobState")
JobStatesMap = Dict[JobId, JobState]
# TODO: We need a first-class representation of a search filter in the python code. Using str for now.
JobStateFilterType = NewType("JobStateFilterType", str)
@dataclass
class JobStateKey:
pipeline_name: str
platform_instance_id: str
job_names: List[JobId]
class IngestionJobStateProviderConfig(ConfigModel):
pass
class IngestionJobStateProvider(
StatefulCommittable[JobStateKey, JobStatesMap, JobStateFilterType],
Generic[JobState],
):
"""
Abstract base class for all ingestion state providers.
This introduces the notion of ingestion pipelines and jobs for committable state providers.
"""
def __init__(self, name: str, commit_policy: CommitPolicy):
super(IngestionJobStateProvider, self).__init__(name, commit_policy, dict())
@classmethod
@abstractmethod
def create(
cls, config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> "IngestionJobStateProvider":
"""Concrete sub-classes must throw an exception if this fails."""
pass
def get_last_state(self, state_key: JobStateKey) -> Optional[JobStatesMap]:
previous_states = self.get_previous_states(
state_key=state_key, last_only=True, filter_opt=None
)
if previous_states:
return previous_states[0]
return None
@staticmethod
def get_data_job_urn(
orchestrator: str,
pipeline_name: str,
job_name: JobId,
platform_instance_id: str,
) -> str:
"""
Standardizes datajob urn minting for all ingestion job state providers.
"""
return builder.make_data_job_urn(
orchestrator, f"{pipeline_name}_{platform_instance_id}", job_name
)

View File

@ -0,0 +1,172 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from datahub.configuration.common import ConfigurationError
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_reporting_provider_base import (
IngestionReportingProviderBase,
IngestionReportingProviderConfig,
JobId,
JobStateFilterType,
JobStateKey,
ReportingJobStatesMap,
)
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.metadata.schema_classes import (
ChangeTypeClass,
DatahubIngestionRunSummaryClass,
)
logger = logging.getLogger(__name__)
class DatahubIngestionReportingProviderConfig(IngestionReportingProviderConfig):
datahub_api: Optional[DatahubClientConfig] = DatahubClientConfig()
class DatahubIngestionReportingProvider(IngestionReportingProviderBase):
orchestrator_name: str = "datahub"
def __init__(self, graph: DataHubGraph, name: str):
super().__init__(name)
self.graph = graph
if not self._is_server_stateful_ingestion_capable():
raise ConfigurationError(
"Datahub server is not capable of supporting stateful ingestion."
" Please consider upgrading to the latest server version to use this feature."
)
@classmethod
def create(
cls, config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> IngestionReportingProviderBase:
if ctx.graph:
return cls(ctx.graph, name)
elif config_dict is None:
raise ConfigurationError("Missing provider configuration.")
else:
provider_config = DatahubIngestionReportingProviderConfig.parse_obj(
config_dict
)
if provider_config.datahub_api:
graph = DataHubGraph(provider_config.datahub_api)
ctx.graph = graph
return cls(graph, name)
else:
raise ConfigurationError(
"Missing datahub_api. Provide either a global one or under the state_provider."
)
def _is_server_stateful_ingestion_capable(self) -> bool:
server_config = self.graph.get_config() if self.graph else None
if server_config and server_config.get("statefulIngestionCapable"):
return True
return False
def get_latest_run_summary(
self,
pipeline_name: str,
platform_instance_id: str,
job_name: JobId,
) -> Optional[DatahubIngestionRunSummaryClass]:
logger.info(
f"Querying for the latest ingestion run summary for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}'"
)
data_job_urn = self.get_data_job_urn(
self.orchestrator_name, pipeline_name, job_name, platform_instance_id
)
latest_run_summary: Optional[
DatahubIngestionRunSummaryClass
] = self.graph.get_latest_timeseries_value(
entity_urn=data_job_urn,
aspect_name="datahubIngestionRunSummary",
filter_criteria_map={
"pipelineName": pipeline_name,
"platformInstanceId": platform_instance_id,
},
aspect_type=DatahubIngestionRunSummaryClass,
)
if latest_run_summary:
logger.info(
f"The latest saved run summary for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found with start_time:"
f" {datetime.fromtimestamp(latest_run_summary.timestampMillis/1000, tz=timezone.utc)} and a"
f" bucket duration of {latest_run_summary.eventGranularity}."
)
return latest_run_summary
else:
logger.info(
f"No committed ingestion run summary for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found"
)
return None
def get_previous_states(
self,
state_key: JobStateKey,
last_only: bool = True,
filter_opt: Optional[JobStateFilterType] = None,
) -> List[ReportingJobStatesMap]:
if not last_only:
raise NotImplementedError(
"Currently supports retrieving only the last commited state."
)
if filter_opt is not None:
raise NotImplementedError(
"Support for optional filters is not implemented yet."
)
job_run_summaries: List[ReportingJobStatesMap] = []
last_job_run_summary_map: ReportingJobStatesMap = {}
for job_name in state_key.job_names:
last_job_run_summary = self.get_latest_run_summary(
state_key.pipeline_name, state_key.platform_instance_id, job_name
)
if last_job_run_summary is not None:
last_job_run_summary_map[job_name] = last_job_run_summary
job_run_summaries.append(last_job_run_summary_map)
return job_run_summaries
def commit(self) -> None:
if not self.state_to_commit:
# Useful to track source types for which reporting provider need to be enabled.
logger.info(f"No state to commit for {self.name}")
return None
for job_name, run_summary in self.state_to_commit.items():
# Emit the ingestion state for each job
logger.info(
f"Committing ingestion run summary for pipeline:'{run_summary.pipelineName}',"
f"instance:'{run_summary.platformInstanceId}', job:'{job_name}'"
)
self.committed = False
datajob_urn = self.get_data_job_urn(
self.orchestrator_name,
run_summary.pipelineName,
job_name,
run_summary.platformInstanceId,
)
self.graph.emit_mcp(
MetadataChangeProposalWrapper(
entityType="dataJob",
entityUrn=datajob_urn,
aspectName="datahubIngestionRunSummary",
aspect=run_summary,
changeType=ChangeTypeClass.UPSERT,
)
)
self.committed = True
logger.info(
f"Committed ingestion run summary for pipeline:'{run_summary.pipelineName}',"
f"instance:'{run_summary.platformInstanceId}', job:'{job_name}'"
)

View File

@ -0,0 +1,12 @@
from datahub.ingestion.api.ingestion_job_reporting_provider_base import (
IngestionReportingProviderBase,
)
from datahub.ingestion.api.registry import PluginRegistry
reporting_provider_registry = PluginRegistry[IngestionReportingProviderBase]()
reporting_provider_registry.register_from_entrypoint(
"datahub.ingestion.reporting_provider.plugins"
)
# These providers are always enabled
assert reporting_provider_registry.get("datahub")

View File

@ -13,12 +13,16 @@ from datahub.configuration.common import (
DynamicTypedConfig,
PipelineExecutionError,
)
from datahub.ingestion.api.committable import CommitPolicy
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.sink import Sink, WriteCallback
from datahub.ingestion.api.source import Extractor, Source
from datahub.ingestion.api.transform import Transformer
from datahub.ingestion.extractor.extractor_registry import extractor_registry
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.reporting.reporting_provider_registry import (
reporting_provider_registry,
)
from datahub.ingestion.sink.sink_registry import sink_registry
from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry
@ -39,6 +43,7 @@ class PipelineConfig(ConfigModel):
source: SourceConfig
sink: DynamicTypedConfig
transformers: Optional[List[DynamicTypedConfig]]
reporting: Optional[List[DynamicTypedConfig]] = None
run_id: str = "__DEFAULT_RUN_ID"
datahub_api: Optional[DatahubClientConfig] = None
pipeline_name: Optional[str] = None
@ -127,6 +132,7 @@ class Pipeline:
self.extractor_class = extractor_registry.get(self.config.source.extractor)
self._configure_transforms()
self._configure_reporting()
def _configure_transforms(self) -> None:
self.transformers = []
@ -142,6 +148,25 @@ class Pipeline:
f"Transformer type:{transformer_type},{transformer_class} configured"
)
def _configure_reporting(self) -> None:
if self.config.reporting is None:
return
for reporter in self.config.reporting:
reporter_type = reporter.type
reporter_class = reporting_provider_registry.get(reporter_type)
reporter_config_dict = reporter.dict().get("config", {})
self.ctx.register_reporter(
reporter_class.create(
config_dict=reporter_config_dict,
ctx=self.ctx,
name=reporter_class.__name__,
)
)
logger.debug(
f"Transformer type:{reporter_type},{reporter_class} configured"
)
@classmethod
def create(
cls, config_dict: dict, dry_run: bool = False, preview_mode: bool = False
@ -169,17 +194,9 @@ class Pipeline:
extractor.close()
if not self.dry_run:
self.sink.handle_work_unit_end(wu)
self.source.close()
self.sink.close()
# Temporary hack to prevent committing state if there are failures during the pipeline run.
try:
self.raise_from_status()
except Exception:
logger.warning(
"Pipeline failed. Not closing the source to prevent bad commits."
)
else:
self.source.close()
self.process_commits()
def transform(self, records: Iterable[RecordEnvelope]) -> Iterable[RecordEnvelope]:
"""
@ -192,6 +209,46 @@ class Pipeline:
return records
def process_commits(self) -> None:
"""
Evaluates the commit_policy for each committable in the context and triggers the commit operation
on the committable if its required commit policies are satisfied.
"""
has_errors: bool = (
True
if self.source.get_report().failures or self.sink.get_report().failures
else False
)
has_warnings: bool = (
True
if self.source.get_report().warnings or self.sink.get_report().warnings
else False
)
for name, committable in self.ctx.get_committables():
commit_policy: CommitPolicy = committable.commit_policy
logger.info(
f"Processing commit request for {name}. Commit policy = {commit_policy},"
f" has_errors={has_errors}, has_warnings={has_warnings}"
)
if (
commit_policy == CommitPolicy.ON_NO_ERRORS_AND_NO_WARNINGS
and (has_errors or has_warnings)
) or (commit_policy == CommitPolicy.ON_NO_ERRORS and has_errors):
logger.warning(
f"Skipping commit request for {name} since policy requirements are not met."
)
continue
try:
committable.commit()
except Exception as e:
logger.error(f"Failed to commit changes for {name}.", e)
raise e
else:
logger.info(f"Successfully committed changes for {name}.")
def raise_from_status(self, raise_warnings: bool = False) -> None:
if self.source.get_report().failures:
raise PipelineExecutionError(

View File

@ -66,6 +66,7 @@ from datahub.metadata.schema_classes import (
ChangeTypeClass,
DataPlatformInstanceClass,
DatasetPropertiesClass,
JobStatusClass,
)
from datahub.telemetry import telemetry
from datahub.utilities.sqlalchemy_query_combiner import SQLAlchemyQueryCombinerReport
@ -442,6 +443,18 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
)
return None
def update_default_job_run_summary(self) -> None:
summary = self.get_job_run_summary(self.get_default_ingestion_job_id())
if summary is not None:
# For now just add the config and the report.
summary.config = self.config.json()
summary.custom_summary = self.report.as_string()
summary.runStatus = (
JobStatusClass.FAILED
if self.get_report().failures
else JobStatusClass.COMPLETED
)
def get_schema_names(self, inspector):
return inspector.get_schema_names()
@ -997,6 +1010,5 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
return self.report
def close(self):
if self.is_stateful_ingestion_configured():
# Commit the checkpoints for this run
self.commit_checkpoints()
self.update_default_job_run_summary()
self.prepare_for_commit()

View File

@ -33,7 +33,10 @@ class CheckpointStateBase(ConfigModel):
compressor: Callable[[bytes], bytes] = functools.partial(
bz2.compress, compresslevel=9
),
max_allowed_state_size: int = 2**22, # 4MB
# fmt: off
# 4 MB
max_allowed_state_size: int = 2**22,
# fmt: on
) -> bytes:
"""
NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect
@ -91,34 +94,43 @@ class Checkpoint:
# Construct the config
config_as_dict = json.loads(checkpoint_aspect.config)
config_obj = config_class.parse_obj(config_as_dict)
# Construct the state
state_as_dict = (
CheckpointStateBase.from_bytes_to_dict(checkpoint_aspect.state.payload)
if checkpoint_aspect.state.payload is not None
else {}
)
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
state_obj = state_class.parse_obj(state_as_dict)
except Exception as e:
logger.error(
"Failed to construct checkpoint class from checkpoint aspect.", e
# Failure to load config is probably okay...config structure has changed.
logger.warning(
"Failed to construct checkpoint's config from checkpoint aspect.", e
)
else:
# Construct the deserialized Checkpoint object from the raw aspect.
checkpoint = cls(
job_name=job_name,
pipeline_name=checkpoint_aspect.pipelineName,
platform_instance_id=checkpoint_aspect.platformInstanceId,
run_id=checkpoint_aspect.runId,
config=config_obj,
state=state_obj,
)
logger.info(
f"Successfully constructed last checkpoint state for job {job_name}"
)
return checkpoint
try:
# Construct the state
state_as_dict = (
CheckpointStateBase.from_bytes_to_dict(
checkpoint_aspect.state.payload
)
if checkpoint_aspect.state.payload is not None
else {}
)
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
state_obj = state_class.parse_obj(state_as_dict)
except Exception as e:
logger.error(
"Failed to construct checkpoint class from checkpoint aspect.", e
)
raise e
else:
# Construct the deserialized Checkpoint object from the raw aspect.
checkpoint = cls(
job_name=job_name,
pipeline_name=checkpoint_aspect.pipelineName,
platform_instance_id=checkpoint_aspect.platformInstanceId,
run_id=checkpoint_aspect.runId,
config=config_obj,
state=state_obj,
)
logger.info(
f"Successfully constructed last checkpoint state for job {job_name}"
)
return checkpoint
return None
def to_checkpoint_aspect(

View File

@ -51,7 +51,7 @@ class BaseSQLAlchemyCheckpointState(CheckpointStateBase):
self, checkpoint: "BaseSQLAlchemyCheckpointState"
) -> Iterable[str]:
yield from self._get_urns_not_in(
checkpoint.encoded_view_urns, self.encoded_view_urns
self.encoded_view_urns, checkpoint.encoded_view_urns
)
def add_table_urn(self, table_urn: str) -> None:

View File

@ -1,6 +1,9 @@
import logging
from typing import Any, Dict, Optional, Type
import platform
from datetime import datetime
from typing import Any, Dict, Optional, Type, cast
import psutil
import pydantic
from datahub.configuration.common import (
@ -10,16 +13,23 @@ from datahub.configuration.common import (
)
from datahub.configuration.source_common import DatasetSourceConfigBase
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_state_provider import IngestionStateProvider, JobId
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
IngestionCheckpointingProviderBase,
JobId,
)
from datahub.ingestion.api.ingestion_job_reporting_provider_base import (
IngestionReportingProviderBase,
)
from datahub.ingestion.api.source import Source
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state_provider.datahub_ingestion_state_provider import (
DatahubIngestionStateProviderConfig,
)
from datahub.ingestion.source.state_provider.state_provider_registry import (
ingestion_state_provider_registry,
ingestion_checkpoint_provider_registry,
)
from datahub.metadata.schema_classes import (
DatahubIngestionCheckpointClass,
DatahubIngestionRunSummaryClass,
JobStatusClass,
)
from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
logger: logging.Logger = logging.getLogger(__name__)
@ -30,10 +40,10 @@ class StatefulIngestionConfig(ConfigModel):
"""
enabled: bool = False
max_checkpoint_state_size: int = 2**24 # 16MB
state_provider: Optional[DynamicTypedConfig] = DynamicTypedConfig(
type="datahub", config=DatahubIngestionStateProviderConfig()
)
# fmt: off
max_checkpoint_state_size: pydantic.PositiveInt = 2**24 # 16MB
# fmt: on
state_provider: Optional[DynamicTypedConfig] = None
ignore_old_state: bool = False
ignore_new_state: bool = False
@ -41,8 +51,8 @@ class StatefulIngestionConfig(ConfigModel):
def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("enabled"):
if values.get("state_provider") is None:
raise ConfigurationError(
"Must specify state_provider configuration if stateful ingestion is enabled."
values["state_provider"] = DynamicTypedConfig(
type="datahub", config=None
)
return values
@ -68,10 +78,17 @@ class StatefulIngestionSourceBase(Source):
self.source_config_type = type(config)
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self._initialize_state_provider()
self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {}
self._initialize_checkpointing_state_provider()
def _initialize_state_provider(self) -> None:
self.ingestion_state_provider: Optional[IngestionStateProvider] = None
#
# Checkpointing specific support.
#
def _initialize_checkpointing_state_provider(self) -> None:
self.ingestion_checkpointing_state_provider: Optional[
IngestionCheckpointingProviderBase
] = None
if (
self.stateful_ingestion_config is not None
and self.stateful_ingestion_config.state_provider is not None
@ -81,13 +98,26 @@ class StatefulIngestionSourceBase(Source):
raise ConfigurationError(
"pipeline_name must be provided if stateful ingestion is enabled."
)
state_provider_class = ingestion_state_provider_registry.get(
self.stateful_ingestion_config.state_provider.type
checkpointing_state_provider_class = (
ingestion_checkpoint_provider_registry.get(
self.stateful_ingestion_config.state_provider.type
)
)
self.ingestion_state_provider = state_provider_class.create(
if checkpointing_state_provider_class is None:
raise ConfigurationError(
f"Cannot find checkpoint provider class of type={self.stateful_ingestion_config.state_provider.type} "
" in the registry! Please check the type of the checkpointing provider in your config."
)
config_dict: Dict[str, Any] = cast(
Dict[str, Any],
self.stateful_ingestion_config.state_provider.dict().get("config", {}),
self.ctx,
)
self.ingestion_checkpointing_state_provider = checkpointing_state_provider_class.create( # type: ignore
config_dict=config_dict,
ctx=self.ctx,
name=checkpointing_state_provider_class.__name__,
)
assert self.ingestion_checkpointing_state_provider
if self.stateful_ingestion_config.ignore_old_state:
logger.warning(
"The 'ignore_old_state' config is True. The old checkpoint state will not be provided."
@ -96,6 +126,8 @@ class StatefulIngestionSourceBase(Source):
logger.warning(
"The 'ignore_new_state' config is True. The new checkpoint state will not be created."
)
# Add the checkpoint state provide to the platform context.
self.ctx.register_checkpointer(self.ingestion_checkpointing_state_provider)
logger.debug(
f"Successfully created {self.stateful_ingestion_config.state_provider.type} state provider."
@ -105,7 +137,7 @@ class StatefulIngestionSourceBase(Source):
if (
self.stateful_ingestion_config is not None
and self.stateful_ingestion_config.enabled
and self.ingestion_state_provider is not None
and self.ingestion_checkpointing_state_provider is not None
):
return True
return False
@ -134,7 +166,7 @@ class StatefulIngestionSourceBase(Source):
last_checkpoint: Optional[Checkpoint] = None
if self.is_stateful_ingestion_configured():
# Obtain the latest checkpoint from GMS for this job.
last_checkpoint_aspect = self.ingestion_state_provider.get_latest_checkpoint( # type: ignore
last_checkpoint_aspect = self.ingestion_checkpointing_state_provider.get_latest_checkpoint( # type: ignore
pipeline_name=self.ctx.pipeline_name, # type: ignore
platform_instance_id=self.get_platform_instance_id(),
job_name=job_id,
@ -176,7 +208,8 @@ class StatefulIngestionSourceBase(Source):
)
return self.cur_checkpoints[job_id]
def commit_checkpoints(self) -> None:
def _prepare_checkpoint_states_for_commit(self) -> None:
# Perform validations
if not self.is_stateful_ingestion_configured():
return None
if (
@ -193,6 +226,8 @@ class StatefulIngestionSourceBase(Source):
f" or preview_mode(={self.ctx.preview_mode})."
)
return None
# Prepare the state the checkpointing provider should commit.
job_checkpoint_aspects: Dict[JobId, DatahubIngestionCheckpointClass] = {}
for job_name, job_checkpoint in self.cur_checkpoints.items():
if job_checkpoint is None:
@ -210,6 +245,70 @@ class StatefulIngestionSourceBase(Source):
if checkpoint_aspect is not None:
job_checkpoint_aspects[job_name] = checkpoint_aspect
self.ingestion_state_provider.commit_checkpoints( # type: ignore
job_checkpoints=job_checkpoint_aspects
# Set the state to commit in the provider.
assert self.ingestion_checkpointing_state_provider
self.ingestion_checkpointing_state_provider.state_to_commit.update(
job_checkpoint_aspects
)
#
# Reporting specific support.
#
def _is_reporting_enabled(self):
for rc in self.ctx.get_reporters():
assert rc is not None
return True
return False
def _create_default_job_run_summary(self) -> DatahubIngestionRunSummaryClass:
assert self.ctx.pipeline_name
job_run_summary_default = DatahubIngestionRunSummaryClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=self.ctx.pipeline_name,
platformInstanceId=self.get_platform_instance_id(),
runId=self.ctx.run_id,
runStatus=JobStatusClass.COMPLETED,
)
# Add system specific info
job_run_summary_default.systemHostName = platform.node()
job_run_summary_default.operatingSystemName = platform.system()
job_run_summary_default.numProcessors = psutil.cpu_count(logical=True)
vmem = psutil.virtual_memory()
job_run_summary_default.availableMemory = getattr(vmem, "available", None)
job_run_summary_default.totalMemory = getattr(vmem, "total", None)
# Sources can add config in config + source report in custom_value.
# and also populate other source specific metrics.
return job_run_summary_default
def get_job_run_summary(
self, job_id: JobId
) -> Optional[DatahubIngestionRunSummaryClass]:
"""
Get the cached/newly created job run summary for this job if reporting is configured.
"""
if not self._is_reporting_enabled():
return None
if job_id not in self.run_summaries_to_report:
self.run_summaries_to_report[
job_id
] = self._create_default_job_run_summary()
return self.run_summaries_to_report[job_id]
#
# Commit handoff to provider for both checkpointing and reporting.
#
def _prepare_job_run_summaries_for_commit(self) -> None:
for reporting_committable in self.ctx.get_reporters():
if isinstance(reporting_committable, IngestionReportingProviderBase):
reporting_provider = cast(
IngestionReportingProviderBase, reporting_committable
)
reporting_provider.state_to_commit.update(self.run_summaries_to_report)
logger.info(
f"Successfully handed-off job run summaries to {reporting_provider.name}."
)
def prepare_for_commit(self) -> None:
"""NOTE: Sources should call this method from their close method."""
self._prepare_checkpoint_states_for_commit()
self._prepare_job_run_summaries_for_commit()

View File

@ -1,3 +1,5 @@
import pydantic
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
@ -8,5 +10,5 @@ class BaseUsageCheckpointState(CheckpointStateBase):
Subclasses can define additional state as appropriate.
"""
begin_timestamp_millis: int
end_timestamp_millis: int
begin_timestamp_millis: pydantic.PositiveInt
end_timestamp_millis: pydantic.PositiveInt

View File

@ -1,33 +1,36 @@
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.configuration.common import ConfigurationError
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_state_provider import IngestionStateProvider, JobId
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
CheckpointJobStatesMap,
IngestionCheckpointingProviderBase,
IngestionCheckpointingProviderConfig,
JobId,
JobStateFilterType,
JobStateKey,
)
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.metadata.schema_classes import (
CalendarIntervalClass,
ChangeTypeClass,
DatahubIngestionCheckpointClass,
DatahubIngestionRunSummaryClass,
TimeWindowSizeClass,
)
logger = logging.getLogger(__name__)
class DatahubIngestionStateProviderConfig(ConfigModel):
class DatahubIngestionStateProviderConfig(IngestionCheckpointingProviderConfig):
datahub_api: Optional[DatahubClientConfig] = DatahubClientConfig()
class DatahubIngestionStateProvider(IngestionStateProvider):
class DatahubIngestionCheckpointingProvider(IngestionCheckpointingProviderBase):
orchestrator_name: str = "datahub"
def __init__(self, graph: DataHubGraph):
def __init__(self, graph: DataHubGraph, name: str):
super().__init__(name)
self.graph = graph
if not self._is_server_stateful_ingestion_capable():
raise ConfigurationError(
@ -37,17 +40,18 @@ class DatahubIngestionStateProvider(IngestionStateProvider):
@classmethod
def create(
cls, config_dict: Dict[str, Any], ctx: PipelineContext
) -> IngestionStateProvider:
cls, config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> IngestionCheckpointingProviderBase:
if ctx.graph:
return cls(ctx.graph)
# Use the pipeline-level graph if set
return cls(ctx.graph, name)
elif config_dict is None:
raise ConfigurationError("Missing provider configuration")
raise ConfigurationError("Missing provider configuration.")
else:
provider_config = DatahubIngestionStateProviderConfig.parse_obj(config_dict)
if provider_config.datahub_api:
graph = DataHubGraph(provider_config.datahub_api)
return cls(graph)
return cls(graph, name)
else:
raise ConfigurationError(
"Missing datahub_api. Provide either a global one or under the state_provider."
@ -71,8 +75,8 @@ class DatahubIngestionStateProvider(IngestionStateProvider):
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}'"
)
data_job_urn = builder.make_data_job_urn(
self.orchestrator_name, pipeline_name, job_name
data_job_urn = self.get_data_job_urn(
self.orchestrator_name, pipeline_name, job_name, platform_instance_id
)
latest_checkpoint: Optional[
DatahubIngestionCheckpointClass
@ -101,20 +105,50 @@ class DatahubIngestionStateProvider(IngestionStateProvider):
return None
def commit_checkpoints(
self, job_checkpoints: Dict[JobId, DatahubIngestionCheckpointClass]
) -> None:
for job_name, checkpoint in job_checkpoints.items():
def get_previous_states(
self,
state_key: JobStateKey,
last_only: bool = True,
filter_opt: Optional[JobStateFilterType] = None,
) -> List[CheckpointJobStatesMap]:
if not last_only:
raise NotImplementedError(
"Currently supports retrieving only the last commited state."
)
if filter_opt is not None:
raise NotImplementedError(
"Support for optional filters is not implemented yet."
)
checkpoints: List[CheckpointJobStatesMap] = []
last_job_checkpoint_map: CheckpointJobStatesMap = {}
for job_name in state_key.job_names:
last_job_checkpoint = self.get_latest_checkpoint(
state_key.pipeline_name, state_key.platform_instance_id, job_name
)
if last_job_checkpoint is not None:
last_job_checkpoint_map[job_name] = last_job_checkpoint
checkpoints.append(last_job_checkpoint_map)
return checkpoints
def commit(self) -> None:
if not self.state_to_commit:
logger.warning(f"No state available to commit for {self.name}")
return None
for job_name, checkpoint in self.state_to_commit.items():
# Emit the ingestion state for each job
logger.info(
f"Committing ingestion checkpoint for pipeline:'{checkpoint.pipelineName}',"
f"instance:'{checkpoint.platformInstanceId}', job:'{job_name}'"
)
datajob_urn = builder.make_data_job_urn(
self.committed = False
datajob_urn = self.get_data_job_urn(
self.orchestrator_name,
checkpoint.pipelineName,
job_name,
checkpoint.platformInstanceId,
)
self.graph.emit_mcp(
@ -127,59 +161,9 @@ class DatahubIngestionStateProvider(IngestionStateProvider):
)
)
self.committed = True
logger.info(
f"Committed ingestion checkpoint for pipeline:'{checkpoint.pipelineName}',"
f"instance:'{checkpoint.platformInstanceId}', job:'{job_name}'"
)
@staticmethod
def get_end_time(ingestion_state: DatahubIngestionRunSummaryClass) -> int:
start_time_millis = ingestion_state.timestampMillis
granularity = ingestion_state.eventGranularity
granularity_millis = (
DatahubIngestionStateProvider.get_granularity_to_millis(granularity)
if granularity is not None
else 0
)
return start_time_millis + granularity_millis
@staticmethod
def get_time_window_size(interval_str: str) -> TimeWindowSizeClass:
to_calendar_interval: Dict[str, str] = {
"s": CalendarIntervalClass.SECOND,
"m": CalendarIntervalClass.MINUTE,
"h": CalendarIntervalClass.HOUR,
"d": CalendarIntervalClass.DAY,
"W": CalendarIntervalClass.WEEK,
"M": CalendarIntervalClass.MONTH,
"Q": CalendarIntervalClass.QUARTER,
"Y": CalendarIntervalClass.YEAR,
}
interval_pattern = re.compile(r"(\d+)([s|m|h|d|W|M|Q|Y])")
token_search = interval_pattern.search(interval_str)
if token_search is None:
raise ValueError("Invalid interval string:", interval_str)
(multiples_str, unit_str) = (token_search.group(1), token_search.group(2))
if not multiples_str or not unit_str:
raise ValueError("Invalid interval string:", interval_str)
unit = to_calendar_interval.get(unit_str)
if not unit:
raise ValueError("Invalid time unit token:", unit_str)
return TimeWindowSizeClass(unit=unit, multiple=int(multiples_str))
@staticmethod
def get_granularity_to_millis(granularity: TimeWindowSizeClass) -> int:
to_millis_from_interval: Dict[str, int] = {
CalendarIntervalClass.SECOND: 1000,
CalendarIntervalClass.MINUTE: 60 * 1000,
CalendarIntervalClass.HOUR: 60 * 60 * 1000,
CalendarIntervalClass.DAY: 24 * 60 * 60 * 1000,
CalendarIntervalClass.WEEK: 7 * 24 * 60 * 60 * 1000,
CalendarIntervalClass.MONTH: 31 * 7 * 24 * 60 * 60 * 1000,
CalendarIntervalClass.QUARTER: 90 * 7 * 24 * 60 * 60 * 1000,
CalendarIntervalClass.YEAR: 365 * 7 * 24 * 60 * 60 * 1000,
}
units_to_millis = to_millis_from_interval.get(str(granularity.unit), None)
if not units_to_millis:
raise ValueError("Invalid unit", granularity.unit)
return granularity.multiple * units_to_millis

View File

@ -1,10 +1,15 @@
from datahub.ingestion.api.ingestion_state_provider import IngestionStateProvider
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
IngestionCheckpointingProviderBase,
)
from datahub.ingestion.api.registry import PluginRegistry
ingestion_state_provider_registry = PluginRegistry[IngestionStateProvider]()
ingestion_state_provider_registry.register_from_entrypoint(
"datahub.ingestion.state_provider.plugins"
ingestion_checkpoint_provider_registry = PluginRegistry[
IngestionCheckpointingProviderBase
]()
ingestion_checkpoint_provider_registry.register_from_entrypoint(
"datahub.ingestion.checkpointing_provider.plugins"
)
# These sinks are always enabled
assert ingestion_state_provider_registry.get("datahub")
# These providers are always enabled
assert ingestion_checkpoint_provider_registry.get("datahub")

View File

@ -32,8 +32,10 @@ from datahub.ingestion.source.usage.usage_common import (
)
from datahub.metadata.schema_classes import (
ChangeTypeClass,
JobStatusClass,
OperationClass,
OperationTypeClass,
TimeWindowSizeClass,
)
logger = logging.getLogger(__name__)
@ -168,6 +170,7 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
super(SnowflakeUsageSource, self).__init__(config, ctx)
self.config: SnowflakeUsageConfig = config
self.report: SourceReport = SourceReport()
self.should_skip_this_run = self._should_skip_this_run()
@classmethod
def create(cls, config_dict, ctx):
@ -252,9 +255,26 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
def _init_checkpoints(self):
self.get_current_checkpoint(self.get_default_ingestion_job_id())
def update_default_job_summary(self) -> None:
summary = self.get_job_run_summary(self.get_default_ingestion_job_id())
if summary is not None:
summary.runStatus = (
JobStatusClass.SKIPPED
if self.should_skip_this_run
else JobStatusClass.COMPLETED
)
summary.messageId = datetime.now().strftime("%m-%d-%Y,%H:%M:%S")
summary.eventGranularity = TimeWindowSizeClass(
unit=self.config.bucket_duration, multiple=1
)
summary.numWarnings = len(self.report.warnings)
summary.numErrors = len(self.report.failures)
summary.numEntities = self.report.workunits_produced
summary.config = self.config.json()
summary.custom_summary = self.report.as_string()
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
skip_this_run: bool = self._should_skip_this_run()
if not skip_this_run:
if not self.should_skip_this_run:
# Initialize the checkpoints
self._init_checkpoints()
# Generate the workunits.
@ -486,5 +506,5 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
return self.report
def close(self):
# Checkpoint this run
self.commit_checkpoints()
self.update_default_job_summary()
self.prepare_for_commit()

View File

@ -0,0 +1,189 @@
import types
import unittest
from typing import Dict, List, Optional, Type
from unittest.mock import MagicMock, patch
from avrogen.dict_wrapper import DictWrapper
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
CheckpointJobStatesMap,
CheckpointJobStateType,
IngestionCheckpointingProviderBase,
JobId,
JobStateKey,
)
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionCheckpointingProvider,
)
class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
# Static members for the tests
pipeline_name: str = "test_pipeline"
platform_instance_id: str = "test_platform_instance_1"
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
run_id: str = "test_run"
job_state_key: JobStateKey = JobStateKey(
pipeline_name=pipeline_name,
platform_instance_id=platform_instance_id,
job_names=job_names,
)
def setUp(self) -> None:
self._setup_mock_graph()
self.provider = self._create_provider()
assert self.provider
def _setup_mock_graph(self) -> None:
"""
Setup monkey-patched graph client.
"""
self.patcher = patch(
"datahub.ingestion.graph.client.DataHubGraph", autospec=True
)
self.addCleanup(self.patcher.stop)
self.mock_graph = self.patcher.start()
# Make server stateful ingestion capable
self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True}
# Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits.
self.mock_graph.emit_mcp = types.MethodType(
self.monkey_patch_emit_mcp, self.mock_graph
)
# Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value
self.mock_graph.get_latest_timeseries_value = types.MethodType(
self.monkey_patch_get_latest_timeseries_value, self.mock_graph
)
# Tracking for emitted mcps.
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}
def _create_provider(self) -> IngestionCheckpointingProviderBase:
ctx: PipelineContext = PipelineContext(
run_id=self.run_id, pipeline_name=self.pipeline_name
)
ctx.graph = self.mock_graph
return DatahubIngestionCheckpointingProvider.create(
{}, ctx, name=DatahubIngestionCheckpointingProvider.__name__
)
def monkey_patch_emit_mcp(
self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper
) -> None:
"""
Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory.
"""
self.assertIsNotNone(graph_ref)
self.assertEqual(mcpw.entityType, "dataJob")
self.assertEqual(mcpw.aspectName, "datahubIngestionCheckpoint")
# Cache the mcpw against the entityUrn
assert mcpw.entityUrn is not None
self.mcps_emitted[mcpw.entityUrn] = mcpw
def monkey_patch_get_latest_timeseries_value(
self,
graph_ref: MagicMock,
entity_urn: str,
aspect_name: str,
aspect_type: Type[DictWrapper],
filter_criteria_map: Dict[str, str],
) -> Optional[DictWrapper]:
"""
Monkey patched implementation of DatahubGraph.get_latest_timeseries_value that returns the latest cached aspect
for a given entity urn.
"""
self.assertIsNotNone(graph_ref)
self.assertEqual(aspect_name, "datahubIngestionCheckpoint")
self.assertEqual(aspect_type, CheckpointJobStateType)
self.assertEqual(
filter_criteria_map,
{
"pipelineName": self.pipeline_name,
"platformInstanceId": self.platform_instance_id,
},
)
# Retrieve the cached mcpw and return its aspect value.
mcpw = self.mcps_emitted.get(entity_urn)
if mcpw:
return mcpw.aspect
return None
def test_provider(self):
# 1. Create the individual job checkpoints with appropriate states.
# Job1 - Checkpoint with a BaseSQLAlchemyCheckpointState state
job1_state_obj = BaseSQLAlchemyCheckpointState()
job1_checkpoint = Checkpoint(
job_name=self.job_names[0],
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
state=job1_state_obj,
)
# Job2 - Checkpoint with a BaseUsageCheckpointState state
job2_state_obj = BaseUsageCheckpointState(
begin_timestamp_millis=10, end_timestamp_millis=100
)
job2_checkpoint = Checkpoint(
job_name=self.job_names[1],
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
state=job2_state_obj,
)
# 2. Set the provider's state_to_commit.
self.provider.state_to_commit = {
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
self.job_names[0]: job1_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
self.job_names[1]: job2_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
}
# 3. Perform the commit
# NOTE: This will commit the state to the in-memory self.mcps_emitted because of the monkey-patching.
self.provider.commit()
self.assertTrue(self.provider.committed)
# 4. Get last committed state. This must match what has been committed earlier.
# NOTE: This will retrieve from in-memory self.mcps_emitted because of the monkey-patching.
last_state: Optional[CheckpointJobStatesMap] = self.provider.get_last_state(
self.job_state_key
)
assert last_state is not None
self.assertEqual(len(last_state), 2)
# 5. Validate individual job checkpoint state values that have been committed and retrieved
# against the original values.
self.assertIsNotNone(last_state[self.job_names[0]])
job1_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[0],
checkpoint_aspect=last_state[self.job_names[0]],
state_class=type(job1_state_obj),
config_class=type(job1_checkpoint.config),
)
self.assertEqual(job1_last_checkpoint, job1_checkpoint)
self.assertIsNotNone(last_state[self.job_names[1]])
job2_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[1],
checkpoint_aspect=last_state[self.job_names[1]],
state_class=type(job2_state_obj),
config_class=type(job2_checkpoint.config),
)
self.assertEqual(job2_last_checkpoint, job2_checkpoint)

View File

@ -0,0 +1,156 @@
import types
import unittest
from datetime import datetime
from typing import Dict, List, Optional, Type
from unittest.mock import MagicMock, patch
from avrogen.dict_wrapper import DictWrapper
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_reporting_provider_base import (
IngestionReportingProviderBase,
JobId,
JobStateKey,
ReportingJobStatesMap,
ReportingJobStateType,
)
from datahub.ingestion.reporting.datahub_ingestion_reporting_provider import (
DatahubIngestionReportingProvider,
)
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.metadata.schema_classes import JobStatusClass
class TestDatahubIngestionReportingProvider(unittest.TestCase):
# Static members for the tests
pipeline_name: str = "test_pipeline"
platform_instance_id: str = "test_platform_instance_1"
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
run_id: str = "test_run"
job_state_key: JobStateKey = JobStateKey(
pipeline_name=pipeline_name,
platform_instance_id=platform_instance_id,
job_names=job_names,
)
def setUp(self) -> None:
self._setup_mock_graph()
self.provider = self._create_provider()
assert self.provider
def _setup_mock_graph(self) -> None:
"""
Setup monkey-patched graph client.
"""
self.patcher = patch(
"datahub.ingestion.graph.client.DataHubGraph", autospec=True
)
self.addCleanup(self.patcher.stop)
self.mock_graph = self.patcher.start()
# Make server stateful ingestion capable
self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True}
# Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits.
self.mock_graph.emit_mcp = types.MethodType(
self.monkey_patch_emit_mcp, self.mock_graph
)
# Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value
self.mock_graph.get_latest_timeseries_value = types.MethodType(
self.monkey_patch_get_latest_timeseries_value, self.mock_graph
)
# Tracking for emitted mcps.
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}
def _create_provider(self) -> IngestionReportingProviderBase:
ctx: PipelineContext = PipelineContext(
run_id=self.run_id, pipeline_name=self.pipeline_name
)
ctx.graph = self.mock_graph
return DatahubIngestionReportingProvider.create(
{}, ctx, name=DatahubIngestionReportingProvider.__name__
)
def monkey_patch_emit_mcp(
self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper
) -> None:
"""
Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory.
"""
self.assertIsNotNone(graph_ref)
self.assertEqual(mcpw.entityType, "dataJob")
self.assertEqual(mcpw.aspectName, "datahubIngestionRunSummary")
# Cache the mcpw against the entityUrn
assert mcpw.entityUrn is not None
self.mcps_emitted[mcpw.entityUrn] = mcpw
def monkey_patch_get_latest_timeseries_value(
self,
graph_ref: MagicMock,
entity_urn: str,
aspect_name: str,
aspect_type: Type[DictWrapper],
filter_criteria_map: Dict[str, str],
) -> Optional[DictWrapper]:
"""
Monkey patched implementation of DatahubGraph.get_latest_timeseries_value that returns the latest cached aspect
for a given entity urn.
"""
self.assertIsNotNone(graph_ref)
self.assertEqual(aspect_name, "datahubIngestionRunSummary")
self.assertEqual(aspect_type, ReportingJobStateType)
self.assertEqual(
filter_criteria_map,
{
"pipelineName": self.pipeline_name,
"platformInstanceId": self.platform_instance_id,
},
)
# Retrieve the cached mcpw and return its aspect value.
mcpw = self.mcps_emitted.get(entity_urn)
if mcpw:
return mcpw.aspect
return None
def test_provider(self):
# 1. Create the job reports
job_reports: Dict[JobId, ReportingJobStateType] = {
# A completed job
self.job_names[0]: ReportingJobStateType(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=self.pipeline_name,
platformInstanceId=self.platform_instance_id,
runId=self.run_id,
runStatus=JobStatusClass.COMPLETED,
config=MySQLConfig().json(),
),
# A skipped job
self.job_names[1]: ReportingJobStateType(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=self.pipeline_name,
platformInstanceId=self.platform_instance_id,
runId=self.run_id,
runStatus=JobStatusClass.SKIPPED,
config=MySQLConfig().json(),
),
}
# 2. Set the provider's state_to_commit.
self.provider.state_to_commit = job_reports
# 3. Perform the commit
# NOTE: This will commit the state to the in-memory self.mcps_emitted because of the monkey-patching.
self.provider.commit()
self.assertTrue(self.provider.committed)
# 4. Get last committed state. This must match what has been committed earlier.
# NOTE: This will retrieve from in-memory self.mcps_emitted because of the monkey-patching.
last_state: Optional[ReportingJobStatesMap] = self.provider.get_last_state(
self.job_state_key
)
assert last_state is not None
self.assertEqual(len(last_state), 2)
# 5. Validate individual job report values that have been committed and retrieved
# against the original values.
self.assertEqual(last_state, job_reports)

View File

@ -0,0 +1,130 @@
from datetime import datetime
from typing import Dict
import pytest
from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.sql.sql_common import BasicSQLAlchemyConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState
from datahub.metadata.schema_classes import (
DatahubIngestionCheckpointClass,
IngestionCheckpointStateClass,
)
# 1. Setup common test param values.
test_pipeline_name: str = "test_pipeline"
test_platform_instance_id: str = "test_platform_instance_1"
test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = MySQLConfig()
# 2. Create the params for parametrized tests.
# 2.1 Create and add an instance of BaseSQLAlchemyCheckpointState.
test_checkpoint_serde_params: Dict[str, CheckpointStateBase] = {}
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_table_urn(
make_dataset_urn("mysql", "db1.t1", "prod")
)
base_sql_alchemy_checkpoint_state_obj.add_view_urn(
make_dataset_urn("mysql", "db1.v1", "prod")
)
test_checkpoint_serde_params[
"BaseSQLAlchemyCheckpointState"
] = base_sql_alchemy_checkpoint_state_obj
# 2.2 Create and add an instance of BaseUsageCheckpointState.
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
test_checkpoint_serde_params[
"BaseUsageCheckpointState"
] = base_usage_checkpoint_state_obj
# 3. Define the test with the params
@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
)
def test_create_from_checkpoint_aspect(state_obj):
"""
Tests the Checkpoint class API 'create_from_checkpoint_aspect' with the state_obj parameter as the state.
"""
# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=test_pipeline_name,
platformInstanceId=test_platform_instance_id,
config=test_source_config.json(),
state=checkpoint_state,
runId=test_run_id,
)
# 2. Create the checkpoint from the raw checkpoint aspect and validate.
checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
)
expected_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
config=test_source_config,
state=state_obj,
)
assert checkpoint_obj == expected_checkpoint_obj
@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
)
def test_serde_idempotence(state_obj):
"""
Verifies that Serialization + Deserialization reconstructs the original object fully.
"""
# 1. Construct the initial checkpoint object
orig_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
config=test_source_config,
state=state_obj,
)
# 2. Convert it to the aspect form.
checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
)
assert checkpoint_aspect is not None
# 3. Reconstruct from the aspect form and verify that it matches the original.
serde_checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
)
assert orig_checkpoint_obj == serde_checkpoint_obj

View File

@ -0,0 +1,20 @@
from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
def test_sql_common_state() -> None:
state1 = BaseSQLAlchemyCheckpointState()
test_table_urn = make_dataset_urn("test_platform", "db1.test_table1", "test")
state1.add_table_urn(test_table_urn)
test_view_urn = make_dataset_urn("test_platform", "db1.test_view1", "test")
state1.add_view_urn(test_view_urn)
state2 = BaseSQLAlchemyCheckpointState()
table_urns_diff = list(state1.get_table_urns_not_in(state2))
assert len(table_urns_diff) == 1 and table_urns_diff[0] == test_table_urn
view_urns_diff = list(state1.get_view_urns_not_in(state2))
assert len(view_urns_diff) == 1 and view_urns_diff[0] == test_view_urn

View File

@ -0,0 +1,263 @@
from typing import Any, Dict, Optional, Tuple, Type, cast
import pytest
from pydantic import ValidationError
from datahub.configuration.common import ConfigModel, DynamicTypedConfig
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.reporting.datahub_ingestion_reporting_provider import (
DatahubIngestionReportingProviderConfig,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfig,
)
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionStateProviderConfig,
)
# 0. Common client configs.
datahub_client_configs: Dict[str, Any] = {
"full": {
"server": "http://localhost:8080",
"token": "dummy_test_tok",
"timeout_sec": 10,
"extra_headers": {},
"max_threads": 1,
},
"simple": {},
"default": {},
"none": None,
}
# 1. Datahub Checkpointing State Provider Config test params
checkpointing_provider_config_test_params: Dict[
str,
Tuple[
Type[DatahubIngestionStateProviderConfig],
Dict[str, Any],
Optional[DatahubIngestionStateProviderConfig],
bool,
],
] = {
# Full custom-config
"checkpointing_valid_full_config": (
DatahubIngestionStateProviderConfig,
{
"datahub_api": datahub_client_configs["full"],
},
DatahubIngestionStateProviderConfig(
datahub_api=DatahubClientConfig(
server="http://localhost:8080",
token="dummy_test_tok",
timeout_sec=10,
extra_headers={},
max_threads=1,
),
),
False,
),
# Simple config
"checkpointing_valid_simple_config": (
DatahubIngestionStateProviderConfig,
{
"datahub_api": datahub_client_configs["simple"],
},
DatahubIngestionStateProviderConfig(
datahub_api=DatahubClientConfig(
server="http://localhost:8080",
),
),
False,
),
# Default
"checkpointing_default": (
DatahubIngestionStateProviderConfig,
{
"datahub_api": datahub_client_configs["default"],
},
DatahubIngestionStateProviderConfig(
datahub_api=DatahubClientConfig(),
),
False,
),
# None
"checkpointing_bad_config": (
DatahubIngestionStateProviderConfig,
datahub_client_configs["none"],
None,
True,
),
}
# 2. Datahub Reporting Provider Config test params
reporting_provider_config_test_params: Dict[
str,
Tuple[
Type[DatahubIngestionReportingProviderConfig],
Dict[str, Any],
Optional[DatahubIngestionReportingProviderConfig],
bool,
],
] = {
# Full custom-config
"reporting_valid_full_config": (
DatahubIngestionReportingProviderConfig,
{
"datahub_api": datahub_client_configs["full"],
},
DatahubIngestionReportingProviderConfig(
datahub_api=DatahubClientConfig(
server="http://localhost:8080",
token="dummy_test_tok",
timeout_sec=10,
extra_headers={},
max_threads=1,
),
),
False,
),
# Simple config
"reporting_valid_simple_config": (
DatahubIngestionReportingProviderConfig,
{
"datahub_api": datahub_client_configs["simple"],
},
DatahubIngestionReportingProviderConfig(
datahub_api=DatahubClientConfig(
server="http://localhost:8080",
),
),
False,
),
# Default
"reporting_default": (
DatahubIngestionReportingProviderConfig,
{
"datahub_api": datahub_client_configs["default"],
},
DatahubIngestionReportingProviderConfig(
datahub_api=DatahubClientConfig(),
),
False,
),
# None
"reporting_bad_config": (
DatahubIngestionReportingProviderConfig,
datahub_client_configs["none"],
None,
True,
),
}
# 3. StatefulIngestion Config test params
stateful_ingestion_config_test_params: Dict[
str,
Tuple[
Type[StatefulIngestionConfig],
Dict[str, Any],
Optional[StatefulIngestionConfig],
bool,
],
] = {
# Ful custom-config
"stateful_ingestion_full_custom": (
StatefulIngestionConfig,
{
"enabled": True,
"max_checkpoint_state_size": 1024,
"state_provider": {
"type": "datahub",
"config": datahub_client_configs["full"],
},
"ignore_old_state": True,
"ignore_new_state": True,
},
StatefulIngestionConfig(
enabled=True,
max_checkpoint_state_size=1024,
ignore_old_state=True,
ignore_new_state=True,
state_provider=DynamicTypedConfig(
type="datahub",
config=datahub_client_configs["full"],
),
),
False,
),
# Default disabled
"stateful_ingestion_default_disabled": (
StatefulIngestionConfig,
{},
StatefulIngestionConfig(
enabled=False,
# fmt: off
max_checkpoint_state_size=2**24,
# fmt: on
ignore_old_state=False,
ignore_new_state=False,
state_provider=None,
),
False,
),
# Default enabled
"stateful_ingestion_default_enabled": (
StatefulIngestionConfig,
{"enabled": True},
StatefulIngestionConfig(
enabled=True,
# fmt: off
max_checkpoint_state_size=2**24,
# fmt: on
ignore_old_state=False,
ignore_new_state=False,
state_provider=DynamicTypedConfig(type="datahub", config=None),
),
False,
),
# Bad Config- throws ValidationError
"stateful_ingestion_bad_config": (
StatefulIngestionConfig,
{"enabled": True, "state_provider": {}},
None,
True,
),
}
# 4. Combine all of the config params from 1, 2 & 3 above for the common parametrized test.
CombinedTestConfigType = Dict[
str,
Tuple[
Type[ConfigModel],
Dict[str, Any],
Optional[ConfigModel],
bool,
],
]
combined_test_configs = {
**cast(CombinedTestConfigType, checkpointing_provider_config_test_params),
**cast(CombinedTestConfigType, reporting_provider_config_test_params),
**cast(CombinedTestConfigType, stateful_ingestion_config_test_params),
}
@pytest.mark.parametrize(
"config_class, config_dict, expected, raises_exception",
combined_test_configs.values(),
ids=combined_test_configs.keys(),
)
def test_state_provider_configs(
config_class: Type[ConfigModel],
config_dict: Dict[str, Any],
expected: Optional[ConfigModel],
raises_exception: bool,
) -> None:
if raises_exception:
with pytest.raises(ValidationError):
assert expected is None
config_class.parse_obj(config_dict)
else:
config = config_class.parse_obj(config_dict)
assert config == expected

View File

@ -1,3 +1,4 @@
pytest>=6.2
pytest-dependency>=0.5.1
-e ../metadata-ingestion[datahub-rest,datahub-kafka,mysql]
psutil
-e ../metadata-ingestion[datahub-rest,datahub-kafka,mysql]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,129 @@
from typing import Any, Dict, Optional, cast
from datahub.ingestion.api.committable import StatefulCommittable
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource
from datahub.ingestion.source.sql.sql_common import \
BaseSQLAlchemyCheckpointState
from datahub.ingestion.source.state.checkpoint import Checkpoint
from sqlalchemy import create_engine
from sqlalchemy.sql import text
def test_stateful_ingestion(wait_for_healthchecks):
def create_mysql_engine(mysql_source_config_dict: Dict[str, Any]) -> Any:
mysql_config = MySQLConfig.parse_obj(mysql_source_config_dict)
url = mysql_config.get_sql_alchemy_url()
return create_engine(url)
def create_table(engine: Any, name: str, defn: str) -> None:
create_table_query = text(f"CREATE TABLE IF NOT EXISTS {name}{defn};")
engine.execute(create_table_query)
def drop_table(engine: Any, table_name: str) -> None:
drop_table_query = text(f"DROP TABLE {table_name};")
engine.execute(drop_table_query)
def run_and_get_pipeline(pipeline_config_dict: Dict[str, Any]) -> Pipeline:
pipeline = Pipeline.create(pipeline_config_dict)
pipeline.run()
pipeline.raise_from_status()
return pipeline
def validate_all_providers_have_committed_successfully(pipeline: Pipeline) -> None:
provider_count: int = 0
for name, provider in pipeline.ctx.get_committables():
provider_count += 1
assert isinstance(provider, StatefulCommittable)
stateful_committable = cast(StatefulCommittable, provider)
assert stateful_committable.has_successfully_committed()
assert stateful_committable.state_to_commit
assert provider_count == 2
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
mysql_source = cast(MySQLSource, pipeline.source)
return mysql_source.get_current_checkpoint(
mysql_source.get_default_ingestion_job_id()
)
source_config_dict: Dict[str, Any] = {
"username": "datahub",
"password": "datahub",
"database": "datahub",
"stateful_ingestion": {
"enabled": True,
"remove_stale_metadata": True,
"state_provider": {
"type": "datahub",
"config": {"datahub_api": {"server": "http://localhost:8080"}},
},
},
}
pipeline_config_dict: Dict[str, Any] = {
"source": {
"type": "mysql",
"config": source_config_dict,
},
"sink": {
"type": "datahub-rest",
"config": {"server": "http://localhost:8080"},
},
"pipeline_name": "mysql_stateful_ingestion_smoke_test_pipeline",
"reporting": [
{
"type": "datahub",
"config": {"datahub_api": {"server": "http://localhost:8080"}},
}
],
}
# 1. Setup the SQL engine
mysql_engine = create_mysql_engine(source_config_dict)
# 2. Create test tables for first run of the pipeline.
table_prefix = "stateful_ingestion_test"
table_defs = {
f"{table_prefix}_t1": "(id INT, name VARCHAR(10))",
f"{table_prefix}_t2": "(id INT)",
}
table_names = sorted(table_defs.keys())
for table_name, defn in table_defs.items():
create_table(mysql_engine, table_name, defn)
# 3. Do the first run of the pipeline and get the default job's checkpoint.
pipeline_run1 = run_and_get_pipeline(pipeline_config_dict)
checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1)
assert checkpoint1
assert checkpoint1.state
# 4. Drop table t1 created during step 2 + rerun the pipeline and get the checkpoint state.
drop_table(mysql_engine, table_names[0])
pipeline_run2 = run_and_get_pipeline(pipeline_config_dict)
checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2)
assert checkpoint2
assert checkpoint2.state
# 5. Perform all assertions on the states
state1 = cast(BaseSQLAlchemyCheckpointState, checkpoint1.state)
state2 = cast(BaseSQLAlchemyCheckpointState, checkpoint2.state)
difference_urns = list(state1.get_table_urns_not_in(state2))
assert len(difference_urns) == 1
assert (
difference_urns[0]
== "urn:li:dataset:(urn:li:dataPlatform:mysql,datahub.stateful_ingestion_test_t1,PROD)"
)
# 6. Perform all assertions on the config.
assert checkpoint1.config == checkpoint2.config
# 7. Cleanup table t2 as well to prevent other tests that rely on data in the smoke-test world.
drop_table(mysql_engine, table_names[1])
# 8. Validate that all providers have committed successfully.
# NOTE: The following validation asserts for presence of state as well
# and validates reporting.
validate_all_providers_have_committed_successfully(pipeline_run1)
validate_all_providers_have_committed_successfully(pipeline_run2)