datahub/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py
Aditya Malik 92b1cfa194
feat(ingest): Support for JSONL in s3 source with max_rows support (#9921)
Co-authored-by: Aditya <aditya.malik@quillbot.com>
Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
2024-02-28 15:05:30 +01:00

138 lines
4.2 KiB
Python

import tempfile
from typing import List, Type
import pandas as pd
import ujson
from avro import schema as avro_schema
from avro.datafile import DataFileWriter
from avro.io import DatumWriter
from datahub.ingestion.source.schema_inference import csv_tsv, json, parquet
from datahub.ingestion.source.schema_inference.avro import AvroInferrer
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
BooleanTypeClass,
NumberTypeClass,
SchemaField,
StringTypeClass,
)
from tests.unit.test_schema_util import assert_field_paths_match
expected_field_paths = [
"integer_field",
"boolean_field",
"string_field",
]
expected_field_paths_avro = [
"[version=2.0].[type=test].[type=int].integer_field",
"[version=2.0].[type=test].[type=boolean].boolean_field",
"[version=2.0].[type=test].[type=string].string_field",
]
expected_field_types = [NumberTypeClass, BooleanTypeClass, StringTypeClass]
test_table = pd.DataFrame(
{
"integer_field": [1, 2, 3],
"boolean_field": [True, False, True],
"string_field": ["a", "b", "c"],
}
)
def assert_field_types_match(
fields: List[SchemaField], expected_field_types: List[Type]
) -> None:
assert len(fields) == len(expected_field_types)
for field, expected_type in zip(fields, expected_field_types):
assert isinstance(field.type.type, expected_type)
def test_infer_schema_csv():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(bytes(test_table.to_csv(index=False, header=True), encoding="utf-8"))
file.seek(0)
fields = csv_tsv.CsvInferrer(max_rows=100).infer_schema(file)
assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
def test_infer_schema_tsv():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(
bytes(
test_table.to_csv(index=False, header=True, sep="\t"), encoding="utf-8"
)
)
file.seek(0)
fields = csv_tsv.TsvInferrer(max_rows=100).infer_schema(file)
assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
def test_infer_schema_jsonl():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(
bytes(test_table.to_json(orient="records", lines=True), encoding="utf-8")
)
file.seek(0)
fields = json.JsonInferrer(max_rows=100, format="jsonl").infer_schema(file)
assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
def test_infer_schema_json():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(bytes(test_table.to_json(orient="records"), encoding="utf-8"))
file.seek(0)
fields = json.JsonInferrer().infer_schema(file)
assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
def test_infer_schema_parquet():
with tempfile.TemporaryFile(mode="w+b") as file:
test_table.to_parquet(file)
file.seek(0)
fields = parquet.ParquetInferrer().infer_schema(file)
assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
def test_infer_schema_avro():
with tempfile.TemporaryFile(mode="w+b") as file:
schema = avro_schema.parse(
ujson.dumps(
{
"type": "record",
"name": "test",
"fields": [
{"name": "integer_field", "type": "int"},
{"name": "boolean_field", "type": "boolean"},
{"name": "string_field", "type": "string"},
],
}
)
)
writer = DataFileWriter(file, DatumWriter(), schema)
records = test_table.to_dict(orient="records")
for record in records:
writer.append(record)
writer.sync()
file.seek(0)
fields = AvroInferrer().infer_schema(file)
assert_field_paths_match(fields, expected_field_paths_avro)
assert_field_types_match(fields, expected_field_types)