mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-03 20:27:50 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			943 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			943 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import collections
 | 
						|
import copy
 | 
						|
import json
 | 
						|
import re
 | 
						|
import textwrap
 | 
						|
from pathlib import Path
 | 
						|
from typing import Dict, Iterable, List, Optional, Tuple, Union
 | 
						|
 | 
						|
import avro.schema
 | 
						|
import click
 | 
						|
import pydantic
 | 
						|
import yaml
 | 
						|
from avrogen import write_schema_files
 | 
						|
 | 
						|
ENTITY_CATEGORY_UNSET = "_unset_"
 | 
						|
 | 
						|
 | 
						|
class EntityType(pydantic.BaseModel):
 | 
						|
    name: str
 | 
						|
    doc: Optional[str] = None
 | 
						|
    category: str = ENTITY_CATEGORY_UNSET
 | 
						|
 | 
						|
    keyAspect: str
 | 
						|
    aspects: List[str]
 | 
						|
 | 
						|
 | 
						|
def load_entity_registry(entity_registry_file: Path) -> List[EntityType]:
 | 
						|
    with entity_registry_file.open() as f:
 | 
						|
        raw_entity_registry = yaml.safe_load(f)
 | 
						|
 | 
						|
    entities = pydantic.parse_obj_as(List[EntityType], raw_entity_registry["entities"])
 | 
						|
    return entities
 | 
						|
 | 
						|
 | 
						|
def load_schema_file(schema_file: Union[str, Path]) -> dict:
 | 
						|
    raw_schema_text = Path(schema_file).read_text()
 | 
						|
    return json.loads(raw_schema_text)
 | 
						|
 | 
						|
 | 
						|
def load_schemas(schemas_path: str) -> Dict[str, dict]:
 | 
						|
    required_schema_files = {
 | 
						|
        "mxe/MetadataChangeEvent.avsc",
 | 
						|
        "mxe/MetadataChangeProposal.avsc",
 | 
						|
        "usage/UsageAggregation.avsc",
 | 
						|
        "mxe/MetadataChangeLog.avsc",
 | 
						|
        "mxe/PlatformEvent.avsc",
 | 
						|
        "platform/event/v1/EntityChangeEvent.avsc",
 | 
						|
        "metadata/query/filter/Filter.avsc",  # temporarily added to test reserved keywords support
 | 
						|
    }
 | 
						|
 | 
						|
    # Find all the aspect schemas / other important schemas.
 | 
						|
    schema_files: List[Path] = []
 | 
						|
    for schema_file in Path(schemas_path).glob("**/*.avsc"):
 | 
						|
        relative_path = schema_file.relative_to(schemas_path).as_posix()
 | 
						|
        if relative_path in required_schema_files:
 | 
						|
            schema_files.append(schema_file)
 | 
						|
            required_schema_files.remove(relative_path)
 | 
						|
        elif load_schema_file(schema_file).get("Aspect"):
 | 
						|
            schema_files.append(schema_file)
 | 
						|
 | 
						|
    assert not required_schema_files, f"Schema files not found: {required_schema_files}"
 | 
						|
 | 
						|
    schemas: Dict[str, dict] = {}
 | 
						|
    for schema_file in schema_files:
 | 
						|
        schema = load_schema_file(schema_file)
 | 
						|
        schemas[Path(schema_file).stem] = schema
 | 
						|
 | 
						|
    return schemas
 | 
						|
 | 
						|
 | 
						|
def patch_schemas(schemas: Dict[str, dict], pdl_path: Path) -> Dict[str, dict]:
 | 
						|
    # We can easily find normal urn types using the generated avro schema,
 | 
						|
    # but for arrays of urns there's nothing in the avro schema and hence
 | 
						|
    # we have to look in the PDL files instead.
 | 
						|
    urn_arrays: Dict[
 | 
						|
        str, List[Tuple[str, str]]
 | 
						|
    ] = {}  # schema name -> list of (field name, type)
 | 
						|
 | 
						|
    # First, we need to load the PDL files and find all urn arrays.
 | 
						|
    for pdl_file in Path(pdl_path).glob("**/*.pdl"):
 | 
						|
        pdl_text = pdl_file.read_text()
 | 
						|
 | 
						|
        # TRICKY: We assume that all urn types end with "Urn".
 | 
						|
        arrays = re.findall(
 | 
						|
            r"^\s*(\w+)\s*:\s*(?:optional\s+)?array\[(\w*Urn)\]",
 | 
						|
            pdl_text,
 | 
						|
            re.MULTILINE,
 | 
						|
        )
 | 
						|
        if arrays:
 | 
						|
            schema_name = pdl_file.stem
 | 
						|
            urn_arrays[schema_name] = [(item[0], item[1]) for item in arrays]
 | 
						|
 | 
						|
    # Then, we can patch each schema.
 | 
						|
    patched_schemas = {}
 | 
						|
    for name, schema in schemas.items():
 | 
						|
        patched_schemas[name] = patch_schema(schema, urn_arrays)
 | 
						|
 | 
						|
    return patched_schemas
 | 
						|
 | 
						|
 | 
						|
