feat(ingestion): Add test_connection methods for important sources (#9334)

This commit is contained in:
Shubham Jagtap 2023-12-14 23:01:51 +05:30 committed by GitHub
parent ecef50f8fc
commit 1741c07d76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 685 additions and 382 deletions

View File

@ -14,7 +14,12 @@ from datahub.ingestion.api.decorators import (
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
DBTCommonConfig,
@ -177,7 +182,7 @@ query DatahubMetadataQuery_{type}($jobId: BigInt!, $runId: BigInt) {{
@support_status(SupportStatus.INCUBATING)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCloudSource(DBTSourceBase):
class DBTCloudSource(DBTSourceBase, TestableSource):
"""
This source pulls dbt metadata directly from the dbt Cloud APIs.
@ -199,6 +204,57 @@ class DBTCloudSource(DBTSourceBase):
config = DBTCloudConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCloudConfig.parse_obj_allow_extras(config_dict)
DBTCloudSource._send_graphql_query(
metadata_endpoint=source_config.metadata_endpoint,
token=source_config.token,
query=_DBT_GRAPHQL_QUERY.format(type="tests", fields="jobId"),
variables={
"jobId": source_config.job_id,
"runId": source_config.run_id,
},
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report
@staticmethod
def _send_graphql_query(
metadata_endpoint: str, token: str, query: str, variables: Dict
) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {token}",
"X-dbt-partner-source": "acryldatahub",
},
)
try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e
return data
def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# TODO: In dbt Cloud, commands are scheduled as part of jobs, where
# each job can have multiple runs. We currently only fully support
@ -213,6 +269,8 @@ class DBTCloudSource(DBTSourceBase):
for node_type, fields in _DBT_FIELDS_BY_TYPE.items():
logger.info(f"Fetching {node_type} from dbt Cloud")
data = self._send_graphql_query(
metadata_endpoint=self.config.metadata_endpoint,
token=self.config.token,
query=_DBT_GRAPHQL_QUERY.format(type=node_type, fields=fields),
variables={
"jobId": self.config.job_id,
@ -232,33 +290,6 @@ class DBTCloudSource(DBTSourceBase):
return nodes, additional_metadata
def _send_graphql_query(self, query: str, variables: Dict) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
self.config.metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {self.config.token}",
"X-dbt-partner-source": "acryldatahub",
},
)
try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e
return data
def _parse_into_dbt_node(self, node: Dict) -> DBTNode:
key = node["uniqueId"]

View File

@ -18,7 +18,12 @@ from datahub.ingestion.api.decorators import (
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
@ -60,11 +65,6 @@ class DBTCoreConfig(DBTCommonConfig):
_github_info_deprecated = pydantic_renamed_field("github_info", "git_info")
@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()
@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
@ -363,7 +363,7 @@ def load_test_results(
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCoreSource(DBTSourceBase):
class DBTCoreSource(DBTSourceBase, TestableSource):
"""
The artifacts used by this source are:
- [dbt manifest file](https://docs.getdbt.com/reference/artifacts/manifest-json)
@ -387,12 +387,34 @@ class DBTCoreSource(DBTSourceBase):
config = DBTCoreConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")
def load_file_as_json(self, uri: str) -> Any:
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict)
DBTCoreSource.load_file_as_json(
source_config.manifest_path, source_config.aws_connection
)
DBTCoreSource.load_file_as_json(
source_config.catalog_path, source_config.aws_connection
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report
@staticmethod
def load_file_as_json(
uri: str, aws_connection: Optional[AwsConnectionConfig]
) -> Dict:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
assert aws_connection
response = aws_connection.get_s3_client().get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
@ -410,12 +432,18 @@ class DBTCoreSource(DBTSourceBase):
Optional[str],
Optional[str],
]:
dbt_manifest_json = self.load_file_as_json(self.config.manifest_path)
dbt_manifest_json = self.load_file_as_json(
self.config.manifest_path, self.config.aws_connection
)
dbt_catalog_json = self.load_file_as_json(self.config.catalog_path)
dbt_catalog_json = self.load_file_as_json(
self.config.catalog_path, self.config.aws_connection
)
if self.config.sources_path is not None:
dbt_sources_json = self.load_file_as_json(self.config.sources_path)
dbt_sources_json = self.load_file_as_json(
self.config.sources_path, self.config.aws_connection
)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
@ -491,7 +519,9 @@ class DBTCoreSource(DBTSourceBase):
# This will populate the test_results field on each test node.
all_nodes = load_test_results(
self.config,
self.load_file_as_json(self.config.test_results_path),
self.load_file_as_json(
self.config.test_results_path, self.config.aws_connection
),
all_nodes,
)

View File

@ -15,6 +15,7 @@ from confluent_kafka.admin import (
ConfigResource,
TopicMetadata,
)
from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
@ -40,7 +41,13 @@ from datahub.ingestion.api.decorators import (
support_status,
)
from datahub.ingestion.api.registry import import_path
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
@ -133,6 +140,18 @@ class KafkaSourceConfig(
)
def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
**connection.consumer_config,
}
)
@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):
topics_scanned: int = 0
@ -145,6 +164,45 @@ class KafkaSourceReport(StaleEntityRemovalSourceReport):
self.filtered.append(topic)
class KafkaConnectionTest:
def __init__(self, config_dict: dict):
self.config = KafkaSourceConfig.parse_obj_allow_extras(config_dict)
self.report = KafkaSourceReport()
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.config.connection
)
def get_connection_test(self) -> TestConnectionReport:
capability_report = {
SourceCapability.SCHEMA_METADATA: self.schema_registry_connectivity(),
}
return TestConnectionReport(
basic_connectivity=self.basic_connectivity(),
capability_report={
k: v for k, v in capability_report.items() if v is not None
},
)
def basic_connectivity(self) -> CapabilityReport:
try:
self.consumer.list_topics(timeout=10)
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))
def schema_registry_connectivity(self) -> CapabilityReport:
try:
SchemaRegistryClient(
{
"url": self.config.connection.schema_registry_url,
**self.config.connection.schema_registry_config,
}
).get_subjects()
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))
@platform_name("Kafka")
@config_class(KafkaSourceConfig)
@support_status(SupportStatus.CERTIFIED)
@ -160,7 +218,7 @@ class KafkaSourceReport(StaleEntityRemovalSourceReport):
SourceCapability.SCHEMA_METADATA,
"Schemas associated with each topic are extracted from the schema registry. Avro and Protobuf (certified), JSON (incubating). Schema references are supported.",
)
class KafkaSource(StatefulIngestionSourceBase):
class KafkaSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Topics from the Kafka broker
@ -183,12 +241,8 @@ class KafkaSource(StatefulIngestionSourceBase):
def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.source_config: KafkaSourceConfig = config
self.consumer: confluent_kafka.Consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": self.source_config.connection.bootstrap,
**self.source_config.connection.consumer_config,
}
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.source_config.connection
)
self.init_kafka_admin_client()
self.report: KafkaSourceReport = KafkaSourceReport()
@ -226,6 +280,10 @@ class KafkaSource(StatefulIngestionSourceBase):
f"Failed to create Kafka Admin Client due to error {e}.",
)
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
return KafkaConnectionTest(config_dict).get_connection_test()
@classmethod
def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource":
config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict)

