From 1552aeb2de9d2cadb902d19bb17b43d9452d139f Mon Sep 17 00:00:00 2001 From: Ayush Shah Date: Thu, 25 Jan 2024 15:22:16 +0530 Subject: [PATCH] Fix #13149: Multiple Project Id for Datalake GCS (#14846) * Fix Multiple Project Id for datalake gcs * Optimize logic * Fix Tests * Add Datalake GCS Tests * Add multiple project id gcs test --- .../source/database/bigquery/metadata.py | 27 ++-- .../source/database/datalake/connection.py | 35 ++++- .../source/database/datalake/metadata.py | 120 +++++++++++++++--- ingestion/tests/unit/test_entity_link.py | 7 + .../unit/topology/database/test_datalake.py | 99 ++++++++++++++- 5 files changed, 244 insertions(+), 44 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py index a81a87e48d5..1abd2ab9b1a 100644 --- a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py @@ -210,7 +210,6 @@ class BigquerySource( # as per service connection config, which would result in an error. self.test_connection = lambda: None super().__init__(config, metadata) - self.temp_credentials = None self.client = None # Used to delete temp json file created while initializing bigquery client self.temp_credentials_file_path = [] @@ -366,18 +365,18 @@ class BigquerySource( schema_name=schema_name, ), ) - - dataset_obj = self.client.get_dataset(schema_name) - if dataset_obj.labels and self.source_config.includeTags: - database_schema_request_obj.tags = [] - for label_classification, label_tag_name in dataset_obj.labels.items(): - tag_label = get_tag_label( - metadata=self.metadata, - tag_name=label_tag_name, - classification_name=label_classification, - ) - if tag_label: - database_schema_request_obj.tags.append(tag_label) + if self.source_config.includeTags: + dataset_obj = self.client.get_dataset(schema_name) + if dataset_obj.labels: + database_schema_request_obj.tags = [] + for label_classification, label_tag_name in dataset_obj.labels.items(): + tag_label = get_tag_label( + metadata=self.metadata, + tag_name=label_tag_name, + classification_name=label_classification, + ) + if tag_label: + database_schema_request_obj.tags.append(tag_label) yield Either(right=database_schema_request_obj) def get_table_obj(self, table_name: str): @@ -530,8 +529,6 @@ class BigquerySource( def close(self): super().close() - if self.temp_credentials: - os.unlink(self.temp_credentials) os.environ.pop("GOOGLE_CLOUD_PROJECT", "") if isinstance( self.service_connection.credentials.gcpConfig, GcpCredentialsValues diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py index f43cd1b6c39..56e5315da5b 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/connection.py @@ -12,10 +12,14 @@ """ Source connection handler """ +import os +from copy import deepcopy from dataclasses import dataclass from functools import partial, singledispatch from typing import Optional +from google.cloud import storage + from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) @@ -31,9 +35,13 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C from metadata.generated.schema.entity.services.connections.database.datalakeConnection import ( DatalakeConnection, ) +from metadata.generated.schema.security.credentials.gcpValues import ( + MultipleProjectId, + SingleProjectId, +) from metadata.ingestion.connections.test_connections import test_connection_steps from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.utils.credentials import set_google_credentials +from metadata.utils.credentials import GOOGLE_CREDENTIALS, set_google_credentials # Only import specific datalake dependencies if necessary @@ -65,9 +73,15 @@ def _(config: S3Config): @get_datalake_client.register def _(config: GCSConfig): - from google.cloud import storage - - set_google_credentials(gcp_credentials=config.securityConfig) + gcs_config = deepcopy(config) + if hasattr(config.securityConfig, "gcpConfig") and isinstance( + config.securityConfig.gcpConfig.projectId, MultipleProjectId + ): + gcs_config: GCSConfig = deepcopy(config) + gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj( + gcs_config.securityConfig.gcpConfig.projectId.__root__[0] + ) + set_google_credentials(gcp_credentials=gcs_config.securityConfig) gcs_client = storage.Client() return gcs_client @@ -96,6 +110,15 @@ def _(config: AzureConfig): ) +def set_gcs_datalake_client(config: GCSConfig, project_id: str): + gcs_config = deepcopy(config) + if hasattr(gcs_config.securityConfig, "gcpConfig"): + gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj( + project_id + ) + return get_datalake_client(config=gcs_config) + + def get_connection(connection: DatalakeConnection) -> DatalakeClient: """ Create connection. @@ -125,6 +148,10 @@ def test_connection( func = partial(connection.client.get_bucket, connection.config.bucketName) else: func = connection.client.list_buckets + os.environ.pop("GOOGLE_CLOUD_PROJECT", "") + if GOOGLE_CREDENTIALS in os.environ: + os.remove(os.environ[GOOGLE_CREDENTIALS]) + del os.environ[GOOGLE_CREDENTIALS] if isinstance(config, S3Config): if connection.config.bucketName: diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py index dcff7205e72..50316f8539a 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py @@ -13,6 +13,7 @@ DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs """ import json +import os import traceback from typing import Any, Iterable, Tuple, Union @@ -53,12 +54,18 @@ from metadata.generated.schema.metadataIngestion.storage.containerMetadataConfig from metadata.generated.schema.metadataIngestion.workflow import ( Source as WorkflowSource, ) +from metadata.generated.schema.security.credentials.gcpValues import ( + GcpCredentialsValues, +) from metadata.ingestion.api.models import Either from metadata.ingestion.api.steps import InvalidSourceException from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_connection from metadata.ingestion.source.database.database_service import DatabaseServiceSource +from metadata.ingestion.source.database.datalake.connection import ( + set_gcs_datalake_client, +) from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure from metadata.ingestion.source.storage.storage_service import ( OPENMETADATA_TEMPLATE_FILE_NAME, @@ -68,12 +75,13 @@ from metadata.readers.file.base import ReadException from metadata.readers.file.config_source_factory import get_reader from metadata.utils import fqn from metadata.utils.constants import DEFAULT_DATABASE +from metadata.utils.credentials import GOOGLE_CREDENTIALS from metadata.utils.datalake.datalake_utils import ( fetch_dataframe, get_columns, get_file_format_type, ) -from metadata.utils.filters import filter_by_schema, filter_by_table +from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table from metadata.utils.logger import ingestion_logger from metadata.utils.s3_utils import list_s3_objects @@ -96,8 +104,10 @@ class DatalakeSource(DatabaseServiceSource): ) self.metadata = metadata self.service_connection = self.config.serviceConnection.__root__.config + self.temp_credentials_file_path = [] self.connection = get_connection(self.service_connection) - + if GOOGLE_CREDENTIALS in os.environ: + self.temp_credentials_file_path.append(os.environ[GOOGLE_CREDENTIALS]) self.client = self.connection.client self.table_constraints = None self.database_source_state = set() @@ -125,8 +135,47 @@ class DatalakeSource(DatabaseServiceSource): Sources with multiple databases should overwrite this and apply the necessary filters. """ - database_name = self.service_connection.databaseName or DEFAULT_DATABASE - yield database_name + if isinstance(self.config_source, GCSConfig): + project_id_list = ( + self.service_connection.configSource.securityConfig.gcpConfig.projectId.__root__ + ) + if not isinstance( + project_id_list, + list, + ): + project_id_list = [project_id_list] + for project_id in project_id_list: + database_fqn = fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.database_service, + database_name=project_id, + ) + if filter_by_database( + self.source_config.databaseFilterPattern, + database_fqn + if self.source_config.useFqnForFiltering + else project_id, + ): + self.status.filter(database_fqn, "Database Filtered out") + else: + try: + self.client = set_gcs_datalake_client( + config=self.config_source, project_id=project_id + ) + if GOOGLE_CREDENTIALS in os.environ: + self.temp_credentials_file_path.append( + os.environ[GOOGLE_CREDENTIALS] + ) + yield project_id + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.error( + f"Error trying to connect to database {project_id}: {exc}" + ) + else: + database_name = self.service_connection.databaseName or DEFAULT_DATABASE + yield database_name def yield_database( self, database_name: str @@ -135,6 +184,8 @@ class DatalakeSource(DatabaseServiceSource): From topology. Prepare a database request and pass it to the sink """ + if isinstance(self.config_source, GCSConfig): + database_name = self.client.project yield Either( right=CreateDatabaseRequest( name=database_name, @@ -143,24 +194,42 @@ class DatalakeSource(DatabaseServiceSource): ) def fetch_gcs_bucket_names(self): - for bucket in self.client.list_buckets(): - schema_fqn = fqn.build( - self.metadata, - entity_type=DatabaseSchema, - service_name=self.context.database_service, - database_name=self.context.database, - schema_name=bucket.name, - ) - if filter_by_schema( - self.config.sourceConfig.config.schemaFilterPattern, - schema_fqn - if self.config.sourceConfig.config.useFqnForFiltering - else bucket.name, - ): - self.status.filter(schema_fqn, "Bucket Filtered Out") - continue + """ + Fetch Google cloud storage buckets + """ + try: + # List all the buckets in the project + for bucket in self.client.list_buckets(): + # Build a fully qualified name (FQN) for each bucket + schema_fqn = fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.database_service, + database_name=self.context.database, + schema_name=bucket.name, + ) - yield bucket.name + # Check if the bucket matches a certain filter pattern + if filter_by_schema( + self.config.sourceConfig.config.schemaFilterPattern, + schema_fqn + if self.config.sourceConfig.config.useFqnForFiltering + else bucket.name, + ): + # If it does not match, the bucket is filtered out + self.status.filter(schema_fqn, "Bucket Filtered Out") + continue + + # If it does match, the bucket name is yielded + yield bucket.name + except Exception as exc: + yield Either( + left=StackTraceError( + name="Bucket", + error=f"Unexpected exception to yield bucket: {exc}", + stackTrace=traceback.format_exc(), + ) + ) def fetch_s3_bucket_names(self): for bucket in self.client.list_buckets()["Buckets"]: @@ -434,3 +503,12 @@ class DatalakeSource(DatabaseServiceSource): def close(self): if isinstance(self.config_source, AzureConfig): self.client.close() + if isinstance(self.config_source, GCSConfig): + os.environ.pop("GOOGLE_CLOUD_PROJECT", "") + if isinstance(self.service_connection, GcpCredentialsValues) and ( + GOOGLE_CREDENTIALS in os.environ + ): + del os.environ[GOOGLE_CREDENTIALS] + for temp_file_path in self.temp_credentials_file_path: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) diff --git a/ingestion/tests/unit/test_entity_link.py b/ingestion/tests/unit/test_entity_link.py index 18549b55dc2..ca45c858b1a 100644 --- a/ingestion/tests/unit/test_entity_link.py +++ b/ingestion/tests/unit/test_entity_link.py @@ -97,6 +97,13 @@ class TestEntityLink(TestCase): "<#E::table::随机的>", ["table", "随机的"], ), + EntityLinkTest( + '<#E::table::ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv">', + [ + "table", + 'ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv"', + ], + ), ] for x in xs: x.validate(entity_link.split(x.entitylink), x.split_list) diff --git a/ingestion/tests/unit/topology/database/test_datalake.py b/ingestion/tests/unit/topology/database/test_datalake.py index 46c70e5e113..8c45c7fcfe6 100644 --- a/ingestion/tests/unit/topology/database/test_datalake.py +++ b/ingestion/tests/unit/topology/database/test_datalake.py @@ -94,6 +94,7 @@ MOCK_GCS_SCHEMA = [ ] EXPECTED_SCHEMA = ["my_bucket"] +EXPECTED_GCS_SCHEMA = ["test_datalake", "test_gcs", "s3_test", "my_bucket"] MOCK_DATABASE_SERVICE = DatabaseService( @@ -427,10 +428,6 @@ class DatalakeUnitTest(TestCase): self.datalake_source.client.list_buckets = lambda: MOCK_S3_SCHEMA assert list(self.datalake_source.fetch_s3_bucket_names()) == EXPECTED_SCHEMA - def test_gcs_schema_filer(self): - self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA - assert list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_SCHEMA - def test_json_file_parse(self): """ Test json data files @@ -479,3 +476,97 @@ class DatalakeUnitTest(TestCase): columns = AvroDataFrameReader.read_from_avro(AVRO_DATA_FILE) assert EXPECTED_AVRO_COL_2 == columns.columns # pylint: disable=no-member + + +mock_datalake_gcs_config = { + "source": { + "type": "datalake", + "serviceName": "local_datalake", + "serviceConnection": { + "config": { + "type": "Datalake", + "configSource": { + "securityConfig": { + "gcpConfig": { + "type": "service_account", + "projectId": "project_id", + "privateKeyId": "private_key_id", + "privateKey": "private_key", + "clientEmail": "gcpuser@project_id.iam.gserviceaccount.com", + "clientId": "client_id", + "authUri": "https://accounts.google.com/o/oauth2/auth", + "tokenUri": "https://oauth2.googleapis.com/token", + "authProviderX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs", + "clientX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs", + } + } + }, + "bucketName": "bucket name", + "prefix": "prefix", + } + }, + "sourceConfig": {"config": {"type": "DatabaseMetadata"}}, + }, + "sink": {"type": "metadata-rest", "config": {}}, + "workflowConfig": { + "loggerLevel": "DEBUG", + "openMetadataServerConfig": { + "hostPort": "http://localhost:8585/api", + "authProvider": "openmetadata", + "securityConfig": { + "jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" + }, + }, + }, +} + +mock_multiple_project_id = deepcopy(mock_datalake_gcs_config) + +mock_multiple_project_id["source"]["serviceConnection"]["config"]["configSource"][ + "securityConfig" +]["gcpConfig"]["projectId"] = ["project_id", "project_id2"] + + +class DatalakeGCSUnitTest(TestCase): + """ + Datalake Source Unit Tests + """ + + @patch( + "metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection" + ) + @patch("metadata.utils.credentials.validate_private_key") + @patch("google.cloud.storage.Client") + def __init__(self, methodName, _, __, test_connection) -> None: + super().__init__(methodName) + test_connection.return_value = False + self.config = OpenMetadataWorkflowConfig.parse_obj(mock_datalake_gcs_config) + self.datalake_source = DatalakeSource.create( + mock_datalake_gcs_config["source"], + self.config.workflowConfig.openMetadataServerConfig, + ) + self.datalake_source.context.__dict__["database"] = MOCK_DATABASE.name.__root__ + self.datalake_source.context.__dict__[ + "database_service" + ] = MOCK_DATABASE_SERVICE.name.__root__ + + @patch( + "metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection" + ) + @patch("google.cloud.storage.Client") + @patch("metadata.utils.credentials.validate_private_key") + def test_multiple_project_id_implementation( + self, validate_private_key, storage_client, test_connection + ): + self.datalake_source_multiple_project_id = DatalakeSource.create( + mock_multiple_project_id["source"], + OpenMetadataWorkflowConfig.parse_obj( + mock_multiple_project_id + ).workflowConfig.openMetadataServerConfig, + ) + + def test_gcs_schema_filer(self): + self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA + assert ( + list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_GCS_SCHEMA + )