Fixes #15565 : Advanced Avro schema recursion depth issue (#17683)

* Fix: Avro schema recursion depth issue

* py_format

* Addressed comments
This commit is contained in:
Suman Maharana 2024-09-10 18:26:42 +05:30 committed by GitHub
parent 8decd2338f
commit 094bae7097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 127 additions and 16 deletions

View File

@ -30,15 +30,19 @@ RECORD_DATATYPE_NAME = "RECORD"
def _parse_array_children(
arr_item: Schema, cls: Type[BaseModel] = FieldModel
arr_item: Schema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
) -> Tuple[str, Optional[Union[FieldModel, Column]]]:
if isinstance(arr_item, ArraySchema):
display_type, children = _parse_array_children(arr_item.items, cls=cls)
display_type, children = _parse_array_children(
arr_item.items, cls=cls, already_parsed=already_parsed
)
return f"ARRAY<{display_type}>", children
if isinstance(arr_item, UnionSchema):
display_type, children = _parse_union_children(
parent=None, union_field=arr_item, cls=cls
parent=None, union_field=arr_item, cls=cls, already_parsed=already_parsed
)
return f"UNION<{display_type}>", children
@ -46,7 +50,7 @@ def _parse_array_children(
child_obj = cls(
name=arr_item.name,
dataType=str(arr_item.type).upper(),
children=get_avro_fields(arr_item, cls),
children=get_avro_fields(arr_item, cls, already_parsed=already_parsed),
description=arr_item.doc,
)
return str(arr_item.type), child_obj
@ -55,7 +59,9 @@ def _parse_array_children(
def parse_array_fields(
field: ArraySchema, cls: Type[BaseModel] = FieldModel
field: ArraySchema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
) -> Optional[List[Union[FieldModel, Column]]]:
"""
Parse array field for avro schema
@ -93,7 +99,9 @@ def parse_array_fields(
description=field.doc,
)
display, children = _parse_array_children(arr_item=field.type.items, cls=cls)
display, children = _parse_array_children(
arr_item=field.type.items, cls=cls, already_parsed=already_parsed
)
obj.dataTypeDisplay = f"ARRAY<{display}>"
if cls == Column:
@ -109,6 +117,7 @@ def _parse_union_children(
parent: Optional[Schema],
union_field: UnionSchema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
) -> Tuple[str, Optional[Union[FieldModel, Column]]]:
non_null_schema = [
(i, schema)
@ -120,7 +129,9 @@ def _parse_union_children(
field = non_null_schema[0][1]
if isinstance(field, ArraySchema):
display, children = _parse_array_children(arr_item=field.items, cls=cls)
display, children = _parse_array_children(
arr_item=field.items, cls=cls, already_parsed=already_parsed
)
sub_type = [None, None]
sub_type[non_null_schema[0][0]] = f"ARRAY<{display}>"
sub_type[non_null_schema[0][0] ^ 1] = "null"
@ -131,7 +142,9 @@ def _parse_union_children(
children = cls(
name=field.name,
dataType=str(field.type).upper(),
children=None if field == parent else get_avro_fields(field, cls),
children=None
if field == parent
else get_avro_fields(field, cls, already_parsed),
description=field.doc,
)
return sub_type, children
@ -139,7 +152,11 @@ def _parse_union_children(
return sub_type, None
def parse_record_fields(field: RecordSchema, cls: Type[BaseModel] = FieldModel):
def parse_record_fields(
field: RecordSchema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
):
"""
Parse the nested record fields for avro
"""
@ -150,7 +167,7 @@ def parse_record_fields(field: RecordSchema, cls: Type[BaseModel] = FieldModel):
cls(
name=field.type.name,
dataType=RECORD_DATATYPE_NAME,
children=get_avro_fields(field.type, cls),
children=get_avro_fields(field.type, cls, already_parsed),
description=field.type.doc,
)
],
@ -163,6 +180,7 @@ def parse_union_fields(
parent: Optional[Schema],
union_field: Schema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
) -> Optional[List[Union[FieldModel, Column]]]:
"""
Parse union field for avro schema
@ -202,7 +220,7 @@ def parse_union_fields(
description=union_field.doc,
)
sub_type, children = _parse_union_children(
union_field=field_type, cls=cls, parent=parent
union_field=field_type, cls=cls, parent=parent, already_parsed=already_parsed
)
obj.dataTypeDisplay = f"UNION<{sub_type}>"
if children and cls == FieldModel:
@ -237,7 +255,7 @@ def parse_avro_schema(
cls(
name=parsed_schema.name,
dataType=str(parsed_schema.type).upper(),
children=get_avro_fields(parsed_schema, cls),
children=get_avro_fields(parsed_schema, cls, {}),
description=parsed_schema.doc,
)
]
@ -249,23 +267,40 @@ def parse_avro_schema(
def get_avro_fields(
parsed_schema: Schema, cls: Type[BaseModel] = FieldModel
parsed_schema: Schema,
cls: Type[BaseModel] = FieldModel,
already_parsed: Optional[dict] = None,
) -> Optional[List[Union[FieldModel, Column]]]:
"""
Recursively convert the parsed schema into required models
"""
field_models = []
if parsed_schema.name in already_parsed:
if already_parsed[parsed_schema.name] == parsed_schema.type:
return None
else:
already_parsed.update({parsed_schema.name: parsed_schema.type})
for field in parsed_schema.fields:
try:
if isinstance(field.type, ArraySchema):
field_models.append(parse_array_fields(field, cls=cls))
field_models.append(
parse_array_fields(field, cls=cls, already_parsed=already_parsed)
)
elif isinstance(field.type, UnionSchema):
field_models.append(
parse_union_fields(union_field=field, cls=cls, parent=parsed_schema)
parse_union_fields(
union_field=field,
cls=cls,
parent=parsed_schema,
already_parsed=already_parsed,
)
)
elif isinstance(field.type, RecordSchema):
field_models.append(parse_record_fields(field, cls=cls))
field_models.append(
parse_record_fields(field, cls=cls, already_parsed=already_parsed)
)
else:
field_models.append(parse_single_field(field, cls=cls))
except Exception as exc: # pylint: disable=broad-except

View File

@ -530,6 +530,55 @@ RECORD_INSIDE_RECORD = """
}
"""
RECURSION_ISSUE_SAMPLE = """
{
"type": "record",
"name": "RecursionIssue",
"namespace": "com.issue.recursion",
"doc": "Schema with recursion issue",
"fields": [
{
"name": "issue",
"type": {
"type": "record",
"name": "Issue",
"doc": "Global Schema Name",
"fields": [
{
"name": "itemList",
"default": null,
"type": [
"null",
{
"type": "array",
"items": {
"type": "record",
"name": "Item",
"doc": "Item List - Array of Sub Schema",
"fields": [
{
"name": "itemList",
"type": [
"null",
{
"type": "array",
"items": "Item"
}
],
"default": null
}
]
}
}
]
}
]
}
}
]
}
"""
class AvroParserTests(TestCase):
"""
@ -747,3 +796,30 @@ class AvroParserTests(TestCase):
.children[0]
.children
)
def test_recursive_issue_parsing(self):
recur_parsed_schema = parse_avro_schema(RECURSION_ISSUE_SAMPLE)
self.assertEqual(
recur_parsed_schema[0]
.children[0]
.children[0]
.children[0]
.children[0]
.name.root,
"Item",
)
self.assertEqual(
recur_parsed_schema[0].children[0].children[0].children[0].name.root,
"itemList",
)
self.assertIsNone(
recur_parsed_schema[0]
.children[0]
.children[0]
.children[0]
.children[0]
.children[0]
.children[0]
.children
)