From 010da3c48071772be872c99a50553dc7ddce6621 Mon Sep 17 00:00:00 2001 From: Jonny Dixon <45681293+acrylJonny@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:44:31 +0100 Subject: [PATCH] fix(ingestion/abs): updated deprecated azure sdk parameter to supported parameter and uri prefix support of https (#14106) --- .../datahub/ingestion/source/abs/source.py | 2 +- .../ingestion/source/azure/azure_common.py | 4 +- .../source/data_lake_common/object_store.py | 100 +- .../tests/unit/abs/test_abs_source.py | 224 +++++ .../tests/unit/data_lake/test_object_store.py | 882 ++++++++++++------ 5 files changed, 888 insertions(+), 324 deletions(-) create mode 100644 metadata-ingestion/tests/unit/abs/test_abs_source.py diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py index 586e7a3af3..fa9e07b36a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py @@ -533,7 +533,7 @@ class ABSSource(StatefulIngestionSourceBase): ) path_spec.sample_files = False for obj in container_client.list_blobs( - prefix=f"{prefix}", results_per_page=PAGE_SIZE + name_starts_with=f"{prefix}", results_per_page=PAGE_SIZE ): abs_path = self.create_abs_path(obj.name) logger.debug(f"Path: {abs_path}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py index 46de4e09d7..c1ea000b7d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py @@ -61,13 +61,13 @@ class AzureConnectionConfig(ConfigModel): def get_blob_service_client(self): return BlobServiceClient( account_url=f"https://{self.account_name}.blob.core.windows.net", - credential=f"{self.get_credentials()}", + credential=self.get_credentials(), ) def get_data_lake_service_client(self) -> DataLakeServiceClient: return DataLakeServiceClient( account_url=f"https://{self.account_name}.dfs.core.windows.net", - credential=f"{self.get_credentials()}", + credential=self.get_credentials(), ) def get_credentials( diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/object_store.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/object_store.py index 36076f524a..d1c36d8d96 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/object_store.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/object_store.py @@ -1,3 +1,4 @@ +import re from abc import ABC, abstractmethod # Add imports for source customization @@ -236,42 +237,76 @@ class ABSObjectStore(ObjectStoreInterface): """Implementation of ObjectStoreInterface for Azure Blob Storage.""" PREFIX = "abfss://" + HTTPS_REGEX = re.compile(r"(https?://[a-z0-9]{3,24}\.blob\.core\.windows\.net/)") @classmethod def is_uri(cls, uri: str) -> bool: - return uri.startswith(cls.PREFIX) + return uri.startswith(cls.PREFIX) or bool(cls.HTTPS_REGEX.match(uri)) @classmethod def get_prefix(cls, uri: str) -> Optional[str]: if uri.startswith(cls.PREFIX): return cls.PREFIX + + # Check for HTTPS format + match = cls.HTTPS_REGEX.match(uri) + if match: + return match.group(1) + return None @classmethod def strip_prefix(cls, uri: str) -> str: - prefix = cls.get_prefix(uri) - if not prefix: - raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}") - return uri[len(prefix) :] + if uri.startswith(cls.PREFIX): + return uri[len(cls.PREFIX) :] + + # Handle HTTPS format + match = cls.HTTPS_REGEX.match(uri) + if match: + return uri[len(match.group(1)) :] + + raise ValueError( + f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern" + ) @classmethod def get_bucket_name(cls, uri: str) -> str: if not cls.is_uri(uri): - raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}") - return cls.strip_prefix(uri).split("@")[0] + raise ValueError( + f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern" + ) + + if uri.startswith(cls.PREFIX): + # abfss://container@account.dfs.core.windows.net/path + return cls.strip_prefix(uri).split("@")[0] + else: + # https://account.blob.core.windows.net/container/path + return cls.strip_prefix(uri).split("/")[0] @classmethod def get_object_key(cls, uri: str) -> str: if not cls.is_uri(uri): - raise ValueError(f"Not an ABS URI. Must start with prefix: {cls.PREFIX}") - parts = cls.strip_prefix(uri).split("@", 1) - if len(parts) < 2: - return "" - account_path = parts[1] - path_parts = account_path.split("/", 1) - if len(path_parts) < 2: - return "" - return path_parts[1] + raise ValueError( + f"Not an ABS URI. Must start with prefix: {cls.PREFIX} or match Azure Blob Storage HTTPS pattern" + ) + + if uri.startswith(cls.PREFIX): + # abfss://container@account.dfs.core.windows.net/path + parts = cls.strip_prefix(uri).split("@", 1) + if len(parts) < 2: + return "" + account_path = parts[1] + path_parts = account_path.split("/", 1) + if len(path_parts) < 2: + return "" + return path_parts[1] + else: + # https://account.blob.core.windows.net/container/path + stripped = cls.strip_prefix(uri) + parts = stripped.split("/", 1) + if len(parts) < 2: + return "" + return parts[1] # Registry of all object store implementations @@ -331,6 +366,12 @@ def get_object_store_bucket_name(uri: str) -> str: return uri[prefix_length:].split("/")[0] elif uri.startswith(ABSObjectStore.PREFIX): return uri[len(ABSObjectStore.PREFIX) :].split("@")[0] + elif ABSObjectStore.HTTPS_REGEX.match(uri): + # Handle HTTPS Azure Blob Storage URLs + match = ABSObjectStore.HTTPS_REGEX.match(uri) + if match: + stripped = uri[len(match.group(1)) :] + return stripped.split("/")[0] raise ValueError(f"Unsupported URI format: {uri}") @@ -470,18 +511,25 @@ class ObjectStoreSourceAdapter: if not ABSObjectStore.is_uri(table_data.table_path): return None - # Parse the ABS URI try: - # URI format: abfss://container@account.dfs.core.windows.net/path - path_without_prefix = ABSObjectStore.strip_prefix(table_data.table_path) - parts = path_without_prefix.split("@", 1) - if len(parts) < 2: - return None + if table_data.table_path.startswith("abfss://"): + # URI format: abfss://container@account.dfs.core.windows.net/path + path_without_prefix = ABSObjectStore.strip_prefix(table_data.table_path) + parts = path_without_prefix.split("@", 1) + if len(parts) < 2: + return None - container_name = parts[0] - account_parts = parts[1].split("/", 1) - account_domain = account_parts[0] - account_name = account_domain.split(".")[0] + container_name = parts[0] + account_parts = parts[1].split("/", 1) + account_domain = account_parts[0] + account_name = account_domain.split(".")[0] + else: + # Handle HTTPS format: https://account.blob.core.windows.net/container/path + container_name = ABSObjectStore.get_bucket_name(table_data.table_path) + if "blob.core.windows.net" in table_data.table_path: + account_name = table_data.table_path.split("//")[1].split(".")[0] + else: + return None # Construct Azure portal URL return f"https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/{account_name}/containerName/{container_name}" diff --git a/metadata-ingestion/tests/unit/abs/test_abs_source.py b/metadata-ingestion/tests/unit/abs/test_abs_source.py new file mode 100644 index 0000000000..065d359220 --- /dev/null +++ b/metadata-ingestion/tests/unit/abs/test_abs_source.py @@ -0,0 +1,224 @@ +from unittest.mock import Mock, patch + +import pytest +from azure.identity import ClientSecretCredential +from azure.storage.blob import BlobServiceClient +from azure.storage.filedatalake import DataLakeServiceClient + +from datahub.ingestion.source.azure.azure_common import AzureConnectionConfig + + +def test_service_principal_credentials_return_objects(): + """Service principal credentials must return ClientSecretCredential objects, not strings""" + config = AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ) + + credential = config.get_credentials() + + assert isinstance(credential, ClientSecretCredential) + assert not isinstance(credential, str) + + +@pytest.mark.parametrize( + "auth_type,config_params,expected_type", + [ + ( + "service_principal", + { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "tenant_id": "test-tenant-id", + }, + ClientSecretCredential, + ), + ("account_key", {"account_key": "test-account-key"}, str), + ("sas_token", {"sas_token": "test-sas-token"}, str), + ], +) +def test_credential_types_by_auth_method(auth_type, config_params, expected_type): + """Test that different authentication methods return correct credential types""" + base_config = {"account_name": "testaccount", "container_name": "testcontainer"} + config = AzureConnectionConfig(**{**base_config, **config_params}) + + credential = config.get_credentials() + assert isinstance(credential, expected_type) + + +def test_credential_object_not_converted_to_string(): + """Credential objects should not be accidentally converted to strings via f-string formatting""" + config = AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ) + + credential = config.get_credentials() + credential_as_string = f"{credential}" + + assert isinstance(credential, ClientSecretCredential) + assert credential != credential_as_string + assert "ClientSecretCredential" in str(credential) + + +@pytest.mark.parametrize( + "service_client_class,method_name", + [ + (BlobServiceClient, "get_blob_service_client"), + (DataLakeServiceClient, "get_data_lake_service_client"), + ], +) +def test_service_clients_receive_credential_objects(service_client_class, method_name): + """Both BlobServiceClient and DataLakeServiceClient should receive credential objects""" + config = AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ) + + with patch( + f"datahub.ingestion.source.azure.azure_common.{service_client_class.__name__}" + ) as mock_client: + getattr(config, method_name)() + + mock_client.assert_called_once() + credential = mock_client.call_args[1]["credential"] + assert isinstance(credential, ClientSecretCredential) + assert not isinstance(credential, str) + + +@pytest.mark.parametrize( + "deprecated_param,new_param", + [ + ("prefix", "name_starts_with"), + ], +) +def test_azure_sdk_parameter_deprecation(deprecated_param, new_param): + """Test that demonstrates the Azure SDK parameter deprecation issue""" + # This test shows why the fix was needed - deprecated params cause errors + mock_container_client = Mock() + + def list_blobs_with_validation(**kwargs): + if deprecated_param in kwargs: + raise ValueError( + f"Passing '{deprecated_param}' has no effect on filtering, please use the '{new_param}' parameter instead." + ) + return [] + + mock_container_client.list_blobs.side_effect = list_blobs_with_validation + + # Test that the deprecated parameter causes an error (this is what was happening before the fix) + with pytest.raises(ValueError) as exc_info: + mock_container_client.list_blobs( + **{deprecated_param: "test/path", "results_per_page": 1000} + ) + + assert new_param in str(exc_info.value) + assert deprecated_param in str(exc_info.value) + + # Test that the new parameter works (this is what the fix implemented) + mock_container_client.list_blobs.side_effect = None + mock_container_client.list_blobs.return_value = [] + + result = mock_container_client.list_blobs( + **{new_param: "test/path", "results_per_page": 1000} + ) + assert result == [] + + +@patch("datahub.ingestion.source.azure.azure_common.BlobServiceClient") +def test_datahub_source_uses_correct_azure_parameters(mock_blob_service_client_class): + """Test that DataHub source code actually uses the correct Azure SDK parameters""" + # This test verifies that the real DataHub code calls Azure SDK with correct parameters + mock_container_client = Mock() + mock_blob_service_client = Mock() + mock_blob_service_client.get_container_client.return_value = mock_container_client + mock_blob_service_client_class.return_value = mock_blob_service_client + + # Mock the blob objects returned by list_blobs + mock_blob = Mock() + mock_blob.name = "test/path/file.csv" + mock_blob.size = 1024 + mock_container_client.list_blobs.return_value = [mock_blob] + + # Now test the REAL DataHub code + from datahub.ingestion.api.common import PipelineContext + from datahub.ingestion.source.abs.config import DataLakeSourceConfig + from datahub.ingestion.source.abs.source import ABSSource + from datahub.ingestion.source.data_lake_common.path_spec import PathSpec + + # Create real DataHub source + source_config = DataLakeSourceConfig( + platform="abs", + azure_config=AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ), + path_specs=[ + PathSpec( + include="https://testaccount.blob.core.windows.net/testcontainer/test/*.*", + exclude=[], + file_types=["csv"], + sample_files=False, + ) + ], + ) + + pipeline_context = PipelineContext(run_id="test-run-id", pipeline_name="abs-source") + pipeline_context.graph = Mock() + source = ABSSource(source_config, pipeline_context) + + # Call the REAL DataHub method + with patch( + "datahub.ingestion.source.abs.source.get_container_relative_path", + return_value="test/path", + ): + path_spec = source_config.path_specs[0] + list(source.abs_browser(path_spec, 100)) + + # NOW verify the real DataHub code called Azure SDK with correct parameters + mock_container_client.list_blobs.assert_called_once_with( + name_starts_with="test/path", results_per_page=1000 + ) + + # Verify the fix worked - no deprecated 'prefix' parameter + call_args = mock_container_client.list_blobs.call_args + assert "name_starts_with" in call_args[1] + assert "prefix" not in call_args[1] + + +def test_account_key_authentication(): + """Test that account key authentication returns string credentials""" + config = AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + account_key="test-account-key", + ) + + credential = config.get_credentials() + assert isinstance(credential, str) + assert credential == "test-account-key" + + +def test_sas_token_authentication(): + """Test that SAS token authentication returns string credentials""" + config = AzureConnectionConfig( + account_name="testaccount", + container_name="testcontainer", + sas_token="test-sas-token", + ) + + credential = config.get_credentials() + assert isinstance(credential, str) + assert credential == "test-sas-token" diff --git a/metadata-ingestion/tests/unit/data_lake/test_object_store.py b/metadata-ingestion/tests/unit/data_lake/test_object_store.py index fb376f24bf..09dfc4d1bc 100644 --- a/metadata-ingestion/tests/unit/data_lake/test_object_store.py +++ b/metadata-ingestion/tests/unit/data_lake/test_object_store.py @@ -1,5 +1,4 @@ import pathlib -import unittest from unittest.mock import MagicMock import pytest @@ -16,308 +15,463 @@ from datahub.ingestion.source.data_lake_common.object_store import ( ) -class TestS3ObjectStore(unittest.TestCase): +class TestS3ObjectStore: """Tests for the S3ObjectStore implementation.""" - def test_is_uri(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", True), + ("s3n://bucket/path", True), + ("s3a://bucket/path", True), + ("gs://bucket/path", False), + ("abfss://container@account.dfs.core.windows.net/path", False), + ("https://account.blob.core.windows.net/container/path", False), + ("file:///path/to/file", False), + ], + ) + def test_is_uri(self, uri, expected): """Test the is_uri method with various URIs.""" - self.assertTrue(S3ObjectStore.is_uri("s3://bucket/path")) - self.assertTrue(S3ObjectStore.is_uri("s3n://bucket/path")) - self.assertTrue(S3ObjectStore.is_uri("s3a://bucket/path")) - self.assertFalse(S3ObjectStore.is_uri("gs://bucket/path")) - self.assertFalse( - S3ObjectStore.is_uri("abfss://container@account.dfs.core.windows.net/path") - ) - self.assertFalse(S3ObjectStore.is_uri("file:///path/to/file")) + assert S3ObjectStore.is_uri(uri) == expected - def test_get_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "s3://"), + ("s3n://bucket/path", "s3n://"), + ("s3a://bucket/path", "s3a://"), + ("gs://bucket/path", None), + ], + ) + def test_get_prefix(self, uri, expected): """Test the get_prefix method.""" - self.assertEqual(S3ObjectStore.get_prefix("s3://bucket/path"), "s3://") - self.assertEqual(S3ObjectStore.get_prefix("s3n://bucket/path"), "s3n://") - self.assertEqual(S3ObjectStore.get_prefix("s3a://bucket/path"), "s3a://") - self.assertIsNone(S3ObjectStore.get_prefix("gs://bucket/path")) + assert S3ObjectStore.get_prefix(uri) == expected - def test_strip_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "bucket/path"), + ("s3n://bucket/path", "bucket/path"), + ("s3a://bucket/path", "bucket/path"), + ], + ) + def test_strip_prefix(self, uri, expected): """Test the strip_prefix method.""" - self.assertEqual(S3ObjectStore.strip_prefix("s3://bucket/path"), "bucket/path") - self.assertEqual(S3ObjectStore.strip_prefix("s3n://bucket/path"), "bucket/path") - self.assertEqual(S3ObjectStore.strip_prefix("s3a://bucket/path"), "bucket/path") + assert S3ObjectStore.strip_prefix(uri) == expected - # Should raise ValueError for non-S3 URIs - with self.assertRaises(ValueError): + def test_strip_prefix_invalid_uri(self): + """Test strip_prefix with invalid URI.""" + with pytest.raises(ValueError): S3ObjectStore.strip_prefix("gs://bucket/path") - def test_get_bucket_name(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "bucket"), + ("s3n://my-bucket/path/to/file", "my-bucket"), + ("s3a://bucket.name/file.txt", "bucket.name"), + ], + ) + def test_get_bucket_name(self, uri, expected): """Test the get_bucket_name method.""" - self.assertEqual(S3ObjectStore.get_bucket_name("s3://bucket/path"), "bucket") - self.assertEqual( - S3ObjectStore.get_bucket_name("s3n://my-bucket/path/to/file"), "my-bucket" - ) - self.assertEqual( - S3ObjectStore.get_bucket_name("s3a://bucket.name/file.txt"), "bucket.name" - ) + assert S3ObjectStore.get_bucket_name(uri) == expected - # Should raise ValueError for non-S3 URIs - with self.assertRaises(ValueError): + def test_get_bucket_name_invalid_uri(self): + """Test get_bucket_name with invalid URI.""" + with pytest.raises(ValueError): S3ObjectStore.get_bucket_name("gs://bucket/path") - def test_get_object_key(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "path"), + ("s3://bucket/path/to/file.txt", "path/to/file.txt"), + ("s3://bucket/", ""), + ("s3://bucket", ""), + ], + ) + def test_get_object_key(self, uri, expected): """Test the get_object_key method.""" - self.assertEqual(S3ObjectStore.get_object_key("s3://bucket/path"), "path") - self.assertEqual( - S3ObjectStore.get_object_key("s3://bucket/path/to/file.txt"), - "path/to/file.txt", - ) - self.assertEqual(S3ObjectStore.get_object_key("s3://bucket/"), "") - self.assertEqual(S3ObjectStore.get_object_key("s3://bucket"), "") + assert S3ObjectStore.get_object_key(uri) == expected - # Should raise ValueError for non-S3 URIs - with self.assertRaises(ValueError): + def test_get_object_key_invalid_uri(self): + """Test get_object_key with invalid URI.""" + with pytest.raises(ValueError): S3ObjectStore.get_object_key("gs://bucket/path") -class TestGCSObjectStore(unittest.TestCase): +class TestGCSObjectStore: """Tests for the GCSObjectStore implementation.""" - def test_is_uri(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("gs://bucket/path", True), + ("s3://bucket/path", False), + ("abfss://container@account.dfs.core.windows.net/path", False), + ("https://account.blob.core.windows.net/container/path", False), + ], + ) + def test_is_uri(self, uri, expected): """Test the is_uri method with various URIs.""" - self.assertTrue(GCSObjectStore.is_uri("gs://bucket/path")) - self.assertFalse(GCSObjectStore.is_uri("s3://bucket/path")) - self.assertFalse( - GCSObjectStore.is_uri("abfss://container@account.dfs.core.windows.net/path") - ) + assert GCSObjectStore.is_uri(uri) == expected - def test_get_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("gs://bucket/path", "gs://"), + ("s3://bucket/path", None), + ], + ) + def test_get_prefix(self, uri, expected): """Test the get_prefix method.""" - self.assertEqual(GCSObjectStore.get_prefix("gs://bucket/path"), "gs://") - self.assertIsNone(GCSObjectStore.get_prefix("s3://bucket/path")) + assert GCSObjectStore.get_prefix(uri) == expected - def test_strip_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("gs://bucket/path", "bucket/path"), + ], + ) + def test_strip_prefix(self, uri, expected): """Test the strip_prefix method.""" - self.assertEqual(GCSObjectStore.strip_prefix("gs://bucket/path"), "bucket/path") + assert GCSObjectStore.strip_prefix(uri) == expected - # Should raise ValueError for non-GCS URIs - with self.assertRaises(ValueError): + def test_strip_prefix_invalid_uri(self): + """Test strip_prefix with invalid URI.""" + with pytest.raises(ValueError): GCSObjectStore.strip_prefix("s3://bucket/path") - def test_get_bucket_name(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("gs://bucket/path", "bucket"), + ("gs://my-bucket/path/to/file", "my-bucket"), + ], + ) + def test_get_bucket_name(self, uri, expected): """Test the get_bucket_name method.""" - self.assertEqual(GCSObjectStore.get_bucket_name("gs://bucket/path"), "bucket") - self.assertEqual( - GCSObjectStore.get_bucket_name("gs://my-bucket/path/to/file"), "my-bucket" - ) + assert GCSObjectStore.get_bucket_name(uri) == expected - # Should raise ValueError for non-GCS URIs - with self.assertRaises(ValueError): + def test_get_bucket_name_invalid_uri(self): + """Test get_bucket_name with invalid URI.""" + with pytest.raises(ValueError): GCSObjectStore.get_bucket_name("s3://bucket/path") - def test_get_object_key(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("gs://bucket/path", "path"), + ("gs://bucket/path/to/file.txt", "path/to/file.txt"), + ("gs://bucket/", ""), + ("gs://bucket", ""), + ], + ) + def test_get_object_key(self, uri, expected): """Test the get_object_key method.""" - self.assertEqual(GCSObjectStore.get_object_key("gs://bucket/path"), "path") - self.assertEqual( - GCSObjectStore.get_object_key("gs://bucket/path/to/file.txt"), - "path/to/file.txt", - ) - self.assertEqual(GCSObjectStore.get_object_key("gs://bucket/"), "") - self.assertEqual(GCSObjectStore.get_object_key("gs://bucket"), "") + assert GCSObjectStore.get_object_key(uri) == expected - # Should raise ValueError for non-GCS URIs - with self.assertRaises(ValueError): + def test_get_object_key_invalid_uri(self): + """Test get_object_key with invalid URI.""" + with pytest.raises(ValueError): GCSObjectStore.get_object_key("s3://bucket/path") -class TestABSObjectStore(unittest.TestCase): +class TestABSObjectStore: """Tests for the ABSObjectStore implementation.""" - def test_is_uri(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("abfss://container@account.dfs.core.windows.net/path", True), + ("https://account.blob.core.windows.net/container/path", True), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + True, + ), + ("s3://bucket/path", False), + ("gs://bucket/path", False), + ("https://example.com/path", False), + ], + ) + def test_is_uri(self, uri, expected): """Test the is_uri method with various URIs.""" - self.assertTrue( - ABSObjectStore.is_uri("abfss://container@account.dfs.core.windows.net/path") - ) - self.assertFalse(ABSObjectStore.is_uri("s3://bucket/path")) - self.assertFalse(ABSObjectStore.is_uri("gs://bucket/path")) + assert ABSObjectStore.is_uri(uri) == expected - def test_get_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("abfss://container@account.dfs.core.windows.net/path", "abfss://"), + ( + "https://account.blob.core.windows.net/container/path", + "https://account.blob.core.windows.net/", + ), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "https://odedmdatacatalog.blob.core.windows.net/", + ), + ("s3://bucket/path", None), + ("https://example.com/path", None), + ], + ) + def test_get_prefix(self, uri, expected): """Test the get_prefix method.""" - self.assertEqual( - ABSObjectStore.get_prefix( - "abfss://container@account.dfs.core.windows.net/path" - ), - "abfss://", - ) - self.assertIsNone(ABSObjectStore.get_prefix("s3://bucket/path")) + assert ABSObjectStore.get_prefix(uri) == expected - def test_strip_prefix(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "abfss://container@account.dfs.core.windows.net/path", + "container@account.dfs.core.windows.net/path", + ), + ("https://account.blob.core.windows.net/container/path", "container/path"), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler/import_export_services/message_data_randomized.csv", + ), + ], + ) + def test_strip_prefix(self, uri, expected): """Test the strip_prefix method.""" - self.assertEqual( - ABSObjectStore.strip_prefix( - "abfss://container@account.dfs.core.windows.net/path" - ), - "container@account.dfs.core.windows.net/path", - ) + assert ABSObjectStore.strip_prefix(uri) == expected - # Should raise ValueError for non-ABS URIs - with self.assertRaises(ValueError): + def test_strip_prefix_invalid_uri(self): + """Test strip_prefix with invalid URI.""" + with pytest.raises(ValueError): ABSObjectStore.strip_prefix("s3://bucket/path") - def test_get_bucket_name(self): - """Test the get_bucket_name method.""" - self.assertEqual( - ABSObjectStore.get_bucket_name( - "abfss://container@account.dfs.core.windows.net/path" + @pytest.mark.parametrize( + "uri,expected", + [ + ("abfss://container@account.dfs.core.windows.net/path", "container"), + ("https://account.blob.core.windows.net/container/path", "container"), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler", ), - "container", - ) + ], + ) + def test_get_bucket_name(self, uri, expected): + """Test the get_bucket_name method.""" + assert ABSObjectStore.get_bucket_name(uri) == expected - # Should raise ValueError for non-ABS URIs - with self.assertRaises(ValueError): + def test_get_bucket_name_invalid_uri(self): + """Test get_bucket_name with invalid URI.""" + with pytest.raises(ValueError): ABSObjectStore.get_bucket_name("s3://bucket/path") - def test_get_object_key(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("abfss://container@account.dfs.core.windows.net/path", "path"), + ( + "abfss://container@account.dfs.core.windows.net/path/to/file.txt", + "path/to/file.txt", + ), + ("https://account.blob.core.windows.net/container/path", "path"), + ( + "https://account.blob.core.windows.net/container/path/to/file.txt", + "path/to/file.txt", + ), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "import_export_services/message_data_randomized.csv", + ), + ], + ) + def test_get_object_key(self, uri, expected): """Test the get_object_key method.""" - self.assertEqual( - ABSObjectStore.get_object_key( - "abfss://container@account.dfs.core.windows.net/path" - ), - "path", - ) - self.assertEqual( - ABSObjectStore.get_object_key( - "abfss://container@account.dfs.core.windows.net/path/to/file.txt" - ), - "path/to/file.txt", - ) + assert ABSObjectStore.get_object_key(uri) == expected - # Should raise ValueError for non-ABS URIs - with self.assertRaises(ValueError): + def test_get_object_key_invalid_uri(self): + """Test get_object_key with invalid URI.""" + with pytest.raises(ValueError): ABSObjectStore.get_object_key("s3://bucket/path") -class TestUtilityFunctions(unittest.TestCase): +class TestUtilityFunctions: """Tests for the utility functions in object_store module.""" - def test_get_object_store_for_uri(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", S3ObjectStore), + ("gs://bucket/path", GCSObjectStore), + ("abfss://container@account.dfs.core.windows.net/path", ABSObjectStore), + ("https://account.blob.core.windows.net/container/path", ABSObjectStore), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + ABSObjectStore, + ), + ("file:///path/to/file", None), + ], + ) + def test_get_object_store_for_uri(self, uri, expected): """Test the get_object_store_for_uri function.""" - self.assertEqual(get_object_store_for_uri("s3://bucket/path"), S3ObjectStore) - self.assertEqual(get_object_store_for_uri("gs://bucket/path"), GCSObjectStore) - self.assertEqual( - get_object_store_for_uri( - "abfss://container@account.dfs.core.windows.net/path" - ), - ABSObjectStore, - ) - self.assertIsNone(get_object_store_for_uri("file:///path/to/file")) + assert get_object_store_for_uri(uri) == expected - def test_get_object_store_bucket_name(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "bucket"), + ("gs://bucket/path", "bucket"), + ("abfss://container@account.dfs.core.windows.net/path", "container"), + ("https://account.blob.core.windows.net/container/path", "container"), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler", + ), + ], + ) + def test_get_object_store_bucket_name(self, uri, expected): """Test the get_object_store_bucket_name function.""" - self.assertEqual(get_object_store_bucket_name("s3://bucket/path"), "bucket") - self.assertEqual(get_object_store_bucket_name("gs://bucket/path"), "bucket") - self.assertEqual( - get_object_store_bucket_name( - "abfss://container@account.dfs.core.windows.net/path" - ), - "container", - ) + assert get_object_store_bucket_name(uri) == expected - # Should raise ValueError for unsupported URIs - with self.assertRaises(ValueError): + def test_get_object_store_bucket_name_invalid_uri(self): + """Test get_object_store_bucket_name with unsupported URI.""" + with pytest.raises(ValueError): get_object_store_bucket_name("file:///path/to/file") - def test_get_object_key(self): + @pytest.mark.parametrize( + "uri,expected", + [ + ("s3://bucket/path", "path"), + ("gs://bucket/path/to/file.txt", "path/to/file.txt"), + ("abfss://container@account.dfs.core.windows.net/path", "path"), + ("https://account.blob.core.windows.net/container/path", "path"), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "import_export_services/message_data_randomized.csv", + ), + ], + ) + def test_get_object_key(self, uri, expected): """Test the get_object_key function.""" - self.assertEqual(get_object_key("s3://bucket/path"), "path") - self.assertEqual( - get_object_key("gs://bucket/path/to/file.txt"), "path/to/file.txt" - ) - self.assertEqual( - get_object_key("abfss://container@account.dfs.core.windows.net/path"), - "path", - ) + assert get_object_key(uri) == expected - # Should raise ValueError for unsupported URIs - with self.assertRaises(ValueError): + def test_get_object_key_invalid_uri(self): + """Test get_object_key with unsupported URI.""" + with pytest.raises(ValueError): get_object_key("file:///path/to/file") -class TestObjectStoreSourceAdapter(unittest.TestCase): +class TestObjectStoreSourceAdapter: """Tests for the ObjectStoreSourceAdapter class.""" - def test_create_s3_path(self): + @pytest.mark.parametrize( + "bucket,key,expected", + [ + ("bucket", "path/to/file.txt", "s3://bucket/path/to/file.txt"), + ("my-bucket", "file.json", "s3://my-bucket/file.json"), + ], + ) + def test_create_s3_path(self, bucket, key, expected): """Test the create_s3_path static method.""" - self.assertEqual( - ObjectStoreSourceAdapter.create_s3_path("bucket", "path/to/file.txt"), - "s3://bucket/path/to/file.txt", - ) + assert ObjectStoreSourceAdapter.create_s3_path(bucket, key) == expected - def test_create_gcs_path(self): + @pytest.mark.parametrize( + "bucket,key,expected", + [ + ("bucket", "path/to/file.txt", "gs://bucket/path/to/file.txt"), + ("my-bucket", "file.json", "gs://my-bucket/file.json"), + ], + ) + def test_create_gcs_path(self, bucket, key, expected): """Test the create_gcs_path static method.""" - self.assertEqual( - ObjectStoreSourceAdapter.create_gcs_path("bucket", "path/to/file.txt"), - "gs://bucket/path/to/file.txt", - ) + assert ObjectStoreSourceAdapter.create_gcs_path(bucket, key) == expected - def test_create_abs_path(self): - """Test the create_abs_path static method.""" - self.assertEqual( - ObjectStoreSourceAdapter.create_abs_path( - "container", "path/to/file.txt", "storage" + @pytest.mark.parametrize( + "container,key,account,expected", + [ + ( + "container", + "path/to/file.txt", + "storage", + "abfss://container@storage.dfs.core.windows.net/path/to/file.txt", ), - "abfss://container@storage.dfs.core.windows.net/path/to/file.txt", + ( + "data", + "file.json", + "myaccount", + "abfss://data@myaccount.dfs.core.windows.net/file.json", + ), + ], + ) + def test_create_abs_path(self, container, key, account, expected): + """Test the create_abs_path static method.""" + assert ( + ObjectStoreSourceAdapter.create_abs_path(container, key, account) + == expected ) - def test_get_s3_external_url(self): + @pytest.mark.parametrize( + "table_path,region,expected", + [ + ( + "s3://bucket/path/to/file.txt", + None, + "https://us-east-1.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", + ), + ( + "s3://bucket/path/to/file.txt", + "us-west-2", + "https://us-west-2.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", + ), + ("gs://bucket/path", None, None), + ], + ) + def test_get_s3_external_url(self, table_path, region, expected): """Test the get_s3_external_url static method.""" mock_table_data = MagicMock() - mock_table_data.table_path = "s3://bucket/path/to/file.txt" - - # Test with default region - self.assertEqual( - ObjectStoreSourceAdapter.get_s3_external_url(mock_table_data), - "https://us-east-1.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", + mock_table_data.table_path = table_path + assert ( + ObjectStoreSourceAdapter.get_s3_external_url(mock_table_data, region) + == expected ) - # Test with custom region - self.assertEqual( - ObjectStoreSourceAdapter.get_s3_external_url(mock_table_data, "us-west-2"), - "https://us-west-2.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", - ) - - # Test with non-S3 URI - mock_table_data.table_path = "gs://bucket/path" - self.assertIsNone(ObjectStoreSourceAdapter.get_s3_external_url(mock_table_data)) - - def test_get_gcs_external_url(self): + @pytest.mark.parametrize( + "table_path,expected", + [ + ( + "gs://bucket/path/to/file.txt", + "https://console.cloud.google.com/storage/browser/bucket/path/to/file.txt", + ), + ("s3://bucket/path", None), + ], + ) + def test_get_gcs_external_url(self, table_path, expected): """Test the get_gcs_external_url static method.""" mock_table_data = MagicMock() - mock_table_data.table_path = "gs://bucket/path/to/file.txt" - - self.assertEqual( - ObjectStoreSourceAdapter.get_gcs_external_url(mock_table_data), - "https://console.cloud.google.com/storage/browser/bucket/path/to/file.txt", + mock_table_data.table_path = table_path + assert ( + ObjectStoreSourceAdapter.get_gcs_external_url(mock_table_data) == expected ) - # Test with non-GCS URI - mock_table_data.table_path = "s3://bucket/path" - self.assertIsNone( - ObjectStoreSourceAdapter.get_gcs_external_url(mock_table_data) - ) - - def test_get_abs_external_url(self): + @pytest.mark.parametrize( + "table_path,expected", + [ + ( + "abfss://container@account.dfs.core.windows.net/path/to/file.txt", + "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", + ), + ( + "https://account.blob.core.windows.net/container/path/to/file.txt", + "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", + ), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/odedmdatacatalog/containerName/settler", + ), + ("s3://bucket/path", None), + ], + ) + def test_get_abs_external_url(self, table_path, expected): """Test the get_abs_external_url static method.""" mock_table_data = MagicMock() - mock_table_data.table_path = ( - "abfss://container@account.dfs.core.windows.net/path/to/file.txt" - ) - - self.assertEqual( - ObjectStoreSourceAdapter.get_abs_external_url(mock_table_data), - "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", - ) - - # Test with non-ABS URI - mock_table_data.table_path = "s3://bucket/path" - self.assertIsNone( - ObjectStoreSourceAdapter.get_abs_external_url(mock_table_data) + mock_table_data.table_path = table_path + assert ( + ObjectStoreSourceAdapter.get_abs_external_url(mock_table_data) == expected ) def test_adapter_initialization(self): @@ -326,16 +480,16 @@ class TestObjectStoreSourceAdapter(unittest.TestCase): s3_adapter = ObjectStoreSourceAdapter( platform="s3", platform_name="Amazon S3", aws_region="us-west-2" ) - self.assertEqual(s3_adapter.platform, "s3") - self.assertEqual(s3_adapter.platform_name, "Amazon S3") - self.assertEqual(s3_adapter.aws_region, "us-west-2") + assert s3_adapter.platform == "s3" + assert s3_adapter.platform_name == "Amazon S3" + assert s3_adapter.aws_region == "us-west-2" # Test GCS adapter gcs_adapter = ObjectStoreSourceAdapter( platform="gcs", platform_name="Google Cloud Storage" ) - self.assertEqual(gcs_adapter.platform, "gcs") - self.assertEqual(gcs_adapter.platform_name, "Google Cloud Storage") + assert gcs_adapter.platform == "gcs" + assert gcs_adapter.platform_name == "Google Cloud Storage" # Test ABS adapter abs_adapter = ObjectStoreSourceAdapter( @@ -343,9 +497,9 @@ class TestObjectStoreSourceAdapter(unittest.TestCase): platform_name="Azure Blob Storage", azure_storage_account="myaccount", ) - self.assertEqual(abs_adapter.platform, "abs") - self.assertEqual(abs_adapter.platform_name, "Azure Blob Storage") - self.assertEqual(abs_adapter.azure_storage_account, "myaccount") + assert abs_adapter.platform == "abs" + assert abs_adapter.platform_name == "Azure Blob Storage" + assert abs_adapter.azure_storage_account == "myaccount" def test_register_customization(self): """Test registering customizations.""" @@ -357,8 +511,8 @@ class TestObjectStoreSourceAdapter(unittest.TestCase): adapter.register_customization("custom_method", custom_func) - self.assertIn("custom_method", adapter.customizations) - self.assertEqual(adapter.customizations["custom_method"], custom_func) + assert "custom_method" in adapter.customizations + assert adapter.customizations["custom_method"] == custom_func def test_apply_customizations(self): """Test applying customizations to a source.""" @@ -378,80 +532,190 @@ class TestObjectStoreSourceAdapter(unittest.TestCase): result = adapter.apply_customizations(mock_source) # Check that the platform was set - self.assertEqual(mock_source.source_config.platform, "s3") + assert mock_source.source_config.platform == "s3" # Check that the custom method was added - self.assertTrue(hasattr(mock_source, "custom_method")) - self.assertEqual(mock_source.custom_method, custom_func) + assert hasattr(mock_source, "custom_method") + assert mock_source.custom_method == custom_func # Check that the result is the same object - self.assertEqual(result, mock_source) + assert result == mock_source - def test_get_external_url(self): + @pytest.mark.parametrize( + "platform,table_path,expected_url", + [ + ( + "s3", + "s3://bucket/path/to/file.txt", + "https://us-east-1.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", + ), + ( + "gcs", + "gs://bucket/path/to/file.txt", + "https://console.cloud.google.com/storage/browser/bucket/path/to/file.txt", + ), + ( + "abs", + "abfss://container@account.dfs.core.windows.net/path/to/file.txt", + "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", + ), + ( + "abs", + "https://account.blob.core.windows.net/container/path/to/file.txt", + "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", + ), + ], + ) + def test_get_external_url(self, platform, table_path, expected_url): """Test the get_external_url method.""" mock_table_data = MagicMock() - mock_table_data.table_path = "s3://bucket/path/to/file.txt" + mock_table_data.table_path = table_path - # Test S3 adapter - s3_adapter = ObjectStoreSourceAdapter( - platform="s3", platform_name="Amazon S3", aws_region="us-west-2" - ) - self.assertEqual( - s3_adapter.get_external_url(mock_table_data), - "https://us-west-2.console.aws.amazon.com/s3/buckets/bucket?prefix=path/to/file.txt", - ) + if platform == "s3": + adapter = ObjectStoreSourceAdapter( + platform=platform, platform_name="Amazon S3", aws_region="us-east-1" + ) + elif platform == "gcs": + adapter = ObjectStoreSourceAdapter( + platform=platform, platform_name="Google Cloud Storage" + ) + elif platform == "abs": + adapter = ObjectStoreSourceAdapter( + platform=platform, platform_name="Azure Blob Storage" + ) - # Test GCS adapter - mock_table_data.table_path = "gs://bucket/path/to/file.txt" - gcs_adapter = ObjectStoreSourceAdapter( - platform="gcs", platform_name="Google Cloud Storage" - ) - self.assertEqual( - gcs_adapter.get_external_url(mock_table_data), - "https://console.cloud.google.com/storage/browser/bucket/path/to/file.txt", - ) - - # Test ABS adapter - mock_table_data.table_path = ( - "abfss://container@account.dfs.core.windows.net/path/to/file.txt" - ) - abs_adapter = ObjectStoreSourceAdapter( - platform="abs", platform_name="Azure Blob Storage" - ) - self.assertEqual( - abs_adapter.get_external_url(mock_table_data), - "https://portal.azure.com/#blade/Microsoft_Azure_Storage/ContainerMenuBlade/overview/storageAccountId/account/containerName/container", - ) + assert adapter.get_external_url(mock_table_data) == expected_url -class TestCreateObjectStoreAdapter(unittest.TestCase): +class TestCreateObjectStoreAdapter: """Tests for the create_object_store_adapter function.""" - def test_create_s3_adapter(self): - """Test creating an S3 adapter.""" - adapter = create_object_store_adapter("s3", aws_region="us-west-2") - self.assertEqual(adapter.platform, "s3") - self.assertEqual(adapter.platform_name, "Amazon S3") - self.assertEqual(adapter.aws_region, "us-west-2") + @pytest.mark.parametrize( + "platform,aws_region,azure_storage_account,expected_platform,expected_name", + [ + ("s3", "us-west-2", None, "s3", "Amazon S3"), + ("gcs", None, None, "gcs", "Google Cloud Storage"), + ("abs", None, "myaccount", "abs", "Azure Blob Storage"), + ("unknown", None, None, "unknown", "Unknown (unknown)"), + ], + ) + def test_create_adapter( + self, + platform, + aws_region, + azure_storage_account, + expected_platform, + expected_name, + ): + """Test creating adapters for different platforms.""" + adapter = create_object_store_adapter( + platform, aws_region=aws_region, azure_storage_account=azure_storage_account + ) + assert adapter.platform == expected_platform + assert adapter.platform_name == expected_name + if aws_region: + assert adapter.aws_region == aws_region + if azure_storage_account: + assert adapter.azure_storage_account == azure_storage_account - def test_create_gcs_adapter(self): - """Test creating a GCS adapter.""" - adapter = create_object_store_adapter("gcs") - self.assertEqual(adapter.platform, "gcs") - self.assertEqual(adapter.platform_name, "Google Cloud Storage") - def test_create_abs_adapter(self): - """Test creating an ABS adapter.""" - adapter = create_object_store_adapter("abs", azure_storage_account="myaccount") - self.assertEqual(adapter.platform, "abs") - self.assertEqual(adapter.platform_name, "Azure Blob Storage") - self.assertEqual(adapter.azure_storage_account, "myaccount") +class TestABSHTTPSSupport: + """Tests specifically for HTTPS Azure Blob Storage support.""" - def test_create_unknown_adapter(self): - """Test creating an adapter for an unknown platform.""" - adapter = create_object_store_adapter("unknown") - self.assertEqual(adapter.platform, "unknown") - self.assertEqual(adapter.platform_name, "Unknown (unknown)") + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + True, + ), + ("https://account.blob.core.windows.net/container/path", True), + ("https://myaccount123.blob.core.windows.net/data/file.json", True), + ("https://google.com/path", False), + ("https://example.com/path", False), + ], + ) + def test_https_uri_detection(self, uri, expected): + """Test that HTTPS Azure Blob Storage URIs are detected correctly.""" + assert ABSObjectStore.is_uri(uri) == expected + + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler", + ), + ("https://account.blob.core.windows.net/container/path", "container"), + ("https://myaccount123.blob.core.windows.net/data/file.json", "data"), + ], + ) + def test_https_container_extraction(self, uri, expected): + """Test container name extraction from HTTPS URIs.""" + assert ABSObjectStore.get_bucket_name(uri) == expected + + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "import_export_services/message_data_randomized.csv", + ), + ("https://account.blob.core.windows.net/container/path", "path"), + ("https://myaccount123.blob.core.windows.net/data/file.json", "file.json"), + ], + ) + def test_https_object_key_extraction(self, uri, expected): + """Test object key extraction from HTTPS URIs.""" + assert ABSObjectStore.get_object_key(uri) == expected + + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "https://odedmdatacatalog.blob.core.windows.net/", + ), + ( + "https://account.blob.core.windows.net/container/path", + "https://account.blob.core.windows.net/", + ), + ], + ) + def test_https_prefix_extraction(self, uri, expected): + """Test prefix extraction from HTTPS URIs.""" + assert ABSObjectStore.get_prefix(uri) == expected + + @pytest.mark.parametrize( + "uri,expected", + [ + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler", + ), + ("https://account.blob.core.windows.net/container/path", "container"), + ], + ) + def test_fallback_bucket_name_resolution(self, uri, expected): + """Test the fallback logic in get_object_store_bucket_name.""" + assert get_object_store_bucket_name(uri) == expected + + def test_mixed_format_compatibility(self): + """Test that both abfss:// and HTTPS formats work for the same container.""" + abfss_uri = "abfss://container@account.dfs.core.windows.net/path/file.txt" + https_uri = "https://account.blob.core.windows.net/container/path/file.txt" + + # Both should be recognized as ABS URIs + assert ABSObjectStore.is_uri(abfss_uri) + assert ABSObjectStore.is_uri(https_uri) + + # Both should extract the same container name + assert ABSObjectStore.get_bucket_name(abfss_uri) == "container" + assert ABSObjectStore.get_bucket_name(https_uri) == "container" + + # Both should extract the same object key + assert ABSObjectStore.get_object_key(abfss_uri) == "path/file.txt" + assert ABSObjectStore.get_object_key(https_uri) == "path/file.txt" # Parametrized tests for GCS URI normalization @@ -489,7 +753,37 @@ def test_gcs_prefix_stripping(input_uri, expected): assert result == expected -class TestGCSURINormalization(unittest.TestCase): +# Parametrized tests for ABS HTTPS URI handling +@pytest.mark.parametrize( + "input_uri,expected_container,expected_key", + [ + ( + "https://account.blob.core.windows.net/container/path/file.txt", + "container", + "path/file.txt", + ), + ( + "https://odedmdatacatalog.blob.core.windows.net/settler/import_export_services/message_data_randomized.csv", + "settler", + "import_export_services/message_data_randomized.csv", + ), + ( + "https://mystorageaccount.blob.core.windows.net/data/2023/logs/app.log", + "data", + "2023/logs/app.log", + ), + ("https://account.blob.core.windows.net/container/", "container", ""), + ("https://account.blob.core.windows.net/container", "container", ""), + ], +) +def test_abs_https_uri_parsing(input_uri, expected_container, expected_key): + """Test that HTTPS ABS URIs are parsed correctly.""" + assert ABSObjectStore.is_uri(input_uri) + assert ABSObjectStore.get_bucket_name(input_uri) == expected_container + assert ABSObjectStore.get_object_key(input_uri) == expected_key + + +class TestGCSURINormalization: """Tests for the GCS URI normalization fix.""" def test_gcs_adapter_customizations(self): @@ -506,7 +800,7 @@ class TestGCSURINormalization(unittest.TestCase): ] for customization in expected_customizations: - self.assertIn(customization, gcs_adapter.customizations) + assert customization in gcs_adapter.customizations def test_gcs_adapter_applied_to_mock_source(self): """Test that GCS adapter customizations are applied to a mock source.""" @@ -520,18 +814,18 @@ class TestGCSURINormalization(unittest.TestCase): gcs_adapter.apply_customizations(mock_source) # Check that the customizations were applied - self.assertTrue(hasattr(mock_source, "_normalize_uri_for_pattern_matching")) - self.assertTrue(hasattr(mock_source, "strip_s3_prefix")) - self.assertTrue(hasattr(mock_source, "create_s3_path")) + assert hasattr(mock_source, "_normalize_uri_for_pattern_matching") + assert hasattr(mock_source, "strip_s3_prefix") + assert hasattr(mock_source, "create_s3_path") # Test that the URI normalization method works on the mock source test_uri = "gs://bucket/path/file.parquet" normalized = mock_source._normalize_uri_for_pattern_matching(test_uri) - self.assertEqual(normalized, "s3://bucket/path/file.parquet") + assert normalized == "s3://bucket/path/file.parquet" # Test that the prefix stripping method works on the mock source stripped = mock_source.strip_s3_prefix(test_uri) - self.assertEqual(stripped, "bucket/path/file.parquet") + assert stripped == "bucket/path/file.parquet" def test_gcs_path_creation_via_adapter(self): """Test that GCS paths are created correctly via the adapter.""" @@ -544,7 +838,7 @@ class TestGCSURINormalization(unittest.TestCase): # Test that create_s3_path now creates GCS paths gcs_path = mock_source.create_s3_path("bucket", "path/to/file.parquet") - self.assertEqual(gcs_path, "gs://bucket/path/to/file.parquet") + assert gcs_path == "gs://bucket/path/to/file.parquet" def test_pattern_matching_scenario(self): """Test the actual pattern matching scenario that was failing.""" @@ -563,14 +857,12 @@ class TestGCSURINormalization(unittest.TestCase): ) # The normalized URI should now be compatible with the pattern - self.assertEqual( - normalized_file_uri, "s3://bucket/path/food_parquet/file.parquet" - ) + assert normalized_file_uri == "s3://bucket/path/food_parquet/file.parquet" # Test that the normalized URI would match the pattern (simplified test) glob_pattern = path_spec_pattern.replace("{table}", "*") - self.assertTrue(pathlib.PurePath(normalized_file_uri).match(glob_pattern)) + assert pathlib.PurePath(normalized_file_uri).match(glob_pattern) if __name__ == "__main__": - unittest.main() + pytest.main([__file__])