Fix #17098: Fixed case sensitive partition column name in Bigquery (#17104)

* Fixed case sensitive partiion col name bigquery

* update test
This commit is contained in:
Onkar Ravgan 2024-07-22 12:36:41 +05:30 committed by GitHub
parent d3ea1ead01
commit b19b7f59a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 126 additions and 37 deletions

View File

@ -641,49 +641,82 @@ class BigquerySource(
logger.warning("Schema definition not implemented")
return None
def _get_partition_column_name(
self, columns: List[Dict], partition_field_name: str
):
"""
Method to get the correct partition column name
"""
try:
for column in columns or []:
column_name = column.get("name")
if column_name and (
column_name.lower() == partition_field_name.lower()
):
return column_name
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error getting partition column name for {partition_field_name}: {exc}"
)
return None
def get_table_partition_details(
self, table_name: str, schema_name: str, inspector: Inspector
) -> Tuple[bool, Optional[TablePartition]]:
"""
check if the table is partitioned table and return the partition details
"""
database = self.context.get().database
table = self.client.get_table(fqn._build(database, schema_name, table_name))
if table.time_partitioning is not None:
if table.time_partitioning.field:
table_partition = TablePartition(
try:
database = self.context.get().database
table = self.client.get_table(fqn._build(database, schema_name, table_name))
columns = inspector.get_columns(table_name, schema_name, db_name=database)
if table.time_partitioning is not None:
if table.time_partitioning.field:
table_partition = TablePartition(
columns=[
PartitionColumnDetails(
columnName=self._get_partition_column_name(
columns=columns,
partition_field_name=table.time_partitioning.field,
),
interval=str(table.time_partitioning.type_),
intervalType=PartitionIntervalTypes.TIME_UNIT,
)
]
)
return True, table_partition
return True, TablePartition(
columns=[
PartitionColumnDetails(
columnName=table.time_partitioning.field,
columnName="_PARTITIONTIME"
if table.time_partitioning.type_ == "HOUR"
else "_PARTITIONDATE",
interval=str(table.time_partitioning.type_),
intervalType=PartitionIntervalTypes.TIME_UNIT,
intervalType=PartitionIntervalTypes.INGESTION_TIME,
)
]
)
return True, table_partition
return True, TablePartition(
columns=[
PartitionColumnDetails(
columnName="_PARTITIONTIME"
if table.time_partitioning.type_ == "HOUR"
else "_PARTITIONDATE",
interval=str(table.time_partitioning.type_),
intervalType=PartitionIntervalTypes.INGESTION_TIME,
)
]
if table.range_partitioning:
table_partition = PartitionColumnDetails(
columnName=self._get_partition_column_name(
columns=columns,
partition_field_name=table.range_partitioning.field,
),
intervalType=PartitionIntervalTypes.INTEGER_RANGE,
interval=None,
)
if hasattr(table.range_partitioning, "range_") and hasattr(
table.range_partitioning.range_, "interval"
):
table_partition.interval = table.range_partitioning.range_.interval
table_partition.columnName = table.range_partitioning.field
return True, TablePartition(columns=[table_partition])
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error getting table partition details for {table_name}: {exc}"
)
if table.range_partitioning:
table_partition = PartitionColumnDetails(
columnName=table.range_partitioning.field,
intervalType=PartitionIntervalTypes.INTEGER_RANGE,
interval=None,
)
if hasattr(table.range_partitioning, "range_") and hasattr(
table.range_partitioning.range_, "interval"
):
table_partition.interval = table.range_partitioning.range_.interval
table_partition.columnName = table.range_partitioning.field
return True, TablePartition(columns=[table_partition])
return False, None
def clean_raw_data_type(self, raw_data_type):

View File

@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import unittest
from typing import Optional
from unittest import TestCase
@ -18,6 +17,7 @@ from unittest.mock import Mock, patch
from google.cloud.bigquery import PartitionRange, RangePartitioning, TimePartitioning
from google.cloud.bigquery.table import Table
from pydantic import BaseModel
from sqlalchemy import Integer, String
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import (
@ -92,6 +92,61 @@ MOCK_RANGE_PARTITIONING = RangePartitioning(
field="test_column", range_=PartitionRange(end=100, interval=10, start=0)
)
MOCK_COLUMN_DATA = [
{
"name": "customer_id",
"type": Integer(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "INTEGER",
"is_complex": False,
"policy_tags": None,
},
{
"name": "first_name",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
{
"name": "last_name",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
{
"name": "test_column",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
]
class BigqueryUnitTest(TestCase):
@patch("google.cloud.bigquery.Client")
@ -127,7 +182,9 @@ class BigqueryUnitTest(TestCase):
"database"
] = MOCK_DATABASE.fullyQualifiedName.root
self.bigquery_source.client = client
self.inspector = types.SimpleNamespace()
self.bigquery_source.inspector.get_columns = (
lambda table_name, schema, db_name: MOCK_COLUMN_DATA
)
unittest.mock.patch.object(Table, "object")
@ -138,7 +195,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector,
inspector=self.bigquery_source.inspector,
)
assert partition.columns == [
@ -162,7 +219,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector,
inspector=self.bigquery_source.inspector,
)
self.assertIsInstance(partition.columns, list)
@ -177,11 +234,10 @@ class BigqueryUnitTest(TestCase):
self.bigquery_source.client.get_table = lambda fqn: MockTable(
time_partitioning=None, range_partitioning=MOCK_RANGE_PARTITIONING
)
bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector,
inspector=self.bigquery_source.inspector,
)
self.assertIsInstance(partition.columns, list)
@ -200,7 +256,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector,
inspector=self.bigquery_source.inspector,
)
assert not bool_resp