feat(sdk): use discriminated unions for Filter types (#14127)

This commit is contained in:
Harshal Sheth 2025-07-21 08:23:07 -07:00 committed by GitHub
parent 14b9bed58d
commit c92aa0842e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 34 deletions

View File

@ -1,12 +1,13 @@
import pydantic.version import pydantic.version
from packaging.version import Version from packaging.version import Version
PYDANTIC_VERSION_2: bool _pydantic_version = Version(pydantic.version.VERSION)
if Version(pydantic.version.VERSION) >= Version("2.0"):
PYDANTIC_VERSION_2 = True
else:
PYDANTIC_VERSION_2 = False
PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0")
# The pydantic.Discriminator type was added in v2.5.0.
# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0")
# This can be used to silence deprecation warnings while we migrate. # This can be used to silence deprecation warnings while we migrate.
if PYDANTIC_VERSION_2: if PYDANTIC_VERSION_2:
@ -50,6 +51,7 @@ class v1_ConfigModel(v1_BaseModel):
__all__ = [ __all__ = [
"PYDANTIC_VERSION_2", "PYDANTIC_VERSION_2",
"PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR",
"PydanticDeprecatedSince20", "PydanticDeprecatedSince20",
"GenericModel", "GenericModel",
"v1_ConfigModel", "v1_ConfigModel",

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import abc import abc
from typing import ( from typing import (
TYPE_CHECKING,
Annotated,
Any, Any,
ClassVar, ClassVar,
Iterator, Iterator,
@ -15,7 +17,10 @@ from typing import (
import pydantic import pydantic
from datahub.configuration.common import ConfigModel from datahub.configuration.common import ConfigModel
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 from datahub.configuration.pydantic_migration_helpers import (
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
PYDANTIC_VERSION_2,
)
from datahub.ingestion.graph.client import flexible_entity_type_to_graphql from datahub.ingestion.graph.client import flexible_entity_type_to_graphql
from datahub.ingestion.graph.filters import ( from datahub.ingestion.graph.filters import (
FilterOperator, FilterOperator,
@ -42,12 +47,29 @@ class _BaseFilter(ConfigModel):
populate_by_name = True populate_by_name = True
@abc.abstractmethod @abc.abstractmethod
def compile(self) -> _OrFilters: def compile(self) -> _OrFilters: ...
pass
def dfs(self) -> Iterator[_BaseFilter]: def dfs(self) -> Iterator[_BaseFilter]:
yield self yield self
@classmethod
def _field_discriminator(cls) -> str:
if cls is _BaseFilter:
raise ValueError("Cannot get discriminator for _BaseFilter")
if PYDANTIC_VERSION_2:
fields: dict = cls.model_fields # type: ignore
else:
fields = cls.__fields__ # type: ignore
# Assumes that there's only one field name per filter.
# If that's not the case, this method should be overridden.
if len(fields.keys()) != 1:
raise ValueError(
f"Found multiple fields that could be the discriminator for this filter: {list(fields.keys())}"
)
name, field = next(iter(fields.items()))
return field.alias or name # type: ignore
class _EntityTypeFilter(_BaseFilter): class _EntityTypeFilter(_BaseFilter):
"""Filter for specific entity types. """Filter for specific entity types.
@ -74,15 +96,19 @@ class _EntityTypeFilter(_BaseFilter):
class _EntitySubtypeFilter(_BaseFilter): class _EntitySubtypeFilter(_BaseFilter):
entity_subtype: str = pydantic.Field( entity_subtype: List[str] = pydantic.Field(
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.", description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
) )
@pydantic.validator("entity_subtype", pre=True)
def validate_entity_subtype(cls, v: str) -> List[str]:
return [v] if not isinstance(v, list) else v
def _build_rule(self) -> SearchFilterRule: def _build_rule(self) -> SearchFilterRule:
return SearchFilterRule( return SearchFilterRule(
field="typeNames", field="typeNames",
condition="EQUAL", condition="EQUAL",
values=[self.entity_subtype], values=self.entity_subtype,
) )
def compile(self) -> _OrFilters: def compile(self) -> _OrFilters:
@ -196,6 +222,10 @@ class _CustomCondition(_BaseFilter):
) )
return [{"and": [rule]}] return [{"and": [rule]}]
@classmethod
def _field_discriminator(cls) -> str:
return "_custom"
class _And(_BaseFilter): class _And(_BaseFilter):
"""Represents an AND conjunction of filters.""" """Represents an AND conjunction of filters."""
@ -302,31 +332,69 @@ class _Not(_BaseFilter):
yield from self.not_.dfs() yield from self.not_.dfs()
# TODO: With pydantic 2, we can use a RootModel with a def _filter_discriminator(v: Any) -> Optional[str]:
# discriminated union to make the error messages more informative. if isinstance(v, _BaseFilter):
Filter = Union[ return v._field_discriminator()
_And,
_Or, if not isinstance(v, dict):
_Not, return None
_EntityTypeFilter,
_EntitySubtypeFilter, keys = list(v.keys())
_StatusFilter, if len(keys) == 1:
_PlatformFilter, return keys[0]
_DomainFilter, elif set(keys).issuperset({"field", "condition"}):
_EnvFilter, return _CustomCondition._field_discriminator()
_CustomCondition,
] return None
# Required to resolve forward references to "Filter" if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR:
if PYDANTIC_VERSION_2: # The `not TYPE_CHECKING` bit is required to make the linter happy,
_And.model_rebuild() # type: ignore # since we currently only run mypy with pydantic v1.
_Or.model_rebuild() # type: ignore Filter = Union[
_Not.model_rebuild() # type: ignore _And,
else: _Or,
_Not,
_EntityTypeFilter,
_EntitySubtypeFilter,
_StatusFilter,
_PlatformFilter,
_DomainFilter,
_EnvFilter,
_CustomCondition,
]
_And.update_forward_refs() _And.update_forward_refs()
_Or.update_forward_refs() _Or.update_forward_refs()
_Not.update_forward_refs() _Not.update_forward_refs()
else:
from pydantic import Discriminator, Tag
# TODO: Once we're fully on pydantic 2, we can use a RootModel here.
# That way we'd be able to attach methods to the Filter type.
# e.g. replace load_filters(...) with Filter.load(...)
Filter = Annotated[
Union[
Annotated[_And, Tag(_And._field_discriminator())],
Annotated[_Or, Tag(_Or._field_discriminator())],
Annotated[_Not, Tag(_Not._field_discriminator())],
Annotated[_EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator())],
Annotated[
_EntitySubtypeFilter, Tag(_EntitySubtypeFilter._field_discriminator())
],
Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())],
Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())],
Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())],
Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())],
Annotated[_CustomCondition, Tag(_CustomCondition._field_discriminator())],
],
Discriminator(_filter_discriminator),
]
# Required to resolve forward references to "Filter"
_And.model_rebuild() # type: ignore
_Or.model_rebuild() # type: ignore
_Not.model_rebuild() # type: ignore
def load_filters(obj: Any) -> Filter: def load_filters(obj: Any) -> Filter:

