mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-03 04:10:43 +00:00
feat(sdk): use discriminated unions for Filter types (#14127)
This commit is contained in:
parent
14b9bed58d
commit
c92aa0842e
@ -1,12 +1,13 @@
|
|||||||
import pydantic.version
|
import pydantic.version
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
PYDANTIC_VERSION_2: bool
|
_pydantic_version = Version(pydantic.version.VERSION)
|
||||||
if Version(pydantic.version.VERSION) >= Version("2.0"):
|
|
||||||
PYDANTIC_VERSION_2 = True
|
|
||||||
else:
|
|
||||||
PYDANTIC_VERSION_2 = False
|
|
||||||
|
|
||||||
|
PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0")
|
||||||
|
|
||||||
|
# The pydantic.Discriminator type was added in v2.5.0.
|
||||||
|
# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13
|
||||||
|
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0")
|
||||||
|
|
||||||
# This can be used to silence deprecation warnings while we migrate.
|
# This can be used to silence deprecation warnings while we migrate.
|
||||||
if PYDANTIC_VERSION_2:
|
if PYDANTIC_VERSION_2:
|
||||||
@ -50,6 +51,7 @@ class v1_ConfigModel(v1_BaseModel):
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PYDANTIC_VERSION_2",
|
"PYDANTIC_VERSION_2",
|
||||||
|
"PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR",
|
||||||
"PydanticDeprecatedSince20",
|
"PydanticDeprecatedSince20",
|
||||||
"GenericModel",
|
"GenericModel",
|
||||||
"v1_ConfigModel",
|
"v1_ConfigModel",
|
||||||
|
|||||||
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -15,7 +17,10 @@ from typing import (
|
|||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from datahub.configuration.common import ConfigModel
|
from datahub.configuration.common import ConfigModel
|
||||||
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
|
from datahub.configuration.pydantic_migration_helpers import (
|
||||||
|
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||||
|
PYDANTIC_VERSION_2,
|
||||||
|
)
|
||||||
from datahub.ingestion.graph.client import flexible_entity_type_to_graphql
|
from datahub.ingestion.graph.client import flexible_entity_type_to_graphql
|
||||||
from datahub.ingestion.graph.filters import (
|
from datahub.ingestion.graph.filters import (
|
||||||
FilterOperator,
|
FilterOperator,
|
||||||
@ -42,12 +47,29 @@ class _BaseFilter(ConfigModel):
|
|||||||
populate_by_name = True
|
populate_by_name = True
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(self) -> _OrFilters:
|
def compile(self) -> _OrFilters: ...
|
||||||
pass
|
|
||||||
|
|
||||||
def dfs(self) -> Iterator[_BaseFilter]:
|
def dfs(self) -> Iterator[_BaseFilter]:
|
||||||
yield self
|
yield self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _field_discriminator(cls) -> str:
|
||||||
|
if cls is _BaseFilter:
|
||||||
|
raise ValueError("Cannot get discriminator for _BaseFilter")
|
||||||
|
if PYDANTIC_VERSION_2:
|
||||||
|
fields: dict = cls.model_fields # type: ignore
|
||||||
|
else:
|
||||||
|
fields = cls.__fields__ # type: ignore
|
||||||
|
|
||||||
|
# Assumes that there's only one field name per filter.
|
||||||
|
# If that's not the case, this method should be overridden.
|
||||||
|
if len(fields.keys()) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found multiple fields that could be the discriminator for this filter: {list(fields.keys())}"
|
||||||
|
)
|
||||||
|
name, field = next(iter(fields.items()))
|
||||||
|
return field.alias or name # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class _EntityTypeFilter(_BaseFilter):
|
class _EntityTypeFilter(_BaseFilter):
|
||||||
"""Filter for specific entity types.
|
"""Filter for specific entity types.
|
||||||
@ -74,15 +96,19 @@ class _EntityTypeFilter(_BaseFilter):
|
|||||||
|
|
||||||
|
|
||||||
class _EntitySubtypeFilter(_BaseFilter):
|
class _EntitySubtypeFilter(_BaseFilter):
|
||||||
entity_subtype: str = pydantic.Field(
|
entity_subtype: List[str] = pydantic.Field(
|
||||||
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
|
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pydantic.validator("entity_subtype", pre=True)
|
||||||
|
def validate_entity_subtype(cls, v: str) -> List[str]:
|
||||||
|
return [v] if not isinstance(v, list) else v
|
||||||
|
|
||||||
def _build_rule(self) -> SearchFilterRule:
|
def _build_rule(self) -> SearchFilterRule:
|
||||||
return SearchFilterRule(
|
return SearchFilterRule(
|
||||||
field="typeNames",
|
field="typeNames",
|
||||||
condition="EQUAL",
|
condition="EQUAL",
|
||||||
values=[self.entity_subtype],
|
values=self.entity_subtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compile(self) -> _OrFilters:
|
def compile(self) -> _OrFilters:
|
||||||
@ -196,6 +222,10 @@ class _CustomCondition(_BaseFilter):
|
|||||||
)
|
)
|
||||||
return [{"and": [rule]}]
|
return [{"and": [rule]}]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _field_discriminator(cls) -> str:
|
||||||
|
return "_custom"
|
||||||
|
|
||||||
|
|
||||||
class _And(_BaseFilter):
|
class _And(_BaseFilter):
|
||||||
"""Represents an AND conjunction of filters."""
|
"""Represents an AND conjunction of filters."""
|
||||||
@ -302,31 +332,69 @@ class _Not(_BaseFilter):
|
|||||||
yield from self.not_.dfs()
|
yield from self.not_.dfs()
|
||||||
|
|
||||||
|
|
||||||
# TODO: With pydantic 2, we can use a RootModel with a
|
def _filter_discriminator(v: Any) -> Optional[str]:
|
||||||
# discriminated union to make the error messages more informative.
|
if isinstance(v, _BaseFilter):
|
||||||
Filter = Union[
|
return v._field_discriminator()
|
||||||
_And,
|
|
||||||
_Or,
|
if not isinstance(v, dict):
|
||||||
_Not,
|
return None
|
||||||
_EntityTypeFilter,
|
|
||||||
_EntitySubtypeFilter,
|
keys = list(v.keys())
|
||||||
_StatusFilter,
|
if len(keys) == 1:
|
||||||
_PlatformFilter,
|
return keys[0]
|
||||||
_DomainFilter,
|
elif set(keys).issuperset({"field", "condition"}):
|
||||||
_EnvFilter,
|
return _CustomCondition._field_discriminator()
|
||||||
_CustomCondition,
|
|
||||||
]
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Required to resolve forward references to "Filter"
|
if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR:
|
||||||
if PYDANTIC_VERSION_2:
|
# The `not TYPE_CHECKING` bit is required to make the linter happy,
|
||||||
_And.model_rebuild() # type: ignore
|
# since we currently only run mypy with pydantic v1.
|
||||||
_Or.model_rebuild() # type: ignore
|
Filter = Union[
|
||||||
_Not.model_rebuild() # type: ignore
|
_And,
|
||||||
else:
|
_Or,
|
||||||
|
_Not,
|
||||||
|
_EntityTypeFilter,
|
||||||
|
_EntitySubtypeFilter,
|
||||||
|
_StatusFilter,
|
||||||
|
_PlatformFilter,
|
||||||
|
_DomainFilter,
|
||||||
|
_EnvFilter,
|
||||||
|
_CustomCondition,
|
||||||
|
]
|
||||||
|
|
||||||
_And.update_forward_refs()
|
_And.update_forward_refs()
|
||||||
_Or.update_forward_refs()
|
_Or.update_forward_refs()
|
||||||
_Not.update_forward_refs()
|
_Not.update_forward_refs()
|
||||||
|
else:
|
||||||
|
from pydantic import Discriminator, Tag
|
||||||
|
|
||||||
|
# TODO: Once we're fully on pydantic 2, we can use a RootModel here.
|
||||||
|
# That way we'd be able to attach methods to the Filter type.
|
||||||
|
# e.g. replace load_filters(...) with Filter.load(...)
|
||||||
|
Filter = Annotated[
|
||||||
|
Union[
|
||||||
|
Annotated[_And, Tag(_And._field_discriminator())],
|
||||||
|
Annotated[_Or, Tag(_Or._field_discriminator())],
|
||||||
|
Annotated[_Not, Tag(_Not._field_discriminator())],
|
||||||
|
Annotated[_EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator())],
|
||||||
|
Annotated[
|
||||||
|
_EntitySubtypeFilter, Tag(_EntitySubtypeFilter._field_discriminator())
|
||||||
|
],
|
||||||
|
Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())],
|
||||||
|
Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())],
|
||||||
|
Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())],
|
||||||
|
Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())],
|
||||||
|
Annotated[_CustomCondition, Tag(_CustomCondition._field_discriminator())],
|
||||||
|
],
|
||||||
|
Discriminator(_filter_discriminator),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Required to resolve forward references to "Filter"
|
||||||
|
_And.model_rebuild() # type: ignore
|
||||||
|
_Or.model_rebuild() # type: ignore
|
||||||
|
_Not.model_rebuild() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def load_filters(obj: Any) -> Filter:
|
def load_filters(obj: Any) -> Filter:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
from typing import Dict, List, Optional, Sequence, Set, cast
|
from typing import Dict, List, Optional, Sequence, Set, cast
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -141,7 +141,6 @@ def test_infer_lineage_from_sql(client: DataHubClient) -> None:
|
|||||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||||
],
|
],
|
||||||
query_type=QueryType.SELECT,
|
query_type=QueryType.SELECT,
|
||||||
debug_info=MagicMock(error=None, table_error=None),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_text = (
|
query_text = (
|
||||||
@ -197,7 +196,6 @@ def test_infer_lineage_from_sql_with_multiple_upstreams(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
query_type=QueryType.SELECT,
|
query_type=QueryType.SELECT,
|
||||||
debug_info=MagicMock(error=None, table_error=None),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_text = """
|
query_text = """
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@ -6,6 +7,9 @@ import pytest
|
|||||||
import yaml
|
import yaml
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from datahub.configuration.pydantic_migration_helpers import (
|
||||||
|
PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||||
|
)
|
||||||
from datahub.ingestion.graph.filters import (
|
from datahub.ingestion.graph.filters import (
|
||||||
RemovedStatusFilter,
|
RemovedStatusFilter,
|
||||||
SearchFilterRule,
|
SearchFilterRule,
|
||||||
@ -14,7 +18,14 @@ from datahub.ingestion.graph.filters import (
|
|||||||
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
|
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
|
||||||
from datahub.sdk.main_client import DataHubClient
|
from datahub.sdk.main_client import DataHubClient
|
||||||
from datahub.sdk.search_client import compile_filters, compute_entity_types
|
from datahub.sdk.search_client import compile_filters, compute_entity_types
|
||||||
from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters
|
from datahub.sdk.search_filters import (
|
||||||
|
Filter,
|
||||||
|
FilterDsl as F,
|
||||||
|
_BaseFilter,
|
||||||
|
_CustomCondition,
|
||||||
|
_filter_discriminator,
|
||||||
|
load_filters,
|
||||||
|
)
|
||||||
from datahub.utilities.urns.error import InvalidUrnError
|
from datahub.utilities.urns.error import InvalidUrnError
|
||||||
from tests.test_helpers.graph_helpers import MockDataHubGraph
|
from tests.test_helpers.graph_helpers import MockDataHubGraph
|
||||||
|
|
||||||
@ -170,6 +181,137 @@ and:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_subtype_filter() -> None:
|
||||||
|
filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]})
|
||||||
|
assert filter_obj_1 == F.entity_subtype("Table")
|
||||||
|
|
||||||
|
# Ensure it works without the list wrapper to maintain backwards compatibility.
|
||||||
|
filter_obj_2: Filter = load_filters({"entity_subtype": "Table"})
|
||||||
|
assert filter_obj_1 == filter_obj_2
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_all_types() -> None:
|
||||||
|
filter_obj: Filter = load_filters(
|
||||||
|
{
|
||||||
|
"and": [
|
||||||
|
{
|
||||||
|
"or": [
|
||||||
|
{"entity_type": ["dataset"]},
|
||||||
|
{"entity_type": ["chart", "dashboard"]},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"not": {"entity_subtype": ["Table"]}},
|
||||||
|
{"platform": ["snowflake"]},
|
||||||
|
{"domain": ["urn:li:domain:marketing"]},
|
||||||
|
{"env": ["PROD"]},
|
||||||
|
{"status": "NOT_SOFT_DELETED"},
|
||||||
|
{
|
||||||
|
"field": "custom_field",
|
||||||
|
"condition": "GREATER_THAN_OR_EQUAL_TO",
|
||||||
|
"values": ["5"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert filter_obj == F.and_(
|
||||||
|
F.or_(
|
||||||
|
F.entity_type("dataset"),
|
||||||
|
F.entity_type(["chart", "dashboard"]),
|
||||||
|
),
|
||||||
|
F.not_(F.entity_subtype("Table")),
|
||||||
|
F.platform("snowflake"),
|
||||||
|
F.domain("urn:li:domain:marketing"),
|
||||||
|
F.env("PROD"),
|
||||||
|
F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED),
|
||||||
|
F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_field_discriminator() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"):
|
||||||
|
_BaseFilter._field_discriminator()
|
||||||
|
|
||||||
|
assert F.entity_type("dataset")._field_discriminator() == "entity_type"
|
||||||
|
assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not"
|
||||||
|
assert (
|
||||||
|
F.custom_filter(
|
||||||
|
"custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]
|
||||||
|
)._field_discriminator()
|
||||||
|
== _CustomCondition._field_discriminator()
|
||||||
|
)
|
||||||
|
|
||||||
|
class _BadFilter(_BaseFilter):
|
||||||
|
field1: str
|
||||||
|
field2: str
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
_BadFilter._field_discriminator()
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_discriminator() -> None:
|
||||||
|
# Simple filter discriminator extraction.
|
||||||
|
assert _filter_discriminator(F.entity_type("dataset")) == "entity_type"
|
||||||
|
assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type"
|
||||||
|
assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not"
|
||||||
|
assert _filter_discriminator({"unknown_field": 6}) == "unknown_field"
|
||||||
|
assert _filter_discriminator({"field1": 6, "field2": 7}) is None
|
||||||
|
assert _filter_discriminator({}) is None
|
||||||
|
assert _filter_discriminator(6) is None
|
||||||
|
|
||||||
|
# Special case for custom conditions.
|
||||||
|
assert (
|
||||||
|
_filter_discriminator(
|
||||||
|
{
|
||||||
|
"field": "custom_field",
|
||||||
|
"condition": "GREATER_THAN_OR_EQUAL_TO",
|
||||||
|
"values": ["5"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
== "_custom"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
_filter_discriminator(
|
||||||
|
{
|
||||||
|
"field": "custom_field",
|
||||||
|
"condition": "EXISTS",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
== "_custom"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR,
|
||||||
|
reason="Tagged union w/ callable discriminator is not supported by the current pydantic version",
|
||||||
|
)
|
||||||
|
def test_tagged_union_error_messages() -> None:
|
||||||
|
# With pydantic v1, we'd get 10+ validation errors and it'd be hard to
|
||||||
|
# understand what went wrong. With v2, we get a single simple error message.
|
||||||
|
with pytest.raises(
|
||||||
|
ValidationError,
|
||||||
|
match=re.compile(
|
||||||
|
r"1 validation error.*entity_type\.entity_type.*Input should be a valid list",
|
||||||
|
re.DOTALL,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
load_filters({"entity_type": 6})
|
||||||
|
|
||||||
|
# Even when within an "and" clause, we get a single error message.
|
||||||
|
with pytest.raises(
|
||||||
|
ValidationError,
|
||||||
|
match=re.compile(
|
||||||
|
r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid",
|
||||||
|
re.DOTALL,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
load_filters({"and": [{"unknown_field": 6}]})
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_filter() -> None:
|
def test_invalid_filter() -> None:
|
||||||
with pytest.raises(InvalidUrnError):
|
with pytest.raises(InvalidUrnError):
|
||||||
F.domain("marketing")
|
F.domain("marketing")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user