View File

@ -19,7 +19,13 @@ from datahub.ingestion.api.decorators import (
platform_name,
support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceReport
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceReport,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import (
@ -1147,7 +1153,7 @@ class Mapper:
SourceCapability.LINEAGE_FINE,
"Disabled by default, configured using `extract_column_level_lineage`. ",
)
class PowerBiDashboardSource(StatefulIngestionSourceBase):
class PowerBiDashboardSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Power BI dashboards, tiles and datasets
@ -1186,6 +1192,18 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase):
self, self.source_config, self.ctx
)
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
PowerBiAPI(PowerBiDashboardSourceConfig.parse_obj_allow_extras(config_dict))
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report
@classmethod
def create(cls, config_dict, ctx):
config = PowerBiDashboardSourceConfig.parse_obj(config_dict)

View File

@ -15,6 +15,7 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
import sqlalchemy.dialects.postgresql.base
@ -35,7 +36,12 @@ from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import SqlParsingBuilder
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.incremental_lineage_helper import auto_incremental_lineage
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import (
DatasetContainerSubTypes,
@ -298,7 +304,7 @@ class ProfileMetadata:
dataset_name_to_storage_bytes: Dict[str, int] = field(default_factory=dict)
class SQLAlchemySource(StatefulIngestionSourceBase):
class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
"""A Base class for all SQL Sources that use SQLAlchemy to extend"""
def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str):
@ -348,6 +354,22 @@ class SQLAlchemySource(StatefulIngestionSourceBase):
else:
self._view_definition_cache = {}
@classmethod
def test_connection(cls, config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source = cast(
SQLAlchemySource,
cls.create(config_dict, PipelineContext(run_id="test_connection")),
)
list(source.get_inspectors())
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report
def warn(self, log: logging.Logger, key: str, reason: str) -> None:
self.report.report_warning(key, reason[:100])
log.warning(f"{key} => {reason}")

View File

@ -58,7 +58,13 @@ from datahub.ingestion.api.decorators import (
platform_name,
support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, Source
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
Source,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source import tableau_constant as c
from datahub.ingestion.source.common.subtypes import (
@ -469,7 +475,7 @@ class TableauSourceReport(StaleEntityRemovalSourceReport):
SourceCapability.LINEAGE_FINE,
"Enabled by default, configure using `extract_column_level_lineage`",
)
class TableauSource(StatefulIngestionSourceBase):
class TableauSource(StatefulIngestionSourceBase, TestableSource):
platform = "tableau"
def __hash__(self):
@ -509,6 +515,19 @@ class TableauSource(StatefulIngestionSourceBase):
self._authenticate()
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = TableauConfig.parse_obj_allow_extras(config_dict)
source_config.make_tableau_client()
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report
def close(self) -> None:
try:
if self.server is not None:

View File

@ -143,7 +143,7 @@ class BaseSnowflakeConfig(ConfigModel):
"'oauth_config' is none but should be set when using OAUTH_AUTHENTICATOR authentication"
)
if oauth_config.use_certificate is True:
if oauth_config.provider == OAuthIdentityProvider.OKTA.value:
if oauth_config.provider == OAuthIdentityProvider.OKTA:
raise ValueError(
"Certificate authentication is not supported for Okta."
)

View File

@ -10,20 +10,25 @@ from datahub.configuration.common import DynamicTypedConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.dbt.dbt_common import DBTEntitiesEnabled, EmitDirective
from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig
from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig, DBTCoreSource
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_trino_modified_type,
)
from tests.test_helpers import mce_helpers
from tests.test_helpers import mce_helpers, test_connection_helpers
FROZEN_TIME = "2022-02-03 07:00:00"
GMS_PORT = 8080
GMS_SERVER = f"http://localhost:{GMS_PORT}"
@pytest.fixture(scope="module")
def test_resources_dir(pytestconfig):
return pytestconfig.rootpath / "tests/integration/dbt"
@dataclass
class DbtTestConfig:
run_id: str
@ -195,7 +200,14 @@ class DbtTestConfig:
)
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_ingest(dbt_test_config, pytestconfig, tmp_path, mock_time, requests_mock):
def test_dbt_ingest(
dbt_test_config,
test_resources_dir,
pytestconfig,
tmp_path,
mock_time,
requests_mock,
):
config: DbtTestConfig = dbt_test_config
test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt"
@ -233,11 +245,48 @@ def test_dbt_ingest(dbt_test_config, pytestconfig, tmp_path, mock_time, requests
)
@pytest.mark.parametrize(
"config_dict, is_success",
[
(
{
"manifest_path": "dbt_manifest.json",
"catalog_path": "dbt_catalog.json",
"target_platform": "postgres",
},
True,
),
(
{
"manifest_path": "dbt_manifest.json",
"catalog_path": "dbt_catalog-this-file-does-not-exist.json",
"target_platform": "postgres",
},
False,
),
],
)
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests(pytestconfig, tmp_path, mock_time, **kwargs):
test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt"
def test_dbt_test_connection(test_resources_dir, config_dict, is_success):
config_dict["manifest_path"] = str(
(test_resources_dir / config_dict["manifest_path"]).resolve()
)
config_dict["catalog_path"] = str(
(test_resources_dir / config_dict["catalog_path"]).resolve()
)
report = test_connection_helpers.run_test_connection(DBTCoreSource, config_dict)
if is_success:
test_connection_helpers.assert_basic_connectivity_success(report)
else:
test_connection_helpers.assert_basic_connectivity_failure(
report, "No such file or directory"
)
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs):
# Run the metadata ingestion pipeline.
output_file = tmp_path / "dbt_test_events.json"
golden_path = test_resources_dir / "dbt_test_events_golden.json"
@ -340,9 +389,9 @@ def test_resolve_athena_modified_type(data_type, expected_data_type):
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests_only_assertions(pytestconfig, tmp_path, mock_time, **kwargs):
test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt"
def test_dbt_tests_only_assertions(
test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs
):
# Run the metadata ingestion pipeline.
output_file = tmp_path / "test_only_assertions.json"
@ -418,10 +467,8 @@ def test_dbt_tests_only_assertions(pytestconfig, tmp_path, mock_time, **kwargs):
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_only_test_definitions_and_results(
pytestconfig, tmp_path, mock_time, **kwargs
test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs
):
test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt"
# Run the metadata ingestion pipeline.
output_file = tmp_path / "test_only_definitions_and_assertions.json"

