Fix #13149: Multiple Project Id for Datalake GCS (#14846)

* Fix Multiple Project Id for datalake gcs

* Optimize logic

* Fix Tests

* Add Datalake GCS Tests

* Add multiple project id gcs test
This commit is contained in:
Ayush Shah 2024-01-25 15:22:16 +05:30 committed by GitHub
parent 951917bf6d
commit 1552aeb2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 244 additions and 44 deletions

View File

@ -210,7 +210,6 @@ class BigquerySource(
# as per service connection config, which would result in an error. # as per service connection config, which would result in an error.
self.test_connection = lambda: None self.test_connection = lambda: None
super().__init__(config, metadata) super().__init__(config, metadata)
self.temp_credentials = None
self.client = None self.client = None
# Used to delete temp json file created while initializing bigquery client # Used to delete temp json file created while initializing bigquery client
self.temp_credentials_file_path = [] self.temp_credentials_file_path = []
@ -366,18 +365,18 @@ class BigquerySource(
schema_name=schema_name, schema_name=schema_name,
), ),
) )
if self.source_config.includeTags:
dataset_obj = self.client.get_dataset(schema_name) dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels and self.source_config.includeTags: if dataset_obj.labels:
database_schema_request_obj.tags = [] database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items(): for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label( tag_label = get_tag_label(
metadata=self.metadata, metadata=self.metadata,
tag_name=label_tag_name, tag_name=label_tag_name,
classification_name=label_classification, classification_name=label_classification,
) )
if tag_label: if tag_label:
database_schema_request_obj.tags.append(tag_label) database_schema_request_obj.tags.append(tag_label)
yield Either(right=database_schema_request_obj) yield Either(right=database_schema_request_obj)
def get_table_obj(self, table_name: str): def get_table_obj(self, table_name: str):
@ -530,8 +529,6 @@ class BigquerySource(
def close(self): def close(self):
super().close() super().close()
if self.temp_credentials:
os.unlink(self.temp_credentials)
os.environ.pop("GOOGLE_CLOUD_PROJECT", "") os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance( if isinstance(
self.service_connection.credentials.gcpConfig, GcpCredentialsValues self.service_connection.credentials.gcpConfig, GcpCredentialsValues

View File

@ -12,10 +12,14 @@
""" """
Source connection handler Source connection handler
""" """
import os
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, singledispatch from functools import partial, singledispatch
from typing import Optional from typing import Optional
from google.cloud import storage
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
@ -31,9 +35,13 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import ( from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection, DatalakeConnection,
) )
from metadata.generated.schema.security.credentials.gcpValues import (
MultipleProjectId,
SingleProjectId,
)
from metadata.ingestion.connections.test_connections import test_connection_steps from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.credentials import set_google_credentials from metadata.utils.credentials import GOOGLE_CREDENTIALS, set_google_credentials
# Only import specific datalake dependencies if necessary # Only import specific datalake dependencies if necessary
@ -65,9 +73,15 @@ def _(config: S3Config):
@get_datalake_client.register @get_datalake_client.register
def _(config: GCSConfig): def _(config: GCSConfig):
from google.cloud import storage gcs_config = deepcopy(config)
if hasattr(config.securityConfig, "gcpConfig") and isinstance(
set_google_credentials(gcp_credentials=config.securityConfig) config.securityConfig.gcpConfig.projectId, MultipleProjectId
):
gcs_config: GCSConfig = deepcopy(config)
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
gcs_config.securityConfig.gcpConfig.projectId.__root__[0]
)
set_google_credentials(gcp_credentials=gcs_config.securityConfig)
gcs_client = storage.Client() gcs_client = storage.Client()
return gcs_client return gcs_client
@ -96,6 +110,15 @@ def _(config: AzureConfig):
) )
def set_gcs_datalake_client(config: GCSConfig, project_id: str):
gcs_config = deepcopy(config)
if hasattr(gcs_config.securityConfig, "gcpConfig"):
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
project_id
)
return get_datalake_client(config=gcs_config)
def get_connection(connection: DatalakeConnection) -> DatalakeClient: def get_connection(connection: DatalakeConnection) -> DatalakeClient:
""" """
Create connection. Create connection.
@ -125,6 +148,10 @@ def test_connection(
func = partial(connection.client.get_bucket, connection.config.bucketName) func = partial(connection.client.get_bucket, connection.config.bucketName)
else: else:
func = connection.client.list_buckets func = connection.client.list_buckets
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if GOOGLE_CREDENTIALS in os.environ:
os.remove(os.environ[GOOGLE_CREDENTIALS])
del os.environ[GOOGLE_CREDENTIALS]
if isinstance(config, S3Config): if isinstance(config, S3Config):
if connection.config.bucketName: if connection.config.bucketName:

View File

@ -13,6 +13,7 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
""" """
import json import json
import os
import traceback import traceback
from typing import Any, Iterable, Tuple, Union from typing import Any, Iterable, Tuple, Union
@ -53,12 +54,18 @@ from metadata.generated.schema.metadataIngestion.storage.containerMetadataConfig
from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource, Source as WorkflowSource,
) )
from metadata.generated.schema.security.credentials.gcpValues import (
GcpCredentialsValues,
)
from metadata.ingestion.api.models import Either from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import DatabaseServiceSource from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.datalake.connection import (
set_gcs_datalake_client,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.storage.storage_service import ( from metadata.ingestion.source.storage.storage_service import (
OPENMETADATA_TEMPLATE_FILE_NAME, OPENMETADATA_TEMPLATE_FILE_NAME,
@ -68,12 +75,13 @@ from metadata.readers.file.base import ReadException
from metadata.readers.file.config_source_factory import get_reader from metadata.readers.file.config_source_factory import get_reader
from metadata.utils import fqn from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.credentials import GOOGLE_CREDENTIALS
from metadata.utils.datalake.datalake_utils import ( from metadata.utils.datalake.datalake_utils import (
fetch_dataframe, fetch_dataframe,
get_columns, get_columns,
get_file_format_type, get_file_format_type,
) )
from metadata.utils.filters import filter_by_schema, filter_by_table from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
from metadata.utils.s3_utils import list_s3_objects from metadata.utils.s3_utils import list_s3_objects
@ -96,8 +104,10 @@ class DatalakeSource(DatabaseServiceSource):
) )
self.metadata = metadata self.metadata = metadata
self.service_connection = self.config.serviceConnection.__root__.config self.service_connection = self.config.serviceConnection.__root__.config
self.temp_credentials_file_path = []
self.connection = get_connection(self.service_connection) self.connection = get_connection(self.service_connection)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(os.environ[GOOGLE_CREDENTIALS])
self.client = self.connection.client self.client = self.connection.client
self.table_constraints = None self.table_constraints = None
self.database_source_state = set() self.database_source_state = set()
@ -125,8 +135,47 @@ class DatalakeSource(DatabaseServiceSource):
Sources with multiple databases should overwrite this and Sources with multiple databases should overwrite this and
apply the necessary filters. apply the necessary filters.
""" """
database_name = self.service_connection.databaseName or DEFAULT_DATABASE if isinstance(self.config_source, GCSConfig):
yield database_name project_id_list = (
self.service_connection.configSource.securityConfig.gcpConfig.projectId.__root__
)
if not isinstance(
project_id_list,
list,
):
project_id_list = [project_id_list]
for project_id in project_id_list:
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service,
database_name=project_id,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else project_id,
):
self.status.filter(database_fqn, "Database Filtered out")
else:
try:
self.client = set_gcs_datalake_client(
config=self.config_source, project_id=project_id
)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(
os.environ[GOOGLE_CREDENTIALS]
)
yield project_id
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Error trying to connect to database {project_id}: {exc}"
)
else:
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name
def yield_database( def yield_database(
self, database_name: str self, database_name: str
@ -135,6 +184,8 @@ class DatalakeSource(DatabaseServiceSource):
From topology. From topology.
Prepare a database request and pass it to the sink Prepare a database request and pass it to the sink
""" """
if isinstance(self.config_source, GCSConfig):
database_name = self.client.project
yield Either( yield Either(
right=CreateDatabaseRequest( right=CreateDatabaseRequest(
name=database_name, name=database_name,
@ -143,24 +194,42 @@ class DatalakeSource(DatabaseServiceSource):
) )
def fetch_gcs_bucket_names(self): def fetch_gcs_bucket_names(self):
for bucket in self.client.list_buckets(): """
schema_fqn = fqn.build( Fetch Google cloud storage buckets
self.metadata, """
entity_type=DatabaseSchema, try:
service_name=self.context.database_service, # List all the buckets in the project
database_name=self.context.database, for bucket in self.client.list_buckets():
schema_name=bucket.name, # Build a fully qualified name (FQN) for each bucket
) schema_fqn = fqn.build(
if filter_by_schema( self.metadata,
self.config.sourceConfig.config.schemaFilterPattern, entity_type=DatabaseSchema,
schema_fqn service_name=self.context.database_service,
if self.config.sourceConfig.config.useFqnForFiltering database_name=self.context.database,
else bucket.name, schema_name=bucket.name,
): )
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue
yield bucket.name # Check if the bucket matches a certain filter pattern
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
# If it does not match, the bucket is filtered out
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue
# If it does match, the bucket name is yielded
yield bucket.name
except Exception as exc:
yield Either(
left=StackTraceError(
name="Bucket",
error=f"Unexpected exception to yield bucket: {exc}",
stackTrace=traceback.format_exc(),
)
)
def fetch_s3_bucket_names(self): def fetch_s3_bucket_names(self):
for bucket in self.client.list_buckets()["Buckets"]: for bucket in self.client.list_buckets()["Buckets"]:
@ -434,3 +503,12 @@ class DatalakeSource(DatabaseServiceSource):
def close(self): def close(self):
if isinstance(self.config_source, AzureConfig): if isinstance(self.config_source, AzureConfig):
self.client.close() self.client.close()
if isinstance(self.config_source, GCSConfig):
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(self.service_connection, GcpCredentialsValues) and (
GOOGLE_CREDENTIALS in os.environ
):
del os.environ[GOOGLE_CREDENTIALS]
for temp_file_path in self.temp_credentials_file_path:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View File

@ -97,6 +97,13 @@ class TestEntityLink(TestCase):
"<#E::table::随机的>", "<#E::table::随机的>",
["table", "随机的"], ["table", "随机的"],
), ),
EntityLinkTest(
'<#E::table::ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv">',
[
"table",
'ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv"',
],
),
] ]
for x in xs: for x in xs:
x.validate(entity_link.split(x.entitylink), x.split_list) x.validate(entity_link.split(x.entitylink), x.split_list)

