From b19b7f59a5f30f84b93acf13c6c5b6e50758e129 Mon Sep 17 00:00:00 2001 From: Onkar Ravgan Date: Mon, 22 Jul 2024 12:36:41 +0530 Subject: [PATCH] Fix #17098: Fixed case sensitive partition column name in Bigquery (#17104) * Fixed case sensitive partiion col name bigquery * update test --- .../source/database/bigquery/metadata.py | 93 +++++++++++++------ .../tests/unit/test_handle_partitions.py | 70 ++++++++++++-- 2 files changed, 126 insertions(+), 37 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py index b3e3624f799..47c4be4c4cb 100644 --- a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py @@ -641,49 +641,82 @@ class BigquerySource( logger.warning("Schema definition not implemented") return None + def _get_partition_column_name( + self, columns: List[Dict], partition_field_name: str + ): + """ + Method to get the correct partition column name + """ + try: + for column in columns or []: + column_name = column.get("name") + if column_name and ( + column_name.lower() == partition_field_name.lower() + ): + return column_name + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"Error getting partition column name for {partition_field_name}: {exc}" + ) + return None + def get_table_partition_details( self, table_name: str, schema_name: str, inspector: Inspector ) -> Tuple[bool, Optional[TablePartition]]: """ check if the table is partitioned table and return the partition details """ - database = self.context.get().database - table = self.client.get_table(fqn._build(database, schema_name, table_name)) - if table.time_partitioning is not None: - if table.time_partitioning.field: - table_partition = TablePartition( + try: + database = self.context.get().database + table = self.client.get_table(fqn._build(database, schema_name, table_name)) + columns = inspector.get_columns(table_name, schema_name, db_name=database) + if table.time_partitioning is not None: + if table.time_partitioning.field: + table_partition = TablePartition( + columns=[ + PartitionColumnDetails( + columnName=self._get_partition_column_name( + columns=columns, + partition_field_name=table.time_partitioning.field, + ), + interval=str(table.time_partitioning.type_), + intervalType=PartitionIntervalTypes.TIME_UNIT, + ) + ] + ) + return True, table_partition + return True, TablePartition( columns=[ PartitionColumnDetails( - columnName=table.time_partitioning.field, + columnName="_PARTITIONTIME" + if table.time_partitioning.type_ == "HOUR" + else "_PARTITIONDATE", interval=str(table.time_partitioning.type_), - intervalType=PartitionIntervalTypes.TIME_UNIT, + intervalType=PartitionIntervalTypes.INGESTION_TIME, ) ] ) - return True, table_partition - return True, TablePartition( - columns=[ - PartitionColumnDetails( - columnName="_PARTITIONTIME" - if table.time_partitioning.type_ == "HOUR" - else "_PARTITIONDATE", - interval=str(table.time_partitioning.type_), - intervalType=PartitionIntervalTypes.INGESTION_TIME, - ) - ] + if table.range_partitioning: + table_partition = PartitionColumnDetails( + columnName=self._get_partition_column_name( + columns=columns, + partition_field_name=table.range_partitioning.field, + ), + intervalType=PartitionIntervalTypes.INTEGER_RANGE, + interval=None, + ) + if hasattr(table.range_partitioning, "range_") and hasattr( + table.range_partitioning.range_, "interval" + ): + table_partition.interval = table.range_partitioning.range_.interval + table_partition.columnName = table.range_partitioning.field + return True, TablePartition(columns=[table_partition]) + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"Error getting table partition details for {table_name}: {exc}" ) - if table.range_partitioning: - table_partition = PartitionColumnDetails( - columnName=table.range_partitioning.field, - intervalType=PartitionIntervalTypes.INTEGER_RANGE, - interval=None, - ) - if hasattr(table.range_partitioning, "range_") and hasattr( - table.range_partitioning.range_, "interval" - ): - table_partition.interval = table.range_partitioning.range_.interval - table_partition.columnName = table.range_partitioning.field - return True, TablePartition(columns=[table_partition]) return False, None def clean_raw_data_type(self, raw_data_type): diff --git a/ingestion/tests/unit/test_handle_partitions.py b/ingestion/tests/unit/test_handle_partitions.py index d88027498e3..1a09f5b3b7d 100644 --- a/ingestion/tests/unit/test_handle_partitions.py +++ b/ingestion/tests/unit/test_handle_partitions.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types import unittest from typing import Optional from unittest import TestCase @@ -18,6 +17,7 @@ from unittest.mock import Mock, patch from google.cloud.bigquery import PartitionRange, RangePartitioning, TimePartitioning from google.cloud.bigquery.table import Table from pydantic import BaseModel +from sqlalchemy import Integer, String from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.table import ( @@ -92,6 +92,61 @@ MOCK_RANGE_PARTITIONING = RangePartitioning( field="test_column", range_=PartitionRange(end=100, interval=10, start=0) ) +MOCK_COLUMN_DATA = [ + { + "name": "customer_id", + "type": Integer(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "system_data_type": "INTEGER", + "is_complex": False, + "policy_tags": None, + }, + { + "name": "first_name", + "type": String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "system_data_type": "VARCHAR", + "is_complex": False, + "policy_tags": None, + }, + { + "name": "last_name", + "type": String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "system_data_type": "VARCHAR", + "is_complex": False, + "policy_tags": None, + }, + { + "name": "test_column", + "type": String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "system_data_type": "VARCHAR", + "is_complex": False, + "policy_tags": None, + }, +] + class BigqueryUnitTest(TestCase): @patch("google.cloud.bigquery.Client") @@ -127,7 +182,9 @@ class BigqueryUnitTest(TestCase): "database" ] = MOCK_DATABASE.fullyQualifiedName.root self.bigquery_source.client = client - self.inspector = types.SimpleNamespace() + self.bigquery_source.inspector.get_columns = ( + lambda table_name, schema, db_name: MOCK_COLUMN_DATA + ) unittest.mock.patch.object(Table, "object") @@ -138,7 +195,7 @@ class BigqueryUnitTest(TestCase): bool_resp, partition = self.bigquery_source.get_table_partition_details( schema_name=TEST_PARTITION.get("schema_name"), table_name=TEST_PARTITION.get("table_name"), - inspector=self.inspector, + inspector=self.bigquery_source.inspector, ) assert partition.columns == [ @@ -162,7 +219,7 @@ class BigqueryUnitTest(TestCase): bool_resp, partition = self.bigquery_source.get_table_partition_details( schema_name=TEST_PARTITION.get("schema_name"), table_name=TEST_PARTITION.get("table_name"), - inspector=self.inspector, + inspector=self.bigquery_source.inspector, ) self.assertIsInstance(partition.columns, list) @@ -177,11 +234,10 @@ class BigqueryUnitTest(TestCase): self.bigquery_source.client.get_table = lambda fqn: MockTable( time_partitioning=None, range_partitioning=MOCK_RANGE_PARTITIONING ) - bool_resp, partition = self.bigquery_source.get_table_partition_details( schema_name=TEST_PARTITION.get("schema_name"), table_name=TEST_PARTITION.get("table_name"), - inspector=self.inspector, + inspector=self.bigquery_source.inspector, ) self.assertIsInstance(partition.columns, list) @@ -200,7 +256,7 @@ class BigqueryUnitTest(TestCase): bool_resp, partition = self.bigquery_source.get_table_partition_details( schema_name=TEST_PARTITION.get("schema_name"), table_name=TEST_PARTITION.get("table_name"), - inspector=self.inspector, + inspector=self.bigquery_source.inspector, ) assert not bool_resp