def patch_schema(schema: dict, urn_arrays: Dict[str, List[Tuple[str, str]]]) -> dict:
 | 
						|
    """
 | 
						|
    This method patches the schema to add an "Urn" property to all urn fields.
 | 
						|
    Because the inner type in an array is not a named Avro schema, for urn arrays
 | 
						|
    we annotate the array field and add an "urn_is_array" property.
 | 
						|
    """
 | 
						|
 | 
						|
    # We're using Names() to generate a full list of embedded schemas.
 | 
						|
    all_schemas = avro.schema.Names()
 | 
						|
    patched = avro.schema.make_avsc_object(schema, names=all_schemas)
 | 
						|
 | 
						|
    for nested in all_schemas.names.values():
 | 
						|
        if isinstance(nested, (avro.schema.EnumSchema, avro.schema.FixedSchema)):
 | 
						|
            continue
 | 
						|
        assert isinstance(nested, avro.schema.RecordSchema)
 | 
						|
 | 
						|
        # Patch normal urn types.
 | 
						|
        field: avro.schema.Field
 | 
						|
        for field in nested.fields:
 | 
						|
            field_props: dict = field.props  # type: ignore
 | 
						|
            java_props: dict = field_props.get("java", {})
 | 
						|
            java_class: Optional[str] = java_props.get("class")
 | 
						|
            if java_class and java_class.startswith(
 | 
						|
                "com.linkedin.pegasus2avro.common.urn."
 | 
						|
            ):
 | 
						|
                type = java_class.split(".")[-1]
 | 
						|
                entity_types = field_props.get("Relationship", {}).get(
 | 
						|
                    "entityTypes", []
 | 
						|
                )
 | 
						|
 | 
						|
                field.set_prop("Urn", type)
 | 
						|
                if entity_types:
 | 
						|
                    field.set_prop("entityTypes", entity_types)
 | 
						|
 | 
						|
        # Patch array urn types.
 | 
						|
        if nested.name in urn_arrays:
 | 
						|
            mapping = urn_arrays[nested.name]
 | 
						|
 | 
						|
            for field_name, type in mapping:
 | 
						|
                field = nested.fields_dict[field_name]
 | 
						|
                field.set_prop("Urn", type)
 | 
						|
                field.set_prop("urn_is_array", True)
 | 
						|
 | 
						|
    return patched.to_json()  # type: ignore
 | 
						|
 | 
						|
 | 
						|
def merge_schemas(schemas_obj: List[dict]) -> str:
 | 
						|
    # Combine schemas as a "union" of all of the types.
 | 
						|
    merged = ["null"] + schemas_obj
 | 
						|
 | 
						|
    # Check that we don't have the same class name in multiple namespaces.
 | 
						|
    names_to_spaces: Dict[str, str] = {}
 | 
						|
 | 
						|
    # Patch add_name method to NOT complain about duplicate names.
 | 
						|
    class NamesWithDups(avro.schema.Names):
 | 
						|
        def add_name(self, name_attr, space_attr, new_schema):
 | 
						|
            to_add = avro.schema.Name(name_attr, space_attr, self.default_namespace)
 | 
						|
            assert to_add.name
 | 
						|
            assert to_add.space
 | 
						|
            assert to_add.fullname
 | 
						|
 | 
						|
            if to_add.name in names_to_spaces:
 | 
						|
                if names_to_spaces[to_add.name] != to_add.space:
 | 
						|
                    raise ValueError(
 | 
						|
                        f"Duplicate name {to_add.name} in namespaces {names_to_spaces[to_add.name]} and {to_add.space}. "
 | 
						|
                        "This will cause conflicts in the generated code."
 | 
						|
                    )
 | 
						|
            else:
 | 
						|
                names_to_spaces[to_add.name] = to_add.space
 | 
						|
 | 
						|
            self.names[to_add.fullname] = new_schema
 | 
						|
            return to_add
 | 
						|
 | 
						|
    cleaned_schema = avro.schema.make_avsc_object(merged, names=NamesWithDups())
 | 
						|
 | 
						|
    # Convert back to an Avro schema JSON representation.
 | 
						|
    out_schema = cleaned_schema.to_json()
 | 
						|
    encoded = json.dumps(out_schema, indent=2)
 | 
						|
    return encoded
 | 
						|
 | 
						|
 | 
						|
autogen_header = """# mypy: ignore-errors
 | 
						|
# flake8: noqa
 | 
						|
 | 
						|
# This file is autogenerated by /metadata-ingestion/scripts/avro_codegen.py
 | 
						|
# Do not modify manually!
 | 
						|
 | 
						|
# pylint: skip-file
 | 
						|
# fmt: off
 | 
						|
# isort: skip_file
 | 
						|
"""
 | 
						|
autogen_footer = """
 | 
						|
# fmt: on
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
def suppress_checks_in_file(filepath: Union[str, Path]) -> None:
 | 
						|
    """
 | 
						|
    Adds a couple lines to the top of an autogenerated file:
 | 
						|
        - Comments to suppress flake8 and black.
 | 
						|
        - A note stating that the file was autogenerated.
 | 
						|
    """
 | 
						|
 | 
						|
    with open(filepath, "r+") as f:
 | 
						|
        contents = f.read()
 | 
						|
 | 
						|
        f.seek(0, 0)
 | 
						|
        f.write(autogen_header)
 | 
						|
        f.write(contents)
 | 
						|
        f.write(autogen_footer)
 | 
						|
 | 
						|
 | 
						|
def add_avro_python3_warning(filepath: Path) -> None:
 | 
						|
    contents = filepath.read_text()
 | 
						|
 | 
						|
    contents = f"""
 | 
						|
# The SchemaFromJSONData method only exists in avro-python3, but is called make_avsc_object in avro.
 | 
						|
# We can use this fact to detect conflicts between the two packages. Pip won't detect those conflicts
 | 
						|
# because both are namespace packages, and hence are allowed to overwrite files from each other.
 | 
						|
# This means that installation order matters, which is a pretty unintuitive outcome.
 | 
						|
# See https://github.com/pypa/pip/issues/4625 for details.
 | 
						|
try:
 | 
						|
    from avro.schema import SchemaFromJSONData  # type: ignore
 | 
						|
    import warnings
 | 
						|
 | 
						|
    warnings.warn("It seems like 'avro-python3' is installed, which conflicts with the 'avro' package used by datahub. "
 | 
						|
                + "Try running `pip uninstall avro-python3 && pip install --upgrade --force-reinstall avro` to fix this issue.")
 | 
						|
except ImportError:
 | 
						|
    pass
 | 
						|
 | 
						|
{contents}
 | 
						|
    """
 | 
						|
 | 
						|
    filepath.write_text(contents)
 | 
						|
 | 
						|
 | 
						|
load_schema_method = """
 | 
						|
import functools
 | 
						|
import pathlib
 | 
						|
 | 
						|
@functools.lru_cache(maxsize=None)
 | 
						|
def _load_schema(schema_name: str) -> str:
 | 
						|
    return (pathlib.Path(__file__).parent / f"{schema_name}.avsc").read_text()
 | 
						|
"""
 | 
						|
individual_schema_method = """
 | 
						|
def get{schema_name}Schema() -> str:
 | 
						|
    return _load_schema("{schema_name}")
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
def make_load_schema_methods(schemas: Iterable[str]) -> str:
 | 
						|
    return load_schema_method + "".join(
 | 
						|
        individual_schema_method.format(schema_name=schema) for schema in schemas
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def save_raw_schemas(schema_save_dir: Path, schemas: Dict[str, dict]) -> None:
 | 
						|
    # Save raw avsc files.
 | 
						|
    for name, schema in schemas.items():
 | 
						|
        (schema_save_dir / f"{name}.avsc").write_text(json.dumps(schema, indent=2))
 | 
						|
 | 
						|
    # Add getXSchema methods.
 | 
						|
    with open(schema_save_dir / "__init__.py", "w") as schema_dir_init:
 | 
						|
        schema_dir_init.write(make_load_schema_methods(schemas.keys()))
 | 
						|
 | 
						|
 | 
						|
def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
 | 
						|
    schema_classes_lines = schema_class_file.read_text().splitlines()
 | 
						|
    line_lookup_table = {line: i for i, line in enumerate(schema_classes_lines)}
 | 
						|
 | 
						|
    # Import the _Aspect class.
 | 
						|
    schema_classes_lines[
 | 
						|
        line_lookup_table["__SCHEMAS: Dict[str, RecordSchema] = {}"]
 | 
						|
    ] += """
 | 
						|
 | 
						|
from datahub._codegen.aspect import _Aspect
 | 
						|
"""
 | 
						|
 | 
						|
    for aspect in aspects:
 | 
						|
        className = f"{aspect['name']}Class"
 | 
						|
        aspectName = aspect["Aspect"]["name"]
 | 
						|
        class_def_original = f"class {className}(DictWrapper):"
 | 
						|
 | 
						|
        # Make the aspects inherit from the Aspect class.
 | 
						|
        class_def_line = line_lookup_table[class_def_original]
 | 
						|
        schema_classes_lines[class_def_line] = f"class {className}(_Aspect):"
 | 
						|
 | 
						|
        # Define the ASPECT_NAME class attribute.
 | 
						|
        # There's always an empty line between the docstring and the RECORD_SCHEMA class attribute.
 | 
						|
        # We need to find it and insert our line of code there.
 | 
						|
        empty_line = class_def_line + 1
 | 
						|
        while not (
 | 
						|
            schema_classes_lines[empty_line].strip() == ""
 | 
						|
            and schema_classes_lines[empty_line + 1]
 | 
						|
            .strip()
 | 
						|
            .startswith("RECORD_SCHEMA = ")
 | 
						|
        ):
 | 
						|
            empty_line += 1
 | 
						|
        schema_classes_lines[empty_line] = "\n"
 | 
						|
        schema_classes_lines[empty_line] += f"\n    ASPECT_NAME = '{aspectName}'"
 | 
						|
        if "type" in aspect["Aspect"]:
 | 
						|
            schema_classes_lines[empty_line] += (
 | 
						|
                f"\n    ASPECT_TYPE = '{aspect['Aspect']['type']}'"
 | 
						|
            )
 | 
						|
 | 
						|
        aspect_info = {
 | 
						|
            k: v for k, v in aspect["Aspect"].items() if k not in {"name", "type"}
 | 
						|
        }
 | 
						|
        schema_classes_lines[empty_line] += f"\n    ASPECT_INFO = {aspect_info}"
 | 
						|
 | 
						|
        schema_classes_lines[empty_line + 1] += "\n"
 | 
						|
 | 
						|
    # Finally, generate a big list of all available aspects.
 | 
						|
    newline = "\n"
 | 
						|
    schema_classes_lines.append(
 | 
						|
        f"""
 | 
						|
ASPECT_CLASSES: List[Type[_Aspect]] = [
 | 
						|
    {f",{newline}    ".join(f"{aspect['name']}Class" for aspect in aspects)}
 | 
						|
]
 | 
						|
 | 
						|
ASPECT_NAME_MAP: Dict[str, Type[_Aspect]] = {{
 | 
						|
    aspect.get_aspect_name(): aspect
 | 
						|
    for aspect in ASPECT_CLASSES
 | 
						|
}}
 | 
						|
 | 
						|
from typing import Literal
 | 
						|
from typing_extensions import TypedDict
 | 
						|
 | 
						|
class AspectBag(TypedDict, total=False):
 | 
						|
    {f"{newline}    ".join(f"{aspect['Aspect']['name']}: {aspect['name']}Class" for aspect in aspects)}
 | 
						|
 | 
						|
 | 
						|
KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{
 | 
						|
    {f",{newline}    ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
 | 
						|
}}
 | 
						|
 | 
						|
ENTITY_TYPE_NAMES: List[str] = [
 | 
						|
    {f",{newline}    ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
 | 
						|
]
 | 
						|
EntityTypeName = Literal[
 | 
						|
    {f",{newline}    ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
 | 
						|
]
 | 
						|
"""
 | 
						|
    )
 | 
						|
 | 
						|
    schema_class_file.write_text("\n".join(schema_classes_lines))
 | 
						|
 | 
						|
 | 
						|
def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None:
 | 
						|
    urn_dir.mkdir()
 | 
						|
 | 
						|
    (urn_dir / "__init__.py").write_text("\n# This file is intentionally left empty.")
 | 
						|
 | 
						|
    code = """
 | 
						|
# This file contains classes corresponding to entity URNs.
 | 
						|
 | 
						|
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union, Literal
 | 
						|
 | 
						|
import functools
 | 
						|
from deprecated.sphinx import deprecated as _sphinx_deprecated
 | 
						|
 | 
						|
from datahub.utilities.urn_encoder import UrnEncoder
 | 
						|
from datahub.utilities.urns._urn_base import _SpecificUrn, Urn
 | 
						|
from datahub.utilities.urns.error import InvalidUrnError
 | 
						|
 | 
						|
deprecated = functools.partial(_sphinx_deprecated, version="0.12.0.2")
 | 
						|
"""
 | 
						|
 | 
						|
    for aspect in key_aspects:
 | 
						|
        entity_type = aspect["Aspect"]["keyForEntity"]
 | 
						|
        code += generate_urn_class(entity_type, aspect)
 | 
						|
 | 
						|
    (urn_dir / "urn_defs.py").write_text(code)
 | 
						|
 | 
						|
 | 
						|
def capitalize_entity_name(entity_name: str) -> str:
 | 
						|
    # Examples:
 | 
						|
    # corpuser -> CorpUser
 | 
						|
    # corpGroup -> CorpGroup
 | 
						|
    # mlModelDeployment -> MlModelDeployment
 | 
						|
 | 
						|
    if entity_name == "corpuser":
 | 
						|
        return "CorpUser"
 | 
						|
 | 
						|
    return f"{entity_name[0].upper()}{entity_name[1:]}"
 | 
						|
 | 
						|
 | 
						|
def python_type(avro_type: str) -> str:
 | 
						|
    if avro_type == "string":
 | 
						|
        return "str"
 | 
						|
    elif (
 | 
						|
        isinstance(avro_type, dict)
 | 
						|
        and avro_type.get("type") == "enum"
 | 
						|
        and avro_type.get("name") == "FabricType"
 | 
						|
    ):
 | 
						|
        # TODO: make this stricter using an enum
 | 
						|
        return "str"
 | 
						|
    raise ValueError(f"unknown type {avro_type}")
 | 
						|
 | 
						|
 | 
						|
def field_type(field: dict) -> str:
 | 
						|
    return python_type(field["type"])
 | 
						|
 | 
						|
 | 
						|
def field_name(field: dict) -> str:
 | 
						|
    manual_mapping = {
 | 
						|
        "origin": "env",
 | 
						|
        "platformName": "platform_name",
 | 
						|
    }
 | 
						|
 | 
						|
    name: str = field["name"]
 | 
						|
    if name in manual_mapping:
 | 
						|
        return manual_mapping[name]
 | 
						|
 | 
						|
    # If the name is mixed case, convert to snake case.
 | 
						|
    if name.lower() != name:
 | 
						|
        # Inject an underscore before each capital letter, and then convert to lowercase.
 | 
						|
        return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
 | 
						|
 | 
						|
    return name
 | 
						|
 | 
						|
 | 
						|
_create_from_id = """
 | 
						|
@classmethod
 | 
						|
@deprecated(reason="Use the constructor instead")
 | 
						|
def create_from_id(cls, id: str) -> "{class_name}":
 | 
						|
    return cls(id)
 | 
						|
"""
 | 
						|
_extra_urn_methods: Dict[str, List[str]] = {
 | 
						|
    "corpGroup": [_create_from_id.format(class_name="CorpGroupUrn")],
 | 
						|
    "corpuser": [_create_from_id.format(class_name="CorpUserUrn")],
 | 
						|
    "dataFlow": [
 | 
						|
        """
 | 
						|
@classmethod
 | 
						|
def create_from_ids(
 | 
						|
    cls,
 | 
						|
    orchestrator: str,
 | 
						|
    flow_id: str,
 | 
						|
    env: str,
 | 
						|
    platform_instance: Optional[str] = None,
 | 
						|
) -> "DataFlowUrn":
 | 
						|
    return cls(
 | 
						|
        orchestrator=orchestrator,
 | 
						|
        flow_id=f"{platform_instance}.{flow_id}" if platform_instance else flow_id,
 | 
						|
        cluster=env,
 | 
						|
    )
 | 
						|
 | 
						|
@deprecated(reason="Use .orchestrator instead")
 | 
						|
def get_orchestrator_name(self) -> str:
 | 
						|
    return self.orchestrator
 | 
						|
 | 
						|
@deprecated(reason="Use .flow_id instead")
 | 
						|
def get_flow_id(self) -> str:
 | 
						|
    return self.flow_id
 | 
						|
 | 
						|
@deprecated(reason="Use .cluster instead")
 | 
						|
def get_env(self) -> str:
 | 
						|
    return self.cluster
 | 
						|
""",
 | 
						|
    ],
 | 
						|
    "dataJob": [
 | 
						|
        """
 | 
						|
@classmethod
 | 
						|
def create_from_ids(cls, data_flow_urn: str, job_id: str) -> "DataJobUrn":
 | 
						|
    return cls(data_flow_urn, job_id)
 | 
						|
 | 
						|
def get_data_flow_urn(self) -> "DataFlowUrn":
 | 
						|
    return DataFlowUrn.from_string(self.flow)
 | 
						|
 | 
						|
@deprecated(reason="Use .job_id instead")
 | 
						|
def get_job_id(self) -> str:
 | 
						|
    return self.job_id
 | 
						|
"""
 | 
						|
    ],
 | 
						|
    "dataPlatform": [_create_from_id.format(class_name="DataPlatformUrn")],
 | 
						|
    "dataProcessInstance": [
 | 
						|
        _create_from_id.format(class_name="DataProcessInstanceUrn"),
 | 
						|
        """
 | 
						|
@deprecated(reason="Use .id instead")
 | 
						|
def get_dataprocessinstance_id(self) -> str:
 | 
						|
    return self.id
 | 
						|
""",
 | 
						|
    ],
 | 
						|
    "dataset": [
 | 
						|
        """
 | 
						|
@classmethod
 | 
						|
def create_from_ids(
 | 
						|
    cls,
 | 
						|
    platform_id: str,
 | 
						|
    table_name: str,
 | 
						|
    env: str,
 | 
						|
    platform_instance: Optional[str] = None,
 | 
						|
) -> "DatasetUrn":
 | 
						|
    return DatasetUrn(
 | 
						|
        platform=platform_id,
 | 
						|
        name=f"{platform_instance}.{table_name}" if platform_instance else table_name,
 | 
						|
        env=env,
 | 
						|
    )
 | 
						|
 | 
						|
from datahub.utilities.urns.field_paths import get_simple_field_path_from_v2_field_path as _get_simple_field_path_from_v2_field_path
 | 
						|
 | 
						|
get_simple_field_path_from_v2_field_path = staticmethod(deprecated(reason='Use the function from the field_paths module instead')(_get_simple_field_path_from_v2_field_path))
 | 
						|
 | 
						|
def get_data_platform_urn(self) -> "DataPlatformUrn":
 | 
						|
    return DataPlatformUrn.from_string(self.platform)
 | 
						|
 | 
						|
@deprecated(reason="Use .name instead")
 | 
						|
def get_dataset_name(self) -> str:
 | 
						|
    return self.name
 | 
						|
 | 
						|
@deprecated(reason="Use .env instead")
 | 
						|
def get_env(self) -> str:
 | 
						|
    return self.env
 | 
						|
"""
 | 
						|
    ],
 | 
						|
    "domain": [_create_from_id.format(class_name="DomainUrn")],
 | 
						|
    "notebook": [
 | 
						|
        """
 | 
						|
@deprecated(reason="Use .notebook_tool instead")
 | 
						|
def get_platform_id(self) -> str:
 | 
						|
    return self.notebook_tool
 | 
						|
 | 
						|
@deprecated(reason="Use .notebook_id instead")
 | 
						|
def get_notebook_id(self) -> str:
 | 
						|
    return self.notebook_id
 | 
						|
"""
 | 
						|
    ],
 | 
						|
    "tag": [_create_from_id.format(class_name="TagUrn")],
 | 
						|
    "chart": [
 | 
						|
        """
 | 
						|
@classmethod
 | 
						|
def create_from_ids(
 | 
						|
    cls,
 | 
						|
    platform: str,
 | 
						|
    name: str,
 | 
						|
    platform_instance: Optional[str] = None,
 | 
						|
) -> "ChartUrn":
 | 
						|
    return ChartUrn(
 | 
						|
        dashboard_tool=platform,
 | 
						|
        chart_id=f"{platform_instance}.{name}" if platform_instance else name,
 | 
						|
    )
 | 
						|
        """
 | 
						|
    ],
 | 
						|
    "dashboard": [
 | 
						|
        """
 | 
						|
@classmethod
 | 
						|
def create_from_ids(
 | 
						|
    cls,
 | 
						|
    platform: str,
 | 
						|
    name: str,
 | 
						|
    platform_instance: Optional[str] = None,
 | 
						|
) -> "DashboardUrn":
 | 
						|
    return DashboardUrn(
 | 
						|
        dashboard_tool=platform,
 | 
						|
        dashboard_id=f"{platform_instance}.{name}" if platform_instance else name,
 | 
						|
    )
 | 
						|
        """
 | 
						|
    ],
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
 | 
						|
    """Generate a class definition for this entity.
 | 
						|
 | 
						|
    The class definition has the following structure:
 | 
						|
    - A class attribute ENTITY_TYPE, which is the entity type string.
 | 
						|
    - A class attribute URN_PARTS, which is the number of parts in the URN.
 | 
						|
    - A constructor that takes the URN parts as arguments. The field names
 | 
						|
      will match the key aspect's field names. It will also have a _allow_coercion
 | 
						|
      flag, which will allow for some normalization (e.g. upper case env).
 | 
						|
      Then, each part will be validated (including nested calls for urn subparts).
 | 
						|
    - Utilities for converting to/from the key aspect.
 | 
						|
    - Any additional methods that are required for this entity type, defined above.
 | 
						|
      These are primarily for backwards compatibility.
 | 
						|
    - Getter methods for each field.
 | 
						|
    """
 | 
						|
 | 
						|
    class_name = f"{capitalize_entity_name(entity_type)}Urn"
 | 
						|
 | 
						|
    fields = copy.deepcopy(key_aspect["fields"])
 | 
						|
    if entity_type == "container":
 | 
						|
        # The annotations say guid is optional, but it is required.
 | 
						|
        # This is a quick fix of the annotations.
 | 
						|
        assert field_name(fields[0]) == "guid"
 | 
						|
        assert fields[0]["type"] == ["null", "string"]
 | 
						|
        fields[0]["type"] = "string"
 | 
						|
    arg_count = len(fields)
 | 
						|
 | 
						|
    field_urn_type_classes = {}
 | 
						|
    for field in fields:
 | 
						|
        # Figure out if urn types are valid for each field.
 | 
						|
        field_urn_type_class = None
 | 
						|
        if field_name(field) == "platform":
 | 
						|
            field_urn_type_class = "DataPlatformUrn"
 | 
						|
        elif field.get("Urn"):
 | 
						|
            if len(field.get("entityTypes", [])) == 1:
 | 
						|
                field_entity_type = field["entityTypes"][0]
 | 
						|
                field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn"
 | 
						|
            else:
 | 
						|
                field_urn_type_class = "Urn"
 | 
						|
 | 
						|
        field_urn_type_classes[field_name(field)] = field_urn_type_class
 | 
						|
    if arg_count == 1:
 | 
						|
        field = fields[0]
 | 
						|
 | 
						|
        if field_urn_type_classes[field_name(field)] is None:
 | 
						|
            # All single-arg urn types should accept themselves.
 | 
						|
            field_urn_type_classes[field_name(field)] = class_name
 | 
						|
 | 
						|
    _init_arg_parts: List[str] = []
 | 
						|
    for field in fields:
 | 
						|
        field_urn_type_class = field_urn_type_classes[field_name(field)]
 | 
						|
 | 
						|
        default = '"PROD"' if field_name(field) == "env" else None
 | 
						|
 | 
						|
        type_hint = field_type(field)
 | 
						|
        if field_urn_type_class:
 | 
						|
            type_hint = f'Union["{field_urn_type_class}", str]'
 | 
						|
        _arg_part = f"{field_name(field)}: {type_hint}"
 | 
						|
        if default:
 | 
						|
            _arg_part += f" = {default}"
 | 
						|
        _init_arg_parts.append(_arg_part)
 | 
						|
    init_args = ", ".join(_init_arg_parts)
 | 
						|
 | 
						|
    super_init_args = ", ".join(field_name(field) for field in fields)
 | 
						|
 | 
						|
    parse_ids_mapping = ", ".join(
 | 
						|
        f"{field_name(field)}=entity_ids[{i}]" for i, field in enumerate(fields)
 | 
						|
    )
 | 
						|
 | 
						|
    key_aspect_class = f"{key_aspect['name']}Class"
 | 
						|
    to_key_aspect_args = ", ".join(
 | 
						|
        # The LHS bypasses any field name aliases.
 | 
						|
        f"{field['name']}=self.{field_name(field)}"
 | 
						|
        for field in fields
 | 
						|
    )
 | 
						|
    from_key_aspect_args = ", ".join(
 | 
						|
        f"{field_name(field)}=key_aspect.{field['name']}" for field in fields
 | 
						|
    )
 | 
						|
 | 
						|
    init_coercion = ""
 | 
						|
    init_validation = ""
 | 
						|
    for field in fields:
 | 
						|
        init_validation += f'if not {field_name(field)}:\n    raise InvalidUrnError("{class_name} {field_name(field)} cannot be empty")\n'
 | 
						|
 | 
						|
        # Generalized mechanism for validating embedded urns.
 | 
						|
        field_urn_type_class = field_urn_type_classes[field_name(field)]
 | 
						|
        if field_urn_type_class and field_urn_type_class == class_name:
 | 
						|
            # First, we need to extract the main piece from the urn type.
 | 
						|
            init_validation += (
 | 
						|
                f"if isinstance({field_name(field)}, {field_urn_type_class}):\n"
 | 
						|
                f"    {field_name(field)} = {field_name(field)}.{field_name(field)}\n"
 | 
						|
            )
 | 
						|
 | 
						|
            # If it's still an urn type, that's a problem.
 | 
						|
            init_validation += (
 | 
						|
                f"elif isinstance({field_name(field)}, Urn):\n"
 | 
						|
                f"    raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
 | 
						|
            )
 | 
						|
 | 
						|
            # Then, we do character validation as normal.
 | 
						|
            init_validation += (
 | 
						|
                f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
 | 
						|
                f"    raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
 | 
						|
            )
 | 
						|
        elif field_urn_type_class:
 | 
						|
            init_validation += f"{field_name(field)} = str({field_name(field)})  # convert urn type to str\n"
 | 
						|
            init_validation += (
 | 
						|
                f"assert {field_urn_type_class}.from_string({field_name(field)})\n"
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            init_validation += (
 | 
						|
                f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
 | 
						|
                f"    raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
 | 
						|
            )
 | 
						|
        # TODO add ALL_ENV_TYPES validation
 | 
						|
 | 
						|
        # Field coercion logic.
 | 
						|
        if field_name(field) == "env":
 | 
						|
            init_coercion += "env = env.upper()\n"
 | 
						|
        elif field_name(field) == "platform":
 | 
						|
            # For platform names in particular, we also qualify them when they don't have the prefix.
 | 
						|
            # We can rely on the DataPlatformUrn constructor to do this prefixing.
 | 
						|
            init_coercion += "platform = DataPlatformUrn(platform).urn()\n"
 | 
						|
        elif field_urn_type_class is not None:
 | 
						|
            # For urn types, we need to parse them into urn types where appropriate.
 | 
						|
            # Otherwise, we just need to encode special characters.
 | 
						|
            init_coercion += (
 | 
						|
                f"if isinstance({field_name(field)}, str):\n"
 | 
						|
                f"    if {field_name(field)}.startswith('urn:li:'):\n"
 | 
						|
                f"        try:\n"
 | 
						|
                f"            {field_name(field)} = {field_urn_type_class}.from_string({field_name(field)})\n"
 | 
						|
                f"        except InvalidUrnError:\n"
 | 
						|
                f"            raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
 | 
						|
                f"    else:\n"
 | 
						|
                f"        {field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            # For all non-urns, run the value through the UrnEncoder.
 | 
						|
            init_coercion += (
 | 
						|
                f"{field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
 | 
						|
            )
 | 
						|
    if not init_coercion:
 | 
						|
        init_coercion = "pass"
 | 
						|
 | 
						|
    # TODO include the docs for each field
 | 
						|
 | 
						|
    code = f"""
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from datahub.metadata.schema_classes import {key_aspect_class}
 | 
						|
 | 
						|
class {class_name}(_SpecificUrn):
 | 
						|
    ENTITY_TYPE: ClassVar[Literal["{entity_type}"]] = "{entity_type}"
 | 
						|
    _URN_PARTS: ClassVar[int] = {arg_count}
 | 
						|
 | 
						|
    def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:
 | 
						|
        if _allow_coercion:
 | 
						|
            # Field coercion logic (if any is required).
 | 
						|
{textwrap.indent(init_coercion.strip(), prefix=" " * 4 * 3)}
 | 
						|
 | 
						|
        # Validation logic.
 | 
						|
{textwrap.indent(init_validation.strip(), prefix=" " * 4 * 2)}
 | 
						|
 | 
						|
        super().__init__(self.ENTITY_TYPE, [{super_init_args}])
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _parse_ids(cls, entity_ids: List[str]) -> "{class_name}":
 | 
						|
        if len(entity_ids) != cls._URN_PARTS:
 | 
						|
            raise InvalidUrnError(f"{class_name} should have {{cls._URN_PARTS}} parts, got {{len(entity_ids)}}: {{entity_ids}}")
 | 
						|
        return cls({parse_ids_mapping}, _allow_coercion=False)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def underlying_key_aspect_type(cls) -> Type["{key_aspect_class}"]:
 | 
						|
        from datahub.metadata.schema_classes import {key_aspect_class}
 | 
						|
 | 
						|
        return {key_aspect_class}
 | 
						|
 | 
						|
    def to_key_aspect(self) -> "{key_aspect_class}":
 | 
						|
        from datahub.metadata.schema_classes import {key_aspect_class}
 | 
						|
 | 
						|
        return {key_aspect_class}({to_key_aspect_args})
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_key_aspect(cls, key_aspect: "{key_aspect_class}") -> "{class_name}":
 | 
						|
        return cls({from_key_aspect_args})
 | 
						|
"""
 | 
						|
 | 
						|
    for extra_method in _extra_urn_methods.get(entity_type, []):
 | 
						|
        code += textwrap.indent(extra_method, prefix=" " * 4)
 | 
						|
 | 
						|
    for i, field in enumerate(fields):
 | 
						|
        code += f"""
 | 
						|
    @property
 | 
						|
    def {field_name(field)}(self) -> {field_type(field)}:
 | 
						|
        return self._entity_ids[{i}]
 | 
						|
"""
 | 
						|
 | 
						|
    return code
 | 
						|
 | 
						|
 | 
						|
@click.command()
 | 
						|
@click.argument(
 | 
						|
    "entity_registry", type=click.Path(exists=True, dir_okay=False), required=True
 | 
						|
)
 | 
						|
@click.argument(
 | 
						|
    "pdl_path", type=click.Path(exists=True, file_okay=False), required=True
 | 
						|
)
 | 
						|
@click.argument(
 | 
						|
    "schemas_path", type=click.Path(exists=True, file_okay=False), required=True
 | 
						|
)
 | 
						|
@click.argument("outdir", type=click.Path(), required=True)
 | 
						|
@click.option("--check-unused-aspects", is_flag=True, default=False)
 | 
						|
@click.option("--enable-custom-loader", is_flag=True, default=True)
 | 
						|
def generate(
 | 
						|
    entity_registry: str,
 | 
						|
    pdl_path: str,
 | 
						|
    schemas_path: str,
 | 
						|
    outdir: str,
 | 
						|
    check_unused_aspects: bool,
 | 
						|
    enable_custom_loader: bool,
 | 
						|
) -> None:
 | 
						|
    entities = load_entity_registry(Path(entity_registry))
 | 
						|
    schemas = load_schemas(schemas_path)
 | 
						|
 | 
						|
    # Patch the avsc files.
 | 
						|
    schemas = patch_schemas(schemas, Path(pdl_path))
 | 
						|
 | 
						|
    # Special handling for aspects.
 | 
						|
    aspects = {
 | 
						|
        schema["Aspect"]["name"]: schema
 | 
						|
        for schema in schemas.values()
 | 
						|
        if schema.get("Aspect")
 | 
						|
    }
 | 
						|
 | 
						|
    # Copy entity registry info into the corresponding key aspect.
 | 
						|
    for entity in entities:
 | 
						|
        # This implicitly requires that all keyAspects are resolvable.
 | 
						|
        aspect = aspects[entity.keyAspect]
 | 
						|
 | 
						|
        # This requires that entities cannot share a keyAspect.
 | 
						|
        if (
 | 
						|
            "keyForEntity" in aspect["Aspect"]
 | 
						|
            and aspect["Aspect"]["keyForEntity"] != entity.name
 | 
						|
        ):
 | 
						|
            raise ValueError(
 | 
						|
                f"Entity key {entity.keyAspect} is used by {aspect['Aspect']['keyForEntity']} and {entity.name}"
 | 
						|
            )
 | 
						|
 | 
						|
        # Also require that the aspect list is deduplicated.
 | 
						|
        duplicate_aspects = collections.Counter(entity.aspects) - collections.Counter(
 | 
						|
            set(entity.aspects)
 | 
						|
        )
 | 
						|
        if duplicate_aspects:
 | 
						|
            raise ValueError(
 | 
						|
                f"Entity {entity.name} has duplicate aspects: {duplicate_aspects}"
 | 
						|
            )
 | 
						|
 | 
						|
        aspect["Aspect"]["keyForEntity"] = entity.name
 | 
						|
        aspect["Aspect"]["entityCategory"] = entity.category
 | 
						|
        aspect["Aspect"]["entityAspects"] = entity.aspects
 | 
						|
        if entity.doc is not None:
 | 
						|
            aspect["Aspect"]["entityDoc"] = entity.doc
 | 
						|
 | 
						|
    # Check for unused aspects. We currently have quite a few.
 | 
						|
    if check_unused_aspects:
 | 
						|
        unused_aspects = set(aspects.keys()) - set().union(
 | 
						|
            {entity.keyAspect for entity in entities},
 | 
						|
            *(set(entity.aspects) for entity in entities),
 | 
						|
        )
 | 
						|
        if unused_aspects:
 | 
						|
            raise ValueError(f"Unused aspects: {unused_aspects}")
 | 
						|
 | 
						|
    merged_schema = merge_schemas(list(schemas.values()))
 | 
						|
    write_schema_files(merged_schema, outdir)
 | 
						|
 | 
						|
    # Schema files post-processing.
 | 
						|
    (Path(outdir) / "__init__.py").write_text("# This file is intentionally empty.\n")
 | 
						|
    add_avro_python3_warning(Path(outdir) / "schema_classes.py")
 | 
						|
    annotate_aspects(
 | 
						|
        list(aspects.values()),
 | 
						|
        Path(outdir) / "schema_classes.py",
 | 
						|
    )
 | 
						|
 | 
						|
    if enable_custom_loader:
 | 
						|
        # Move schema_classes.py -> _schema_classes.py
 | 
						|
        # and add a custom loader.
 | 
						|
        (Path(outdir) / "_schema_classes.py").write_text(
 | 
						|
            (Path(outdir) / "schema_classes.py").read_text()
 | 
						|
        )
 | 
						|
        (Path(outdir) / "schema_classes.py").write_text(
 | 
						|
            """
 | 
						|
# This is a specialized shim layer that allows us to dynamically load custom models from elsewhere.
 | 
						|
 | 
						|
import importlib
 | 
						|
from typing import TYPE_CHECKING
 | 
						|
 | 
						|
from datahub._codegen.aspect import _Aspect as _Aspect
 | 
						|
from datahub.utilities.docs_build import IS_SPHINX_BUILD
 | 
						|
from datahub.utilities._custom_package_loader import get_custom_models_package
 | 
						|
 | 
						|
_custom_package_path = get_custom_models_package()
 | 
						|
 | 
						|
if TYPE_CHECKING or not _custom_package_path:
 | 
						|
    from ._schema_classes import *
 | 
						|
 | 
						|
    # Required explicitly because __all__ doesn't include _ prefixed names.
 | 
						|
    from ._schema_classes import __SCHEMA_TYPES
 | 
						|
 | 
						|
    if IS_SPHINX_BUILD:
 | 
						|
        # Set __module__ to the current module so that Sphinx will document the
 | 
						|
        # classes as belonging to this module instead of the custom package.
 | 
						|
        for _cls in list(globals().values()):
 | 
						|
            if hasattr(_cls, "__module__") and "datahub.metadata._schema_classes" in _cls.__module__:
 | 
						|
                _cls.__module__ = __name__
 | 
						|
else:
 | 
						|
    _custom_package = importlib.import_module(_custom_package_path)
 | 
						|
    globals().update(_custom_package.__dict__)
 | 
						|
"""
 | 
						|
        )
 | 
						|
 | 
						|
        (Path(outdir) / "urns.py").write_text(
 | 
						|
            """
 | 
						|
# This is a specialized shim layer that allows us to dynamically load custom URN types from elsewhere.
 | 
						|
 | 
						|
import importlib
 | 
						|
from typing import TYPE_CHECKING
 | 
						|
 | 
						|
from datahub.utilities.docs_build import IS_SPHINX_BUILD
 | 
						|
from datahub.utilities._custom_package_loader import get_custom_urns_package
 | 
						|
from datahub.utilities.urns._urn_base import Urn as Urn  # noqa: F401
 | 
						|
 | 
						|
_custom_package_path = get_custom_urns_package()
 | 
						|
 | 
						|
if TYPE_CHECKING or not _custom_package_path:
 | 
						|
    from ._urns.urn_defs import *  # noqa: F401
 | 
						|
 | 
						|
    if IS_SPHINX_BUILD:
 | 
						|
        # Set __module__ to the current module so that Sphinx will document the
 | 
						|
        # classes as belonging to this module instead of the custom package.
 | 
						|
        for _cls in list(globals().values()):
 | 
						|
            if hasattr(_cls, "__module__") and ("datahub.metadata._urns.urn_defs" in _cls.__module__ or _cls is Urn):
 | 
						|
                _cls.__module__ = __name__
 | 
						|
else:
 | 
						|
    _custom_package = importlib.import_module(_custom_package_path)
 | 
						|
    globals().update(_custom_package.__dict__)
 | 
						|
"""
 | 
						|
        )
 | 
						|
 | 
						|
    # Generate URN classes.
 | 
						|
    urn_dir = Path(outdir) / "_urns"
 | 
						|
    write_urn_classes(
 | 
						|
        [aspect for aspect in aspects.values() if aspect["Aspect"].get("keyForEntity")],
 | 
						|
        urn_dir,
 | 
						|
    )
 | 
						|
 | 
						|
    # Save raw schema files in codegen as well.
 | 
						|
    schema_save_dir = Path(outdir) / "schemas"
 | 
						|
    schema_save_dir.mkdir()
 | 
						|
    for schema_out_file, schema in schemas.items():
 | 
						|
        (schema_save_dir / f"{schema_out_file}.avsc").write_text(
 | 
						|
            json.dumps(schema, indent=2)
 | 
						|
        )
 | 
						|
 | 
						|
    # Keep a copy of a few raw avsc files.
 | 
						|
    required_avsc_schemas = {"MetadataChangeEvent", "MetadataChangeProposal"}
 | 
						|
    save_raw_schemas(
 | 
						|
        schema_save_dir,
 | 
						|
        {
 | 
						|
            name: schema
 | 
						|
            for name, schema in schemas.items()
 | 
						|
            if name in required_avsc_schemas
 | 
						|
        },
 | 
						|
    )
 | 
						|
 | 
						|
    # Add headers for all generated files
 | 
						|
    generated_files = Path(outdir).glob("**/*.py")
 | 
						|
    for file in generated_files:
 | 
						|
        suppress_checks_in_file(file)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    generate()
 |