feat(sdk): use discriminated unions for Filter types (#14127)

This commit is contained in:
Harshal Sheth 2025-07-21 08:23:07 -07:00 committed by GitHub
parent 14b9bed58d
commit c92aa0842e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 34 deletions

View File

@ -1,12 +1,13 @@
import pydantic.version
from packaging.version import Version
PYDANTIC_VERSION_2: bool
if Version(pydantic.version.VERSION) >= Version("2.0"):
PYDANTIC_VERSION_2 = True
else:
PYDANTIC_VERSION_2 = False
_pydantic_version = Version(pydantic.version.VERSION)
PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0")
# The pydantic.Discriminator type was added in v2.5.0.
# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0")
# This can be used to silence deprecation warnings while we migrate.
if PYDANTIC_VERSION_2:
@ -50,6 +51,7 @@ class v1_ConfigModel(v1_BaseModel):
__all__ = [
"PYDANTIC_VERSION_2",
"PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR",
"PydanticDeprecatedSince20",
"GenericModel",
"v1_ConfigModel",

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import abc
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Iterator,
@ -15,7 +17,10 @@ from typing import (
import pydantic
from datahub.configuration.common import ConfigModel
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.configuration.pydantic_migration_helpers import (
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
PYDANTIC_VERSION_2,
)
from datahub.ingestion.graph.client import flexible_entity_type_to_graphql
from datahub.ingestion.graph.filters import (
FilterOperator,
@ -42,12 +47,29 @@ class _BaseFilter(ConfigModel):
populate_by_name = True
@abc.abstractmethod
def compile(self) -> _OrFilters:
pass
def compile(self) -> _OrFilters: ...
def dfs(self) -> Iterator[_BaseFilter]:
yield self
@classmethod
def _field_discriminator(cls) -> str:
if cls is _BaseFilter:
raise ValueError("Cannot get discriminator for _BaseFilter")
if PYDANTIC_VERSION_2:
fields: dict = cls.model_fields # type: ignore
else:
fields = cls.__fields__ # type: ignore
# Assumes that there's only one field name per filter.
# If that's not the case, this method should be overridden.
if len(fields.keys()) != 1:
raise ValueError(
f"Found multiple fields that could be the discriminator for this filter: {list(fields.keys())}"
)
name, field = next(iter(fields.items()))
return field.alias or name # type: ignore
class _EntityTypeFilter(_BaseFilter):
"""Filter for specific entity types.
@ -74,15 +96,19 @@ class _EntityTypeFilter(_BaseFilter):
class _EntitySubtypeFilter(_BaseFilter):
entity_subtype: str = pydantic.Field(
entity_subtype: List[str] = pydantic.Field(
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
)
@pydantic.validator("entity_subtype", pre=True)
def validate_entity_subtype(cls, v: str) -> List[str]:
return [v] if not isinstance(v, list) else v
def _build_rule(self) -> SearchFilterRule:
return SearchFilterRule(
field="typeNames",
condition="EQUAL",
values=[self.entity_subtype],
values=self.entity_subtype,
)
def compile(self) -> _OrFilters:
@ -196,6 +222,10 @@ class _CustomCondition(_BaseFilter):
)
return [{"and": [rule]}]
@classmethod
def _field_discriminator(cls) -> str:
return "_custom"
class _And(_BaseFilter):
"""Represents an AND conjunction of filters."""
@ -302,31 +332,69 @@ class _Not(_BaseFilter):
yield from self.not_.dfs()
# TODO: With pydantic 2, we can use a RootModel with a
# discriminated union to make the error messages more informative.
Filter = Union[
_And,
_Or,
_Not,
_EntityTypeFilter,
_EntitySubtypeFilter,
_StatusFilter,
_PlatformFilter,
_DomainFilter,
_EnvFilter,
_CustomCondition,
]
def _filter_discriminator(v: Any) -> Optional[str]:
if isinstance(v, _BaseFilter):
return v._field_discriminator()
if not isinstance(v, dict):
return None
keys = list(v.keys())
if len(keys) == 1:
return keys[0]
elif set(keys).issuperset({"field", "condition"}):
return _CustomCondition._field_discriminator()
return None
# Required to resolve forward references to "Filter"
if PYDANTIC_VERSION_2:
_And.model_rebuild() # type: ignore
_Or.model_rebuild() # type: ignore
_Not.model_rebuild() # type: ignore
else:
if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR:
# The `not TYPE_CHECKING` bit is required to make the linter happy,
# since we currently only run mypy with pydantic v1.
Filter = Union[
_And,
_Or,
_Not,
_EntityTypeFilter,
_EntitySubtypeFilter,
_StatusFilter,
_PlatformFilter,
_DomainFilter,
_EnvFilter,
_CustomCondition,
]
_And.update_forward_refs()
_Or.update_forward_refs()
_Not.update_forward_refs()
else:
from pydantic import Discriminator, Tag
# TODO: Once we're fully on pydantic 2, we can use a RootModel here.
# That way we'd be able to attach methods to the Filter type.
# e.g. replace load_filters(...) with Filter.load(...)
Filter = Annotated[
Union[
Annotated[_And, Tag(_And._field_discriminator())],
Annotated[_Or, Tag(_Or._field_discriminator())],
Annotated[_Not, Tag(_Not._field_discriminator())],
Annotated[_EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator())],
Annotated[
_EntitySubtypeFilter, Tag(_EntitySubtypeFilter._field_discriminator())
],
Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())],
Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())],
Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())],
Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())],
Annotated[_CustomCondition, Tag(_CustomCondition._field_discriminator())],
],
Discriminator(_filter_discriminator),
]
# Required to resolve forward references to "Filter"
_And.model_rebuild() # type: ignore
_Or.model_rebuild() # type: ignore
_Not.model_rebuild() # type: ignore
def load_filters(obj: Any) -> Filter:

View File

@ -1,6 +1,6 @@
import pathlib
from typing import Dict, List, Optional, Sequence, Set, cast
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch
import pytest
@ -141,7 +141,6 @@ def test_infer_lineage_from_sql(client: DataHubClient) -> None:
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
],
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
query_text = (
@ -197,7 +196,6 @@ def test_infer_lineage_from_sql_with_multiple_upstreams(
),
],
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
query_text = """

View File

@ -1,3 +1,4 @@
import re
import unittest
import unittest.mock
from io import StringIO
@ -6,6 +7,9 @@ import pytest
import yaml
from pydantic import ValidationError
from datahub.configuration.pydantic_migration_helpers import (
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
)
from datahub.ingestion.graph.filters import (
RemovedStatusFilter,
SearchFilterRule,
@ -14,7 +18,14 @@ from datahub.ingestion.graph.filters import (
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_client import compile_filters, compute_entity_types
from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters
from datahub.sdk.search_filters import (
Filter,
FilterDsl as F,
_BaseFilter,
_CustomCondition,
_filter_discriminator,
load_filters,
)
from datahub.utilities.urns.error import InvalidUrnError
from tests.test_helpers.graph_helpers import MockDataHubGraph
@ -170,6 +181,137 @@ and:
]
def test_entity_subtype_filter() -> None:
filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]})
assert filter_obj_1 == F.entity_subtype("Table")
# Ensure it works without the list wrapper to maintain backwards compatibility.
filter_obj_2: Filter = load_filters({"entity_subtype": "Table"})
assert filter_obj_1 == filter_obj_2
def test_filters_all_types() -> None:
filter_obj: Filter = load_filters(
{
"and": [
{
"or": [
{"entity_type": ["dataset"]},
{"entity_type": ["chart", "dashboard"]},
]
},
{"not": {"entity_subtype": ["Table"]}},
{"platform": ["snowflake"]},
{"domain": ["urn:li:domain:marketing"]},
{"env": ["PROD"]},
{"status": "NOT_SOFT_DELETED"},
{
"field": "custom_field",
"condition": "GREATER_THAN_OR_EQUAL_TO",
"values": ["5"],
},
]
}
)
assert filter_obj == F.and_(
F.or_(
F.entity_type("dataset"),
F.entity_type(["chart", "dashboard"]),
),
F.not_(F.entity_subtype("Table")),
F.platform("snowflake"),
F.domain("urn:li:domain:marketing"),
F.env("PROD"),
F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED),
F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]),
)
def test_field_discriminator() -> None:
with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"):
_BaseFilter._field_discriminator()
assert F.entity_type("dataset")._field_discriminator() == "entity_type"
assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not"
assert (
F.custom_filter(
"custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]
)._field_discriminator()
== _CustomCondition._field_discriminator()
)
class _BadFilter(_BaseFilter):
field1: str
field2: str
with pytest.raises(
ValueError,
match=re.escape(
"Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']"
),
):
_BadFilter._field_discriminator()
def test_filter_discriminator() -> None:
# Simple filter discriminator extraction.
assert _filter_discriminator(F.entity_type("dataset")) == "entity_type"
assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type"
assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not"
assert _filter_discriminator({"unknown_field": 6}) == "unknown_field"
assert _filter_discriminator({"field1": 6, "field2": 7}) is None
assert _filter_discriminator({}) is None
assert _filter_discriminator(6) is None
# Special case for custom conditions.
assert (
_filter_discriminator(
{
"field": "custom_field",
"condition": "GREATER_THAN_OR_EQUAL_TO",
"values": ["5"],
}
)
== "_custom"
)
assert (
_filter_discriminator(
{
"field": "custom_field",
"condition": "EXISTS",
}
)
== "_custom"
)
@pytest.mark.skipif(
not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
reason="Tagged union w/ callable discriminator is not supported by the current pydantic version",
)
def test_tagged_union_error_messages() -> None:
# With pydantic v1, we'd get 10+ validation errors and it'd be hard to
# understand what went wrong. With v2, we get a single simple error message.
with pytest.raises(
ValidationError,
match=re.compile(
r"1 validation error.*entity_type\.entity_type.*Input should be a valid list",
re.DOTALL,
),
):
load_filters({"entity_type": 6})
# Even when within an "and" clause, we get a single error message.
with pytest.raises(
ValidationError,
match=re.compile(
r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid",
re.DOTALL,
),
):
load_filters({"and": [{"unknown_field": 6}]})
def test_invalid_filter() -> None:
with pytest.raises(InvalidUrnError):
F.domain("marketing")