Fix #17098: Fixed case sensitive partition column name in Bigquery (#17104)

* Fixed case sensitive partiion col name bigquery

* update test
This commit is contained in:
Onkar Ravgan 2024-07-22 12:36:41 +05:30 committed by GitHub
parent d3ea1ead01
commit b19b7f59a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 126 additions and 37 deletions

View File

@ -641,49 +641,82 @@ class BigquerySource(
logger.warning("Schema definition not implemented") logger.warning("Schema definition not implemented")
return None return None
def _get_partition_column_name(
self, columns: List[Dict], partition_field_name: str
):
"""
Method to get the correct partition column name
"""
try:
for column in columns or []:
column_name = column.get("name")
if column_name and (
column_name.lower() == partition_field_name.lower()
):
return column_name
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error getting partition column name for {partition_field_name}: {exc}"
)
return None
def get_table_partition_details( def get_table_partition_details(
self, table_name: str, schema_name: str, inspector: Inspector self, table_name: str, schema_name: str, inspector: Inspector
) -> Tuple[bool, Optional[TablePartition]]: ) -> Tuple[bool, Optional[TablePartition]]:
""" """
check if the table is partitioned table and return the partition details check if the table is partitioned table and return the partition details
""" """
database = self.context.get().database try:
table = self.client.get_table(fqn._build(database, schema_name, table_name)) database = self.context.get().database
if table.time_partitioning is not None: table = self.client.get_table(fqn._build(database, schema_name, table_name))
if table.time_partitioning.field: columns = inspector.get_columns(table_name, schema_name, db_name=database)
table_partition = TablePartition( if table.time_partitioning is not None:
if table.time_partitioning.field:
table_partition = TablePartition(
columns=[
PartitionColumnDetails(
columnName=self._get_partition_column_name(
columns=columns,
partition_field_name=table.time_partitioning.field,
),
interval=str(table.time_partitioning.type_),
intervalType=PartitionIntervalTypes.TIME_UNIT,
)
]
)
return True, table_partition
return True, TablePartition(
columns=[ columns=[
PartitionColumnDetails( PartitionColumnDetails(
columnName=table.time_partitioning.field, columnName="_PARTITIONTIME"
if table.time_partitioning.type_ == "HOUR"
else "_PARTITIONDATE",
interval=str(table.time_partitioning.type_), interval=str(table.time_partitioning.type_),
intervalType=PartitionIntervalTypes.TIME_UNIT, intervalType=PartitionIntervalTypes.INGESTION_TIME,
) )
] ]
) )
return True, table_partition if table.range_partitioning:
return True, TablePartition( table_partition = PartitionColumnDetails(
columns=[ columnName=self._get_partition_column_name(
PartitionColumnDetails( columns=columns,
columnName="_PARTITIONTIME" partition_field_name=table.range_partitioning.field,
if table.time_partitioning.type_ == "HOUR" ),
else "_PARTITIONDATE", intervalType=PartitionIntervalTypes.INTEGER_RANGE,
interval=str(table.time_partitioning.type_), interval=None,
intervalType=PartitionIntervalTypes.INGESTION_TIME, )
) if hasattr(table.range_partitioning, "range_") and hasattr(
] table.range_partitioning.range_, "interval"
):
table_partition.interval = table.range_partitioning.range_.interval
table_partition.columnName = table.range_partitioning.field
return True, TablePartition(columns=[table_partition])
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error getting table partition details for {table_name}: {exc}"
) )
if table.range_partitioning:
table_partition = PartitionColumnDetails(
columnName=table.range_partitioning.field,
intervalType=PartitionIntervalTypes.INTEGER_RANGE,
interval=None,
)
if hasattr(table.range_partitioning, "range_") and hasattr(
table.range_partitioning.range_, "interval"
):
table_partition.interval = table.range_partitioning.range_.interval
table_partition.columnName = table.range_partitioning.field
return True, TablePartition(columns=[table_partition])
return False, None return False, None
def clean_raw_data_type(self, raw_data_type): def clean_raw_data_type(self, raw_data_type):

View File

@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import types
import unittest import unittest
from typing import Optional from typing import Optional
from unittest import TestCase from unittest import TestCase
@ -18,6 +17,7 @@ from unittest.mock import Mock, patch
from google.cloud.bigquery import PartitionRange, RangePartitioning, TimePartitioning from google.cloud.bigquery import PartitionRange, RangePartitioning, TimePartitioning
from google.cloud.bigquery.table import Table from google.cloud.bigquery.table import Table
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Integer, String
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.data.table import (
@ -92,6 +92,61 @@ MOCK_RANGE_PARTITIONING = RangePartitioning(
field="test_column", range_=PartitionRange(end=100, interval=10, start=0) field="test_column", range_=PartitionRange(end=100, interval=10, start=0)
) )
MOCK_COLUMN_DATA = [
{
"name": "customer_id",
"type": Integer(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "INTEGER",
"is_complex": False,
"policy_tags": None,
},
{
"name": "first_name",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
{
"name": "last_name",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
{
"name": "test_column",
"type": String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"system_data_type": "VARCHAR",
"is_complex": False,
"policy_tags": None,
},
]
class BigqueryUnitTest(TestCase): class BigqueryUnitTest(TestCase):
@patch("google.cloud.bigquery.Client") @patch("google.cloud.bigquery.Client")
@ -127,7 +182,9 @@ class BigqueryUnitTest(TestCase):
"database" "database"
] = MOCK_DATABASE.fullyQualifiedName.root ] = MOCK_DATABASE.fullyQualifiedName.root
self.bigquery_source.client = client self.bigquery_source.client = client
self.inspector = types.SimpleNamespace() self.bigquery_source.inspector.get_columns = (
lambda table_name, schema, db_name: MOCK_COLUMN_DATA
)
unittest.mock.patch.object(Table, "object") unittest.mock.patch.object(Table, "object")
@ -138,7 +195,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details( bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"), schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"), table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector, inspector=self.bigquery_source.inspector,
) )
assert partition.columns == [ assert partition.columns == [
@ -162,7 +219,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details( bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"), schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"), table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector, inspector=self.bigquery_source.inspector,
) )
self.assertIsInstance(partition.columns, list) self.assertIsInstance(partition.columns, list)
@ -177,11 +234,10 @@ class BigqueryUnitTest(TestCase):
self.bigquery_source.client.get_table = lambda fqn: MockTable( self.bigquery_source.client.get_table = lambda fqn: MockTable(
time_partitioning=None, range_partitioning=MOCK_RANGE_PARTITIONING time_partitioning=None, range_partitioning=MOCK_RANGE_PARTITIONING
) )
bool_resp, partition = self.bigquery_source.get_table_partition_details( bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"), schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"), table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector, inspector=self.bigquery_source.inspector,
) )
self.assertIsInstance(partition.columns, list) self.assertIsInstance(partition.columns, list)
@ -200,7 +256,7 @@ class BigqueryUnitTest(TestCase):
bool_resp, partition = self.bigquery_source.get_table_partition_details( bool_resp, partition = self.bigquery_source.get_table_partition_details(
schema_name=TEST_PARTITION.get("schema_name"), schema_name=TEST_PARTITION.get("schema_name"),
table_name=TEST_PARTITION.get("table_name"), table_name=TEST_PARTITION.get("table_name"),
inspector=self.inspector, inspector=self.bigquery_source.inspector,
) )
assert not bool_resp assert not bool_resp