mirror of
https://github.com/datahub-project/datahub.git
synced 2025-09-08 16:48:17 +00:00
fix(ingestion/abs): updated deprecated azure sdk parameter to supported parameter and uri prefix support of https (#14106)
This commit is contained in:
parent
ee5e280165
commit
010da3c480
@ -533,7 +533,7 @@ class ABSSource(StatefulIngestionSourceBase):
|
||||
)
|
||||
path_spec.sample_files = False
|
||||
for obj in container_client.list_blobs(
|
||||
prefix=f"{prefix}", results_per_page=PAGE_SIZE
|
||||
name_starts_with=f"{prefix}", results_per_page=PAGE_SIZE
|
||||
):
|
||||
abs_path = self.create_abs_path(obj.name)
|
||||
logger.debug(f"Path: {abs_path}")
|
||||
|
@ -61,13 +61,13 @@ class AzureConnectionConfig(ConfigModel):
|
||||
def get_blob_service_client(self):
|
||||
return BlobServiceClient(
|
||||
account_url=f"https://{self.account_name}.blob.core.windows.net",
|
||||
credential=f"{self.get_credentials()}",
|
||||
credential=self.get_credentials(),
|
||||
)
|
||||
|
||||
def get_data_lake_service_client(self) -> DataLakeServiceClient:
|
||||
return DataLakeServiceClient(
|
||||
account_url=f"https://{self.account_name}.dfs.core.windows.net",
|
||||
credential=f"{self.get_credentials()}",
|
||||
credential=self.get_credentials(),
|
||||
)
|
||||
|
||||
def get_credentials(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Add imports for source customization
|
||||
@ -236,42 +237,76 @@ class ABSObjectStore(ObjectStoreInterface):
|
||||
"""Implementation of ObjectStoreInterface for Azure Blob Storage."""
|
||||
|
||||
PREFIX = "abfss://"
|
||||
HTTPS_REGEX = re.compile(r"(https?://[a-z0-9]{3,24}\.blob\.core\.windows\.net/)")
|
||||
|
||||
@classmethod
|
||||
def is_uri(cls, uri: str) -> bool:
|
||||
return uri.startswith(cls.PREFIX)
|
||||
return uri.startswith(cls.PREFIX) or bool(cls.HTTPS_REGEX.match(uri))
|
||||
|
||||
@classmethod
|
||||
def get_prefix(cls, uri: str) -> Optional[str]:
|
||||
if uri.startswith(cls.PREFIX):
|
||||
return cls.PREFIX
|
||||
|
||||
# Check for HTTPS format
|
||||
match = cls.HTTPS_REGEX.match(uri)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def strip_prefix(cls, uri: str) -> str:
|
||||
prefix = cls.get_prefix(uri)
|
||||
if not prefix:
|
||||
raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}")
|
||||
return uri[len(prefix) :]
|
||||
if uri.startswith(cls.PREFIX):
|
||||
return uri[len(cls.PREFIX) :]
|
||||
|
||||
# Handle HTTPS format
|
||||
match = cls.HTTPS_REGEX.match(uri)
|
||||
if match:
|
||||
return uri[len(match.group(1)) :]
|
||||
|
||||
raise ValueError(
|
||||
f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_bucket_name(cls, uri: str) -> str:
|
||||
if not cls.is_uri(uri):
|
||||
raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}")
|
||||
return cls.strip_prefix(uri).split("@")[0]
|
||||
raise ValueError(
|
||||
f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern"
|
||||
)
|
||||
|
||||
if uri.startswith(cls.PREFIX):
|
||||
# abfss://container@account.dfs.core.windows.net/path
|
||||
return cls.strip_prefix(uri).split("@")[0]
|
||||
else:
|
||||
# https://account.blob.core.windows.net/container/path
|
||||
return cls.strip_prefix(uri).split("/")[0]
|
||||
|
||||
@classmethod
|
||||
def get_object_key(cls, uri: str) -> str:
|
||||
if not cls.is_uri(uri):
|
||||
raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}")
|
||||
parts = cls.strip_prefix(uri).split("@", 1)
|
||||
if len(parts) < 2:
|
||||
return ""
|
||||
account_path = parts[1]
|
||||
path_parts = account_path.split("/", 1)
|
||||
if len(path_parts) < 2:
|
||||
return ""
|
||||
return path_parts[1]
|
||||
raise ValueError(
|
||||
f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern"
|
||||
)
|
||||
|
||||
if uri.startswith(cls.PREFIX):
|
||||
# abfss://container@account.dfs.core.windows.net/path
|
||||
parts = cls.strip_prefix(uri).split("@", 1)
|
||||
if len(parts) < 2:
|
||||
return ""
|
||||
account_path = parts[1]
|
||||
path_parts = account_path.split("/", 1)
|
||||
if len(path_parts) < 2:
|
||||
return ""
|
||||
return path_parts[1]
|
||||
else:
|
||||
# https://account.blob.core.windows.net/container/path
|
||||
stripped = cls.strip_prefix(uri)
|
||||
parts = stripped.split("/", 1)
|
||||
if len(parts) < 2:
|
||||
return ""
|
||||
return parts[1]
|
||||
|
||||
|
||||
# Registry of all object store implementations
|
||||
@ -331,6 +366,12 @@ def get_object_store_bucket_name(uri: str) -> str:
|
||||
return uri[prefix_length:].split("/")[0]
|
||||
elif uri.startswith(ABSObjectStore.PREFIX):
|
||||
return uri[len(ABSObjectStore.PREFIX) :].split("@")[0]
|
||||
elif ABSObjectStore.HTTPS_REGEX.match(uri):
|
||||
# Handle HTTPS Azure Blob Storage URLs
|
||||
match = ABSObjectStore.HTTPS_REGEX.match(uri)
|
||||
if match:
|
||||
stripped = uri[len(match.group(1)) :]
|
||||
return stripped.split("/")[0]
|
||||
|
||||
raise ValueError(f"Unsupported URI format: {uri}")
|
||||
|
||||
@ -470,18 +511,25 @@ class ObjectStoreSourceAdapter:
|
||||
if not ABSObjectStore.is_uri(table_data.table_path):
|
||||
return None
|
||||
|
||||
# Parse the ABS URI
|
||||
try:
|
||||
# URI format: abfss://container@account.dfs.core.windows.net/path
|
||||
path_without_prefix = ABSObjectStore.strip_prefix(table_data.table_path)
|
||||
parts = path_without_prefix.split("@", 1)
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
if table_data.table_path.startswith("abfss://"):
|
||||
# URI format: abfss://container@account.dfs.core.windows.net/path
|
||||
path_without_prefix = ABSObjectStore.strip_prefix(table_data.table_path)
|
||||
parts = path_without_prefix.split("@", 1)
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
|
||||
container_name = parts[0]
|
||||
account_parts = parts[1].split("/", 1)
|
||||
account_domain = account_parts[0]
|
||||
account_name = account_domain.split(".")[0]
|
||||
container_name = parts[0]
|
||||
account_parts = parts[1].split("/", 1)
|
||||
account_domain = account_parts[0]
|
||||
account_name = account_domain.split(".")[0]
|
||||
else:
|
||||
# Handle HTTPS format: https://account.blob.core.windows.net/container/path
|
||||
container_name = ABSObjectStore.get_bucket_name(table_data.table_path)
|
||||
if "blob.core.windows.net" in table_data.table_path:
|
||||
account_name = table_data.table_path.split("//")[1].split(".")[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Construct Azure portal URL
|
||||
return f"https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/{account_name}/containerName/{container_name}"
|
||||
|
224
metadata-ingestion/tests/unit/abs/test_abs_source.py
Normal file
224
metadata-ingestion/tests/unit/abs/test_abs_source.py
Normal file
@ -0,0 +1,224 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from azure.identity import ClientSecretCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.storage.filedatalake import DataLakeServiceClient
|
||||
|
||||
from datahub.ingestion.source.azure.azure_common import AzureConnectionConfig
|
||||
|
||||
|
||||
def test_service_principal_credentials_return_objects():
|
||||
"""Service principal credentials must return ClientSecretCredential objects, not strings"""
|
||||
config = AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
tenant_id="test-tenant-id",
|
||||
)
|
||||
|
||||
credential = config.get_credentials()
|
||||
|
||||
assert isinstance(credential, ClientSecretCredential)
|
||||
assert not isinstance(credential, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth_type,config_params,expected_type",
|
||||
[
|
||||
(
|
||||
"service_principal",
|
||||
{
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"tenant_id": "test-tenant-id",
|
||||
},
|
||||
ClientSecretCredential,
|
||||
),
|
||||
("account_key", {"account_key": "test-account-key"}, str),
|
||||
("sas_token", {"sas_token": "test-sas-token"}, str),
|
||||
],
|
||||
)
|
||||
def test_credential_types_by_auth_method(auth_type, config_params, expected_type):
|
||||
"""Test that different authentication methods return correct credential types"""
|
||||
base_config = {"account_name": "testaccount", "container_name": "testcontainer"}
|
||||
config = AzureConnectionConfig(**{**base_config, **config_params})
|
||||
|
||||
credential = config.get_credentials()
|
||||
assert isinstance(credential, expected_type)
|
||||
|
||||
|
||||
def test_credential_object_not_converted_to_string():
|
||||
"""Credential objects should not be accidentally converted to strings via f-string formatting"""
|
||||
config = AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
tenant_id="test-tenant-id",
|
||||
)
|
||||
|
||||
credential = config.get_credentials()
|
||||
credential_as_string = f"{credential}"
|
||||
|
||||
assert isinstance(credential, ClientSecretCredential)
|
||||
assert credential != credential_as_string
|
||||
assert "ClientSecretCredential" in str(credential)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_client_class,method_name",
|
||||
[
|
||||
(BlobServiceClient, "get_blob_service_client"),
|
||||
(DataLakeServiceClient, "get_data_lake_service_client"),
|
||||
],
|
||||
)
|
||||
def test_service_clients_receive_credential_objects(service_client_class, method_name):
|
||||
"""Both BlobServiceClient and DataLakeServiceClient should receive credential objects"""
|
||||
config = AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
tenant_id="test-tenant-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
f"datahub.ingestion.source.azure.azure_common.{service_client_class.__name__}"
|
||||
) as mock_client:
|
||||
getattr(config, method_name)()
|
||||
|
||||
mock_client.assert_called_once()
|
||||
credential = mock_client.call_args[1]["credential"]
|
||||
assert isinstance(credential, ClientSecretCredential)
|
||||
assert not isinstance(credential, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"deprecated_param,new_param",
|
||||
[
|
||||
("prefix", "name_starts_with"),
|
||||
],
|
||||
)
|
||||
def test_azure_sdk_parameter_deprecation(deprecated_param, new_param):
|
||||
"""Test that demonstrates the Azure SDK parameter deprecation issue"""
|
||||
# This test shows why the fix was needed - deprecated params cause errors
|
||||
mock_container_client = Mock()
|
||||
|
||||
def list_blobs_with_validation(**kwargs):
|
||||
if deprecated_param in kwargs:
|
||||
raise ValueError(
|
||||
f"Passing '{deprecated_param}' has no effect on filtering, please use the '{new_param}' parameter instead."
|
||||
)
|
||||
return []
|
||||
|
||||
mock_container_client.list_blobs.side_effect = list_blobs_with_validation
|
||||
|
||||
# Test that the deprecated parameter causes an error (this is what was happening before the fix)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
mock_container_client.list_blobs(
|
||||
**{deprecated_param: "test/path", "results_per_page": 1000}
|
||||
)
|
||||
|
||||
assert new_param in str(exc_info.value)
|
||||
assert deprecated_param in str(exc_info.value)
|
||||
|
||||
# Test that the new parameter works (this is what the fix implemented)
|
||||
mock_container_client.list_blobs.side_effect = None
|
||||
mock_container_client.list_blobs.return_value = []
|
||||
|
||||
result = mock_container_client.list_blobs(
|
||||
**{new_param: "test/path", "results_per_page": 1000}
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("datahub.ingestion.source.azure.azure_common.BlobServiceClient")
|
||||
def test_datahub_source_uses_correct_azure_parameters(mock_blob_service_client_class):
|
||||
"""Test that DataHub source code actually uses the correct Azure SDK parameters"""
|
||||
# This test verifies that the real DataHub code calls Azure SDK with correct parameters
|
||||
mock_container_client = Mock()
|
||||
mock_blob_service_client = Mock()
|
||||
mock_blob_service_client.get_container_client.return_value = mock_container_client
|
||||
mock_blob_service_client_class.return_value = mock_blob_service_client
|
||||
|
||||
# Mock the blob objects returned by list_blobs
|
||||
mock_blob = Mock()
|
||||
mock_blob.name = "test/path/file.csv"
|
||||
mock_blob.size = 1024
|
||||
mock_container_client.list_blobs.return_value = [mock_blob]
|
||||
|
||||
# Now test the REAL DataHub code
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.source.abs.config import DataLakeSourceConfig
|
||||
from datahub.ingestion.source.abs.source import ABSSource
|
||||
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
|
||||
|
||||
# Create real DataHub source
|
||||
source_config = DataLakeSourceConfig(
|
||||
platform="abs",
|
||||
azure_config=AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
tenant_id="test-tenant-id",
|
||||
),
|
||||
path_specs=[
|
||||
PathSpec(
|
||||
include="https://testaccount.blob.core.windows.net/testcontainer/test/*.*",
|
||||
exclude=[],
|
||||
file_types=["csv"],
|
||||
sample_files=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
pipeline_context = PipelineContext(run_id="test-run-id", pipeline_name="abs-source")
|
||||
pipeline_context.graph = Mock()
|
||||
source = ABSSource(source_config, pipeline_context)
|
||||
|
||||
# Call the REAL DataHub method
|
||||
with patch(
|
||||
"datahub.ingestion.source.abs.source.get_container_relative_path",
|
||||
return_value="test/path",
|
||||
):
|
||||
path_spec = source_config.path_specs[0]
|
||||
list(source.abs_browser(path_spec, 100))
|
||||
|
||||
# NOW verify the real DataHub code called Azure SDK with correct parameters
|
||||
mock_container_client.list_blobs.assert_called_once_with(
|
||||
name_starts_with="test/path", results_per_page=1000
|
||||
)
|
||||
|
||||
# Verify the fix worked - no deprecated 'prefix' parameter
|
||||
call_args = mock_container_client.list_blobs.call_args
|
||||
assert "name_starts_with" in call_args[1]
|
||||
assert "prefix" not in call_args[1]
|
||||
|
||||
|
||||
def test_account_key_authentication():
|
||||
"""Test that account key authentication returns string credentials"""
|
||||
config = AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
account_key="test-account-key",
|
||||
)
|
||||
|
||||
credential = config.get_credentials()
|
||||
assert isinstance(credential, str)
|
||||
assert credential == "test-account-key"
|
||||
|
||||
|
||||
def test_sas_token_authentication():
|
||||
"""Test that SAS token authentication returns string credentials"""
|
||||
config = AzureConnectionConfig(
|
||||
account_name="testaccount",
|
||||
container_name="testcontainer",
|
||||
sas_token="test-sas-token",
|
||||
)
|
||||
|
||||
credential = config.get_credentials()
|
||||
assert isinstance(credential, str)
|
||||
assert credential == "test-sas-token"
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user