2023-11-30 18:11:36 -05:00

205 lines
7.0 KiB
Python

from datetime import datetime, timezone
from typing import Dict, List
import pydantic
import pytest
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.usage_common_state import (
BaseTimeWindowCheckpointState,
)
from datahub.metadata.schema_classes import (
DatahubIngestionCheckpointClass,
IngestionCheckpointStateClass,
)
# 1. Setup common test param values.
test_pipeline_name: str = "test_pipeline"
test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1"
def _assert_checkpoint_deserialization(
serialized_checkpoint_state: IngestionCheckpointStateClass,
expected_checkpoint_state: CheckpointStateBase,
) -> Checkpoint:
# Serialize a checkpoint aspect with the previous state.
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.now(tz=timezone.utc).timestamp() * 1000),
pipelineName=test_pipeline_name,
platformInstanceId="this-can-be-anything-and-will-be-ignored",
config="this-is-also-ignored",
state=serialized_checkpoint_state,
runId=test_run_id,
)
# 2. Create the checkpoint from the raw checkpoint aspect and validate.
checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(expected_checkpoint_state),
)
expected_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
run_id=test_run_id,
state=expected_checkpoint_state,
)
assert checkpoint_obj == expected_checkpoint_obj
return checkpoint_obj
# 2. Create the params for parametrized tests.
def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState:
# Note that the urns here purposely use a lowercase env, even though it's
# technically incorrect. This is purely for backwards compatibility testing, but
# all existing code uses correctly formed envs.
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="table", urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.t1,prod)"
)
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="view", urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.v1,prod)"
)
return base_sql_alchemy_checkpoint_state_obj
def _make_usage_checkpoint_state() -> BaseTimeWindowCheckpointState:
base_usage_checkpoint_state_obj = BaseTimeWindowCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
return base_usage_checkpoint_state_obj
_checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = {
# An instance of BaseSQLAlchemyCheckpointState.
"BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(),
# An instance of BaseTimeWindowCheckpointState.
"BaseTimeWindowCheckpointState": _make_usage_checkpoint_state(),
}
# 3. Define the test with the params
@pytest.mark.parametrize(
"state_obj",
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_checkpoint_serde(state_obj: CheckpointStateBase) -> None:
"""
Tests CheckpointStateBase.to_bytes() and Checkpoint.create_from_checkpoint_aspect().
"""
# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)
_assert_checkpoint_deserialization(checkpoint_state, state_obj)
@pytest.mark.parametrize(
"state_obj",
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_serde_idempotence(state_obj):
"""
Verifies that Serialization + Deserialization reconstructs the original object fully.
"""
# 1. Construct the initial checkpoint object
orig_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
run_id=test_run_id,
state=state_obj,
)
# 2. Convert it to the aspect form.
checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect(
max_allowed_state_size=2**20
)
assert checkpoint_aspect is not None
# 3. Reconstruct from the aspect form and verify that it matches the original.
serde_checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
)
assert orig_checkpoint_obj == serde_checkpoint_obj
def test_supported_encodings():
"""
Tests utf-8 and base85-bz2-json encodings
"""
test_state = BaseTimeWindowCheckpointState(
version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
# 1. Test UTF-8 encoding
test_state.serde = "utf-8"
test_serde_idempotence(test_state)
# 2. Test Base85 encoding
test_state.serde = "base85-bz2-json"
test_serde_idempotence(test_state)
def test_base85_upgrade_pickle_to_json():
"""Verify that base85 (pickle) encoding is transitioned to base85-bz2-json."""
base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra"
checkpoint_state = IngestionCheckpointStateClass(
formatVersion="1.0", serde="base85", payload=base85_payload
)
checkpoint = _assert_checkpoint_deserialization(
checkpoint_state, _checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"]
)
assert checkpoint.state.serde == "base85-bz2-json"
assert len(checkpoint.state.to_bytes()) < len(base85_payload)
@pytest.mark.parametrize(
"serde",
["utf-8", "base85-bz2-json"],
)
def test_state_forward_compatibility(serde: str) -> None:
class PrevState(CheckpointStateBase):
list_a: List[str]
list_b: List[str]
class NextState(CheckpointStateBase):
list_stuff: List[str]
@pydantic.root_validator(pre=True, allow_reuse=True)
def _migrate(cls, values: dict) -> dict:
values.setdefault("list_stuff", [])
values["list_stuff"] += values.pop("list_a", [])
values["list_stuff"] += values.pop("list_b", [])
return values
prev_state = PrevState(list_a=["a", "b"], list_b=["c", "d"], serde=serde)
expected_next_state = NextState(list_stuff=["a", "b", "c", "d"], serde=serde)
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=prev_state.version,
serde=prev_state.serde,
payload=prev_state.to_bytes(),
)
_assert_checkpoint_deserialization(checkpoint_state, expected_next_state)