View File

@ -3,18 +3,22 @@ import subprocess
import pytest
from freezegun import freeze_time
from tests.test_helpers import mce_helpers
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.kafka import KafkaSource
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
FROZEN_TIME = "2020-04-14 07:00:00"
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time):
test_resources_dir = pytestconfig.rootpath / "tests/integration/kafka"
@pytest.fixture(scope="module")
def test_resources_dir(pytestconfig):
return pytestconfig.rootpath / "tests/integration/kafka"
@pytest.fixture(scope="module")
def mock_kafka_service(docker_compose_runner, test_resources_dir):
with docker_compose_runner(
test_resources_dir / "docker-compose.yml", "kafka", cleanup=False
) as docker_services:
@ -31,14 +35,67 @@ def test_kafka_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time):
command = f"{test_resources_dir}/send_records.sh {test_resources_dir}"
subprocess.run(command, shell=True, check=True)
# Run the metadata ingestion pipeline.
config_file = (test_resources_dir / "kafka_to_file.yml").resolve()
run_datahub_cmd(["ingest", "-c", f"{config_file}"], tmp_path=tmp_path)
yield docker_compose_runner
# Verify the output.
mce_helpers.check_golden_file(
pytestconfig,
output_path=tmp_path / "kafka_mces.json",
golden_path=test_resources_dir / "kafka_mces_golden.json",
ignore_paths=[],
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_ingest(
mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time
):
# Run the metadata ingestion pipeline.
config_file = (test_resources_dir / "kafka_to_file.yml").resolve()
run_datahub_cmd(["ingest", "-c", f"{config_file}"], tmp_path=tmp_path)
# Verify the output.
mce_helpers.check_golden_file(
pytestconfig,
output_path=tmp_path / "kafka_mces.json",
golden_path=test_resources_dir / "kafka_mces_golden.json",
ignore_paths=[],
)
@pytest.mark.parametrize(
"config_dict, is_success",
[
(
{
"connection": {
"bootstrap": "localhost:29092",
"schema_registry_url": "http://localhost:28081",
},
},
True,
),
(
{
"connection": {
"bootstrap": "localhost:2909",
"schema_registry_url": "http://localhost:2808",
},
},
False,
),
],
)
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_kafka_test_connection(mock_kafka_service, config_dict, is_success):
report = test_connection_helpers.run_test_connection(KafkaSource, config_dict)
if is_success:
test_connection_helpers.assert_basic_connectivity_success(report)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[SourceCapability.SCHEMA_METADATA],
)
else:
test_connection_helpers.assert_basic_connectivity_failure(
report, "Failed to get metadata"
)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
failure_capabilities={
SourceCapability.SCHEMA_METADATA: "Failed to establish a new connection"
},
)

