datahub/metadata-ingestion/tests/unit/test_gcs_source.py

271 lines
9.6 KiB
Python

import pathlib
import re
from unittest import mock
import pytest
from pydantic import ValidationError
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.data_lake_common.data_lake_utils import PLATFORM_GCS
from datahub.ingestion.source.gcs.gcs_source import GCSSource
def test_gcs_source_setup():
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Baseline: valid config
source: dict = {
"path_specs": [
{
"include": "gs://bucket_name/{table}/year={partition[0]}/month={partition[1]}/day={partition[1]}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
"stateful_ingestion": {"enabled": "true"},
}
gcs = GCSSource.create(source, ctx)
assert gcs.s3_source.source_config.platform == PLATFORM_GCS
assert (
gcs.s3_source.create_s3_path(
"bucket-name", "food_parquet/year%3D2023/month%3D4/day%3D24/part1.parquet"
)
== "gs://bucket-name/food_parquet/year=2023/month=4/day=24/part1.parquet"
)
def test_data_lake_incorrect_config_raises_error():
ctx = PipelineContext(run_id="test-gcs")
# Case 1 : named variable in table name is not present in include
source = {
"path_specs": [{"include": "gs://a/b/c/d/{table}.*", "table_name": "{table1}"}],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
with pytest.raises(ValidationError, match="table_name"):
GCSSource.create(source, ctx)
# Case 2 : named variable in exclude is not allowed
source = {
"path_specs": [
{
"include": "gs://a/b/c/d/{table}/*.*",
"exclude": ["gs://a/b/c/d/a-{exclude}/**"],
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
with pytest.raises(ValidationError, match=r"exclude.*named variable"):
GCSSource.create(source, ctx)
# Case 3 : unsupported file type not allowed
source = {
"path_specs": [
{
"include": "gs://a/b/c/d/{table}/*.hd5",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
with pytest.raises(ValidationError, match="file type"):
GCSSource.create(source, ctx)
# Case 4 : ** in include not allowed
source = {
"path_specs": [
{
"include": "gs://a/b/c/d/**/*.*",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
with pytest.raises(ValidationError, match=r"\*\*"):
GCSSource.create(source, ctx)
def test_gcs_uri_normalization_fix():
"""Test that GCS URIs are normalized correctly for pattern matching."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source with a path spec that includes table templating
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/year={partition[0]}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Check that the S3 source has the URI normalization method
assert hasattr(gcs_source.s3_source, "_normalize_uri_for_pattern_matching")
# Check that strip_s3_prefix is overridden for GCS
assert hasattr(gcs_source.s3_source, "strip_s3_prefix")
# Test URI normalization
gs_uri = "gs://test-bucket/data/food_parquet/year=2023/file.parquet"
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(gs_uri)
assert normalized_uri == "s3://test-bucket/data/food_parquet/year=2023/file.parquet"
# Test prefix stripping
stripped_uri = gcs_source.s3_source.strip_s3_prefix(gs_uri)
assert stripped_uri == "test-bucket/data/food_parquet/year=2023/file.parquet"
@pytest.mark.parametrize(
"gs_uri,expected_normalized,expected_stripped",
[
(
"gs://test-bucket/data/food_parquet/year=2023/file.parquet",
"s3://test-bucket/data/food_parquet/year=2023/file.parquet",
"test-bucket/data/food_parquet/year=2023/file.parquet",
),
(
"gs://my-bucket/simple/file.json",
"s3://my-bucket/simple/file.json",
"my-bucket/simple/file.json",
),
(
"gs://bucket/nested/deep/path/data.csv",
"s3://bucket/nested/deep/path/data.csv",
"bucket/nested/deep/path/data.csv",
),
],
)
def test_gcs_uri_transformations(gs_uri, expected_normalized, expected_stripped):
"""Test GCS URI normalization and prefix stripping with various inputs."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Test URI normalization
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(gs_uri)
assert normalized_uri == expected_normalized
# Test prefix stripping
stripped_uri = gcs_source.s3_source.strip_s3_prefix(gs_uri)
assert stripped_uri == expected_stripped
def test_gcs_path_spec_pattern_matching():
"""Test that GCS path specs correctly match files after URI normalization."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/year={partition[0]}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Get the path spec that was converted to S3 format
s3_path_spec = gcs_source.s3_source.source_config.path_specs[0]
# The path spec should have been converted to S3 format
assert (
s3_path_spec.include
== "s3://test-bucket/data/{table}/year={partition[0]}/*.parquet"
)
# Test that a GCS file URI would be normalized for pattern matching
gs_file_uri = "gs://test-bucket/data/food_parquet/year=2023/file.parquet"
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(
gs_file_uri
)
# Convert the path spec pattern to glob format (similar to what PathSpec.glob_include does)
glob_pattern = re.sub(r"\{[^}]+\}", "*", s3_path_spec.include)
assert pathlib.PurePath(normalized_uri).match(glob_pattern)
def test_gcs_source_preserves_gs_uris():
"""Test that GCS source preserves gs:// URIs in the final output."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Test that create_s3_path creates GCS URIs
gcs_path = gcs_source.s3_source.create_s3_path(
"test-bucket", "data/food_parquet/file.parquet"
)
assert gcs_path == "gs://test-bucket/data/food_parquet/file.parquet"
# Test that the platform is correctly set
assert gcs_source.s3_source.source_config.platform == PLATFORM_GCS
# Test that container subtypes are correctly set for GCS
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
container_creator = gcs_source.s3_source.container_WU_creator
assert container_creator.get_sub_types() == DatasetContainerSubTypes.GCS_BUCKET
def test_gcs_container_subtypes():
"""Test that GCS containers use 'GCS bucket' subtype instead of 'S3 bucket'."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Verify the platform is set correctly
assert gcs_source.s3_source.source_config.platform == PLATFORM_GCS
# Verify container subtypes use GCS bucket, not S3 bucket
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
container_creator = gcs_source.s3_source.container_WU_creator
# Should return "GCS bucket" for GCS platform
assert container_creator.get_sub_types() == DatasetContainerSubTypes.GCS_BUCKET
assert container_creator.get_sub_types() == "GCS bucket"
# Should NOT return "S3 bucket"
assert container_creator.get_sub_types() != DatasetContainerSubTypes.S3_BUCKET
assert container_creator.get_sub_types() != "S3 bucket"