MINOR - Fix column to match set test (#15186)

* fix: column value test for SQA types

* style: ran python linting
This commit is contained in:
Teddy 2024-02-23 16:35:58 +01:00 committed by GitHub
parent bdf27458e5
commit ba8208222e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 11 deletions

View File

@ -16,6 +16,7 @@ Validator for table column to match set test case
import collections import collections
import traceback import traceback
from abc import abstractmethod from abc import abstractmethod
from typing import List
from metadata.data_quality.validations.base_test_handler import BaseTestValidator from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import ( from metadata.generated.schema.tests.basic import (
@ -93,5 +94,5 @@ class BaseTableColumnToMatchSetValidator(BaseTestValidator):
) )
@abstractmethod @abstractmethod
def _run_results(self): def _run_results(self) -> List[str]:
raise NotImplementedError raise NotImplementedError

View File

@ -13,6 +13,7 @@
Validator for table column name to match set test case Validator for table column name to match set test case
""" """
from typing import List
from metadata.data_quality.validations.mixins.pandas_validator_mixin import ( from metadata.data_quality.validations.mixins.pandas_validator_mixin import (
PandasValidatorMixin, PandasValidatorMixin,
@ -30,7 +31,7 @@ class TableColumnToMatchSetValidator(
): ):
"""Validator table column name to match set test case""" """Validator table column name to match set test case"""
def _run_results(self): def _run_results(self) -> List[str]:
"""compute result of the test case""" """compute result of the test case"""
names = list(self.runner[0].columns) names = list(self.runner[0].columns)
if not names: if not names:

View File

@ -14,9 +14,10 @@ Validator for table column name to match set test case
""" """
from typing import Optional from typing import List, cast
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.sql.base import ColumnCollection
from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin, SQAValidatorMixin,
@ -34,11 +35,15 @@ class TableColumnToMatchSetValidator(
): ):
"""Validator for table column name to match set test case""" """Validator for table column name to match set test case"""
def _run_results(self) -> Optional[int]: def _run_results(self) -> List[str]:
"""compute result of the test case""" """compute result of the test case"""
names = inspect(self.runner.table).c names = inspect(self.runner.table).c
if not names: if not names:
raise ValueError( raise ValueError(
f"Column names for test case {self.test_case.name} returned None" f"Column names for test case {self.test_case.name} returned None"
) )
names = cast(
ColumnCollection, names
) # satisfy type checker for names.keys() access
names = list(names.keys())
return names return names

View File

@ -158,10 +158,10 @@ class BigtableSource(CommonNoSQLSource, MultiDBSource):
records = [{"row_key": b"row_key"}] records = [{"row_key": b"row_key"}]
# In order to get a "good" sample of data, we try to distribute the sampling # In order to get a "good" sample of data, we try to distribute the sampling
# across multiple column families. # across multiple column families.
for cf in list(column_families.keys())[:MAX_COLUMN_FAMILIES]: for column_family in list(column_families.keys())[:MAX_COLUMN_FAMILIES]:
records.extend( records.extend(
self._get_records_for_column_family( self._get_records_for_column_family(
table, cf, SAMPLES_PER_COLUMN_FAMILY table, column_family, SAMPLES_PER_COLUMN_FAMILY
) )
) )
if len(records) >= GLOBAL_SAMPLE_SIZE: if len(records) >= GLOBAL_SAMPLE_SIZE:

View File

@ -39,22 +39,22 @@ class Row(BaseModel):
@classmethod @classmethod
def from_partial_row(cls, row: PartialRowData): def from_partial_row(cls, row: PartialRowData):
cells = {} cells = {}
for cf, cf_cells in row.cells.items(): for column_family, cf_cells in row.cells.items():
cells.setdefault(cf, {}) cells.setdefault(column_family, {})
for column, cell in cf_cells.items(): for column, cell in cf_cells.items():
cells[cf][column] = Cell( cells[column_family][column] = Cell(
values=[Value(timestamp=c.timestamp, value=c.value) for c in cell] values=[Value(timestamp=c.timestamp, value=c.value) for c in cell]
) )
return cls(cells=cells, row_key=row.row_key) return cls(cells=cells, row_key=row.row_key)
def to_record(self) -> Dict[str, bytes]: def to_record(self) -> Dict[str, bytes]:
record = {} record = {}
for cf, cells in self.cells.items(): for column_family, cells in self.cells.items():
for column, cell in cells.items(): for column, cell in cells.items():
# Since each cell can have multiple values and the API returns them in descending order # Since each cell can have multiple values and the API returns them in descending order
# from latest to oldest, we only take the latest value. This probably does not matter since # from latest to oldest, we only take the latest value. This probably does not matter since
# all we care about is data types and all data stored in BigTable is of type `bytes`. # all we care about is data types and all data stored in BigTable is of type `bytes`.
record[f"{cf}.{column.decode()}"] = cell.values[0].value record[f"{column_family}.{column.decode()}"] = cell.values[0].value
record["row_key"] = self.row_key record["row_key"] = self.row_key
return record return record