View File

@ -3,7 +3,8 @@ import subprocess
import pytest
from freezegun import freeze_time
from tests.test_helpers import mce_helpers
from datahub.ingestion.source.sql.mysql import MySQLSource
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
@ -75,3 +76,38 @@ def test_mysql_ingest_no_db(
output_path=tmp_path / "mysql_mces.json",
golden_path=test_resources_dir / golden_file,
)
@pytest.mark.parametrize(
"config_dict, is_success",
[
(
{
"host_port": "localhost:53307",
"database": "northwind",
"username": "root",
"password": "example",
},
True,
),
(
{
"host_port": "localhost:5330",
"database": "wrong_db",
"username": "wrong_user",
"password": "wrong_pass",
},
False,
),
],
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_mysql_test_connection(mysql_runner, config_dict, is_success):
report = test_connection_helpers.run_test_connection(MySQLSource, config_dict)
if is_success:
test_connection_helpers.assert_basic_connectivity_success(report)
else:
test_connection_helpers.assert_basic_connectivity_failure(
report, "Connection refused"
)

View File

@ -21,7 +21,7 @@ from datahub.ingestion.source.powerbi.rest_api_wrapper.data_classes import (
Report,
Workspace,
)
from tests.test_helpers import mce_helpers
from tests.test_helpers import mce_helpers, test_connection_helpers
pytestmark = pytest.mark.integration_batch_2
FROZEN_TIME = "2022-02-03 07:00:00"
@ -681,6 +681,27 @@ def test_powerbi_ingest(
)
@freeze_time(FROZEN_TIME)
@mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca)
@pytest.mark.integration
def test_powerbi_test_connection_success(mock_msal):
report = test_connection_helpers.run_test_connection(
PowerBiDashboardSource, default_source_config()
)
test_connection_helpers.assert_basic_connectivity_success(report)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_powerbi_test_connection_failure():
report = test_connection_helpers.run_test_connection(
PowerBiDashboardSource, default_source_config()
)
test_connection_helpers.assert_basic_connectivity_failure(
report, "Unable to get authority configuration"
)
@freeze_time(FROZEN_TIME)
@mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca)
@pytest.mark.integration

