feat(cli): Make consistent use of DataHubGraphClientConfig (#10466)

Deprecates get_url_and_token() in favor of a more complete option: load_graph_config() that returns a full DatahubClientConfig.
This change was then propagated across previous usages of get_url_and_token so that connections to DataHub server from the client respect the full breadth of configuration specified by DatahubClientConfig.

I.e: You can now specify disable_ssl_verification: true in your ~/.datahubenv file so that all cli functions to the server work when ssl certification is disabled.

Fixes #9705
This commit is contained in:
Pedro Silva 2024-07-25 20:06:14 +01:00 committed by GitHub
parent 5a2fc3c58e
commit dd732d0d46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 396 additions and 504 deletions

View File

@ -80,6 +80,7 @@ New (optional fields `systemMetadata` and `headers`):
### Deprecations
### Other Notable Change
- #10466 - Extends configuration in `~/.datahubenv` to match `DatahubClientConfig` object definition. See full configuration in https://datahubproject.io/docs/python-sdk/clients/. The CLI should now respect the updated configurations specified in `~/.datahubenv` across its functions and utilities. This means that for systems where ssl certification is disabled, setting `disable_ssl_verification: true` in `~./datahubenv` will apply to all CLI calls.
## 0.13.1

View File

@ -2,15 +2,12 @@ import json
import logging
import os
import os.path
import sys
import typing
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import click
import requests
from deprecated import deprecated
from requests.models import Response
from requests.sessions import Session
import datahub
@ -28,46 +25,14 @@ from datahub.utilities.urns.urn import Urn, guess_entity_type
log = logging.getLogger(__name__)
ENV_METADATA_HOST_URL = "DATAHUB_GMS_URL"
ENV_METADATA_HOST = "DATAHUB_GMS_HOST"
ENV_METADATA_PORT = "DATAHUB_GMS_PORT"
ENV_METADATA_PROTOCOL = "DATAHUB_GMS_PROTOCOL"
ENV_METADATA_TOKEN = "DATAHUB_GMS_TOKEN"
ENV_DATAHUB_SYSTEM_CLIENT_ID = "DATAHUB_SYSTEM_CLIENT_ID"
ENV_DATAHUB_SYSTEM_CLIENT_SECRET = "DATAHUB_SYSTEM_CLIENT_SECRET"
config_override: Dict = {}
# TODO: Many of the methods in this file duplicate logic that already lives
# in the DataHubGraph client. We should refactor this to use the client instead.
# For the methods that aren't duplicates, that logic should be moved to the client.
def set_env_variables_override_config(url: str, token: Optional[str]) -> None:
"""Should be used to override the config when using rest emitter"""
config_override[ENV_METADATA_HOST_URL] = url
if token is not None:
config_override[ENV_METADATA_TOKEN] = token
def get_details_from_env() -> Tuple[Optional[str], Optional[str]]:
host = os.environ.get(ENV_METADATA_HOST)
port = os.environ.get(ENV_METADATA_PORT)
token = os.environ.get(ENV_METADATA_TOKEN)
protocol = os.environ.get(ENV_METADATA_PROTOCOL, "http")
url = os.environ.get(ENV_METADATA_HOST_URL)
if port is not None:
url = f"{protocol}://{host}:{port}"
return url, token
# The reason for using host as URL is backward compatibility
# If port is not being used we assume someone is using host env var as URL
if url is None and host is not None:
log.warning(
f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead"
)
return url or host, token
def first_non_null(ls: List[Optional[str]]) -> Optional[str]:
return next((el for el in ls if el is not None and el.strip() != ""), None)
@ -80,72 +45,6 @@ def get_system_auth() -> Optional[str]:
return None
def get_url_and_token():
gms_host_env, gms_token_env = get_details_from_env()
if len(config_override.keys()) > 0:
gms_host = config_override.get(ENV_METADATA_HOST_URL)
gms_token = config_override.get(ENV_METADATA_TOKEN)
elif config_utils.should_skip_config():
gms_host = gms_host_env
gms_token = gms_token_env
else:
config_utils.ensure_datahub_config()
gms_host_conf, gms_token_conf = config_utils.get_details_from_config()
gms_host = first_non_null([gms_host_env, gms_host_conf])
gms_token = first_non_null([gms_token_env, gms_token_conf])
return gms_host, gms_token
def get_token():
return get_url_and_token()[1]
def get_session_and_host():
session = requests.Session()
gms_host, gms_token = get_url_and_token()
if gms_host is None or gms_host.strip() == "":
log.error(
f"GMS Host is not set. Use datahub init command or set {ENV_METADATA_HOST_URL} env var"
)
return None, None
session.headers.update(
{
"X-RestLi-Protocol-Version": "2.0.0",
"Content-Type": "application/json",
}
)
if isinstance(gms_token, str) and len(gms_token) > 0:
session.headers.update(
{"Authorization": f"Bearer {gms_token.format(**os.environ)}"}
)
return session, gms_host
def test_connection():
(session, host) = get_session_and_host()
url = f"{host}/config"
response = session.get(url)
response.raise_for_status()
def test_connectivity_complain_exit(operation_name: str) -> None:
"""Test connectivity to metadata-service, log operation name and exit"""
# First test connectivity
try:
test_connection()
except Exception as e:
click.secho(
f"Failed to connect to DataHub server at {get_session_and_host()[1]}. Run with datahub --debug {operation_name} ... to get more information.",
fg="red",
)
log.debug(f"Failed to connect with {e}")
sys.exit(1)
def parse_run_restli_response(response: requests.Response) -> dict:
response_json = response.json()
if response.status_code != 200:
@ -195,10 +94,11 @@ def format_aspect_summaries(summaries: list) -> typing.List[typing.List[str]]:
def post_rollback_endpoint(
session: Session,
gms_host: str,
payload_obj: dict,
path: str,
) -> typing.Tuple[typing.List[typing.List[str]], int, int, int, int, typing.List[dict]]:
session, gms_host = get_session_and_host()
url = gms_host + path
payload = json.dumps(payload_obj)
@ -229,212 +129,13 @@ def post_rollback_endpoint(
)
@deprecated(reason="Use DataHubGraph.get_urns_by_filter instead")
def get_urns_by_filter(
platform: Optional[str],
env: Optional[str] = None,
entity_type: str = "dataset",
search_query: str = "*",
include_removed: bool = False,
only_soft_deleted: Optional[bool] = None,
) -> Iterable[str]:
# TODO: Replace with DataHubGraph call
session, gms_host = get_session_and_host()
endpoint: str = "/entities?action=search"
url = gms_host + endpoint
filter_criteria = []
entity_type_lower = entity_type.lower()
if env and entity_type_lower != "container":
filter_criteria.append({"field": "origin", "value": env, "condition": "EQUAL"})
if (
platform is not None
and entity_type_lower == "dataset"
or entity_type_lower == "dataflow"
or entity_type_lower == "datajob"
or entity_type_lower == "container"
):
filter_criteria.append(
{
"field": "platform.keyword",
"value": f"urn:li:dataPlatform:{platform}",
"condition": "EQUAL",
}
)
if platform is not None and entity_type_lower in {"chart", "dashboard"}:
filter_criteria.append(
{
"field": "tool",
"value": platform,
"condition": "EQUAL",
}
)
if only_soft_deleted:
filter_criteria.append(
{
"field": "removed",
"value": "true",
"condition": "EQUAL",
}
)
elif include_removed:
filter_criteria.append(
{
"field": "removed",
"value": "", # accept anything regarding removed property (true, false, non-existent)
"condition": "EQUAL",
}
)
search_body = {
"input": search_query,
"entity": entity_type,
"start": 0,
"count": 10000,
"filter": {"or": [{"and": filter_criteria}]},
}
payload = json.dumps(search_body)
log.debug(payload)
response: Response = session.post(url, payload)
if response.status_code == 200:
assert response._content
results = json.loads(response._content)
num_entities = results["value"]["numEntities"]
entities_yielded: int = 0
for x in results["value"]["entities"]:
entities_yielded += 1
log.debug(f"yielding {x['entity']}")
yield x["entity"]
if entities_yielded != num_entities:
log.warning(
f"Discrepancy in entities yielded {entities_yielded} and num entities {num_entities}. This means all entities may not have been deleted."
)
else:
log.error(f"Failed to execute search query with {str(response.content)}")
response.raise_for_status()
def get_container_ids_by_filter(
env: Optional[str],
entity_type: str = "container",
search_query: str = "*",
) -> Iterable[str]:
session, gms_host = get_session_and_host()
endpoint: str = "/entities?action=search"
url = gms_host + endpoint
container_filters = []
for container_subtype in ["Database", "Schema", "Project", "Dataset"]:
filter_criteria = []
filter_criteria.append(
{
"field": "customProperties",
"value": f"instance={env}",
"condition": "EQUAL",
}
)
filter_criteria.append(
{
"field": "typeNames",
"value": container_subtype,
"condition": "EQUAL",
}
)
container_filters.append({"and": filter_criteria})
search_body = {
"input": search_query,
"entity": entity_type,
"start": 0,
"count": 10000,
"filter": {"or": container_filters},
}
payload = json.dumps(search_body)
log.debug(payload)
response: Response = session.post(url, payload)
if response.status_code == 200:
assert response._content
log.debug(response._content)
results = json.loads(response._content)
num_entities = results["value"]["numEntities"]
entities_yielded: int = 0
for x in results["value"]["entities"]:
entities_yielded += 1
log.debug(f"yielding {x['entity']}")
yield x["entity"]
assert (
entities_yielded == num_entities
), "Did not delete all entities, try running this command again!"
else:
log.error(f"Failed to execute search query with {str(response.content)}")
response.raise_for_status()
def batch_get_ids(
ids: List[str],
) -> Iterable[Dict]:
session, gms_host = get_session_and_host()
endpoint: str = "/entitiesV2"
url = gms_host + endpoint
ids_to_get = [Urn.url_encode(id) for id in ids]
response = session.get(
f"{url}?ids=List({','.join(ids_to_get)})",
)
if response.status_code == 200:
assert response._content
log.debug(response._content)
results = json.loads(response._content)
num_entities = len(results["results"])
entities_yielded: int = 0
for x in results["results"].values():
entities_yielded += 1
log.debug(f"yielding {x}")
yield x
assert (
entities_yielded == num_entities
), "Did not delete all entities, try running this command again!"
else:
log.error(f"Failed to execute batch get with {str(response.content)}")
response.raise_for_status()
def get_incoming_relationships(urn: str, types: List[str]) -> Iterable[Dict]:
yield from get_relationships(urn=urn, types=types, direction="INCOMING")
def get_outgoing_relationships(urn: str, types: List[str]) -> Iterable[Dict]:
yield from get_relationships(urn=urn, types=types, direction="OUTGOING")
def get_relationships(urn: str, types: List[str], direction: str) -> Iterable[Dict]:
session, gms_host = get_session_and_host()
encoded_urn: str = Urn.url_encode(urn)
types_param_string = "List(" + ",".join(types) + ")"
endpoint: str = f"{gms_host}/relationships?urn={encoded_urn}&direction={direction}&types={types_param_string}"
response: Response = session.get(endpoint)
if response.status_code == 200:
results = response.json()
log.debug(f"Relationship response: {results}")
num_entities = results["count"]
entities_yielded: int = 0
for x in results["relationships"]:
entities_yielded += 1
yield x
if entities_yielded != num_entities:
log.warn("Yielded entities differ from num entities")
else:
log.error(f"Failed to execute relationships query with {str(response.content)}")
response.raise_for_status()
def get_entity(
session: Session,
gms_host: str,
urn: str,
aspect: Optional[List] = None,
cached_session_host: Optional[Tuple[Session, str]] = None,
) -> Dict:
session, gms_host = cached_session_host or get_session_and_host()
if urn.startswith("urn%3A"):
# we assume the urn is already encoded
encoded_urn: str = urn
@ -457,6 +158,8 @@ def get_entity(
def post_entity(
session: Session,
gms_host: str,
urn: str,
entity_type: str,
aspect_name: str,
@ -464,7 +167,6 @@ def post_entity(
cached_session_host: Optional[Tuple[Session, str]] = None,
is_async: Optional[str] = "false",
) -> int:
session, gms_host = cached_session_host or get_session_and_host()
endpoint: str = "/aspects/?action=ingestProposal"
proposal = {
@ -502,11 +204,12 @@ def _get_pydantic_class_from_aspect_name(aspect_name: str) -> Optional[Type[_Asp
def get_latest_timeseries_aspect_values(
session: Session,
gms_host: str,
entity_urn: str,
timeseries_aspect_name: str,
cached_session_host: Optional[Tuple[Session, str]],
) -> Dict:
session, gms_host = cached_session_host or get_session_and_host()
query_body = {
"urn": entity_urn,
"entity": guess_entity_type(entity_urn),
@ -524,6 +227,8 @@ def get_latest_timeseries_aspect_values(
def get_aspects_for_entity(
session: Session,
gms_host: str,
entity_urn: str,
aspects: List[str],
typed: bool = False,
@ -533,7 +238,7 @@ def get_aspects_for_entity(
# Process non-timeseries aspects
non_timeseries_aspects = [a for a in aspects if a not in TIMESERIES_ASPECT_MAP]
entity_response = get_entity(
entity_urn, non_timeseries_aspects, cached_session_host
session, gms_host, entity_urn, non_timeseries_aspects, cached_session_host
)
aspect_list: Dict[str, dict] = entity_response["aspects"]
@ -541,7 +246,7 @@ def get_aspects_for_entity(
timeseries_aspects: List[str] = [a for a in aspects if a in TIMESERIES_ASPECT_MAP]
for timeseries_aspect in timeseries_aspects:
timeseries_response: Dict = get_latest_timeseries_aspect_values(
entity_urn, timeseries_aspect, cached_session_host
session, gms_host, entity_urn, timeseries_aspect, cached_session_host
)
values: List[Dict] = timeseries_response.get("value", {}).get("values", [])
if values:

View File

@ -4,12 +4,10 @@ For helper methods to contain manipulation of the config file in local system.
import logging
import os
import sys
from typing import Optional, Union
from typing import Optional
import click
import yaml
from pydantic import BaseModel, ValidationError
from datahub.cli.env_utils import get_boolean_env_variable
@ -22,82 +20,20 @@ DATAHUB_ROOT_FOLDER = os.path.expanduser("~/.datahub")
ENV_SKIP_CONFIG = "DATAHUB_SKIP_CONFIG"
class GmsConfig(BaseModel):
server: str
token: Optional[str] = None
class DatahubConfig(BaseModel):
gms: GmsConfig
def persist_datahub_config(config: dict) -> None:
with open(DATAHUB_CONFIG_PATH, "w+") as outfile:
yaml.dump(config, outfile, default_flow_style=False)
return None
def write_gms_config(
host: str, token: Optional[str], merge_with_previous: bool = True
) -> None:
config = DatahubConfig(gms=GmsConfig(server=host, token=token))
if merge_with_previous:
try:
previous_config = get_client_config(as_dict=True)
assert isinstance(previous_config, dict)
except Exception as e:
# ok to fail on this
previous_config = {}
log.debug(
f"Failed to retrieve config from file {DATAHUB_CONFIG_PATH}: {e}. This isn't fatal."
)
config_dict = {**previous_config, **config.dict()}
else:
config_dict = config.dict()
persist_datahub_config(config_dict)
def get_details_from_config():
datahub_config = get_client_config(as_dict=False)
assert isinstance(datahub_config, DatahubConfig)
if datahub_config is not None:
gms_config = datahub_config.gms
gms_host = gms_config.server
gms_token = gms_config.token
return gms_host, gms_token
else:
return None, None
def should_skip_config() -> bool:
return get_boolean_env_variable(ENV_SKIP_CONFIG, False)
def ensure_datahub_config() -> None:
if not os.path.isfile(DATAHUB_CONFIG_PATH):
click.secho(
f"No {CONDENSED_DATAHUB_CONFIG_PATH} file found, generating one for you...",
bold=True,
)
write_gms_config(DEFAULT_GMS_HOST, None)
def get_client_config(as_dict: bool = False) -> Union[Optional[DatahubConfig], dict]:
def get_client_config() -> Optional[dict]:
with open(DATAHUB_CONFIG_PATH) as stream:
try:
config_json = yaml.safe_load(stream)
if as_dict:
return config_json
try:
datahub_config = DatahubConfig.parse_obj(config_json)
return datahub_config
except ValidationError as e:
click.echo(
f"Received error, please check your {CONDENSED_DATAHUB_CONFIG_PATH}"
)
click.echo(e, err=True)
sys.exit(1)
return yaml.safe_load(stream)
except yaml.YAMLError as exc:
click.secho(f"{DATAHUB_CONFIG_PATH} malformed, error: {exc}", bold=True)
return None

View File

@ -123,6 +123,8 @@ def by_registry(
Delete all metadata written using the given registry id and version pair.
"""
client = get_default_graph()
if soft and not dry_run:
raise click.UsageError(
"Soft-deleting with a registry-id is not yet supported. Try --dry-run to see what you will be deleting, before issuing a hard-delete using the --hard flag"
@ -138,7 +140,10 @@ def by_registry(
unsafe_entity_count,
unsafe_entities,
) = cli_utils.post_rollback_endpoint(
registry_delete, "/entities?action=deleteAll"
client._session,
client.config.server,
registry_delete,
"/entities?action=deleteAll",
)
if not dry_run:

View File

@ -6,6 +6,7 @@ import click
from click_default_group import DefaultGroup
from datahub.cli.cli_utils import get_aspects_for_entity
from datahub.ingestion.graph.client import get_default_graph
from datahub.telemetry import telemetry
from datahub.upgrade import upgrade
@ -44,10 +45,17 @@ def urn(ctx: Any, urn: Optional[str], aspect: List[str], details: bool) -> None:
raise click.UsageError("Nothing for me to get. Maybe provide an urn?")
urn = ctx.args[0]
logger.debug(f"Using urn from args {urn}")
client = get_default_graph()
click.echo(
json.dumps(
get_aspects_for_entity(
entity_urn=urn, aspects=aspect, typed=False, details=details
session=client._session,
gms_host=client.config.server,
entity_urn=urn,
aspects=aspect,
typed=False,
),
sort_keys=True,
indent=2,

View File

@ -427,7 +427,9 @@ def mcps(path: str) -> None:
def list_runs(page_offset: int, page_size: int, include_soft_deletes: bool) -> None:
"""List recent ingestion runs to datahub"""
session, gms_host = cli_utils.get_session_and_host()
client = get_default_graph()
session = client._session
gms_host = client.config.server
url = f"{gms_host}/runs?action=list"
@ -476,7 +478,9 @@ def show(
run_id: str, start: int, count: int, include_soft_deletes: bool, show_aspect: bool
) -> None:
"""Describe a provided ingestion run to datahub"""
session, gms_host = cli_utils.get_session_and_host()
client = get_default_graph()
session = client._session
gms_host = client.config.server
url = f"{gms_host}/runs?action=describe"
@ -524,8 +528,7 @@ def rollback(
run_id: str, force: bool, dry_run: bool, safe: bool, report_dir: str
) -> None:
"""Rollback a provided ingestion run to datahub"""
cli_utils.test_connectivity_complain_exit("ingest")
client = get_default_graph()
if not force and not dry_run:
click.confirm(
@ -541,7 +544,9 @@ def rollback(
aspects_affected,
unsafe_entity_count,
unsafe_entities,
) = cli_utils.post_rollback_endpoint(payload_obj, "/runs?action=rollback")
) = cli_utils.post_rollback_endpoint(
client._session, client.config.server, payload_obj, "/runs?action=rollback"
)
click.echo(
"Rolling back deletes the entities created by a run and reverts the updated aspects"

View File

@ -11,12 +11,12 @@ from click_default_group import DefaultGroup
from datahub.cli.config_utils import (
DATAHUB_ROOT_FOLDER,
DatahubConfig,
get_client_config,
persist_datahub_config,
)
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.sink import NoopWriteCallback
from datahub.ingestion.graph.client import DatahubConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.lite.duckdb_lite_config import DuckDBLiteConfig
@ -45,7 +45,7 @@ class LiteCliConfig(DatahubConfig):
def get_lite_config() -> LiteLocalConfig:
client_config_dict = get_client_config(as_dict=True)
client_config_dict = get_client_config()
lite_config = LiteCliConfig.parse_obj(client_config_dict)
return lite_config.lite
@ -309,7 +309,7 @@ def search(
def write_lite_config(lite_config: LiteLocalConfig) -> None:
cli_config = get_client_config(as_dict=True)
cli_config = get_client_config()
assert isinstance(cli_config, dict)
cli_config["lite"] = lite_config.dict()
persist_datahub_config(cli_config)

View File

@ -1,7 +1,8 @@
import json
import logging
import random
import uuid
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, Iterable, List, Tuple, Union
import click
import progressbar
@ -23,7 +24,11 @@ from datahub.emitter.mcp_builder import (
SchemaKey,
)
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.ingestion.graph.client import (
DataHubGraph,
RelatedEntity,
get_default_graph,
)
from datahub.metadata.schema_classes import (
ContainerKeyClass,
ContainerPropertiesClass,
@ -31,6 +36,7 @@ from datahub.metadata.schema_classes import (
SystemMetadataClass,
)
from datahub.telemetry import telemetry
from datahub.utilities.urns.urn import Urn
log = logging.getLogger(__name__)
@ -143,15 +149,17 @@ def dataplatform2instance_func(
graph = get_default_graph()
urns_to_migrate = []
urns_to_migrate: List[str] = []
# we first calculate all the urns we will be migrating
for src_entity_urn in cli_utils.get_urns_by_filter(platform=platform, env=env):
for src_entity_urn in graph.get_urns_by_filter(platform=platform, env=env):
key = dataset_urn_to_key(src_entity_urn)
assert key
# Does this urn already have a platform instance associated with it?
response = cli_utils.get_aspects_for_entity(
entity_urn=src_entity_urn, aspects=["dataPlatformInstance"], typed=True
response = graph.get_aspects_for_entity(
entity_urn=src_entity_urn,
aspects=["dataPlatformInstance"],
aspect_types=[DataPlatformInstanceClass],
)
if "dataPlatformInstance" in response:
assert isinstance(
@ -229,14 +237,14 @@ def dataplatform2instance_func(
migration_report.on_entity_create(new_urn, "dataPlatformInstance")
for relationship in relationships:
target_urn = relationship["entity"]
target_urn = relationship.urn
entity_type = _get_type_from_urn(target_urn)
relationshipType = relationship["type"]
relationshipType = relationship.relationship_type
aspect_name = migration_utils.get_aspect_name_from_relationship(
relationshipType, entity_type
)
aspect_map = cli_utils.get_aspects_for_entity(
target_urn, aspects=[aspect_name], typed=True
graph._session, graph.config.server, target_urn, aspects=[aspect_name]
)
if aspect_name in aspect_map:
aspect = aspect_map[aspect_name]
@ -378,13 +386,16 @@ def migrate_containers(
def get_containers_for_migration(env: str) -> List[Any]:
containers_to_migrate = list(cli_utils.get_container_ids_by_filter(env=env))
client = get_default_graph()
containers_to_migrate = list(
client.get_urns_by_filter(entity_types=["container"], env=env)
)
containers = []
increment = 20
for i in range(0, len(containers_to_migrate), increment):
for container in cli_utils.batch_get_ids(
containers_to_migrate[i : i + increment]
for container in batch_get_ids(
client, containers_to_migrate[i : i + increment]
):
log.debug(container)
containers.append(container)
@ -392,6 +403,37 @@ def get_containers_for_migration(env: str) -> List[Any]:
return containers
def batch_get_ids(
client: DataHubGraph,
ids: List[str],
) -> Iterable[Dict]:
session = client._session
gms_host = client.config.server
endpoint: str = "/entitiesV2"
url = gms_host + endpoint
ids_to_get = [Urn.url_encode(id) for id in ids]
response = session.get(
f"{url}?ids=List({','.join(ids_to_get)})",
)
if response.status_code == 200:
assert response._content
log.debug(response._content)
results = json.loads(response._content)
num_entities = len(results["results"])
entities_yielded: int = 0
for x in results["results"].values():
entities_yielded += 1
log.debug(f"yielding {x}")
yield x
assert (
entities_yielded == num_entities
), "Did not delete all entities, try running this command again!"
else:
log.error(f"Failed to execute batch get with {str(response.content)}")
response.raise_for_status()
def process_container_relationships(
container_id_map: Dict[str, str],
dry_run: bool,
@ -400,22 +442,29 @@ def process_container_relationships(
migration_report: MigrationReport,
rest_emitter: DatahubRestEmitter,
) -> None:
relationships = migration_utils.get_incoming_relationships(urn=src_urn)
relationships: Iterable[RelatedEntity] = migration_utils.get_incoming_relationships(
urn=src_urn
)
client = get_default_graph()
for relationship in relationships:
log.debug(f"Incoming Relationship: {relationship}")
target_urn = relationship["entity"]
target_urn: str = relationship.urn
# We should use the new id if we already migrated it
if target_urn in container_id_map:
target_urn = container_id_map.get(target_urn)
target_urn = container_id_map[target_urn]
entity_type = _get_type_from_urn(target_urn)
relationshipType = relationship["type"]
relationshipType = relationship.relationship_type
aspect_name = migration_utils.get_aspect_name_from_relationship(
relationshipType, entity_type
)
aspect_map = cli_utils.get_aspects_for_entity(
target_urn, aspects=[aspect_name], typed=True
client._session,
client.config.server,
target_urn,
aspects=[aspect_name],
typed=True,
)
if aspect_name in aspect_map:
aspect = aspect_map[aspect_name]

View File

@ -1,12 +1,17 @@
import logging
import uuid
from typing import Dict, Iterable, List
from typing import Iterable, List
from avrogen.dict_wrapper import DictWrapper
from datahub.cli import cli_utils
from datahub.emitter.mce_builder import Aspect
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import (
DataHubGraph,
RelatedEntity,
get_default_graph,
)
from datahub.metadata.schema_classes import (
ChartInfoClass,
ContainerClass,
@ -238,8 +243,13 @@ def clone_aspect(
run_id: str = str(uuid.uuid4()),
dry_run: bool = False,
) -> Iterable[MetadataChangeProposalWrapper]:
client = get_default_graph()
aspect_map = cli_utils.get_aspects_for_entity(
entity_urn=src_urn, aspects=aspect_names, typed=True
client._session,
client.config.server,
entity_urn=src_urn,
aspects=aspect_names,
typed=True,
)
if aspect_names is not None:
@ -263,10 +273,11 @@ def clone_aspect(
log.debug(f"did not find aspect {a} in response, continuing...")
def get_incoming_relationships(urn: str) -> Iterable[Dict]:
yield from cli_utils.get_incoming_relationships(
urn,
types=[
def get_incoming_relationships(urn: str) -> Iterable[RelatedEntity]:
client = get_default_graph()
yield from client.get_related_entities(
entity_urn=urn,
relationship_types=[
"DownstreamOf",
"Consumes",
"Produces",
@ -274,13 +285,15 @@ def get_incoming_relationships(urn: str) -> Iterable[Dict]:
"DerivedFrom",
"IsPartOf",
],
direction=DataHubGraph.RelationshipDirection.INCOMING,
)
def get_outgoing_relationships(urn: str) -> Iterable[Dict]:
yield from cli_utils.get_outgoing_relationships(
urn,
types=[
def get_outgoing_relationships(urn: str) -> Iterable[RelatedEntity]:
client = get_default_graph()
yield from client.get_related_entities(
entity_urn=urn,
relationship_types=[
"DownstreamOf",
"Consumes",
"Produces",
@ -288,4 +301,5 @@ def get_outgoing_relationships(urn: str) -> Iterable[Dict]:
"DerivedFrom",
"IsPartOf",
],
direction=DataHubGraph.RelationshipDirection.OUTGOING,
)

View File

@ -46,7 +46,12 @@ def aspect(urn: str, aspect: str, aspect_data: str) -> None:
aspect_data, allow_stdin=True, resolve_env_vars=False, process_directives=False
)
client = get_default_graph()
# TODO: Replace with client.emit, requires figuring out the correct subsclass of _Aspect to create from the data
status = post_entity(
client._session,
client.config.server,
urn=urn,
aspect_name=aspect,
entity_type=entity_type,

View File

@ -8,8 +8,8 @@ import click
from requests import Response
from termcolor import colored
import datahub.cli.cli_utils
from datahub.emitter.mce_builder import dataset_urn_to_key, schema_field_urn_to_key
from datahub.ingestion.graph.client import get_default_graph
from datahub.telemetry import telemetry
from datahub.upgrade import upgrade
from datahub.utilities.urns.urn import Urn
@ -63,7 +63,9 @@ def get_timeline(
end_time: Optional[int],
diff: bool,
) -> Any:
session, host = datahub.cli.cli_utils.get_session_and_host()
client = get_default_graph()
session = client._session
host = client.config.server
if urn.startswith("urn%3A"):
# we assume the urn is already encoded
encoded_urn: str = urn

View File

@ -13,11 +13,7 @@ from datahub.cli.cli_utils import (
generate_access_token,
make_shim_command,
)
from datahub.cli.config_utils import (
DATAHUB_CONFIG_PATH,
get_boolean_env_variable,
write_gms_config,
)
from datahub.cli.config_utils import DATAHUB_CONFIG_PATH, get_boolean_env_variable
from datahub.cli.delete_cli import delete
from datahub.cli.docker_cli import docker
from datahub.cli.exists_cli import exists
@ -37,7 +33,7 @@ from datahub.cli.state_cli import state
from datahub.cli.telemetry import telemetry as telemetry_cli
from datahub.cli.timeline_cli import timeline
from datahub.configuration.common import should_show_stack_trace
from datahub.ingestion.graph.client import get_default_graph
from datahub.ingestion.graph.client import get_default_graph, write_gms_config
from datahub.telemetry import telemetry
from datahub.utilities._custom_package_loader import model_version_name
from datahub.utilities.logging_manager import configure_logging

View File

@ -3,6 +3,8 @@ import enum
import functools
import json
import logging
import os
import sys
import textwrap
import time
from dataclasses import dataclass
@ -22,12 +24,13 @@ from typing import (
Union,
)
import click
from avro.schema import RecordSchema
from deprecated import deprecated
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from requests.models import HTTPError
from datahub.cli.cli_utils import get_url_and_token
from datahub.cli import config_utils
from datahub.configuration.common import ConfigModel, GraphError, OperationalError
from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP
from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect
@ -87,6 +90,12 @@ logger = logging.getLogger(__name__)
_MISSING_SERVER_ID = "missing"
_GRAPH_DUMMY_RUN_ID = "__datahub-graph-client"
ENV_METADATA_HOST_URL = "DATAHUB_GMS_URL"
ENV_METADATA_TOKEN = "DATAHUB_GMS_TOKEN"
ENV_METADATA_HOST = "DATAHUB_GMS_HOST"
ENV_METADATA_PORT = "DATAHUB_GMS_PORT"
ENV_METADATA_PROTOCOL = "DATAHUB_GMS_PROTOCOL"
class DatahubClientConfig(ConfigModel):
"""Configuration class for holding connectivity to datahub gms"""
@ -583,6 +592,9 @@ class DataHubGraph(DatahubRestEmitter):
def _aspect_count_endpoint(self):
return f"{self.config.server}/aspects?action=getCount"
# def _session(self) -> Session:
# return super()._session
def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]:
"""Retrieve a domain urn based on its name. Returns None if there is no match found"""
@ -1763,7 +1775,88 @@ class DataHubGraph(DatahubRestEmitter):
def get_default_graph() -> DataHubGraph:
(url, token) = get_url_and_token()
graph = DataHubGraph(DatahubClientConfig(server=url, token=token))
graph_config = load_client_config()
graph = DataHubGraph(graph_config)
graph.test_connection()
return graph
class DatahubConfig(BaseModel):
gms: DatahubClientConfig
config_override: Dict = {}
def get_details_from_env() -> Tuple[Optional[str], Optional[str]]:
host = os.environ.get(ENV_METADATA_HOST)
port = os.environ.get(ENV_METADATA_PORT)
token = os.environ.get(ENV_METADATA_TOKEN)
protocol = os.environ.get(ENV_METADATA_PROTOCOL, "http")
url = os.environ.get(ENV_METADATA_HOST_URL)
if port is not None:
url = f"{protocol}://{host}:{port}"
return url, token
# The reason for using host as URL is backward compatibility
# If port is not being used we assume someone is using host env var as URL
if url is None and host is not None:
logger.warning(
f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead"
)
return url or host, token
def load_client_config() -> DatahubClientConfig:
try:
ensure_datahub_config()
client_config_dict = config_utils.get_client_config()
datahub_config: DatahubClientConfig = DatahubConfig.parse_obj(
client_config_dict
).gms
except ValidationError as e:
click.echo(
f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}"
)
click.echo(e, err=True)
sys.exit(1)
# Override gms & token configs if specified.
if len(config_override.keys()) > 0:
datahub_config.server = str(config_override.get(ENV_METADATA_HOST_URL))
datahub_config.token = config_override.get(ENV_METADATA_TOKEN)
elif config_utils.should_skip_config():
gms_host_env, gms_token_env = get_details_from_env()
if gms_host_env:
datahub_config.server = gms_host_env
datahub_config.token = gms_token_env
return datahub_config
def ensure_datahub_config() -> None:
if not os.path.isfile(config_utils.DATAHUB_CONFIG_PATH):
click.secho(
f"No {config_utils.CONDENSED_DATAHUB_CONFIG_PATH} file found, generating one for you...",
bold=True,
)
write_gms_config(config_utils.DEFAULT_GMS_HOST, None)
def write_gms_config(
host: str, token: Optional[str], merge_with_previous: bool = True
) -> None:
config = DatahubConfig(gms=DatahubClientConfig(server=host, token=token))
if merge_with_previous:
try:
previous_config = config_utils.get_client_config()
assert isinstance(previous_config, dict)
except Exception as e:
# ok to fail on this
previous_config = {}
logger.debug(
f"Failed to retrieve config from file {config_utils.DATAHUB_CONFIG_PATH}: {e}. This isn't fatal."
)
config_dict = {**previous_config, **config.dict()}
else:
config_dict = config.dict()
config_utils.persist_datahub_config(config_dict)

View File

@ -1,12 +1,14 @@
import datetime
import logging
import os
import uuid
from typing import Any, Dict, List, Optional
from pydantic import Field, validator
from pydantic import Field, root_validator, validator
from datahub.configuration import config_loader
from datahub.configuration.common import ConfigModel, DynamicTypedConfig
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.graph.client import DatahubClientConfig, load_client_config
from datahub.ingestion.sink.file import FileSinkConfig
logger = logging.getLogger(__name__)
@ -101,6 +103,34 @@ class PipelineConfig(ConfigModel):
assert v is not None
return v
@root_validator(pre=True)
def default_sink_is_datahub_rest(cls, values: Dict[str, Any]) -> Any:
if "sink" not in values:
config = load_client_config()
# update this
default_sink_config = {
"type": "datahub-rest",
"config": config.dict(exclude_defaults=True),
}
# resolve env variables if present
default_sink_config = config_loader.resolve_env_variables(
default_sink_config, environ=os.environ
)
values["sink"] = default_sink_config
return values
@validator("datahub_api", always=True)
def datahub_api_should_use_rest_sink_as_default(
cls, v: Optional[DatahubClientConfig], values: Dict[str, Any], **kwargs: Any
) -> Optional[DatahubClientConfig]:
if v is None and "sink" in values and hasattr(values["sink"], "type"):
sink_type = values["sink"].type
if sink_type == "datahub-rest":
sink_config = values["sink"].config
v = DatahubClientConfig.parse_obj_allow_extras(sink_config)
return v
@classmethod
def from_dict(
cls, resolved_dict: dict, raw_dict: Optional[dict] = None

View File

@ -9,7 +9,6 @@ import uuid
from enum import auto
from typing import List, Optional, Tuple, Union
from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
ConfigEnum,
ConfigurationError,
@ -120,7 +119,6 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
)
self.report.max_threads = self.config.max_threads
logger.debug("Setting env variables to override config")
set_env_variables_override_config(self.config.server, self.config.token)
logger.debug("Setting gms config")
set_gms_config(gms_config)

View File

@ -35,6 +35,7 @@ from datahub.ingestion.api.source_helpers import (
auto_workunit_reporter,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import get_default_graph
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineageDownstreamType,
FineGrainedLineageUpstreamType,
@ -209,7 +210,12 @@ def _get_lineage_mcp(
# extract the old lineage and save it for the new mcp
if preserve_upstream:
client = get_default_graph()
old_upstream_lineage = get_aspects_for_entity(
client._session,
client.config.server,
entity_urn=entity_urn,
aspects=["upstreamLineage"],
typed=True,

View File

@ -12,8 +12,7 @@ from pydantic import BaseModel
from termcolor import colored
from datahub import __version__
from datahub.cli import cli_utils
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.graph.client import DataHubGraph, load_client_config
log = logging.getLogger(__name__)
@ -101,16 +100,18 @@ async def get_github_stats():
return (latest_server_version, latest_server_date)
async def get_server_config(gms_url: str, token: str) -> dict:
async def get_server_config(gms_url: str, token: Optional[str]) -> dict:
import aiohttp
async with aiohttp.ClientSession(
headers={
"X-RestLi-Protocol-Version": "2.0.0",
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
) as session:
headers = {
"X-RestLi-Protocol-Version": "2.0.0",
"Content-Type": "application/json",
}
if token:
headers["Authorization"] = f"Bearer {token}"
async with aiohttp.ClientSession() as session:
config_endpoint = f"{gms_url}/config"
async with session.get(config_endpoint) as dh_response:
dh_response_json = await dh_response.json()
@ -126,7 +127,9 @@ async def get_server_version_stats(
if not server:
try:
# let's get the server from the cli config
host, token = cli_utils.get_url_and_token()
client_config = load_client_config()
host = client_config.server
token = client_config.token
server_config = await get_server_config(host, token)
log.debug(f"server_config:{server_config}")
except Exception as e:

View File

@ -2,6 +2,7 @@ import os
from unittest import mock
from datahub.cli import cli_utils
from datahub.ingestion.graph.client import get_details_from_env
def test_first_non_null():
@ -16,14 +17,14 @@ def test_first_non_null():
@mock.patch.dict(os.environ, {"DATAHUB_GMS_HOST": "http://localhost:9092"})
def test_correct_url_when_gms_host_in_old_format():
assert cli_utils.get_details_from_env() == ("http://localhost:9092", None)
assert get_details_from_env() == ("http://localhost:9092", None)
@mock.patch.dict(
os.environ, {"DATAHUB_GMS_HOST": "localhost", "DATAHUB_GMS_PORT": "8080"}
)
def test_correct_url_when_gms_host_and_port_set():
assert cli_utils.get_details_from_env() == ("http://localhost:8080", None)
assert get_details_from_env() == ("http://localhost:8080", None)
@mock.patch.dict(
@ -35,7 +36,7 @@ def test_correct_url_when_gms_host_and_port_set():
},
)
def test_correct_url_when_gms_host_port_url_set():
assert cli_utils.get_details_from_env() == ("http://localhost:8080", None)
assert get_details_from_env() == ("http://localhost:8080", None)
@mock.patch.dict(
@ -48,7 +49,7 @@ def test_correct_url_when_gms_host_port_url_set():
},
)
def test_correct_url_when_gms_host_port_url_protocol_set():
assert cli_utils.get_details_from_env() == ("https://localhost:8080", None)
assert get_details_from_env() == ("https://localhost:8080", None)
@mock.patch.dict(
@ -58,7 +59,7 @@ def test_correct_url_when_gms_host_port_url_protocol_set():
},
)
def test_correct_url_when_url_set():
assert cli_utils.get_details_from_env() == ("https://example.com", None)
assert get_details_from_env() == ("https://example.com", None)
def test_fixup_gms_url():

View File

@ -1,7 +1,11 @@
import json
import pytest
from datahub.cli.cli_utils import get_aspects_for_entity, get_session_and_host
from datahub.ingestion.graph.client import get_default_graph
from datahub.metadata.schema_classes import (
BrowsePathsV2Class,
EditableDatasetPropertiesClass,
)
from tests.utils import ingest_file_via_rest, wait_for_writes_to_sync
@ -22,23 +26,19 @@ def test_setup():
env = "PROD"
dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})"
session, gms_host = get_session_and_host()
client = get_default_graph()
session = client._session
gms_host = client.config.server
assert "browsePathsV2" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False
)
assert "editableDatasetProperties" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
)
assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None
assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is None
ingested_dataset_run_id = ingest_file_via_rest(
"tests/cli/cli_test_data.json"
).config.run_id
print("Setup ingestion id: " + ingested_dataset_run_id)
assert "browsePathsV2" in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False
)
assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is not None
yield
@ -58,12 +58,8 @@ def test_setup():
),
)
assert "browsePathsV2" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False
)
assert "editableDatasetProperties" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
)
assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None
assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is None
@pytest.mark.dependency()
@ -75,13 +71,14 @@ def test_rollback_editable():
env = "PROD"
dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})"
session, gms_host = get_session_and_host()
client = get_default_graph()
session = client._session
gms_host = client.config.server
print("Ingested dataset id:", ingested_dataset_run_id)
# Assert that second data ingestion worked
assert "browsePathsV2" in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False
)
assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is not None
# Make editable change
ingested_editable_run_id = ingest_file_via_rest(
@ -89,9 +86,8 @@ def test_rollback_editable():
).config.run_id
print("ingested editable id:", ingested_editable_run_id)
# Assert that second data ingestion worked
assert "editableDatasetProperties" in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
)
assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is not None
# rollback ingestion 1
rollback_url = f"{gms_host}/runs?action=rollback"
@ -107,10 +103,7 @@ def test_rollback_editable():
wait_for_writes_to_sync()
# EditableDatasetProperties should still be part of the entity that was soft deleted.
assert "editableDatasetProperties" in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
)
assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is not None
# But first ingestion aspects should not be present
assert "browsePathsV2" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False
)
assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None

View File

@ -1,3 +1,5 @@
from typing import Optional
import pytest
import tenacity
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
@ -36,9 +38,9 @@ def test_healthchecks(wait_for_healthchecks):
@pytest.mark.dependency(depends=["test_healthchecks"])
def test_get_aspect_v2(frontend_session, ingest_cleanup_data):
graph: DataHubGraph = DataHubGraph(DatahubClientConfig(server=get_gms_url()))
client: DataHubGraph = DataHubGraph(DatahubClientConfig(server=get_gms_url()))
urn = "urn:li:dataset:(urn:li:dataPlatform:kafka,test-rollback,PROD)"
schema_metadata: SchemaMetadataClass = graph.get_aspect_v2(
schema_metadata: Optional[SchemaMetadataClass] = client.get_aspect_v2(
urn, aspect="schemaMetadata", aspect_type=SchemaMetadataClass
)

View File

@ -2,7 +2,7 @@ import json
import os
import pytest
from datahub.cli.cli_utils import get_aspects_for_entity, get_session_and_host
from datahub.cli.cli_utils import get_aspects_for_entity
from tests.utils import (
delete_urns_from_file,
@ -38,14 +38,24 @@ def test_setup():
env = "PROD"
dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})"
session, gms_host = get_session_and_host()
client = get_datahub_graph()
session = client._session
gms_host = client.config.server
try:
assert "institutionalMemory" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False
session,
gms_host,
entity_urn=dataset_urn,
aspects=["institutionalMemory"],
typed=False,
)
assert "editableDatasetProperties" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
session,
gms_host,
entity_urn=dataset_urn,
aspects=["editableDatasetProperties"],
typed=False,
)
except Exception as e:
delete_urns_from_file("tests/delete/cli_test_data.json")
@ -56,7 +66,11 @@ def test_setup():
).config.run_id
assert "institutionalMemory" in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False
session,
gms_host,
entity_urn=dataset_urn,
aspects=["institutionalMemory"],
typed=False,
)
yield
@ -71,10 +85,18 @@ def test_setup():
wait_for_writes_to_sync()
assert "institutionalMemory" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False
session,
gms_host,
entity_urn=dataset_urn,
aspects=["institutionalMemory"],
typed=False,
)
assert "editableDatasetProperties" not in get_aspects_for_entity(
entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False
session,
gms_host,
entity_urn=dataset_urn,
aspects=["editableDatasetProperties"],
typed=False,
)

View File

@ -6,10 +6,8 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import datahub.emitter.mce_builder as builder
import networkx as nx
import pytest
from datahub.cli.cli_utils import get_url_and_token
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DataHubGraph # get_default_graph,
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.metadata.schema_classes import (
AuditStampClass,
ChangeAuditStampsClass,
@ -847,10 +845,7 @@ def test_lineage_via_node(
)
# Create an emitter to the GMS REST API.
(url, token) = get_url_and_token()
with DataHubGraph(
DatahubClientConfig(server=url, token=token, retry_max_times=0)
) as graph:
with get_default_graph() as graph:
emitter = graph
# emitter = DataHubConsoleEmitter()
@ -891,14 +886,11 @@ def destination_urn_fixture():
def ingest_multipath_metadata(
chart_urn_fixture, intermediates_fixture, destination_urn_fixture
):
(url, token) = get_url_and_token()
fake_auditstamp = AuditStampClass(
time=int(time.time() * 1000),
actor="urn:li:corpuser:datahub",
)
with DataHubGraph(
DatahubClientConfig(server=url, token=token, retry_max_times=0)
) as graph:
with get_default_graph() as graph:
chart_urn = chart_urn_fixture
intermediates = intermediates_fixture
destination_urn = destination_urn_fixture

View File

@ -3,7 +3,7 @@ from typing import Dict, Optional
from datahub.emitter.mce_builder import make_dataset_urn, make_tag_urn, make_term_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DataHubGraph, DataHubGraphConfig
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.metadata.schema_classes import (
DatasetLineageTypeClass,
DatasetPropertiesClass,
@ -72,13 +72,16 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks):
)
mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=upstream_lineage)
with DataHubGraph(DataHubGraphConfig()) as graph:
with get_default_graph() as graph:
graph.emit_mcp(mcpw)
upstream_lineage_read = graph.get_aspect_v2(
entity_urn=dataset_urn,
aspect_type=UpstreamLineageClass,
aspect="upstreamLineage",
)
assert upstream_lineage_read is not None
assert len(upstream_lineage_read.upstreams) > 0
assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn
for patch_mcp in (
@ -94,6 +97,8 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks):
aspect_type=UpstreamLineageClass,
aspect="upstreamLineage",
)
assert upstream_lineage_read is not None
assert len(upstream_lineage_read.upstreams) == 2
assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn
assert upstream_lineage_read.upstreams[1].dataset == patch_dataset_urn
@ -111,6 +116,8 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks):
aspect_type=UpstreamLineageClass,
aspect="upstreamLineage",
)
assert upstream_lineage_read is not None
assert len(upstream_lineage_read.upstreams) == 1
assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn
@ -148,7 +155,7 @@ def test_field_terms_patch(wait_for_healthchecks):
)
mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=editable_field)
with DataHubGraph(DataHubGraphConfig()) as graph:
with get_default_graph() as graph:
graph.emit_mcp(mcpw)
field_info = get_field_info(graph, dataset_urn, field_path)
assert field_info
@ -209,7 +216,7 @@ def test_field_tags_patch(wait_for_healthchecks):
)
mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=editable_field)
with DataHubGraph(DataHubGraphConfig()) as graph:
with get_default_graph() as graph:
graph.emit_mcp(mcpw)
field_info = get_field_info(graph, dataset_urn, field_path)
assert field_info
@ -299,7 +306,7 @@ def test_custom_properties_patch(wait_for_healthchecks):
base_aspect=orig_dataset_properties,
)
with DataHubGraph(DataHubGraphConfig()) as graph:
with get_default_graph() as graph:
# Patch custom properties along with name
for patch_mcp in (
DatasetPatchBuilder(dataset_urn)

View File

@ -1,6 +1,7 @@
import json
from datahub.cli.cli_utils import get_aspects_for_entity
from datahub.ingestion.graph.client import get_default_graph
def test_no_client_id():
@ -9,8 +10,16 @@ def test_no_client_id():
"clientId"
] # this is checking for the removal of the invalid aspect RemoveClientIdAspectStep.java
client = get_default_graph()
res_data = json.dumps(
get_aspects_for_entity(entity_urn=client_id_urn, aspects=aspect, typed=False)
get_aspects_for_entity(
session=client._session,
gms_host=client.config.server,
entity_urn=client_id_urn,
aspects=aspect,
typed=False,
)
)
assert res_data == "{}"
@ -19,7 +28,15 @@ def test_no_telemetry_client_id():
client_id_urn = "urn:li:telemetry:clientId"
aspect = ["telemetryClientId"] # telemetry expected to be disabled for tests
client = get_default_graph()
res_data = json.dumps(
get_aspects_for_entity(entity_urn=client_id_urn, aspects=aspect, typed=False)
get_aspects_for_entity(
session=client._session,
gms_host=client.config.server,
entity_urn=client_id_urn,
aspects=aspect,
typed=False,
)
)
assert res_data == "{}"

View File

@ -179,11 +179,13 @@ def test_ownership():
def put(urn: str, aspect: str, aspect_data: str) -> None:
"""Update a single aspect of an entity"""
client = get_datahub_graph()
entity_type = guess_entity_type(urn)
with open(aspect_data) as fp:
aspect_obj = json.load(fp)
post_entity(
session=client._session,
gms_host=client.config.server,
urn=urn,
aspect_name=aspect,
entity_type=entity_type,

View File

@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Tuple
from datahub.cli import cli_utils, env_utils
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
from datahub.ingestion.run.pipeline import Pipeline
from joblib import Parallel, delayed
@ -120,7 +120,7 @@ def ingest_file_via_rest(filename: str) -> Pipeline:
@functools.lru_cache(maxsize=1)
def get_datahub_graph() -> DataHubGraph:
return DataHubGraph(DatahubClientConfig(server=get_gms_url()))
return get_default_graph()
def delete_urn(urn: str) -> None: