fix(ingest/json-schema): handle property inheritance in unions (#8121)

This commit is contained in:
Harshal Sheth 2023-05-30 22:59:28 -07:00 committed by GitHub
parent b42f518255
commit a29b576daa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 156 additions and 3 deletions

View File

@ -1,6 +1,6 @@
import json
import logging
import unittest
import unittest.mock
from hashlib import md5
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
@ -455,7 +455,25 @@ class JsonSchemaTranslator:
(union_category, union_category_schema) = [
(k, v) for k, v in union_category_map.items() if v
][0]
if not field_path.has_field_name() and len(union_category_schema) == 1:
# Special case: If this is a top-level field AND there is only one type in the
# union, we collapse down the union to avoid extra nesting.
union_schema = union_category_schema[0]
merged_union_schema = (
JsonSchemaTranslator._retain_parent_schema_props_in_union(
union_schema=union_schema, parent_schema=schema
)
)
yield from JsonSchemaTranslator.get_fields(
JsonSchemaTranslator._get_type_from_schema(merged_union_schema),
merged_union_schema,
required=required,
base_field_path=field_path,
)
return # this one is done
if field_path.has_field_name():
# The frontend expects the top-level field to be a record, so we only
# include the UnionTypeClass if we're not at the top level.
yield SchemaField(
fieldPath=field_path.expand_type("union", schema).as_string(),
type=type_override or SchemaFieldDataTypeClass(UnionTypeClass()),
@ -481,9 +499,14 @@ class JsonSchemaTranslator:
union_field_path._set_parent_type_if_not_exists(
DataHubType(type=UnionTypeClass, nested_type=union_type)
)
merged_union_schema = (
JsonSchemaTranslator._retain_parent_schema_props_in_union(
union_schema=union_schema, parent_schema=schema
)
)
yield from JsonSchemaTranslator.get_fields(
JsonSchemaTranslator._get_type_from_schema(union_schema),
union_schema,
JsonSchemaTranslator._get_type_from_schema(merged_union_schema),
merged_union_schema,
required=required,
base_field_path=union_field_path,
specific_type=union_type,
@ -491,6 +514,26 @@ class JsonSchemaTranslator:
else:
raise Exception(f"Unhandled type {datahub_field_type}")
@staticmethod
def _retain_parent_schema_props_in_union(
union_schema: Dict, parent_schema: Dict
) -> Dict:
"""Merge the "properties" and the "required" fields from the parent schema into the child union schema."""
union_schema = union_schema.copy()
if "properties" in parent_schema:
union_schema["properties"] = {
**parent_schema["properties"],
**union_schema.get("properties", {}),
}
if "required" in parent_schema:
union_schema["required"] = [
*parent_schema["required"],
*union_schema.get("required", []),
]
return union_schema
@staticmethod
def get_type_mapping(json_type: str) -> Type:
return JsonSchemaTranslator.field_type_mapping.get(json_type, NullTypeClass)

View File

@ -710,3 +710,113 @@ def test_required_field():
assert fields[1].nullable is True
assert json.loads(fields[1].jsonProps or "{}")["required"] is True
assert json.loads(fields[0].jsonProps or "{}")["required"] is False
def test_anyof_with_properties():
# We expect the event / timestamp fields to be included in both branches of the anyOf.
schema = {
"$id": "test",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"additionalProperties": False,
"anyOf": [{"required": ["anonymousId"]}, {"required": ["userId"]}],
"properties": {
"anonymousId": {
"description": "A temporary user id, used before a user logs in.",
"format": "uuid",
"type": "string",
},
"userId": {
"description": "Unique user id.",
"type": "string",
},
"event": {"description": "Unique name of the event.", "type": "string"},
"timestamp": {
"description": "Timestamp of when the message itself took place.",
"type": "string",
},
},
"required": ["event"],
"type": "object",
}
fields = list(JsonSchemaTranslator.get_fields_from_schema(schema))
expected_field_paths: List[str] = [
"[version=2.0].[type=union].[type=union_0].[type=string(uuid)].anonymousId",
"[version=2.0].[type=union].[type=union_0].[type=string].userId",
"[version=2.0].[type=union].[type=union_0].[type=string].event",
"[version=2.0].[type=union].[type=union_0].[type=string].timestamp",
"[version=2.0].[type=union].[type=union_1].[type=string(uuid)].anonymousId",
"[version=2.0].[type=union].[type=union_1].[type=string].userId",
"[version=2.0].[type=union].[type=union_1].[type=string].event",
"[version=2.0].[type=union].[type=union_1].[type=string].timestamp",
]
assert_field_paths_match(fields, expected_field_paths)
assert_fields_are_valid(fields)
# In the first one, the anonymousId is required, but the userId is not.
assert json.loads(fields[0].jsonProps or "{}")["required"] is True
assert json.loads(fields[1].jsonProps or "{}")["required"] is False
# In the second one, the userId is required, but the anonymousId is not.
assert json.loads(fields[4].jsonProps or "{}")["required"] is False
assert json.loads(fields[5].jsonProps or "{}")["required"] is True
# The event field is required in both branches.
assert json.loads(fields[2].jsonProps or "{}")["required"] is True
assert json.loads(fields[6].jsonProps or "{}")["required"] is True
# The timestamp field is not required in either branch.
assert json.loads(fields[3].jsonProps or "{}")["required"] is False
assert json.loads(fields[7].jsonProps or "{}")["required"] is False
def test_top_level_trival_allof():
schema = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "event-wrapper",
"type": "object",
"allOf": [
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "event",
"properties": {
"userId": {
"description": "Unique user id.",
"type": "string",
},
"event": {
"description": "Unique name of the event.",
"type": "string",
},
"timestamp": {
"description": "Timestamp of when the message itself took place.",
"type": "string",
},
},
"required": ["event"],
"type": "object",
"additionalProperties": False,
},
],
"properties": {
"extra-top-level-property": {
"type": "string",
},
},
}
fields = list(JsonSchemaTranslator.get_fields_from_schema(schema))
expected_field_paths: List[str] = [
"[version=2.0].[type=object].[type=string].extra-top-level-property",
"[version=2.0].[type=object].[type=string].userId",
"[version=2.0].[type=object].[type=string].event",
"[version=2.0].[type=object].[type=string].timestamp",
]
assert_field_paths_match(fields, expected_field_paths)
assert_fields_are_valid(fields)
assert json.loads(fields[0].jsonProps or "{}")["required"] is False
assert json.loads(fields[1].jsonProps or "{}")["required"] is False
assert json.loads(fields[2].jsonProps or "{}")["required"] is True
assert json.loads(fields[3].jsonProps or "{}")["required"] is False