Fix #5153: Add referred columns for foreign key constraint and sort_key, dist_key support (#10433)

Co-authored-by: ulixius9 <mayursingal9@gmail.com>
This commit is contained in:
Sriharsha Chintalapani 2023-03-15 06:25:51 -07:00 committed by GitHub
parent bf0d26922e
commit b33587041d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 347 additions and 71 deletions

View File

@ -11,15 +11,27 @@
"""
Table related pydantic definitions
"""
from typing import Optional
from typing import Dict, List, Optional
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.data.table import Table, TableConstraint
class DeleteTable(BaseModel):
"""Entity Reference of a table to be deleted"""
"""
Entity Reference of a table to be deleted
"""
table: Table
mark_deleted_tables: Optional[bool] = False
class OMetaTableConstraints(BaseModel):
"""
Model to club table with its constraints
"""
table_id: str
foreign_constraints: Optional[List[Dict]]
constraints: Optional[List[TableConstraint]]

View File

@ -19,7 +19,7 @@ from typing import Dict, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.data.table import Table, TableConstraint
from metadata.generated.schema.type import basic
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.generated.schema.type.tagLabel import LabelType, State, TagSource
@ -46,6 +46,8 @@ REMOVE = "remove"
# OM specific description handling
ENTITY_DESCRIPTION = "/description"
COL_DESCRIPTION = "/columns/{index}/description"
TABLE_CONSTRAINTS = "/tableConstraints"
ENTITY_TAG = "/tags/{tag_index}"
COL_TAG = "/columns/{index}/tags/{tag_index}"
@ -65,7 +67,7 @@ class OMetaPatchMixin(Generic[T]):
client: REST
def _validate_instance_description(
def _fetch_entity_if_exists(
self, entity: Type[T], entity_id: Union[str, basic.Uuid]
) -> Optional[T]:
"""
@ -111,9 +113,7 @@ class OMetaPatchMixin(Generic[T]):
Returns
Updated Entity
"""
instance = self._validate_instance_description(
entity=entity, entity_id=entity_id
)
instance = self._fetch_entity_if_exists(entity=entity, entity_id=entity_id)
if not instance:
return None
@ -165,7 +165,7 @@ class OMetaPatchMixin(Generic[T]):
Returns
Updated Entity
"""
table: Table = self._validate_instance_description(
table: Table = self._fetch_entity_if_exists(
entity=Table,
entity_id=entity_id,
)
@ -213,6 +213,61 @@ class OMetaPatchMixin(Generic[T]):
return None
def patch_table_constraints(
self,
entity_id: Union[str, basic.Uuid],
table_constraints: List[TableConstraint],
) -> Optional[T]:
"""Given an Entity ID, JSON PATCH the table constraints of table
Args
entity_id: ID
description: new description to add
table_constraints: table constraints to add
Returns
Updated Entity
"""
table: Table = self._fetch_entity_if_exists(
entity=Table,
entity_id=entity_id,
)
if not table:
return None
try:
res = self.client.patch(
path=f"{self.get_suffix(Table)}/{model_str(entity_id)}",
data=json.dumps(
[
{
OPERATION: ADD if not table.tableConstraints else REPLACE,
PATH: TABLE_CONSTRAINTS,
VALUE: [
{
"constraintType": constraint.constraintType.value,
"columns": constraint.columns,
"referredColumns": [
col.__root__
for col in constraint.referredColumns or []
],
}
for constraint in table_constraints
],
}
]
),
)
return Table(**res)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to PATCH description for Table Constraint: {entity_id}: {exc}"
)
return None
def patch_tag(
self,
entity: Type[T],
@ -233,9 +288,7 @@ class OMetaPatchMixin(Generic[T]):
Returns
Updated Entity
"""
instance = self._validate_instance_description(
entity=entity, entity_id=entity_id
)
instance = self._fetch_entity_if_exists(entity=entity, entity_id=entity_id)
if not instance:
return None
@ -304,9 +357,7 @@ class OMetaPatchMixin(Generic[T]):
Returns
Updated Entity
"""
table: Table = self._validate_instance_description(
entity=Table, entity_id=entity_id
)
table: Table = self._fetch_entity_if_exists(entity=Table, entity_id=entity_id)
if not table:
return None
@ -389,9 +440,7 @@ class OMetaPatchMixin(Generic[T]):
Returns
Updated Entity
"""
instance = self._validate_instance_description(
entity=entity, entity_id=entity_id
)
instance = self._fetch_entity_if_exists(entity=entity, entity_id=entity_id)
if not instance:
return None

View File

@ -38,7 +38,7 @@ from metadata.ingestion.models.ometa_classification import OMetaTagAndClassifica
from metadata.ingestion.models.ometa_topic_data import OMetaTopicSampleData
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
from metadata.ingestion.models.profile_data import OMetaTableProfileSampleData
from metadata.ingestion.models.table_metadata import DeleteTable
from metadata.ingestion.models.table_metadata import DeleteTable, OMetaTableConstraints
from metadata.ingestion.models.tests_data import (
OMetaTestCaseResultsSample,
OMetaTestCaseSample,
@ -102,6 +102,7 @@ class MetadataRestSink(Sink[Entity]):
self.write_record.register(DataModelLink, self.write_datamodel)
self.write_record.register(TableLocationLink, self.write_table_location_link)
self.write_record.register(DashboardUsage, self.write_dashboard_usage)
self.write_record.register(OMetaTableConstraints, self.write_table_constraints)
self.write_record.register(
OMetaTableProfileSampleData, self.write_profile_sample_data
)
@ -461,6 +462,24 @@ class MetadataRestSink(Sink[Entity]):
f"Unexpected error while ingesting sample data for topic [{record.topic.name.__root__}]: {exc}"
)
def write_table_constraints(self, record: OMetaTableConstraints):
"""
Patch table constraints
"""
try:
self.metadata.patch_table_constraints(
record.table_id,
record.constraints,
)
logger.debug(
f"Successfully ingested table constraints for table id {record.table_id}"
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Unexpected error while ingesting table constraints for table id [{record.table_id}]: {exc}"
)
def get_status(self):
return self.status

View File

@ -14,7 +14,7 @@ Generic source to build SQL connectors.
import traceback
from abc import ABC
from copy import deepcopy
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
from pydantic import BaseModel
from sqlalchemy.engine import Connection
@ -28,7 +28,13 @@ from metadata.generated.schema.api.data.createDatabaseSchema import (
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import Table, TablePartition, TableType
from metadata.generated.schema.entity.data.table import (
ConstraintType,
Table,
TableConstraint,
TablePartition,
TableType,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
@ -41,10 +47,12 @@ from metadata.generated.schema.metadataIngestion.workflow import (
from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper
from metadata.ingestion.lineage.parser import LineageParser
from metadata.ingestion.lineage.sql_lineage import (
get_column_fqn,
get_lineage_by_query,
get_lineage_via_table_entity,
)
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.table_metadata import OMetaTableConstraints
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection, get_test_connection_fn
from metadata.ingestion.source.database.database_service import (
@ -104,6 +112,7 @@ class CommonDbSourceService(
self.table_constraints = None
self.database_source_state = set()
self.context.table_views = []
self.context.table_constrains = []
super().__init__()
def set_inspector(self, database_name: str) -> None:
@ -350,7 +359,11 @@ class CommonDbSourceService(
db_name = self.context.database.name.__root__
try:
columns, table_constraints = self.get_columns_and_constraints(
(
columns,
table_constraints,
foreign_columns,
) = self.get_columns_and_constraints(
schema_name=schema_name,
table_name=table_name,
db_name=db_name,
@ -374,7 +387,6 @@ class CommonDbSourceService(
),
columns=columns,
viewDefinition=view_definition,
tableConstraints=table_constraints if table_constraints else None,
databaseSchema=self.context.database_schema.fullyQualifiedName,
tags=self.get_tag_labels(
table_name=table_name
@ -402,6 +414,15 @@ class CommonDbSourceService(
yield table_request
self.register_record(table_request=table_request)
if table_constraints or foreign_columns:
self.context.table_constrains.append(
OMetaTableConstraints(
foreign_constraints=foreign_columns,
constraints=table_constraints,
table_id=str(self.context.table.id.__root__),
)
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Unexpected exception to yield table [{table_name}]: {exc}")
@ -459,6 +480,53 @@ class CommonDbSourceService(
f"Could not parse query [{view_definition}] ingesting lineage failed: {exc}"
)
def _get_foreign_constraints(
self, table_constraints: OMetaTableConstraints
) -> List[TableConstraint]:
"""
Search the referred table for foreign constraints
and get referred column fqn
"""
foreign_constraints = []
for constraint in table_constraints.foreign_constraints:
referred_column_fqns = []
referred_table = fqn.search_table_from_es(
metadata=self.metadata,
table_name=constraint.get("referred_table"),
schema_name=constraint.get("referred_schema"),
database_name=None,
service_name=self.context.database_service.name.__root__,
)
if referred_table:
for column in constraint.get("referred_columns"):
col_fqn = get_column_fqn(table_entity=referred_table, column=column)
if col_fqn:
referred_column_fqns.append(col_fqn)
foreign_constraints.append(
TableConstraint(
constraintType=ConstraintType.FOREIGN_KEY,
columns=constraint.get("constrained_columns"),
referredColumns=referred_column_fqns,
)
)
return foreign_constraints
def yield_table_constraints(self) -> Optional[Iterable[OMetaTableConstraints]]:
"""
From topology.
process the table constraints of all tables
"""
for table_constraints in self.context.table_constrains:
foreign_constraints = self._get_foreign_constraints(table_constraints)
if foreign_constraints:
if table_constraints.constraints:
table_constraints.constraints.extend(foreign_constraints)
else:
table_constraints.constraints = foreign_constraints
yield table_constraints
def test_connection(self) -> None:
"""
Used a timed-bound function to test that the engine

View File

@ -59,7 +59,7 @@ from metadata.generated.schema.type.tagLabel import (
from metadata.ingestion.api.source import Source, SourceStatus
from metadata.ingestion.api.topology_runner import TopologyRunnerMixin
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.table_metadata import DeleteTable
from metadata.ingestion.models.table_metadata import DeleteTable, OMetaTableConstraints
from metadata.ingestion.models.topology import (
NodeStage,
ServiceTopology,
@ -119,9 +119,7 @@ class DatabaseServiceTopology(ServiceTopology):
),
],
children=["database"],
post_process=[
"yield_view_lineage",
],
post_process=["yield_view_lineage", "yield_table_constraints"],
)
database = TopologyNode(
producer="get_database_names",
@ -309,6 +307,14 @@ class DatabaseServiceSource(
Parses view definition to get lineage information
"""
def yield_table_constraints(self) -> Optional[Iterable[OMetaTableConstraints]]:
"""
From topology.
process the table constraints of all tables
by default no need to process table constraints
specially for non SQA sources
"""
@abstractmethod
def yield_table(
self, table_name_and_type: Tuple[str, TableType]
@ -327,8 +333,6 @@ class DatabaseServiceSource(
"""
From topology.
Prepare a location request and pass it to the sink.
Also, update the self.inspector value to the current db.
"""
return

View File

@ -33,7 +33,9 @@ from sqlalchemy_redshift.dialect import (
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import (
ConstraintType,
IntervalType,
TableConstraint,
TablePartition,
TableType,
)
@ -73,6 +75,40 @@ logger = ingestion_logger()
ischema_names = pg_ischema_names
ischema_names.update({"binary varying": sqltypes.VARBINARY})
# pylint: disable=protected-access
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
Return information about columns in `table_name`.
Overrides interface
:meth:`~sqlalchemy.engine.interfaces.Dialect.get_columns`.
overriding the default dialect method to include the
distkey and sortkey info
"""
cols = self._get_redshift_columns(connection, table_name, schema, **kw)
if not self._domains:
self._domains = self._load_domains(connection)
domains = self._domains
columns = []
for col in cols:
column_info = self._get_column_info(
name=col.name,
format_type=col.format_type,
default=col.default,
notnull=col.notnull,
domains=domains,
enums=[],
schema=col.schema,
encode=col.encode,
comment=col.comment,
)
column_info["distkey"] = col.distkey
column_info["sortkey"] = col.sortkey
columns.append(column_info)
return columns
def _get_column_info(self, *args, **kwargs):
"""
@ -118,6 +154,9 @@ def _get_schema_column_info(
schema:
**kw:
Returns:
This method is responsible for fetching all the column details like
name, type, constraints, distkey and sortkey etc.
"""
schema_clause = f"AND schema = '{schema if schema else ''}'"
all_columns = defaultdict(list)
@ -137,6 +176,7 @@ RedshiftDialectMixin._get_column_info = ( # pylint: disable=protected-access
RedshiftDialectMixin._get_schema_column_info = ( # pylint: disable=protected-access
_get_schema_column_info
)
RedshiftDialectMixin.get_columns = get_columns
def _handle_array_type(attype):
@ -483,3 +523,26 @@ class RedshiftSource(CommonDbSourceService):
)
return True, partition_details
return False, None
def process_additional_table_constraints(
self, column: dict, table_constraints: List[TableConstraint]
) -> None:
"""
Process DIST_KEY & SORT_KEY column properties
"""
if column.get("distkey"):
table_constraints.append(
TableConstraint(
constraintType=ConstraintType.DIST_KEY,
columns=[column.get("name")],
)
)
if column.get("sortkey"):
table_constraints.append(
TableConstraint(
constraintType=ConstraintType.SORT_KEY,
columns=[column.get("name")],
)
)

View File

@ -13,7 +13,7 @@ Generic call to handle table columns for sql connectors.
"""
import re
import traceback
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from sqlalchemy.engine.reflection import Inspector
@ -43,6 +43,13 @@ class SqlColumnHandlerMixin:
logger.info("Fetching tags not implemented for this connector")
self.source_config.includeTags = False
def process_additional_table_constraints(
self, column: dict, table_constraints: List[TableConstraint]
) -> None:
"""
By Default there are no additional table constraints
"""
def _get_display_datatype(
self,
data_type_display: str,
@ -100,7 +107,7 @@ class SqlColumnHandlerMixin:
@staticmethod
def _get_columns_with_constraints(
schema_name: str, table_name: str, inspector: Inspector
) -> Tuple[List, List]:
) -> Tuple[List, List, List]:
pk_constraints = inspector.get_pk_constraint(table_name, schema_name)
try:
unique_constraints = inspector.get_unique_constraints(
@ -130,24 +137,35 @@ class SqlColumnHandlerMixin:
if len(foreign_constraint) > 0 and foreign_constraint.get(
"constrained_columns"
):
foreign_columns.extend(foreign_constraint.get("constrained_columns"))
foreign_constraint.update(
{
"constrained_columns": [
clean_up_starting_ending_double_quotes_in_string(column)
for column in foreign_constraint.get("constrained_columns")
],
"referred_columns": [
clean_up_starting_ending_double_quotes_in_string(column)
for column in foreign_constraint.get("referred_columns")
],
}
)
foreign_columns.append(foreign_constraint)
unique_columns = []
for constraint in unique_constraints:
if constraint.get("column_names"):
unique_columns.extend(constraint.get("column_names"))
unique_columns.append(
[
clean_up_starting_ending_double_quotes_in_string(column)
for column in constraint.get("column_names")
]
)
pk_columns = [
clean_up_starting_ending_double_quotes_in_string(pk_column)
for pk_column in pk_columns
]
unique_columns = [
clean_up_starting_ending_double_quotes_in_string(unique_column)
for unique_column in unique_columns
]
foreign_columns = [
clean_up_starting_ending_double_quotes_in_string(foreign_column)
for foreign_column in foreign_columns
]
return pk_columns, unique_columns, foreign_columns
def _process_complex_col_type(self, parsed_string: dict, column: dict) -> Column:
@ -176,25 +194,43 @@ class SqlColumnHandlerMixin:
def get_columns_and_constraints( # pylint: disable=too-many-locals
self, schema_name: str, table_name: str, db_name: str, inspector: Inspector
) -> Tuple[Optional[List[Column]], Optional[List[TableConstraint]]]:
) -> Tuple[
Optional[List[Column]], Optional[List[TableConstraint]], Optional[List[Dict]]
]:
"""
Get columns types and constraints information
"""
table_constraints = []
# Get inspector information:
(
pk_columns,
unique_columns,
foreign_columns,
) = self._get_columns_with_constraints(schema_name, table_name, inspector)
table_columns = []
table_constraints = []
if foreign_columns:
column_level_unique_constraints = set()
for col in unique_columns:
if len(col) == 1:
column_level_unique_constraints.add(col[0])
else:
table_constraints.append(
TableConstraint(
constraintType=ConstraintType.UNIQUE,
columns=col,
)
)
if len(pk_columns) > 1:
table_constraints.append(
TableConstraint(
constraintType=ConstraintType.FOREIGN_KEY,
columns=foreign_columns,
constraintType=ConstraintType.PRIMARY_KEY,
columns=pk_columns,
)
)
table_columns = []
columns = inspector.get_columns(table_name, schema_name, db_name=db_name)
for column in columns:
try:
@ -204,18 +240,14 @@ class SqlColumnHandlerMixin:
arr_data_type,
parsed_string,
) = self._process_col_type(column, schema_name)
self.process_additional_table_constraints(
column=column, table_constraints=table_constraints
)
if parsed_string is None:
col_type = ColumnTypeParser.get_column_type(column["type"])
col_constraint = self._get_column_constraints(
column, pk_columns, unique_columns
column, pk_columns, column_level_unique_constraints
)
if not col_constraint and len(pk_columns) > 1:
table_constraints.append(
TableConstraint(
constraintType=ConstraintType.PRIMARY_KEY,
columns=[column["name"]],
)
)
col_data_length = self._check_col_length(col_type, column["type"])
precision = ColumnTypeParser.check_col_precision(
col_type, column["type"]
@ -266,7 +298,7 @@ class SqlColumnHandlerMixin:
)
continue
table_columns.append(om_column)
return table_columns, table_constraints
return table_columns, table_constraints, foreign_columns
@staticmethod
def _check_col_length(datatype: str, col_raw_type: object):

View File

@ -143,22 +143,19 @@ def _(
:param table_name: Table name
:return:
"""
fqn_search_string = build_es_fqn_search_string(
database_name, schema_name, service_name, table_name
)
es_result = (
metadata.es_search_from_fqn(
entity_type=Table,
fqn_search_string=fqn_search_string,
entity: Optional[Union[Table, List[Table]]] = None
if not skip_es_search:
entity = search_table_from_es(
metadata=metadata,
database_name=database_name,
schema_name=schema_name,
table_name=table_name,
fetch_multiple_entities=fetch_multiple_entities,
service_name=service_name,
)
if not skip_es_search
else None
)
entity: Optional[Union[Table, List[Table]]] = get_entity_from_es_result(
entity_list=es_result, fetch_multiple_entities=fetch_multiple_entities
)
# if entity not found in ES proceed to build FQN with database_name and schema_name
if not entity and database_name and schema_name:
fqn = _build(service_name, database_name, schema_name, table_name)
@ -476,3 +473,25 @@ def build_es_fqn_search_string(
service_name, database_name or "*", schema_name or "*", table_name
)
return fqn_search_string
def search_table_from_es(
metadata: OpenMetadata,
database_name: str,
schema_name: str,
service_name: str,
table_name: str,
fetch_multiple_entities: bool = False,
):
fqn_search_string = build_es_fqn_search_string(
database_name, schema_name, service_name, table_name
)
es_result = metadata.es_search_from_fqn(
entity_type=Table,
fqn_search_string=fqn_search_string,
)
return get_entity_from_es_result(
entity_list=es_result, fetch_multiple_entities=fetch_multiple_entities
)

View File

@ -235,7 +235,7 @@ class PostgresUnitTest(TestCase):
inspector.get_pk_constraint = lambda table_name, schema_name: []
inspector.get_unique_constraints = lambda table_name, schema_name: []
inspector.get_foreign_keys = lambda table_name, schema_name: []
result, _ = self.postgres_source.get_columns_and_constraints(
result, _, _ = self.postgres_source.get_columns_and_constraints(
"public", "user", "postgres", inspector
)
for i in range(len(EXPECTED_COLUMN_VALUE)):

View File

@ -150,7 +150,9 @@
"enum": [
"UNIQUE",
"PRIMARY_KEY",
"FOREIGN_KEY"
"FOREIGN_KEY",
"SORT_KEY",
"DIST_KEY"
]
},
"columns": {
@ -159,6 +161,14 @@
"items": {
"type": "string"
}
},
"referredColumns": {
"description": "List of referred columns for the constraint.",
"type": "array",
"items": {
"$ref": "../../type/basic.json#/definitions/fullyQualifiedEntityName"
},
"default": null
}
},
"additionalProperties": false