View File

@ -28,7 +28,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
)
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass
from datahub.utilities.sqlglot_lineage import SqlParsingResult
from tests.test_helpers import mce_helpers
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
@ -290,6 +290,25 @@ def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph):
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_tableau_test_connection_success():
with mock.patch("datahub.ingestion.source.tableau.Server"):
report = test_connection_helpers.run_test_connection(
TableauSource, config_source_default
)
test_connection_helpers.assert_basic_connectivity_success(report)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_tableau_test_connection_failure():
report = test_connection_helpers.run_test_connection(
TableauSource, config_source_default
)
test_connection_helpers.assert_basic_connectivity_failure(report, "Unable to login")
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_tableau_cll_ingest(pytestconfig, tmp_path, mock_datahub_graph):

View File

@ -0,0 +1,47 @@
from typing import Dict, List, Optional, Type, Union
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
def run_test_connection(
source_cls: Type[TestableSource], config_dict: Dict
) -> TestConnectionReport:
return source_cls.test_connection(config_dict)
def assert_basic_connectivity_success(report: TestConnectionReport) -> None:
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
def assert_basic_connectivity_failure(
report: TestConnectionReport, expected_reason: str
) -> None:
assert report is not None
assert report.basic_connectivity
assert not report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason
assert expected_reason in report.basic_connectivity.failure_reason
def assert_capability_report(
capability_report: Optional[Dict[Union[SourceCapability, str], CapabilityReport]],
success_capabilities: List[SourceCapability] = [],
failure_capabilities: Dict[SourceCapability, str] = {},
) -> None:
assert capability_report
for capability in success_capabilities:
assert capability_report[capability]
assert capability_report[capability].failure_reason is None
for capability, expected_reason in failure_capabilities.items():
assert not capability_report[capability].capable
failure_reason = capability_report[capability].failure_reason
assert failure_reason
assert expected_reason in failure_reason

View File

