fix(ingest): remove get_platform_instance_id from stateful ingestion (#7572)

Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
Harshal Sheth 2023-03-21 06:05:10 +05:30 committed by GitHub
parent cbd8e14b78
commit d54ff061a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 28 additions and 140 deletions

View File

@ -1,6 +1,5 @@
import json import json
import logging import logging
from typing import Optional
import click import click
from click_default_group import DefaultGroup from click_default_group import DefaultGroup
@ -29,34 +28,20 @@ def state() -> None:
@state.command() @state.command()
@click.option("--pipeline-name", required=True, type=str) @click.option("--pipeline-name", required=True, type=str)
@click.option("--platform", required=True, type=str) @click.option("--platform", required=True, type=str)
@click.option("--platform-instance", required=False, type=str)
@upgrade.check_upgrade @upgrade.check_upgrade
@telemetry.with_telemetry() @telemetry.with_telemetry()
def inspect( def inspect(pipeline_name: str, platform: str) -> None:
pipeline_name: str, platform: str, platform_instance: Optional[str]
) -> None:
""" """
Get the latest stateful ingestion state for a given pipeline. Get the latest stateful ingestion state for a given pipeline.
Only works for state entity removal for now. Only works for state entity removal for now.
""" """
# Note that the platform-instance argument is not generated consistently,
# and is not always equal to the platform_instance config.
datahub_graph = get_default_graph() datahub_graph = get_default_graph()
checkpoint_provider = DatahubIngestionCheckpointingProvider(datahub_graph, "cli") checkpoint_provider = DatahubIngestionCheckpointingProvider(datahub_graph, "cli")
job_name = StaleEntityRemovalHandler.compute_job_id(platform) job_name = StaleEntityRemovalHandler.compute_job_id(platform)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(pipeline_name, job_name) raw_checkpoint = checkpoint_provider.get_latest_checkpoint(pipeline_name, job_name)
if raw_checkpoint is None and platform_instance is not None:
logger.info(
"Failed to fetch state, but trying legacy URN format because platform_instance is provided."
)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(
pipeline_name, job_name, platform_instance_id=platform_instance
)
if not raw_checkpoint: if not raw_checkpoint:
click.secho("No ingestion state found.", fg="red") click.secho("No ingestion state found.", fg="red")
exit(1) exit(1)

View File

@ -1,6 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, NewType, Type, TypeVar from typing import Any, Dict, NewType, Optional, Type, TypeVar
import datahub.emitter.mce_builder as builder import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel from datahub.configuration.common import ConfigModel
@ -43,6 +43,14 @@ class IngestionCheckpointingProviderBase(StatefulCommittable[CheckpointJobStates
def commit(self) -> None: def commit(self) -> None:
pass pass
@abstractmethod
def get_latest_checkpoint(
self,
pipeline_name: str,
job_name: JobId,
) -> Optional[DatahubIngestionCheckpointClass]:
pass
@staticmethod @staticmethod
def get_data_job_urn( def get_data_job_urn(
orchestrator: str, orchestrator: str,
@ -53,14 +61,3 @@ class IngestionCheckpointingProviderBase(StatefulCommittable[CheckpointJobStates
Standardizes datajob urn minting for all ingestion job state providers. Standardizes datajob urn minting for all ingestion job state providers.
""" """
return builder.make_data_job_urn(orchestrator, pipeline_name, job_name) return builder.make_data_job_urn(orchestrator, pipeline_name, job_name)
@staticmethod
def get_data_job_legacy_urn(
orchestrator: str,
pipeline_name: str,
job_name: JobId,
platform_instance_id: str,
) -> str:
return IngestionCheckpointingProviderBase.get_data_job_urn(
orchestrator, f"{pipeline_name}_{platform_instance_id}", job_name
)

View File

@ -1240,6 +1240,3 @@ class GlueSource(StatefulIngestionSourceBase):
def get_report(self): def get_report(self):
return self.report return self.report
def get_platform_instance_id(self) -> Optional[str]:
return self.source_config.platform_instance or self.platform

View File

@ -434,13 +434,6 @@ class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
else: else:
return None return None
def get_platform_instance_id(self) -> Optional[str]:
"""
The source identifier such as the specific source host address required for stateful ingestion.
Individual subclasses need to override this method appropriately.
"""
return f"{self.platform}"
def gen_dataset_key(self, db_name: str, schema: str) -> PlatformKey: def gen_dataset_key(self, db_name: str, schema: str) -> PlatformKey:
return BigQueryDatasetKey( return BigQueryDatasetKey(
project_id=db_name, project_id=db_name,

View File

@ -418,8 +418,3 @@ class DBTCloudSource(DBTSourceBase):
def get_external_url(self, node: DBTNode) -> Optional[str]: def get_external_url(self, node: DBTNode) -> Optional[str]:
# TODO: Once dbt Cloud supports deep linking to specific files, we can use that. # TODO: Once dbt Cloud supports deep linking to specific files, we can use that.
return f"https://cloud.getdbt.com/next/accounts/{self.config.account_id}/projects/{self.config.project_id}/develop" return f"https://cloud.getdbt.com/next/accounts/{self.config.account_id}/projects/{self.config.project_id}/develop"
def get_platform_instance_id(self) -> Optional[str]:
"""The DBT project identifier is used as platform instance."""
return f"{self.platform}_{self.config.project_id}"

View File

@ -488,16 +488,3 @@ class DBTCoreSource(DBTSourceBase):
if self.config.git_info and node.dbt_file_path: if self.config.git_info and node.dbt_file_path:
return self.config.git_info.get_url_for_file_path(node.dbt_file_path) return self.config.git_info.get_url_for_file_path(node.dbt_file_path)
return None return None
def get_platform_instance_id(self) -> Optional[str]:
"""The DBT project identifier is used as platform instance."""
project_id = (
self.load_file_as_json(self.config.manifest_path)
.get("metadata", {})
.get("project_id")
)
if project_id is None:
raise ValueError("DBT project identifier is not found in manifest")
return f"{self.platform}_{project_id}"

View File

@ -318,9 +318,6 @@ class IcebergSource(StatefulIngestionSourceBase):
], ],
} }
def get_platform_instance_id(self) -> Optional[str]:
return self.config.platform_instance
def get_report(self) -> SourceReport: def get_report(self) -> SourceReport:
return self.report return self.report

View File

@ -187,9 +187,6 @@ class KafkaSource(StatefulIngestionSourceBase):
f"Failed to create Kafka Admin Client due to error {e}.", f"Failed to create Kafka Admin Client due to error {e}.",
) )
def get_platform_instance_id(self) -> Optional[str]:
return self.source_config.platform_instance
@classmethod @classmethod
def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource": def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource":
config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict) config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict)

View File

@ -288,13 +288,6 @@ class LDAPSource(StatefulIngestionSourceBase):
cookie = set_cookie(self.lc, pctrls) cookie = set_cookie(self.lc, pctrls)
def get_platform_instance_id(self) -> Optional[str]:
"""
The source identifier such as the specific source host address required for stateful ingestion.
Individual subclasses need to override this method appropriately.
"""
return self.config.ldap_server
def handle_user(self, dn: str, attrs: Dict[str, Any]) -> Iterable[MetadataWorkUnit]: def handle_user(self, dn: str, attrs: Dict[str, Any]) -> Iterable[MetadataWorkUnit]:
""" """
Handle a DN and attributes by adding manager info and constructing a Handle a DN and attributes by adding manager info and constructing a

View File

@ -1357,8 +1357,5 @@ class LookerDashboardSource(TestableSource, StatefulIngestionSourceBase):
def get_report(self) -> SourceReport: def get_report(self) -> SourceReport:
return self.reporter return self.reporter
def get_platform_instance_id(self) -> Optional[str]:
return self.source_config.platform_instance or self.platform
def close(self): def close(self):
self.prepare_for_commit() self.prepare_for_commit()

View File

@ -1778,8 +1778,5 @@ class LookMLSource(StatefulIngestionSourceBase):
def get_report(self): def get_report(self):
return self.reporter return self.reporter
def get_platform_instance_id(self) -> Optional[str]:
return self.source_config.platform_instance or self.platform
def close(self): def close(self):
self.prepare_for_commit() self.prepare_for_commit()

View File

@ -917,9 +917,6 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase):
run_id=ctx.run_id, run_id=ctx.run_id,
) )
def get_platform_instance_id(self) -> Optional[str]:
return self.source_config.platform_name
@classmethod @classmethod
def create(cls, config_dict, ctx): def create(cls, config_dict, ctx):
config = PowerBiDashboardSourceConfig.parse_obj(config_dict) config = PowerBiDashboardSourceConfig.parse_obj(config_dict)

