diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 1c22f4d064..e91e4d14a0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -519,9 +519,20 @@ class SQLAlchemySource(StatefulIngestionSourceBase): else: return f"{schema}.{entity}" - def get_foreign_key_metadata(self, dataset_urn, fk_dict, inspector): + def get_foreign_key_metadata( + self, + dataset_urn: str, + schema: str, + fk_dict: Dict[str, str], + inspector: Inspector, + ) -> ForeignKeyConstraint: + referred_schema: Optional[str] = fk_dict.get("referred_schema") + + if not referred_schema: + referred_schema = schema + referred_dataset_name = self.get_identifier( - schema=fk_dict["referred_schema"], + schema=referred_schema, entity=fk_dict["referred_table"], inspector=inspector, ) @@ -628,7 +639,9 @@ class SQLAlchemySource(StatefulIngestionSourceBase): pk_constraints: dict = inspector.get_pk_constraint(table, schema) try: foreign_keys = [ - self.get_foreign_key_metadata(dataset_urn, fk_rec, inspector) + self.get_foreign_key_metadata( + dataset_urn, schema, fk_rec, inspector + ) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: diff --git a/metadata-ingestion/tests/unit/test_sql_common.py b/metadata-ingestion/tests/unit/test_sql_common.py new file mode 100644 index 0000000000..9378518062 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_sql_common.py @@ -0,0 +1,75 @@ +from typing import Dict +from unittest.mock import Mock + +from sqlalchemy.engine.reflection import Inspector + +from datahub.ingestion.api.source import Source +from datahub.ingestion.source.sql.sql_common import ( + PipelineContext, + SQLAlchemyConfig, + SQLAlchemySource, +) + + +class TestSQLAlchemyConfig(SQLAlchemyConfig): + def get_sql_alchemy_url(self): + pass + + +class TestSQLAlchemySource(SQLAlchemySource): + @classmethod + def create(cls, config_dict: dict, ctx: PipelineContext) -> Source: + pass + + +def test_generate_foreign_key(): + config: SQLAlchemyConfig = TestSQLAlchemyConfig() + ctx: PipelineContext = PipelineContext(run_id="test_ctx") + platform: str = "TEST" + inspector: Inspector = Mock() + source = TestSQLAlchemySource(config=config, ctx=ctx, platform=platform) + fk_dict: Dict[str, str] = { + "name": "test_constraint", + "referred_table": "test_table", + "referred_schema": "test_referred_schema", + "constrained_columns": ["test_column"], # type: ignore + "referred_columns": ["test_referred_column"], # type: ignore + } + foreign_key = source.get_foreign_key_metadata( + dataset_urn="test_urn", + schema="test_schema", + fk_dict=fk_dict, + inspector=inspector, + ) + + assert fk_dict.get("name") == foreign_key.name + assert [ + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_referred_schema.test_table,PROD),test_referred_column)" + ] == foreign_key.foreignFields + assert ["urn:li:schemaField:(test_urn,test_column)"] == foreign_key.sourceFields + + +def test_use_source_schema_for_foreign_key_if_not_specified(): + config: SQLAlchemyConfig = TestSQLAlchemyConfig() + ctx: PipelineContext = PipelineContext(run_id="test_ctx") + platform: str = "TEST" + inspector: Inspector = Mock() + source = TestSQLAlchemySource(config=config, ctx=ctx, platform=platform) + fk_dict: Dict[str, str] = { + "name": "test_constraint", + "referred_table": "test_table", + "constrained_columns": ["test_column"], # type: ignore + "referred_columns": ["test_referred_column"], # type: ignore + } + foreign_key = source.get_foreign_key_metadata( + dataset_urn="test_urn", + schema="test_schema", + fk_dict=fk_dict, + inspector=inspector, + ) + + assert fk_dict.get("name") == foreign_key.name + assert [ + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.test_table,PROD),test_referred_column)" + ] == foreign_key.foreignFields + assert ["urn:li:schemaField:(test_urn,test_column)"] == foreign_key.sourceFields