Fix #13149: Multiple Project Id for Datalake GCS (#14846)

* Fix Multiple Project Id for datalake gcs

* Optimize logic

* Fix Tests

* Add Datalake GCS Tests

* Add multiple project id gcs test
This commit is contained in:
Ayush Shah 2024-01-25 15:22:16 +05:30 committed by GitHub
parent 951917bf6d
commit 1552aeb2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 244 additions and 44 deletions

View File

@ -210,7 +210,6 @@ class BigquerySource(
# as per service connection config, which would result in an error.
self.test_connection = lambda: None
super().__init__(config, metadata)
self.temp_credentials = None
self.client = None
# Used to delete temp json file created while initializing bigquery client
self.temp_credentials_file_path = []
@ -366,9 +365,9 @@ class BigquerySource(
schema_name=schema_name,
),
)
if self.source_config.includeTags:
dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels and self.source_config.includeTags:
if dataset_obj.labels:
database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label(
@ -530,8 +529,6 @@ class BigquerySource(
def close(self):
super().close()
if self.temp_credentials:
os.unlink(self.temp_credentials)
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(
self.service_connection.credentials.gcpConfig, GcpCredentialsValues

View File

@ -12,10 +12,14 @@
"""
Source connection handler
"""
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, singledispatch
from typing import Optional
from google.cloud import storage
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
@ -31,9 +35,13 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.security.credentials.gcpValues import (
MultipleProjectId,
SingleProjectId,
)
from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.credentials import set_google_credentials
from metadata.utils.credentials import GOOGLE_CREDENTIALS, set_google_credentials
# Only import specific datalake dependencies if necessary
@ -65,9 +73,15 @@ def _(config: S3Config):
@get_datalake_client.register
def _(config: GCSConfig):
from google.cloud import storage
set_google_credentials(gcp_credentials=config.securityConfig)
gcs_config = deepcopy(config)
if hasattr(config.securityConfig, "gcpConfig") and isinstance(
config.securityConfig.gcpConfig.projectId, MultipleProjectId
):
gcs_config: GCSConfig = deepcopy(config)
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
gcs_config.securityConfig.gcpConfig.projectId.__root__[0]
)
set_google_credentials(gcp_credentials=gcs_config.securityConfig)
gcs_client = storage.Client()
return gcs_client
@ -96,6 +110,15 @@ def _(config: AzureConfig):
)
def set_gcs_datalake_client(config: GCSConfig, project_id: str):
gcs_config = deepcopy(config)
if hasattr(gcs_config.securityConfig, "gcpConfig"):
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
project_id
)
return get_datalake_client(config=gcs_config)
def get_connection(connection: DatalakeConnection) -> DatalakeClient:
"""
Create connection.
@ -125,6 +148,10 @@ def test_connection(
func = partial(connection.client.get_bucket, connection.config.bucketName)
else:
func = connection.client.list_buckets
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if GOOGLE_CREDENTIALS in os.environ:
os.remove(os.environ[GOOGLE_CREDENTIALS])
del os.environ[GOOGLE_CREDENTIALS]
if isinstance(config, S3Config):
if connection.config.bucketName:

View File

@ -13,6 +13,7 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
"""
import json
import os
import traceback
from typing import Any, Iterable, Tuple, Union
@ -53,12 +54,18 @@ from metadata.generated.schema.metadataIngestion.storage.containerMetadataConfig
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.security.credentials.gcpValues import (
GcpCredentialsValues,
)
from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.datalake.connection import (
set_gcs_datalake_client,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.storage.storage_service import (
OPENMETADATA_TEMPLATE_FILE_NAME,
@ -68,12 +75,13 @@ from metadata.readers.file.base import ReadException
from metadata.readers.file.config_source_factory import get_reader
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.credentials import GOOGLE_CREDENTIALS
from metadata.utils.datalake.datalake_utils import (
fetch_dataframe,
get_columns,
get_file_format_type,
)
from metadata.utils.filters import filter_by_schema, filter_by_table
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger
from metadata.utils.s3_utils import list_s3_objects
@ -96,8 +104,10 @@ class DatalakeSource(DatabaseServiceSource):
)
self.metadata = metadata
self.service_connection = self.config.serviceConnection.__root__.config
self.temp_credentials_file_path = []
self.connection = get_connection(self.service_connection)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(os.environ[GOOGLE_CREDENTIALS])
self.client = self.connection.client
self.table_constraints = None
self.database_source_state = set()
@ -125,6 +135,45 @@ class DatalakeSource(DatabaseServiceSource):
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
if isinstance(self.config_source, GCSConfig):
project_id_list = (
self.service_connection.configSource.securityConfig.gcpConfig.projectId.__root__
)
if not isinstance(
project_id_list,
list,
):
project_id_list = [project_id_list]
for project_id in project_id_list:
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service,
database_name=project_id,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else project_id,
):
self.status.filter(database_fqn, "Database Filtered out")
else:
try:
self.client = set_gcs_datalake_client(
config=self.config_source, project_id=project_id
)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(
os.environ[GOOGLE_CREDENTIALS]
)
yield project_id
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Error trying to connect to database {project_id}: {exc}"
)
else:
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name
@ -135,6 +184,8 @@ class DatalakeSource(DatabaseServiceSource):
From topology.
Prepare a database request and pass it to the sink
"""
if isinstance(self.config_source, GCSConfig):
database_name = self.client.project
yield Either(
right=CreateDatabaseRequest(
name=database_name,
@ -143,7 +194,13 @@ class DatalakeSource(DatabaseServiceSource):
)
def fetch_gcs_bucket_names(self):
"""
Fetch Google cloud storage buckets
"""
try:
# List all the buckets in the project
for bucket in self.client.list_buckets():
# Build a fully qualified name (FQN) for each bucket
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
@ -151,16 +208,28 @@ class DatalakeSource(DatabaseServiceSource):
database_name=self.context.database,
schema_name=bucket.name,
)
# Check if the bucket matches a certain filter pattern
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
# If it does not match, the bucket is filtered out
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue
# If it does match, the bucket name is yielded
yield bucket.name
except Exception as exc:
yield Either(
left=StackTraceError(
name="Bucket",
error=f"Unexpected exception to yield bucket: {exc}",
stackTrace=traceback.format_exc(),
)
)
def fetch_s3_bucket_names(self):
for bucket in self.client.list_buckets()["Buckets"]:
@ -434,3 +503,12 @@ class DatalakeSource(DatabaseServiceSource):
def close(self):
if isinstance(self.config_source, AzureConfig):
self.client.close()
if isinstance(self.config_source, GCSConfig):
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(self.service_connection, GcpCredentialsValues) and (
GOOGLE_CREDENTIALS in os.environ
):
del os.environ[GOOGLE_CREDENTIALS]
for temp_file_path in self.temp_credentials_file_path:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View File

@ -97,6 +97,13 @@ class TestEntityLink(TestCase):
"<#E::table::随机的>",
["table", "随机的"],
),
EntityLinkTest(
'<#E::table::ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv">',
[
"table",
'ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv"',
],
),
]
for x in xs:
x.validate(entity_link.split(x.entitylink), x.split_list)

View File

@ -94,6 +94,7 @@ MOCK_GCS_SCHEMA = [
]
EXPECTED_SCHEMA = ["my_bucket"]
EXPECTED_GCS_SCHEMA = ["test_datalake", "test_gcs", "s3_test", "my_bucket"]
MOCK_DATABASE_SERVICE = DatabaseService(
@ -427,10 +428,6 @@ class DatalakeUnitTest(TestCase):
self.datalake_source.client.list_buckets = lambda: MOCK_S3_SCHEMA
assert list(self.datalake_source.fetch_s3_bucket_names()) == EXPECTED_SCHEMA
def test_gcs_schema_filer(self):
self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA
assert list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_SCHEMA
def test_json_file_parse(self):
"""
Test json data files
@ -479,3 +476,97 @@ class DatalakeUnitTest(TestCase):
columns = AvroDataFrameReader.read_from_avro(AVRO_DATA_FILE)
assert EXPECTED_AVRO_COL_2 == columns.columns # pylint: disable=no-member
mock_datalake_gcs_config = {
"source": {
"type": "datalake",
"serviceName": "local_datalake",
"serviceConnection": {
"config": {
"type": "Datalake",
"configSource": {
"securityConfig": {
"gcpConfig": {
"type": "service_account",
"projectId": "project_id",
"privateKeyId": "private_key_id",
"privateKey": "private_key",
"clientEmail": "gcpuser@project_id.iam.gserviceaccount.com",
"clientId": "client_id",
"authUri": "https://accounts.google.com/o/oauth2/auth",
"tokenUri": "https://oauth2.googleapis.com/token",
"authProviderX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs",
"clientX509CertUrl": "https://www.googleapis.com/oauth2/v1/certs",
}
}
},
"bucketName": "bucket name",
"prefix": "prefix",
}
},
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
},
"sink": {"type": "metadata-rest", "config": {}},
"workflowConfig": {
"loggerLevel": "DEBUG",
"openMetadataServerConfig": {
"hostPort": "http://localhost:8585/api",
"authProvider": "openmetadata",
"securityConfig": {
"jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
},
},
},
}
mock_multiple_project_id = deepcopy(mock_datalake_gcs_config)
mock_multiple_project_id["source"]["serviceConnection"]["config"]["configSource"][
"securityConfig"
]["gcpConfig"]["projectId"] = ["project_id", "project_id2"]
class DatalakeGCSUnitTest(TestCase):
"""
Datalake Source Unit Tests
"""
@patch(
"metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection"
)
@patch("metadata.utils.credentials.validate_private_key")
@patch("google.cloud.storage.Client")
def __init__(self, methodName, _, __, test_connection) -> None:
super().__init__(methodName)
test_connection.return_value = False
self.config = OpenMetadataWorkflowConfig.parse_obj(mock_datalake_gcs_config)
self.datalake_source = DatalakeSource.create(
mock_datalake_gcs_config["source"],
self.config.workflowConfig.openMetadataServerConfig,
)
self.datalake_source.context.__dict__["database"] = MOCK_DATABASE.name.__root__
self.datalake_source.context.__dict__[
"database_service"
] = MOCK_DATABASE_SERVICE.name.__root__
@patch(
"metadata.ingestion.source.database.datalake.metadata.DatalakeSource.test_connection"
)
@patch("google.cloud.storage.Client")
@patch("metadata.utils.credentials.validate_private_key")
def test_multiple_project_id_implementation(
self, validate_private_key, storage_client, test_connection
):
self.datalake_source_multiple_project_id = DatalakeSource.create(
mock_multiple_project_id["source"],
OpenMetadataWorkflowConfig.parse_obj(
mock_multiple_project_id
).workflowConfig.openMetadataServerConfig,
)
def test_gcs_schema_filer(self):
self.datalake_source.client.list_buckets = lambda: MOCK_GCS_SCHEMA
assert (
list(self.datalake_source.fetch_gcs_bucket_names()) == EXPECTED_GCS_SCHEMA
)