feat(sdk): add search client (#12754)

This commit is contained in:
Harshal Sheth 2025-03-03 10:05:26 -08:00 committed by GitHub
parent ccf4412078
commit 12eb0cd1a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 780 additions and 102 deletions

View File

@ -323,6 +323,7 @@ ASPECT_NAME_MAP: Dict[str, Type[_Aspect]] = {{
for aspect in ASPECT_CLASSES for aspect in ASPECT_CLASSES
}} }}
from typing import Literal
from typing_extensions import TypedDict from typing_extensions import TypedDict
class AspectBag(TypedDict, total=False): class AspectBag(TypedDict, total=False):
@ -332,6 +333,13 @@ class AspectBag(TypedDict, total=False):
KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{ KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{
{f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))} {f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
}} }}
ENTITY_TYPE_NAMES: List[str] = [
{f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
]
EntityTypeName = Literal[
{f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
]
""" """
) )
@ -346,7 +354,7 @@ def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None:
code = """ code = """
# This file contains classes corresponding to entity URNs. # This file contains classes corresponding to entity URNs.
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union, Literal
import functools import functools
from deprecated.sphinx import deprecated as _sphinx_deprecated from deprecated.sphinx import deprecated as _sphinx_deprecated
@ -672,7 +680,7 @@ if TYPE_CHECKING:
from datahub.metadata.schema_classes import {key_aspect_class} from datahub.metadata.schema_classes import {key_aspect_class}
class {class_name}(_SpecificUrn): class {class_name}(_SpecificUrn):
ENTITY_TYPE: ClassVar[str] = "{entity_type}" ENTITY_TYPE: ClassVar[Literal["{entity_type}"]] = "{entity_type}"
_URN_PARTS: ClassVar[int] = {arg_count} _URN_PARTS: ClassVar[int] = {arg_count}
def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None: def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:

View File