@ -1,3 +1,4 @@
from typing import Any, Dict
from unittest.mock import MagicMock, patch
import pytest
@ -24,10 +25,20 @@ from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
SnowflakeObjectAccessEntry,
)
from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source
from tests.test_helpers import test_connection_helpers
default_oauth_dict: Dict[str, Any] = {
"client_id": "client_id",
"client_secret": "secret",
"use_certificate": False,
"provider": "microsoft",
"scopes": ["datahub_role"],
"authority_url": "https://dev-abc.okta.com/oauth2/def/v1/token",
}
def test_snowflake_source_throws_error_on_account_id_missing():
with pytest.raises(ValidationError):
with pytest.raises(ValidationError, match="account_id\n field required"):
SnowflakeV2Config.parse_obj(
{
"username": "user",
@ -37,27 +48,21 @@ def test_snowflake_source_throws_error_on_account_id_missing():
def test_no_client_id_invalid_oauth_config():
oauth_dict = {
"provider": "microsoft",
"scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"],
"client_secret": "6Hb9apkbc6HD7",
"authority_url": "https://login.microsoftonline.com/yourorganisation.com",
}
with pytest.raises(ValueError):
oauth_dict = default_oauth_dict.copy()
del oauth_dict["client_id"]
with pytest.raises(ValueError, match="client_id\n field required"):
OAuthConfiguration.parse_obj(oauth_dict)
def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_false():
oauth_dict = {
"client_id": "882e9831-7ea51cb2b954",
"provider": "microsoft",
"scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"],
"use_certificate": False,
"authority_url": "https://login.microsoftonline.com/yourorganisation.com",
}
oauth_dict = default_oauth_dict.copy()
del oauth_dict["client_secret"]
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="'oauth_config.client_secret' was none but should be set when using use_certificate false for oauth_config",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
@ -68,16 +73,13 @@ def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_f
def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_certificate_is_true():
oauth_dict = {
"client_id": "882e9831-7ea51cb2b954",
"provider": "microsoft",
"scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"],
"use_certificate": True,
"authority_url": "https://login.microsoftonline.com/yourorganisation.com",
"encoded_oauth_public_key": "fkdsfhkshfkjsdfiuwrwfkjhsfskfhksjf==",
}
oauth_dict = default_oauth_dict.copy()
oauth_dict["use_certificate"] = True
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="'base64_encoded_oauth_private_key' was none but should be set when using certificate for oauth_config",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
@ -88,16 +90,13 @@ def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_cert
def test_snowflake_oauth_okta_does_not_support_certificate():
oauth_dict = {
"client_id": "882e9831-7ea51cb2b954",
"provider": "okta",
"scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"],
"use_certificate": True,
"authority_url": "https://login.microsoftonline.com/yourorganisation.com",
"encoded_oauth_public_key": "fkdsfhkshfkjsdfiuwrwfkjhsfskfhksjf==",
}
oauth_dict = default_oauth_dict.copy()
oauth_dict["use_certificate"] = True
oauth_dict["provider"] = "okta"
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="Certificate authentication is not supported for Okta."
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
@ -108,79 +107,52 @@ def test_snowflake_oauth_okta_does_not_support_certificate():
def test_snowflake_oauth_happy_paths():
okta_dict = {
"client_id": "client_id",
"client_secret": "secret",
"provider": "okta",
"scopes": ["datahub_role"],
"authority_url": "https://dev-abc.okta.com/oauth2/def/v1/token",
}
oauth_dict = default_oauth_dict.copy()
oauth_dict["provider"] = "okta"
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": okta_dict,
"oauth_config": oauth_dict,
}
)
oauth_dict["use_certificate"] = True
oauth_dict["provider"] = "microsoft"
oauth_dict["encoded_oauth_public_key"] = "publickey"
oauth_dict["encoded_oauth_private_key"] = "privatekey"
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
microsoft_dict = {
"client_id": "client_id",
"provider": "microsoft",
"scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"],
"use_certificate": True,
"authority_url": "https://login.microsoftonline.com/yourorganisation.com",
"encoded_oauth_public_key": "publickey",
"encoded_oauth_private_key": "privatekey",
}
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": microsoft_dict,
}
)
default_config_dict: Dict[str, Any] = {
"username": "user",
"password": "password",
"account_id": "https://acctname.snowflakecomputing.com",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
def test_account_id_is_added_when_host_port_is_present():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"host_port": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
config_dict = default_config_dict.copy()
del config_dict["account_id"]
config_dict["host_port"] = "acctname"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.account_id == "acctname"
def test_account_id_with_snowflake_host_suffix():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "https://acctname.snowflakecomputing.com",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.account_id == "acctname"
def test_snowflake_uri_default_authentication():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user:password@acctname"
"?application=acryl_datahub"
@ -191,17 +163,10 @@ def test_snowflake_uri_default_authentication():
def test_snowflake_uri_external_browser_authentication():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"authentication_type": "EXTERNAL_BROWSER_AUTHENTICATOR",
}
)
config_dict = default_config_dict.copy()
del config_dict["password"]
config_dict["authentication_type"] = "EXTERNAL_BROWSER_AUTHENTICATOR"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user@acctname"
"?application=acryl_datahub"
@ -212,18 +177,12 @@ def test_snowflake_uri_external_browser_authentication():
def test_snowflake_uri_key_pair_authentication():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"authentication_type": "KEY_PAIR_AUTHENTICATOR",
"private_key_path": "/a/random/path",
"private_key_password": "a_random_password",
}
)
config_dict = default_config_dict.copy()
del config_dict["password"]
config_dict["authentication_type"] = "KEY_PAIR_AUTHENTICATOR"
config_dict["private_key_path"] = "/a/random/path"
config_dict["private_key_password"] = "a_random_password"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user@acctname"
@ -235,63 +194,35 @@ def test_snowflake_uri_key_pair_authentication():
def test_options_contain_connect_args():
config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
config = SnowflakeV2Config.parse_obj(default_config_dict)
connect_args = config.get_options().get("connect_args")
assert connect_args is not None
def test_snowflake_config_with_view_lineage_no_table_lineage_throws_error():
with pytest.raises(ValidationError):
SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"include_view_lineage": True,
"include_table_lineage": False,
}
)
config_dict = default_config_dict.copy()
config_dict["include_view_lineage"] = True
config_dict["include_table_lineage"] = False
with pytest.raises(
ValidationError,
match="include_table_lineage must be True for include_view_lineage to be set",
):
SnowflakeV2Config.parse_obj(config_dict)
def test_snowflake_config_with_column_lineage_no_table_lineage_throws_error():
with pytest.raises(ValidationError):
SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"include_column_lineage": True,
"include_table_lineage": False,
}
)
config_dict = default_config_dict.copy()
config_dict["include_column_lineage"] = True
config_dict["include_table_lineage"] = False
with pytest.raises(
ValidationError,
match="include_table_lineage must be True for include_column_lineage to be set",
):
SnowflakeV2Config.parse_obj(config_dict)
def test_snowflake_config_with_no_connect_args_returns_base_connect_args():
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.get_options()["connect_args"] is not None
assert config.get_options()["connect_args"] == {
CLIENT_PREFETCH_THREADS: 10,
@ -300,7 +231,10 @@ def test_snowflake_config_with_no_connect_args_returns_base_connect_args():
def test_private_key_set_but_auth_not_changed():
with pytest.raises(ValidationError):
with pytest.raises(
ValidationError,
match="Either `private_key` and `private_key_path` is set but `authentication_type` is DEFAULT_AUTHENTICATOR. Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "acctname",
@ -310,19 +244,11 @@ def test_private_key_set_but_auth_not_changed():
def test_snowflake_config_with_connect_args_overrides_base_connect_args():
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
"account_id": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"connect_args": {
CLIENT_PREFETCH_THREADS: 5,
},
}
)
config_dict = default_config_dict.copy()
config_dict["connect_args"] = {
CLIENT_PREFETCH_THREADS: 5,
}
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_options()["connect_args"] is not None
assert config.get_options()["connect_args"][CLIENT_PREFETCH_THREADS] == 5
assert config.get_options()["connect_args"][CLIENT_SESSION_KEEP_ALIVE] is True
@ -331,35 +257,20 @@ def test_snowflake_config_with_connect_args_overrides_base_connect_args():
@patch("snowflake.connector.connect")
def test_test_connection_failure(mock_connect):
mock_connect.side_effect = Exception("Failed to connect to snowflake")
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
report = SnowflakeV2Source.test_connection(config)
assert report is not None
assert report.basic_connectivity
assert not report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason
assert "Failed to connect to snowflake" in report.basic_connectivity.failure_reason
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_failure(
report, "Failed to connect to snowflake"
)
@patch("snowflake.connector.connect")
def test_test_connection_basic_success(mock_connect):
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
report = SnowflakeV2Source.test_connection(config)
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
def setup_mock_connect(mock_connect, query_results=None):
@ -400,31 +311,18 @@ def test_test_connection_no_warehouse(mock_connect):
return []
raise ValueError(f"Unexpected query: {query}")
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
setup_mock_connect(mock_connect, query_results)
report = SnowflakeV2Source.test_connection(config)
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
assert report.capability_report
assert report.capability_report[SourceCapability.CONTAINERS].capable
assert not report.capability_report[SourceCapability.SCHEMA_METADATA].capable
failure_reason = report.capability_report[
SourceCapability.SCHEMA_METADATA
].failure_reason
assert failure_reason
assert (
"Current role TEST_ROLE does not have permissions to use warehouse"
in failure_reason
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[SourceCapability.CONTAINERS],
failure_capabilities={
SourceCapability.SCHEMA_METADATA: "Current role TEST_ROLE does not have permissions to use warehouse"
},
)
@ -445,25 +343,17 @@ def test_test_connection_capability_schema_failure(mock_connect):
setup_mock_connect(mock_connect, query_results)
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
report = SnowflakeV2Source.test_connection(config)
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
assert report.capability_report
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
assert report.capability_report[SourceCapability.CONTAINERS].capable
assert not report.capability_report[SourceCapability.SCHEMA_METADATA].capable
assert (
report.capability_report[SourceCapability.SCHEMA_METADATA].failure_reason
is not None
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[SourceCapability.CONTAINERS],
failure_capabilities={
SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them"
},
)
@ -488,24 +378,19 @@ def test_test_connection_capability_schema_success(mock_connect):
setup_mock_connect(mock_connect, query_results)
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
report = SnowflakeV2Source.test_connection(config)
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
assert report.capability_report
assert report.capability_report[SourceCapability.CONTAINERS].capable
assert report.capability_report[SourceCapability.SCHEMA_METADATA].capable
assert report.capability_report[SourceCapability.DESCRIPTIONS].capable
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[
SourceCapability.CONTAINERS,
SourceCapability.SCHEMA_METADATA,
SourceCapability.DESCRIPTIONS,
],
)
@patch("snowflake.connector.connect")
@ -538,25 +423,21 @@ def test_test_connection_capability_all_success(mock_connect):
setup_mock_connect(mock_connect, query_results)
config = {
"username": "user",
"password": "password",
"account_id": "missing",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
report = SnowflakeV2Source.test_connection(config)
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
assert report.capability_report
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
assert report.capability_report[SourceCapability.CONTAINERS].capable
assert report.capability_report[SourceCapability.SCHEMA_METADATA].capable
assert report.capability_report[SourceCapability.DATA_PROFILING].capable
assert report.capability_report[SourceCapability.DESCRIPTIONS].capable
assert report.capability_report[SourceCapability.LINEAGE_COARSE].capable
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[
SourceCapability.CONTAINERS,
SourceCapability.SCHEMA_METADATA,
SourceCapability.DATA_PROFILING,
SourceCapability.DESCRIPTIONS,
SourceCapability.LINEAGE_COARSE,
],
)
def test_aws_cloud_region_from_snowflake_region_id():
@ -610,11 +491,10 @@ def test_azure_cloud_region_from_snowflake_region_id():
def test_unknown_cloud_region_from_snowflake_region_id():
with pytest.raises(Exception) as e:
with pytest.raises(Exception, match="Unknown snowflake region"):
SnowflakeV2Source.get_cloud_region_from_snowflake_region_id(
"somecloud_someregion"
)
assert "Unknown snowflake region" in str(e)
def test_snowflake_object_access_entry_missing_object_id():

