Fixes #23010 #: BigQuery Project Selection In Profiler & AutoClassification Workflow (#23233)

* fix: added code for separate engine and session for each project in rofiler and classification and refactor billing project approach

* fix: added entity.database check, bigquery sampling tests

* fix: system metrics logic when bigquery billing project is provided
This commit is contained in:
Keshav Mohta 2025-09-05 14:09:14 +05:30 committed by GitHub
parent 972afc375a
commit 103857f90c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 87 additions and 24 deletions

View File

@ -56,7 +56,10 @@ def get_connection_args_common(connection) -> Dict[str, Any]:
def create_generic_db_connection(
connection, get_connection_url_fn: Callable, get_connection_args_fn: Callable
connection,
get_connection_url_fn: Callable,
get_connection_args_fn: Callable,
**kwargs,
) -> Engine:
"""
Generic Engine creation from connection object
@ -75,6 +78,7 @@ def create_generic_db_connection(
pool_reset_on_return=None, # https://docs.sqlalchemy.org/en/14/core/pooling.html#reset-on-return
echo=False,
max_overflow=-1,
**kwargs,
)
attach_query_tracker(engine)

View File

@ -31,6 +31,7 @@ from metadata.generated.schema.entity.services.connections.testConnectionResult
TestConnectionResult,
)
from metadata.generated.schema.security.credentials.gcpCredentials import (
GcpADC,
GcpCredentialsPath,
)
from metadata.generated.schema.security.credentials.gcpValues import (
@ -68,25 +69,21 @@ def get_connection_url(connection: BigQueryConnection) -> str:
connection.credentials.gcpConfig.projectId, SingleProjectId
):
if not connection.credentials.gcpConfig.projectId.root:
return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root or ''}"
return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root or ''}"
if (
not connection.credentials.gcpConfig.privateKey
and connection.credentials.gcpConfig.projectId.root
):
project_id = connection.credentials.gcpConfig.projectId.root
os.environ["GOOGLE_CLOUD_PROJECT"] = (
connection.billingProjectId or project_id
)
return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root}"
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}"
elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId):
for project_id in connection.credentials.gcpConfig.projectId.root:
if not connection.credentials.gcpConfig.privateKey and project_id:
# Setting environment variable based on project id given by user / set in ADC
os.environ["GOOGLE_CLOUD_PROJECT"] = (
connection.billingProjectId or project_id
)
return f"{connection.scheme.value}://{connection.billingProjectId or project_id}"
return f"{connection.scheme.value}://{connection.billingProjectId or ''}"
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
return f"{connection.scheme.value}://{project_id}"
return f"{connection.scheme.value}://"
# If gcpConfig is the JSON key path and projectId is defined, we use it by default
elif (
@ -96,13 +93,27 @@ def get_connection_url(connection: BigQueryConnection) -> str:
if isinstance( # pylint: disable=no-else-return
connection.credentials.gcpConfig.projectId, SingleProjectId
):
return f"{connection.scheme.value}://{connection.billingProjectId or connection.credentials.gcpConfig.projectId.root}"
return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}"
elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId):
for project_id in connection.credentials.gcpConfig.projectId.root:
return f"{connection.scheme.value}://{connection.billingProjectId or project_id}"
return f"{connection.scheme.value}://{project_id}"
return f"{connection.scheme.value}://{connection.billingProjectId or ''}"
# If gcpConfig is the GCP ADC and projectId is defined, we use it by default
elif (
isinstance(connection.credentials.gcpConfig, GcpADC)
and connection.credentials.gcpConfig.projectId
):
if isinstance( # pylint: disable=no-else-return
connection.credentials.gcpConfig.projectId, SingleProjectId
):
return f"{connection.scheme.value}://{connection.credentials.gcpConfig.projectId.root}"
elif isinstance(connection.credentials.gcpConfig.projectId, MultipleProjectId):
for project_id in connection.credentials.gcpConfig.projectId.root:
return f"{connection.scheme.value}://{project_id}"
return f"{connection.scheme.value}://"
def get_connection(connection: BigQueryConnection) -> Engine:
@ -110,10 +121,15 @@ def get_connection(connection: BigQueryConnection) -> Engine:
Prepare the engine and the GCP credentials
"""
set_google_credentials(gcp_credentials=connection.credentials)
kwargs = {}
if connection.billingProjectId:
kwargs["billing_project_id"] = connection.billingProjectId
return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args_common,
**kwargs,
)

