feat(cassandra): Support ssl auth with cassandra (#13465)

This commit is contained in:
Gabe Lyons 2025-05-12 13:15:49 -04:00 committed by GitHub
parent a97af36e93
commit aeda8f4c95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 1 deletions

View File

@ -1,3 +1,4 @@
import ssl
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@ -128,6 +129,23 @@ class CassandraAPI:
self._cassandra_session = cluster.connect()
return True
ssl_context = None
if self.config.ssl_ca_certs:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(self.config.ssl_ca_certs)
if self.config.ssl_certfile and self.config.ssl_keyfile:
ssl_context.load_cert_chain(
certfile=self.config.ssl_certfile,
keyfile=self.config.ssl_keyfile,
)
elif self.config.ssl_certfile or self.config.ssl_keyfile:
# If one is provided, the other must be too.
# This is a simplification; real-world scenarios might allow one without the other depending on setup.
raise ValueError(
"Both ssl_certfile and ssl_keyfile must be provided if one is specified."
)
if self.config.username and self.config.password:
auth_provider = PlainTextAuthProvider(
username=self.config.username, password=self.config.password
@ -136,12 +154,14 @@ class CassandraAPI:
[self.config.contact_point],
port=self.config.port,
auth_provider=auth_provider,
ssl_context=ssl_context,
load_balancing_policy=None,
)
else:
cluster = Cluster(
[self.config.contact_point],
port=self.config.port,
ssl_context=ssl_context,
load_balancing_policy=None,
)

View File

@ -79,6 +79,21 @@ class CassandraSourceConfig(
description="Configuration for cloud-based Cassandra, such as DataStax Astra DB.",
)
ssl_ca_certs: Optional[str] = Field(
default=None,
description="Path to the CA certificate file for SSL connections.",
)
ssl_certfile: Optional[str] = Field(
default=None,
description="Path to the SSL certificate file for SSL connections.",
)
ssl_keyfile: Optional[str] = Field(
default=None,
description="Path to the SSL key file for SSL connections.",
)
keyspace_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="Regex patterns to filter keyspaces for ingestion.",

View File

@ -3,10 +3,20 @@ import logging
import re
from typing import Any, Dict, List, Tuple
# Unit tests for CassandraAPI SSL Configuration
from unittest.mock import ANY, MagicMock, patch
import pytest
from datahub.ingestion.api.source import SourceReport
from datahub.ingestion.source.cassandra.cassandra import CassandraToSchemaFieldConverter
from datahub.ingestion.source.cassandra.cassandra_api import CassandraColumn
from datahub.ingestion.source.cassandra.cassandra_api import (
CassandraAPI,
CassandraColumn,
)
from datahub.ingestion.source.cassandra.cassandra_config import (
CassandraSourceConfig,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
logger = logging.getLogger(__name__)
@ -78,3 +88,84 @@ def test_cassandra_schema_conversion(
def test_no_properties_in_mappings_schema() -> None:
fields = list(CassandraToSchemaFieldConverter.get_schema_fields([]))
assert fields == []
def _get_base_config_dict() -> dict:
return {
"contact_point": "localhost",
"port": 9042,
}
def test_authenticate_no_ssl():
config_dict = _get_base_config_dict()
config = CassandraSourceConfig.parse_obj(config_dict)
report = MagicMock(spec=SourceReport)
api = CassandraAPI(config, report)
with patch(
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
) as mock_cluster:
mock_cluster.return_value.connect.return_value = MagicMock()
assert api.authenticate()
mock_cluster.assert_called_once()
assert mock_cluster.call_args[1].get("ssl_context") is None
report.failure.assert_not_called()
def test_authenticate_ssl_ca_certs():
config_dict = _get_base_config_dict()
config_dict["ssl_ca_certs"] = "ca.pem"
config = CassandraSourceConfig.parse_obj(config_dict)
report = MagicMock(spec=SourceReport)
api = CassandraAPI(config, report)
with patch(
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
) as mock_cluster, patch(
"datahub.ingestion.source.cassandra.cassandra_api.ssl.SSLContext"
) as mock_ssl_context:
mock_ssl_instance = MagicMock()
mock_ssl_context.return_value = mock_ssl_instance
mock_cluster.return_value.connect.return_value = MagicMock()
assert api.authenticate()
mock_ssl_context.assert_called_once_with(ANY) # ssl.PROTOCOL_TLS_CLIENT
mock_ssl_instance.load_verify_locations.assert_called_once_with("ca.pem")
mock_ssl_instance.load_cert_chain.assert_not_called()
mock_cluster.assert_called_once()
assert mock_cluster.call_args[1].get("ssl_context") == mock_ssl_instance
report.failure.assert_not_called()
def test_authenticate_ssl_all_certs():
config_dict = _get_base_config_dict()
config_dict["ssl_ca_certs"] = "ca.pem"
config_dict["ssl_certfile"] = "client.crt"
config_dict["ssl_keyfile"] = "client.key"
config = CassandraSourceConfig.parse_obj(config_dict)
report = MagicMock(spec=SourceReport)
api = CassandraAPI(config, report)
with patch(
"datahub.ingestion.source.cassandra.cassandra_api.Cluster"
) as mock_cluster, patch(
"datahub.ingestion.source.cassandra.cassandra_api.ssl.SSLContext"
) as mock_ssl_context:
mock_ssl_instance = MagicMock()
mock_ssl_context.return_value = mock_ssl_instance
mock_cluster.return_value.connect.return_value = MagicMock()
assert api.authenticate()
mock_ssl_context.assert_called_once_with(ANY) # ssl.PROTOCOL_TLS_CLIENT
mock_ssl_instance.load_verify_locations.assert_called_once_with("ca.pem")
mock_ssl_instance.load_cert_chain.assert_called_once_with(
certfile="client.crt", keyfile="client.key"
)
mock_cluster.assert_called_once()
assert mock_cluster.call_args[1].get("ssl_context") == mock_ssl_instance
report.failure.assert_not_called()