View File

@ -94,6 +94,7 @@ MOCK_GCS_SCHEMA = [
] ]
EXPECTED_SCHEMA = ["my_bucket"] EXPECTED_SCHEMA = ["my_bucket"]
EXPECTED_GCS_SCHEMA = ["test_datalake", "test_gcs", "s3_test", "my_bucket"]
MOCK_DATABASE_SERVICE = DatabaseService( MOCK_DATABASE_SERVICE = DatabaseService(
@ -427,10 +428,6 @@ class DatalakeUnitTest(TestCase):
self.datalake_source.client.list_buckets = lambda: MOCK_S3_SCHEMA self.datalake_source.client.list_buckets = lambda: MOCK_S3_SCHEMA
assert list(self.datalake_source.fetch_s3_bucket_names()) == EXPECTED_SCHEMA assert list(self.datalake_source.fetch_s3_bucket_names()) == EXPECTED_SCHEMA
def test_gcs_schema_filer(self):
self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA
assert list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_SCHEMA
def test_json_file_parse(self): def test_json_file_parse(self):
""" """
Test json data files Test json data files
@ -479,3 +476,97 @@ class DatalakeUnitTest(TestCase):
columns = AvroDataFrameReader.read_from_avro(AVRO_DATA_FILE) columns = AvroDataFrameReader.read_from_avro(AVRO_DATA_FILE)
assert EXPECTED_AVRO_COL_2 == columns.columns # pylint: disable=no-member assert EXPECTED_AVRO_COL_2 == columns.columns # pylint: disable=no-member
mock_datalake_gcs_config = {
"source": {
"type": "datalake",
"serviceName": "local_datalake",
"serviceConnection": {
"config": {
"type": "Datalake",
"configSource": {
"securityConfig": {
"gcpConfig": {
"type": "service_account",
"projectId": "project_id",
"privateKeyId": "private_key_id",
"privateKey": "private_key",
"clientEmail": "gcpuser@project_id.iam.gserviceaccount.com",
"clientId": "client_id",
"authUri": "https://accounts.google.com/o/oauth2/auth",
"tokenUri": "https://oauth2.googleapis.com/token",
"authProviderX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs",
"clientX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs",
}
}
},
"bucketName": "bucket name",
"prefix": "prefix",
}
},
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
},
"sink": {"type": "metadata-rest", "config": {}},
"workflowConfig": {
"loggerLevel": "DEBUG",
"openMetadataServerConfig": {
"hostPort": "http://localhost:8585/api",
"authProvider": "openmetadata",
"securityConfig": {
"jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
},
},
},
}
mock_multiple_project_id = deepcopy(mock_datalake_gcs_config)
mock_multiple_project_id["source"]["serviceConnection"]["config"]["configSource"][
"securityConfig"
]["gcpConfig"]["projectId"] = ["project_id", "project_id2"]
class DatalakeGCSUnitTest(TestCase):
"""
Datalake Source Unit Tests
"""
@patch(
"metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection"
)
@patch("metadata.utils.credentials.validate_private_key")
@patch("google.cloud.storage.Client")
def __init__(self, methodName, _, __, test_connection) -> None:
super().__init__(methodName)
test_connection.return_value = False
self.config = OpenMetadataWorkflowConfig.parse_obj(mock_datalake_gcs_config)
self.datalake_source = DatalakeSource.create(
mock_datalake_gcs_config["source"],
self.config.workflowConfig.openMetadataServerConfig,
)
self.datalake_source.context.__dict__["database"] = MOCK_DATABASE.name.__root__
self.datalake_source.context.__dict__[
"database_service"
] = MOCK_DATABASE_SERVICE.name.__root__
@patch(
"metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection"
)
@patch("google.cloud.storage.Client")
@patch("metadata.utils.credentials.validate_private_key")
def test_multiple_project_id_implementation(
self, validate_private_key, storage_client, test_connection
):
self.datalake_source_multiple_project_id = DatalakeSource.create(
mock_multiple_project_id["source"],
OpenMetadataWorkflowConfig.parse_obj(
mock_multiple_project_id
).workflowConfig.openMetadataServerConfig,
)
def test_gcs_schema_filer(self):
self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA
assert (
list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_GCS_SCHEMA
)