414 lines
13 KiB
Python

# Copyright 2021 Acryl Data, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import urllib.parse
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from datahub.configuration.common import OperationalError
from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.schema_classes import (
GlossaryTermAssociationClass,
TagAssociationClass,
)
from datahub.specific.dataset import DatasetPatchBuilder
logger = logging.getLogger(__name__)
@dataclass
class AcrylDataHubGraph:
def __init__(self, baseGraph: DataHubGraph):
self.graph = baseGraph
def get_by_query(
self,
query: str,
entity: str,
start: int = 0,
count: int = 100,
filters: Optional[Dict] = None,
) -> List[Dict]:
url_frag = "/entities?action=search"
url = f"{self.graph._gms_server}{url_frag}"
payload = {"input": query, "start": start, "count": count, "entity": entity}
if filters is not None:
payload["filter"] = filters
headers = {
"X-RestLi-Protocol-Version": "2.0.0",
"Content-Type": "application/json",
}
try:
response = self.graph._session.post(
url, data=json.dumps(payload), headers=headers
)
if response.status_code != 200:
return []
json_resp = response.json()
return json_resp.get("value", {}).get("entities")
except Exception as e:
print(e)
return []
def get_by_graphql_query(self, query: Dict) -> Dict:
url_frag = "/api/graphql"
url = f"{self.graph._gms_server}{url_frag}"
headers = {
"X-DataHub-Actor": "urn:li:corpuser:admin",
"Content-Type": "application/json",
}
try:
response = self.graph._session.post(
url, data=json.dumps(query), headers=headers
)
if response.status_code != 200:
return {}
json_resp = response.json()
return json_resp.get("data", {})
except Exception as e:
print(e)
return {}
def query_constraints_for_dataset(self, dataset_id: str) -> List:
resp = self.get_by_graphql_query(
{
"query": """
query dataset($input: String!) {
dataset(urn: $input) {
constraints {
type
displayName
description
params {
hasGlossaryTermInNodeParams {
nodeName
}
}
}
}
}
""",
"variables": {"input": dataset_id},
}
)
constraints: List = resp.get("dataset", {}).get("constraints", [])
return constraints
def query_execution_result_details(self, execution_id: str) -> Any:
resp = self.get_by_graphql_query(
{
"query": """
query executionRequest($urn: String!) {
executionRequest(urn: $urn) {
input {
task
arguments {
key
value
}
}
}
}
""",
"variables": {"urn": f"urn:li:dataHubExecutionRequest:{execution_id}"},
}
)
return resp.get("executionRequest", {}).get("input", {})
def query_ingestion_sources(self) -> List:
sources = []
start, count = 0, 10
while True:
resp = self.get_by_graphql_query(
{
"query": """
query listIngestionSources($input: ListIngestionSourcesInput!, $execution_start: Int!, $execution_count: Int!) {
listIngestionSources(input: $input) {
start
count
total
ingestionSources {
urn
type
name
executions(start: $execution_start, count: $execution_count) {
start
count
total
executionRequests {
urn
}
}
}
}
}
""",
"variables": {
"input": {"start": start, "count": count},
"execution_start": 0,
"execution_count": 10,
},
}
)
listIngestionSources = resp.get("listIngestionSources", {})
sources.extend(listIngestionSources.get("ingestionSources", []))
cur_total = listIngestionSources.get("total", 0)
if cur_total > count:
start += count
else:
break
return sources
def get_downstreams(
self, entity_urn: str, max_downstreams: int = 3000
) -> List[str]:
start = 0
count_per_page = 1000
entities = []
done = False
total_downstreams = 0
while not done:
# if start > 0:
# breakpoint()
url_frag = f"/relationships?direction=INCOMING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}&count={count_per_page}&start={start}"
url = f"{self.graph._gms_server}{url_frag}"
response = self.graph._get_generic(url)
if response["count"] > 0:
relnships = response["relationships"]
entities.extend([x["entity"] for x in relnships])
start += count_per_page
total_downstreams += response["count"]
if start >= response["total"] or total_downstreams >= max_downstreams:
done = True
else:
done = True
return entities
def get_upstreams(self, entity_urn: str, max_upstreams: int = 3000) -> List[str]:
start = 0
count_per_page = 100
entities = []
done = False
total_upstreams = 0
while not done:
url_frag = f"/relationships?direction=OUTGOING&types=List(DownstreamOf)&urn={urllib.parse.quote(entity_urn)}&count={count_per_page}&start={start}"
url = f"{self.graph._gms_server}{url_frag}"
response = self.graph._get_generic(url)
if response["count"] > 0:
relnships = response["relationships"]
entities.extend([x["entity"] for x in relnships])
start += count_per_page
total_upstreams += response["count"]
if start >= response["total"] or total_upstreams >= max_upstreams:
done = True
else:
done = True
return entities
def get_relationships(
self, entity_urn: str, direction: str, relationship_types: List[str]
) -> List[str]:
url_frag = (
f"/relationships?"
f"direction={direction}"
f"&types=List({','.join(relationship_types)})"
f"&urn={urllib.parse.quote(entity_urn)}"
)
url = f"{self.graph._gms_server}{url_frag}"
response = self.graph._get_generic(url)
if response["count"] > 0:
relnships = response["relationships"]
entities = [x["entity"] for x in relnships]
return entities
return []
def check_relationship(self, entity_urn, target_urn, relationship_type):
url_frag = f"/relationships?direction=INCOMING&types=List({relationship_type})&urn={urllib.parse.quote(entity_urn)}"
url = f"{self.graph._gms_server}{url_frag}"
response = self.graph._get_generic(url)
if response["count"] > 0:
relnships = response["relationships"]
entities = [x["entity"] for x in relnships]
return target_urn in entities
return False
def add_tags_to_dataset(
self,
entity_urn: str,
dataset_tags: List[str],
field_tags: Optional[Dict] = None,
context: Optional[Dict] = None,
) -> None:
if field_tags is None:
field_tags = {}
dataset = DatasetPatchBuilder(entity_urn)
for t in dataset_tags:
dataset.add_tag(
tag=TagAssociationClass(
tag=t, context=json.dumps(context) if context else None
)
)
for field_path, tags in field_tags.items():
field_builder = dataset.for_field(field_path=field_path)
for tag in tags:
field_builder.add_tag(
tag=TagAssociationClass(
tag=tag, context=json.dumps(context) if context else None
)
)
for mcp in dataset.build():
self.graph.emit(mcp)
def add_terms_to_dataset(
self,
entity_urn: str,
dataset_terms: List[str],
field_terms: Optional[Dict] = None,
context: Optional[Dict] = None,
) -> None:
if field_terms is None:
field_terms = {}
dataset = DatasetPatchBuilder(urn=entity_urn)
for term in dataset_terms:
dataset.add_term(
GlossaryTermAssociationClass(
term, context=json.dumps(context) if context else None
)
)
for field_path, terms in field_terms.items():
field_builder = dataset.for_field(field_path=field_path)
for term in terms:
field_builder.add_term(
GlossaryTermAssociationClass(
term, context=json.dumps(context) if context else None
)
)
for mcp in dataset.build():
self.graph.emit(mcp)
def get_corpuser_info(self, urn: str) -> Any:
return self.get_untyped_aspect(
urn, "corpUserInfo", "com.linkedin.identity.CorpUserInfo"
)
def get_untyped_aspect(
self,
entity_urn: str,
aspect: str,
aspect_type_name: str,
) -> Any:
url = f"{self.graph._gms_server}/aspects/{urllib.parse.quote(entity_urn)}?aspect={aspect}&version=0"
response = self.graph._session.get(url)
if response.status_code == 404:
# not found
return None
response.raise_for_status()
response_json = response.json()
aspect_json = response_json.get("aspect", {}).get(aspect_type_name)
if aspect_json:
return aspect_json
else:
raise OperationalError(
f"Failed to find {aspect_type_name} in response {response_json}"
)
def _get_entity_by_name(
self,
name: str,
entity_type: str,
indexed_fields: Optional[List[str]] = None,
) -> Optional[str]:
"""Retrieve an entity urn based on its name and type. Returns None if there is no match found"""
if indexed_fields is None:
indexed_fields = ["name", "displayName"]
filters = []
if len(indexed_fields) > 1:
for indexed_field in indexed_fields:
filter_criteria = [
{
"field": indexed_field,
"value": name,
"condition": "EQUAL",
}
]
filters.append({"and": filter_criteria})
search_body = {
"input": "*",
"entity": entity_type,
"start": 0,
"count": 10,
"orFilters": [filters],
}
else:
search_body = {
"input": "*",
"entity": entity_type,
"start": 0,
"count": 10,
"filter": {
"or": [
{
"and": [
{
"field": indexed_fields[0],
"value": name,
"condition": "EQUAL",
}
]
}
]
},
}
results: Dict = self.graph._post_generic(
self.graph._search_endpoint, search_body
)
num_entities = results.get("value", {}).get("numEntities", 0)
if num_entities > 1:
logger.warning(
f"Got {num_entities} results for {entity_type} {name}. Will return the first match."
)
entities_yielded: int = 0
entities = []
for x in results["value"]["entities"]:
entities_yielded += 1
logger.debug(f"yielding {x['entity']}")
entities.append(x["entity"])
return entities[0] if entities_yielded else None
def get_glossary_term_urn_by_name(self, term_name: str) -> Optional[str]:
"""Retrieve a glossary term urn based on its name. Returns None if there is no match found"""
return self._get_entity_by_name(
term_name, "glossaryTerm", indexed_fields=["name"]
)
def get_glossary_node_urn_by_name(self, node_name: str) -> Optional[str]:
"""Retrieve a glossary node urn based on its name. Returns None if there is no match found"""
return self._get_entity_by_name(node_name, "glossaryNode")