chore(ingest): remove pickle from stateful ingestion (#14531)

This commit is contained in:
Harshal Sheth 2025-08-21 16:12:47 -07:00 committed by GitHub
parent d64d296639
commit e6ac57f465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 35 deletions

View File

@ -1,10 +1,8 @@
import base64 import base64
import bz2 import bz2
import contextlib
import functools import functools
import json import json
import logging import logging
import pickle
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Callable, Generic, Optional, Type, TypeVar from typing import Callable, Generic, Optional, Type, TypeVar
@ -117,10 +115,9 @@ class Checkpoint(Generic[StateType]):
checkpoint_aspect, state_class checkpoint_aspect, state_class
) )
elif checkpoint_aspect.state.serde == "base85": elif checkpoint_aspect.state.serde == "base85":
state_obj = Checkpoint._from_base85_bytes( raise ValueError(
checkpoint_aspect, "The base85 encoding for stateful ingestion has been removed for security reasons. "
functools.partial(bz2.decompress), "You may need to temporarily set `ignore_previous_checkpoint` to true to ignore the outdated checkpoint object."
state_class,
) )
elif checkpoint_aspect.state.serde == "base85-bz2-json": elif checkpoint_aspect.state.serde == "base85-bz2-json":
state_obj = Checkpoint._from_base85_json_bytes( state_obj = Checkpoint._from_base85_json_bytes(
@ -164,28 +161,6 @@ class Checkpoint(Generic[StateType]):
state_as_dict["serde"] = checkpoint_aspect.state.serde state_as_dict["serde"] = checkpoint_aspect.state.serde
return state_class.parse_obj(state_as_dict) return state_class.parse_obj(state_as_dict)
@staticmethod
def _from_base85_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
state_class: Type[StateType],
) -> StateType:
state: StateType = pickle.loads(
decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore
)
with contextlib.suppress(Exception):
# When loading from pickle, the pydantic validators don't run.
# By re-serializing and re-parsing, we ensure that the state is valid.
# However, we also suppress any exceptions to make sure this doesn't blow up.
state = state_class.parse_obj(state.dict())
# Because the base85 method is deprecated in favor of base85-bz2-json,
# we will automatically switch the serde.
state.serde = "base85-bz2-json"
return state
@staticmethod @staticmethod
def _from_base85_json_bytes( def _from_base85_json_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass, checkpoint_aspect: DatahubIngestionCheckpointClass,

View File

@ -158,19 +158,22 @@ def test_supported_encodings():
test_serde_idempotence(test_state) test_serde_idempotence(test_state)
def test_base85_upgrade_pickle_to_json(): def test_base85_is_removed():
"""Verify that base85 (pickle) encoding is transitioned to base85-bz2-json.""" """Verify that base85 encoding throws an error."""
base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra" base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra"
checkpoint_state = IngestionCheckpointStateClass( checkpoint_state = IngestionCheckpointStateClass(
formatVersion="1.0", serde="base85", payload=base85_payload formatVersion="1.0", serde="base85", payload=base85_payload
) )
checkpoint = _assert_checkpoint_deserialization( with pytest.raises(
checkpoint_state, _checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"] ValueError,
match=r"base85 encoding.*removed",
):
_assert_checkpoint_deserialization(
checkpoint_state,
_checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"],
) )
assert checkpoint.state.serde == "base85-bz2-json"
assert len(checkpoint.state.to_bytes()) < len(base85_payload)
@pytest.mark.parametrize( @pytest.mark.parametrize(