datahub/metadata-ingestion/tests/unit/redshift/test_redshift_lineage.py

836 lines
31 KiB
Python

from datetime import datetime
from functools import partial
from typing import List
from unittest.mock import MagicMock
import datahub.sql_parsing.sqlglot_lineage as sqlglot_l
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.redshift.config import RedshiftConfig
from datahub.ingestion.source.redshift.lineage import (
LineageCollectorType,
LineageDataset,
LineageDatasetPlatform,
LineageItem,
RedshiftLineageExtractor,
parse_alter_table_rename,
)
from datahub.ingestion.source.redshift.redshift_schema import (
RedshiftSchema,
TempTableRow,
)
from datahub.ingestion.source.redshift.report import RedshiftReport
from datahub.metadata.schema_classes import NumberTypeClass, SchemaFieldDataTypeClass
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import (
ColumnLineageInfo,
DownstreamColumnRef,
SqlParsingDebugInfo,
SqlParsingResult,
)
from tests.unit.redshift.redshift_query_mocker import mock_cursor
def test_get_sources_from_query():
config = RedshiftConfig(host_port="localhost:5439", database="test")
report = RedshiftReport()
test_query = """
select * from my_schema.my_table
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_datasets, _ = lineage_extractor._get_sources_from_query(
db_name="test", query=test_query
)
assert len(lineage_datasets) == 1
lineage = lineage_datasets[0]
assert (
lineage.urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,test.my_schema.my_table,PROD)"
)
def test_get_sources_from_query_with_only_table_name():
config = RedshiftConfig(host_port="localhost:5439", database="test")
report = RedshiftReport()
test_query = """
select * from my_table
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_datasets, _ = lineage_extractor._get_sources_from_query(
db_name="test", query=test_query
)
assert len(lineage_datasets) == 1
lineage = lineage_datasets[0]
assert (
lineage.urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.my_table,PROD)"
)
def test_get_sources_from_query_with_database():
config = RedshiftConfig(host_port="localhost:5439", database="test")
report = RedshiftReport()
test_query = """
select * from test.my_schema.my_table
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_datasets, _ = lineage_extractor._get_sources_from_query(
db_name="test", query=test_query
)
assert len(lineage_datasets) == 1
lineage = lineage_datasets[0]
assert (
lineage.urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,test.my_schema.my_table,PROD)"
)
def test_get_sources_from_query_with_non_default_database():
config = RedshiftConfig(host_port="localhost:5439", database="test")
report = RedshiftReport()
test_query = """
select * from test2.my_schema.my_table
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_datasets, _ = lineage_extractor._get_sources_from_query(
db_name="test", query=test_query
)
assert len(lineage_datasets) == 1
lineage = lineage_datasets[0]
assert (
lineage.urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,test2.my_schema.my_table,PROD)"
)
def test_get_sources_from_query_with_only_table():
config = RedshiftConfig(host_port="localhost:5439", database="test")
report = RedshiftReport()
test_query = """
select * from my_table
"""
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo")
)
lineage_datasets, _ = lineage_extractor._get_sources_from_query(
db_name="test", query=test_query
)
assert len(lineage_datasets) == 1
lineage = lineage_datasets[0]
assert (
lineage.urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.my_table,PROD)"
)
def test_parse_alter_table_rename():
assert parse_alter_table_rename("public", "alter table foo rename to bar") == (
"public",
"foo",
"bar",
)
assert parse_alter_table_rename(
"public", "alter table second_schema.storage_v2_stg rename to storage_v2; "
) == (
"second_schema",
"storage_v2_stg",
"storage_v2",
)
def get_lineage_extractor() -> RedshiftLineageExtractor:
config = RedshiftConfig(
host_port="localhost:5439",
database="test",
resolve_temp_table_in_lineage=True,
start_time=datetime(2024, 1, 1, 12, 0, 0).isoformat() + "Z",
end_time=datetime(2024, 1, 10, 12, 0, 0).isoformat() + "Z",
)
report = RedshiftReport()
lineage_extractor = RedshiftLineageExtractor(
config, report, PipelineContext(run_id="foo", graph=mock_graph())
)
return lineage_extractor
def test_cll():
test_query = """
select a,b,c from db.public.customer inner join db.public.order on db.public.customer.id = db.public.order.customer_id
"""
lineage_extractor = get_lineage_extractor()
_, cll = lineage_extractor._get_sources_from_query(db_name="db", query=test_query)
assert cll == [
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="a"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="b"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="c"),
upstreams=[],
logic=None,
),
]
def cursor_execute_side_effect(cursor: MagicMock, query: str) -> None:
mock_cursor(cursor=cursor, query=query)
def mock_redshift_connection() -> MagicMock:
connection = MagicMock()
cursor = MagicMock()
connection.cursor.return_value = cursor
cursor.execute.side_effect = partial(cursor_execute_side_effect, cursor)
return connection
def mock_graph() -> DataHubGraph:
graph = MagicMock()
graph._make_schema_resolver.return_value = SchemaResolver(
platform="redshift",
env="PROD",
platform_instance=None,
graph=None,
)
return graph
def test_collapse_temp_lineage():
lineage_extractor = get_lineage_extractor()
connection: MagicMock = mock_redshift_connection()
lineage_extractor._init_temp_table_schema(
database=lineage_extractor.config.database,
temp_tables=list(lineage_extractor.get_temp_tables(connection=connection)),
)
lineage_extractor._populate_lineage_map(
query="select * from test_collapse_temp_lineage",
database=lineage_extractor.config.database,
all_tables_set={
lineage_extractor.config.database: {"public": {"player_price_with_hike_v6"}}
},
connection=connection,
lineage_type=LineageCollectorType.QUERY_SQL_PARSER,
)
print(lineage_extractor._lineage_map)
target_urn: str = "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.player_price_with_hike_v6,PROD)"
assert lineage_extractor._lineage_map.get(target_urn) is not None
lineage_item: LineageItem = lineage_extractor._lineage_map[target_urn]
assert list(lineage_item.upstreams)[0].urn == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.player_activity,PROD)"
)
assert lineage_item.cll is not None
assert lineage_item.cll[0].downstream.table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,"
"test.public.player_price_with_hike_v6,PROD)"
)
assert lineage_item.cll[0].downstream.column == "price"
assert lineage_item.cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.player_activity,PROD)"
)
assert lineage_item.cll[0].upstreams[0].column == "price"
def test_collapse_temp_recursive_cll_lineage():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd "
"from player_activity",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="price",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
)
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
assert (
datasets[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].column == "price"
def test_collapse_temp_recursive_with_compex_column_cll_lineage():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price+tax) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price",
),
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="tax",
),
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, price, tax "
"from player_activity",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="price",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="tax",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)",
column="tax",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
)
],
logic=None,
),
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
assert (
datasets[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].table == (
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
)
assert target_dataset_cll[0].upstreams[0].column == "price"
assert target_dataset_cll[0].upstreams[1].column == "tax"
assert target_dataset_cll[1].upstreams[0].column == "player_id"
def test_collapse_temp_recursive_cll_lineage_with_circular_reference():
lineage_extractor = get_lineage_extractor()
temp_table: TempTableRow = TempTableRow(
transaction_id=126,
query_text="CREATE TABLE #player_price distkey(player_id) AS SELECT player_id, SUM(price_usd) AS price_usd "
"from #player_activity_temp group by player_id",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_price",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
temp_table_activity: TempTableRow = TempTableRow(
transaction_id=127,
query_text="CREATE TABLE #player_activity_temp SELECT player_id, SUM(price) AS price_usd "
"from #player_price",
start_time=datetime.now(),
session_id="abc",
create_command="CREATE TABLE #player_activity_temp",
parsed_result=SqlParsingResult(
query_type=QueryType.CREATE_TABLE_AS_SELECT,
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)"
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)"
],
debug_info=SqlParsingDebugInfo(),
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="INTEGER",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="player_id",
)
],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
column_type=SchemaFieldDataTypeClass(NumberTypeClass()),
native_column_type="BIGINT",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
column="price_usd",
)
],
logic=None,
),
],
),
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)",
)
assert temp_table.urn
assert temp_table_activity.urn
lineage_extractor.temp_tables[temp_table.urn] = temp_table
lineage_extractor.temp_tables[temp_table_activity.urn] = temp_table_activity
target_dataset_cll: List[sqlglot_l.ColumnLineageInfo] = [
ColumnLineageInfo(
downstream=DownstreamColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v6,PROD)",
column="price",
column_type=SchemaFieldDataTypeClass(type=NumberTypeClass()),
native_column_type="DOUBLE PRECISION",
),
upstreams=[
sqlglot_l.ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
column="price_usd",
)
],
logic=None,
)
]
datasets = lineage_extractor._get_upstream_lineages(
sources=[
LineageDataset(
platform=LineageDatasetPlatform.REDSHIFT,
urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)",
)
],
target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)",
raw_db_name="dev",
alias_db_name="dev",
all_tables_set={
"dev": {
"public": set(),
}
},
connection=MagicMock(),
target_dataset_cll=target_dataset_cll,
)
assert len(datasets) == 1
# Here we only interested if it fails or not
def test_external_schema_get_upstream_schema_success():
schema = RedshiftSchema(
name="schema",
database="XXXXXXXX",
type="external",
option='{"SCHEMA":"sales_schema"}',
external_platform="redshift",
)
assert schema.get_upstream_schema_name() == "sales_schema"
def test_external_schema_no_upstream_schema():
schema = RedshiftSchema(
name="schema",
database="XXXXXXXX",
type="external",
option=None,
external_platform="redshift",
)
assert schema.get_upstream_schema_name() is None
def test_local_schema_no_upstream_schema():
schema = RedshiftSchema(
name="schema",
database="XXXXXXXX",
type="local",
option='{"some_other_option":"x"}',
external_platform=None,
)
assert schema.get_upstream_schema_name() is None