mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 10:35:51 +00:00
feat(ingest): replace base85's pickle with json (#6178)
This commit is contained in:
parent
034f2e9ff3
commit
d08f5f7cdd
@ -18,6 +18,8 @@ from datahub.metadata.schema_classes import (
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_STATE_SIZE = 2**22 # 4MB
|
||||
|
||||
|
||||
class CheckpointStateBase(ConfigModel):
|
||||
"""
|
||||
@ -28,17 +30,14 @@ class CheckpointStateBase(ConfigModel):
|
||||
"""
|
||||
|
||||
version: str = pydantic.Field(default="1.0")
|
||||
serde: str = pydantic.Field(default="base85")
|
||||
serde: str = pydantic.Field(default="base85-bz2-json")
|
||||
|
||||
def to_bytes(
|
||||
self,
|
||||
compressor: Callable[[bytes], bytes] = functools.partial(
|
||||
bz2.compress, compresslevel=9
|
||||
),
|
||||
# fmt: off
|
||||
# 4 MB
|
||||
max_allowed_state_size: int = 2**22,
|
||||
# fmt: on
|
||||
max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE,
|
||||
) -> bytes:
|
||||
"""
|
||||
NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect
|
||||
@ -50,7 +49,13 @@ class CheckpointStateBase(ConfigModel):
|
||||
if self.serde == "utf-8":
|
||||
encoded_bytes = CheckpointStateBase._to_bytes_utf8(self)
|
||||
elif self.serde == "base85":
|
||||
encoded_bytes = CheckpointStateBase._to_bytes_base85(self, compressor)
|
||||
# The original base85 implementation used pickle, which would cause
|
||||
# issues with deserialization if we ever changed the state class definition.
|
||||
raise ValueError(
|
||||
"Cannot write base85 encoded bytes. Use base85-bz2-json instead."
|
||||
)
|
||||
elif self.serde == "base85-bz2-json":
|
||||
encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor)
|
||||
else:
|
||||
raise ValueError(f"Unknown serde: {self.serde}")
|
||||
|
||||
@ -66,10 +71,10 @@ class CheckpointStateBase(ConfigModel):
|
||||
return model.json(exclude={"version", "serde"}).encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _to_bytes_base85(
|
||||
def _to_bytes_base85_json(
|
||||
model: ConfigModel, compressor: Callable[[bytes], bytes]
|
||||
) -> bytes:
|
||||
return base64.b85encode(compressor(pickle.dumps(model)))
|
||||
return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model)))
|
||||
|
||||
def prepare_for_commit(self) -> None:
|
||||
"""
|
||||
@ -125,6 +130,12 @@ class Checkpoint(Generic[StateType]):
|
||||
state_obj = Checkpoint._from_base85_bytes(
|
||||
checkpoint_aspect, functools.partial(bz2.decompress)
|
||||
)
|
||||
elif checkpoint_aspect.state.serde == "base85-bz2-json":
|
||||
state_obj = Checkpoint._from_base85_json_bytes(
|
||||
checkpoint_aspect,
|
||||
functools.partial(bz2.decompress),
|
||||
state_class,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
|
||||
except Exception as e:
|
||||
@ -167,10 +178,32 @@ class Checkpoint(Generic[StateType]):
|
||||
checkpoint_aspect: DatahubIngestionCheckpointClass,
|
||||
decompressor: Callable[[bytes], bytes],
|
||||
) -> StateType:
|
||||
return pickle.loads(
|
||||
state: StateType = pickle.loads(
|
||||
decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore
|
||||
)
|
||||
|
||||
# Because the base85 method is deprecated in favor of base85-bz2-json,
|
||||
# we will automatically switch the serde.
|
||||
state.serde = "base85-bz2-json"
|
||||
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _from_base85_json_bytes(
|
||||
checkpoint_aspect: DatahubIngestionCheckpointClass,
|
||||
decompressor: Callable[[bytes], bytes],
|
||||
state_class: Type[StateType],
|
||||
) -> StateType:
|
||||
state_uncompressed = decompressor(
|
||||
base64.b85decode(checkpoint_aspect.state.payload)
|
||||
if checkpoint_aspect.state.payload is not None
|
||||
else b"{}"
|
||||
)
|
||||
state_as_dict = json.loads(state_uncompressed.decode("utf-8"))
|
||||
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
|
||||
state_as_dict["serde"] = checkpoint_aspect.state.serde
|
||||
return state_class.parse_obj(state_as_dict)
|
||||
|
||||
def to_checkpoint_aspect(
|
||||
self, max_allowed_state_size: int
|
||||
) -> Optional[DatahubIngestionCheckpointClass]:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from datahub.emitter.mce_builder import make_dataset_urn
|
||||
@ -23,54 +24,18 @@ test_job_name: str = "test_job_1"
|
||||
test_run_id: str = "test_run_1"
|
||||
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")
|
||||
|
||||
# 2. Create the params for parametrized tests.
|
||||
|
||||
# 2.1 Create and add an instance of BaseSQLAlchemyCheckpointState.
|
||||
test_checkpoint_serde_params: Dict[str, CheckpointStateBase] = {}
|
||||
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
|
||||
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
|
||||
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
|
||||
)
|
||||
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
|
||||
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
|
||||
)
|
||||
test_checkpoint_serde_params[
|
||||
"BaseSQLAlchemyCheckpointState"
|
||||
] = base_sql_alchemy_checkpoint_state_obj
|
||||
|
||||
# 2.2 Create and add an instance of BaseUsageCheckpointState.
|
||||
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
|
||||
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
|
||||
)
|
||||
test_checkpoint_serde_params[
|
||||
"BaseUsageCheckpointState"
|
||||
] = base_usage_checkpoint_state_obj
|
||||
|
||||
|
||||
# 3. Define the test with the params
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_obj",
|
||||
test_checkpoint_serde_params.values(),
|
||||
ids=test_checkpoint_serde_params.keys(),
|
||||
)
|
||||
def test_create_from_checkpoint_aspect(state_obj):
|
||||
"""
|
||||
Tests the Checkpoint class API 'create_from_checkpoint_aspect' with the state_obj parameter as the state.
|
||||
"""
|
||||
# 1. Construct the raw aspect object with the state
|
||||
checkpoint_state = IngestionCheckpointStateClass(
|
||||
formatVersion=state_obj.version,
|
||||
serde=state_obj.serde,
|
||||
payload=state_obj.to_bytes(),
|
||||
)
|
||||
def _assert_checkpoint_deserialization(
|
||||
serialized_checkpoint_state: IngestionCheckpointStateClass,
|
||||
expected_checkpoint_state: CheckpointStateBase,
|
||||
) -> Checkpoint:
|
||||
# Serialize a checkpoint aspect with the previous state.
|
||||
checkpoint_aspect = DatahubIngestionCheckpointClass(
|
||||
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
|
||||
pipelineName=test_pipeline_name,
|
||||
platformInstanceId=test_platform_instance_id,
|
||||
config=test_source_config.json(),
|
||||
state=checkpoint_state,
|
||||
state=serialized_checkpoint_state,
|
||||
runId=test_run_id,
|
||||
)
|
||||
|
||||
@ -78,7 +43,7 @@ def test_create_from_checkpoint_aspect(state_obj):
|
||||
checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
|
||||
job_name=test_job_name,
|
||||
checkpoint_aspect=checkpoint_aspect,
|
||||
state_class=type(state_obj),
|
||||
state_class=type(expected_checkpoint_state),
|
||||
config_class=PostgresConfig,
|
||||
)
|
||||
|
||||
@ -88,15 +53,69 @@ def test_create_from_checkpoint_aspect(state_obj):
|
||||
platform_instance_id=test_platform_instance_id,
|
||||
run_id=test_run_id,
|
||||
config=test_source_config,
|
||||
state=state_obj,
|
||||
state=expected_checkpoint_state,
|
||||
)
|
||||
assert checkpoint_obj == expected_checkpoint_obj
|
||||
|
||||
return checkpoint_obj
|
||||
|
||||
|
||||
# 2. Create the params for parametrized tests.
|
||||
|
||||
|
||||
def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState:
|
||||
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
|
||||
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
|
||||
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
|
||||
)
|
||||
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
|
||||
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
|
||||
)
|
||||
return base_sql_alchemy_checkpoint_state_obj
|
||||
|
||||
|
||||
def _make_usage_checkpoint_state() -> BaseUsageCheckpointState:
|
||||
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
|
||||
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
|
||||
)
|
||||
return base_usage_checkpoint_state_obj
|
||||
|
||||
|
||||
_checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = {
|
||||
# An instance of BaseSQLAlchemyCheckpointState.
|
||||
"BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(),
|
||||
# An instance of BaseUsageCheckpointState.
|
||||
"BaseUsageCheckpointState": _make_usage_checkpoint_state(),
|
||||
}
|
||||
|
||||
|
||||
# 3. Define the test with the params
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_obj",
|
||||
test_checkpoint_serde_params.values(),
|
||||
ids=test_checkpoint_serde_params.keys(),
|
||||
_checkpoint_aspect_test_cases.values(),
|
||||
ids=_checkpoint_aspect_test_cases.keys(),
|
||||
)
|
||||
def test_checkpoint_serde(state_obj: CheckpointStateBase) -> None:
|
||||
"""
|
||||
Tests CheckpointStateBase.to_bytes() and Checkpoint.create_from_checkpoint_aspect().
|
||||
"""
|
||||
|
||||
# 1. Construct the raw aspect object with the state
|
||||
checkpoint_state = IngestionCheckpointStateClass(
|
||||
formatVersion=state_obj.version,
|
||||
serde=state_obj.serde,
|
||||
payload=state_obj.to_bytes(),
|
||||
)
|
||||
|
||||
_assert_checkpoint_deserialization(checkpoint_state, state_obj)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_obj",
|
||||
_checkpoint_aspect_test_cases.values(),
|
||||
ids=_checkpoint_aspect_test_cases.keys(),
|
||||
)
|
||||
def test_serde_idempotence(state_obj):
|
||||
"""
|
||||
@ -114,9 +133,7 @@ def test_serde_idempotence(state_obj):
|
||||
|
||||
# 2. Convert it to the aspect form.
|
||||
checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect(
|
||||
# fmt: off
|
||||
max_allowed_state_size=2**20
|
||||
# fmt: on
|
||||
)
|
||||
assert checkpoint_aspect is not None
|
||||
|
||||
@ -132,7 +149,7 @@ def test_serde_idempotence(state_obj):
|
||||
|
||||
def test_supported_encodings():
|
||||
"""
|
||||
Tests utf-8 and base85 encodings
|
||||
Tests utf-8 and base85-bz2-json encodings
|
||||
"""
|
||||
test_state = BaseUsageCheckpointState(
|
||||
version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100
|
||||
@ -143,5 +160,51 @@ def test_supported_encodings():
|
||||
test_serde_idempotence(test_state)
|
||||
|
||||
# 2. Test Base85 encoding
|
||||
test_state.serde = "base85"
|
||||
test_state.serde = "base85-bz2-json"
|
||||
test_serde_idempotence(test_state)
|
||||
|
||||
|
||||
def test_base85_upgrade_pickle_to_json():
|
||||
"""Verify that base85 (pickle) encoding is transitioned to base85-bz2-json."""
|
||||
|
||||
base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra"
|
||||
checkpoint_state = IngestionCheckpointStateClass(
|
||||
formatVersion="1.0", serde="base85", payload=base85_payload
|
||||
)
|
||||
|
||||
checkpoint = _assert_checkpoint_deserialization(
|
||||
checkpoint_state, _checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"]
|
||||
)
|
||||
assert checkpoint.state.serde == "base85-bz2-json"
|
||||
assert len(checkpoint.state.to_bytes()) < len(base85_payload)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"serde",
|
||||
["utf-8", "base85-bz2-json"],
|
||||
)
|
||||
def test_state_forward_compatibility(serde: str) -> None:
|
||||
class PrevState(CheckpointStateBase):
|
||||
list_a: List[str]
|
||||
list_b: List[str]
|
||||
|
||||
class NextState(CheckpointStateBase):
|
||||
list_stuff: List[str]
|
||||
|
||||
@pydantic.root_validator(pre=True, allow_reuse=True)
|
||||
def _migrate(cls, values: dict) -> dict:
|
||||
values.setdefault("list_stuff", [])
|
||||
values["list_stuff"] += values.pop("list_a", [])
|
||||
values["list_stuff"] += values.pop("list_b", [])
|
||||
return values
|
||||
|
||||
prev_state = PrevState(list_a=["a", "b"], list_b=["c", "d"], serde=serde)
|
||||
expected_next_state = NextState(list_stuff=["a", "b", "c", "d"], serde=serde)
|
||||
|
||||
checkpoint_state = IngestionCheckpointStateClass(
|
||||
formatVersion=prev_state.version,
|
||||
serde=prev_state.serde,
|
||||
payload=prev_state.to_bytes(),
|
||||
)
|
||||
|
||||
_assert_checkpoint_deserialization(checkpoint_state, expected_next_state)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user