datahub/metadata-ingestion/tests/unit/test_glue_source.py

215 lines
6.7 KiB
Python
Raw Normal View History

import json
from pathlib import Path
from typing import Dict, Optional, Tuple, Type
import pytest
from botocore.stub import Stubber
from freezegun import freeze_time
from datahub.configuration.common import ConfigurationError
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayTypeClass,
MapTypeClass,
RecordTypeClass,
StringTypeClass,
)
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
from tests.test_helpers import mce_helpers
from tests.test_helpers.type_helpers import PytestConfig
from tests.unit.test_glue_source_stubs import (
get_bucket_tagging,
get_databases_response,
get_dataflow_graph_response_1,
get_dataflow_graph_response_2,
get_jobs_response,
get_object_body_1,
get_object_body_2,
get_object_response_1,
get_object_response_2,
get_object_tagging,
get_tables_response_1,
get_tables_response_2,
)
FROZEN_TIME = "2020-04-14 07:00:00"
def glue_source(platform_instance: Optional[str] = None) -> GlueSource:
return GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(
aws_region="us-west-2",
extract_transforms=True,
platform_instance=platform_instance,
use_s3_bucket_tags=True,
use_s3_object_tags=True,
),
)
column_type_test_cases: Dict[str, Tuple[str, Type]] = {
"char": ("char", StringTypeClass),
"array": ("array<int>", ArrayTypeClass),
"map": ("map<string, int>", MapTypeClass),
"struct": ("struct<a:int, b:string>", RecordTypeClass),
}
@pytest.mark.parametrize(
"hive_column_type, expected_type",
column_type_test_cases.values(),
ids=column_type_test_cases.keys(),
)
def test_column_type(hive_column_type: str, expected_type: Type) -> None:
avro_schema = get_avro_schema_for_hive_column(
f"test_column_{hive_column_type}", hive_column_type
)
schema_fields = avro_schema_to_mce_fields(json.dumps(avro_schema))
actual_schema_field_type = schema_fields[0].type
assert type(actual_schema_field_type.type) == expected_type
@pytest.mark.parametrize(
"platform_instance, mce_file, mce_golden_file",
[
(None, "glue_mces.json", "glue_mces_golden.json"),
(
"some_instance_name",
"glue_mces_platform_instance.json",
"glue_mces_platform_instance_golden.json",
),
],
)
@freeze_time(FROZEN_TIME)
def test_glue_ingest(
tmp_path: Path,
pytestconfig: PytestConfig,
platform_instance: str,
mce_file: str,
mce_golden_file: str,
) -> None:
glue_source_instance = glue_source(platform_instance=platform_instance)
with Stubber(glue_source_instance.glue_client) as glue_stubber:
glue_stubber.add_response("get_databases", get_databases_response, {})
glue_stubber.add_response(
"get_tables",
get_tables_response_1,
{"DatabaseName": "flights-database"},
)
glue_stubber.add_response(
"get_tables",
get_tables_response_2,
{"DatabaseName": "test-database"},
)
glue_stubber.add_response("get_jobs", get_jobs_response, {})
glue_stubber.add_response(
"get_dataflow_graph",
get_dataflow_graph_response_1,
{"PythonScript": get_object_body_1},
)
glue_stubber.add_response(
"get_dataflow_graph",
get_dataflow_graph_response_2,
{"PythonScript": get_object_body_2},
)
with Stubber(glue_source_instance.s3_client) as s3_stubber:
for _ in range(
len(get_tables_response_1["TableList"])
+ len(get_tables_response_2["TableList"])
):
s3_stubber.add_response(
"get_bucket_tagging",
get_bucket_tagging(),
)
s3_stubber.add_response(
"get_object_tagging",
get_object_tagging(),
)
s3_stubber.add_response(
"get_object",
get_object_response_1(),
{
"Bucket": "aws-glue-assets-123412341234-us-west-2",
"Key": "scripts/job-1.py",
},
)
s3_stubber.add_response(
"get_object",
get_object_response_2(),
{
"Bucket": "aws-glue-assets-123412341234-us-west-2",
"Key": "scripts/job-2.py",
},
)
mce_objects = [
wu.metadata.to_obj() for wu in glue_source_instance.get_workunits()
]
glue_stubber.assert_no_pending_responses()
s3_stubber.assert_no_pending_responses()
with open(str(tmp_path / mce_file), "w") as f:
json.dump(mce_objects, f, indent=2)
# Verify the output.
test_resources_dir = pytestconfig.rootpath / "tests/unit/glue"
mce_helpers.check_golden_file(
pytestconfig,
output_path=tmp_path / mce_file,
golden_path=test_resources_dir / mce_golden_file,
)
def test_underlying_platform_takes_precendence():
source = GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(aws_region="us-west-2", underlying_platform="athena"),
)
assert source.platform == "athena"
def test_platform_takes_precendence_over_underlying_platform():
source = GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(
aws_region="us-west-2", platform="athena", underlying_platform="glue"
),
)
assert source.platform == "athena"
def test_underlying_platform_must_be_valid():
with pytest.raises(ConfigurationError):
GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(
aws_region="us-west-2", underlying_platform="data-warehouse"
),
)
def test_platform_must_be_valid():
with pytest.raises(ConfigurationError):
GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(aws_region="us-west-2", platform="data-warehouse"),
)
def test_without_underlying_platform():
source = GlueSource(
ctx=PipelineContext(run_id="glue-source-test"),
config=GlueSourceConfig(aws_region="us-west-2"),
)
assert source.platform == "glue"