diff --git a/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py b/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py index bd931abe2e..0d24b88615 100644 --- a/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py +++ b/metadata-ingestion/src/datahub/configuration/pydantic_migration_helpers.py @@ -1,12 +1,13 @@ import pydantic.version from packaging.version import Version -PYDANTIC_VERSION_2: bool -if Version(pydantic.version.VERSION) >= Version("2.0"): - PYDANTIC_VERSION_2 = True -else: - PYDANTIC_VERSION_2 = False +_pydantic_version = Version(pydantic.version.VERSION) +PYDANTIC_VERSION_2 = _pydantic_version >= Version("2.0") + +# The pydantic.Discriminator type was added in v2.5.0. +# https://docs.pydantic.dev/latest/changelog/#v250-2023-11-13 +PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR = _pydantic_version >= Version("2.5.0") # This can be used to silence deprecation warnings while we migrate. if PYDANTIC_VERSION_2: @@ -50,6 +51,7 @@ class v1_ConfigModel(v1_BaseModel): __all__ = [ "PYDANTIC_VERSION_2", + "PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR", "PydanticDeprecatedSince20", "GenericModel", "v1_ConfigModel", diff --git a/metadata-ingestion/src/datahub/sdk/search_filters.py b/metadata-ingestion/src/datahub/sdk/search_filters.py index d395e7b009..34de40c34c 100644 --- a/metadata-ingestion/src/datahub/sdk/search_filters.py +++ b/metadata-ingestion/src/datahub/sdk/search_filters.py @@ -2,6 +2,8 @@ from __future__ import annotations import abc from typing import ( + TYPE_CHECKING, + Annotated, Any, ClassVar, Iterator, @@ -15,7 +17,10 @@ from typing import ( import pydantic from datahub.configuration.common import ConfigModel -from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 +from datahub.configuration.pydantic_migration_helpers import ( + PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, + PYDANTIC_VERSION_2, +) from datahub.ingestion.graph.client import flexible_entity_type_to_graphql from datahub.ingestion.graph.filters import ( FilterOperator, @@ -42,12 +47,29 @@ class _BaseFilter(ConfigModel): populate_by_name = True @abc.abstractmethod - def compile(self) -> _OrFilters: - pass + def compile(self) -> _OrFilters: ... def dfs(self) -> Iterator[_BaseFilter]: yield self + @classmethod + def _field_discriminator(cls) -> str: + if cls is _BaseFilter: + raise ValueError("Cannot get discriminator for _BaseFilter") + if PYDANTIC_VERSION_2: + fields: dict = cls.model_fields # type: ignore + else: + fields = cls.__fields__ # type: ignore + + # Assumes that there's only one field name per filter. + # If that's not the case, this method should be overridden. + if len(fields.keys()) != 1: + raise ValueError( + f"Found multiple fields that could be the discriminator for this filter: {list(fields.keys())}" + ) + name, field = next(iter(fields.items())) + return field.alias or name # type: ignore + class _EntityTypeFilter(_BaseFilter): """Filter for specific entity types. @@ -74,15 +96,19 @@ class _EntityTypeFilter(_BaseFilter): class _EntitySubtypeFilter(_BaseFilter): - entity_subtype: str = pydantic.Field( + entity_subtype: List[str] = pydantic.Field( description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.", ) + @pydantic.validator("entity_subtype", pre=True) + def validate_entity_subtype(cls, v: str) -> List[str]: + return [v] if not isinstance(v, list) else v + def _build_rule(self) -> SearchFilterRule: return SearchFilterRule( field="typeNames", condition="EQUAL", - values=[self.entity_subtype], + values=self.entity_subtype, ) def compile(self) -> _OrFilters: @@ -196,6 +222,10 @@ class _CustomCondition(_BaseFilter): ) return [{"and": [rule]}] + @classmethod + def _field_discriminator(cls) -> str: + return "_custom" + class _And(_BaseFilter): """Represents an AND conjunction of filters.""" @@ -302,31 +332,69 @@ class _Not(_BaseFilter): yield from self.not_.dfs() -# TODO: With pydantic 2, we can use a RootModel with a -# discriminated union to make the error messages more informative. -Filter = Union[ - _And, - _Or, - _Not, - _EntityTypeFilter, - _EntitySubtypeFilter, - _StatusFilter, - _PlatformFilter, - _DomainFilter, - _EnvFilter, - _CustomCondition, -] +def _filter_discriminator(v: Any) -> Optional[str]: + if isinstance(v, _BaseFilter): + return v._field_discriminator() + + if not isinstance(v, dict): + return None + + keys = list(v.keys()) + if len(keys) == 1: + return keys[0] + elif set(keys).issuperset({"field", "condition"}): + return _CustomCondition._field_discriminator() + + return None -# Required to resolve forward references to "Filter" -if PYDANTIC_VERSION_2: - _And.model_rebuild() # type: ignore - _Or.model_rebuild() # type: ignore - _Not.model_rebuild() # type: ignore -else: +if TYPE_CHECKING or not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR: + # The `not TYPE_CHECKING` bit is required to make the linter happy, + # since we currently only run mypy with pydantic v1. + Filter = Union[ + _And, + _Or, + _Not, + _EntityTypeFilter, + _EntitySubtypeFilter, + _StatusFilter, + _PlatformFilter, + _DomainFilter, + _EnvFilter, + _CustomCondition, + ] + _And.update_forward_refs() _Or.update_forward_refs() _Not.update_forward_refs() +else: + from pydantic import Discriminator, Tag + + # TODO: Once we're fully on pydantic 2, we can use a RootModel here. + # That way we'd be able to attach methods to the Filter type. + # e.g. replace load_filters(...) with Filter.load(...) + Filter = Annotated[ + Union[ + Annotated[_And, Tag(_And._field_discriminator())], + Annotated[_Or, Tag(_Or._field_discriminator())], + Annotated[_Not, Tag(_Not._field_discriminator())], + Annotated[_EntityTypeFilter, Tag(_EntityTypeFilter._field_discriminator())], + Annotated[ + _EntitySubtypeFilter, Tag(_EntitySubtypeFilter._field_discriminator()) + ], + Annotated[_StatusFilter, Tag(_StatusFilter._field_discriminator())], + Annotated[_PlatformFilter, Tag(_PlatformFilter._field_discriminator())], + Annotated[_DomainFilter, Tag(_DomainFilter._field_discriminator())], + Annotated[_EnvFilter, Tag(_EnvFilter._field_discriminator())], + Annotated[_CustomCondition, Tag(_CustomCondition._field_discriminator())], + ], + Discriminator(_filter_discriminator), + ] + + # Required to resolve forward references to "Filter" + _And.model_rebuild() # type: ignore + _Or.model_rebuild() # type: ignore + _Not.model_rebuild() # type: ignore def load_filters(obj: Any) -> Filter: diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py index d3c9bbe206..e0e75e6d3e 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py @@ -1,6 +1,6 @@ import pathlib from typing import Dict, List, Optional, Sequence, Set, cast -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest @@ -141,7 +141,6 @@ def test_infer_lineage_from_sql(client: DataHubClient) -> None: "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" ], query_type=QueryType.SELECT, - debug_info=MagicMock(error=None, table_error=None), ) query_text = ( @@ -197,7 +196,6 @@ def test_infer_lineage_from_sql_with_multiple_upstreams( ), ], query_type=QueryType.SELECT, - debug_info=MagicMock(error=None, table_error=None), ) query_text = """ diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py index 862b2f1961..fa56c7fe12 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py @@ -1,3 +1,4 @@ +import re import unittest import unittest.mock from io import StringIO @@ -6,6 +7,9 @@ import pytest import yaml from pydantic import ValidationError +from datahub.configuration.pydantic_migration_helpers import ( + PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, +) from datahub.ingestion.graph.filters import ( RemovedStatusFilter, SearchFilterRule, @@ -14,7 +18,14 @@ from datahub.ingestion.graph.filters import ( from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn from datahub.sdk.main_client import DataHubClient from datahub.sdk.search_client import compile_filters, compute_entity_types -from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters +from datahub.sdk.search_filters import ( + Filter, + FilterDsl as F, + _BaseFilter, + _CustomCondition, + _filter_discriminator, + load_filters, +) from datahub.utilities.urns.error import InvalidUrnError from tests.test_helpers.graph_helpers import MockDataHubGraph @@ -170,6 +181,137 @@ and: ] +def test_entity_subtype_filter() -> None: + filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]}) + assert filter_obj_1 == F.entity_subtype("Table") + + # Ensure it works without the list wrapper to maintain backwards compatibility. + filter_obj_2: Filter = load_filters({"entity_subtype": "Table"}) + assert filter_obj_1 == filter_obj_2 + + +def test_filters_all_types() -> None: + filter_obj: Filter = load_filters( + { + "and": [ + { + "or": [ + {"entity_type": ["dataset"]}, + {"entity_type": ["chart", "dashboard"]}, + ] + }, + {"not": {"entity_subtype": ["Table"]}}, + {"platform": ["snowflake"]}, + {"domain": ["urn:li:domain:marketing"]}, + {"env": ["PROD"]}, + {"status": "NOT_SOFT_DELETED"}, + { + "field": "custom_field", + "condition": "GREATER_THAN_OR_EQUAL_TO", + "values": ["5"], + }, + ] + } + ) + assert filter_obj == F.and_( + F.or_( + F.entity_type("dataset"), + F.entity_type(["chart", "dashboard"]), + ), + F.not_(F.entity_subtype("Table")), + F.platform("snowflake"), + F.domain("urn:li:domain:marketing"), + F.env("PROD"), + F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED), + F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]), + ) + + +def test_field_discriminator() -> None: + with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"): + _BaseFilter._field_discriminator() + + assert F.entity_type("dataset")._field_discriminator() == "entity_type" + assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not" + assert ( + F.custom_filter( + "custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"] + )._field_discriminator() + == _CustomCondition._field_discriminator() + ) + + class _BadFilter(_BaseFilter): + field1: str + field2: str + + with pytest.raises( + ValueError, + match=re.escape( + "Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']" + ), + ): + _BadFilter._field_discriminator() + + +def test_filter_discriminator() -> None: + # Simple filter discriminator extraction. + assert _filter_discriminator(F.entity_type("dataset")) == "entity_type" + assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type" + assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not" + assert _filter_discriminator({"unknown_field": 6}) == "unknown_field" + assert _filter_discriminator({"field1": 6, "field2": 7}) is None + assert _filter_discriminator({}) is None + assert _filter_discriminator(6) is None + + # Special case for custom conditions. + assert ( + _filter_discriminator( + { + "field": "custom_field", + "condition": "GREATER_THAN_OR_EQUAL_TO", + "values": ["5"], + } + ) + == "_custom" + ) + assert ( + _filter_discriminator( + { + "field": "custom_field", + "condition": "EXISTS", + } + ) + == "_custom" + ) + + +@pytest.mark.skipif( + not PYDANTIC_SUPPORTS_CALLABLE_DISCRIMINATOR, + reason="Tagged union w/ callable discriminator is not supported by the current pydantic version", +) +def test_tagged_union_error_messages() -> None: + # With pydantic v1, we'd get 10+ validation errors and it'd be hard to + # understand what went wrong. With v2, we get a single simple error message. + with pytest.raises( + ValidationError, + match=re.compile( + r"1 validation error.*entity_type\.entity_type.*Input should be a valid list", + re.DOTALL, + ), + ): + load_filters({"entity_type": 6}) + + # Even when within an "and" clause, we get a single error message. + with pytest.raises( + ValidationError, + match=re.compile( + r"1 validation error.*Input tag 'unknown_field' found using .+ does not match any of the expected tags:.+union_tag_invalid", + re.DOTALL, + ), + ): + load_filters({"and": [{"unknown_field": 6}]}) + + def test_invalid_filter() -> None: with pytest.raises(InvalidUrnError): F.domain("marketing")