refactor(ingest): Add helper DataHubGraph methods (#7851)

Adds:
- get_urns_by_filter(), using scroll by entities
- get_latest_pipeline_checkpoint()
- soft_delete_urn()
This commit is contained in:
Andrew Sikowitz 2023-04-20 13:16:33 -04:00 committed by GitHub
parent a9a80b8c70
commit 1ff6949e36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 32 deletions

View File

@ -369,6 +369,7 @@ def get_urns_by_filter(
include_removed: bool = False,
only_soft_deleted: Optional[bool] = None,
) -> Iterable[str]:
# TODO: Replace with DataHubGraph call
session, gms_host = get_session_and_host()
endpoint: str = "/entities?action=search"
url = gms_host + endpoint

View File

@ -5,14 +5,6 @@ import click
from click_default_group import DefaultGroup
from datahub.ingestion.graph.client import get_default_graph
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionCheckpointingProvider,
)
from datahub.telemetry import telemetry
from datahub.upgrade import upgrade
@ -37,20 +29,9 @@ def inspect(pipeline_name: str, platform: str) -> None:
"""
datahub_graph = get_default_graph()
checkpoint_provider = DatahubIngestionCheckpointingProvider(datahub_graph, "cli")
job_name = StaleEntityRemovalHandler.compute_job_id(platform)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(pipeline_name, job_name)
if not raw_checkpoint:
checkpoint = datahub_graph.get_latest_pipeline_checkpoint(pipeline_name, platform)
if not checkpoint:
click.secho("No ingestion state found.", fg="red")
exit(1)
checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=job_name,
checkpoint_aspect=raw_checkpoint,
state_class=GenericCheckpointState,
)
assert checkpoint
click.echo(json.dumps(checkpoint.state.urns, indent=2))

View File

@ -1,21 +1,25 @@
import json
import logging
import time
from dataclasses import dataclass
from enum import Enum
from json.decoder import JSONDecodeError
from typing import Any, Dict, Iterable, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union
from avro.schema import RecordSchema
from deprecated import deprecated
from requests.adapters import Response
from requests.models import HTTPError
from typing_extensions import Literal
from datahub.cli.cli_utils import get_boolean_env_variable, get_url_and_token
from datahub.configuration.common import ConfigModel, GraphError, OperationalError
from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP
from datahub.emitter.mce_builder import Aspect
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.emitter.serialization_helper import post_json_transform
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.metadata.schema_classes import (
BrowsePathsClass,
DatasetPropertiesClass,
@ -26,10 +30,18 @@ from datahub.metadata.schema_classes import (
GlossaryTermsClass,
OwnershipClass,
SchemaMetadataClass,
StatusClass,
SystemMetadataClass,
TelemetryClientIdClass,
)
from datahub.utilities.urns.urn import Urn, guess_entity_type
if TYPE_CHECKING:
from datahub.ingestion.source.state.entity_removal_state import (
GenericCheckpointState,
)
logger = logging.getLogger(__name__)
@ -289,14 +301,14 @@ class DataHubGraph(DatahubRestEmitter):
"urn": entity_urn,
"entity": guess_entity_type(entity_urn),
"aspect": aspect_type.ASPECT_NAME,
"latestValue": True,
"limit": 1,
"filter": {"or": [{"and": filter_criteria}]},
}
end_point = f"{self.config.server}/aspects?action=getTimeseriesAspectValues"
resp: Dict = self._post_generic(end_point, query_body)
values: list = resp.get("value", {}).get("values")
if values:
assert len(values) == 1
assert len(values) == 1, len(values)
aspect_json: str = values[0].get("aspect", {}).get("value")
if aspect_json:
return aspect_type.from_obj(json.loads(aspect_json), tuples=False)
@ -358,15 +370,22 @@ class DataHubGraph(DatahubRestEmitter):
return result
def _get_search_endpoint(self):
@property
def _search_endpoint(self):
return f"{self.config.server}/entities?action=search"
def _get_relationships_endpoint(self):
@property
def _relationships_endpoint(self):
return f"{self.config.server}/openapi/relationships/v1/"
def _get_aspect_count_endpoint(self):
@property
def _aspect_count_endpoint(self):
return f"{self.config.server}/aspects?action=getCount"
@property
def _scroll_across_entities_endpoint(self):
return f"{self.config.server}/entities?action=scrollAcrossEntities"
def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]:
"""Retrieve a domain urn based on its name. Returns None if there is no match found"""
@ -387,7 +406,7 @@ class DataHubGraph(DatahubRestEmitter):
"count": 10,
"filter": {"or": filters},
}
results: Dict = self._post_generic(self._get_search_endpoint(), search_body)
results: Dict = self._post_generic(self._search_endpoint, search_body)
num_entities = results.get("value", {}).get("numEntities", 0)
if num_entities > 1:
logger.warning(
@ -407,7 +426,7 @@ class DataHubGraph(DatahubRestEmitter):
search_query: str = "*",
) -> Iterable[str]:
"""Return container urns that match based on query"""
url = self._get_search_endpoint()
url = self._search_endpoint
container_filters = []
for container_subtype in ["Database", "Schema", "Project", "Dataset"]:
@ -445,18 +464,91 @@ class DataHubGraph(DatahubRestEmitter):
logger.debug(f"yielding {x['entity']}")
yield x["entity"]
def get_urns_by_filter(
self, platform: str, batch_size: int = 10000
) -> Iterable[str]:
# Does not filter on env, because env is missing in dashboard / chart urns and custom properties
# For containers, use { field: "customProperties", values: ["instance=env}"], condition:EQUAL }
# For others, use { field: "origin", values: ["env"], condition:EQUAL }
query = """
query scrollEntitiesForPlatform($platform: String!, $batchSize: Int!, $scrollId: String) {
scrollAcrossEntities(input: { query: "*", count:$batchSize,
scrollId: $scrollId,
orFilters: [
{and: [{
field: "platform.keyword",
values: [$platform],
condition: EQUAL,
}]}
]
}) {
nextScrollId
searchResults {
entity {
urn
}
}
}
}
"""
# Set scroll_id to False to enter while loop
scroll_id: Union[Literal[False], str, None] = False
while scroll_id is not None:
response = self.execute_graphql(
query,
variables={
"platform": f"urn:li:dataPlatform:{platform}",
"batchSize": batch_size,
"scrollId": scroll_id or None,
},
)
data = response["scrollAcrossEntities"]
scroll_id = data["nextScrollId"]
for entry in data["searchResults"]:
yield entry["entity"]["urn"]
def get_latest_pipeline_checkpoint(
self, pipeline_name: str, platform: str
) -> Optional[Checkpoint["GenericCheckpointState"]]:
from datahub.ingestion.source.state.entity_removal_state import (
GenericCheckpointState,
)
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionCheckpointingProvider,
)
checkpoint_provider = DatahubIngestionCheckpointingProvider(self, "graph")
job_name = StaleEntityRemovalHandler.compute_job_id(platform)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(
pipeline_name, job_name
)
if not raw_checkpoint:
return None
return Checkpoint.create_from_checkpoint_aspect(
job_name=job_name,
checkpoint_aspect=raw_checkpoint,
state_class=GenericCheckpointState,
)
def get_search_results(
self, start: int = 0, count: int = 1, entity: str = "dataset"
) -> Dict:
search_body = {"input": "*", "entity": entity, "start": start, "count": count}
results: Dict = self._post_generic(self._get_search_endpoint(), search_body)
results: Dict = self._post_generic(self._search_endpoint, search_body)
return results
def get_aspect_counts(self, aspect: str, urn_like: Optional[str] = None) -> int:
args = {"aspect": aspect}
if urn_like is not None:
args["urnLike"] = urn_like
results = self._post_generic(self._get_aspect_count_endpoint(), args)
results = self._post_generic(self._aspect_count_endpoint, args)
return results["value"]
def execute_graphql(self, query: str, variables: Optional[Dict] = None) -> Dict:
@ -488,7 +580,7 @@ class DataHubGraph(DatahubRestEmitter):
relationship_types: List[str],
direction: RelationshipDirection,
) -> Iterable[RelatedEntity]:
relationship_endpoint = self._get_relationships_endpoint()
relationship_endpoint = self._relationships_endpoint
done = False
start = 0
while not done:
@ -511,6 +603,22 @@ class DataHubGraph(DatahubRestEmitter):
)
start = start + response.get("count", 0)
def soft_delete_urn(
self,
urn: str,
run_id: str = "soft-delete-urns",
) -> None:
timestamp = int(time.time() * 1000)
self.emit_mcp(
MetadataChangeProposalWrapper(
entityUrn=urn,
aspect=StatusClass(removed=True),
systemMetadata=SystemMetadataClass(
runId=run_id, lastObserved=timestamp
),
)
)
def get_default_graph() -> DataHubGraph:
(url, token) = get_url_and_token()