fix(ingest/gcs): Fix GCS URI mismatch causing file filtering during ingestion (#14006)

This commit is contained in:
Tamas Nemeth 2025-07-09 21:22:22 +01:00 committed by GitHub
parent 4807946d46
commit b1354abcba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 390 additions and 15 deletions

View File

@ -531,7 +531,7 @@ plugins: Dict[str, Set[str]] = {
| {"db-dtypes"} # Pandas extension data types
| cachetools_lib,
"s3": {*s3_base, *data_lake_profiling},
"gcs": {*s3_base, *data_lake_profiling},
"gcs": {*s3_base, *data_lake_profiling, "smart-open[gcs]>=5.2.1"},
"abs": {*abs_base, *data_lake_profiling},
"sagemaker": aws_common,
"salesforce": {"simple-salesforce", *cachetools_lib},

View File

@ -17,10 +17,9 @@ def _load_lineage_data() -> Dict:
Load lineage data from the autogenerated lineage.json file.
Returns:
Dict containing the lineage information
Dict containing the lineage information, or empty dict if file doesn't exist
Raises:
FileNotFoundError: If lineage.json doesn't exist
json.JSONDecodeError: If lineage.json is malformed
"""
global _lineage_data
@ -33,16 +32,23 @@ def _load_lineage_data() -> Dict:
lineage_file = current_file.parent / "lineage.json"
if not lineage_file.exists():
raise FileNotFoundError(f"Lineage file not found: {lineage_file}")
logger.warning(
f"Lineage file not found: {lineage_file}. "
"This may indicate a packaging issue. Lineage detection will be disabled."
)
_lineage_data = {}
return _lineage_data
try:
with open(lineage_file, "r") as f:
_lineage_data = json.load(f)
return _lineage_data
except json.JSONDecodeError as e:
raise json.JSONDecodeError(
f"Failed to parse lineage.json: {e}", e.doc, e.pos
) from e
logger.error(
f"Failed to parse lineage.json: {e}. Lineage detection will be disabled."
)
_lineage_data = {}
return _lineage_data
def _get_fields(entity_type: str, aspect_name: str) -> List[Dict]:

View File

@ -519,6 +519,13 @@ class ObjectStoreSourceAdapter:
"get_external_url",
lambda table_data: self.get_gcs_external_url(table_data),
)
# Fix URI mismatch issue in pattern matching
self.register_customization(
"_normalize_uri_for_pattern_matching",
self._normalize_gcs_uri_for_pattern_matching,
)
# Fix URI handling in schema extraction - override strip_s3_prefix for GCS
self.register_customization("strip_s3_prefix", self._strip_gcs_prefix)
elif platform == "s3":
self.register_customization("is_s3_platform", lambda: True)
self.register_customization("create_s3_path", self.create_s3_path)
@ -612,6 +619,39 @@ class ObjectStoreSourceAdapter:
return self.get_abs_external_url(table_data)
return None
def _normalize_gcs_uri_for_pattern_matching(self, uri: str) -> str:
"""
Normalize GCS URI for pattern matching.
This method converts gs:// URIs to s3:// URIs for pattern matching purposes,
fixing the URI mismatch issue in GCS ingestion.
Args:
uri: The URI to normalize
Returns:
The normalized URI for pattern matching
"""
if uri.startswith("gs://"):
return uri.replace("gs://", "s3://", 1)
return uri
def _strip_gcs_prefix(self, uri: str) -> str:
"""
Strip GCS prefix from URI.
This method removes the gs:// prefix from GCS URIs for path processing.
Args:
uri: The URI to strip the prefix from
Returns:
The URI without the gs:// prefix
"""
if uri.startswith("gs://"):
return uri[5:] # Remove "gs://" prefix
return uri
# Factory function to create an adapter for a specific platform
def create_object_store_adapter(

View File

@ -112,6 +112,7 @@ class GCSSource(StatefulIngestionSourceBase):
env=self.config.env,
max_rows=self.config.max_rows,
number_of_files_to_sample=self.config.number_of_files_to_sample,
platform=PLATFORM_GCS, # Ensure GCS platform is used for correct container subtypes
)
return s3_config
@ -138,7 +139,9 @@ class GCSSource(StatefulIngestionSourceBase):
def create_equivalent_s3_source(self, ctx: PipelineContext) -> S3Source:
config = self.create_equivalent_s3_config()
s3_source = S3Source(config, PipelineContext(ctx.run_id))
# Create a new context for S3 source without graph to avoid duplicate checkpointer registration
s3_ctx = PipelineContext(run_id=ctx.run_id, pipeline_name=ctx.pipeline_name)
s3_source = S3Source(config, s3_ctx)
return self.s3_source_overrides(s3_source)
def s3_source_overrides(self, source: S3Source) -> S3Source:

View File

@ -682,7 +682,7 @@ class S3Source(StatefulIngestionSourceBase):
logger.info(f"Extracting table schema from file: {table_data.full_path}")
browse_path: str = (
strip_s3_prefix(table_data.table_path)
self.strip_s3_prefix(table_data.table_path)
if self.is_s3_platform()
else table_data.table_path.strip("/")
)
@ -949,7 +949,10 @@ class S3Source(StatefulIngestionSourceBase):
"""
def _is_allowed_path(path_spec_: PathSpec, s3_uri: str) -> bool:
allowed = path_spec_.allowed(s3_uri)
# Normalize URI for pattern matching
normalized_uri = self._normalize_uri_for_pattern_matching(s3_uri)
allowed = path_spec_.allowed(normalized_uri)
if not allowed:
logger.debug(f"File {s3_uri} not allowed and skipping")
self.report.report_file_dropped(s3_uri)
@ -1394,8 +1397,13 @@ class S3Source(StatefulIngestionSourceBase):
)
table_dict: Dict[str, TableData] = {}
for browse_path in file_browser:
# Normalize URI for pattern matching
normalized_file_path = self._normalize_uri_for_pattern_matching(
browse_path.file
)
if not path_spec.allowed(
browse_path.file,
normalized_file_path,
ignore_ext=self.is_s3_platform()
and self.source_config.use_s3_content_type,
):
@ -1471,5 +1479,13 @@ class S3Source(StatefulIngestionSourceBase):
def is_s3_platform(self):
return self.source_config.platform == "s3"
def strip_s3_prefix(self, s3_uri: str) -> str:
"""Strip S3 prefix from URI. Can be overridden by adapters for other platforms."""
return strip_s3_prefix(s3_uri)
def _normalize_uri_for_pattern_matching(self, uri: str) -> str:
"""Normalize URI for pattern matching. Can be overridden by adapters for other platforms."""
return uri
def get_report(self):
return self.report

View File

@ -101,11 +101,13 @@ class TestLineageHelper:
def test_load_lineage_data_file_not_found(self, monkeypatch):
self.setup_mock_file_operations(monkeypatch, "", exists=False)
with pytest.raises(FileNotFoundError):
_load_lineage_data()
# Should return empty dict instead of raising exception
result = _load_lineage_data()
assert result == {}
def test_load_lineage_data_invalid_json(self, monkeypatch):
self.setup_mock_file_operations(monkeypatch, "invalid json", exists=True)
with pytest.raises(json.JSONDecodeError):
_load_lineage_data()
# Should return empty dict instead of raising exception
result = _load_lineage_data()
assert result == {}

View File

@ -1,6 +1,9 @@
import pathlib
import unittest
from unittest.mock import MagicMock
import pytest
from datahub.ingestion.source.data_lake_common.object_store import (
ABSObjectStore,
GCSObjectStore,
@ -451,5 +454,123 @@ class TestCreateObjectStoreAdapter(unittest.TestCase):
self.assertEqual(adapter.platform_name, "Unknown (unknown)")
# Parametrized tests for GCS URI normalization
@pytest.mark.parametrize(
"input_uri,expected",
[
("gs://bucket/path/to/file.parquet", "s3://bucket/path/to/file.parquet"),
("s3://bucket/path/to/file.parquet", "s3://bucket/path/to/file.parquet"),
("", ""),
("gs://bucket/", "s3://bucket/"),
("gs://bucket/nested/path/file.json", "s3://bucket/nested/path/file.json"),
],
)
def test_gcs_uri_normalization_for_pattern_matching(input_uri, expected):
"""Test that GCS URIs are normalized to S3 URIs for pattern matching."""
gcs_adapter = create_object_store_adapter("gcs")
result = gcs_adapter._normalize_gcs_uri_for_pattern_matching(input_uri)
assert result == expected
@pytest.mark.parametrize(
"input_uri,expected",
[
("gs://bucket/path/to/file.parquet", "bucket/path/to/file.parquet"),
("s3://bucket/path/to/file.parquet", "s3://bucket/path/to/file.parquet"),
("", ""),
("gs://bucket/", "bucket/"),
("gs://bucket/nested/path/file.json", "bucket/nested/path/file.json"),
],
)
def test_gcs_prefix_stripping(input_uri, expected):
"""Test that GCS prefixes are stripped correctly."""
gcs_adapter = create_object_store_adapter("gcs")
result = gcs_adapter._strip_gcs_prefix(input_uri)
assert result == expected
class TestGCSURINormalization(unittest.TestCase):
"""Tests for the GCS URI normalization fix."""
def test_gcs_adapter_customizations(self):
"""Test that GCS adapter registers the expected customizations."""
gcs_adapter = create_object_store_adapter("gcs")
# Check that the required customizations are registered
expected_customizations = [
"is_s3_platform",
"create_s3_path",
"get_external_url",
"_normalize_uri_for_pattern_matching",
"strip_s3_prefix",
]
for customization in expected_customizations:
self.assertIn(customization, gcs_adapter.customizations)
def test_gcs_adapter_applied_to_mock_source(self):
"""Test that GCS adapter customizations are applied to a mock source."""
gcs_adapter = create_object_store_adapter("gcs")
# Create a mock S3 source
mock_source = MagicMock()
mock_source.source_config = MagicMock()
# Apply customizations
gcs_adapter.apply_customizations(mock_source)
# Check that the customizations were applied
self.assertTrue(hasattr(mock_source, "_normalize_uri_for_pattern_matching"))
self.assertTrue(hasattr(mock_source, "strip_s3_prefix"))
self.assertTrue(hasattr(mock_source, "create_s3_path"))
# Test that the URI normalization method works on the mock source
test_uri = "gs://bucket/path/file.parquet"
normalized = mock_source._normalize_uri_for_pattern_matching(test_uri)
self.assertEqual(normalized, "s3://bucket/path/file.parquet")
# Test that the prefix stripping method works on the mock source
stripped = mock_source.strip_s3_prefix(test_uri)
self.assertEqual(stripped, "bucket/path/file.parquet")
def test_gcs_path_creation_via_adapter(self):
"""Test that GCS paths are created correctly via the adapter."""
gcs_adapter = create_object_store_adapter("gcs")
# Create a mock source and apply customizations
mock_source = MagicMock()
mock_source.source_config = MagicMock()
gcs_adapter.apply_customizations(mock_source)
# Test that create_s3_path now creates GCS paths
gcs_path = mock_source.create_s3_path("bucket", "path/to/file.parquet")
self.assertEqual(gcs_path, "gs://bucket/path/to/file.parquet")
def test_pattern_matching_scenario(self):
"""Test the actual pattern matching scenario that was failing."""
gcs_adapter = create_object_store_adapter("gcs")
# Simulate the scenario where:
# 1. Path spec pattern is s3://bucket/path/{table}/*.parquet
# 2. File URI is gs://bucket/path/food_parquet/file.parquet
path_spec_pattern = "s3://bucket/path/{table}/*.parquet"
file_uri = "gs://bucket/path/food_parquet/file.parquet"
# Normalize the file URI for pattern matching
normalized_file_uri = gcs_adapter._normalize_gcs_uri_for_pattern_matching(
file_uri
)
# The normalized URI should now be compatible with the pattern
self.assertEqual(
normalized_file_uri, "s3://bucket/path/food_parquet/file.parquet"
)
# Test that the normalized URI would match the pattern (simplified test)
glob_pattern = path_spec_pattern.replace("{table}", "*")
self.assertTrue(pathlib.PurePath(normalized_file_uri).match(glob_pattern))
if __name__ == "__main__":
unittest.main()

View File

@ -1,3 +1,5 @@
import pathlib
import re
from unittest import mock
import pytest
@ -81,3 +83,188 @@ def test_data_lake_incorrect_config_raises_error():
}
with pytest.raises(ValidationError, match=r"\*\*"):
GCSSource.create(source, ctx)
def test_gcs_uri_normalization_fix():
"""Test that GCS URIs are normalized correctly for pattern matching."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source with a path spec that includes table templating
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/year={partition[0]}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Check that the S3 source has the URI normalization method
assert hasattr(gcs_source.s3_source, "_normalize_uri_for_pattern_matching")
# Check that strip_s3_prefix is overridden for GCS
assert hasattr(gcs_source.s3_source, "strip_s3_prefix")
# Test URI normalization
gs_uri = "gs://test-bucket/data/food_parquet/year=2023/file.parquet"
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(gs_uri)
assert normalized_uri == "s3://test-bucket/data/food_parquet/year=2023/file.parquet"
# Test prefix stripping
stripped_uri = gcs_source.s3_source.strip_s3_prefix(gs_uri)
assert stripped_uri == "test-bucket/data/food_parquet/year=2023/file.parquet"
@pytest.mark.parametrize(
"gs_uri,expected_normalized,expected_stripped",
[
(
"gs://test-bucket/data/food_parquet/year=2023/file.parquet",
"s3://test-bucket/data/food_parquet/year=2023/file.parquet",
"test-bucket/data/food_parquet/year=2023/file.parquet",
),
(
"gs://my-bucket/simple/file.json",
"s3://my-bucket/simple/file.json",
"my-bucket/simple/file.json",
),
(
"gs://bucket/nested/deep/path/data.csv",
"s3://bucket/nested/deep/path/data.csv",
"bucket/nested/deep/path/data.csv",
),
],
)
def test_gcs_uri_transformations(gs_uri, expected_normalized, expected_stripped):
"""Test GCS URI normalization and prefix stripping with various inputs."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Test URI normalization
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(gs_uri)
assert normalized_uri == expected_normalized
# Test prefix stripping
stripped_uri = gcs_source.s3_source.strip_s3_prefix(gs_uri)
assert stripped_uri == expected_stripped
def test_gcs_path_spec_pattern_matching():
"""Test that GCS path specs correctly match files after URI normalization."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/year={partition[0]}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Get the path spec that was converted to S3 format
s3_path_spec = gcs_source.s3_source.source_config.path_specs[0]
# The path spec should have been converted to S3 format
assert (
s3_path_spec.include
== "s3://test-bucket/data/{table}/year={partition[0]}/*.parquet"
)
# Test that a GCS file URI would be normalized for pattern matching
gs_file_uri = "gs://test-bucket/data/food_parquet/year=2023/file.parquet"
normalized_uri = gcs_source.s3_source._normalize_uri_for_pattern_matching(
gs_file_uri
)
# Convert the path spec pattern to glob format (similar to what PathSpec.glob_include does)
glob_pattern = re.sub(r"\{[^}]+\}", "*", s3_path_spec.include)
assert pathlib.PurePath(normalized_uri).match(glob_pattern)
def test_gcs_source_preserves_gs_uris():
"""Test that GCS source preserves gs:// URIs in the final output."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
# Create a GCS source
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Test that create_s3_path creates GCS URIs
gcs_path = gcs_source.s3_source.create_s3_path(
"test-bucket", "data/food_parquet/file.parquet"
)
assert gcs_path == "gs://test-bucket/data/food_parquet/file.parquet"
# Test that the platform is correctly set
assert gcs_source.s3_source.source_config.platform == PLATFORM_GCS
# Test that container subtypes are correctly set for GCS
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
container_creator = gcs_source.s3_source.container_WU_creator
assert container_creator.get_sub_types() == DatasetContainerSubTypes.GCS_BUCKET
def test_gcs_container_subtypes():
"""Test that GCS containers use 'GCS bucket' subtype instead of 'S3 bucket'."""
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")
source = {
"path_specs": [
{
"include": "gs://test-bucket/data/{table}/*.parquet",
"table_name": "{table}",
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
}
gcs_source = GCSSource.create(source, ctx)
# Verify the platform is set correctly
assert gcs_source.s3_source.source_config.platform == PLATFORM_GCS
# Verify container subtypes use GCS bucket, not S3 bucket
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
container_creator = gcs_source.s3_source.container_WU_creator
# Should return "GCS bucket" for GCS platform
assert container_creator.get_sub_types() == DatasetContainerSubTypes.GCS_BUCKET
assert container_creator.get_sub_types() == "GCS bucket"
# Should NOT return "S3 bucket"
assert container_creator.get_sub_types() != DatasetContainerSubTypes.S3_BUCKET
assert container_creator.get_sub_types() != "S3 bucket"