FIxes #16983: can't sample data from trino tables with complex types (#23478)

* Update test data for `tests.integration.trino`

This is to create tables with complex data types.

Using raw SQL because creating tables with pandas didn't get the right types for the structs

* Update tests to reproduce the issue

Also included the new tables in the other tests to make sure complex data types do not break anything else

Reference: [issue 16983](https://github.com/open-metadata/OpenMetadata/issues/16983)

* Added `TypeDecorator`s handle `trino.types.NamedRowTuple`

This is because pydantic couldn't figure out how to create python objects when receiving `NamedRowTuple`s, which broke the sampling process.

This makes sure the data we receive from the trino interface is compatible with Pydantic
This commit is contained in:
Eugenio 2025-09-26 08:13:28 +02:00 committed by GitHub
parent cc265f956b
commit bb50514a00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 360 additions and 17 deletions

View File

@ -24,6 +24,7 @@ from metadata.profiler.orm.converter.mariadb.converter import MariaDBMapTypes
from metadata.profiler.orm.converter.mssql.converter import MssqlMapTypes
from metadata.profiler.orm.converter.redshift.converter import RedshiftMapTypes
from metadata.profiler.orm.converter.snowflake.converter import SnowflakeMapTypes
from metadata.profiler.orm.converter.trino import TrinoMapTypes
converter_registry = defaultdict(lambda: CommonMapTypes)
converter_registry[DatabaseServiceType.BigQuery] = BigqueryMapTypes
@ -32,3 +33,4 @@ converter_registry[DatabaseServiceType.Redshift] = RedshiftMapTypes
converter_registry[DatabaseServiceType.Mssql] = MssqlMapTypes
converter_registry[DatabaseServiceType.AzureSQL] = AzureSqlMapTypes
converter_registry[DatabaseServiceType.MariaDB] = MariaDBMapTypes
converter_registry[DatabaseServiceType.Trino] = TrinoMapTypes

View File

@ -0,0 +1,14 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .converter import TrinoMapTypes
__all__ = ("TrinoMapTypes",)

View File

@ -0,0 +1,39 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Set
from sqlalchemy.sql.type_api import TypeEngine
from metadata.generated.schema.entity.data.table import DataType
from metadata.profiler.orm.converter.common import CommonMapTypes
from metadata.profiler.orm.types.trino import TrinoArray, TrinoMap, TrinoStruct
class TrinoMapTypes(CommonMapTypes):
_TYPE_MAP_OVERRIDE = {
DataType.ARRAY: TrinoArray,
DataType.MAP: TrinoMap,
DataType.STRUCT: TrinoStruct,
}
_TYPE_MAP = {
**CommonMapTypes._TYPE_MAP,
**_TYPE_MAP_OVERRIDE,
}
@classmethod
def map_sqa_to_om_types(cls) -> Dict[TypeEngine, Set[DataType]]:
"""returns an ORM type"""
# pylint: disable=import-outside-toplevel
return {
**CommonMapTypes.map_sqa_to_om_types(),
**{v: {k} for k, v in cls._TYPE_MAP_OVERRIDE.items()},
}

View File

@ -0,0 +1,10 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,70 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Type adapter for Trino to handle NamedRowTuple serialization
"""
from typing import Any
from sqlalchemy import ARRAY
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.sqltypes import String, TypeDecorator
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class TrinoTypesMixin:
def process_result_value(self, value: Any, dialect: Dialect) -> Any:
# pylint: disable=import-outside-toplevel
from trino.types import NamedRowTuple
def _convert_value(obj: Any) -> Any:
if isinstance(obj, NamedRowTuple):
return {
k: _convert_value(getattr(obj, k))
for k in obj.__annotations__["names"]
}
elif isinstance(obj, (list, tuple)):
return type(obj)(_convert_value(v) for v in obj)
elif isinstance(obj, dict):
return {k: _convert_value(v) for k, v in obj.items()}
return obj
return _convert_value(value)
class TrinoArray(TrinoTypesMixin, TypeDecorator):
impl = ARRAY
cache_ok = True
@property
def python_type(self):
return list
class TrinoMap(TrinoTypesMixin, TypeDecorator):
impl = String
cache_ok = True
@property
def python_type(self):
return dict
class TrinoStruct(TrinoTypesMixin, TypeDecorator):
impl = String
cache_ok = True
@property
def python_type(self):
return dict

View File

@ -1,5 +1,6 @@
import os.path
import random
from pathlib import Path
from time import sleep
import docker
@ -7,7 +8,7 @@ import pandas as pd
import pytest
import testcontainers.core.network
from sqlalchemy import create_engine, insert
from sqlalchemy.engine import make_url
from sqlalchemy.engine import Engine, make_url
from tenacity import retry, stop_after_delay, wait_fixed
from testcontainers.core.container import DockerContainer
from testcontainers.core.generic import DbContainer
@ -192,28 +193,51 @@ def create_test_data(trino_container):
)
data_dir = os.path.dirname(__file__) + "/data"
for file in os.listdir(data_dir):
df = pd.read_parquet(f"{data_dir}/{file}")
for col in df.columns:
if pd.api.types.is_datetime64tz_dtype(df[col]):
df[col] = df[col].dt.tz_convert(None)
df.to_sql(
file.replace(".parquet", ""),
engine,
schema="my_schema",
if_exists="fail",
index=False,
method=custom_insert,
)
file_path = Path(os.path.join(data_dir, file))
if file_path.suffix == ".sql":
# Creating test data with complex fields with pandas breaks
create_test_data_from_sql(engine, file_path)
else:
create_test_data_from_parquet(engine, file_path)
sleep(1)
engine.execute(
"ANALYZE " + f'minio."my_schema"."{file.replace(".parquet", "")}"'
)
engine.execute("ANALYZE " + f'minio."my_schema"."{file_path.stem}"')
engine.execute(
"CALL system.drop_stats(schema_name => 'my_schema', table_name => 'empty')"
)
return
def create_test_data_from_parquet(engine: Engine, file_path: Path):
df = pd.read_parquet(file_path)
# Convert data types
for col in df.columns:
if pd.api.types.is_datetime64tz_dtype(df[col]):
df[col] = df[col].dt.tz_convert(None)
df.to_sql(
Path(file_path).stem,
engine,
schema="my_schema",
if_exists="fail",
index=False,
method=custom_insert,
)
def create_test_data_from_sql(engine: Engine, file_path: Path):
with open(file_path, "r") as f:
sql = f.read()
sql = sql.format(catalog="minio", schema="my_schema", table_name=file_path.stem)
for statement in sql.split(";"):
if statement.strip() == "":
continue
engine.execute(statement)
def custom_insert(self, conn, keys: list[str], data_iter):
"""
Hack pandas.io.sql.SQLTable._execute_insert_multi to retry untill rows are inserted.

View File

@ -0,0 +1,67 @@
CREATE TABLE IF NOT EXISTS {catalog}.{schema}.{table_name} (
promotiontransactionid BIGINT,
validto TIMESTAMP,
vouchercode VARCHAR,
payload ROW(
id VARCHAR,
amount DOUBLE,
currency VARCHAR,
metadata MAP(VARCHAR, VARCHAR),
items ARRAY(ROW(
itemId VARCHAR,
quantity INTEGER,
price DOUBLE
))
),
adjustmenthistory ARRAY(ROW(
adjustmentId VARCHAR,
timestamp TIMESTAMP,
amount DOUBLE,
reason VARCHAR
)),
simplemap MAP(VARCHAR, VARCHAR),
simplearray ARRAY(VARCHAR)
);
INSERT INTO {catalog}.{schema}.{table_name} VALUES
(
1001,
TIMESTAMP '2024-12-31 23:59:59',
'PROMO2024',
ROW(
'txn_001',
99.99,
'USD',
MAP(ARRAY['store', 'region'], ARRAY['Store123', 'US-West']),
ARRAY[
ROW('item_001', 2, 29.99),
ROW('item_002', 1, 40.01)
]
),
ARRAY[
ROW('adj_001', TIMESTAMP '2024-01-15 10:30:00', -5.00, 'Discount'),
ROW('adj_002', TIMESTAMP '2024-01-15 11:00:00', 2.50, 'Tax adjustment')
],
MAP(ARRAY['status', 'type'], ARRAY['active', 'promotion']),
ARRAY['tag1', 'tag2', 'tag3']
),
(
1002,
TIMESTAMP '2024-06-30 23:59:59',
'SUMMER2024',
ROW(
'txn_002',
150.75,
'EUR',
MAP(ARRAY['campaign', 'channel'], ARRAY['Summer Sale', 'Online']),
ARRAY[
ROW('item_003', 3, 45.25),
ROW('item_004', 1, 15.00)
]
),
ARRAY[
ROW('adj_003', TIMESTAMP '2024-02-20 14:15:00', -10.00, 'Coupon')
],
MAP(ARRAY['status', 'priority'], ARRAY['completed', 'high']),
ARRAY['summer', 'sale', 'online']
)

View File

@ -0,0 +1,26 @@
CREATE TABLE IF NOT EXISTS {catalog}.{schema}.{table_name} (
payload ROW(
foobar VARCHAR,
foobaz DOUBLE,
foos ARRAY(ROW(bars VARCHAR))
),
foobars ARRAY(VARCHAR)
);
INSERT INTO {catalog}.{schema}.{table_name} VALUES
(
ROW(
'test_value',
123.45,
ARRAY[ROW('bar1'), ROW('bar2')]
),
ARRAY['foo1', 'foo2', 'foo3']
),
(
ROW(
'another_value',
678.90,
ARRAY[ROW('bar3')]
),
ARRAY['foo4']
);

View File

@ -0,0 +1,63 @@
from copy import deepcopy
import pytest
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.metadataIngestion.databaseServiceAutoClassificationPipeline import (
DatabaseServiceAutoClassificationPipeline,
)
from metadata.ingestion.lineage.sql_lineage import search_cache
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.workflow.classification import AutoClassificationWorkflow
from metadata.workflow.metadata import MetadataWorkflow
@pytest.fixture(scope="module")
def sampling_only_classifier_config(
db_service, sink_config, workflow_config, classifier_config
):
config = deepcopy(classifier_config)
config["source"]["sourceConfig"]["config"]["enableAutoClassification"] = False
return config
@pytest.fixture(
scope="module",
)
def run_classifier(
patch_passwords_for_db_services,
run_workflow,
ingestion_config,
sampling_only_classifier_config,
create_test_data,
request,
):
search_cache.clear()
run_workflow(MetadataWorkflow, ingestion_config)
run_workflow(AutoClassificationWorkflow, sampling_only_classifier_config)
return ingestion_config
@pytest.mark.parametrize(
"table_name",
(
"{database_service}.minio.my_schema.table",
"{database_service}.minio.my_schema.titanic",
"{database_service}.minio.my_schema.iris",
"{database_service}.minio.my_schema.userdata",
"{database_service}.minio.my_schema.empty",
"{database_service}.minio.my_schema.complex_and_simple",
"{database_service}.minio.my_schema.only_complex",
),
)
def test_auto_classification_workflow(
run_classifier,
metadata: OpenMetadata,
table_name: str,
db_service: DatabaseServiceAutoClassificationPipeline,
):
table = metadata.get_by_name(
Table, table_name.format(database_service=db_service.fullyQualifiedName.root)
)
assert metadata.get_sample_data(table) is not None

View File

@ -18,6 +18,8 @@ def run_workflow(run_workflow, ingestion_config, create_test_data):
"{database_service}.minio.my_schema.iris",
"{database_service}.minio.my_schema.userdata",
"{database_service}.minio.my_schema.empty",
"{database_service}.minio.my_schema.complex_and_simple",
"{database_service}.minio.my_schema.only_complex",
],
ids=lambda x: x.split(".")[-1],
)

View File

@ -132,6 +132,32 @@ class ProfilerTestParameters:
],
lambda x: x.useStatistics == True,
),
ProfilerTestParameters(
"{database_service}.minio.my_schema.complex_and_simple", # complex types ignored
TableProfile(timestamp=Timestamp(0), rowCount=2),
[
ColumnProfile(
name="promotiontransactionid",
timestamp=Timestamp(0),
valuesCount=2,
nullCount=0,
),
ColumnProfile(
name="validto", timestamp=Timestamp(0), valuesCount=2, nullCount=0
),
ColumnProfile(
name="vouchercode",
timestamp=Timestamp(0),
valuesCount=2,
nullCount=0,
),
],
),
ProfilerTestParameters(
"{database_service}.minio.my_schema.only_complex", # complex types ignored
TableProfile(timestamp=Timestamp(0), rowCount=2),
[],
),
],
ids=lambda x: x.table_fqn.split(".")[-1],
)
@ -144,7 +170,7 @@ def test_profiler(
)
):
pytest.skip(
"Skipping test becuase its not supported for this profiler configuation"
"Skipping test because it's not supported for this profiler configuration"
)
table: Table = metadata.get_latest_table_profile(
parameters.table_fqn.format(database_service=db_service.fullyQualifiedName.root)