diff --git a/ingestion/src/metadata/ingestion/connections/builders.py b/ingestion/src/metadata/ingestion/connections/builders.py index b648a9788ac..449aec3e939 100644 --- a/ingestion/src/metadata/ingestion/connections/builders.py +++ b/ingestion/src/metadata/ingestion/connections/builders.py @@ -56,7 +56,10 @@ def get_connection_args_common(connection) -> Dict[str, Any]: def create_generic_db_connection( - connection, get_connection_url_fn: Callable, get_connection_args_fn: Callable + connection, + get_connection_url_fn: Callable, + get_connection_args_fn: Callable, + **kwargs, ) -> Engine: """ Generic Engine creation from connection object @@ -75,6 +78,7 @@ def create_generic_db_connection( pool_reset_on_return=None, # https://docs.sqlalchemy.org/en/14/core/pooling.html#reset-on-return echo=False, max_overflow=-1, + **kwargs, ) attach_query_tracker(engine) diff --git a/ingestion/src/metadata/ingestion/source/database/bigquery/connection.py b/ingestion/src/metadata/ingestion/source/database/bigquery/connection.py index 85d9e929d54..e56a6756d63 100644 --- a/ingestion/src/metadata/ingestion/source/database/bigquery/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/bigquery/connection.py @@ -31,6 +31,7 @@ from metadata.generated.schema.entity.services.connections.testConnectionResult TestConnectionResult, ) from metadata.generated.schema.security.credentials.gcpCredentials import ( + GcpADC, GcpCredentialsPath, ) from metadata.generated.schema.security.credentials.gcpValues import ( @@ -68,25 +69,21 @@ def get_connection_url(connection: BigQueryConnection) -> str: connection.credentials.gcpConfig.projectId, SingleProjectId ): if not connection.credentials.gcpConfig.projectId.root: - return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root or ''}" + return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root or ''}" if ( not connection.credentials.gcpConfig.privateKey and connection.credentials.gcpConfig.projectId.root ): project_id = connection.credentials.gcpConfig.projectId.root - os.environ["GOOGLE_CLOUD_PROJECT"] = ( - connection.billingProjectId or project_id - ) - return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root}" + os.environ["GOOGLE_CLOUD_PROJECT"] = project_id + return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}" elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId): for project_id in connection.credentials.gcpConfig.projectId.root: if not connection.credentials.gcpConfig.privateKey and project_id: # Setting environment variable based on project id given by user / set in ADC - os.environ["GOOGLE_CLOUD_PROJECT"] = ( - connection.billingProjectId or project_id - ) - return f"{connection.scheme.value}://{connection.billingProjectId or project_id}" - return f"{connection.scheme.value}://{connection.billingProjectId or ''}" + os.environ["GOOGLE_CLOUD_PROJECT"] = project_id + return f"{connection.scheme.value}://{project_id}" + return f"{connection.scheme.value}://" # If gcpConfig is the JSON key path and projectId is defined, we use it by default elif ( @@ -96,13 +93,27 @@ def get_connection_url(connection: BigQueryConnection) -> str: if isinstance( # pylint: disable=no-else-return connection.credentials.gcpConfig.projectId, SingleProjectId ): - return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root}" + return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}" elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId): for project_id in connection.credentials.gcpConfig.projectId.root: - return f"{connection.scheme.value}://{connection.billingProjectId or project_id}" + return f"{connection.scheme.value}://{project_id}" - return f"{connection.scheme.value}://{connection.billingProjectId or ''}" + # If gcpConfig is the GCP ADC and projectId is defined, we use it by default + elif ( + isinstance(connection.credentials.gcpConfig, GcpADC) + and connection.credentials.gcpConfig.projectId + ): + if isinstance( # pylint: disable=no-else-return + connection.credentials.gcpConfig.projectId, SingleProjectId + ): + return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}" + + elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId): + for project_id in connection.credentials.gcpConfig.projectId.root: + return f"{connection.scheme.value}://{project_id}" + + return f"{connection.scheme.value}://" def get_connection(connection: BigQueryConnection) -> Engine: @@ -110,10 +121,15 @@ def get_connection(connection: BigQueryConnection) -> Engine: Prepare the engine and the GCP credentials """ set_google_credentials(gcp_credentials=connection.credentials) + kwargs = {} + if connection.billingProjectId: + kwargs["billing_project_id"] = connection.billingProjectId + return create_generic_db_connection( connection=connection, get_connection_url_fn=get_connection_url, get_connection_args_fn=get_connection_args_common, + **kwargs, ) diff --git a/ingestion/src/metadata/ingestion/source/database/bigquery/queries.py b/ingestion/src/metadata/ingestion/source/database/bigquery/queries.py index 6676e3d51f6..9d5b1cb77e4 100644 --- a/ingestion/src/metadata/ingestion/source/database/bigquery/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/bigquery/queries.py @@ -211,13 +211,18 @@ class BigQueryQueryResult(BaseModel): usage_location: str, dataset_id: str, project_id: str, + billing_project_id: Optional[str] = None, ): + # Use billing project for the INFORMATION_SCHEMA query if provided + query_project_id = billing_project_id or project_id + rows = session.execute( text( JOBS.format( usage_location=usage_location, dataset_id=dataset_id, project_id=project_id, + query_project_id=query_project_id, insert=DatabaseDMLOperations.INSERT.value, update=DatabaseDMLOperations.UPDATE.value, delete=DatabaseDMLOperations.DELETE.value, @@ -240,7 +245,7 @@ JOBS = """ dml_statistics.deleted_row_count as deleted_row_count, dml_statistics.updated_row_count as updated_row_count FROM - `region-{usage_location}`.INFORMATION_SCHEMA.JOBS + `{query_project_id}`.`region-{usage_location}`.INFORMATION_SCHEMA.JOBS WHERE DATE(creation_time) >= CURRENT_DATE() - 1 AND destination_table.dataset_id = '{dataset_id}' AND diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/bigquery/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/bigquery/profiler_interface.py index 4a87c238a46..f7577ffdae5 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/bigquery/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/bigquery/profiler_interface.py @@ -13,11 +13,13 @@ Interfaces with database for all database engine supporting sqlalchemy abstraction layer """ +from copy import deepcopy from typing import List, Type, cast from sqlalchemy import Column, inspect from metadata.generated.schema.entity.data.table import SystemProfile +from metadata.generated.schema.security.credentials.gcpValues import SingleProjectId from metadata.profiler.interface.sqlalchemy.profiler_interface import ( SQAProfilerInterface, ) @@ -27,6 +29,7 @@ from metadata.profiler.metrics.system.bigquery.system import ( from metadata.profiler.metrics.system.system import System from metadata.profiler.processor.runner import QueryRunner from metadata.utils.logger import profiler_interface_registry_logger +from metadata.utils.ssl_manager import get_ssl_connection logger = profiler_interface_registry_logger() @@ -34,6 +37,19 @@ logger = profiler_interface_registry_logger() class BigQueryProfilerInterface(SQAProfilerInterface): """BigQuery profiler interface""" + def create_session(self): + connection_config = deepcopy(self.service_connection_config) + # Create a modified connection for BigQuery with the correct project ID + if ( + hasattr(connection_config.credentials.gcpConfig, "projectId") + and self.table_entity.database + ): + connection_config.credentials.gcpConfig.projectId = SingleProjectId( + root=self.table_entity.database.name + ) + self.connection = get_ssl_connection(connection_config) + return super().create_session() + def _compute_system_metrics( self, metrics: Type[System], @@ -49,6 +65,7 @@ class BigQueryProfilerInterface(SQAProfilerInterface): session=self.session, runner=runner, usage_location=self.service_connection_config.usageLocation, + billing_project_id=self.service_connection_config.billingProjectId, ) return instance.get_system_metrics() diff --git a/ingestion/src/metadata/profiler/metrics/system/bigquery/system.py b/ingestion/src/metadata/profiler/metrics/system/bigquery/system.py index f6b5230058e..064515a21d0 100644 --- a/ingestion/src/metadata/profiler/metrics/system/bigquery/system.py +++ b/ingestion/src/metadata/profiler/metrics/system/bigquery/system.py @@ -1,6 +1,6 @@ """BigQuery system metric source""" -from typing import List +from typing import List, Optional from pydantic import TypeAdapter from sqlalchemy.orm import Session @@ -30,12 +30,14 @@ class BigQuerySystemMetricsComputer(SystemMetricsComputer, CacheProvider): session: Session, runner: QueryRunner, usage_location: str, + billing_project_id: Optional[str] = None, ): self.session = session self.table = runner.table_name self.project_id = runner.session.get_bind().url.host self.dataset_id = runner.schema_name self.usage_location = usage_location + self.billing_project_id = billing_project_id or self.project_id def get_deletes(self) -> List[SystemProfile]: return self.get_system_profile( @@ -116,6 +118,7 @@ class BigQuerySystemMetricsComputer(SystemMetricsComputer, CacheProvider): usage_location=usage_location, project_id=project_id, dataset_id=dataset_id, + billing_project_id=self.billing_project_id, ) @staticmethod diff --git a/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py index ff0d03cdbf9..6454e6159ce 100644 --- a/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py +++ b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py @@ -304,8 +304,7 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer): ] where_clause = [ - Column("project_id") - == self.conn_config.credentials.gcpConfig.projectId.root, + Column("project_id") == self._entity.database.name, Column("table_schema") == self.schema_name, Column("table_name") == self.table_name, ] @@ -338,17 +337,14 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer): *self._get_col_names_and_count(), ] where_clause = [ - Column("project_id") - == self.conn_config.credentials.gcpConfig.projectId.root, + Column("project_id") == self._entity.database.name, Column("dataset_id") == self.schema_name, Column("table_id") == self.table_name, ] schema = ( - self.schema_name.startswith( - f"{self.conn_config.credentials.gcpConfig.projectId.root}." - ) + self.schema_name.startswith(f"{self._entity.database.name}.") and self.schema_name - or f"{self.conn_config.credentials.gcpConfig.projectId.root}.{self.schema_name}" + or f"{self._entity.database.name}.{self.schema_name}" ) query = self._build_query( columns, diff --git a/ingestion/src/metadata/sampler/sqlalchemy/bigquery/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/bigquery/sampler.py index b74c375169f..2427285d96b 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/bigquery/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/bigquery/sampler.py @@ -12,6 +12,7 @@ Helper module to handle data sampling for the profiler """ +from copy import deepcopy from typing import Dict, Optional, Union from sqlalchemy import Column @@ -31,11 +32,14 @@ from metadata.generated.schema.entity.services.connections.database.datalakeConn DatalakeConnection, ) from metadata.generated.schema.entity.services.databaseService import DatabaseConnection +from metadata.generated.schema.security.credentials.gcpValues import SingleProjectId +from metadata.ingestion.connections.session import create_and_bind_thread_safe_session from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.sampler.models import SampleConfig from metadata.sampler.sqlalchemy.sampler import SQASampler from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT from metadata.utils.logger import profiler_interface_registry_logger +from metadata.utils.ssl_manager import get_ssl_connection logger = profiler_interface_registry_logger() @@ -74,6 +78,19 @@ class BigQuerySampler(SQASampler): ) self.raw_dataset_type: Optional[TableType] = entity.tableType + connection_config = deepcopy(service_connection_config) + # Create a modified connection for BigQuery with the correct project ID + if ( + hasattr(connection_config.credentials.gcpConfig, "projectId") + and self.entity.database + ): + connection_config.credentials.gcpConfig.projectId = SingleProjectId( + root=self.entity.database.name + ) + self.connection = get_ssl_connection(connection_config) + + self.session_factory = create_and_bind_thread_safe_session(self.connection) + def set_tablesample(self, selectable: SqaTable): """Set the TABLESAMPLE clause for BigQuery Args: diff --git a/ingestion/tests/cli_e2e/database/bigquery/bigquery.yaml b/ingestion/tests/cli_e2e/database/bigquery/bigquery.yaml index 88a8be126b4..05d72656b85 100644 --- a/ingestion/tests/cli_e2e/database/bigquery/bigquery.yaml +++ b/ingestion/tests/cli_e2e/database/bigquery/bigquery.yaml @@ -8,6 +8,7 @@ source: taxonomyProjectID: - $E2E_BQ_PROJECT_ID - $E2E_BQ_PROJECT_ID2 + billingProjectId: $E2E_BQ_PROJECT_ID2 credentials: gcpConfig: type: service_account diff --git a/ingestion/tests/unit/profiler/sqlalchemy/bigquery/test_bigquery_sampling.py b/ingestion/tests/unit/profiler/sqlalchemy/bigquery/test_bigquery_sampling.py index 5c719d7fc75..9a7fbf20991 100644 --- a/ingestion/tests/unit/profiler/sqlalchemy/bigquery/test_bigquery_sampling.py +++ b/ingestion/tests/unit/profiler/sqlalchemy/bigquery/test_bigquery_sampling.py @@ -22,6 +22,7 @@ from metadata.generated.schema.security.credentials.gcpCredentials import GCPCre from metadata.generated.schema.security.credentials.gcpValues import ( GcpCredentialsValues, ) +from metadata.generated.schema.type.entityReference import EntityReference from metadata.profiler.interface.sqlalchemy.profiler_interface import ( SQAProfilerInterface, ) @@ -69,6 +70,7 @@ class SampleTest(TestCase): dataType=DataType.INT, ), ], + database=EntityReference(id=uuid4(), name="myproject", type="database"), ) cls.bq_conn = BigQueryConnection( @@ -137,6 +139,7 @@ class SampleTest(TestCase): ), ], tableType=TableType.View, + database=EntityReference(id=uuid4(), name="myproject", type="database"), ) sampler = BigQuerySampler( @@ -172,6 +175,7 @@ class SampleTest(TestCase): ), ], tableType=TableType.View, + database=EntityReference(id=uuid4(), name="myproject", type="database"), ) sampler = BigQuerySampler(