diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 5f870affbf..66becd9830 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -1,4 +1,5 @@ import json +import logging import typing from typing import Dict, List, Optional, Tuple @@ -8,6 +9,7 @@ from sqlalchemy.engine.reflection import Inspector from datahub.emitter.mcp_builder import DatabaseKey, gen_containers from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.sql.sql_common import ( SQLAlchemyConfig, SQLAlchemySource, @@ -52,7 +54,7 @@ class AthenaSource(SQLAlchemySource): def get_table_properties( self, inspector: Inspector, schema: str, table: str - ) -> Tuple[Optional[str], Optional[Dict[str, str]]]: + ) -> Tuple[Optional[str], Optional[Dict[str, str]], Optional[str]]: if not self.cursor: self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor() @@ -87,7 +89,17 @@ class AthenaSource(SQLAlchemySource): metadata.table_type if metadata.table_type else "" ) - return description, custom_properties + location: Optional[str] = custom_properties.get("location", None) + if location is not None: + if location.startswith("s3://"): + location = make_s3_urn(location, self.config.env) + else: + logging.debug( + f"Only s3 url supported for location. Skipping {location}" + ) + location = None + + return description, custom_properties, location # It seems like database/schema filter in the connection string does not work and this to work around that def get_schema_names(self, inspector: Inspector) -> List[str]: @@ -105,10 +117,11 @@ class AthenaSource(SQLAlchemySource): def gen_schema_key(self, db_name: str, schema: str) -> DatabaseKey: return DatabaseKey( - platform=self.platform, - environment=self.config.env, - instance=self.config.platform_instance, database=schema, + platform=self.platform, + instance=self.config.platform_instance + if self.config.platform_instance is not None + else self.config.env, ) def gen_schema_containers( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 8f751cde74..f28dc4ee64 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -55,6 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionSourceBase, ) from datahub.metadata.com.linkedin.pegasus2avro.common import StatusClass +from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent from datahub.metadata.com.linkedin.pegasus2avro.schema import ( @@ -77,9 +78,11 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import ( from datahub.metadata.schema_classes import ( ChangeTypeClass, DataPlatformInstanceClass, + DatasetLineageTypeClass, DatasetPropertiesClass, JobStatusClass, SubTypesClass, + UpstreamClass, ViewPropertiesClass, ) from datahub.telemetry import telemetry @@ -846,13 +849,35 @@ class SQLAlchemySource(StatefulIngestionSourceBase): BaseSQLAlchemyCheckpointState, cur_checkpoint.state ) checkpoint_state.add_table_urn(dataset_urn) - description, properties = self.get_table_properties(inspector, schema, table) + + description, properties, location_urn = self.get_table_properties( + inspector, schema, table + ) dataset_properties = DatasetPropertiesClass( name=table, description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) + + if location_urn: + external_upstream_table = UpstreamClass( + dataset=location_urn, + type=DatasetLineageTypeClass.COPY, + ) + lineage_mcpw = MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeTypeClass.UPSERT, + entityUrn=dataset_snapshot.urn, + aspectName="upstreamLineage", + aspect=UpstreamLineage(upstreams=[external_upstream_table]), + ) + lineage_wu = MetadataWorkUnit( + id=f"{self.platform}-{lineage_mcpw.entityUrn}-{lineage_mcpw.aspectName}", + mcp=lineage_mcpw, + ) + yield lineage_wu + pk_constraints: dict = inspector.get_pk_constraint(table, schema) foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table) schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints) @@ -896,8 +921,9 @@ class SQLAlchemySource(StatefulIngestionSourceBase): def get_table_properties( self, inspector: Inspector, schema: str, table: str - ) -> Tuple[Optional[str], Optional[Dict[str, str]]]: + ) -> Tuple[Optional[str], Optional[Dict[str, str]], Optional[str]]: try: + location: Optional[str] = None # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment(table, schema) # type: ignore @@ -918,7 +944,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase): # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) - return description, properties + return description, properties, location def get_dataplatform_instance_aspect( self, dataset_urn: str diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index cd38821f4c..812e724eea 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -5,6 +5,8 @@ from unittest import mock import pytest from freezegun import freeze_time +from src.datahub.ingestion.source.aws.s3_util import make_s3_urn + FROZEN_TIME = "2020-04-14 07:00:00" @@ -56,7 +58,7 @@ def test_athena_get_table_properties(): ], "Parameters": { "comment": "testComment", - "location": "testLocation", + "location": "s3://testLocation", "inputformat": "testInputFormat", "outputformat": "testOutputFormat", "serde.serialization.lib": "testSerde", @@ -74,7 +76,7 @@ def test_athena_get_table_properties(): ctx = PipelineContext(run_id="test") source = AthenaSource(config=config, ctx=ctx) - description, custom_properties = source.get_table_properties( + description, custom_properties, location = source.get_table_properties( inspector=mock_inspector, table=table, schema=schema ) assert custom_properties == { @@ -82,9 +84,11 @@ def test_athena_get_table_properties(): "create_time": "2020-04-14 07:00:00", "inputformat": "testInputFormat", "last_access_time": "2020-04-14 07:00:00", - "location": "testLocation", + "location": "s3://testLocation", "outputformat": "testOutputFormat", "partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]', "serde.serialization.lib": "testSerde", "table_type": "testType", } + + assert location == make_s3_urn("s3://testLocation", "PROD")