mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-29 19:46:00 +00:00
1616 lines
55 KiB
Python
1616 lines
55 KiB
Python
import contextlib
|
|
import enum
|
|
import functools
|
|
import json
|
|
import logging
|
|
import textwrap
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from json.decoder import JSONDecodeError
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from avro.schema import RecordSchema
|
|
from deprecated import deprecated
|
|
from pydantic import BaseModel
|
|
from requests.models import HTTPError
|
|
|
|
from datahub.cli.cli_utils import get_url_and_token
|
|
from datahub.configuration.common import ConfigModel, GraphError, OperationalError
|
|
from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP
|
|
from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect
|
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
from datahub.emitter.rest_emitter import DatahubRestEmitter
|
|
from datahub.emitter.serialization_helper import post_json_transform
|
|
from datahub.ingestion.graph.connections import (
|
|
connections_gql,
|
|
get_id_from_connection_urn,
|
|
)
|
|
from datahub.ingestion.graph.filters import (
|
|
RemovedStatusFilter,
|
|
SearchFilterRule,
|
|
generate_filter,
|
|
)
|
|
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
|
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
|
|
MetadataChangeEvent,
|
|
MetadataChangeProposal,
|
|
)
|
|
from datahub.metadata.schema_classes import (
|
|
ASPECT_NAME_MAP,
|
|
KEY_ASPECTS,
|
|
AspectBag,
|
|
BrowsePathsClass,
|
|
DatasetPropertiesClass,
|
|
DatasetUsageStatisticsClass,
|
|
DomainPropertiesClass,
|
|
DomainsClass,
|
|
GlobalTagsClass,
|
|
GlossaryTermsClass,
|
|
OwnershipClass,
|
|
SchemaMetadataClass,
|
|
StatusClass,
|
|
SystemMetadataClass,
|
|
TelemetryClientIdClass,
|
|
)
|
|
from datahub.utilities.perf_timer import PerfTimer
|
|
from datahub.utilities.urns.urn import Urn, guess_entity_type
|
|
|
|
if TYPE_CHECKING:
|
|
from datahub.ingestion.sink.datahub_rest import (
|
|
DatahubRestSink,
|
|
DatahubRestSinkConfig,
|
|
)
|
|
from datahub.ingestion.source.state.entity_removal_state import (
|
|
GenericCheckpointState,
|
|
)
|
|
from datahub.sql_parsing.schema_resolver import (
|
|
GraphQLSchemaMetadata,
|
|
SchemaResolver,
|
|
)
|
|
from datahub.sql_parsing.sqlglot_lineage import SqlParsingResult
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_MISSING_SERVER_ID = "missing"
|
|
_GRAPH_DUMMY_RUN_ID = "__datahub-graph-client"
|
|
|
|
|
|
class DatahubClientConfig(ConfigModel):
|
|
"""Configuration class for holding connectivity to datahub gms"""
|
|
|
|
server: str = "http://localhost:8080"
|
|
token: Optional[str] = None
|
|
timeout_sec: Optional[int] = None
|
|
retry_status_codes: Optional[List[int]] = None
|
|
retry_max_times: Optional[int] = None
|
|
extra_headers: Optional[Dict[str, str]] = None
|
|
ca_certificate_path: Optional[str] = None
|
|
client_certificate_path: Optional[str] = None
|
|
disable_ssl_verification: bool = False
|
|
|
|
|
|
# Alias for backwards compatibility.
|
|
# DEPRECATION: Remove in v0.10.2.
|
|
DataHubGraphConfig = DatahubClientConfig
|
|
|
|
|
|
@dataclass
|
|
class RelatedEntity:
|
|
urn: str
|
|
relationship_type: str
|
|
via: Optional[str] = None
|
|
|
|
|
|
def _graphql_entity_type(entity_type: str) -> str:
|
|
"""Convert the entity types into GraphQL "EntityType" enum values."""
|
|
|
|
# Hard-coded special cases.
|
|
if entity_type == "corpuser":
|
|
return "CORP_USER"
|
|
|
|
# Convert camelCase to UPPER_UNDERSCORE.
|
|
entity_type = (
|
|
"".join(["_" + c.lower() if c.isupper() else c for c in entity_type])
|
|
.lstrip("_")
|
|
.upper()
|
|
)
|
|
|
|
# Strip the "DATA_HUB_" prefix.
|
|
if entity_type.startswith("DATA_HUB_"):
|
|
entity_type = entity_type[len("DATA_HUB_") :]
|
|
|
|
return entity_type
|
|
|
|
|
|
class DataHubGraph(DatahubRestEmitter):
|
|
def __init__(self, config: DatahubClientConfig) -> None:
|
|
self.config = config
|
|
super().__init__(
|
|
gms_server=self.config.server,
|
|
token=self.config.token,
|
|
connect_timeout_sec=self.config.timeout_sec, # reuse timeout_sec for connect timeout
|
|
read_timeout_sec=self.config.timeout_sec,
|
|
retry_status_codes=self.config.retry_status_codes,
|
|
retry_max_times=self.config.retry_max_times,
|
|
extra_headers=self.config.extra_headers,
|
|
ca_certificate_path=self.config.ca_certificate_path,
|
|
client_certificate_path=self.config.client_certificate_path,
|
|
disable_ssl_verification=self.config.disable_ssl_verification,
|
|
)
|
|
|
|
self.server_id = _MISSING_SERVER_ID
|
|
|
|
def test_connection(self) -> None:
|
|
super().test_connection()
|
|
|
|
# Cache the server id for telemetry.
|
|
from datahub.telemetry.telemetry import telemetry_instance
|
|
|
|
if not telemetry_instance.enabled:
|
|
self.server_id = _MISSING_SERVER_ID
|
|
return
|
|
try:
|
|
client_id: Optional[TelemetryClientIdClass] = self.get_aspect(
|
|
"urn:li:telemetry:clientId", TelemetryClientIdClass
|
|
)
|
|
self.server_id = client_id.clientId if client_id else _MISSING_SERVER_ID
|
|
except Exception as e:
|
|
self.server_id = _MISSING_SERVER_ID
|
|
logger.debug(f"Failed to get server id due to {e}")
|
|
|
|
@classmethod
|
|
def from_emitter(cls, emitter: DatahubRestEmitter) -> "DataHubGraph":
|
|
return cls(
|
|
DatahubClientConfig(
|
|
server=emitter._gms_server,
|
|
token=emitter._token,
|
|
timeout_sec=emitter._read_timeout_sec,
|
|
retry_status_codes=emitter._retry_status_codes,
|
|
retry_max_times=emitter._retry_max_times,
|
|
extra_headers=emitter._session.headers,
|
|
disable_ssl_verification=emitter._session.verify is False,
|
|
# TODO: Support these headers.
|
|
# ca_certificate_path=emitter._ca_certificate_path,
|
|
# client_certificate_path=emitter._client_certificate_path,
|
|
)
|
|
)
|
|
|
|
def _send_restli_request(self, method: str, url: str, **kwargs: Any) -> Dict:
|
|
try:
|
|
response = self._session.request(method, url, **kwargs)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except HTTPError as e:
|
|
try:
|
|
info = response.json()
|
|
raise OperationalError(
|
|
"Unable to get metadata from DataHub", info
|
|
) from e
|
|
except JSONDecodeError:
|
|
# If we can't parse the JSON, just raise the original error.
|
|
raise OperationalError(
|
|
"Unable to get metadata from DataHub", {"message": str(e)}
|
|
) from e
|
|
|
|
def _get_generic(self, url: str, params: Optional[Dict] = None) -> Dict:
|
|
return self._send_restli_request("GET", url, params=params)
|
|
|
|
def _post_generic(self, url: str, payload_dict: Dict) -> Dict:
|
|
return self._send_restli_request("POST", url, json=payload_dict)
|
|
|
|
def _make_rest_sink_config(self) -> "DatahubRestSinkConfig":
|
|
from datahub.ingestion.sink.datahub_rest import (
|
|
DatahubRestSinkConfig,
|
|
RestSinkMode,
|
|
)
|
|
|
|
# This is a bit convoluted - this DataHubGraph class is a subclass of DatahubRestEmitter,
|
|
# but initializing the rest sink creates another rest emitter.
|
|
# TODO: We should refactor out the multithreading functionality of the sink
|
|
# into a separate class that can be used by both the sink and the graph client
|
|
# e.g. a DatahubBulkRestEmitter that both the sink and the graph client use.
|
|
return DatahubRestSinkConfig(**self.config.dict(), mode=RestSinkMode.ASYNC)
|
|
|
|
@contextlib.contextmanager
|
|
def make_rest_sink(
|
|
self, run_id: str = _GRAPH_DUMMY_RUN_ID
|
|
) -> Iterator["DatahubRestSink"]:
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
from datahub.ingestion.sink.datahub_rest import DatahubRestSink
|
|
|
|
sink_config = self._make_rest_sink_config()
|
|
with DatahubRestSink(PipelineContext(run_id=run_id), sink_config) as sink:
|
|
yield sink
|
|
if sink.report.failures:
|
|
raise OperationalError(
|
|
f"Failed to emit {len(sink.report.failures)} records",
|
|
info=sink.report.as_obj(),
|
|
)
|
|
|
|
def emit_all(
|
|
self,
|
|
items: Iterable[
|
|
Union[
|
|
MetadataChangeEvent,
|
|
MetadataChangeProposal,
|
|
MetadataChangeProposalWrapper,
|
|
]
|
|
],
|
|
run_id: str = _GRAPH_DUMMY_RUN_ID,
|
|
) -> None:
|
|
"""Emit all items in the iterable using multiple threads."""
|
|
|
|
# The context manager also ensures that we raise an error if a failure occurs.
|
|
with self.make_rest_sink(run_id=run_id) as sink:
|
|
for item in items:
|
|
sink.emit_async(item)
|
|
|
|
def get_aspect(
|
|
self,
|
|
entity_urn: str,
|
|
aspect_type: Type[Aspect],
|
|
version: int = 0,
|
|
) -> Optional[Aspect]:
|
|
"""
|
|
Get an aspect for an entity.
|
|
|
|
:param entity_urn: The urn of the entity
|
|
:param aspect_type: The type class of the aspect being requested (e.g. datahub.metadata.schema_classes.DatasetProperties)
|
|
:param version: The version of the aspect to retrieve. The default of 0 means latest. Versions > 0 go from oldest to newest, so 1 is the oldest.
|
|
:return: the Aspect as a dictionary if present, None if no aspect was found (HTTP status 404)
|
|
|
|
:raises TypeError: if the aspect type is a timeseries aspect
|
|
:raises HttpError: if the HTTP response is not a 200 or a 404
|
|
"""
|
|
|
|
aspect = aspect_type.ASPECT_NAME
|
|
if aspect in TIMESERIES_ASPECT_MAP:
|
|
raise TypeError(
|
|
'Cannot get a timeseries aspect using "get_aspect". Use "get_latest_timeseries_value" instead.'
|
|
)
|
|
|
|
url: str = f"{self._gms_server}/aspects/{Urn.url_encode(entity_urn)}?aspect={aspect}&version={version}"
|
|
response = self._session.get(url)
|
|
if response.status_code == 404:
|
|
# not found
|
|
return None
|
|
response.raise_for_status()
|
|
response_json = response.json()
|
|
|
|
# Figure out what field to look in.
|
|
record_schema: RecordSchema = aspect_type.RECORD_SCHEMA
|
|
aspect_type_name = record_schema.fullname.replace(".pegasus2avro", "")
|
|
|
|
# Deserialize the aspect json into the aspect type.
|
|
aspect_json = response_json.get("aspect", {}).get(aspect_type_name)
|
|
if aspect_json is not None:
|
|
# need to apply a transform to the response to match rest.li and avro serialization
|
|
post_json_obj = post_json_transform(aspect_json)
|
|
return aspect_type.from_obj(post_json_obj)
|
|
else:
|
|
raise GraphError(
|
|
f"Failed to find {aspect_type_name} in response {response_json}"
|
|
)
|
|
|
|
@deprecated(reason="Use get_aspect instead which makes aspect string name optional")
|
|
def get_aspect_v2(
|
|
self,
|
|
entity_urn: str,
|
|
aspect_type: Type[Aspect],
|
|
aspect: str,
|
|
aspect_type_name: Optional[str] = None,
|
|
version: int = 0,
|
|
) -> Optional[Aspect]:
|
|
assert aspect_type.ASPECT_NAME == aspect
|
|
return self.get_aspect(
|
|
entity_urn=entity_urn,
|
|
aspect_type=aspect_type,
|
|
version=version,
|
|
)
|
|
|
|
def get_config(self) -> Dict[str, Any]:
|
|
return self._get_generic(f"{self.config.server}/config")
|
|
|
|
def get_ownership(self, entity_urn: str) -> Optional[OwnershipClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=OwnershipClass)
|
|
|
|
def get_schema_metadata(self, entity_urn: str) -> Optional[SchemaMetadataClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=SchemaMetadataClass)
|
|
|
|
@deprecated(reason="Use get_aspect directly.")
|
|
def get_domain_properties(self, entity_urn: str) -> Optional[DomainPropertiesClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=DomainPropertiesClass)
|
|
|
|
def get_dataset_properties(
|
|
self, entity_urn: str
|
|
) -> Optional[DatasetPropertiesClass]:
|
|
return self.get_aspect(
|
|
entity_urn=entity_urn, aspect_type=DatasetPropertiesClass
|
|
)
|
|
|
|
def get_tags(self, entity_urn: str) -> Optional[GlobalTagsClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=GlobalTagsClass)
|
|
|
|
def get_glossary_terms(self, entity_urn: str) -> Optional[GlossaryTermsClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=GlossaryTermsClass)
|
|
|
|
def get_domain(self, entity_urn: str) -> Optional[DomainsClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=DomainsClass)
|
|
|
|
@deprecated(reason="Use get_aspect directly.")
|
|
def get_browse_path(self, entity_urn: str) -> Optional[BrowsePathsClass]:
|
|
return self.get_aspect(entity_urn=entity_urn, aspect_type=BrowsePathsClass)
|
|
|
|
def get_usage_aspects_from_urn(
|
|
self, entity_urn: str, start_timestamp: int, end_timestamp: int
|
|
) -> Optional[List[DatasetUsageStatisticsClass]]:
|
|
payload = {
|
|
"urn": entity_urn,
|
|
"entity": "dataset",
|
|
"aspect": "datasetUsageStatistics",
|
|
"startTimeMillis": start_timestamp,
|
|
"endTimeMillis": end_timestamp,
|
|
}
|
|
headers: Dict[str, Any] = {}
|
|
url = f"{self._gms_server}/aspects?action=getTimeseriesAspectValues"
|
|
try:
|
|
usage_aspects: List[DatasetUsageStatisticsClass] = []
|
|
response = self._session.post(
|
|
url, data=json.dumps(payload), headers=headers
|
|
)
|
|
if response.status_code != 200:
|
|
logger.debug(
|
|
f"Non 200 status found while fetching usage aspects - {response.status_code}"
|
|
)
|
|
return None
|
|
json_resp = response.json()
|
|
all_aspects = json_resp.get("value", {}).get("values", [])
|
|
for aspect in all_aspects:
|
|
if aspect.get("aspect") and aspect.get("aspect").get("value"):
|
|
usage_aspects.append(
|
|
DatasetUsageStatisticsClass.from_obj(
|
|
json.loads(aspect.get("aspect").get("value")), tuples=True
|
|
)
|
|
)
|
|
return usage_aspects
|
|
except Exception as e:
|
|
logger.error("Error while getting usage aspects.", e)
|
|
return None
|
|
|
|
def list_all_entity_urns(
|
|
self, entity_type: str, start: int, count: int
|
|
) -> Optional[List[str]]:
|
|
url = f"{self._gms_server}/entities?action=listUrns"
|
|
payload = {"entity": entity_type, "start": start, "count": count}
|
|
headers = {
|
|
"X-RestLi-Protocol-Version": "2.0.0",
|
|
"Content-Type": "application/json",
|
|
}
|
|
try:
|
|
response = self._session.post(
|
|
url, data=json.dumps(payload), headers=headers
|
|
)
|
|
if response.status_code != 200:
|
|
logger.debug(
|
|
f"Non 200 status found while fetching entity urns - {response.status_code}"
|
|
)
|
|
return None
|
|
json_resp = response.json()
|
|
return json_resp.get("value", {}).get("entities")
|
|
except Exception as e:
|
|
logger.error("Error while fetching entity urns.", e)
|
|
return None
|
|
|
|
def get_latest_timeseries_value(
|
|
self,
|
|
entity_urn: str,
|
|
aspect_type: Type[Aspect],
|
|
filter_criteria_map: Dict[str, str],
|
|
) -> Optional[Aspect]:
|
|
filter_criteria = [
|
|
{"field": k, "value": v, "condition": "EQUAL"}
|
|
for k, v in filter_criteria_map.items()
|
|
]
|
|
filter = {"or": [{"and": filter_criteria}]}
|
|
|
|
values = self.get_timeseries_values(
|
|
entity_urn=entity_urn, aspect_type=aspect_type, filter=filter, limit=1
|
|
)
|
|
if not values:
|
|
return None
|
|
|
|
assert len(values) == 1, len(values)
|
|
return values[0]
|
|
|
|
def get_timeseries_values(
|
|
self,
|
|
entity_urn: str,
|
|
aspect_type: Type[Aspect],
|
|
filter: Dict[str, Any],
|
|
limit: int = 10,
|
|
) -> List[Aspect]:
|
|
query_body = {
|
|
"urn": entity_urn,
|
|
"entity": guess_entity_type(entity_urn),
|
|
"aspect": aspect_type.ASPECT_NAME,
|
|
"limit": limit,
|
|
"filter": filter,
|
|
}
|
|
end_point = f"{self.config.server}/aspects?action=getTimeseriesAspectValues"
|
|
resp: Dict = self._post_generic(end_point, query_body)
|
|
|
|
values: Optional[List] = resp.get("value", {}).get("values")
|
|
aspects: List[Aspect] = []
|
|
for value in values or []:
|
|
aspect_json: str = value.get("aspect", {}).get("value")
|
|
if aspect_json:
|
|
aspects.append(
|
|
aspect_type.from_obj(json.loads(aspect_json), tuples=False)
|
|
)
|
|
else:
|
|
raise GraphError(
|
|
f"Failed to find {aspect_type} in response {aspect_json}"
|
|
)
|
|
return aspects
|
|
|
|
def get_entity_raw(
|
|
self, entity_urn: str, aspects: Optional[List[str]] = None
|
|
) -> Dict:
|
|
endpoint: str = f"{self.config.server}/entitiesV2/{Urn.url_encode(entity_urn)}"
|
|
if aspects is not None:
|
|
assert aspects, "if provided, aspects must be a non-empty list"
|
|
endpoint = f"{endpoint}?aspects=List(" + ",".join(aspects) + ")"
|
|
|
|
response = self._session.get(endpoint)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
@deprecated(
|
|
reason="Use get_aspect for a single aspect or get_entity_semityped for a full entity."
|
|
)
|
|
def get_aspects_for_entity(
|
|
self,
|
|
entity_urn: str,
|
|
aspects: List[str],
|
|
aspect_types: List[Type[Aspect]],
|
|
) -> Dict[str, Optional[Aspect]]:
|
|
"""
|
|
Get multiple aspects for an entity.
|
|
|
|
Deprecated in favor of `get_aspect` (single aspect) or `get_entity_semityped` (full
|
|
entity without manually specifying a list of aspects).
|
|
|
|
Warning: Do not use this method to determine if an entity exists!
|
|
This method will always return an entity, even if it doesn't exist. This is an issue with how DataHub server
|
|
responds to these calls, and will be fixed automatically when the server-side issue is fixed.
|
|
|
|
:param str entity_urn: The urn of the entity
|
|
:param List[Type[Aspect]] aspect_type_list: List of aspect type classes being requested (e.g. [datahub.metadata.schema_classes.DatasetProperties])
|
|
:param List[str] aspects_list: List of aspect names being requested (e.g. [schemaMetadata, datasetProperties])
|
|
:return: Optionally, a map of aspect_name to aspect_value as a dictionary if present, aspect_value will be set to None if that aspect was not found. Returns None on HTTP status 404.
|
|
:raises HttpError: if the HTTP response is not a 200
|
|
"""
|
|
assert len(aspects) == len(
|
|
aspect_types
|
|
), f"number of aspects requested ({len(aspects)}) should be the same as number of aspect types provided ({len(aspect_types)})"
|
|
|
|
# TODO: generate aspects list from type classes
|
|
response_json = self.get_entity_raw(entity_urn, aspects)
|
|
|
|
result: Dict[str, Optional[Aspect]] = {}
|
|
for aspect_type in aspect_types:
|
|
aspect_type_name = aspect_type.get_aspect_name()
|
|
|
|
aspect_json = response_json.get("aspects", {}).get(aspect_type_name)
|
|
if aspect_json:
|
|
# need to apply a transform to the response to match rest.li and avro serialization
|
|
post_json_obj = post_json_transform(aspect_json)
|
|
result[aspect_type_name] = aspect_type.from_obj(post_json_obj["value"])
|
|
else:
|
|
result[aspect_type_name] = None
|
|
|
|
return result
|
|
|
|
def get_entity_semityped(self, entity_urn: str) -> AspectBag:
|
|
"""Get all non-timeseries aspects for an entity (experimental).
|
|
|
|
This method is called "semityped" because it returns aspects as a dictionary of
|
|
properly typed objects. While the returned dictionary is constrained using a TypedDict,
|
|
the return type is still fairly loose.
|
|
|
|
Warning: Do not use this method to determine if an entity exists! This method will always return
|
|
something, even if the entity doesn't actually exist in DataHub.
|
|
|
|
:param entity_urn: The urn of the entity
|
|
:returns: A dictionary of aspect name to aspect value. If an aspect is not found, it will
|
|
not be present in the dictionary. The entity's key aspect will always be present.
|
|
"""
|
|
|
|
response_json = self.get_entity_raw(entity_urn)
|
|
|
|
# Now, we parse the response into proper aspect objects.
|
|
result: AspectBag = {}
|
|
for aspect_name, aspect_json in response_json.get("aspects", {}).items():
|
|
aspect_type = ASPECT_NAME_MAP.get(aspect_name)
|
|
if aspect_type is None:
|
|
logger.warning(f"Ignoring unknown aspect type {aspect_name}")
|
|
continue
|
|
|
|
post_json_obj = post_json_transform(aspect_json)
|
|
aspect_value = aspect_type.from_obj(post_json_obj["value"])
|
|
result[aspect_name] = aspect_value # type: ignore
|
|
|
|
return result
|
|
|
|
@property
|
|
def _search_endpoint(self):
|
|
return f"{self.config.server}/entities?action=search"
|
|
|
|
@property
|
|
def _relationships_endpoint(self):
|
|
return f"{self.config.server}/openapi/relationships/v1/"
|
|
|
|
@property
|
|
def _aspect_count_endpoint(self):
|
|
return f"{self.config.server}/aspects?action=getCount"
|
|
|
|
def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]:
|
|
"""Retrieve a domain urn based on its name. Returns None if there is no match found"""
|
|
|
|
filters = []
|
|
filter_criteria = [
|
|
{
|
|
"field": "name",
|
|
"value": domain_name,
|
|
"condition": "EQUAL",
|
|
}
|
|
]
|
|
|
|
filters.append({"and": filter_criteria})
|
|
search_body = {
|
|
"input": "*",
|
|
"entity": "domain",
|
|
"start": 0,
|
|
"count": 10,
|
|
"filter": {"or": filters},
|
|
}
|
|
results: Dict = self._post_generic(self._search_endpoint, search_body)
|
|
num_entities = results.get("value", {}).get("numEntities", 0)
|
|
if num_entities > 1:
|
|
logger.warning(
|
|
f"Got {num_entities} results for domain name {domain_name}. Will return the first match."
|
|
)
|
|
entities_yielded: int = 0
|
|
entities = []
|
|
for x in results["value"]["entities"]:
|
|
entities_yielded += 1
|
|
logger.debug(f"yielding {x['entity']}")
|
|
entities.append(x["entity"])
|
|
return entities[0] if entities_yielded else None
|
|
|
|
def get_connection_json(self, urn: str) -> Optional[dict]:
|
|
"""Retrieve a connection config.
|
|
|
|
This is only supported with Acryl Cloud.
|
|
|
|
Args:
|
|
urn: The urn of the connection.
|
|
|
|
Returns:
|
|
The connection config as a dictionary, or None if the connection was not found.
|
|
"""
|
|
|
|
# TODO: This should be capable of resolving secrets.
|
|
|
|
res = self.execute_graphql(
|
|
query=connections_gql,
|
|
operation_name="GetConnection",
|
|
variables={"urn": urn},
|
|
)
|
|
|
|
if not res["connection"]:
|
|
return None
|
|
|
|
connection_type = res["connection"]["details"]["type"]
|
|
if connection_type != "JSON":
|
|
logger.error(
|
|
f"Expected connection details type to be 'JSON', but got {connection_type}"
|
|
)
|
|
return None
|
|
|
|
blob = res["connection"]["details"]["json"]["blob"]
|
|
obj = json.loads(blob)
|
|
|
|
name = res["connection"]["details"].get("name")
|
|
logger.info(f"Loaded connection {name or urn}")
|
|
|
|
return obj
|
|
|
|
def set_connection_json(
|
|
self,
|
|
urn: str,
|
|
*,
|
|
platform_urn: str,
|
|
config: Union[ConfigModel, BaseModel, dict],
|
|
name: Optional[str] = None,
|
|
) -> None:
|
|
"""Set a connection config.
|
|
|
|
This is only supported with Acryl Cloud.
|
|
|
|
Args:
|
|
urn: The urn of the connection.
|
|
platform_urn: The urn of the platform.
|
|
config: The connection config as a dictionary or a ConfigModel.
|
|
name: The name of the connection.
|
|
"""
|
|
|
|
if isinstance(config, (ConfigModel, BaseModel)):
|
|
blob = config.json()
|
|
else:
|
|
blob = json.dumps(config)
|
|
|
|
id = get_id_from_connection_urn(urn)
|
|
|
|
res = self.execute_graphql(
|
|
query=connections_gql,
|
|
operation_name="SetConnection",
|
|
variables={
|
|
"id": id,
|
|
"platformUrn": platform_urn,
|
|
"name": name,
|
|
"blob": blob,
|
|
},
|
|
)
|
|
|
|
assert res["upsertConnection"]["urn"] == urn
|
|
|
|
@deprecated(
|
|
reason='Use get_urns_by_filter(entity_types=["container"], ...) instead'
|
|
)
|
|
def get_container_urns_by_filter(
|
|
self,
|
|
env: Optional[str] = None,
|
|
search_query: str = "*",
|
|
) -> Iterable[str]:
|
|
"""Return container urns that match based on query"""
|
|
url = self._search_endpoint
|
|
|
|
container_filters = []
|
|
for container_subtype in ["Database", "Schema", "Project", "Dataset"]:
|
|
filter_criteria = []
|
|
|
|
filter_criteria.append(
|
|
{
|
|
"field": "customProperties",
|
|
"value": f"instance={env}",
|
|
"condition": "EQUAL",
|
|
}
|
|
)
|
|
|
|
filter_criteria.append(
|
|
{
|
|
"field": "typeNames",
|
|
"value": container_subtype,
|
|
"condition": "EQUAL",
|
|
}
|
|
)
|
|
container_filters.append({"and": filter_criteria})
|
|
search_body = {
|
|
"input": search_query,
|
|
"entity": "container",
|
|
"start": 0,
|
|
"count": 10000,
|
|
"filter": {"or": container_filters},
|
|
}
|
|
results: Dict = self._post_generic(url, search_body)
|
|
num_entities = results["value"]["numEntities"]
|
|
logger.debug(f"Matched {num_entities} containers")
|
|
entities_yielded: int = 0
|
|
for x in results["value"]["entities"]:
|
|
entities_yielded += 1
|
|
logger.debug(f"yielding {x['entity']}")
|
|
yield x["entity"]
|
|
|
|
def _bulk_fetch_schema_info_by_filter(
|
|
self,
|
|
*,
|
|
platform: Optional[str] = None,
|
|
platform_instance: Optional[str] = None,
|
|
env: Optional[str] = None,
|
|
query: Optional[str] = None,
|
|
container: Optional[str] = None,
|
|
status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED,
|
|
batch_size: int = 100,
|
|
extraFilters: Optional[List[SearchFilterRule]] = None,
|
|
) -> Iterable[Tuple[str, "GraphQLSchemaMetadata"]]:
|
|
"""Fetch schema info for datasets that match all of the given filters.
|
|
|
|
:return: An iterable of (urn, schema info) tuple that match the filters.
|
|
"""
|
|
types = [_graphql_entity_type("dataset")]
|
|
|
|
# Add the query default of * if no query is specified.
|
|
query = query or "*"
|
|
|
|
orFilters = generate_filter(
|
|
platform, platform_instance, env, container, status, extraFilters
|
|
)
|
|
|
|
graphql_query = textwrap.dedent(
|
|
"""
|
|
query scrollUrnsWithFilters(
|
|
$types: [EntityType!],
|
|
$query: String!,
|
|
$orFilters: [AndFilterInput!],
|
|
$batchSize: Int!,
|
|
$scrollId: String) {
|
|
|
|
scrollAcrossEntities(input: {
|
|
query: $query,
|
|
count: $batchSize,
|
|
scrollId: $scrollId,
|
|
types: $types,
|
|
orFilters: $orFilters,
|
|
searchFlags: {
|
|
skipHighlighting: true
|
|
skipAggregates: true
|
|
}
|
|
}) {
|
|
nextScrollId
|
|
searchResults {
|
|
entity {
|
|
urn
|
|
... on Dataset {
|
|
schemaMetadata(version: 0) {
|
|
fields {
|
|
fieldPath
|
|
nativeDataType
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
variables = {
|
|
"types": types,
|
|
"query": query,
|
|
"orFilters": orFilters,
|
|
"batchSize": batch_size,
|
|
}
|
|
|
|
for entity in self._scroll_across_entities(graphql_query, variables):
|
|
if entity.get("schemaMetadata"):
|
|
yield entity["urn"], entity["schemaMetadata"]
|
|
|
|
def get_urns_by_filter(
|
|
self,
|
|
*,
|
|
entity_types: Optional[List[str]] = None,
|
|
platform: Optional[str] = None,
|
|
platform_instance: Optional[str] = None,
|
|
env: Optional[str] = None,
|
|
query: Optional[str] = None,
|
|
container: Optional[str] = None,
|
|
status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED,
|
|
batch_size: int = 10000,
|
|
extraFilters: Optional[List[SearchFilterRule]] = None,
|
|
) -> Iterable[str]:
|
|
"""Fetch all urns that match all of the given filters.
|
|
|
|
Filters are combined conjunctively. If multiple filters are specified, the results will match all of them.
|
|
Note that specifying a platform filter will automatically exclude all entity types that do not have a platform.
|
|
The same goes for the env filter.
|
|
|
|
:param entity_types: List of entity types to include. If None, all entity types will be returned.
|
|
:param platform: Platform to filter on. If None, all platforms will be returned.
|
|
:param platform_instance: Platform instance to filter on. If None, all platform instances will be returned.
|
|
:param env: Environment (e.g. PROD, DEV) to filter on. If None, all environments will be returned.
|
|
:param query: Query string to filter on. If None, all entities will be returned.
|
|
:param container: A container urn that entities must be within.
|
|
This works recursively, so it will include entities within sub-containers as well.
|
|
If None, all entities will be returned.
|
|
Note that this requires browsePathV2 aspects (added in 0.10.4+).
|
|
:param status: Filter on the deletion status of the entity. The default is only return non-soft-deleted entities.
|
|
:param extraFilters: Additional filters to apply. If specified, the results will match all of the filters.
|
|
|
|
:return: An iterable of urns that match the filters.
|
|
"""
|
|
|
|
types = self._get_types(entity_types)
|
|
|
|
# Add the query default of * if no query is specified.
|
|
query = query or "*"
|
|
|
|
# Env filter.
|
|
orFilters = generate_filter(
|
|
platform, platform_instance, env, container, status, extraFilters
|
|
)
|
|
|
|
graphql_query = textwrap.dedent(
|
|
"""
|
|
query scrollUrnsWithFilters(
|
|
$types: [EntityType!],
|
|
$query: String!,
|
|
$orFilters: [AndFilterInput!],
|
|
$batchSize: Int!,
|
|
$scrollId: String) {
|
|
|
|
scrollAcrossEntities(input: {
|
|
query: $query,
|
|
count: $batchSize,
|
|
scrollId: $scrollId,
|
|
types: $types,
|
|
orFilters: $orFilters,
|
|
searchFlags: {
|
|
skipHighlighting: true
|
|
skipAggregates: true
|
|
}
|
|
}) {
|
|
nextScrollId
|
|
searchResults {
|
|
entity {
|
|
urn
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
variables = {
|
|
"types": types,
|
|
"query": query,
|
|
"orFilters": orFilters,
|
|
"batchSize": batch_size,
|
|
}
|
|
|
|
for entity in self._scroll_across_entities(graphql_query, variables):
|
|
yield entity["urn"]
|
|
|
|
def _scroll_across_entities(
|
|
self, graphql_query: str, variables_orig: dict
|
|
) -> Iterable[dict]:
|
|
variables = variables_orig.copy()
|
|
first_iter = True
|
|
scroll_id: Optional[str] = None
|
|
while first_iter or scroll_id:
|
|
first_iter = False
|
|
variables["scrollId"] = scroll_id
|
|
|
|
response = self.execute_graphql(
|
|
graphql_query,
|
|
variables=variables,
|
|
)
|
|
data = response["scrollAcrossEntities"]
|
|
scroll_id = data["nextScrollId"]
|
|
for entry in data["searchResults"]:
|
|
yield entry["entity"]
|
|
|
|
if scroll_id:
|
|
logger.debug(
|
|
f"Scrolling to next scrollAcrossEntities page: {scroll_id}"
|
|
)
|
|
|
|
def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]:
|
|
types: Optional[List[str]] = None
|
|
if entity_types is not None:
|
|
if not entity_types:
|
|
raise ValueError(
|
|
"entity_types cannot be an empty list; use None for all entities"
|
|
)
|
|
|
|
types = [_graphql_entity_type(entity_type) for entity_type in entity_types]
|
|
return types
|
|
|
|
def get_latest_pipeline_checkpoint(
|
|
self, pipeline_name: str, platform: str
|
|
) -> Optional[Checkpoint["GenericCheckpointState"]]:
|
|
from datahub.ingestion.source.state.entity_removal_state import (
|
|
GenericCheckpointState,
|
|
)
|
|
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
|
StaleEntityRemovalHandler,
|
|
)
|
|
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
|
|
DatahubIngestionCheckpointingProvider,
|
|
)
|
|
|
|
checkpoint_provider = DatahubIngestionCheckpointingProvider(self)
|
|
job_name = StaleEntityRemovalHandler.compute_job_id(platform)
|
|
|
|
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(
|
|
pipeline_name, job_name
|
|
)
|
|
if not raw_checkpoint:
|
|
return None
|
|
|
|
return Checkpoint.create_from_checkpoint_aspect(
|
|
job_name=job_name,
|
|
checkpoint_aspect=raw_checkpoint,
|
|
state_class=GenericCheckpointState,
|
|
)
|
|
|
|
def get_search_results(
|
|
self, start: int = 0, count: int = 1, entity: str = "dataset"
|
|
) -> Dict:
|
|
search_body = {"input": "*", "entity": entity, "start": start, "count": count}
|
|
results: Dict = self._post_generic(self._search_endpoint, search_body)
|
|
return results
|
|
|
|
def get_aspect_counts(self, aspect: str, urn_like: Optional[str] = None) -> int:
|
|
args = {"aspect": aspect}
|
|
if urn_like is not None:
|
|
args["urnLike"] = urn_like
|
|
results = self._post_generic(self._aspect_count_endpoint, args)
|
|
return results["value"]
|
|
|
|
def execute_graphql(
|
|
self,
|
|
query: str,
|
|
variables: Optional[Dict] = None,
|
|
operation_name: Optional[str] = None,
|
|
) -> Dict:
|
|
url = f"{self.config.server}/api/graphql"
|
|
|
|
body: Dict = {
|
|
"query": query,
|
|
}
|
|
if variables:
|
|
body["variables"] = variables
|
|
if operation_name:
|
|
body["operationName"] = operation_name
|
|
|
|
logger.debug(
|
|
f"Executing {operation_name or ''} graphql query: {query} with variables: {json.dumps(variables)}"
|
|
)
|
|
result = self._post_generic(url, body)
|
|
if result.get("errors"):
|
|
raise GraphError(f"Error executing graphql query: {result['errors']}")
|
|
|
|
return result["data"]
|
|
|
|
class RelationshipDirection(str, enum.Enum):
|
|
# FIXME: Upgrade to enum.StrEnum when we drop support for Python 3.10
|
|
|
|
INCOMING = "INCOMING"
|
|
OUTGOING = "OUTGOING"
|
|
|
|
def get_related_entities(
|
|
self,
|
|
entity_urn: str,
|
|
relationship_types: List[str],
|
|
direction: RelationshipDirection,
|
|
) -> Iterable[RelatedEntity]:
|
|
relationship_endpoint = self._relationships_endpoint
|
|
done = False
|
|
start = 0
|
|
while not done:
|
|
response = self._get_generic(
|
|
url=relationship_endpoint,
|
|
params={
|
|
"urn": entity_urn,
|
|
"direction": direction.value,
|
|
"relationshipTypes": relationship_types,
|
|
"start": start,
|
|
},
|
|
)
|
|
for related_entity in response.get("entities", []):
|
|
yield RelatedEntity(
|
|
urn=related_entity["urn"],
|
|
relationship_type=related_entity["relationshipType"],
|
|
via=related_entity.get("via"),
|
|
)
|
|
done = response.get("count", 0) == 0 or response.get("count", 0) < len(
|
|
response.get("entities", [])
|
|
)
|
|
start = start + response.get("count", 0)
|
|
|
|
def exists(self, entity_urn: str) -> bool:
|
|
entity_urn_parsed: Urn = Urn.from_string(entity_urn)
|
|
try:
|
|
key_aspect_class = KEY_ASPECTS.get(entity_urn_parsed.entity_type)
|
|
if key_aspect_class:
|
|
result = self.get_aspect(entity_urn, key_aspect_class)
|
|
return result is not None
|
|
else:
|
|
raise Exception(
|
|
f"Failed to find key class for entity type {entity_urn_parsed.get_type()} for urn {entity_urn}"
|
|
)
|
|
except Exception as e:
|
|
logger.debug(
|
|
f"Failed to check for existence of urn {entity_urn}", exc_info=e
|
|
)
|
|
raise
|
|
|
|
def soft_delete_entity(
|
|
self,
|
|
urn: str,
|
|
run_id: str = _GRAPH_DUMMY_RUN_ID,
|
|
deletion_timestamp: Optional[int] = None,
|
|
) -> None:
|
|
"""Soft-delete an entity by urn.
|
|
|
|
Args:
|
|
urn: The urn of the entity to soft-delete.
|
|
"""
|
|
|
|
assert urn
|
|
|
|
deletion_timestamp = deletion_timestamp or int(time.time() * 1000)
|
|
self.emit(
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=urn,
|
|
aspect=StatusClass(removed=True),
|
|
systemMetadata=SystemMetadataClass(
|
|
runId=run_id, lastObserved=deletion_timestamp
|
|
),
|
|
)
|
|
)
|
|
|
|
def hard_delete_entity(
|
|
self,
|
|
urn: str,
|
|
) -> Tuple[int, int]:
|
|
"""Hard delete an entity by urn.
|
|
|
|
Args:
|
|
urn: The urn of the entity to hard delete.
|
|
|
|
Returns:
|
|
A tuple of (rows_affected, timeseries_rows_affected).
|
|
"""
|
|
|
|
assert urn
|
|
|
|
payload_obj: Dict = {"urn": urn}
|
|
summary = self._post_generic(
|
|
f"{self._gms_server}/entities?action=delete", payload_obj
|
|
).get("value", {})
|
|
|
|
rows_affected: int = summary.get("rows", 0)
|
|
timeseries_rows_affected: int = summary.get("timeseriesRows", 0)
|
|
return rows_affected, timeseries_rows_affected
|
|
|
|
def delete_entity(self, urn: str, hard: bool = False) -> None:
|
|
"""Delete an entity by urn.
|
|
|
|
Args:
|
|
urn: The urn of the entity to delete.
|
|
hard: Whether to hard delete the entity. If False (default), the entity will be soft deleted.
|
|
"""
|
|
|
|
if hard:
|
|
rows_affected, timeseries_rows_affected = self.hard_delete_entity(urn)
|
|
logger.debug(
|
|
f"Hard deleted entity {urn} with {rows_affected} rows affected and {timeseries_rows_affected} timeseries rows affected"
|
|
)
|
|
else:
|
|
self.soft_delete_entity(urn)
|
|
logger.debug(f"Soft deleted entity {urn}")
|
|
|
|
# TODO: Create hard_delete_aspect once we support that in GMS.
|
|
|
|
def hard_delete_timeseries_aspect(
|
|
self,
|
|
urn: str,
|
|
aspect_name: str,
|
|
start_time: Optional[datetime],
|
|
end_time: Optional[datetime],
|
|
) -> int:
|
|
"""Hard delete timeseries aspects of an entity.
|
|
|
|
Args:
|
|
urn: The urn of the entity.
|
|
aspect_name: The name of the timeseries aspect to delete.
|
|
start_time: The start time of the timeseries data to delete. If not specified, defaults to the beginning of time.
|
|
end_time: The end time of the timeseries data to delete. If not specified, defaults to the end of time.
|
|
|
|
Returns:
|
|
The number of timeseries rows affected.
|
|
"""
|
|
|
|
assert urn
|
|
assert aspect_name in TIMESERIES_ASPECT_MAP, "must be a timeseries aspect"
|
|
|
|
payload_obj: Dict = {
|
|
"urn": urn,
|
|
"aspectName": aspect_name,
|
|
}
|
|
if start_time:
|
|
payload_obj["startTimeMillis"] = int(start_time.timestamp() * 1000)
|
|
if end_time:
|
|
payload_obj["endTimeMillis"] = int(end_time.timestamp() * 1000)
|
|
|
|
summary = self._post_generic(
|
|
f"{self._gms_server}/entities?action=delete", payload_obj
|
|
).get("value", {})
|
|
|
|
timeseries_rows_affected: int = summary.get("timeseriesRows", 0)
|
|
return timeseries_rows_affected
|
|
|
|
def delete_references_to_urn(
|
|
self, urn: str, dry_run: bool = False
|
|
) -> Tuple[int, List[Dict]]:
|
|
"""Delete references to a given entity.
|
|
|
|
This is useful for cleaning up references to an entity that is about to be deleted.
|
|
For example, when deleting a tag, you might use this to remove that tag from all other
|
|
entities that reference it.
|
|
|
|
This does not delete the entity itself.
|
|
|
|
Args:
|
|
urn: The urn of the entity to delete references to.
|
|
dry_run: If True, do not actually delete the references, just return the count of
|
|
references and the list of related aspects.
|
|
|
|
Returns:
|
|
A tuple of (reference_count, sample of related_aspects).
|
|
"""
|
|
|
|
assert urn
|
|
|
|
payload_obj = {"urn": urn, "dryRun": dry_run}
|
|
|
|
response = self._post_generic(
|
|
f"{self._gms_server}/entities?action=deleteReferences", payload_obj
|
|
).get("value", {})
|
|
reference_count = response.get("total", 0)
|
|
related_aspects = response.get("relatedAspects", [])
|
|
return reference_count, related_aspects
|
|
|
|
@functools.lru_cache
|
|
def _make_schema_resolver(
|
|
self,
|
|
platform: str,
|
|
platform_instance: Optional[str],
|
|
env: str,
|
|
include_graph: bool = True,
|
|
) -> "SchemaResolver":
|
|
from datahub.sql_parsing.schema_resolver import SchemaResolver
|
|
|
|
return SchemaResolver(
|
|
platform=platform,
|
|
platform_instance=platform_instance,
|
|
env=env,
|
|
graph=self if include_graph else None,
|
|
)
|
|
|
|
def initialize_schema_resolver_from_datahub(
|
|
self,
|
|
platform: str,
|
|
platform_instance: Optional[str],
|
|
env: str,
|
|
batch_size: int = 100,
|
|
) -> "SchemaResolver":
|
|
logger.info("Initializing schema resolver")
|
|
schema_resolver = self._make_schema_resolver(
|
|
platform, platform_instance, env, include_graph=False
|
|
)
|
|
|
|
logger.info(f"Fetching schemas for platform {platform}, env {env}")
|
|
count = 0
|
|
with PerfTimer() as timer:
|
|
for urn, schema_info in self._bulk_fetch_schema_info_by_filter(
|
|
platform=platform,
|
|
platform_instance=platform_instance,
|
|
env=env,
|
|
batch_size=batch_size,
|
|
):
|
|
try:
|
|
schema_resolver.add_graphql_schema_metadata(urn, schema_info)
|
|
count += 1
|
|
except Exception:
|
|
logger.warning("Failed to add schema info", exc_info=True)
|
|
|
|
if count % 1000 == 0:
|
|
logger.debug(
|
|
f"Loaded {count} schema info in {timer.elapsed_seconds()} seconds"
|
|
)
|
|
logger.info(
|
|
f"Finished loading total {count} schema info in {timer.elapsed_seconds()} seconds"
|
|
)
|
|
|
|
logger.info("Finished initializing schema resolver")
|
|
return schema_resolver
|
|
|
|
def parse_sql_lineage(
|
|
self,
|
|
sql: str,
|
|
*,
|
|
platform: str,
|
|
platform_instance: Optional[str] = None,
|
|
env: str = DEFAULT_ENV,
|
|
default_db: Optional[str] = None,
|
|
default_schema: Optional[str] = None,
|
|
default_dialect: Optional[str] = None,
|
|
) -> "SqlParsingResult":
|
|
from datahub.sql_parsing.sqlglot_lineage import sqlglot_lineage
|
|
|
|
# Cache the schema resolver to make bulk parsing faster.
|
|
schema_resolver = self._make_schema_resolver(
|
|
platform=platform, platform_instance=platform_instance, env=env
|
|
)
|
|
|
|
return sqlglot_lineage(
|
|
sql,
|
|
schema_resolver=schema_resolver,
|
|
default_db=default_db,
|
|
default_schema=default_schema,
|
|
default_dialect=default_dialect,
|
|
)
|
|
|
|
def create_tag(self, tag_name: str) -> str:
|
|
graph_query: str = """
|
|
mutation($tag_detail: CreateTagInput!) {
|
|
createTag(input: $tag_detail)
|
|
}
|
|
"""
|
|
|
|
variables = {
|
|
"tag_detail": {
|
|
"name": tag_name,
|
|
"id": tag_name,
|
|
},
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
# return urn
|
|
return res["createTag"]
|
|
|
|
def remove_tag(self, tag_urn: str, resource_urn: str) -> bool:
|
|
graph_query = f"""
|
|
mutation removeTag {{
|
|
removeTag(
|
|
input: {{
|
|
tagUrn: "{tag_urn}",
|
|
resourceUrn: "{resource_urn}"
|
|
}})
|
|
}}
|
|
"""
|
|
|
|
res = self.execute_graphql(query=graph_query)
|
|
return res["removeTag"]
|
|
|
|
def _assertion_result_shared(self) -> str:
|
|
fragment: str = """
|
|
fragment assertionResult on AssertionResult {
|
|
type
|
|
rowCount
|
|
missingCount
|
|
unexpectedCount
|
|
actualAggValue
|
|
externalUrl
|
|
nativeResults {
|
|
value
|
|
}
|
|
error {
|
|
type
|
|
properties {
|
|
value
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
return fragment
|
|
|
|
def _run_assertion_result_shared(self) -> str:
|
|
fragment: str = """
|
|
fragment runAssertionResult on RunAssertionResult {
|
|
assertion {
|
|
urn
|
|
}
|
|
result {
|
|
... assertionResult
|
|
}
|
|
}
|
|
"""
|
|
return fragment
|
|
|
|
def _run_assertion_build_params(
|
|
self, params: Optional[Dict[str, str]] = {}
|
|
) -> List[Any]:
|
|
if params is None:
|
|
return []
|
|
|
|
results = []
|
|
for key, value in params.items():
|
|
result = {
|
|
"key": key,
|
|
"value": value,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def run_assertion(
|
|
self,
|
|
urn: str,
|
|
save_result: bool = True,
|
|
parameters: Optional[Dict[str, str]] = {},
|
|
async_flag: bool = False,
|
|
) -> Dict:
|
|
params = self._run_assertion_build_params(parameters)
|
|
graph_query: str = """
|
|
%s
|
|
mutation runAssertion($assertionUrn: String!, $saveResult: Boolean, $parameters: [StringMapEntryInput!], $async: Boolean!) {
|
|
runAssertion(urn: $assertionUrn, saveResult: $saveResult, parameters: $parameters, async: $async) {
|
|
... assertionResult
|
|
}
|
|
}
|
|
""" % (
|
|
self._assertion_result_shared()
|
|
)
|
|
|
|
variables = {
|
|
"assertionUrn": urn,
|
|
"saveResult": save_result,
|
|
"parameters": params,
|
|
"async": async_flag,
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
return res["runAssertion"]
|
|
|
|
def run_assertions(
|
|
self,
|
|
urns: List[str],
|
|
save_result: bool = True,
|
|
parameters: Optional[Dict[str, str]] = {},
|
|
async_flag: bool = False,
|
|
) -> Dict:
|
|
params = self._run_assertion_build_params(parameters)
|
|
graph_query: str = """
|
|
%s
|
|
%s
|
|
mutation runAssertions($assertionUrns: [String!]!, $saveResult: Boolean, $parameters: [StringMapEntryInput!], $async: Boolean!) {
|
|
runAssertions(urns: $assertionUrns, saveResults: $saveResult, parameters: $parameters, async: $async) {
|
|
passingCount
|
|
failingCount
|
|
errorCount
|
|
results {
|
|
... runAssertionResult
|
|
}
|
|
}
|
|
}
|
|
""" % (
|
|
self._assertion_result_shared(),
|
|
self._run_assertion_result_shared(),
|
|
)
|
|
|
|
variables = {
|
|
"assertionUrns": urns,
|
|
"saveResult": save_result,
|
|
"parameters": params,
|
|
"async": async_flag,
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
return res["runAssertions"]
|
|
|
|
def run_assertions_for_asset(
|
|
self,
|
|
urn: str,
|
|
tag_urns: Optional[List[str]] = [],
|
|
parameters: Optional[Dict[str, str]] = {},
|
|
async_flag: bool = False,
|
|
) -> Dict:
|
|
params = self._run_assertion_build_params(parameters)
|
|
graph_query: str = """
|
|
%s
|
|
%s
|
|
mutation runAssertionsForAsset($assetUrn: String!, $tagUrns: [String!], $parameters: [StringMapEntryInput!], $async: Boolean!) {
|
|
runAssertionsForAsset(urn: $assetUrn, tagUrns: $tagUrns, parameters: $parameters, async: $async) {
|
|
passingCount
|
|
failingCount
|
|
errorCount
|
|
results {
|
|
... runAssertionResult
|
|
}
|
|
}
|
|
}
|
|
""" % (
|
|
self._assertion_result_shared(),
|
|
self._run_assertion_result_shared(),
|
|
)
|
|
|
|
variables = {
|
|
"assetUrn": urn,
|
|
"tagUrns": tag_urns,
|
|
"parameters": params,
|
|
"async": async_flag,
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
return res["runAssertionsForAsset"]
|
|
|
|
def get_entities_v2(
|
|
self,
|
|
entity_name: str,
|
|
urns: List[str],
|
|
aspects: List[str] = [],
|
|
with_system_metadata: bool = False,
|
|
) -> Dict[str, Any]:
|
|
payload = {
|
|
"urns": urns,
|
|
"aspectNames": aspects,
|
|
"withSystemMetadata": with_system_metadata,
|
|
}
|
|
headers: Dict[str, Any] = {
|
|
"Accept": "application/json",
|
|
"Content-Type": "application/json",
|
|
}
|
|
url = f"{self.config.server}/openapi/v2/entity/batch/{entity_name}"
|
|
response = self._session.post(url, data=json.dumps(payload), headers=headers)
|
|
response.raise_for_status()
|
|
|
|
json_resp = response.json()
|
|
entities = json_resp.get("entities", [])
|
|
aspects_set = set(aspects)
|
|
retval: Dict[str, Any] = {}
|
|
|
|
for entity in entities:
|
|
entity_aspects = entity.get("aspects", {})
|
|
entity_urn = entity.get("urn", None)
|
|
|
|
if entity_urn is None:
|
|
continue
|
|
for aspect_key, aspect_value in entity_aspects.items():
|
|
# Include all aspects if aspect filter is empty
|
|
if len(aspects) == 0 or aspect_key in aspects_set:
|
|
retval.setdefault(entity_urn, {})
|
|
retval[entity_urn][aspect_key] = aspect_value
|
|
return retval
|
|
|
|
def upsert_custom_assertion(
|
|
self,
|
|
urn: Optional[str],
|
|
entity_urn: str,
|
|
type: str,
|
|
description: str,
|
|
platform_name: Optional[str] = None,
|
|
platform_urn: Optional[str] = None,
|
|
field_path: Optional[str] = None,
|
|
external_url: Optional[str] = None,
|
|
logic: Optional[str] = None,
|
|
) -> Dict:
|
|
graph_query: str = """
|
|
mutation upsertCustomAssertion(
|
|
$assertionUrn: String,
|
|
$entityUrn: String!,
|
|
$type: String!,
|
|
$description: String!,
|
|
$fieldPath: String,
|
|
$platformName: String,
|
|
$platformUrn: String,
|
|
$externalUrl: String,
|
|
$logic: String
|
|
) {
|
|
upsertCustomAssertion(urn: $assertionUrn, input: {
|
|
entityUrn: $entityUrn
|
|
type: $type
|
|
description: $description
|
|
fieldPath: $fieldPath
|
|
platform: {
|
|
urn: $platformUrn
|
|
name: $platformName
|
|
}
|
|
externalUrl: $externalUrl
|
|
logic: $logic
|
|
}) {
|
|
urn
|
|
}
|
|
}
|
|
"""
|
|
|
|
variables = {
|
|
"assertionUrn": urn,
|
|
"entityUrn": entity_urn,
|
|
"type": type,
|
|
"description": description,
|
|
"fieldPath": field_path,
|
|
"platformName": platform_name,
|
|
"platformUrn": platform_urn,
|
|
"externalUrl": external_url,
|
|
"logic": logic,
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
return res["upsertCustomAssertion"]
|
|
|
|
def report_assertion_result(
|
|
self,
|
|
urn: str,
|
|
timestamp_millis: int,
|
|
type: Literal["SUCCESS", "FAILURE", "ERROR", "INIT"],
|
|
properties: Optional[List[Dict[str, str]]] = None,
|
|
external_url: Optional[str] = None,
|
|
error_type: Optional[str] = None,
|
|
error_message: Optional[str] = None,
|
|
) -> bool:
|
|
graph_query: str = """
|
|
mutation reportAssertionResult(
|
|
$assertionUrn: String!,
|
|
$timestampMillis: Long!,
|
|
$type: AssertionResultType!,
|
|
$properties: [StringMapEntryInput!],
|
|
$externalUrl: String,
|
|
$error: AssertionResultErrorInput,
|
|
) {
|
|
reportAssertionResult(urn: $assertionUrn, result: {
|
|
timestampMillis: $timestampMillis
|
|
type: $type
|
|
properties: $properties
|
|
externalUrl: $externalUrl
|
|
error: $error
|
|
})
|
|
}
|
|
"""
|
|
|
|
variables = {
|
|
"assertionUrn": urn,
|
|
"timestampMillis": timestamp_millis,
|
|
"type": type,
|
|
"properties": properties,
|
|
"externalUrl": external_url,
|
|
"error": {"type": error_type, "message": error_message}
|
|
if error_type
|
|
else None,
|
|
}
|
|
|
|
res = self.execute_graphql(
|
|
query=graph_query,
|
|
variables=variables,
|
|
)
|
|
|
|
return res["reportAssertionResult"]
|
|
|
|
def close(self) -> None:
|
|
self._make_schema_resolver.cache_clear()
|
|
super().close()
|
|
|
|
|
|
def get_default_graph() -> DataHubGraph:
|
|
(url, token) = get_url_and_token()
|
|
graph = DataHubGraph(DatahubClientConfig(server=url, token=token))
|
|
graph.test_connection()
|
|
return graph
|