mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-03 20:27:50 +00:00 
			
		
		
		
	feat(ingest): athena - set Athena location as upstream (#4503)
This commit is contained in:
		
							parent
							
								
									37aedfc87c
								
							
						
					
					
						commit
						4358d8fb01
					
				@ -1,4 +1,5 @@
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
import typing
 | 
					import typing
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple
 | 
					from typing import Dict, List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -8,6 +9,7 @@ from sqlalchemy.engine.reflection import Inspector
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from datahub.emitter.mcp_builder import DatabaseKey, gen_containers
 | 
					from datahub.emitter.mcp_builder import DatabaseKey, gen_containers
 | 
				
			||||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
 | 
					from datahub.ingestion.api.workunit import MetadataWorkUnit
 | 
				
			||||||
 | 
					from datahub.ingestion.source.aws.s3_util import make_s3_urn
 | 
				
			||||||
from datahub.ingestion.source.sql.sql_common import (
 | 
					from datahub.ingestion.source.sql.sql_common import (
 | 
				
			||||||
    SQLAlchemyConfig,
 | 
					    SQLAlchemyConfig,
 | 
				
			||||||
    SQLAlchemySource,
 | 
					    SQLAlchemySource,
 | 
				
			||||||
@ -52,7 +54,7 @@ class AthenaSource(SQLAlchemySource):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def get_table_properties(
 | 
					    def get_table_properties(
 | 
				
			||||||
        self, inspector: Inspector, schema: str, table: str
 | 
					        self, inspector: Inspector, schema: str, table: str
 | 
				
			||||||
    ) -> Tuple[Optional[str], Optional[Dict[str, str]]]:
 | 
					    ) -> Tuple[Optional[str], Optional[Dict[str, str]], Optional[str]]:
 | 
				
			||||||
        if not self.cursor:
 | 
					        if not self.cursor:
 | 
				
			||||||
            self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor()
 | 
					            self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -87,7 +89,17 @@ class AthenaSource(SQLAlchemySource):
 | 
				
			|||||||
            metadata.table_type if metadata.table_type else ""
 | 
					            metadata.table_type if metadata.table_type else ""
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return description, custom_properties
 | 
					        location: Optional[str] = custom_properties.get("location", None)
 | 
				
			||||||
 | 
					        if location is not None:
 | 
				
			||||||
 | 
					            if location.startswith("s3://"):
 | 
				
			||||||
 | 
					                location = make_s3_urn(location, self.config.env)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                logging.debug(
 | 
				
			||||||
 | 
					                    f"Only s3 url supported for location. Skipping {location}"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                location = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return description, custom_properties, location
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # It seems like database/schema filter in the connection string does not work and this to work around that
 | 
					    # It seems like database/schema filter in the connection string does not work and this to work around that
 | 
				
			||||||
    def get_schema_names(self, inspector: Inspector) -> List[str]:
 | 
					    def get_schema_names(self, inspector: Inspector) -> List[str]:
 | 
				
			||||||
@ -105,10 +117,11 @@ class AthenaSource(SQLAlchemySource):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def gen_schema_key(self, db_name: str, schema: str) -> DatabaseKey:
 | 
					    def gen_schema_key(self, db_name: str, schema: str) -> DatabaseKey:
 | 
				
			||||||
        return DatabaseKey(
 | 
					        return DatabaseKey(
 | 
				
			||||||
            platform=self.platform,
 | 
					 | 
				
			||||||
            environment=self.config.env,
 | 
					 | 
				
			||||||
            instance=self.config.platform_instance,
 | 
					 | 
				
			||||||
            database=schema,
 | 
					            database=schema,
 | 
				
			||||||
 | 
					            platform=self.platform,
 | 
				
			||||||
 | 
					            instance=self.config.platform_instance
 | 
				
			||||||
 | 
					            if self.config.platform_instance is not None
 | 
				
			||||||
 | 
					            else self.config.env,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def gen_schema_containers(
 | 
					    def gen_schema_containers(
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
 | 
				
			|||||||
    StatefulIngestionSourceBase,
 | 
					    StatefulIngestionSourceBase,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from datahub.metadata.com.linkedin.pegasus2avro.common import StatusClass
 | 
					from datahub.metadata.com.linkedin.pegasus2avro.common import StatusClass
 | 
				
			||||||
 | 
					from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage
 | 
				
			||||||
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
 | 
					from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
 | 
				
			||||||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
 | 
					from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
 | 
				
			||||||
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
 | 
					from datahub.metadata.com.linkedin.pegasus2avro.schema import (
 | 
				
			||||||
@ -77,9 +78,11 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
 | 
				
			|||||||
from datahub.metadata.schema_classes import (
 | 
					from datahub.metadata.schema_classes import (
 | 
				
			||||||
    ChangeTypeClass,
 | 
					    ChangeTypeClass,
 | 
				
			||||||
    DataPlatformInstanceClass,
 | 
					    DataPlatformInstanceClass,
 | 
				
			||||||
 | 
					    DatasetLineageTypeClass,
 | 
				
			||||||
    DatasetPropertiesClass,
 | 
					    DatasetPropertiesClass,
 | 
				
			||||||
    JobStatusClass,
 | 
					    JobStatusClass,
 | 
				
			||||||
    SubTypesClass,
 | 
					    SubTypesClass,
 | 
				
			||||||
 | 
					    UpstreamClass,
 | 
				
			||||||
    ViewPropertiesClass,
 | 
					    ViewPropertiesClass,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from datahub.telemetry import telemetry
 | 
					from datahub.telemetry import telemetry
 | 
				
			||||||
@ -846,13 +849,35 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
 | 
				
			|||||||
                    BaseSQLAlchemyCheckpointState, cur_checkpoint.state
 | 
					                    BaseSQLAlchemyCheckpointState, cur_checkpoint.state
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                checkpoint_state.add_table_urn(dataset_urn)
 | 
					                checkpoint_state.add_table_urn(dataset_urn)
 | 
				
			||||||
        description, properties = self.get_table_properties(inspector, schema, table)
 | 
					
 | 
				
			||||||
 | 
					        description, properties, location_urn = self.get_table_properties(
 | 
				
			||||||
 | 
					            inspector, schema, table
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        dataset_properties = DatasetPropertiesClass(
 | 
					        dataset_properties = DatasetPropertiesClass(
 | 
				
			||||||
            name=table,
 | 
					            name=table,
 | 
				
			||||||
            description=description,
 | 
					            description=description,
 | 
				
			||||||
            customProperties=properties,
 | 
					            customProperties=properties,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        dataset_snapshot.aspects.append(dataset_properties)
 | 
					        dataset_snapshot.aspects.append(dataset_properties)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if location_urn:
 | 
				
			||||||
 | 
					            external_upstream_table = UpstreamClass(
 | 
				
			||||||
 | 
					                dataset=location_urn,
 | 
				
			||||||
 | 
					                type=DatasetLineageTypeClass.COPY,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            lineage_mcpw = MetadataChangeProposalWrapper(
 | 
				
			||||||
 | 
					                entityType="dataset",
 | 
				
			||||||
 | 
					                changeType=ChangeTypeClass.UPSERT,
 | 
				
			||||||
 | 
					                entityUrn=dataset_snapshot.urn,
 | 
				
			||||||
 | 
					                aspectName="upstreamLineage",
 | 
				
			||||||
 | 
					                aspect=UpstreamLineage(upstreams=[external_upstream_table]),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            lineage_wu = MetadataWorkUnit(
 | 
				
			||||||
 | 
					                id=f"{self.platform}-{lineage_mcpw.entityUrn}-{lineage_mcpw.aspectName}",
 | 
				
			||||||
 | 
					                mcp=lineage_mcpw,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            yield lineage_wu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pk_constraints: dict = inspector.get_pk_constraint(table, schema)
 | 
					        pk_constraints: dict = inspector.get_pk_constraint(table, schema)
 | 
				
			||||||
        foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
 | 
					        foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
 | 
				
			||||||
        schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints)
 | 
					        schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints)
 | 
				
			||||||
@ -896,8 +921,9 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def get_table_properties(
 | 
					    def get_table_properties(
 | 
				
			||||||
        self, inspector: Inspector, schema: str, table: str
 | 
					        self, inspector: Inspector, schema: str, table: str
 | 
				
			||||||
    ) -> Tuple[Optional[str], Optional[Dict[str, str]]]:
 | 
					    ) -> Tuple[Optional[str], Optional[Dict[str, str]], Optional[str]]:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            location: Optional[str] = None
 | 
				
			||||||
            # SQLALchemy stubs are incomplete and missing this method.
 | 
					            # SQLALchemy stubs are incomplete and missing this method.
 | 
				
			||||||
            # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223.
 | 
					            # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223.
 | 
				
			||||||
            table_info: dict = inspector.get_table_comment(table, schema)  # type: ignore
 | 
					            table_info: dict = inspector.get_table_comment(table, schema)  # type: ignore
 | 
				
			||||||
@ -918,7 +944,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            # The "properties" field is a non-standard addition to SQLAlchemy's interface.
 | 
					            # The "properties" field is a non-standard addition to SQLAlchemy's interface.
 | 
				
			||||||
            properties = table_info.get("properties", {})
 | 
					            properties = table_info.get("properties", {})
 | 
				
			||||||
        return description, properties
 | 
					        return description, properties, location
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_dataplatform_instance_aspect(
 | 
					    def get_dataplatform_instance_aspect(
 | 
				
			||||||
        self, dataset_urn: str
 | 
					        self, dataset_urn: str
 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,8 @@ from unittest import mock
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from freezegun import freeze_time
 | 
					from freezegun import freeze_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from src.datahub.ingestion.source.aws.s3_util import make_s3_urn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
FROZEN_TIME = "2020-04-14 07:00:00"
 | 
					FROZEN_TIME = "2020-04-14 07:00:00"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -56,7 +58,7 @@ def test_athena_get_table_properties():
 | 
				
			|||||||
            ],
 | 
					            ],
 | 
				
			||||||
            "Parameters": {
 | 
					            "Parameters": {
 | 
				
			||||||
                "comment": "testComment",
 | 
					                "comment": "testComment",
 | 
				
			||||||
                "location": "testLocation",
 | 
					                "location": "s3://testLocation",
 | 
				
			||||||
                "inputformat": "testInputFormat",
 | 
					                "inputformat": "testInputFormat",
 | 
				
			||||||
                "outputformat": "testOutputFormat",
 | 
					                "outputformat": "testOutputFormat",
 | 
				
			||||||
                "serde.serialization.lib": "testSerde",
 | 
					                "serde.serialization.lib": "testSerde",
 | 
				
			||||||
@ -74,7 +76,7 @@ def test_athena_get_table_properties():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    ctx = PipelineContext(run_id="test")
 | 
					    ctx = PipelineContext(run_id="test")
 | 
				
			||||||
    source = AthenaSource(config=config, ctx=ctx)
 | 
					    source = AthenaSource(config=config, ctx=ctx)
 | 
				
			||||||
    description, custom_properties = source.get_table_properties(
 | 
					    description, custom_properties, location = source.get_table_properties(
 | 
				
			||||||
        inspector=mock_inspector, table=table, schema=schema
 | 
					        inspector=mock_inspector, table=table, schema=schema
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert custom_properties == {
 | 
					    assert custom_properties == {
 | 
				
			||||||
@ -82,9 +84,11 @@ def test_athena_get_table_properties():
 | 
				
			|||||||
        "create_time": "2020-04-14 07:00:00",
 | 
					        "create_time": "2020-04-14 07:00:00",
 | 
				
			||||||
        "inputformat": "testInputFormat",
 | 
					        "inputformat": "testInputFormat",
 | 
				
			||||||
        "last_access_time": "2020-04-14 07:00:00",
 | 
					        "last_access_time": "2020-04-14 07:00:00",
 | 
				
			||||||
        "location": "testLocation",
 | 
					        "location": "s3://testLocation",
 | 
				
			||||||
        "outputformat": "testOutputFormat",
 | 
					        "outputformat": "testOutputFormat",
 | 
				
			||||||
        "partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]',
 | 
					        "partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]',
 | 
				
			||||||
        "serde.serialization.lib": "testSerde",
 | 
					        "serde.serialization.lib": "testSerde",
 | 
				
			||||||
        "table_type": "testType",
 | 
					        "table_type": "testType",
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert location == make_s3_urn("s3://testLocation", "PROD")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user