@ -16,6 +16,7 @@ from typing import (
List, List,
Literal, Literal,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -42,8 +43,8 @@ from datahub.ingestion.graph.connections import (
) )
from datahub.ingestion.graph.entity_versioning import EntityVersioningAPI from datahub.ingestion.graph.entity_versioning import EntityVersioningAPI
from datahub.ingestion.graph.filters import ( from datahub.ingestion.graph.filters import (
RawSearchFilterRule,
RemovedStatusFilter, RemovedStatusFilter,
SearchFilterRule,
generate_filter, generate_filter,
) )
from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.checkpoint import Checkpoint
@ -105,7 +106,7 @@ class RelatedEntity:
via: Optional[str] = None via: Optional[str] = None
def _graphql_entity_type(entity_type: str) -> str: def entity_type_to_graphql(entity_type: str) -> str:
"""Convert the entity types into GraphQL "EntityType" enum values.""" """Convert the entity types into GraphQL "EntityType" enum values."""
# Hard-coded special cases. # Hard-coded special cases.
@ -797,13 +798,13 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
container: Optional[str] = None, container: Optional[str] = None,
status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED,
batch_size: int = 100, batch_size: int = 100,
extraFilters: Optional[List[SearchFilterRule]] = None, extraFilters: Optional[List[RawSearchFilterRule]] = None,
) -> Iterable[Tuple[str, "GraphQLSchemaMetadata"]]: ) -> Iterable[Tuple[str, "GraphQLSchemaMetadata"]]:
"""Fetch schema info for datasets that match all of the given filters. """Fetch schema info for datasets that match all of the given filters.
:return: An iterable of (urn, schema info) tuple that match the filters. :return: An iterable of (urn, schema info) tuple that match the filters.
""" """
types = [_graphql_entity_type("dataset")] types = [entity_type_to_graphql("dataset")]
# Add the query default of * if no query is specified. # Add the query default of * if no query is specified.
query = query or "*" query = query or "*"
@ -865,7 +866,7 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
def get_urns_by_filter( def get_urns_by_filter(
self, self,
*, *,
entity_types: Optional[List[str]] = None, entity_types: Optional[Sequence[str]] = None,
platform: Optional[str] = None, platform: Optional[str] = None,
platform_instance: Optional[str] = None, platform_instance: Optional[str] = None,
env: Optional[str] = None, env: Optional[str] = None,
@ -873,8 +874,8 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
container: Optional[str] = None, container: Optional[str] = None,
status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED,
batch_size: int = 10000, batch_size: int = 10000,
extraFilters: Optional[List[SearchFilterRule]] = None, extraFilters: Optional[List[RawSearchFilterRule]] = None,
extra_or_filters: Optional[List[Dict[str, List[SearchFilterRule]]]] = None, extra_or_filters: Optional[List[Dict[str, List[RawSearchFilterRule]]]] = None,
) -> Iterable[str]: ) -> Iterable[str]:
"""Fetch all urns that match all of the given filters. """Fetch all urns that match all of the given filters.
@ -965,8 +966,8 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
container: Optional[str] = None, container: Optional[str] = None,
status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED,
batch_size: int = 10000, batch_size: int = 10000,
extra_and_filters: Optional[List[SearchFilterRule]] = None, extra_and_filters: Optional[List[RawSearchFilterRule]] = None,
extra_or_filters: Optional[List[Dict[str, List[SearchFilterRule]]]] = None, extra_or_filters: Optional[List[Dict[str, List[RawSearchFilterRule]]]] = None,
extra_source_fields: Optional[List[str]] = None, extra_source_fields: Optional[List[str]] = None,
skip_cache: bool = False, skip_cache: bool = False,
) -> Iterable[dict]: ) -> Iterable[dict]:
@ -1109,7 +1110,8 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
f"Scrolling to next scrollAcrossEntities page: {scroll_id}" f"Scrolling to next scrollAcrossEntities page: {scroll_id}"
) )
def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]: @classmethod
def _get_types(cls, entity_types: Optional[Sequence[str]]) -> Optional[List[str]]:
types: Optional[List[str]] = None types: Optional[List[str]] = None
if entity_types is not None: if entity_types is not None:
if not entity_types: if not entity_types:
@ -1117,7 +1119,9 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
"entity_types cannot be an empty list; use None for all entities" "entity_types cannot be an empty list; use None for all entities"
) )
types = [_graphql_entity_type(entity_type) for entity_type in entity_types] types = [
entity_type_to_graphql(entity_type) for entity_type in entity_types
]
return types return types
def get_latest_pipeline_checkpoint( def get_latest_pipeline_checkpoint(

View File

@ -1,3 +1,4 @@
import dataclasses
import enum import enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -7,7 +8,31 @@ from datahub.emitter.mce_builder import (
) )
from datahub.utilities.urns.urn import guess_entity_type from datahub.utilities.urns.urn import guess_entity_type
SearchFilterRule = Dict[str, Any] RawSearchFilterRule = Dict[str, Any]
@dataclasses.dataclass
class SearchFilterRule:
field: str
condition: str # TODO: convert to an enum
values: List[str]
negated: bool = False
def to_raw(self) -> RawSearchFilterRule:
return {
"field": self.field,
"condition": self.condition,
"values": self.values,
"negated": self.negated,
}
def negate(self) -> "SearchFilterRule":
return SearchFilterRule(
field=self.field,
condition=self.condition,
values=self.values,
negated=not self.negated,
)
class RemovedStatusFilter(enum.Enum): class RemovedStatusFilter(enum.Enum):
@ -29,9 +54,9 @@ def generate_filter(
env: Optional[str], env: Optional[str],
container: Optional[str], container: Optional[str],
status: RemovedStatusFilter, status: RemovedStatusFilter,
extra_filters: Optional[List[SearchFilterRule]], extra_filters: Optional[List[RawSearchFilterRule]],
extra_or_filters: Optional[List[SearchFilterRule]] = None, extra_or_filters: Optional[List[RawSearchFilterRule]] = None,
) -> List[Dict[str, List[SearchFilterRule]]]: ) -> List[Dict[str, List[RawSearchFilterRule]]]:
""" """
Generate a search filter based on the provided parameters. Generate a search filter based on the provided parameters.
:param platform: The platform to filter by. :param platform: The platform to filter by.
@ -43,30 +68,32 @@ def generate_filter(
:param extra_or_filters: Extra OR filters to apply. These are combined with :param extra_or_filters: Extra OR filters to apply. These are combined with
the AND filters using an OR at the top level. the AND filters using an OR at the top level.
""" """
and_filters: List[SearchFilterRule] = [] and_filters: List[RawSearchFilterRule] = []
# Platform filter. # Platform filter.
if platform: if platform:
and_filters.append(_get_platform_filter(platform)) and_filters.append(_get_platform_filter(platform).to_raw())
# Platform instance filter. # Platform instance filter.
if platform_instance: if platform_instance:
and_filters.append(_get_platform_instance_filter(platform, platform_instance)) and_filters.append(
_get_platform_instance_filter(platform, platform_instance).to_raw()
)
# Browse path v2 filter. # Browse path v2 filter.
if container: if container:
and_filters.append(_get_container_filter(container)) and_filters.append(_get_container_filter(container).to_raw())
# Status filter. # Status filter.
status_filter = _get_status_filter(status) status_filter = _get_status_filter(status)
if status_filter: if status_filter:
and_filters.append(status_filter) and_filters.append(status_filter.to_raw())
# Extra filters. # Extra filters.
if extra_filters: if extra_filters:
and_filters += extra_filters and_filters += extra_filters
or_filters: List[Dict[str, List[SearchFilterRule]]] = [{"and": and_filters}] or_filters: List[Dict[str, List[RawSearchFilterRule]]] = [{"and": and_filters}]
# Env filter # Env filter
if env: if env:
@ -89,7 +116,7 @@ def generate_filter(
return or_filters return or_filters
def _get_env_filters(env: str) -> List[SearchFilterRule]: def _get_env_filters(env: str) -> List[RawSearchFilterRule]:
# The env filter is a bit more tricky since it's not always stored # The env filter is a bit more tricky since it's not always stored
# in the same place in ElasticSearch. # in the same place in ElasticSearch.
return [ return [
@ -125,19 +152,19 @@ def _get_status_filter(status: RemovedStatusFilter) -> Optional[SearchFilterRule
# removed field is simply not present in the ElasticSearch document. Ideally this # removed field is simply not present in the ElasticSearch document. Ideally this
# would be a "removed" : "false" filter, but that doesn't work. Instead, we need to # would be a "removed" : "false" filter, but that doesn't work. Instead, we need to
# use a negated filter. # use a negated filter.
return { return SearchFilterRule(
"field": "removed", field="removed",
"values": ["true"], values=["true"],
"condition": "EQUAL", condition="EQUAL",
"negated": True, negated=True,
} )
elif status == RemovedStatusFilter.ONLY_SOFT_DELETED: elif status == RemovedStatusFilter.ONLY_SOFT_DELETED:
return { return SearchFilterRule(
"field": "removed", field="removed",
"values": ["true"], values=["true"],
"condition": "EQUAL", condition="EQUAL",
} )
elif status == RemovedStatusFilter.ALL: elif status == RemovedStatusFilter.ALL:
# We don't need to add a filter for this case. # We don't need to add a filter for this case.
@ -152,11 +179,11 @@ def _get_container_filter(container: str) -> SearchFilterRule:
if guess_entity_type(container) != "container": if guess_entity_type(container) != "container":
raise ValueError(f"Invalid container urn: {container}") raise ValueError(f"Invalid container urn: {container}")
return { return SearchFilterRule(
"field": "browsePathV2", field="browsePathV2",
"values": [container], values=[container],
"condition": "CONTAIN", condition="CONTAIN",
} )
def _get_platform_instance_filter( def _get_platform_instance_filter(
@ -171,16 +198,16 @@ def _get_platform_instance_filter(
if guess_entity_type(platform_instance) != "dataPlatformInstance": if guess_entity_type(platform_instance) != "dataPlatformInstance":
raise ValueError(f"Invalid data platform instance urn: {platform_instance}") raise ValueError(f"Invalid data platform instance urn: {platform_instance}")
return { return SearchFilterRule(
"field": "platformInstance", field="platformInstance",
"values": [platform_instance], condition="EQUAL",
"condition": "EQUAL", values=[platform_instance],
} )
def _get_platform_filter(platform: str) -> SearchFilterRule: def _get_platform_filter(platform: str) -> SearchFilterRule:
return { return SearchFilterRule(
"field": "platform.keyword", field="platform.keyword",
"values": [make_data_platform_urn(platform)], condition="EQUAL",
"condition": "EQUAL", values=[make_data_platform_urn(platform)],
} )

View File

@ -59,9 +59,9 @@ from datahub.metadata.schema_classes import (
UpstreamLineageClass, UpstreamLineageClass,
ViewPropertiesClass, ViewPropertiesClass,
) )
from datahub.sdk._entity import Entity
from datahub.sdk.container import Container from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset from datahub.sdk.dataset import Dataset
from datahub.sdk.entity import Entity
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -20,6 +20,7 @@ from datahub.metadata.urns import (
from datahub.sdk.container import Container from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset from datahub.sdk.dataset import Dataset
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_filters import Filter, FilterDsl
# We want to print out the warning if people do `from datahub.sdk import X`. # We want to print out the warning if people do `from datahub.sdk import X`.
# But we don't want to print out warnings if they're doing a more direct # But we don't want to print out warnings if they're doing a more direct

View File

@ -1,8 +1,8 @@
from typing import Dict, List, Type from typing import Dict, List, Type
from datahub.sdk._entity import Entity
from datahub.sdk.container import Container from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset from datahub.sdk.dataset import Dataset
from datahub.sdk.entity import Entity
# TODO: Is there a better way to declare this? # TODO: Is there a better way to declare this?
ENTITY_CLASSES_LIST: List[Type[Entity]] = [ ENTITY_CLASSES_LIST: List[Type[Entity]] = [

View File

@ -36,8 +36,8 @@ from datahub.metadata.urns import (
TagUrn, TagUrn,
Urn, Urn,
) )
from datahub.sdk._entity import Entity
from datahub.sdk._utils import add_list_unique, remove_list_unique from datahub.sdk._utils import add_list_unique, remove_list_unique
from datahub.sdk.entity import Entity
from datahub.utilities.urns.error import InvalidUrnError from datahub.utilities.urns.error import InvalidUrnError
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -16,7 +16,6 @@ from datahub.metadata.urns import (
ContainerUrn, ContainerUrn,
Urn, Urn,
) )
from datahub.sdk._entity import Entity, ExtraAspectsType
from datahub.sdk._shared import ( from datahub.sdk._shared import (
DomainInputType, DomainInputType,
HasContainer, HasContainer,
@ -33,6 +32,7 @@ from datahub.sdk._shared import (
make_time_stamp, make_time_stamp,
parse_time_stamp, parse_time_stamp,
) )
from datahub.sdk.entity import Entity, ExtraAspectsType
from datahub.utilities.sentinels import Auto, auto from datahub.utilities.sentinels import Auto, auto

View File

@ -18,7 +18,6 @@ from datahub.errors import (
from datahub.ingestion.source.sql.sql_types import resolve_sql_type from datahub.ingestion.source.sql.sql_types import resolve_sql_type
from datahub.metadata.urns import DatasetUrn, SchemaFieldUrn, Urn from datahub.metadata.urns import DatasetUrn, SchemaFieldUrn, Urn
from datahub.sdk._attribution import is_ingestion_attribution from datahub.sdk._attribution import is_ingestion_attribution
from datahub.sdk._entity import Entity, ExtraAspectsType
from datahub.sdk._shared import ( from datahub.sdk._shared import (
DatasetUrnOrStr, DatasetUrnOrStr,
DomainInputType, DomainInputType,
@ -39,6 +38,7 @@ from datahub.sdk._shared import (
parse_time_stamp, parse_time_stamp,
) )
from datahub.sdk._utils import add_list_unique, remove_list_unique from datahub.sdk._utils import add_list_unique, remove_list_unique
from datahub.sdk.entity import Entity, ExtraAspectsType
from datahub.utilities.sentinels import Unset, unset from datahub.utilities.sentinels import Unset, unset
SchemaFieldInputType: TypeAlias = Union[ SchemaFieldInputType: TypeAlias = Union[

View File

@ -56,6 +56,10 @@ class Entity:
@abc.abstractmethod @abc.abstractmethod
def get_urn_type(cls) -> Type[_SpecificUrn]: ... def get_urn_type(cls) -> Type[_SpecificUrn]: ...
@classmethod
def entity_type_name(cls) -> str:
return cls.get_urn_type().ENTITY_TYPE
@property @property
def urn(self) -> _SpecificUrn: def urn(self) -> _SpecificUrn:
return self._urn return self._urn

View File

@ -14,10 +14,10 @@ from datahub.metadata.urns import (
Urn, Urn,
) )
from datahub.sdk._all_entities import ENTITY_CLASSES from datahub.sdk._all_entities import ENTITY_CLASSES
from datahub.sdk._entity import Entity
from datahub.sdk._shared import UrnOrStr from datahub.sdk._shared import UrnOrStr
from datahub.sdk.container import Container from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset from datahub.sdk.dataset import Dataset
from datahub.sdk.entity import Entity
if TYPE_CHECKING: if TYPE_CHECKING:
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient

View File

@ -7,6 +7,7 @@ from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.ingestion.graph.config import DatahubClientConfig from datahub.ingestion.graph.config import DatahubClientConfig
from datahub.sdk.entity_client import EntityClient from datahub.sdk.entity_client import EntityClient
from datahub.sdk.resolver_client import ResolverClient from datahub.sdk.resolver_client import ResolverClient
from datahub.sdk.search_client import SearchClient
class DataHubClient: class DataHubClient:
@ -39,6 +40,8 @@ class DataHubClient:
self._graph = graph self._graph = graph
# TODO: test connection
@classmethod @classmethod
def from_env(cls) -> "DataHubClient": def from_env(cls) -> "DataHubClient":
"""Initialize a DataHubClient from the environment variables or ~/.datahubenv file. """Initialize a DataHubClient from the environment variables or ~/.datahubenv file.
@ -69,5 +72,8 @@ class DataHubClient:
def resolve(self) -> ResolverClient: def resolve(self) -> ResolverClient:
return ResolverClient(self) return ResolverClient(self)
# TODO: search client @property
def search(self) -> SearchClient:
return SearchClient(self)
# TODO: lineage client # TODO: lineage client

View File

@ -9,6 +9,7 @@ from datahub.metadata.urns import (
DomainUrn, DomainUrn,
GlossaryTermUrn, GlossaryTermUrn,
) )
from datahub.sdk.search_filters import Filter, FilterDsl as F
if TYPE_CHECKING: if TYPE_CHECKING:
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
@ -38,37 +39,28 @@ class ResolverClient:
self, *, name: Optional[str] = None, email: Optional[str] = None self, *, name: Optional[str] = None, email: Optional[str] = None
) -> CorpUserUrn: ) -> CorpUserUrn:
filter_explanation: str filter_explanation: str
filters = [] filter: Filter
if name is not None: if name is not None:
if email is not None: if email is not None:
raise SdkUsageError("Cannot specify both name and email for auto_user") raise SdkUsageError("Cannot specify both name and email for auto_user")
# TODO: do we filter on displayName or fullName? # We're filtering on both fullName and displayName. It's not clear
# what the right behavior is here.
filter_explanation = f"with name {name}" filter_explanation = f"with name {name}"
filters.append( filter = F.or_(
{ F.custom_filter("fullName", "EQUAL", [name]),
"field": "fullName", F.custom_filter("displayName", "EQUAL", [name]),
"values": [name],
"condition": "EQUAL",
}
) )
elif email is not None: elif email is not None:
filter_explanation = f"with email {email}" filter_explanation = f"with email {email}"
filters.append( filter = F.custom_filter("email", "EQUAL", [email])
{
"field": "email",
"values": [email],
"condition": "EQUAL",
}
)
else: else:
raise SdkUsageError("Must specify either name or email for auto_user") raise SdkUsageError("Must specify either name or email for auto_user")
users = list( filter = F.and_(
self._graph.get_urns_by_filter( F.entity_type(CorpUserUrn.ENTITY_TYPE),
entity_types=[CorpUserUrn.ENTITY_TYPE], filter,
extraFilters=filters,
)
) )
users = list(self._client.search.get_urns(filter=filter))
if len(users) == 0: if len(users) == 0:
# TODO: In auto methods, should we just create the user/domain/etc if it doesn't exist? # TODO: In auto methods, should we just create the user/domain/etc if it doesn't exist?
raise ItemNotFoundError(f"User {filter_explanation} not found") raise ItemNotFoundError(f"User {filter_explanation} not found")
@ -82,15 +74,11 @@ class ResolverClient:
def term(self, *, name: str) -> GlossaryTermUrn: def term(self, *, name: str) -> GlossaryTermUrn:
# TODO: Add some limits on the graph fetch # TODO: Add some limits on the graph fetch
terms = list( terms = list(
self._graph.get_urns_by_filter( self._client.search.get_urns(
entity_types=[GlossaryTermUrn.ENTITY_TYPE], filter=F.and_(
extraFilters=[ F.entity_type(GlossaryTermUrn.ENTITY_TYPE),
{ F.custom_filter("name", "EQUAL", [name]),
"field": "id", ),
"values": [name],
"condition": "EQUAL",
}
],
) )
) )
if len(terms) == 0: if len(terms) == 0:

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
)
from datahub.ingestion.graph.filters import RawSearchFilterRule
from datahub.metadata.urns import Urn
from datahub.sdk.search_filters import Filter
if TYPE_CHECKING:
from datahub.sdk.main_client import DataHubClient
def compile_filters(
filter: Optional[Filter],
) -> Optional[List[Dict[str, List[RawSearchFilterRule]]]]:
# TODO: Not every filter type is supported for every entity type.
# If we can detect issues with the filters at compile time, we should
# raise an error.
if filter is None:
return None
initial_filters = filter.compile()
return [
{"and": [rule.to_raw() for rule in andClause["and"]]}
for andClause in initial_filters
]
class SearchClient:
def __init__(self, client: DataHubClient):
self._client = client
def get_urns(
self,
query: Optional[str] = None,
filter: Optional[Filter] = None,
) -> Iterable[Urn]:
# TODO: Add better limit / pagination support.
for urn in self._client._graph.get_urns_by_filter(
query=query,
extra_or_filters=compile_filters(filter),
):
yield Urn.from_string(urn)

View File

@ -0,0 +1,374 @@
from __future__ import annotations
import abc
from typing import (
Any,
List,
Sequence,
TypedDict,
Union,
)
import pydantic
from datahub.configuration.common import ConfigModel
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.ingestion.graph.client import entity_type_to_graphql
from datahub.ingestion.graph.filters import SearchFilterRule
from datahub.metadata.schema_classes import EntityTypeName
from datahub.metadata.urns import DataPlatformUrn, DomainUrn
_AndSearchFilterRule = TypedDict(
"_AndSearchFilterRule", {"and": List[SearchFilterRule]}
)
_OrFilters = List[_AndSearchFilterRule]
class _BaseFilter(ConfigModel):
class Config:
# We can't wrap this in a TYPE_CHECKING block because the pydantic plugin
# doesn't recognize it properly. So unfortunately we'll need to live
# with the deprecation warning w/ pydantic v2.
allow_population_by_field_name = True
if PYDANTIC_VERSION_2:
populate_by_name = True
@abc.abstractmethod
def compile(self) -> _OrFilters:
pass
def _flexible_entity_type_to_graphql(entity_type: str) -> str:
if entity_type.upper() == entity_type:
# Assume that we were passed a graphql EntityType enum value,
# so no conversion is needed.
return entity_type
return entity_type_to_graphql(entity_type)
class _EntityTypeFilter(_BaseFilter):
entity_type: List[str] = pydantic.Field(
description="The entity type to filter on. Can be 'dataset', 'chart', 'dashboard', 'corpuser', etc.",
)
def _build_rule(self) -> SearchFilterRule:
return SearchFilterRule(
field="_entityType",
condition="EQUAL",
values=[_flexible_entity_type_to_graphql(t) for t in self.entity_type],
)
def compile(self) -> _OrFilters:
return [{"and": [self._build_rule()]}]
class _EntitySubtypeFilter(_BaseFilter):
entity_type: str
entity_subtype: str = pydantic.Field(
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
)
def compile(self) -> _OrFilters:
rules = [
SearchFilterRule(
field="_entityType",
condition="EQUAL",
values=[_flexible_entity_type_to_graphql(self.entity_type)],
),
SearchFilterRule(
field="typeNames",
condition="EQUAL",
values=[self.entity_subtype],
),
]
return [{"and": rules}]
class _PlatformFilter(_BaseFilter):
platform: List[str]
# TODO: Add validator to convert string -> list of strings
@pydantic.validator("platform", each_item=True)
def validate_platform(cls, v: str) -> str:
# Subtle - we use the constructor instead of the from_string method
# because coercion is acceptable here.
return str(DataPlatformUrn(v))
def _build_rule(self) -> SearchFilterRule:
return SearchFilterRule(
field="platform.keyword",
condition="EQUAL",
values=self.platform,
)
def compile(self) -> _OrFilters:
return [{"and": [self._build_rule()]}]
class _DomainFilter(_BaseFilter):
domain: List[str]
@pydantic.validator("domain", each_item=True)
def validate_domain(cls, v: str) -> str:
return str(DomainUrn.from_string(v))
def _build_rule(self) -> SearchFilterRule:
return SearchFilterRule(
field="domains",
condition="EQUAL",
values=self.domain,
)
def compile(self) -> _OrFilters:
return [{"and": [self._build_rule()]}]
class _EnvFilter(_BaseFilter):
# Note that not all entity types have an env (e.g. dashboards / charts).
# If the env filter is specified, these will be excluded.
env: List[str]
def compile(self) -> _OrFilters:
return [
# For most entity types, we look at the origin field.
{
"and": [
SearchFilterRule(
field="origin",
condition="EQUAL",
values=self.env,
),
]
},
# For containers, we now have an "env" property as of
# https://github.com/datahub-project/datahub/pull/11214
# Prior to this, we put "env" in the customProperties. But we're
# not bothering with that here.
{
"and": [
SearchFilterRule(
field="env",
condition="EQUAL",
values=self.env,
),
]
},
]
class _CustomCondition(_BaseFilter):
"""Represents a single field condition"""
field: str
condition: str
values: List[str]
def compile(self) -> _OrFilters:
rule = SearchFilterRule(
field=self.field,
condition=self.condition,
values=self.values,
)
return [{"and": [rule]}]
class _And(_BaseFilter):
"""Represents an AND conjunction of filters"""
and_: Sequence["Filter"] = pydantic.Field(alias="and")
# TODO: Add validator to ensure that the "and" field is not empty
def compile(self) -> _OrFilters:
# The "and" operator must be implemented by doing a Cartesian product
# of the OR clauses.
# Example 1:
# (A or B) and (C or D) ->
# (A and C) or (A and D) or (B and C) or (B and D)
# Example 2:
# (A or B) and (C or D) and (E or F) ->
# (A and C and E) or (A and C and F) or (A and D and E) or (A and D and F) or
# (B and C and E) or (B and C and F) or (B and D and E) or (B and D and F)
# Start with the first filter's OR clauses
result = self.and_[0].compile()
# For each subsequent filter
for filter in self.and_[1:]:
new_result = []
# Get its OR clauses
other_clauses = filter.compile()
# Create Cartesian product
for existing_clause in result:
for other_clause in other_clauses:
# Merge the AND conditions from both clauses
new_result.append(self._merge_ands(existing_clause, other_clause))
result = new_result
return result
@classmethod
def _merge_ands(
cls, a: _AndSearchFilterRule, b: _AndSearchFilterRule
) -> _AndSearchFilterRule:
return {
"and": [
*a["and"],
*b["and"],
]
}
class _Or(_BaseFilter):
"""Represents an OR conjunction of filters"""
or_: Sequence["Filter"] = pydantic.Field(alias="or")
# TODO: Add validator to ensure that the "or" field is not empty
def compile(self) -> _OrFilters:
merged_filter = []
for filter in self.or_:
merged_filter.extend(filter.compile())
return merged_filter
class _Not(_BaseFilter):
"""Represents a NOT filter"""
not_: "Filter" = pydantic.Field(alias="not")
@pydantic.validator("not_", pre=False)
def validate_not(cls, v: "Filter") -> "Filter":
inner_filter = v.compile()
if len(inner_filter) != 1:
raise ValueError(
"Cannot negate a filter with multiple OR clauses [not yet supported]"
)
return v
def compile(self) -> _OrFilters:
# TODO: Eventually we'll want to implement a full DNF normalizer.
# https://en.wikipedia.org/wiki/Disjunctive_normal_form#Conversion_to_DNF
inner_filter = self.not_.compile()
assert len(inner_filter) == 1 # validated above
# ¬(A and B) -> (¬A) OR (¬B)
and_filters = inner_filter[0]["and"]
final_filters: _OrFilters = []
for rule in and_filters:
final_filters.append({"and": [rule.negate()]})
return final_filters
# TODO: With pydantic 2, we can use a RootModel with a
# discriminated union to make the error messages more informative.
Filter = Union[
_And,
_Or,
_Not,
_EntityTypeFilter,
_EntitySubtypeFilter,
_PlatformFilter,
_DomainFilter,
_EnvFilter,
_CustomCondition,
]
# Required to resolve forward references to "Filter"
if PYDANTIC_VERSION_2:
_And.model_rebuild() # type: ignore
_Or.model_rebuild() # type: ignore
_Not.model_rebuild() # type: ignore
else:
_And.update_forward_refs()
_Or.update_forward_refs()
_Not.update_forward_refs()
def load_filters(obj: Any) -> Filter:
if PYDANTIC_VERSION_2:
return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore
else:
return pydantic.parse_obj_as(Filter, obj) # type: ignore
# We need FilterDsl for two reasons:
# 1. To provide wrapper methods around lots of filters while avoid bloating the
# yaml spec.
# 2. Pydantic models in general don't support positional arguments, making the
# calls feel repetitive (e.g. Platform(platform=...)).
# See https://github.com/pydantic/pydantic/issues/6792
# We also considered using dataclasses / pydantic dataclasses, but
# ultimately decided that they didn't quite suit our requirements,
# particularly with regards to the field aliases for and/or/not.
class FilterDsl:
@staticmethod
def and_(*args: "Filter") -> _And:
return _And(and_=list(args))
@staticmethod
def or_(*args: "Filter") -> _Or:
return _Or(or_=list(args))
@staticmethod
def not_(arg: "Filter") -> _Not:
return _Not(not_=arg)
@staticmethod
def entity_type(
entity_type: Union[EntityTypeName, Sequence[EntityTypeName]],
) -> _EntityTypeFilter:
return _EntityTypeFilter(
entity_type=(
[entity_type] if isinstance(entity_type, str) else list(entity_type)
)
)
@staticmethod
def entity_subtype(entity_type: str, subtype: str) -> _EntitySubtypeFilter:
return _EntitySubtypeFilter(
entity_type=entity_type,
entity_subtype=subtype,
)
@staticmethod
def platform(platform: Union[str, List[str]], /) -> _PlatformFilter:
return _PlatformFilter(
platform=[platform] if isinstance(platform, str) else platform
)
# TODO: Add a platform_instance filter
@staticmethod
def domain(domain: Union[str, List[str]], /) -> _DomainFilter:
return _DomainFilter(domain=[domain] if isinstance(domain, str) else domain)
@staticmethod
def env(env: Union[str, List[str]], /) -> _EnvFilter:
return _EnvFilter(env=[env] if isinstance(env, str) else env)
@staticmethod
def has_custom_property(key: str, value: str) -> _CustomCondition:
return _CustomCondition(
field="customProperties",
condition="EQUAL",
values=[f"{key}={value}"],
)
# TODO: Add a soft-deletion status filter
# TODO: add a container / browse path filter
# TODO add shortcut for custom filters
@staticmethod
def custom_filter(
field: str, condition: str, values: List[str]
) -> _CustomCondition:
return _CustomCondition(
field=field,
condition=condition,
values=values,
)

View File

@ -1,6 +1,6 @@
import pathlib import pathlib
from datahub.sdk._entity import Entity from datahub.sdk.entity import Entity
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers

View File

@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
from datahub.ingestion.graph.client import ( from datahub.ingestion.graph.client import (
DatahubClientConfig, DatahubClientConfig,
DataHubGraph, DataHubGraph,
_graphql_entity_type, entity_type_to_graphql,
) )
from datahub.metadata.schema_classes import CorpUserEditableInfoClass from datahub.metadata.schema_classes import CorpUserEditableInfoClass
@ -26,20 +26,22 @@ def test_get_aspect(mock_test_connection):
assert editable is not None assert editable is not None
def test_graphql_entity_types(): def test_graphql_entity_types() -> None:
# FIXME: This is a subset of all the types, but it's enough to get us ok coverage. # FIXME: This is a subset of all the types, but it's enough to get us ok coverage.
assert _graphql_entity_type("domain") == "DOMAIN" known_mappings = {
assert _graphql_entity_type("dataset") == "DATASET" "domain": "DOMAIN",
assert _graphql_entity_type("dashboard") == "DASHBOARD" "dataset": "DATASET",
assert _graphql_entity_type("chart") == "CHART" "dashboard": "DASHBOARD",
"chart": "CHART",
"corpuser": "CORP_USER",
"corpGroup": "CORP_GROUP",
"dataFlow": "DATA_FLOW",
"dataJob": "DATA_JOB",
"glossaryNode": "GLOSSARY_NODE",
"glossaryTerm": "GLOSSARY_TERM",
"dataHubExecutionRequest": "EXECUTION_REQUEST",
}
assert _graphql_entity_type("corpuser") == "CORP_USER" for entity_type, graphql_type in known_mappings.items():
assert _graphql_entity_type("corpGroup") == "CORP_GROUP" assert entity_type_to_graphql(entity_type) == graphql_type
assert _graphql_entity_type("dataFlow") == "DATA_FLOW"
assert _graphql_entity_type("dataJob") == "DATA_JOB"
assert _graphql_entity_type("glossaryNode") == "GLOSSARY_NODE"
assert _graphql_entity_type("glossaryTerm") == "GLOSSARY_TERM"
assert _graphql_entity_type("dataHubExecutionRequest") == "EXECUTION_REQUEST"

View File

@ -0,0 +1,214 @@
from io import StringIO
import pytest
import yaml
from pydantic import ValidationError
from datahub.ingestion.graph.filters import SearchFilterRule
from datahub.sdk.search_client import compile_filters
from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters
from datahub.utilities.urns.error import InvalidUrnError
def test_filters_simple() -> None:
yaml_dict = {"platform": ["snowflake", "bigquery"]}
filter_obj: Filter = load_filters(yaml_dict)
assert filter_obj == F.platform(["snowflake", "bigquery"])
assert filter_obj.compile() == [
{
"and": [
SearchFilterRule(
field="platform.keyword",
condition="EQUAL",
values=[
"urn:li:dataPlatform:snowflake",
"urn:li:dataPlatform:bigquery",
],
)
]
}
]
def test_filters_and() -> None:
yaml_dict = {
"and": [
{"env": ["PROD"]},
{"platform": ["snowflake", "bigquery"]},
]
}
filter_obj: Filter = load_filters(yaml_dict)
assert filter_obj == F.and_(
F.env("PROD"),
F.platform(["snowflake", "bigquery"]),
)
platform_rule = SearchFilterRule(
field="platform.keyword",
condition="EQUAL",
values=[
"urn:li:dataPlatform:snowflake",
"urn:li:dataPlatform:bigquery",
],
)
assert filter_obj.compile() == [
{
"and": [
SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
platform_rule,
]
},
{
"and": [
SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
platform_rule,
]
},
]
def test_filters_complex() -> None:
yaml_dict = yaml.safe_load(
StringIO("""\
and:
- env: [PROD]
- or:
- platform: [ snowflake, bigquery ]
- and:
- platform: [postgres]
- not:
domain: [urn:li:domain:analytics]
- field: customProperties
condition: EQUAL
values: ["dbt_unique_id=source.project.name"]
""")
)
filter_obj: Filter = load_filters(yaml_dict)
assert filter_obj == F.and_(
F.env("PROD"),
F.or_(
F.platform(["snowflake", "bigquery"]),
F.and_(
F.platform("postgres"),
F.not_(F.domain("urn:li:domain:analytics")),
),
F.has_custom_property("dbt_unique_id", "source.project.name"),
),
)
warehouse_rule = SearchFilterRule(
field="platform.keyword",
condition="EQUAL",
values=["urn:li:dataPlatform:snowflake", "urn:li:dataPlatform:bigquery"],
)
postgres_rule = SearchFilterRule(
field="platform.keyword",
condition="EQUAL",
values=["urn:li:dataPlatform:postgres"],
)
domain_rule = SearchFilterRule(
field="domains",
condition="EQUAL",
values=["urn:li:domain:analytics"],
negated=True,
)
custom_property_rule = SearchFilterRule(
field="customProperties",
condition="EQUAL",
values=["dbt_unique_id=source.project.name"],
)
# There's one OR clause in the original filter with 3 clauses,
# and one hidden in the env filter with 2 clauses.
# The final result should have 3 * 2 = 6 OR clauses.
assert filter_obj.compile() == [
{
"and": [
SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
warehouse_rule,
],
},
{
"and": [
SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
postgres_rule,
domain_rule,
],
},
{
"and": [
SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
custom_property_rule,
],
},
{
"and": [
SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
warehouse_rule,
],
},
{
"and": [
SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
postgres_rule,
domain_rule,
],
},
{
"and": [
SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
custom_property_rule,
],
},
]
def test_invalid_filter() -> None:
with pytest.raises(InvalidUrnError):
F.domain("marketing")
def test_unsupported_not() -> None:
env_filter = F.env("PROD")
with pytest.raises(
ValidationError,
match="Cannot negate a filter with multiple OR clauses",
):
F.not_(env_filter)
def test_compile_filters() -> None:
filter = F.and_(F.env("PROD"), F.platform("snowflake"))
expected_filters = [
{
"and": [
{
"field": "origin",
"condition": "EQUAL",
"values": ["PROD"],
"negated": False,
},
{
"field": "platform.keyword",
"condition": "EQUAL",
"values": ["urn:li:dataPlatform:snowflake"],
"negated": False,
},
]
},
{
"and": [
{
"field": "env",
"condition": "EQUAL",
"values": ["PROD"],
"negated": False,
},
{
"field": "platform.keyword",
"condition": "EQUAL",
"values": ["urn:li:dataPlatform:snowflake"],
"negated": False,
},
]
},
]
assert compile_filters(filter) == expected_filters