128 lines
3.6 KiB
Python
Raw Normal View History

import json
import types
import unittest.mock
from pathlib import Path
from typing import Dict, Iterable, List, Union
import avro.schema
import click
2021-02-11 23:14:20 -08:00
from avrogen import write_schema_files
def load_schema_file(schema_file: str) -> str:
with open(schema_file) as f:
raw_schema_text = f.read()
redo_spaces = json.dumps(json.loads(raw_schema_text), indent=2)
return redo_spaces
def merge_schemas(schemas: List[str]) -> str:
# Combine schemas.
schemas_obj = [json.loads(schema) for schema in schemas]
merged = ["null"] + schemas_obj
# Deduplicate repeated names.
def Register(self, schema):
if schema.fullname in self._names:
# print(f"deduping {schema.fullname}")
pass
else:
self._names[schema.fullname] = schema
with unittest.mock.patch("avro.schema.Names.Register", Register):
cleaned_schema = avro.schema.SchemaFromJSONData(merged)
# Convert back to an Avro schema JSON representation.
class MappingProxyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, types.MappingProxyType):
return dict(obj)
return json.JSONEncoder.default(self, obj)
out_schema = cleaned_schema.to_json()
encoded = json.dumps(out_schema, cls=MappingProxyEncoder, indent=2)
return encoded
autogen_header = """# flake8: noqa
# This file is autogenerated by /metadata-ingestion/scripts/avro_codegen.py
# Do not modify manually!
# fmt: off
"""
autogen_footer = "# fmt: on\n"
def suppress_checks_in_file(filepath: Union[str, Path]) -> None:
"""
Adds a couple lines to the top of an autogenerated file:
- Comments to suppress flake8 and black.
- A note stating that the file was autogenerated.
"""
with open(filepath, "r+") as f:
contents = f.read()
f.seek(0, 0)
f.write(autogen_header)
f.write(contents)
f.write(autogen_footer)
load_schema_method = """
import functools
import pathlib
def _load_schema(schema_name: str) -> str:
return (pathlib.Path(__file__).parent / f"{schema_name}.avsc").read_text()
"""
individual_schema_method = """
@functools.lru_cache(maxsize=None)
def get{schema_name}Schema() -> str:
return _load_schema("{schema_name}")
"""
def make_load_schema_methods(schemas: Iterable[str]) -> str:
return load_schema_method + "".join(
individual_schema_method.format(schema_name=schema) for schema in schemas
2021-02-12 10:46:28 -08:00
)
@click.command()
@click.argument("schema_files", type=click.Path(exists=True), nargs=-1, required=True)
@click.argument("outdir", type=click.Path(), required=True)
def generate(schema_files: List[str], outdir: str) -> None:
schemas: Dict[str, str] = {}
for schema_file in schema_files:
schema = load_schema_file(schema_file)
schemas[Path(schema_file).stem] = schema
merged_schema = merge_schemas(list(schemas.values()))
write_schema_files(merged_schema, outdir)
with open(f"{outdir}/__init__.py", "w"):
# Truncate this file.
pass
# Save raw schema files in codegen as well.
schema_save_dir = Path(outdir) / "schemas"
schema_save_dir.mkdir()
for schema_out_file, schema in schemas.items():
(schema_save_dir / f"{schema_out_file}.avsc").write_text(schema)
# Add load_schema method.
with open(schema_save_dir / "__init__.py", "a") as schema_dir_init:
schema_dir_init.write(make_load_schema_methods(schemas.keys()))
# Add headers for all generated files
generated_files = Path(outdir).glob("**/*.py")
for file in generated_files:
suppress_checks_in_file(file)
if __name__ == "__main__":
generate()