View File

@ -1,8 +1,7 @@
from typing import Dict
from unittest.mock import Mock
from unittest import mock
import pytest
from sqlalchemy.engine.reflection import Inspector
from datahub.ingestion.source.sql.sql_common import PipelineContext, SQLAlchemySource
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
@ -13,19 +12,24 @@ from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
class _TestSQLAlchemyConfig(SQLCommonConfig):
def get_sql_alchemy_url(self):
pass
return "mysql+pymysql://user:pass@localhost:5330"
class _TestSQLAlchemySource(SQLAlchemySource):
pass
@classmethod
def create(cls, config_dict, ctx):
config = _TestSQLAlchemyConfig.parse_obj(config_dict)
return cls(config, ctx, "TEST")
def get_test_sql_alchemy_source():
return _TestSQLAlchemySource.create(
config_dict={}, ctx=PipelineContext(run_id="test_ctx")
)
def test_generate_foreign_key():
config: SQLCommonConfig = _TestSQLAlchemyConfig()
ctx: PipelineContext = PipelineContext(run_id="test_ctx")
platform: str = "TEST"
inspector: Inspector = Mock()
source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform)
source = get_test_sql_alchemy_source()
fk_dict: Dict[str, str] = {
"name": "test_constraint",
"referred_table": "test_table",
@ -37,7 +41,7 @@ def test_generate_foreign_key():
dataset_urn="test_urn",
schema="test_schema",
fk_dict=fk_dict,
inspector=inspector,
inspector=mock.Mock(),
)
assert fk_dict.get("name") == foreign_key.name
@ -48,11 +52,7 @@ def test_generate_foreign_key():
def test_use_source_schema_for_foreign_key_if_not_specified():
config: SQLCommonConfig = _TestSQLAlchemyConfig()
ctx: PipelineContext = PipelineContext(run_id="test_ctx")
platform: str = "TEST"
inspector: Inspector = Mock()
source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform)
source = get_test_sql_alchemy_source()
fk_dict: Dict[str, str] = {
"name": "test_constraint",
"referred_table": "test_table",
@ -63,7 +63,7 @@ def test_use_source_schema_for_foreign_key_if_not_specified():
dataset_urn="test_urn",
schema="test_schema",
fk_dict=fk_dict,
inspector=inspector,
inspector=mock.Mock(),
)
assert fk_dict.get("name") == foreign_key.name
@ -105,14 +105,32 @@ def test_get_platform_from_sqlalchemy_uri(uri: str, expected_platform: str) -> N
def test_get_db_schema_with_dots_in_view_name():
config: SQLCommonConfig = _TestSQLAlchemyConfig()
ctx: PipelineContext = PipelineContext(run_id="test_ctx")
platform: str = "TEST"
source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform)
source = get_test_sql_alchemy_source()
database, schema = source.get_db_schema(
dataset_identifier="database.schema.long.view.name1"
)
assert database == "database"
assert schema == "schema"
def test_test_connection_success():
source = get_test_sql_alchemy_source()
with mock.patch(
"datahub.ingestion.source.sql.sql_common.SQLAlchemySource.get_inspectors",
side_effect=lambda: [],
):
report = source.test_connection({})
assert report is not None
assert report.basic_connectivity
assert report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason is None
def test_test_connection_failure():
source = get_test_sql_alchemy_source()
report = source.test_connection({})
assert report is not None
assert report.basic_connectivity
assert not report.basic_connectivity.capable
assert report.basic_connectivity.failure_reason
assert "Connection refused" in report.basic_connectivity.failure_reason