mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-09 10:12:20 +00:00

Co-authored-by: Aditya <aditya.malik@quillbot.com> Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
138 lines
4.2 KiB
Python
138 lines
4.2 KiB
Python
import tempfile
|
|
from typing import List, Type
|
|
|
|
import pandas as pd
|
|
import ujson
|
|
from avro import schema as avro_schema
|
|
from avro.datafile import DataFileWriter
|
|
from avro.io import DatumWriter
|
|
|
|
from datahub.ingestion.source.schema_inference import csv_tsv, json, parquet
|
|
from datahub.ingestion.source.schema_inference.avro import AvroInferrer
|
|
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
|
BooleanTypeClass,
|
|
NumberTypeClass,
|
|
SchemaField,
|
|
StringTypeClass,
|
|
)
|
|
from tests.unit.test_schema_util import assert_field_paths_match
|
|
|
|
expected_field_paths = [
|
|
"integer_field",
|
|
"boolean_field",
|
|
"string_field",
|
|
]
|
|
|
|
expected_field_paths_avro = [
|
|
"[version=2.0].[type=test].[type=int].integer_field",
|
|
"[version=2.0].[type=test].[type=boolean].boolean_field",
|
|
"[version=2.0].[type=test].[type=string].string_field",
|
|
]
|
|
|
|
expected_field_types = [NumberTypeClass, BooleanTypeClass, StringTypeClass]
|
|
|
|
test_table = pd.DataFrame(
|
|
{
|
|
"integer_field": [1, 2, 3],
|
|
"boolean_field": [True, False, True],
|
|
"string_field": ["a", "b", "c"],
|
|
}
|
|
)
|
|
|
|
|
|
def assert_field_types_match(
|
|
fields: List[SchemaField], expected_field_types: List[Type]
|
|
) -> None:
|
|
assert len(fields) == len(expected_field_types)
|
|
for field, expected_type in zip(fields, expected_field_types):
|
|
assert isinstance(field.type.type, expected_type)
|
|
|
|
|
|
def test_infer_schema_csv():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
file.write(bytes(test_table.to_csv(index=False, header=True), encoding="utf-8"))
|
|
file.seek(0)
|
|
|
|
fields = csv_tsv.CsvInferrer(max_rows=100).infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths)
|
|
assert_field_types_match(fields, expected_field_types)
|
|
|
|
|
|
def test_infer_schema_tsv():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
file.write(
|
|
bytes(
|
|
test_table.to_csv(index=False, header=True, sep="\t"), encoding="utf-8"
|
|
)
|
|
)
|
|
file.seek(0)
|
|
|
|
fields = csv_tsv.TsvInferrer(max_rows=100).infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths)
|
|
assert_field_types_match(fields, expected_field_types)
|
|
|
|
|
|
def test_infer_schema_jsonl():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
file.write(
|
|
bytes(test_table.to_json(orient="records", lines=True), encoding="utf-8")
|
|
)
|
|
file.seek(0)
|
|
|
|
fields = json.JsonInferrer(max_rows=100, format="jsonl").infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths)
|
|
assert_field_types_match(fields, expected_field_types)
|
|
|
|
|
|
def test_infer_schema_json():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
file.write(bytes(test_table.to_json(orient="records"), encoding="utf-8"))
|
|
file.seek(0)
|
|
|
|
fields = json.JsonInferrer().infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths)
|
|
assert_field_types_match(fields, expected_field_types)
|
|
|
|
|
|
def test_infer_schema_parquet():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
test_table.to_parquet(file)
|
|
file.seek(0)
|
|
fields = parquet.ParquetInferrer().infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths)
|
|
assert_field_types_match(fields, expected_field_types)
|
|
|
|
|
|
def test_infer_schema_avro():
|
|
with tempfile.TemporaryFile(mode="w+b") as file:
|
|
schema = avro_schema.parse(
|
|
ujson.dumps(
|
|
{
|
|
"type": "record",
|
|
"name": "test",
|
|
"fields": [
|
|
{"name": "integer_field", "type": "int"},
|
|
{"name": "boolean_field", "type": "boolean"},
|
|
{"name": "string_field", "type": "string"},
|
|
],
|
|
}
|
|
)
|
|
)
|
|
writer = DataFileWriter(file, DatumWriter(), schema)
|
|
records = test_table.to_dict(orient="records")
|
|
for record in records:
|
|
writer.append(record)
|
|
writer.sync()
|
|
|
|
file.seek(0)
|
|
|
|
fields = AvroInferrer().infer_schema(file)
|
|
|
|
assert_field_paths_match(fields, expected_field_paths_avro)
|
|
assert_field_types_match(fields, expected_field_types)
|