View File

@ -211,13 +211,18 @@ class BigQueryQueryResult(BaseModel):
usage_location: str,
dataset_id: str,
project_id: str,
billing_project_id: Optional[str] = None,
):
# Use billing project for the INFORMATION_SCHEMA query if provided
query_project_id = billing_project_id or project_id
rows = session.execute(
text(
JOBS.format(
usage_location=usage_location,
dataset_id=dataset_id,
project_id=project_id,
query_project_id=query_project_id,
insert=DatabaseDMLOperations.INSERT.value,
update=DatabaseDMLOperations.UPDATE.value,
delete=DatabaseDMLOperations.DELETE.value,
@ -240,7 +245,7 @@ JOBS = """
dml_statistics.deleted_row_count as deleted_row_count,
dml_statistics.updated_row_count as updated_row_count
FROM
`region-{usage_location}`.INFORMATION_SCHEMA.JOBS
`{query_project_id}`.`region-{usage_location}`.INFORMATION_SCHEMA.JOBS
WHERE
DATE(creation_time) >= CURRENT_DATE() - 1 AND
destination_table.dataset_id = '{dataset_id}' AND

View File

@ -13,11 +13,13 @@
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from copy import deepcopy
from typing import List, Type, cast
from sqlalchemy import Column, inspect
from metadata.generated.schema.entity.data.table import SystemProfile
from metadata.generated.schema.security.credentials.gcpValues import SingleProjectId
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
@ -27,6 +29,7 @@ from metadata.profiler.metrics.system.bigquery.system import (
from metadata.profiler.metrics.system.system import System
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.logger import profiler_interface_registry_logger
from metadata.utils.ssl_manager import get_ssl_connection
logger = profiler_interface_registry_logger()
@ -34,6 +37,19 @@ logger = profiler_interface_registry_logger()
class BigQueryProfilerInterface(SQAProfilerInterface):
"""BigQuery profiler interface"""
def create_session(self):
connection_config = deepcopy(self.service_connection_config)
# Create a modified connection for BigQuery with the correct project ID
if (
hasattr(connection_config.credentials.gcpConfig, "projectId")
and self.table_entity.database
):
connection_config.credentials.gcpConfig.projectId = SingleProjectId(
root=self.table_entity.database.name
)
self.connection = get_ssl_connection(connection_config)
return super().create_session()
def _compute_system_metrics(
self,
metrics: Type[System],
@ -49,6 +65,7 @@ class BigQueryProfilerInterface(SQAProfilerInterface):
session=self.session,
runner=runner,
usage_location=self.service_connection_config.usageLocation,
billing_project_id=self.service_connection_config.billingProjectId,
)
return instance.get_system_metrics()

View File

@ -1,6 +1,6 @@
"""BigQuery system metric source"""
from typing import List
from typing import List, Optional
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
@ -30,12 +30,14 @@ class BigQuerySystemMetricsComputer(SystemMetricsComputer, CacheProvider):
session: Session,
runner: QueryRunner,
usage_location: str,
billing_project_id: Optional[str] = None,
):
self.session = session
self.table = runner.table_name
self.project_id = runner.session.get_bind().url.host
self.dataset_id = runner.schema_name
self.usage_location = usage_location
self.billing_project_id = billing_project_id or self.project_id
def get_deletes(self) -> List[SystemProfile]:
return self.get_system_profile(
@ -116,6 +118,7 @@ class BigQuerySystemMetricsComputer(SystemMetricsComputer, CacheProvider):
usage_location=usage_location,
project_id=project_id,
dataset_id=dataset_id,
billing_project_id=self.billing_project_id,
)
@staticmethod

View File

@ -304,8 +304,7 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer):
]
where_clause = [
Column("project_id")
== self.conn_config.credentials.gcpConfig.projectId.root,
Column("project_id") == self._entity.database.name,
Column("table_schema") == self.schema_name,
Column("table_name") == self.table_name,
]
@ -338,17 +337,14 @@ class BigQueryTableMetricComputer(BaseTableMetricComputer):
*self._get_col_names_and_count(),
]
where_clause = [
Column("project_id")
== self.conn_config.credentials.gcpConfig.projectId.root,
Column("project_id") == self._entity.database.name,
Column("dataset_id") == self.schema_name,
Column("table_id") == self.table_name,
]
schema = (
self.schema_name.startswith(
f"{self.conn_config.credentials.gcpConfig.projectId.root}."
)
self.schema_name.startswith(f"{self._entity.database.name}.")
and self.schema_name
or f"{self.conn_config.credentials.gcpConfig.projectId.root}.{self.schema_name}"
or f"{self._entity.database.name}.{self.schema_name}"
)
query = self._build_query(
columns,

View File

@ -12,6 +12,7 @@
Helper module to handle data sampling
for the profiler
"""
from copy import deepcopy
from typing import Dict, Optional, Union
from sqlalchemy import Column
@ -31,11 +32,14 @@ from metadata.generated.schema.entity.services.connections.database.datalakeConn
DatalakeConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.security.credentials.gcpValues import SingleProjectId
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.sampler.models import SampleConfig
from metadata.sampler.sqlalchemy.sampler import SQASampler
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
from metadata.utils.logger import profiler_interface_registry_logger
from metadata.utils.ssl_manager import get_ssl_connection
logger = profiler_interface_registry_logger()
@ -74,6 +78,19 @@ class BigQuerySampler(SQASampler):
)
self.raw_dataset_type: Optional[TableType] = entity.tableType
connection_config = deepcopy(service_connection_config)
# Create a modified connection for BigQuery with the correct project ID
if (
hasattr(connection_config.credentials.gcpConfig, "projectId")
and self.entity.database
):
connection_config.credentials.gcpConfig.projectId = SingleProjectId(
root=self.entity.database.name
)
self.connection = get_ssl_connection(connection_config)
self.session_factory = create_and_bind_thread_safe_session(self.connection)
def set_tablesample(self, selectable: SqaTable):
"""Set the TABLESAMPLE clause for BigQuery
Args:

View File

@ -8,6 +8,7 @@ source:
taxonomyProjectID:
- $E2E_BQ_PROJECT_ID
- $E2E_BQ_PROJECT_ID2
billingProjectId: $E2E_BQ_PROJECT_ID2
credentials:
gcpConfig:
type: service_account

View File

@ -22,6 +22,7 @@ from metadata.generated.schema.security.credentials.gcpCredentials import GCPCre
from metadata.generated.schema.security.credentials.gcpValues import (
GcpCredentialsValues,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
@ -69,6 +70,7 @@ class SampleTest(TestCase):
dataType=DataType.INT,
),
],
database=EntityReference(id=uuid4(), name="myproject", type="database"),
)
cls.bq_conn = BigQueryConnection(
@ -137,6 +139,7 @@ class SampleTest(TestCase):
),
],
tableType=TableType.View,
database=EntityReference(id=uuid4(), name="myproject", type="database"),
)
sampler = BigQuerySampler(
@ -172,6 +175,7 @@ class SampleTest(TestCase):
),
],
tableType=TableType.View,
database=EntityReference(id=uuid4(), name="myproject", type="database"),
)
sampler = BigQuerySampler(