diff --git a/metadata-ingestion/scripts/avro_codegen.py b/metadata-ingestion/scripts/avro_codegen.py index 2841985ad0..0fe79a2c6a 100644 --- a/metadata-ingestion/scripts/avro_codegen.py +++ b/metadata-ingestion/scripts/avro_codegen.py @@ -346,7 +346,7 @@ def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None: code = """ # This file contains classes corresponding to entity URNs. -from typing import ClassVar, List, Optional, Type, TYPE_CHECKING +from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union import functools from deprecated.sphinx import deprecated as _sphinx_deprecated @@ -547,10 +547,31 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str: assert fields[0]["type"] == ["null", "string"] fields[0]["type"] = "string" + field_urn_type_classes = {} + for field in fields: + # Figure out if urn types are valid for each field. + field_urn_type_class = None + if field_name(field) == "platform": + field_urn_type_class = "DataPlatformUrn" + elif field.get("Urn"): + if len(field.get("entityTypes", [])) == 1: + field_entity_type = field["entityTypes"][0] + field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn" + else: + field_urn_type_class = "Urn" + + field_urn_type_classes[field_name(field)] = field_urn_type_class + _init_arg_parts: List[str] = [] for field in fields: + field_urn_type_class = field_urn_type_classes[field_name(field)] + default = '"PROD"' if field_name(field) == "env" else None - _arg_part = f"{field_name(field)}: {field_type(field)}" + + type_hint = field_type(field) + if field_urn_type_class: + type_hint = f'Union["{field_urn_type_class}", str]' + _arg_part = f"{field_name(field)}: {type_hint}" if default: _arg_part += f" = {default}" _init_arg_parts.append(_arg_part) @@ -579,16 +600,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str: init_validation += f'if not {field_name(field)}:\n raise InvalidUrnError("{class_name} {field_name(field)} cannot be empty")\n' # Generalized mechanism for validating embedded urns. - field_urn_type_class = None - if field_name(field) == "platform": - field_urn_type_class = "DataPlatformUrn" - elif field.get("Urn"): - if len(field.get("entityTypes", [])) == 1: - field_entity_type = field["entityTypes"][0] - field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn" - else: - field_urn_type_class = "Urn" - + field_urn_type_class = field_urn_type_classes[field_name(field)] if field_urn_type_class: init_validation += f"{field_name(field)} = str({field_name(field)})\n" init_validation += ( @@ -608,7 +620,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str: init_coercion += " platform_name = DataPlatformUrn.from_string(platform_name).platform_name\n" if field_name(field) == "platform": - init_coercion += "platform = DataPlatformUrn(platform).urn()\n" + init_coercion += "platform = platform.urn() if isinstance(platform, DataPlatformUrn) else DataPlatformUrn(platform).urn()\n" elif field_urn_type_class is None: # For all non-urns, run the value through the UrnEncoder. init_coercion += ( diff --git a/metadata-ingestion/tests/unit/urns/test_urn.py b/metadata-ingestion/tests/unit/urns/test_urn.py index 0c362473c0..bee80ec331 100644 --- a/metadata-ingestion/tests/unit/urns/test_urn.py +++ b/metadata-ingestion/tests/unit/urns/test_urn.py @@ -4,7 +4,13 @@ from typing import List import pytest -from datahub.metadata.urns import CorpUserUrn, DatasetUrn, Urn +from datahub.metadata.urns import ( + CorpUserUrn, + DataPlatformUrn, + DatasetUrn, + SchemaFieldUrn, + Urn, +) from datahub.utilities.urns.error import InvalidUrnError pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") @@ -60,6 +66,20 @@ def test_urn_coercion() -> None: assert urn == Urn.from_string(urn.urn()) +def test_urns_in_init() -> None: + platform = DataPlatformUrn("abc") + assert platform.urn() == "urn:li:dataPlatform:abc" + + dataset_urn = DatasetUrn(platform, "def", "PROD") + assert dataset_urn.urn() == "urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)" + + schema_field = SchemaFieldUrn(dataset_urn, "foo") + assert ( + schema_field.urn() + == "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD),foo)" + ) + + def test_urn_type_dispatch_1() -> None: urn = Urn.from_string("urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)") assert isinstance(urn, DatasetUrn)