datahub/smoke-test/tests/test_stateful_ingestion.py
2025-01-17 23:50:13 +05:30

137 lines
5.2 KiB
Python

from typing import Any, Dict, Optional, cast
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from datahub.ingestion.api.committable import StatefulCommittable
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from tests.utils import get_mysql_password, get_mysql_url, get_mysql_username
def test_stateful_ingestion(auth_session):
def create_mysql_engine(mysql_source_config_dict: Dict[str, Any]) -> Any:
mysql_config = MySQLConfig.parse_obj(mysql_source_config_dict)
url = mysql_config.get_sql_alchemy_url()
return create_engine(url)
def create_table(engine: Any, name: str, defn: str) -> None:
create_table_query = text(f"CREATE TABLE IF NOT EXISTS {name}{defn};")
engine.execute(create_table_query)
def drop_table(engine: Any, table_name: str) -> None:
drop_table_query = text(f"DROP TABLE {table_name};")
engine.execute(drop_table_query)
def run_and_get_pipeline(pipeline_config_dict: Dict[str, Any]) -> Pipeline:
pipeline = Pipeline.create(pipeline_config_dict)
pipeline.run()
pipeline.raise_from_status()
return pipeline
def validate_all_providers_have_committed_successfully(pipeline: Pipeline) -> None:
provider_count: int = 0
for _, provider in pipeline.ctx.get_committables():
provider_count += 1
assert isinstance(provider, StatefulCommittable)
stateful_committable = cast(StatefulCommittable, provider)
assert stateful_committable.has_successfully_committed()
assert stateful_committable.state_to_commit
assert provider_count == 1
def get_current_checkpoint_from_pipeline(
auth_session,
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here.
mysql_source = cast(MySQLSource, pipeline.source)
return mysql_source.state_provider.get_current_checkpoint(
StaleEntityRemovalHandler.compute_job_id(
getattr(mysql_source, "platform", "default")
)
)
source_config_dict: Dict[str, Any] = {
"host_port": get_mysql_url(),
"username": get_mysql_username(),
"password": get_mysql_password(),
"database": "datahub",
"stateful_ingestion": {
"enabled": True,
"remove_stale_metadata": True,
"fail_safe_threshold": 100.0,
},
}
pipeline_config_dict: Dict[str, Any] = {
"source": {
"type": "mysql",
"config": source_config_dict,
},
"sink": {
"type": "datahub-rest",
"config": {
"server": auth_session.gms_url(),
"token": auth_session.gms_token(),
},
},
"pipeline_name": "mysql_stateful_ingestion_smoke_test_pipeline",
"reporting": [
{
"type": "datahub",
}
],
}
# 1. Setup the SQL engine
mysql_engine = create_mysql_engine(source_config_dict)
# 2. Create test tables for first run of the pipeline.
table_prefix = "stateful_ingestion_test"
table_defs = {
f"{table_prefix}_t1": "(id INT, name VARCHAR(10))",
f"{table_prefix}_t2": "(id INT)",
}
table_names = sorted(table_defs.keys())
for table_name, defn in table_defs.items():
create_table(mysql_engine, table_name, defn)
# 3. Do the first run of the pipeline and get the default job's checkpoint.
pipeline_run1 = run_and_get_pipeline(pipeline_config_dict)
checkpoint1 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run1)
assert checkpoint1
assert checkpoint1.state
# 4. Drop table t1 created during step 2 + rerun the pipeline and get the checkpoint state.
drop_table(mysql_engine, table_names[0])
pipeline_run2 = run_and_get_pipeline(pipeline_config_dict)
checkpoint2 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run2)
assert checkpoint2
assert checkpoint2.state
# 5. Perform all assertions on the states
state1 = checkpoint1.state
state2 = checkpoint2.state
difference_urns = list(
state1.get_urns_not_in(type="*", other_checkpoint_state=state2)
)
assert len(difference_urns) == 1
assert (
difference_urns[0]
== "urn:li:dataset:(urn:li:dataPlatform:mysql,datahub.stateful_ingestion_test_t1,PROD)"
)
# 6. Cleanup table t2 as well to prevent other tests that rely on data in the smoke-test world.
drop_table(mysql_engine, table_names[1])
# 7. Validate that all providers have committed successfully.
# NOTE: The following validation asserts for presence of state as well
# and validates reporting.
validate_all_providers_have_committed_successfully(pipeline_run1)
validate_all_providers_have_committed_successfully(pipeline_run2)