Fixes #6795 - Implement profiler support for struct types (#10817)

* fix: raise more informative error message when service is not found

* fix: profiling for struct table

* fix: linting

* fix: added tests for struct and nestedt struct for get_columns
This commit is contained in:
Teddy 2023-03-29 12:06:34 +02:00 committed by GitHub
parent 63edc5d5ca
commit e1b193a719
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 184 additions and 30 deletions

View File

@ -497,10 +497,18 @@ class ProfilerWorkflow(WorkflowStatusMixin):
service_name,
),
)
if not service:
raise ConnectionError(
f"Could not retrieve service with name `{service_name}`. "
"Typically caused by the `serviceName` does not exists in OpenMetadata "
"or the JWT Token is invalid."
)
if service:
self.config.source.serviceConnection = ServiceConnection(
__root__=service.connection
)
except ConnectionError as exc:
raise exc
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(

View File

@ -14,7 +14,7 @@ Converter logic to transform an OpenMetadata Table Entity
to an SQLAlchemy ORM class.
"""
from typing import Optional
from typing import List, Optional, cast
import sqlalchemy
from sqlalchemy import MetaData
@ -105,7 +105,9 @@ def check_snowflake_case_sensitive(table_service_type, table_or_col) -> Optional
return None
def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Column:
def build_orm_col(
idx: int, col: Column, table_service_type, parent: Optional[str] = None
) -> sqlalchemy.Column:
"""
Cook the ORM column from our metadata instance
information.
@ -118,8 +120,13 @@ def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Colum
there is no impact for our read-only purposes.
"""
if parent:
name = f"{parent}.{col.name.__root__}"
else:
name = col.name.__root__
return sqlalchemy.Column(
name=str(col.name.__root__),
name=str(name),
type_=map_types(col, table_service_type),
primary_key=not bool(idx), # The first col seen is used as PK
quote=check_snowflake_case_sensitive(table_service_type, col.name.__root__),
@ -129,6 +136,40 @@ def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Colum
)
def get_columns(
column_list: List[Column],
service_type: databaseService.DatabaseServiceType,
start: int = 0,
parent: Optional[str] = None,
) -> dict:
"""Build dictionnary of ORM columns
Args:
column_list (List[Column]): list of columns
service_type (DatabaseServiceType): database service type
start (int): index of the column used to define the primary key
parent (str): parent column name
"""
cols = {}
for idx, col in enumerate(column_list, start=start):
if parent:
name = f"{parent}.{col.name.__root__}"
else:
name = col.name.__root__
if name in SQA_RESERVED_ATTRIBUTES:
name = f"{name}_"
cols[name] = build_orm_col(idx, col, service_type, parent)
if col.children:
cols = {
**cols,
**get_columns(col.children, service_type, start=idx, parent=name),
}
return cols
def ometa_to_sqa_orm(
table: Table, metadata: OpenMetadata, sqa_metadata_obj: Optional[MetaData] = None
) -> DeclarativeMeta:
@ -142,14 +183,10 @@ def ometa_to_sqa_orm(
as the bases tuple for inheritance.
"""
cols = {
(
col.name.__root__ + "_"
if col.name.__root__ in SQA_RESERVED_ATTRIBUTES
else col.name.__root__
): build_orm_col(idx, col, table.serviceType)
for idx, col in enumerate(table.columns)
}
table.serviceType = cast(
databaseService.DatabaseServiceType, table.serviceType
) # satisfy mypy
cols = get_columns(table.columns, table.serviceType)
orm_database_name = get_orm_database(table, metadata)
orm_schema_name = get_orm_schema(table, metadata)

View File

@ -476,7 +476,7 @@ class Profiler(Generic[TMetric]):
return table_profile
def generate_sample_data(self) -> TableData:
def generate_sample_data(self) -> Optional[TableData]:
"""Fetch and ingest sample data
Returns:

View File

@ -17,11 +17,14 @@ supporting sqlalchemy abstraction layer
from abc import ABC, abstractmethod
from typing import Dict, Optional, Union
from pydantic import BaseModel
from sqlalchemy import Column, MetaData
from sqlalchemy import Column
from typing_extensions import Self
from metadata.generated.schema.entity.data.table import PartitionProfilerConfig, Table
from metadata.generated.schema.entity.data.table import (
PartitionProfilerConfig,
Table,
TableData,
)
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
@ -225,6 +228,6 @@ class ProfilerProtocol(ABC):
raise NotImplementedError
@abstractmethod
def fetch_sample_data(self, table) -> dict:
def fetch_sample_data(self, table) -> TableData:
"""run profiler metrics"""
raise NotImplementedError

View File

@ -488,7 +488,17 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin):
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return sampler.fetch_sqa_sample_data()
# Only fetch columns that are in the table entity
# with struct columns we create a column for each field in the ORM table
# but we only want to fetch the columns that are in the table entity
sample_columns = [
column.name
for column in table.__table__.columns
if column.name in {col.name.__root__ for col in self.table_entity.columns}
]
return sampler.fetch_sqa_sample_data(sample_columns)
def get_composed_metrics(
self, column: Column, metric: Metrics, column_results: Dict

View File

@ -132,7 +132,7 @@ class Sampler:
# Assign as an alias
return aliased(self.table, sampled)
def fetch_sqa_sample_data(self) -> TableData:
def fetch_sqa_sample_data(self, sample_columns: Optional[list] = None) -> TableData:
"""
Use the sampler to retrieve sample data rows as per limit given by user
:return: TableData to be added to the Table Entity
@ -142,7 +142,14 @@ class Sampler:
# Add new RandomNumFn column
rnd = self.get_sample_query()
sqa_columns = [col for col in inspect(rnd).c if col.name != RANDOM_LABEL]
sample_columns = (
sample_columns if sample_columns else [col.name for col in inspect(rnd).c]
)
sqa_columns = [
col
for col in inspect(rnd).c
if col.name != RANDOM_LABEL and col.name in sample_columns
]
sqa_sample = (
self.session.query(*sqa_columns)

View File

@ -17,12 +17,14 @@ from unittest.mock import patch
from uuid import UUID
from pytest import mark
from sqlalchemy import Column as SQAColumn
from sqlalchemy.sql.sqltypes import INTEGER, String
from metadata.generated.schema.entity.data.table import Column, DataType, Table
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.profiler.orm.converter import ometa_to_sqa_orm
from metadata.profiler.orm.converter import get_columns, ometa_to_sqa_orm
@patch("metadata.profiler.orm.converter.get_orm_schema", return_value="schema")
@ -125,3 +127,62 @@ def test_metadata_column(mock_schema, mock_database):
assert orm_table.__table_args__["schema"] == "schema"
for name, _ in column_definition:
assert hasattr(orm_table, name)
def test_get_columns_regular():
"""Test get columns function reads columns correctly"""
regular_columns = [
Column(
name="col1",
dataType=DataType.STRING,
),
Column(
name="col2",
dataType=DataType.INT,
),
]
cols = get_columns(regular_columns, DatabaseServiceType.BigQuery)
col1 = cols["col1"]
col2 = cols["col2"]
assert len(cols) == 2
assert col1.name == "col1"
assert isinstance(col1.type, String)
assert col2.name == "col2"
assert isinstance(col2.type, INTEGER)
def test_get_columns_struct():
"""Test get columns function reads columns correctly for struct"""
struct_columns = [
Column(
name="col1",
dataType=DataType.STRING,
),
Column(
name="col2",
dataType=DataType.STRUCT,
children=[
Column(
name="structCol1",
dataType=DataType.STRING,
),
Column(
name="structCol2",
dataType=DataType.STRUCT,
children=[
Column(
name="nestedStructCol1",
dataType=DataType.STRING,
),
],
),
],
),
]
cols = get_columns(struct_columns, DatabaseServiceType.BigQuery)
assert len(cols) == 5
assert "col2.structCol1" in cols
assert "col2.structCol2" in cols
assert "col2.structCol2.nestedStructCol1" in cols

View File

@ -275,6 +275,27 @@ public class TableRepository extends EntityRepository<Table> {
return table;
}
private Column getColumnNameForProfiler(List<Column> columnList, ColumnProfile columnProfile, String parentName) {
for (Column col : columnList) {
String columnName;
if (parentName != null) {
columnName = String.format("%s.%s", parentName, col.getName());
} else {
columnName = col.getName();
}
if (columnName.equals(columnProfile.getName())) {
return col;
}
if (col.getChildren() != null) {
Column childColumn = getColumnNameForProfiler(col.getChildren(), columnProfile, columnName);
if (childColumn != null) {
return childColumn;
}
}
}
return null;
}
@Transaction
public Table addTableProfileData(UUID tableId, CreateTableProfile createTableProfile) throws IOException {
// Validate the request content
@ -308,8 +329,7 @@ public class TableRepository extends EntityRepository<Table> {
for (ColumnProfile columnProfile : createTableProfile.getColumnProfile()) {
// Validate all the columns
Column column =
table.getColumns().stream().filter(c -> c.getName().equals(columnProfile.getName())).findFirst().orElse(null);
Column column = getColumnNameForProfiler(table.getColumns(), columnProfile, null);
if (column == null) {
throw new IllegalArgumentException("Invalid column name " + columnProfile.getName());
}
@ -432,6 +452,21 @@ public class TableRepository extends EntityRepository<Table> {
return new ResultList<>(systemProfiles, startTs.toString(), endTs.toString(), systemProfiles.size());
}
private void setColumnProfile(List<Column> columnList) throws IOException {
for (Column column : columnList) {
ColumnProfile columnProfile =
JsonUtils.readValue(
daoCollection
.entityExtensionTimeSeriesDao()
.getLatestExtension(column.getFullyQualifiedName(), TABLE_COLUMN_PROFILE_EXTENSION),
ColumnProfile.class);
column.setProfile(columnProfile);
if (column.getChildren() != null) {
setColumnProfile(column.getChildren());
}
}
}
@Transaction
public Table getLatestTableProfile(String fqn) throws IOException {
Table table = dao.findEntityByName(fqn);
@ -442,14 +477,7 @@ public class TableRepository extends EntityRepository<Table> {
.getLatestExtension(table.getFullyQualifiedName(), TABLE_PROFILE_EXTENSION),
TableProfile.class);
table.setProfile(tableProfile);
for (Column c : table.getColumns()) {
c.setProfile(
JsonUtils.readValue(
daoCollection
.entityExtensionTimeSeriesDao()
.getLatestExtension(c.getFullyQualifiedName(), TABLE_COLUMN_PROFILE_EXTENSION),
ColumnProfile.class));
}
setColumnProfile(table.getColumns());
return table;
}