mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-05 12:07:10 +00:00
* fix: raise more informative error message when service is not found * fix: profiling for struct table * fix: linting * fix: added tests for struct and nestedt struct for get_columns
This commit is contained in:
parent
63edc5d5ca
commit
e1b193a719
@ -497,10 +497,18 @@ class ProfilerWorkflow(WorkflowStatusMixin):
|
||||
service_name,
|
||||
),
|
||||
)
|
||||
if not service:
|
||||
raise ConnectionError(
|
||||
f"Could not retrieve service with name `{service_name}`. "
|
||||
"Typically caused by the `serviceName` does not exists in OpenMetadata "
|
||||
"or the JWT Token is invalid."
|
||||
)
|
||||
if service:
|
||||
self.config.source.serviceConnection = ServiceConnection(
|
||||
__root__=service.connection
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(
|
||||
|
||||
@ -14,7 +14,7 @@ Converter logic to transform an OpenMetadata Table Entity
|
||||
to an SQLAlchemy ORM class.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import MetaData
|
||||
@ -105,7 +105,9 @@ def check_snowflake_case_sensitive(table_service_type, table_or_col) -> Optional
|
||||
return None
|
||||
|
||||
|
||||
def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Column:
|
||||
def build_orm_col(
|
||||
idx: int, col: Column, table_service_type, parent: Optional[str] = None
|
||||
) -> sqlalchemy.Column:
|
||||
"""
|
||||
Cook the ORM column from our metadata instance
|
||||
information.
|
||||
@ -118,8 +120,13 @@ def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Colum
|
||||
there is no impact for our read-only purposes.
|
||||
"""
|
||||
|
||||
if parent:
|
||||
name = f"{parent}.{col.name.__root__}"
|
||||
else:
|
||||
name = col.name.__root__
|
||||
|
||||
return sqlalchemy.Column(
|
||||
name=str(col.name.__root__),
|
||||
name=str(name),
|
||||
type_=map_types(col, table_service_type),
|
||||
primary_key=not bool(idx), # The first col seen is used as PK
|
||||
quote=check_snowflake_case_sensitive(table_service_type, col.name.__root__),
|
||||
@ -129,6 +136,40 @@ def build_orm_col(idx: int, col: Column, table_service_type) -> sqlalchemy.Colum
|
||||
)
|
||||
|
||||
|
||||
def get_columns(
|
||||
column_list: List[Column],
|
||||
service_type: databaseService.DatabaseServiceType,
|
||||
start: int = 0,
|
||||
parent: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Build dictionnary of ORM columns
|
||||
|
||||
Args:
|
||||
column_list (List[Column]): list of columns
|
||||
service_type (DatabaseServiceType): database service type
|
||||
start (int): index of the column used to define the primary key
|
||||
parent (str): parent column name
|
||||
"""
|
||||
cols = {}
|
||||
|
||||
for idx, col in enumerate(column_list, start=start):
|
||||
if parent:
|
||||
name = f"{parent}.{col.name.__root__}"
|
||||
else:
|
||||
name = col.name.__root__
|
||||
if name in SQA_RESERVED_ATTRIBUTES:
|
||||
name = f"{name}_"
|
||||
|
||||
cols[name] = build_orm_col(idx, col, service_type, parent)
|
||||
if col.children:
|
||||
cols = {
|
||||
**cols,
|
||||
**get_columns(col.children, service_type, start=idx, parent=name),
|
||||
}
|
||||
|
||||
return cols
|
||||
|
||||
|
||||
def ometa_to_sqa_orm(
|
||||
table: Table, metadata: OpenMetadata, sqa_metadata_obj: Optional[MetaData] = None
|
||||
) -> DeclarativeMeta:
|
||||
@ -142,14 +183,10 @@ def ometa_to_sqa_orm(
|
||||
as the bases tuple for inheritance.
|
||||
"""
|
||||
|
||||
cols = {
|
||||
(
|
||||
col.name.__root__ + "_"
|
||||
if col.name.__root__ in SQA_RESERVED_ATTRIBUTES
|
||||
else col.name.__root__
|
||||
): build_orm_col(idx, col, table.serviceType)
|
||||
for idx, col in enumerate(table.columns)
|
||||
}
|
||||
table.serviceType = cast(
|
||||
databaseService.DatabaseServiceType, table.serviceType
|
||||
) # satisfy mypy
|
||||
cols = get_columns(table.columns, table.serviceType)
|
||||
|
||||
orm_database_name = get_orm_database(table, metadata)
|
||||
orm_schema_name = get_orm_schema(table, metadata)
|
||||
|
||||
@ -476,7 +476,7 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
return table_profile
|
||||
|
||||
def generate_sample_data(self) -> TableData:
|
||||
def generate_sample_data(self) -> Optional[TableData]:
|
||||
"""Fetch and ingest sample data
|
||||
|
||||
Returns:
|
||||
|
||||
@ -17,11 +17,14 @@ supporting sqlalchemy abstraction layer
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, MetaData
|
||||
from sqlalchemy import Column
|
||||
from typing_extensions import Self
|
||||
|
||||
from metadata.generated.schema.entity.data.table import PartitionProfilerConfig, Table
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
PartitionProfilerConfig,
|
||||
Table,
|
||||
TableData,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||
DatalakeConnection,
|
||||
)
|
||||
@ -225,6 +228,6 @@ class ProfilerProtocol(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_sample_data(self, table) -> dict:
|
||||
def fetch_sample_data(self, table) -> TableData:
|
||||
"""run profiler metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -488,7 +488,17 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin):
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return sampler.fetch_sqa_sample_data()
|
||||
|
||||
# Only fetch columns that are in the table entity
|
||||
# with struct columns we create a column for each field in the ORM table
|
||||
# but we only want to fetch the columns that are in the table entity
|
||||
sample_columns = [
|
||||
column.name
|
||||
for column in table.__table__.columns
|
||||
if column.name in {col.name.__root__ for col in self.table_entity.columns}
|
||||
]
|
||||
|
||||
return sampler.fetch_sqa_sample_data(sample_columns)
|
||||
|
||||
def get_composed_metrics(
|
||||
self, column: Column, metric: Metrics, column_results: Dict
|
||||
|
||||
@ -132,7 +132,7 @@ class Sampler:
|
||||
# Assign as an alias
|
||||
return aliased(self.table, sampled)
|
||||
|
||||
def fetch_sqa_sample_data(self) -> TableData:
|
||||
def fetch_sqa_sample_data(self, sample_columns: Optional[list] = None) -> TableData:
|
||||
"""
|
||||
Use the sampler to retrieve sample data rows as per limit given by user
|
||||
:return: TableData to be added to the Table Entity
|
||||
@ -142,7 +142,14 @@ class Sampler:
|
||||
|
||||
# Add new RandomNumFn column
|
||||
rnd = self.get_sample_query()
|
||||
sqa_columns = [col for col in inspect(rnd).c if col.name != RANDOM_LABEL]
|
||||
sample_columns = (
|
||||
sample_columns if sample_columns else [col.name for col in inspect(rnd).c]
|
||||
)
|
||||
sqa_columns = [
|
||||
col
|
||||
for col in inspect(rnd).c
|
||||
if col.name != RANDOM_LABEL and col.name in sample_columns
|
||||
]
|
||||
|
||||
sqa_sample = (
|
||||
self.session.query(*sqa_columns)
|
||||
|
||||
@ -17,12 +17,14 @@ from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
from pytest import mark
|
||||
from sqlalchemy import Column as SQAColumn
|
||||
from sqlalchemy.sql.sqltypes import INTEGER, String
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Column, DataType, Table
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.profiler.orm.converter import ometa_to_sqa_orm
|
||||
from metadata.profiler.orm.converter import get_columns, ometa_to_sqa_orm
|
||||
|
||||
|
||||
@patch("metadata.profiler.orm.converter.get_orm_schema", return_value="schema")
|
||||
@ -125,3 +127,62 @@ def test_metadata_column(mock_schema, mock_database):
|
||||
assert orm_table.__table_args__["schema"] == "schema"
|
||||
for name, _ in column_definition:
|
||||
assert hasattr(orm_table, name)
|
||||
|
||||
|
||||
def test_get_columns_regular():
|
||||
"""Test get columns function reads columns correctly"""
|
||||
regular_columns = [
|
||||
Column(
|
||||
name="col1",
|
||||
dataType=DataType.STRING,
|
||||
),
|
||||
Column(
|
||||
name="col2",
|
||||
dataType=DataType.INT,
|
||||
),
|
||||
]
|
||||
|
||||
cols = get_columns(regular_columns, DatabaseServiceType.BigQuery)
|
||||
col1 = cols["col1"]
|
||||
col2 = cols["col2"]
|
||||
assert len(cols) == 2
|
||||
assert col1.name == "col1"
|
||||
assert isinstance(col1.type, String)
|
||||
assert col2.name == "col2"
|
||||
assert isinstance(col2.type, INTEGER)
|
||||
|
||||
|
||||
def test_get_columns_struct():
|
||||
"""Test get columns function reads columns correctly for struct"""
|
||||
struct_columns = [
|
||||
Column(
|
||||
name="col1",
|
||||
dataType=DataType.STRING,
|
||||
),
|
||||
Column(
|
||||
name="col2",
|
||||
dataType=DataType.STRUCT,
|
||||
children=[
|
||||
Column(
|
||||
name="structCol1",
|
||||
dataType=DataType.STRING,
|
||||
),
|
||||
Column(
|
||||
name="structCol2",
|
||||
dataType=DataType.STRUCT,
|
||||
children=[
|
||||
Column(
|
||||
name="nestedStructCol1",
|
||||
dataType=DataType.STRING,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
cols = get_columns(struct_columns, DatabaseServiceType.BigQuery)
|
||||
assert len(cols) == 5
|
||||
assert "col2.structCol1" in cols
|
||||
assert "col2.structCol2" in cols
|
||||
assert "col2.structCol2.nestedStructCol1" in cols
|
||||
|
||||
@ -275,6 +275,27 @@ public class TableRepository extends EntityRepository<Table> {
|
||||
return table;
|
||||
}
|
||||
|
||||
private Column getColumnNameForProfiler(List<Column> columnList, ColumnProfile columnProfile, String parentName) {
|
||||
for (Column col : columnList) {
|
||||
String columnName;
|
||||
if (parentName != null) {
|
||||
columnName = String.format("%s.%s", parentName, col.getName());
|
||||
} else {
|
||||
columnName = col.getName();
|
||||
}
|
||||
if (columnName.equals(columnProfile.getName())) {
|
||||
return col;
|
||||
}
|
||||
if (col.getChildren() != null) {
|
||||
Column childColumn = getColumnNameForProfiler(col.getChildren(), columnProfile, columnName);
|
||||
if (childColumn != null) {
|
||||
return childColumn;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Transaction
|
||||
public Table addTableProfileData(UUID tableId, CreateTableProfile createTableProfile) throws IOException {
|
||||
// Validate the request content
|
||||
@ -308,8 +329,7 @@ public class TableRepository extends EntityRepository<Table> {
|
||||
|
||||
for (ColumnProfile columnProfile : createTableProfile.getColumnProfile()) {
|
||||
// Validate all the columns
|
||||
Column column =
|
||||
table.getColumns().stream().filter(c -> c.getName().equals(columnProfile.getName())).findFirst().orElse(null);
|
||||
Column column = getColumnNameForProfiler(table.getColumns(), columnProfile, null);
|
||||
if (column == null) {
|
||||
throw new IllegalArgumentException("Invalid column name " + columnProfile.getName());
|
||||
}
|
||||
@ -432,6 +452,21 @@ public class TableRepository extends EntityRepository<Table> {
|
||||
return new ResultList<>(systemProfiles, startTs.toString(), endTs.toString(), systemProfiles.size());
|
||||
}
|
||||
|
||||
private void setColumnProfile(List<Column> columnList) throws IOException {
|
||||
for (Column column : columnList) {
|
||||
ColumnProfile columnProfile =
|
||||
JsonUtils.readValue(
|
||||
daoCollection
|
||||
.entityExtensionTimeSeriesDao()
|
||||
.getLatestExtension(column.getFullyQualifiedName(), TABLE_COLUMN_PROFILE_EXTENSION),
|
||||
ColumnProfile.class);
|
||||
column.setProfile(columnProfile);
|
||||
if (column.getChildren() != null) {
|
||||
setColumnProfile(column.getChildren());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Transaction
|
||||
public Table getLatestTableProfile(String fqn) throws IOException {
|
||||
Table table = dao.findEntityByName(fqn);
|
||||
@ -442,14 +477,7 @@ public class TableRepository extends EntityRepository<Table> {
|
||||
.getLatestExtension(table.getFullyQualifiedName(), TABLE_PROFILE_EXTENSION),
|
||||
TableProfile.class);
|
||||
table.setProfile(tableProfile);
|
||||
for (Column c : table.getColumns()) {
|
||||
c.setProfile(
|
||||
JsonUtils.readValue(
|
||||
daoCollection
|
||||
.entityExtensionTimeSeriesDao()
|
||||
.getLatestExtension(c.getFullyQualifiedName(), TABLE_COLUMN_PROFILE_EXTENSION),
|
||||
ColumnProfile.class));
|
||||
}
|
||||
setColumnProfile(table.getColumns());
|
||||
return table;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user