2021-12-20 16:11:55 +01:00
|
|
|
from typing import Dict
|
|
|
|
from unittest.mock import Mock
|
|
|
|
|
2022-04-12 03:48:15 +01:00
|
|
|
import pytest
|
2021-12-20 16:11:55 +01:00
|
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
|
|
|
2023-10-04 06:53:15 -04:00
|
|
|
from datahub.ingestion.source.sql.sql_common import PipelineContext, SQLAlchemySource
|
|
|
|
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
|
|
|
|
from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
|
2022-04-12 03:48:15 +01:00
|
|
|
get_platform_from_sqlalchemy_uri,
|
2021-12-20 16:11:55 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-08-15 17:49:20 -04:00
|
|
|
class _TestSQLAlchemyConfig(SQLCommonConfig):
|
2021-12-20 16:11:55 +01:00
|
|
|
def get_sql_alchemy_url(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-05-26 12:39:40 -04:00
|
|
|
class _TestSQLAlchemySource(SQLAlchemySource):
|
2022-12-02 13:53:28 -05:00
|
|
|
pass
|
2021-12-20 16:11:55 +01:00
|
|
|
|
|
|
|
|
|
|
|
def test_generate_foreign_key():
|
2023-08-15 17:49:20 -04:00
|
|
|
config: SQLCommonConfig = _TestSQLAlchemyConfig()
|
2021-12-20 16:11:55 +01:00
|
|
|
ctx: PipelineContext = PipelineContext(run_id="test_ctx")
|
|
|
|
platform: str = "TEST"
|
|
|
|
inspector: Inspector = Mock()
|
2022-05-26 12:39:40 -04:00
|
|
|
source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform)
|
2021-12-20 16:11:55 +01:00
|
|
|
fk_dict: Dict[str, str] = {
|
|
|
|
"name": "test_constraint",
|
|
|
|
"referred_table": "test_table",
|
|
|
|
"referred_schema": "test_referred_schema",
|
|
|
|
"constrained_columns": ["test_column"], # type: ignore
|
|
|
|
"referred_columns": ["test_referred_column"], # type: ignore
|
|
|
|
}
|
|
|
|
foreign_key = source.get_foreign_key_metadata(
|
|
|
|
dataset_urn="test_urn",
|
|
|
|
schema="test_schema",
|
|
|
|
fk_dict=fk_dict,
|
|
|
|
inspector=inspector,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert fk_dict.get("name") == foreign_key.name
|
|
|
|
assert [
|
|
|
|
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_referred_schema.test_table,PROD),test_referred_column)"
|
|
|
|
] == foreign_key.foreignFields
|
|
|
|
assert ["urn:li:schemaField:(test_urn,test_column)"] == foreign_key.sourceFields
|
|
|
|
|
|
|
|
|
|
|
|
def test_use_source_schema_for_foreign_key_if_not_specified():
|
2023-08-15 17:49:20 -04:00
|
|
|
config: SQLCommonConfig = _TestSQLAlchemyConfig()
|
2021-12-20 16:11:55 +01:00
|
|
|
ctx: PipelineContext = PipelineContext(run_id="test_ctx")
|
|
|
|
platform: str = "TEST"
|
|
|
|
inspector: Inspector = Mock()
|
2022-05-26 12:39:40 -04:00
|
|
|
source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform)
|
2021-12-20 16:11:55 +01:00
|
|
|
fk_dict: Dict[str, str] = {
|
|
|
|
"name": "test_constraint",
|
|
|
|
"referred_table": "test_table",
|
|
|
|
"constrained_columns": ["test_column"], # type: ignore
|
|
|
|
"referred_columns": ["test_referred_column"], # type: ignore
|
|
|
|
}
|
|
|
|
foreign_key = source.get_foreign_key_metadata(
|
|
|
|
dataset_urn="test_urn",
|
|
|
|
schema="test_schema",
|
|
|
|
fk_dict=fk_dict,
|
|
|
|
inspector=inspector,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert fk_dict.get("name") == foreign_key.name
|
|
|
|
assert [
|
|
|
|
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.test_table,PROD),test_referred_column)"
|
|
|
|
] == foreign_key.foreignFields
|
|
|
|
assert ["urn:li:schemaField:(test_urn,test_column)"] == foreign_key.sourceFields
|
2022-04-12 03:48:15 +01:00
|
|
|
|
|
|
|
|
|
|
|
PLATFORM_FROM_SQLALCHEMY_URI_TEST_CASES: Dict[str, str] = {
|
|
|
|
"awsathena://test_athena:3316/athenadb": "athena",
|
|
|
|
"bigquery://test_bq:3316/bigquery": "bigquery",
|
|
|
|
"clickhouse://test_clickhouse:3316/clickhousedb": "clickhouse",
|
|
|
|
"druid://test_druid:1101/druiddb": "druid",
|
|
|
|
"hive://test_hive:1201/hive": "hive",
|
|
|
|
"mongodb://test_mongodb:1201/mongo": "mongodb",
|
|
|
|
"mssql://test_mssql:1201/mssqldb": "mssql",
|
|
|
|
"mysql://test_mysql:1201/mysql": "mysql",
|
|
|
|
"oracle://test_oracle:3306/oracledb": "oracle",
|
|
|
|
"pinot://test_pinot:3306/pinot": "pinot",
|
|
|
|
"postgresql://test_postgres:5432/postgres": "postgres",
|
|
|
|
"presto://test_presto:5432/prestodb": "presto",
|
|
|
|
"redshift://test_redshift:5432/redshift": "redshift",
|
|
|
|
"jdbc:postgres://test_redshift:5432/redshift.amazonaws": "redshift",
|
|
|
|
"postgresql://test_redshift:5432/redshift.amazonaws": "redshift",
|
|
|
|
"snowflake://test_snowflake:5432/snowflakedb": "snowflake",
|
|
|
|
"trino://test_trino:5432/trino": "trino",
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"uri, expected_platform",
|
|
|
|
PLATFORM_FROM_SQLALCHEMY_URI_TEST_CASES.items(),
|
|
|
|
ids=PLATFORM_FROM_SQLALCHEMY_URI_TEST_CASES.keys(),
|
|
|
|
)
|
|
|
|
def test_get_platform_from_sqlalchemy_uri(uri: str, expected_platform: str) -> None:
|
|
|
|
platform: str = get_platform_from_sqlalchemy_uri(uri)
|
|
|
|
assert platform == expected_platform
|