mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-02 03:39:03 +00:00
feat(sdk): use discriminated unions for Filter types (#14127)
This commit is contained in:
parent
14b9bed58d
commit
c92aa0842e
@ -1,12 +1,13 @@
|
||||
import pydantic.version
|
||||
from packaging.version import Version
|
||||
|
||||
PYDANTIC_VERSION_2: bool
|
||||
if Version(pydantic.version.VERSION) >= Version("2.0"):
|
||||
PYDANTIC_VERSION_2 = True
|
||||
else:
|
||||
PYDANTIC_VERSION_2 = False
|
||||
_pydantic_version = Version(pydantic.version.VERSION)
|
||||
|
||||
PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0")
|
||||
|
||||
# The pydantic.Discriminator type was added in v2.5.0.
|
||||
# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13
|
||||
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0")
|
||||
|
||||
# This can be used to silence deprecation warnings while we migrate.
|
||||
if PYDANTIC_VERSION_2:
|
||||
@ -50,6 +51,7 @@ class v1_ConfigModel(v1_BaseModel):
|
||||
|
||||
__all__ = [
|
||||
"PYDANTIC_VERSION_2",
|
||||
"PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR",
|
||||
"PydanticDeprecatedSince20",
|
||||
"GenericModel",
|
||||
"v1_ConfigModel",
|
||||
|
||||
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Iterator,
|
||||
@ -15,7 +17,10 @@ from typing import (
|
||||
import pydantic
|
||||
|
||||
from datahub.configuration.common import ConfigModel
|
||||
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
|
||||
from datahub.configuration.pydantic_migration_helpers import (
|
||||
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||
PYDANTIC_VERSION_2,
|
||||
)
|
||||
from datahub.ingestion.graph.client import flexible_entity_type_to_graphql
|
||||
from datahub.ingestion.graph.filters import (
|
||||
FilterOperator,
|
||||
@ -42,12 +47,29 @@ class _BaseFilter(ConfigModel):
|
||||
populate_by_name = True
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self) -> _OrFilters:
|
||||
pass
|
||||
def compile(self) -> _OrFilters: ...
|
||||
|
||||
def dfs(self) -> Iterator[_BaseFilter]:
|
||||
yield self
|
||||
|
||||
@classmethod
|
||||
def _field_discriminator(cls) -> str:
|
||||
if cls is _BaseFilter:
|
||||
raise ValueError("Cannot get discriminator for _BaseFilter")
|
||||
if PYDANTIC_VERSION_2:
|
||||
fields: dict = cls.model_fields # type: ignore
|
||||
else:
|
||||
fields = cls.__fields__ # type: ignore
|
||||
|
||||
# Assumes that there's only one field name per filter.
|
||||
# If that's not the case, this method should be overridden.
|
||||
if len(fields.keys()) != 1:
|
||||
raise ValueError(
|
||||
f"Found multiple fields that could be the discriminator for this filter: {list(fields.keys())}"
|
||||
)
|
||||
name, field = next(iter(fields.items()))
|
||||
return field.alias or name # type: ignore
|
||||
|
||||
|
||||
class _EntityTypeFilter(_BaseFilter):
|
||||
"""Filter for specific entity types.
|
||||
@ -74,15 +96,19 @@ class _EntityTypeFilter(_BaseFilter):
|
||||
|
||||
|
||||
class _EntitySubtypeFilter(_BaseFilter):
|
||||
entity_subtype: str = pydantic.Field(
|
||||
entity_subtype: List[str] = pydantic.Field(
|
||||
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
|
||||
)
|
||||
|
||||
@pydantic.validator("entity_subtype", pre=True)
|
||||
def validate_entity_subtype(cls, v: str) -> List[str]:
|
||||
return [v] if not isinstance(v, list) else v
|
||||
|
||||
def _build_rule(self) -> SearchFilterRule:
|
||||
return SearchFilterRule(
|
||||
field="typeNames",
|
||||
condition="EQUAL",
|
||||
values=[self.entity_subtype],
|
||||
values=self.entity_subtype,
|
||||
)
|
||||
|
||||
def compile(self) -> _OrFilters:
|
||||
@ -196,6 +222,10 @@ class _CustomCondition(_BaseFilter):
|
||||
)
|
||||
return [{"and": [rule]}]
|
||||
|
||||
@classmethod
|
||||
def _field_discriminator(cls) -> str:
|
||||
return "_custom"
|
||||
|
||||
|
||||
class _And(_BaseFilter):
|
||||
"""Represents an AND conjunction of filters."""
|
||||
@ -302,31 +332,69 @@ class _Not(_BaseFilter):
|
||||
yield from self.not_.dfs()
|
||||
|
||||
|
||||
# TODO: With pydantic 2, we can use a RootModel with a
|
||||
# discriminated union to make the error messages more informative.
|
||||
Filter = Union[
|
||||
_And,
|
||||
_Or,
|
||||
_Not,
|
||||
_EntityTypeFilter,
|
||||
_EntitySubtypeFilter,
|
||||
_StatusFilter,
|
||||
_PlatformFilter,
|
||||
_DomainFilter,
|
||||
_EnvFilter,
|
||||
_CustomCondition,
|
||||
]
|
||||
def _filter_discriminator(v: Any) -> Optional[str]:
|
||||
if isinstance(v, _BaseFilter):
|
||||
return v._field_discriminator()
|
||||
|
||||
if not isinstance(v, dict):
|
||||
return None
|
||||
|
||||
keys = list(v.keys())
|
||||
if len(keys) == 1:
|
||||
return keys[0]
|
||||
elif set(keys).issuperset({"field", "condition"}):
|
||||
return _CustomCondition._field_discriminator()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Required to resolve forward references to "Filter"
|
||||
if PYDANTIC_VERSION_2:
|
||||
_And.model_rebuild() # type: ignore
|
||||
_Or.model_rebuild() # type: ignore
|
||||
_Not.model_rebuild() # type: ignore
|
||||
else:
|
||||
if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR:
|
||||
# The `not TYPE_CHECKING` bit is required to make the linter happy,
|
||||
# since we currently only run mypy with pydantic v1.
|
||||
Filter = Union[
|
||||
_And,
|
||||
_Or,
|
||||
_Not,
|
||||
_EntityTypeFilter,
|
||||
_EntitySubtypeFilter,
|
||||
_StatusFilter,
|
||||
_PlatformFilter,
|
||||
_DomainFilter,
|
||||
_EnvFilter,
|
||||
_CustomCondition,
|
||||
]
|
||||
|
||||
_And.update_forward_refs()
|
||||
_Or.update_forward_refs()
|
||||
_Not.update_forward_refs()
|
||||
else:
|
||||
from pydantic import Discriminator, Tag
|
||||
|
||||
# TODO: Once we're fully on pydantic 2, we can use a RootModel here.
|
||||
# That way we'd be able to attach methods to the Filter type.
|
||||
# e.g. replace load_filters(...) with Filter.load(...)
|
||||
Filter = Annotated[
|
||||
Union[
|
||||
Annotated[_And, Tag(_And._field_discriminator())],
|
||||
Annotated[_Or, Tag(_Or._field_discriminator())],
|
||||
Annotated[_Not, Tag(_Not._field_discriminator())],
|
||||
Annotated[_EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator())],
|
||||
Annotated[
|
||||
_EntitySubtypeFilter, Tag(_EntitySubtypeFilter._field_discriminator())
|
||||
],
|
||||
Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())],
|
||||
Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())],
|
||||
Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())],
|
||||
Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())],
|
||||
Annotated[_CustomCondition, Tag(_CustomCondition._field_discriminator())],
|
||||
],
|
||||
Discriminator(_filter_discriminator),
|
||||
]
|
||||
|
||||
# Required to resolve forward references to "Filter"
|
||||
_And.model_rebuild() # type: ignore
|
||||
_Or.model_rebuild() # type: ignore
|
||||
_Not.model_rebuild() # type: ignore
|
||||
|
||||
|
||||
def load_filters(obj: Any) -> Filter:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional, Sequence, Set, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -141,7 +141,6 @@ def test_infer_lineage_from_sql(client: DataHubClient) -> None:
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||
],
|
||||
query_type=QueryType.SELECT,
|
||||
debug_info=MagicMock(error=None, table_error=None),
|
||||
)
|
||||
|
||||
query_text = (
|
||||
@ -197,7 +196,6 @@ def test_infer_lineage_from_sql_with_multiple_upstreams(
|
||||
),
|
||||
],
|
||||
query_type=QueryType.SELECT,
|
||||
debug_info=MagicMock(error=None, table_error=None),
|
||||
)
|
||||
|
||||
query_text = """
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from io import StringIO
|
||||
@ -6,6 +7,9 @@ import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from datahub.configuration.pydantic_migration_helpers import (
|
||||
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||
)
|
||||
from datahub.ingestion.graph.filters import (
|
||||
RemovedStatusFilter,
|
||||
SearchFilterRule,
|
||||
@ -14,7 +18,14 @@ from datahub.ingestion.graph.filters import (
|
||||
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sdk.search_client import compile_filters, compute_entity_types
|
||||
from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters
|
||||
from datahub.sdk.search_filters import (
|
||||
Filter,
|
||||
FilterDsl as F,
|
||||
_BaseFilter,
|
||||
_CustomCondition,
|
||||
_filter_discriminator,
|
||||
load_filters,
|
||||
)
|
||||
from datahub.utilities.urns.error import InvalidUrnError
|
||||
from tests.test_helpers.graph_helpers import MockDataHubGraph
|
||||
|
||||
@ -170,6 +181,137 @@ and:
|
||||
]
|
||||
|
||||
|
||||
def test_entity_subtype_filter() -> None:
|
||||
filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]})
|
||||
assert filter_obj_1 == F.entity_subtype("Table")
|
||||
|
||||
# Ensure it works without the list wrapper to maintain backwards compatibility.
|
||||
filter_obj_2: Filter = load_filters({"entity_subtype": "Table"})
|
||||
assert filter_obj_1 == filter_obj_2
|
||||
|
||||
|
||||
def test_filters_all_types() -> None:
|
||||
filter_obj: Filter = load_filters(
|
||||
{
|
||||
"and": [
|
||||
{
|
||||
"or": [
|
||||
{"entity_type": ["dataset"]},
|
||||
{"entity_type": ["chart", "dashboard"]},
|
||||
]
|
||||
},
|
||||
{"not": {"entity_subtype": ["Table"]}},
|
||||
{"platform": ["snowflake"]},
|
||||
{"domain": ["urn:li:domain:marketing"]},
|
||||
{"env": ["PROD"]},
|
||||
{"status": "NOT_SOFT_DELETED"},
|
||||
{
|
||||
"field": "custom_field",
|
||||
"condition": "GREATER_THAN_OR_EQUAL_TO",
|
||||
"values": ["5"],
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
assert filter_obj == F.and_(
|
||||
F.or_(
|
||||
F.entity_type("dataset"),
|
||||
F.entity_type(["chart", "dashboard"]),
|
||||
),
|
||||
F.not_(F.entity_subtype("Table")),
|
||||
F.platform("snowflake"),
|
||||
F.domain("urn:li:domain:marketing"),
|
||||
F.env("PROD"),
|
||||
F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED),
|
||||
F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]),
|
||||
)
|
||||
|
||||
|
||||
def test_field_discriminator() -> None:
|
||||
with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"):
|
||||
_BaseFilter._field_discriminator()
|
||||
|
||||
assert F.entity_type("dataset")._field_discriminator() == "entity_type"
|
||||
assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not"
|
||||
assert (
|
||||
F.custom_filter(
|
||||
"custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]
|
||||
)._field_discriminator()
|
||||
== _CustomCondition._field_discriminator()
|
||||
)
|
||||
|
||||
class _BadFilter(_BaseFilter):
|
||||
field1: str
|
||||
field2: str
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']"
|
||||
),
|
||||
):
|
||||
_BadFilter._field_discriminator()
|
||||
|
||||
|
||||
def test_filter_discriminator() -> None:
|
||||
# Simple filter discriminator extraction.
|
||||
assert _filter_discriminator(F.entity_type("dataset")) == "entity_type"
|
||||
assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type"
|
||||
assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not"
|
||||
assert _filter_discriminator({"unknown_field": 6}) == "unknown_field"
|
||||
assert _filter_discriminator({"field1": 6, "field2": 7}) is None
|
||||
assert _filter_discriminator({}) is None
|
||||
assert _filter_discriminator(6) is None
|
||||
|
||||
# Special case for custom conditions.
|
||||
assert (
|
||||
_filter_discriminator(
|
||||
{
|
||||
"field": "custom_field",
|
||||
"condition": "GREATER_THAN_OR_EQUAL_TO",
|
||||
"values": ["5"],
|
||||
}
|
||||
)
|
||||
== "_custom"
|
||||
)
|
||||
assert (
|
||||
_filter_discriminator(
|
||||
{
|
||||
"field": "custom_field",
|
||||
"condition": "EXISTS",
|
||||
}
|
||||
)
|
||||
== "_custom"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||
reason="Tagged union w/ callable discriminator is not supported by the current pydantic version",
|
||||
)
|
||||
def test_tagged_union_error_messages() -> None:
|
||||
# With pydantic v1, we'd get 10+ validation errors and it'd be hard to
|
||||
# understand what went wrong. With v2, we get a single simple error message.
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match=re.compile(
|
||||
r"1 validation error.*entity_type\.entity_type.*Input should be a valid list",
|
||||
re.DOTALL,
|
||||
),
|
||||
):
|
||||
load_filters({"entity_type": 6})
|
||||
|
||||
# Even when within an "and" clause, we get a single error message.
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match=re.compile(
|
||||
r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid",
|
||||
re.DOTALL,
|
||||
),
|
||||
):
|
||||
load_filters({"and": [{"unknown_field": 6}]})
|
||||
|
||||
|
||||
def test_invalid_filter() -> None:
|
||||
with pytest.raises(InvalidUrnError):
|
||||
F.domain("marketing")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user