View File

@ -224,9 +224,6 @@ class PulsarSource(StatefulIngestionSourceBase):
f"An ambiguous exception occurred while handling the request: {e}" f"An ambiguous exception occurred while handling the request: {e}"
) )
def get_platform_instance_id(self) -> Optional[str]:
return self.config.platform_instance
@classmethod @classmethod
def create(cls, config_dict, ctx): def create(cls, config_dict, ctx):
config = PulsarSourceConfig.parse_obj(config_dict) config = PulsarSourceConfig.parse_obj(config_dict)

View File

@ -1403,10 +1403,6 @@ class SnowflakeV2Source(
except Exception: except Exception:
self.report.edition = None self.report.edition = None
# Stateful Ingestion Overrides.
def get_platform_instance_id(self) -> Optional[str]:
return self.config.get_account()
# Ideally we do not want null values in sample data for a column. # Ideally we do not want null values in sample data for a column.
# However that would require separate query per column and # However that would require separate query per column and
# that would be expensive, hence not done. # that would be expensive, hence not done.

View File

@ -392,16 +392,6 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
def get_schema_names(self, inspector): def get_schema_names(self, inspector):
return inspector.get_schema_names() return inspector.get_schema_names()
def get_platform_instance_id(self) -> Optional[str]:
"""
The source identifier such as the specific source host address required for stateful ingestion.
Individual subclasses need to override this method appropriately.
"""
config_dict = self.config.dict()
host_port = config_dict.get("host_port", "no_host_port")
database = config_dict.get("database", "no_database")
return f"{self.platform}_{host_port}_{database}"
def get_allowed_schemas(self, inspector: Inspector, db_name: str) -> Iterable[str]: def get_allowed_schemas(self, inspector: Inspector, db_name: str) -> Iterable[str]:
# this function returns the schema names which are filtered by schema_pattern. # this function returns the schema names which are filtered by schema_pattern.
for schema in self.get_schema_names(inspector): for schema in self.get_schema_names(inspector):

View File

@ -166,7 +166,9 @@ class StatefulIngestionSourceBase(Source):
""" """
def __init__( def __init__(
self, config: StatefulIngestionConfigBase, ctx: PipelineContext self,
config: StatefulIngestionConfigBase[StatefulIngestionConfig],
ctx: PipelineContext,
) -> None: ) -> None:
super().__init__(ctx) super().__init__(ctx)
self.stateful_ingestion_config = config.stateful_ingestion self.stateful_ingestion_config = config.stateful_ingestion
@ -278,12 +280,6 @@ class StatefulIngestionSourceBase(Source):
raise ValueError(f"No use-case handler for job_id{job_id}") raise ValueError(f"No use-case handler for job_id{job_id}")
return self._usecase_handlers[job_id].is_checkpointing_enabled() return self._usecase_handlers[job_id].is_checkpointing_enabled()
def get_platform_instance_id(self) -> Optional[str]:
# This method is retained for backwards compatibility, but it is not
# required that new sources implement it. We mainly need it for the
# fallback logic in _get_last_checkpoint.
raise NotImplementedError("no platform_instance_id configured")
def _get_last_checkpoint( def _get_last_checkpoint(
self, job_id: JobId, checkpoint_state_class: Type[StateType] self, job_id: JobId, checkpoint_state_class: Type[StateType]
) -> Optional[Checkpoint]: ) -> Optional[Checkpoint]:
@ -292,27 +288,14 @@ class StatefulIngestionSourceBase(Source):
""" """
last_checkpoint: Optional[Checkpoint] = None last_checkpoint: Optional[Checkpoint] = None
if self.is_stateful_ingestion_configured(): if self.is_stateful_ingestion_configured():
# TRICKY: We currently don't include the platform_instance_id in the
# checkpoint urn, but we previously did. As such, we need to fallback
# and try the old urn format if the new format doesn't return anything.
# Obtain the latest checkpoint from GMS for this job. # Obtain the latest checkpoint from GMS for this job.
assert self.ctx.pipeline_name assert self.ctx.pipeline_name
last_checkpoint_aspect = self.ingestion_checkpointing_state_provider.get_latest_checkpoint( # type: ignore assert self.ingestion_checkpointing_state_provider
last_checkpoint_aspect = (
self.ingestion_checkpointing_state_provider.get_latest_checkpoint(
pipeline_name=self.ctx.pipeline_name, pipeline_name=self.ctx.pipeline_name,
job_name=job_id, job_name=job_id,
) )
if last_checkpoint_aspect is None:
# Try again with the platform_instance_id, if implemented.
try:
platform_instance_id = self.get_platform_instance_id()
except NotImplementedError:
pass
else:
last_checkpoint_aspect = self.ingestion_checkpointing_state_provider.get_latest_checkpoint( # type: ignore
pipeline_name=self.ctx.pipeline_name,
job_name=job_id,
platform_instance_id=platform_instance_id,
) )
# Convert it to a first-class Checkpoint object. # Convert it to a first-class Checkpoint object.
@ -355,6 +338,8 @@ class StatefulIngestionSourceBase(Source):
# Perform validations # Perform validations
if not self.is_stateful_ingestion_configured(): if not self.is_stateful_ingestion_configured():
return None return None
assert self.stateful_ingestion_config
if ( if (
self.stateful_ingestion_config self.stateful_ingestion_config
and self.stateful_ingestion_config.ignore_new_state and self.stateful_ingestion_config.ignore_new_state
@ -378,7 +363,7 @@ class StatefulIngestionSourceBase(Source):
job_checkpoint.prepare_for_commit() job_checkpoint.prepare_for_commit()
try: try:
checkpoint_aspect = job_checkpoint.to_checkpoint_aspect( checkpoint_aspect = job_checkpoint.to_checkpoint_aspect(
self.stateful_ingestion_config.max_checkpoint_state_size # type: ignore self.stateful_ingestion_config.max_checkpoint_state_size
) )
except Exception as e: except Exception as e:
logger.error( logger.error(

View File

@ -64,21 +64,15 @@ class DatahubIngestionCheckpointingProvider(IngestionCheckpointingProviderBase):
self, self,
pipeline_name: str, pipeline_name: str,
job_name: JobId, job_name: JobId,
platform_instance_id: Optional[str] = None,
) -> Optional[DatahubIngestionCheckpointClass]: ) -> Optional[DatahubIngestionCheckpointClass]:
logger.debug( logger.debug(
f"Querying for the latest ingestion checkpoint for pipelineName:'{pipeline_name}'," f"Querying for the latest ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}'" f" job_name:'{job_name}'"
) )
if platform_instance_id is None:
data_job_urn = self.get_data_job_urn( data_job_urn = self.get_data_job_urn(
self.orchestrator_name, pipeline_name, job_name self.orchestrator_name, pipeline_name, job_name
) )
else:
data_job_urn = self.get_data_job_legacy_urn(
self.orchestrator_name, pipeline_name, job_name, platform_instance_id
)
latest_checkpoint: Optional[ latest_checkpoint: Optional[
DatahubIngestionCheckpointClass DatahubIngestionCheckpointClass
@ -92,14 +86,14 @@ class DatahubIngestionCheckpointingProvider(IngestionCheckpointingProviderBase):
if latest_checkpoint: if latest_checkpoint:
logger.debug( logger.debug(
f"The last committed ingestion checkpoint for pipelineName:'{pipeline_name}'," f"The last committed ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found with start_time:" f" job_name:'{job_name}' found with start_time:"
f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis/1000)}" f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis/1000)}"
) )
return latest_checkpoint return latest_checkpoint
else: else:
logger.debug( logger.debug(
f"No committed ingestion checkpoint for pipelineName:'{pipeline_name}'," f"No committed ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found" f" job_name:'{job_name}' found"
) )
return None return None

View File

@ -2264,6 +2264,3 @@ class TableauSource(StatefulIngestionSourceBase):
def get_report(self) -> StaleEntityRemovalSourceReport: def get_report(self) -> StaleEntityRemovalSourceReport:
return self.report return self.report
def get_platform_instance_id(self) -> Optional[str]:
return self.config.platform_instance or self.platform

View File

@ -160,9 +160,6 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
config = UnityCatalogSourceConfig.parse_obj(config_dict) config = UnityCatalogSourceConfig.parse_obj(config_dict)
return cls(ctx=ctx, config=config) return cls(ctx=ctx, config=config)
def get_platform_instance_id(self) -> Optional[str]:
return self.config.platform_instance or self.platform
def get_workunits(self) -> Iterable[MetadataWorkUnit]: def get_workunits(self) -> Iterable[MetadataWorkUnit]:
return auto_stale_entity_removal( return auto_stale_entity_removal(
self.stale_entity_removal_handler, self.stale_entity_removal_handler,