diff --git a/ingestion/src/metadata/ingestion/source/postgres.py b/ingestion/src/metadata/ingestion/source/postgres.py index ee80dd6933f..aec14a21745 100644 --- a/ingestion/src/metadata/ingestion/source/postgres.py +++ b/ingestion/src/metadata/ingestion/source/postgres.py @@ -15,7 +15,7 @@ from collections import namedtuple -import pymysql # noqa: F401 +import psycopg2 # This import verifies that the dependencies are available. from metadata.generated.schema.entity.services.databaseService import ( @@ -45,6 +45,7 @@ class PostgresSourceConfig(SQLConnectionConfig): class PostgresSource(SQLSource): def __init__(self, config, metadata_config, ctx): super().__init__(config, metadata_config, ctx) + self.pgconn = self.engine.raw_connection() @classmethod def create(cls, config_dict, metadata_config_dict, ctx): @@ -54,3 +55,17 @@ class PostgresSource(SQLSource): def get_status(self) -> SourceStatus: return self.status + + def type_of_column_name(self, sa_type, table_name: str, column_name: str): + cur = self.pgconn.cursor() + schema_table = table_name.split(".") + cur.execute( + """select data_type, udt_name + from information_schema.columns + where table_schema = %s and table_name = %s and column_name = %s""", + (schema_table[0], schema_table[1], column_name), + ) + pgtype = cur.fetchone()[1] + if pgtype == "geometry" or pgtype == "geography": + return "GEOGRAPHY" + return sa_type diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index cd29fe033b0..726db4ffdf9 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -167,6 +167,9 @@ class SQLSource(Source): ): pass + def type_of_column_name(self, sa_type, table_name: str, column_name: str): + return sa_type + def standardize_schema_table_names( self, schema: str, table: str ) -> Tuple[str, str]: @@ -409,6 +412,12 @@ class SQLSource(Source): if col_data_length is None: col_data_length = 1 try: + if col_type == "NULL": + col_type = self.type_of_column_name( + col_type, + column_name=column["name"], + table_name=dataset_name, + ) if col_type == "NULL": col_type = "VARCHAR" logger.warning(