View File

@ -1,6 +1,6 @@
import pathlib import pathlib
from typing import Dict, List, Optional, Sequence, Set, cast from typing import Dict, List, Optional, Sequence, Set, cast
from unittest.mock import MagicMock, Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -141,7 +141,6 @@ def test_infer_lineage_from_sql(client: DataHubClient) -> None:
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
], ],
query_type=QueryType.SELECT, query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
) )
query_text = ( query_text = (
@ -197,7 +196,6 @@ def test_infer_lineage_from_sql_with_multiple_upstreams(
), ),
], ],
query_type=QueryType.SELECT, query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
) )
query_text = """ query_text = """

View File

@ -1,3 +1,4 @@
import re
import unittest import unittest
import unittest.mock import unittest.mock
from io import StringIO from io import StringIO
@ -6,6 +7,9 @@ import pytest
import yaml import yaml
from pydantic import ValidationError from pydantic import ValidationError
from datahub.configuration.pydantic_migration_helpers import (
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
)
from datahub.ingestion.graph.filters import ( from datahub.ingestion.graph.filters import (
RemovedStatusFilter, RemovedStatusFilter,
SearchFilterRule, SearchFilterRule,
@ -14,7 +18,14 @@ from datahub.ingestion.graph.filters import (
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_client import compile_filters, compute_entity_types from datahub.sdk.search_client import compile_filters, compute_entity_types
from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters from datahub.sdk.search_filters import (
Filter,
FilterDsl as F,
_BaseFilter,
_CustomCondition,
_filter_discriminator,
load_filters,
)
from datahub.utilities.urns.error import InvalidUrnError from datahub.utilities.urns.error import InvalidUrnError
from tests.test_helpers.graph_helpers import MockDataHubGraph from tests.test_helpers.graph_helpers import MockDataHubGraph
@ -170,6 +181,137 @@ and:
] ]
def test_entity_subtype_filter() -> None:
filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]})
assert filter_obj_1 == F.entity_subtype("Table")
# Ensure it works without the list wrapper to maintain backwards compatibility.
filter_obj_2: Filter = load_filters({"entity_subtype": "Table"})
assert filter_obj_1 == filter_obj_2
def test_filters_all_types() -> None:
filter_obj: Filter = load_filters(
{
"and": [
{
"or": [
{"entity_type": ["dataset"]},
{"entity_type": ["chart", "dashboard"]},
]
},
{"not": {"entity_subtype": ["Table"]}},
{"platform": ["snowflake"]},
{"domain": ["urn:li:domain:marketing"]},
{"env": ["PROD"]},
{"status": "NOT_SOFT_DELETED"},
{
"field": "custom_field",
"condition": "GREATER_THAN_OR_EQUAL_TO",
"values": ["5"],
},
]
}
)
assert filter_obj == F.and_(
F.or_(
F.entity_type("dataset"),
F.entity_type(["chart", "dashboard"]),
),
F.not_(F.entity_subtype("Table")),
F.platform("snowflake"),
F.domain("urn:li:domain:marketing"),
F.env("PROD"),
F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED),
F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]),
)
def test_field_discriminator() -> None:
with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"):
_BaseFilter._field_discriminator()
assert F.entity_type("dataset")._field_discriminator() == "entity_type"
assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not"
assert (
F.custom_filter(
"custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]
)._field_discriminator()
== _CustomCondition._field_discriminator()
)
class _BadFilter(_BaseFilter):
field1: str
field2: str
with pytest.raises(
ValueError,
match=re.escape(
"Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']"
),
):
_BadFilter._field_discriminator()
def test_filter_discriminator() -> None:
# Simple filter discriminator extraction.
assert _filter_discriminator(F.entity_type("dataset")) == "entity_type"
assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type"
assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not"
assert _filter_discriminator({"unknown_field": 6}) == "unknown_field"
assert _filter_discriminator({"field1": 6, "field2": 7}) is None
assert _filter_discriminator({}) is None
assert _filter_discriminator(6) is None
# Special case for custom conditions.
assert (
_filter_discriminator(
{
"field": "custom_field",
"condition": "GREATER_THAN_OR_EQUAL_TO",
"values": ["5"],
}
)
== "_custom"
)
assert (
_filter_discriminator(
{
"field": "custom_field",
"condition": "EXISTS",
}
)
== "_custom"
)
@pytest.mark.skipif(
not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
reason="Tagged union w/ callable discriminator is not supported by the current pydantic version",
)
def test_tagged_union_error_messages() -> None:
# With pydantic v1, we'd get 10+ validation errors and it'd be hard to
# understand what went wrong. With v2, we get a single simple error message.
with pytest.raises(
ValidationError,
match=re.compile(
r"1 validation error.*entity_type\.entity_type.*Input should be a valid list",
re.DOTALL,
),
):
load_filters({"entity_type": 6})
# Even when within an "and" clause, we get a single error message.
with pytest.raises(
ValidationError,
match=re.compile(
r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid",
re.DOTALL,
),
):
load_filters({"and": [{"unknown_field": 6}]})
def test_invalid_filter() -> None: def test_invalid_filter() -> None:
with pytest.raises(InvalidUrnError): with pytest.raises(InvalidUrnError):
F.domain("marketing") F.domain("marketing")