mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-02 19:58:59 +00:00
feat(cassandra): Support ssl auth with cassandra (#13465)
This commit is contained in:
parent
a97af36e93
commit
aeda8f4c95
@ -1,3 +1,4 @@
|
||||
import ssl
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@ -128,6 +129,23 @@ class CassandraAPI:
|
||||
|
||||
self._cassandra_session = cluster.connect()
|
||||
return True
|
||||
|
||||
ssl_context = None
|
||||
if self.config.ssl_ca_certs:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.load_verify_locations(self.config.ssl_ca_certs)
|
||||
if self.config.ssl_certfile and self.config.ssl_keyfile:
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=self.config.ssl_certfile,
|
||||
keyfile=self.config.ssl_keyfile,
|
||||
)
|
||||
elif self.config.ssl_certfile or self.config.ssl_keyfile:
|
||||
# If one is provided, the other must be too.
|
||||
# This is a simplification; real-world scenarios might allow one without the other depending on setup.
|
||||
raise ValueError(
|
||||
"Both ssl_certfile and ssl_keyfile must be provided if one is specified."
|
||||
)
|
||||
|
||||
if self.config.username and self.config.password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.config.username, password=self.config.password
|
||||
@ -136,12 +154,14 @@ class CassandraAPI:
|
||||
[self.config.contact_point],
|
||||
port=self.config.port,
|
||||
auth_provider=auth_provider,
|
||||
ssl_context=ssl_context,
|
||||
load_balancing_policy=None,
|
||||
)
|
||||
else:
|
||||
cluster = Cluster(
|
||||
[self.config.contact_point],
|
||||
port=self.config.port,
|
||||
ssl_context=ssl_context,
|
||||
load_balancing_policy=None,
|
||||
)
|
||||
|
||||
|
||||
@ -79,6 +79,21 @@ class CassandraSourceConfig(
|
||||
description="Configuration for cloud-based Cassandra, such as DataStax Astra DB.",
|
||||
)
|
||||
|
||||
ssl_ca_certs: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the CA certificate file for SSL connections.",
|
||||
)
|
||||
|
||||
ssl_certfile: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the SSL certificate file for SSL connections.",
|
||||
)
|
||||
|
||||
ssl_keyfile: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the SSL key file for SSL connections.",
|
||||
)
|
||||
|
||||
keyspace_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="Regex patterns to filter keyspaces for ingestion.",
|
||||
|
||||
@ -3,10 +3,20 @@ import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Unit tests for CassandraAPI SSL Configuration
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.api.source import SourceReport
|
||||
from datahub.ingestion.source.cassandra.cassandra import CassandraToSchemaFieldConverter
|
||||
from datahub.ingestion.source.cassandra.cassandra_api import CassandraColumn
|
||||
from datahub.ingestion.source.cassandra.cassandra_api import (
|
||||
CassandraAPI,
|
||||
CassandraColumn,
|
||||
)
|
||||
from datahub.ingestion.source.cassandra.cassandra_config import (
|
||||
CassandraSourceConfig,
|
||||
)
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -78,3 +88,84 @@ def test_cassandra_schema_conversion(
|
||||
def test_no_properties_in_mappings_schema() -> None:
|
||||
fields = list(CassandraToSchemaFieldConverter.get_schema_fields([]))
|
||||
assert fields == []
|
||||
|
||||
|
||||
def _get_base_config_dict() -> dict:
|
||||
return {
|
||||
"contact_point": "localhost",
|
||||
"port": 9042,
|
||||
}
|
||||
|
||||
|
||||
def test_authenticate_no_ssl():
|
||||
config_dict = _get_base_config_dict()
|
||||
config = CassandraSourceConfig.parse_obj(config_dict)
|
||||
report = MagicMock(spec=SourceReport)
|
||||
api = CassandraAPI(config, report)
|
||||
|
||||
with patch(
|
||||
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
|
||||
) as mock_cluster:
|
||||
mock_cluster.return_value.connect.return_value = MagicMock()
|
||||
assert api.authenticate()
|
||||
mock_cluster.assert_called_once()
|
||||
assert mock_cluster.call_args[1].get("ssl_context") is None
|
||||
report.failure.assert_not_called()
|
||||
|
||||
|
||||
def test_authenticate_ssl_ca_certs():
|
||||
config_dict = _get_base_config_dict()
|
||||
config_dict["ssl_ca_certs"] = "ca.pem"
|
||||
config = CassandraSourceConfig.parse_obj(config_dict)
|
||||
report = MagicMock(spec=SourceReport)
|
||||
api = CassandraAPI(config, report)
|
||||
|
||||
with patch(
|
||||
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
|
||||
) as mock_cluster, patch(
|
||||
"datahub.ingestion.source.cassandra.cassandra_api.ssl.SSLContext"
|
||||
) as mock_ssl_context:
|
||||
mock_ssl_instance = MagicMock()
|
||||
mock_ssl_context.return_value = mock_ssl_instance
|
||||
mock_cluster.return_value.connect.return_value = MagicMock()
|
||||
|
||||
assert api.authenticate()
|
||||
|
||||
mock_ssl_context.assert_called_once_with(ANY) # ssl.PROTOCOL_TLS_CLIENT
|
||||
mock_ssl_instance.load_verify_locations.assert_called_once_with("ca.pem")
|
||||
mock_ssl_instance.load_cert_chain.assert_not_called()
|
||||
|
||||
mock_cluster.assert_called_once()
|
||||
assert mock_cluster.call_args[1].get("ssl_context") == mock_ssl_instance
|
||||
report.failure.assert_not_called()
|
||||
|
||||
|
||||
def test_authenticate_ssl_all_certs():
|
||||
config_dict = _get_base_config_dict()
|
||||
config_dict["ssl_ca_certs"] = "ca.pem"
|
||||
config_dict["ssl_certfile"] = "client.crt"
|
||||
config_dict["ssl_keyfile"] = "client.key"
|
||||
config = CassandraSourceConfig.parse_obj(config_dict)
|
||||
report = MagicMock(spec=SourceReport)
|
||||
api = CassandraAPI(config, report)
|
||||
|
||||
with patch(
|
||||
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
|
||||
) as mock_cluster, patch(
|
||||
"datahub.ingestion.source.cassandra.cassandra_api.ssl.SSLContext"
|
||||
) as mock_ssl_context:
|
||||
mock_ssl_instance = MagicMock()
|
||||
mock_ssl_context.return_value = mock_ssl_instance
|
||||
mock_cluster.return_value.connect.return_value = MagicMock()
|
||||
|
||||
assert api.authenticate()
|
||||
|
||||
mock_ssl_context.assert_called_once_with(ANY) # ssl.PROTOCOL_TLS_CLIENT
|
||||
mock_ssl_instance.load_verify_locations.assert_called_once_with("ca.pem")
|
||||
mock_ssl_instance.load_cert_chain.assert_called_once_with(
|
||||
certfile="client.crt", keyfile="client.key"
|
||||
)
|
||||
|
||||
mock_cluster.assert_called_once()
|
||||
assert mock_cluster.call_args[1].get("ssl_context") == mock_ssl_instance
|
||||
report.failure.assert_not_called()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user