feat(sdk): support urns in other urn constructors (#12311)

This commit is contained in:
Harshal Sheth 2025-01-10 09:45:31 -08:00 committed by GitHub
parent 208447d20e
commit 5f63f3fba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 14 deletions

View File

@ -346,7 +346,7 @@ def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None:
code = """
# This file contains classes corresponding to entity URNs.
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union
import functools
from deprecated.sphinx import deprecated as _sphinx_deprecated
@ -547,10 +547,31 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
assert fields[0]["type"] == ["null", "string"]
fields[0]["type"] = "string"
field_urn_type_classes = {}
for field in fields:
# Figure out if urn types are valid for each field.
field_urn_type_class = None
if field_name(field) == "platform":
field_urn_type_class = "DataPlatformUrn"
elif field.get("Urn"):
if len(field.get("entityTypes", [])) == 1:
field_entity_type = field["entityTypes"][0]
field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn"
else:
field_urn_type_class = "Urn"
field_urn_type_classes[field_name(field)] = field_urn_type_class
_init_arg_parts: List[str] = []
for field in fields:
field_urn_type_class = field_urn_type_classes[field_name(field)]
default = '"PROD"' if field_name(field) == "env" else None
_arg_part = f"{field_name(field)}: {field_type(field)}"
type_hint = field_type(field)
if field_urn_type_class:
type_hint = f'Union["{field_urn_type_class}", str]'
_arg_part = f"{field_name(field)}: {type_hint}"
if default:
_arg_part += f" = {default}"
_init_arg_parts.append(_arg_part)
@ -579,16 +600,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
init_validation += f'if not {field_name(field)}:\n raise InvalidUrnError("{class_name} {field_name(field)} cannot be empty")\n'
# Generalized mechanism for validating embedded urns.
field_urn_type_class = None
if field_name(field) == "platform":
field_urn_type_class = "DataPlatformUrn"
elif field.get("Urn"):
if len(field.get("entityTypes", [])) == 1:
field_entity_type = field["entityTypes"][0]
field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn"
else:
field_urn_type_class = "Urn"
field_urn_type_class = field_urn_type_classes[field_name(field)]
if field_urn_type_class:
init_validation += f"{field_name(field)} = str({field_name(field)})\n"
init_validation += (
@ -608,7 +620,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
init_coercion += " platform_name = DataPlatformUrn.from_string(platform_name).platform_name\n"
if field_name(field) == "platform":
init_coercion += "platform = DataPlatformUrn(platform).urn()\n"
init_coercion += "platform = platform.urn() if isinstance(platform, DataPlatformUrn) else DataPlatformUrn(platform).urn()\n"
elif field_urn_type_class is None:
# For all non-urns, run the value through the UrnEncoder.
init_coercion += (

View File

@ -4,7 +4,13 @@ from typing import List
import pytest
from datahub.metadata.urns import CorpUserUrn, DatasetUrn, Urn
from datahub.metadata.urns import (
CorpUserUrn,
DataPlatformUrn,
DatasetUrn,
SchemaFieldUrn,
Urn,
)
from datahub.utilities.urns.error import InvalidUrnError
pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning")
@ -60,6 +66,20 @@ def test_urn_coercion() -> None:
assert urn == Urn.from_string(urn.urn())
def test_urns_in_init() -> None:
platform = DataPlatformUrn("abc")
assert platform.urn() == "urn:li:dataPlatform:abc"
dataset_urn = DatasetUrn(platform, "def", "PROD")
assert dataset_urn.urn() == "urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)"
schema_field = SchemaFieldUrn(dataset_urn, "foo")
assert (
schema_field.urn()
== "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD),foo)"
)
def test_urn_type_dispatch_1() -> None:
urn = Urn.from_string("urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)")
assert isinstance(urn, DatasetUrn)