mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-02 11:49:23 +00:00
192 lines
6.9 KiB
Python
192 lines
6.9 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from datahub.ingestion.source.data_lake_common.data_lake_utils import (
|
|
add_partition_columns_to_schema,
|
|
)
|
|
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
|
|
from datahub.metadata.schema_classes import (
|
|
NumberTypeClass,
|
|
SchemaFieldClass,
|
|
SchemaFieldDataTypeClass,
|
|
StringTypeClass,
|
|
)
|
|
|
|
|
|
class TestAddPartitionColumnsToSchema:
|
|
def create_mock_path_spec(
|
|
self, partition_result: Optional[List[Tuple[str, str]]] = None
|
|
) -> PathSpec:
|
|
mock_path_spec = MagicMock(spec=PathSpec)
|
|
mock_path_spec.get_partition_from_path.return_value = partition_result
|
|
return mock_path_spec
|
|
|
|
def create_schema_field(
|
|
self,
|
|
field_path: str,
|
|
field_type: Union[StringTypeClass, NumberTypeClass] = StringTypeClass(),
|
|
) -> SchemaFieldClass:
|
|
return SchemaFieldClass(
|
|
fieldPath=field_path,
|
|
nativeDataType="string",
|
|
type=SchemaFieldDataTypeClass(field_type),
|
|
nullable=False,
|
|
recursive=False,
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"partition_keys,expected_field_paths",
|
|
[
|
|
# Simple partition
|
|
([("year", "2023")], ["year"]),
|
|
# Multiple partitions
|
|
([("year", "2023"), ("month", "01")], ["year", "month"]),
|
|
# Complex partition keys with underscores
|
|
(
|
|
[("partition_0", "value1"), ("partition_1", "value2")],
|
|
["partition_0", "partition_1"],
|
|
),
|
|
# User's complex case
|
|
(
|
|
[("date", "2023-01-01"), ("region", "us-east"), ("category", "sales")],
|
|
["date", "region", "category"],
|
|
),
|
|
],
|
|
)
|
|
def test_add_partition_columns_basic(self, partition_keys, expected_field_paths):
|
|
# Setup
|
|
path_spec = self.create_mock_path_spec(partition_keys)
|
|
fields = [
|
|
self.create_schema_field("existing_field1"),
|
|
self.create_schema_field("existing_field2"),
|
|
]
|
|
original_field_count = len(fields)
|
|
|
|
# Execute
|
|
add_partition_columns_to_schema(path_spec, "/test/path", fields)
|
|
|
|
# Assert
|
|
assert len(fields) == original_field_count + len(expected_field_paths)
|
|
|
|
# Check that partition fields were added correctly
|
|
partition_fields = fields[original_field_count:]
|
|
for i, expected_path in enumerate(expected_field_paths):
|
|
field = partition_fields[i]
|
|
assert field.fieldPath == expected_path
|
|
assert field.isPartitioningKey is True
|
|
assert field.nullable is False
|
|
assert isinstance(field.type.type, StringTypeClass)
|
|
assert field.nativeDataType == "string"
|
|
|
|
@pytest.mark.parametrize(
|
|
"existing_fields,expected_v2_format",
|
|
[
|
|
# v1 format only
|
|
(["regular_field", "another_field"], False),
|
|
# v2 format detected
|
|
(["[version=2.0].[type=string].existing_field", "regular_field"], True),
|
|
# Mixed with v2 present
|
|
(["regular_field", "[version=2.0].[type=int].id", "another_field"], True),
|
|
# Empty fields list
|
|
([], False),
|
|
],
|
|
)
|
|
def test_fieldpath_version_detection(self, existing_fields, expected_v2_format):
|
|
path_spec = self.create_mock_path_spec([("year", "2023")])
|
|
fields = [
|
|
self.create_schema_field(field_path) for field_path in existing_fields
|
|
]
|
|
original_field_count = len(fields)
|
|
|
|
add_partition_columns_to_schema(path_spec, "/test/path", fields)
|
|
|
|
if expected_v2_format:
|
|
partition_field = fields[original_field_count]
|
|
assert partition_field.fieldPath == "[version=2.0].[type=string].year"
|
|
else:
|
|
partition_field = fields[original_field_count]
|
|
assert partition_field.fieldPath == "year"
|
|
|
|
@pytest.mark.parametrize(
|
|
"partition_result",
|
|
[
|
|
None, # No partitions detected
|
|
[], # Empty partition list
|
|
],
|
|
)
|
|
def test_no_partitions_detected(self, partition_result):
|
|
path_spec = self.create_mock_path_spec(partition_result)
|
|
fields = [
|
|
self.create_schema_field("existing_field1"),
|
|
self.create_schema_field("existing_field2"),
|
|
]
|
|
original_fields = fields.copy()
|
|
|
|
add_partition_columns_to_schema(path_spec, "/test/path", fields)
|
|
|
|
assert fields == original_fields
|
|
|
|
def test_preserves_existing_fields(self):
|
|
path_spec = self.create_mock_path_spec([("year", "2023")])
|
|
original_field = self.create_schema_field("existing_field", NumberTypeClass())
|
|
original_field.isPartitioningKey = False
|
|
fields = [original_field]
|
|
|
|
add_partition_columns_to_schema(path_spec, "/test/path", fields)
|
|
|
|
assert fields[0] == original_field
|
|
assert fields[0].fieldPath == "existing_field"
|
|
assert isinstance(fields[0].type.type, NumberTypeClass)
|
|
assert fields[0].isPartitioningKey is False
|
|
|
|
assert len(fields) == 2
|
|
assert fields[1].fieldPath == "year"
|
|
assert fields[1].isPartitioningKey is True
|
|
|
|
def test_empty_fields_list(self):
|
|
path_spec = self.create_mock_path_spec([("year", "2023")])
|
|
fields: List[SchemaFieldClass] = []
|
|
|
|
add_partition_columns_to_schema(path_spec, "/test/path", fields)
|
|
|
|
assert len(fields) == 1
|
|
assert fields[0].fieldPath == "year"
|
|
assert fields[0].isPartitioningKey is True
|
|
|
|
def test_real_world_complex_partition_scenario(self):
|
|
# This simulates the user's path spec:
|
|
# /{partition_key[0]}={partition[0]}/{partition_key[1]}={partition[1]}/{partition_key[2]}={partition[2]}/
|
|
partition_keys = [
|
|
("date", "2023-12-01"),
|
|
("region", "us-east"),
|
|
("category", "sales"),
|
|
]
|
|
|
|
path_spec = self.create_mock_path_spec(partition_keys)
|
|
fields = [
|
|
self.create_schema_field("customer_id"),
|
|
self.create_schema_field("amount"),
|
|
self.create_schema_field("transaction_date"),
|
|
]
|
|
original_field_count = len(fields)
|
|
|
|
add_partition_columns_to_schema(
|
|
path_spec,
|
|
"https://odedmdatacataloggold.blob.core.windows.net/settler/transactions/partitioned/date=2023-12-01/region=us-east/category=sales/data.parquet",
|
|
fields,
|
|
)
|
|
|
|
assert len(fields) == original_field_count + 3
|
|
|
|
partition_fields = fields[original_field_count:]
|
|
expected_partitions = ["date", "region", "category"]
|
|
|
|
for i, expected_name in enumerate(expected_partitions):
|
|
field = partition_fields[i]
|
|
assert field.fieldPath == expected_name
|
|
assert field.isPartitioningKey is True
|
|
assert field.nullable is False
|
|
assert isinstance(field.type.type, StringTypeClass)
|