mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-05 20:15:15 +00:00
263 lines
8.6 KiB
Python
263 lines
8.6 KiB
Python
# Copyright 2021 Collate
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Postgres source module
|
|
"""
|
|
import traceback
|
|
from collections import namedtuple
|
|
from typing import Iterable, Tuple
|
|
|
|
from sqlalchemy import sql
|
|
from sqlalchemy.dialects.postgresql.base import PGDialect, ischema_names
|
|
from sqlalchemy.engine import reflection
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
from sqlalchemy.sql.sqltypes import String
|
|
|
|
from metadata.generated.schema.api.classification.createClassification import (
|
|
CreateClassificationRequest,
|
|
)
|
|
from metadata.generated.schema.api.classification.createTag import CreateTagRequest
|
|
from metadata.generated.schema.entity.data.database import Database
|
|
from metadata.generated.schema.entity.data.table import (
|
|
IntervalType,
|
|
TablePartition,
|
|
TableType,
|
|
)
|
|
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
|
|
PostgresConnection,
|
|
)
|
|
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
|
OpenMetadataConnection,
|
|
)
|
|
from metadata.generated.schema.metadataIngestion.workflow import (
|
|
Source as WorkflowSource,
|
|
)
|
|
from metadata.ingestion.api.source import InvalidSourceException
|
|
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
|
|
from metadata.ingestion.source.database.common_db_source import (
|
|
CommonDbSourceService,
|
|
TableNameAndType,
|
|
)
|
|
from metadata.ingestion.source.database.postgres.queries import (
|
|
POSTGRES_GET_ALL_TABLE_PG_POLICY,
|
|
POSTGRES_GET_DB_NAMES,
|
|
POSTGRES_GET_TABLE_NAMES,
|
|
POSTGRES_PARTITION_DETAILS,
|
|
POSTGRES_TABLE_COMMENTS,
|
|
POSTGRES_VIEW_DEFINITIONS,
|
|
)
|
|
from metadata.utils import fqn
|
|
from metadata.utils.filters import filter_by_database
|
|
from metadata.utils.logger import ingestion_logger
|
|
from metadata.utils.sqlalchemy_utils import (
|
|
get_all_table_comments,
|
|
get_all_view_definitions,
|
|
get_table_comment_wrapper,
|
|
get_view_definition_wrapper,
|
|
)
|
|
|
|
TableKey = namedtuple("TableKey", ["schema", "table_name"])
|
|
|
|
logger = ingestion_logger()
|
|
|
|
|
|
INTERVAL_TYPE_MAP = {
|
|
"list": IntervalType.COLUMN_VALUE.value,
|
|
"hash": IntervalType.COLUMN_VALUE.value,
|
|
"range": IntervalType.TIME_UNIT.value,
|
|
}
|
|
|
|
RELKIND_MAP = {
|
|
"r": TableType.Regular,
|
|
"p": TableType.Partitioned,
|
|
"f": TableType.Foreign,
|
|
}
|
|
|
|
|
|
class GEOMETRY(String):
|
|
"""The SQL GEOMETRY type."""
|
|
|
|
__visit_name__ = "GEOMETRY"
|
|
|
|
|
|
class POINT(String):
|
|
"""The SQL POINT type."""
|
|
|
|
__visit_name__ = "POINT"
|
|
|
|
|
|
class POLYGON(String):
|
|
"""The SQL GEOMETRY type."""
|
|
|
|
__visit_name__ = "POLYGON"
|
|
|
|
|
|
ischema_names.update({"geometry": GEOMETRY, "point": POINT, "polygon": POLYGON})
|
|
|
|
|
|
@reflection.cache
|
|
def get_table_comment(
|
|
self, connection, table_name, schema=None, **kw
|
|
): # pylint: disable=unused-argument
|
|
return get_table_comment_wrapper(
|
|
self,
|
|
connection,
|
|
table_name=table_name,
|
|
schema=schema,
|
|
query=POSTGRES_TABLE_COMMENTS,
|
|
)
|
|
|
|
|
|
PGDialect.get_all_table_comments = get_all_table_comments
|
|
PGDialect.get_table_comment = get_table_comment
|
|
|
|
|
|
@reflection.cache
|
|
def get_view_definition(
|
|
self, connection, table_name, schema=None, **kw
|
|
): # pylint: disable=unused-argument
|
|
return get_view_definition_wrapper(
|
|
self,
|
|
connection,
|
|
table_name=table_name,
|
|
schema=schema,
|
|
query=POSTGRES_VIEW_DEFINITIONS,
|
|
)
|
|
|
|
|
|
PGDialect.get_view_definition = get_view_definition
|
|
PGDialect.get_all_view_definitions = get_all_view_definitions
|
|
|
|
PGDialect.ischema_names = ischema_names
|
|
|
|
|
|
class PostgresSource(CommonDbSourceService):
|
|
"""
|
|
Implements the necessary methods to extract
|
|
Database metadata from Postgres Source
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
|
|
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
|
connection: PostgresConnection = config.serviceConnection.__root__.config
|
|
if not isinstance(connection, PostgresConnection):
|
|
raise InvalidSourceException(
|
|
f"Expected PostgresConnection, but got {connection}"
|
|
)
|
|
return cls(config, metadata_config)
|
|
|
|
def query_table_names_and_types(
|
|
self, schema_name: str
|
|
) -> Iterable[TableNameAndType]:
|
|
"""
|
|
Overwrite the inspector implementation to handle partitioned
|
|
and foreign types
|
|
"""
|
|
result = self.connection.execute(
|
|
sql.text(POSTGRES_GET_TABLE_NAMES),
|
|
{"schema": schema_name},
|
|
)
|
|
|
|
return [
|
|
TableNameAndType(
|
|
name=name, type_=RELKIND_MAP.get(relkind, TableType.Regular)
|
|
)
|
|
for name, relkind in result
|
|
]
|
|
|
|
def get_database_names(self) -> Iterable[str]:
|
|
if not self.config.serviceConnection.__root__.config.ingestAllDatabases:
|
|
configured_db = self.config.serviceConnection.__root__.config.database
|
|
self.set_inspector(database_name=configured_db)
|
|
yield configured_db
|
|
else:
|
|
results = self.connection.execute(POSTGRES_GET_DB_NAMES)
|
|
for res in results:
|
|
row = list(res)
|
|
new_database = row[0]
|
|
database_fqn = fqn.build(
|
|
self.metadata,
|
|
entity_type=Database,
|
|
service_name=self.context.database_service.name.__root__,
|
|
database_name=new_database,
|
|
)
|
|
|
|
if filter_by_database(
|
|
self.source_config.databaseFilterPattern,
|
|
database_fqn
|
|
if self.source_config.useFqnForFiltering
|
|
else new_database,
|
|
):
|
|
self.status.filter(database_fqn, "Database Filtered Out")
|
|
continue
|
|
|
|
try:
|
|
self.set_inspector(database_name=new_database)
|
|
yield new_database
|
|
except Exception as exc:
|
|
logger.debug(traceback.format_exc())
|
|
logger.error(
|
|
f"Error trying to connect to database {new_database}: {exc}"
|
|
)
|
|
|
|
def get_table_partition_details(
|
|
self, table_name: str, schema_name: str, inspector: Inspector
|
|
) -> Tuple[bool, TablePartition]:
|
|
result = self.engine.execute(
|
|
POSTGRES_PARTITION_DETAILS.format(
|
|
table_name=table_name, schema_name=schema_name
|
|
)
|
|
).all()
|
|
if result:
|
|
partition_details = TablePartition(
|
|
intervalType=INTERVAL_TYPE_MAP.get(
|
|
result[0].partition_strategy, IntervalType.COLUMN_VALUE.value
|
|
),
|
|
columns=[row.column_name for row in result if row.column_name],
|
|
)
|
|
return True, partition_details
|
|
return False, None
|
|
|
|
def yield_tag(self, schema_name: str) -> Iterable[OMetaTagAndClassification]:
|
|
"""
|
|
Fetch Tags
|
|
"""
|
|
try:
|
|
result = self.engine.execute(
|
|
POSTGRES_GET_ALL_TABLE_PG_POLICY.format(
|
|
database_name=self.context.database.name.__root__,
|
|
schema_name=schema_name,
|
|
)
|
|
).all()
|
|
|
|
for res in result:
|
|
row = list(res)
|
|
fqn_elements = [name for name in row[2:] if name]
|
|
yield OMetaTagAndClassification(
|
|
fqn=fqn._build( # pylint: disable=protected-access
|
|
self.context.database_service.name.__root__, *fqn_elements
|
|
),
|
|
classification_request=CreateClassificationRequest(
|
|
name=self.service_connection.classificationName,
|
|
description="Postgres Tag Name",
|
|
),
|
|
tag_request=CreateTagRequest(
|
|
classification=self.service_connection.classificationName,
|
|
name=row[1],
|
|
description="Postgres Tag Value",
|
|
),
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.debug(traceback.format_exc())
|
|
logger.warning(f"Skipping Policy Tag: {exc}")
|