feat(actions): support pydantic v2 (#13378)

This commit is contained in:
Harshal Sheth 2025-04-30 19:39:35 -07:00 committed by GitHub
parent d25d318233
commit 591b6ce0c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 85 additions and 149 deletions

View File

@ -6,16 +6,20 @@ on:
paths:
- ".github/workflows/python-build-pages.yml"
- "metadata-ingestion/**"
- "datahub-actions/**"
- "metadata-ingestion-modules/**"
- "metadata-models/**"
- "python-build/**"
pull_request:
branches:
- "**"
paths:
- ".github/workflows/python-build-pages.yml"
- "metadata-ingestion/**"
- "datahub-actions/**"
- "metadata-ingestion-modules/**"
- "metadata-models/**"
- "python-build/**"
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

View File

@ -101,33 +101,13 @@ task installDevTest(type: Exec, dependsOn: [installDev]) {
"touch ${sentinel_file}"
}
task testQuick(type: Exec, dependsOn: installDevTest) {
// We can't enforce the coverage requirements if we run a subset of the tests.
task testFull(type: Exec, dependsOn: installDevTest) {
inputs.files(project.fileTree(dir: "src/", include: "**/*.py"))
inputs.files(project.fileTree(dir: "tests/"))
outputs.dir("${venv_name}")
commandLine 'bash', '-c',
"source ${venv_name}/bin/activate && set -x && " +
"pytest -vv ${get_coverage_args('quick')} --continue-on-collection-errors --junit-xml=junit.quick.xml"
}
def testFile = hasProperty('testFile') ? testFile : 'unknown'
task testSingle(dependsOn: [installDevTest]) {
doLast {
if (testFile != 'unknown') {
exec {
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && pytest ${testFile}"
}
} else {
throw new GradleException("No file provided. Use -PtestFile=<test_file>")
}
}
}
task testFull(type: Exec, dependsOn: [testQuick, installDevTest]) {
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && pytest -vv ${get_coverage_args('full')} --continue-on-collection-errors --junit-xml=junit.full.xml"
"pytest -vv ${get_coverage_args('full')} --continue-on-collection-errors --junit-xml=junit.full.xml"
}
task buildWheel(type: Exec, dependsOn: [environmentSetup]) {
@ -172,7 +152,7 @@ docker {
build.dependsOn install
check.dependsOn lint
check.dependsOn testQuick
check.dependsOn testFull
clean {
delete venv_name

View File

@ -21,6 +21,13 @@ package_metadata: dict = {}
with open("./src/datahub_actions/_version.py") as fp:
exec(fp.read(), package_metadata)
_version: str = package_metadata["__version__"]
_self_pin = (
f"=={_version}"
if not (_version.endswith(("dev0", "dev1")) or "docker" in _version)
else ""
)
def get_long_description():
root = os.path.dirname(__file__)
@ -30,8 +37,6 @@ def get_long_description():
return description
acryl_datahub_min_version = os.environ.get("ACRYL_DATAHUB_MIN_VERSION") or "1.0.0"
lint_requirements = {
# This is pinned only to avoid spurious errors in CI.
# We should make an effort to keep it up to date.
@ -40,18 +45,17 @@ lint_requirements = {
}
base_requirements = {
*lint_requirements,
f"acryl-datahub[datahub-kafka]>={acryl_datahub_min_version}",
f"acryl-datahub[datahub-kafka]{_self_pin}",
# Compatibility.
"typing_extensions>=3.7.4; python_version < '3.8'",
"mypy_extensions>=0.4.3",
# Actual dependencies.
"typing-inspect",
"pydantic<2",
"dictdiffer",
"pydantic>=1.10.21",
"ratelimit",
# Lower bounds on httpcore and h11 due to CVE-2025-43859.
"httpcore>=1.0.9",
"h11>=0.16"
"h11>=0.16",
}
framework_common = {
@ -67,14 +71,6 @@ framework_common = {
"tenacity",
}
aws_common = {
# AWS Python SDK
"boto3",
# Deal with a version incompatibility between botocore (used by boto3) and urllib3.
# See https://github.com/boto/botocore/pull/2563.
"botocore!=1.23.0",
}
# Note: for all of these, framework_common will be added.
plugins: Dict[str, Set[str]] = {
# Source Plugins
@ -94,7 +90,7 @@ plugins: Dict[str, Set[str]] = {
"tag_propagation": set(),
"term_propagation": set(),
"snowflake_tag_propagation": {
f"acryl-datahub[snowflake]>={acryl_datahub_min_version}"
f"acryl-datahub[snowflake-slim]{_self_pin}",
},
"doc_propagation": set(),
# Transformer Plugins (None yet)
@ -115,10 +111,10 @@ mypy_stubs = {
"types-cachetools",
# versions 0.1.13 and 0.1.14 seem to have issues
"types-click==0.1.12",
"boto3-stubs[s3,glue,sagemaker]",
}
base_dev_requirements = {
*lint_requirements,
*base_requirements,
*framework_common,
*mypy_stubs,
@ -169,6 +165,9 @@ full_test_dev_requirements = {
]
for dependency in plugins[plugin]
),
# In our tests, we want to always test against pydantic v2.
# However, we maintain compatibility with pydantic v1 for now.
"pydantic>2",
}
entry_points = {

View File

@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
from datahub.configuration import ConfigModel
from datahub.configuration.common import ConfigEnum
from datahub.ingestion.graph.client import DatahubClientConfig
class FailureMode(str, Enum):
class FailureMode(ConfigEnum):
# Log the failed event to the failed events log. Then throw an pipeline exception to stop the pipeline.
THROW = "THROW"
# Log the failed event to the failed events log. Then continue processing the event stream.
@ -30,17 +30,17 @@ class FailureMode(str, Enum):
class SourceConfig(ConfigModel):
type: str
config: Optional[Dict[str, Any]]
config: Optional[Dict[str, Any]] = None
class TransformConfig(ConfigModel):
type: str
config: Optional[Dict[str, Any]]
config: Optional[Dict[str, Any]] = None
class FilterConfig(ConfigModel):
event_type: Union[str, List[str]]
event: Optional[Dict[str, Any]]
event: Optional[Dict[str, Any]] = None
class ActionConfig(ConfigModel):
@ -49,12 +49,11 @@ class ActionConfig(ConfigModel):
class PipelineOptions(BaseModel):
retry_count: Optional[int]
failure_mode: Optional[FailureMode]
failed_events_dir: Optional[str] # The path where failed events should be logged.
class Config:
use_enum_values = True
retry_count: Optional[int] = None
failure_mode: Optional[FailureMode] = None
failed_events_dir: Optional[str] = (
None # The path where failed events should be logged.
)
class PipelineConfig(ConfigModel):
@ -68,8 +67,8 @@ class PipelineConfig(ConfigModel):
name: str
enabled: bool = True
source: SourceConfig
filter: Optional[FilterConfig]
transform: Optional[List[TransformConfig]]
filter: Optional[FilterConfig] = None
transform: Optional[List[TransformConfig]] = None
action: ActionConfig
datahub: Optional[DatahubClientConfig]
options: Optional[PipelineOptions]
datahub: Optional[DatahubClientConfig] = None
options: Optional[PipelineOptions] = None

View File

@ -68,8 +68,8 @@ def import_path(path: str) -> Any:
class ExecutorConfig(BaseModel):
executor_id: Optional[str]
task_configs: Optional[List[TaskConfig]]
executor_id: Optional[str] = None
task_configs: Optional[List[TaskConfig]] = None
# Listens to new Execution Requests & dispatches them to the appropriate handler.
@ -203,7 +203,10 @@ class ExecutorAction(Action):
SecretStoreConfig(type="env", config=dict({})),
SecretStoreConfig(
type="datahub",
config=DataHubSecretStoreConfig(graph_client=graph),
# TODO: Once SecretStoreConfig is updated to accept arbitrary types
# and not just dicts, we can just pass in the DataHubSecretStoreConfig
# object directly.
config=DataHubSecretStoreConfig(graph_client=graph).dict(),
),
],
graph_client=graph,

View File

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class HelloWorldConfig(BaseModel):
# Whether to print the message in upper case.
to_upper: Optional[bool]
to_upper: Optional[bool] = None
# A basic example of a DataHub action that prints all

View File

@ -19,13 +19,13 @@ logger = logging.getLogger(__name__)
class MetadataChangeEmitterConfig(BaseModel):
gms_server: Optional[str]
gms_auth_token: Optional[str]
aspects_to_exclude: Optional[List]
aspects_to_include: Optional[List]
gms_server: Optional[str] = None
gms_auth_token: Optional[str] = None
aspects_to_exclude: Optional[List] = None
aspects_to_include: Optional[List] = None
entity_type_to_exclude: List[str] = Field(default_factory=list)
extra_headers: Optional[Dict[str, str]]
urn_regex: Optional[str]
extra_headers: Optional[Dict[str, str]] = None
urn_regex: Optional[str] = None
class MetadataChangeSyncAction(Action):

View File

@ -15,11 +15,11 @@
import json
import logging
import time
from enum import Enum
from typing import Iterable, List, Optional, Tuple
from pydantic import Field
from datahub.configuration.common import ConfigEnum
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.metadata.schema_classes import (
AuditStampClass,
@ -60,7 +60,7 @@ class DocPropagationDirective(PropagationDirective):
)
class ColumnPropagationRelationships(str, Enum):
class ColumnPropagationRelationships(ConfigEnum):
UPSTREAM = "upstream"
DOWNSTREAM = "downstream"
SIBLING = "sibling"
@ -82,18 +82,15 @@ class DocPropagationConfig(PropagationConfig):
enabled: bool = Field(
True,
description="Indicates whether documentation propagation is enabled or not.",
example=True,
)
columns_enabled: bool = Field(
True,
description="Indicates whether column documentation propagation is enabled or not.",
example=True,
)
# TODO: Currently this flag does nothing. Datasets are NOT supported for docs propagation.
datasets_enabled: bool = Field(
False,
description="Indicates whether dataset level documentation propagation is enabled or not.",
example=False,
)
column_propagation_relationships: List[ColumnPropagationRelationships] = Field(
[
@ -102,11 +99,6 @@ class DocPropagationConfig(PropagationConfig):
ColumnPropagationRelationships.UPSTREAM,
],
description="Relationships for column documentation propagation.",
example=[
ColumnPropagationRelationships.UPSTREAM,
ColumnPropagationRelationships.SIBLING,
ColumnPropagationRelationships.DOWNSTREAM,
],
)

View File

@ -15,13 +15,13 @@
import abc
import json
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, Optional
import pydantic
from pydantic import BaseModel
from datahub.ingestion.api.report import Report, SupportsAsObj
from datahub.utilities.str_enum import StrEnum
from datahub_actions.action.action import Action
from datahub_actions.event.event_envelope import EventEnvelope
from datahub_actions.event.event_registry import (
@ -114,7 +114,7 @@ class EventProcessingStats(BaseModel):
return json.dumps(self.dict(), indent=2)
class StageStatus(str, Enum):
class StageStatus(StrEnum):
SUCCESS = "success"
FAILURE = "failure"
RUNNING = "running"

View File

@ -52,12 +52,13 @@ class TagPropagationConfig(ConfigModel):
enabled: bool = Field(
True,
description="Indicates whether tag propagation is enabled or not.",
example=True,
)
tag_prefixes: Optional[List[str]] = Field(
None,
description="Optional list of tag prefixes to restrict tag propagation.",
example=["urn:li:tag:classification"],
examples=[
"urn:li:tag:classification",
],
)
@validator("tag_prefixes", each_item=True)

View File

@ -60,17 +60,21 @@ class TermPropagationConfig(ConfigModel):
enabled: bool = Field(
True,
description="Indicates whether term propagation is enabled or not.",
example=True,
)
target_terms: Optional[List[str]] = Field(
None,
description="Optional target terms to restrict term propagation to this and all terms related to these terms.",
example="[urn:li:glossaryTerm:Sensitive]",
examples=[
"urn:li:glossaryTerm:Sensitive",
],
)
term_groups: Optional[List[str]] = Field(
None,
description="Optional list of term groups to restrict term propagation.",
example=["Group1", "Group2"],
examples=[
"Group1",
"Group2",
],
)

View File

@ -44,7 +44,7 @@ def build_entity_change_event(payload: GenericPayloadClass) -> EntityChangeEvent
class DataHubEventsSourceConfig(ConfigModel):
topic: str = PLATFORM_EVENT_TOPIC_NAME
consumer_id: Optional[str] # Used to store offset for the consumer.
consumer_id: Optional[str] = None # Used to store offset for the consumer.
lookback_days: Optional[int] = None
reset_offsets: Optional[bool] = False

View File

@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
class EventConsumerState(BaseModel):
VERSION = 1 # Increment this version when the schema of EventConsumerState changes
VERSION: int = (
1 # Increment this version when the schema of EventConsumerState changes
)
offset_id: Optional[str] = None
timestamp: Optional[int] = None

View File

@ -12,67 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from datahub.configuration.common import ConfigurationError
from datahub.ingestion.api.registry import PluginRegistry
from datahub_actions.action.action import Action
from datahub_actions.action.action_registry import action_registry
from datahub_actions.plugin.action.hello_world.hello_world import HelloWorldAction
def test_registry_nonempty():
assert len(action_registry.mapping) > 0
def test_registry():
fake_registry = PluginRegistry[Action]()
fake_registry.register("hello_world", HelloWorldAction)
assert len(fake_registry.mapping) > 0
assert fake_registry.is_enabled("hello_world")
assert fake_registry.get("hello_world") == HelloWorldAction
assert (
fake_registry.get(
"datahub_actions.plugin.action.hello_world.hello_world.HelloWorldAction"
)
== HelloWorldAction
)
# Test lazy-loading capabilities.
fake_registry.register_lazy(
"lazy-hello-world",
"datahub_actions.plugin.action.hello_world.hello_world:HelloWorldAction",
)
assert fake_registry.get("lazy-hello-world") == HelloWorldAction
# Test Registry Errors
fake_registry.register_lazy("lazy-error", "thisdoesnot.exist")
with pytest.raises(ConfigurationError, match="disabled"):
fake_registry.get("lazy-error")
with pytest.raises(KeyError, match="special characters"):
fake_registry.register("thisdoesnotexist.otherthing", HelloWorldAction)
with pytest.raises(KeyError, match="in use"):
fake_registry.register("hello_world", HelloWorldAction)
with pytest.raises(KeyError, match="not find"):
fake_registry.get("thisdoesnotexist")
# Test error-checking on registered types.
with pytest.raises(ValueError, match="abstract"):
fake_registry.register("thisdoesnotexist", Action) # type: ignore
class DummyClass: # Does not extend Action.
pass
with pytest.raises(ValueError, match="derived"):
fake_registry.register("thisdoesnotexist", DummyClass) # type: ignore
# Test disabled actions
fake_registry.register_disabled("disabled", ModuleNotFoundError("disabled action"))
fake_registry.register_disabled(
"disabled-exception", Exception("second disabled action")
)
with pytest.raises(ConfigurationError, match="disabled"):
fake_registry.get("disabled")
with pytest.raises(ConfigurationError, match="disabled"):
fake_registry.get("disabled-exception")
def test_all_registry_plugins_enabled() -> None:
for plugin in action_registry.mapping.keys():
assert action_registry.is_enabled(plugin), f"Plugin {plugin} is not enabled"

View File

@ -131,11 +131,14 @@ cachetools_lib = {
"cachetools",
}
sql_common_slim = {
# Required for all SQL sources.
# This is temporary lower bound that we're open to loosening/tightening as requirements show up
"sqlalchemy>=1.4.39, <2",
}
sql_common = (
{
# Required for all SQL sources.
# This is temporary lower bound that we're open to loosening/tightening as requirements show up
"sqlalchemy>=1.4.39, <2",
*sql_common_slim,
# Required for SQL profiling.
"great-expectations>=0.15.12, <=0.15.50",
*pydantic_no_v2, # because of great-expectations
@ -220,8 +223,6 @@ redshift_common = {
}
snowflake_common = {
# Snowflake plugin utilizes sql common
*sql_common,
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/350
"snowflake-sqlalchemy>=1.4.3",
"snowflake-connector-python>=3.4.0",
@ -229,7 +230,7 @@ snowflake_common = {
"cryptography",
"msal",
*cachetools_lib,
} | classification_lib
}
trino = {
"trino[sqlalchemy]>=0.308",
@ -400,6 +401,7 @@ plugins: Dict[str, Set[str]] = {
| {
"google-cloud-datacatalog-lineage==0.2.2",
},
"bigquery-slim": bigquery_common,
"bigquery-queries": sql_common | bigquery_common | sqlglot_lib,
"clickhouse": sql_common | clickhouse_common,
"clickhouse-usage": sql_common | usage_common | clickhouse_common,
@ -502,9 +504,10 @@ plugins: Dict[str, Set[str]] = {
"abs": {*abs_base, *data_lake_profiling},
"sagemaker": aws_common,
"salesforce": {"simple-salesforce", *cachetools_lib},
"snowflake": snowflake_common | usage_common | sqlglot_lib,
"snowflake-summary": snowflake_common | usage_common | sqlglot_lib,
"snowflake-queries": snowflake_common | usage_common | sqlglot_lib,
"snowflake": snowflake_common | sql_common | usage_common | sqlglot_lib,
"snowflake-slim": snowflake_common,
"snowflake-summary": snowflake_common | sql_common | usage_common | sqlglot_lib,
"snowflake-queries": snowflake_common | sql_common | usage_common | sqlglot_lib,
"sqlalchemy": sql_common,
"sql-queries": usage_common | sqlglot_lib,
"slack": slack,
@ -935,6 +938,8 @@ See the [DataHub docs](https://docs.datahub.com/docs/metadata-ingestion).
"sql-parser",
"iceberg",
"feast",
"bigquery-slim",
"snowflake-slim",
}
else set()
)