diff --git a/ingestion/src/metadata/ingestion/source/database/trino.py b/ingestion/src/metadata/ingestion/source/database/trino.py index 05a18ad14d8..4b78cde4837 100644 --- a/ingestion/src/metadata/ingestion/source/database/trino.py +++ b/ingestion/src/metadata/ingestion/source/database/trino.py @@ -10,15 +10,18 @@ # limitations under the License. import logging +import re import sys -from typing import Iterable +from textwrap import dedent +from typing import Any, Dict, Iterable, List, Optional, Tuple import click -from sqlalchemy import inspect -from sqlalchemy.engine.reflection import Inspector +from sqlalchemy import inspect, sql, util +from sqlalchemy.engine.base import Connection +from sqlalchemy.sql import sqltypes +from trino.sqlalchemy import datatype +from trino.sqlalchemy.dialect import TrinoDialect -from metadata.generated.schema.entity.data.database import Database -from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( TrinoConnection, ) @@ -31,8 +34,92 @@ from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.ingestion.api.source import InvalidSourceException from metadata.ingestion.source.database.common_db_source import CommonDbSourceService from metadata.utils.logger import ingestion_logger +from metadata.utils.sql_queries import TRINO_GET_COLUMNS logger = ingestion_logger() +ROW_DATA_TYPE = "row" +ARRAY_DATA_TYPE = "array" + + +def get_type_name_and_opts(type_str: str) -> Tuple[str, Optional[str]]: + match = re.match(r"^(?P\w+)\s*(?:\((?P.*)\))?", type_str) + if not match: + util.warn(f"Could not parse type name '{type_str}'") + return sqltypes.NULLTYPE + type_name = match.group("type") + type_opts = match.group("options") + return type_name, type_opts + + +def parse_array_data_type(type_str: str) -> str: + """ + This mehtod is used to convert the complex array datatype to the format that is supported by OpenMetadata + For Example: + If we have a row type as array(row(col1 bigint, col2 string)) + this method will return type as -> array> + """ + type_name, type_opts = get_type_name_and_opts(type_str) + final = type_name + "<" + if type_opts: + if type_opts.startswith(ROW_DATA_TYPE): + final += parse_row_data_type(type_opts) + elif type_opts.startswith(ARRAY_DATA_TYPE): + final += parse_array_data_type(type_opts) + else: + final += type_opts + return final + ">" + + +def parse_row_data_type(type_str: str) -> str: + """ + This mehtod is used to convert the complex row datatype to the format that is supported by OpenMetadata + For Example: + If we have a row type as row(col1 bigint, col2 bigint, col3 row(col4 string, col5 bigint)) + this method will return type as -> struct> + """ + type_name, type_opts = get_type_name_and_opts(type_str) + final = type_name.replace(ROW_DATA_TYPE, "struct") + "<" + if type_opts: + for i in datatype.aware_split(type_opts) or []: + attr_name, attr_type_str = datatype.aware_split( + i.strip(), delimiter=" ", maxsplit=1 + ) + if attr_type_str.startswith(ROW_DATA_TYPE): + final += attr_name + ":" + parse_row_data_type(attr_type_str) + "," + elif attr_type_str.startswith(ARRAY_DATA_TYPE): + final += attr_name + ":" + parse_array_data_type(attr_type_str) + "," + else: + final += attr_name + ":" + attr_type_str + "," + return final[:-1] + ">" + + +def _get_columns( + self, connection: Connection, table_name: str, schema: str = None, **kw +) -> List[Dict[str, Any]]: + schema = schema or self._get_default_schema_name(connection) + query = dedent(TRINO_GET_COLUMNS).strip() + res = connection.execute(sql.text(query), schema=schema, table=table_name) + columns = [] + + for record in res: + col_type = datatype.parse_sqltype(record.data_type) + column = dict( + name=record.column_name, + type=col_type, + nullable=record.is_nullable == "YES", + default=record.column_default, + ) + type_str = record.data_type.strip().lower() + type_name, type_opts = get_type_name_and_opts(type_str) + if type_opts and type_name == ROW_DATA_TYPE: + column["raw_data_type"] = parse_row_data_type(type_str) + elif type_opts and type_name == ARRAY_DATA_TYPE: + column["raw_data_type"] = parse_array_data_type(type_str) + columns.append(column) + return columns + + +TrinoDialect._get_columns = _get_columns # pylint: disable=protected-access class TrinoSource(CommonDbSourceService): diff --git a/ingestion/src/metadata/utils/column_type_parser.py b/ingestion/src/metadata/utils/column_type_parser.py index 50e1f2ed5c0..4305a03bbba 100644 --- a/ingestion/src/metadata/utils/column_type_parser.py +++ b/ingestion/src/metadata/utils/column_type_parser.py @@ -144,6 +144,7 @@ class ColumnTypeParser: "SQL_VARIANT": "VARBINARY", "STRING": "STRING", "STRUCT": "STRUCT", + "ROW": "STRUCT", "TABLE": "BINARY", "TEXT": "TEXT", "TIME": "TIME", diff --git a/ingestion/src/metadata/utils/sql_queries.py b/ingestion/src/metadata/utils/sql_queries.py index 739cd2967f5..ad7be0c88c3 100644 --- a/ingestion/src/metadata/utils/sql_queries.py +++ b/ingestion/src/metadata/utils/sql_queries.py @@ -357,3 +357,16 @@ WHERE creation_time BETWEEN "{start_time}" AND "{end_time}" AND state = "DONE" AND IFNULL(statement_type, "NO") not in ("NO", "DROP_TABLE", "CREATE_TABLE") """ + + +TRINO_GET_COLUMNS = """ + SELECT + "column_name", + "data_type", + "column_default", + UPPER("is_nullable") AS "is_nullable" + FROM "information_schema"."columns" + WHERE "table_schema" = :schema + AND "table_name" = :table + ORDER BY "ordinal_position" ASC +""" diff --git a/ingestion/tests/unit/test_trino_complex_types.py b/ingestion/tests/unit/test_trino_complex_types.py new file mode 100644 index 00000000000..6eb29a1d3a5 --- /dev/null +++ b/ingestion/tests/unit/test_trino_complex_types.py @@ -0,0 +1,44 @@ +from unittest import TestCase + +from metadata.ingestion.source.database.trino import ( + parse_array_data_type, + parse_row_data_type, +) + +RAW_ARRAY_DATA_TYPES = [ + "array(string)", + "array(row(check_datatype array(string)))", +] + +EXPECTED_ARRAY_DATA_TYPES = [ + "array", + "array>>", +] + +RAW_ROW_DATA_TYPES = [ + "row(a int, b string)", + "row(a row(b array(string),c bigint))", + "row(a array(string))", + "row(bigquerytestdatatype51 array(row(bigquery_test_datatype_511 array(string))))", + "row(record_1 row(record_2 row(record_3 row(record_4 string))))", +] + +EXPECTED_ROW_DATA_TYPES = [ + "struct", + "struct,c:bigint>>", + "struct>", + "struct>>>", + "struct>>>", +] + + +class SouceConnectionTest(TestCase): + def test_array_datatype(self): + for i in range(len(RAW_ARRAY_DATA_TYPES)): + parsed_type = parse_array_data_type(RAW_ARRAY_DATA_TYPES[i]) + assert parsed_type == EXPECTED_ARRAY_DATA_TYPES[i] + + def test_row_datatype(self): + for i in range(len(RAW_ROW_DATA_TYPES)): + parsed_type = parse_row_data_type(RAW_ROW_DATA_TYPES[i]) + assert parsed_type == EXPECTED_ROW_DATA_TYPES[i]