mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-27 09:55:36 +00:00
Add ssl support to hive (#22831)
* Add ssl support to hive * Added missing ts files * Added version to pure transport * Added Tests * fix tests add missing files
This commit is contained in:
parent
d33ffe40a3
commit
20e18d4f9f
@ -83,6 +83,7 @@ COMMONS = {
|
||||
"cramjam~=2.7",
|
||||
},
|
||||
"hive": {
|
||||
"pure-transport==0.2.0",
|
||||
"presto-types-parser>=0.0.2",
|
||||
VERSIONS["pyhive"],
|
||||
},
|
||||
|
@ -48,11 +48,20 @@ from metadata.ingestion.connections.test_connections import (
|
||||
test_connection_db_schema_sources,
|
||||
)
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.database.hive.custom_hive_connection import (
|
||||
CustomHiveConnection,
|
||||
)
|
||||
from metadata.utils.constants import THREE_MIN
|
||||
from metadata.utils.ssl_manager import check_ssl_and_init
|
||||
|
||||
HIVE_POSTGRES_SCHEME = "hive+postgres"
|
||||
HIVE_MYSQL_SCHEME = "hive+mysql"
|
||||
|
||||
# Monkey-patch the pyhive.hive module to use our custom connection
|
||||
import pyhive.hive
|
||||
|
||||
pyhive.hive.Connection = CustomHiveConnection
|
||||
|
||||
|
||||
def get_connection_url(connection: HiveConnection) -> str:
|
||||
"""
|
||||
@ -113,6 +122,34 @@ def get_connection(connection: HiveConnection) -> Engine:
|
||||
"kerberos_service_name"
|
||||
] = connection.kerberosServiceName
|
||||
|
||||
# Handle SSL using SSL manager (following established patterns)
|
||||
ssl_manager = check_ssl_and_init(connection)
|
||||
if ssl_manager:
|
||||
connection = ssl_manager.setup_ssl(connection)
|
||||
# Store SSL manager for cleanup
|
||||
connection._ssl_manager = ssl_manager
|
||||
|
||||
# Add SSL configuration to connection arguments if SSL is enabled
|
||||
if hasattr(connection, "useSSL") and connection.useSSL:
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
connection.connectionArguments.root["use_ssl"] = True
|
||||
|
||||
# Add SSL certificate configuration if available
|
||||
if hasattr(connection, "sslConfig") and connection.sslConfig:
|
||||
if connection.sslConfig.root.sslCertificate:
|
||||
connection.connectionArguments.root[
|
||||
"ssl_certfile"
|
||||
] = connection.sslConfig.root.sslCertificate
|
||||
if connection.sslConfig.root.sslKey:
|
||||
connection.connectionArguments.root[
|
||||
"ssl_keyfile"
|
||||
] = connection.sslConfig.root.sslKey
|
||||
if connection.sslConfig.root.caCertificate:
|
||||
connection.connectionArguments.root[
|
||||
"ssl_ca_certs"
|
||||
] = connection.sslConfig.root.caCertificate
|
||||
|
||||
return create_generic_db_connection(
|
||||
connection=connection,
|
||||
get_connection_url_fn=get_connection_url,
|
||||
|
@ -0,0 +1,195 @@
|
||||
import contextlib
|
||||
import getpass
|
||||
import ssl
|
||||
|
||||
import thrift.protocol.TBinaryProtocol
|
||||
import thrift.transport.TSocket
|
||||
import thrift.transport.TTransport
|
||||
from pyhive.hive import Connection as BaseConnection
|
||||
from pyhive.hive import _check_status, get_installed_sasl
|
||||
from TCLIService import TCLIService, ttypes
|
||||
|
||||
|
||||
class CustomHiveConnection(BaseConnection):
|
||||
"""Custom Hive connection that integrates puretransport and SSL certificate support"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host=None,
|
||||
port=None,
|
||||
scheme=None,
|
||||
username=None,
|
||||
database="default",
|
||||
auth=None,
|
||||
configuration=None,
|
||||
kerberos_service_name=None,
|
||||
password=None,
|
||||
check_hostname=None,
|
||||
ssl_cert=None,
|
||||
thrift_transport=None,
|
||||
use_ssl=False,
|
||||
ssl_certfile=None,
|
||||
ssl_keyfile=None,
|
||||
ssl_ca_certs=None,
|
||||
ssl_cert_reqs=None,
|
||||
ssl_check_hostname=None,
|
||||
):
|
||||
"""Connect to HiveServer2 with integrated puretransport and SSL support"""
|
||||
|
||||
# Handle HTTPS scheme with SSL context
|
||||
if scheme in ("https", "http") and thrift_transport is None:
|
||||
port = port or 1000
|
||||
ssl_context = None
|
||||
if scheme == "https":
|
||||
from ssl import create_default_context
|
||||
|
||||
ssl_context = create_default_context()
|
||||
ssl_context.check_hostname = check_hostname == "true"
|
||||
ssl_cert = ssl_cert or "none"
|
||||
ssl_cert_parameter_map = {
|
||||
"none": 0, # CERT_NONE
|
||||
"optional": 1, # CERT_OPTIONAL
|
||||
"required": 2, # CERT_REQUIRED
|
||||
}
|
||||
ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, 0)
|
||||
thrift_transport = thrift.transport.THttpClient.THttpClient(
|
||||
uri_or_host="{scheme}://{host}:{port}/cliservice/".format(
|
||||
scheme=scheme, host=host, port=port
|
||||
),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
|
||||
if auth in ("BASIC", "NOSASL", "NONE", None):
|
||||
# Always needs the Authorization header
|
||||
self._set_authorization_header(thrift_transport, username, password)
|
||||
elif auth == "KERBEROS" and kerberos_service_name:
|
||||
self._set_kerberos_header(thrift_transport, kerberos_service_name, host)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Authentication is not valid use one of:"
|
||||
"BASIC, NOSASL, KERBEROS, NONE"
|
||||
)
|
||||
host, port, auth, kerberos_service_name, password = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
username = username or getpass.getuser()
|
||||
configuration = configuration or {}
|
||||
|
||||
if (password is not None) != (auth in ("LDAP", "CUSTOM")):
|
||||
raise ValueError(
|
||||
"Password should be set if and only if in LDAP or CUSTOM mode; "
|
||||
"Remove password or use one of those modes"
|
||||
)
|
||||
if (kerberos_service_name is not None) != (auth == "KERBEROS"):
|
||||
raise ValueError(
|
||||
"kerberos_service_name should be set if and only if in KERBEROS mode"
|
||||
)
|
||||
|
||||
# Use puretransport if SSL is enabled or if thrift_transport is provided
|
||||
if use_ssl or thrift_transport is not None:
|
||||
if thrift_transport is not None:
|
||||
# Use provided thrift_transport
|
||||
self._transport = thrift_transport
|
||||
else:
|
||||
# Create puretransport with SSL
|
||||
import puretransport
|
||||
|
||||
# Prepare socket_kwargs for SSL
|
||||
socket_kwargs = {}
|
||||
if ssl_certfile:
|
||||
socket_kwargs["certfile"] = ssl_certfile
|
||||
if ssl_keyfile:
|
||||
socket_kwargs["keyfile"] = ssl_keyfile
|
||||
if ssl_ca_certs:
|
||||
socket_kwargs["ca_certs"] = ssl_ca_certs
|
||||
if ssl_cert_reqs is not None:
|
||||
socket_kwargs["cert_reqs"] = ssl_cert_reqs
|
||||
elif use_ssl:
|
||||
socket_kwargs["cert_reqs"] = ssl.CERT_NONE
|
||||
|
||||
# Create puretransport
|
||||
self._transport = puretransport.transport_factory(
|
||||
host=host or "localhost",
|
||||
port=port or 10000,
|
||||
username=username,
|
||||
password=password or username,
|
||||
use_ssl=use_ssl,
|
||||
socket_kwargs=socket_kwargs if socket_kwargs else None,
|
||||
)
|
||||
else:
|
||||
# Use standard connection logic
|
||||
if port is None:
|
||||
port = 10000
|
||||
if auth is None:
|
||||
auth = "NONE"
|
||||
socket = thrift.transport.TSocket.TSocket(host, port)
|
||||
if auth == "NOSASL":
|
||||
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
|
||||
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
|
||||
elif auth in ("LDAP", "KERBEROS", "NONE", "CUSTOM"):
|
||||
# Defer import so package dependency is optional
|
||||
import thrift_sasl
|
||||
|
||||
if auth == "KERBEROS":
|
||||
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
|
||||
sasl_auth = "GSSAPI"
|
||||
else:
|
||||
sasl_auth = "PLAIN"
|
||||
if password is None:
|
||||
# Password doesn't matter in NONE mode, just needs to be nonempty.
|
||||
password = "x"
|
||||
|
||||
self._transport = thrift_sasl.TSaslClientTransport(
|
||||
lambda: get_installed_sasl(
|
||||
host=host,
|
||||
sasl_auth=sasl_auth,
|
||||
service=kerberos_service_name,
|
||||
username=username,
|
||||
password=password,
|
||||
),
|
||||
sasl_auth,
|
||||
socket,
|
||||
)
|
||||
else:
|
||||
# All HS2 config options:
|
||||
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
|
||||
# PAM currently left to end user via thrift_transport option.
|
||||
raise NotImplementedError(
|
||||
"Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM "
|
||||
"authentication are supported, got {}".format(auth)
|
||||
)
|
||||
|
||||
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
|
||||
self._client = TCLIService.Client(protocol)
|
||||
# oldest version that still contains features we care about
|
||||
# "V6 uses binary type for binary payload (was string) and uses columnar result set"
|
||||
protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
|
||||
|
||||
try:
|
||||
self._transport.open()
|
||||
open_session_req = ttypes.TOpenSessionReq(
|
||||
client_protocol=protocol_version,
|
||||
configuration=configuration,
|
||||
username=username,
|
||||
)
|
||||
response = self._client.OpenSession(open_session_req)
|
||||
_check_status(response)
|
||||
assert (
|
||||
response.sessionHandle is not None
|
||||
), "Expected a session from OpenSession"
|
||||
self._sessionHandle = response.sessionHandle
|
||||
assert (
|
||||
response.serverProtocolVersion == protocol_version
|
||||
), "Unable to handle protocol version {}".format(
|
||||
response.serverProtocolVersion
|
||||
)
|
||||
with contextlib.closing(self.cursor()) as cursor:
|
||||
cursor.execute("USE `{}`".format(database))
|
||||
except:
|
||||
self._transport.close()
|
||||
raise
|
@ -37,6 +37,9 @@ from metadata.generated.schema.entity.services.connections.database.dorisConnect
|
||||
from metadata.generated.schema.entity.services.connections.database.greenplumConnection import (
|
||||
GreenplumConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.hiveConnection import (
|
||||
HiveConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.mongoDBConnection import (
|
||||
MongoDBConnection,
|
||||
)
|
||||
@ -247,6 +250,25 @@ class SSLManager:
|
||||
connection.connectionArguments.root["ssl_context"] = ssl_context
|
||||
return connection
|
||||
|
||||
@setup_ssl.register(HiveConnection)
|
||||
def _(self, connection):
|
||||
connection = cast(HiveConnection, connection)
|
||||
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
|
||||
# Add certificate paths if available (following MySQL pattern)
|
||||
ssl_args = connection.connectionArguments.root.get("ssl", {})
|
||||
if self.ca_file_path:
|
||||
ssl_args["ssl_ca"] = self.ca_file_path
|
||||
if self.cert_file_path:
|
||||
ssl_args["ssl_cert"] = self.cert_file_path
|
||||
if self.key_file_path:
|
||||
ssl_args["ssl_key"] = self.key_file_path
|
||||
connection.connectionArguments.root["ssl"] = ssl_args
|
||||
|
||||
return connection
|
||||
|
||||
|
||||
@singledispatch
|
||||
def check_ssl_and_init(
|
||||
@ -375,6 +397,25 @@ def _(connection):
|
||||
return None
|
||||
|
||||
|
||||
@check_ssl_and_init.register(HiveConnection)
|
||||
def _(connection):
|
||||
service_connection = cast(HiveConnection, connection)
|
||||
if hasattr(service_connection, "useSSL") and service_connection.useSSL:
|
||||
# Check if SSL config is provided in sslConfig (following MySQL pattern)
|
||||
if hasattr(service_connection, "sslConfig") and service_connection.sslConfig:
|
||||
if (
|
||||
service_connection.sslConfig.root.caCertificate
|
||||
or service_connection.sslConfig.root.sslCertificate
|
||||
or service_connection.sslConfig.root.sslKey
|
||||
):
|
||||
return SSLManager(
|
||||
ca=service_connection.sslConfig.root.caCertificate,
|
||||
cert=service_connection.sslConfig.root.sslCertificate,
|
||||
key=service_connection.sslConfig.root.sslKey,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_ssl_connection(service_config):
|
||||
try:
|
||||
# To be cleaned up as part of https://github.com/open-metadata/OpenMetadata/issues/15913
|
||||
|
@ -14,7 +14,7 @@ Test Hive using the topology
|
||||
|
||||
import types
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from sqlalchemy.types import INTEGER, VARCHAR, Integer, String
|
||||
|
||||
@ -32,6 +32,11 @@ from metadata.generated.schema.entity.data.table import (
|
||||
DataType,
|
||||
TableType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.hiveConnection import (
|
||||
Auth,
|
||||
HiveConnection,
|
||||
HiveScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseConnection,
|
||||
DatabaseService,
|
||||
@ -40,8 +45,17 @@ from metadata.generated.schema.entity.services.databaseService import (
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
OpenMetadataWorkflowConfig,
|
||||
)
|
||||
from metadata.generated.schema.security.ssl.validateSSLClientConfig import (
|
||||
ValidateSslClientConfig,
|
||||
)
|
||||
from metadata.generated.schema.security.ssl.verifySSLConfig import SslConfig
|
||||
from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
|
||||
from metadata.ingestion.source.database.hive.connection import (
|
||||
get_connection,
|
||||
get_connection_url,
|
||||
)
|
||||
from metadata.ingestion.source.database.hive.metadata import HiveSource
|
||||
|
||||
mock_hive_config = {
|
||||
@ -162,6 +176,7 @@ EXPECTED_TABLE = [
|
||||
dataLength=1,
|
||||
dataTypeDisplay="varchar(50)",
|
||||
constraint="NULL",
|
||||
tags=None,
|
||||
),
|
||||
Column(
|
||||
name=ColumnName("sample_col_2"),
|
||||
@ -169,6 +184,7 @@ EXPECTED_TABLE = [
|
||||
dataLength=1,
|
||||
dataTypeDisplay="int",
|
||||
constraint="NULL",
|
||||
tags=None,
|
||||
),
|
||||
Column(
|
||||
name=ColumnName("sample_col_3"),
|
||||
@ -176,6 +192,7 @@ EXPECTED_TABLE = [
|
||||
dataLength=1,
|
||||
dataTypeDisplay="varchar(50)",
|
||||
constraint="NULL",
|
||||
tags=None,
|
||||
),
|
||||
Column(
|
||||
name=ColumnName("sample_col_4"),
|
||||
@ -183,6 +200,7 @@ EXPECTED_TABLE = [
|
||||
dataLength=1,
|
||||
dataTypeDisplay="varchar(50)",
|
||||
constraint="NULL",
|
||||
tags=None,
|
||||
),
|
||||
],
|
||||
tableConstraints=[],
|
||||
@ -222,6 +240,89 @@ EXPECTED_COMPLEX_COL_TYPE = [
|
||||
},
|
||||
]
|
||||
|
||||
# SSL-specific mock configurations
|
||||
mock_hive_ssl_config = {
|
||||
"source": {
|
||||
"type": "hive",
|
||||
"serviceName": "sample_hive_ssl",
|
||||
"serviceConnection": {
|
||||
"config": {
|
||||
"type": "Hive",
|
||||
"databaseSchema": "test_database_schema",
|
||||
"username": "username",
|
||||
"hostPort": "localhost:1466",
|
||||
"useSSL": True,
|
||||
"sslConfig": {
|
||||
"sslCertificate": "test_cert.pem",
|
||||
"sslKey": "test_key.pem",
|
||||
"caCertificate": "test_ca.pem",
|
||||
},
|
||||
}
|
||||
},
|
||||
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
|
||||
},
|
||||
"sink": {"type": "metadata-rest", "config": {}},
|
||||
"workflowConfig": {
|
||||
"openMetadataServerConfig": {
|
||||
"hostPort": "http://localhost:8585/api",
|
||||
"authProvider": "openmetadata",
|
||||
"securityConfig": {"jwtToken": "hive"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mock_hive_https_config = {
|
||||
"source": {
|
||||
"type": "hive",
|
||||
"serviceName": "sample_hive_https",
|
||||
"serviceConnection": {
|
||||
"config": {
|
||||
"type": "Hive",
|
||||
"scheme": "hive+https",
|
||||
"databaseSchema": "test_database_schema",
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"hostPort": "localhost:1000",
|
||||
"auth": "BASIC",
|
||||
}
|
||||
},
|
||||
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
|
||||
},
|
||||
"sink": {"type": "metadata-rest", "config": {}},
|
||||
"workflowConfig": {
|
||||
"openMetadataServerConfig": {
|
||||
"hostPort": "http://localhost:8585/api",
|
||||
"authProvider": "openmetadata",
|
||||
"securityConfig": {"jwtToken": "hive"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# SSL configuration objects for testing
|
||||
mock_ssl_config = ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("test_cert.pem"),
|
||||
sslKey=CustomSecretStr("test_key.pem"),
|
||||
caCertificate=CustomSecretStr("test_ca.pem"),
|
||||
)
|
||||
|
||||
mock_hive_connection_ssl = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
sslConfig=SslConfig(root=mock_ssl_config),
|
||||
)
|
||||
|
||||
mock_hive_connection_https = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
password=CustomSecretStr("password"),
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.BASIC,
|
||||
)
|
||||
|
||||
|
||||
class HiveUnitTest(TestCase):
|
||||
"""
|
||||
@ -286,10 +387,11 @@ class HiveUnitTest(TestCase):
|
||||
self.hive.inspector.get_columns = (
|
||||
lambda table_name, schema_name, table_type, db_name: MOCK_COLUMN_VALUE
|
||||
)
|
||||
assert EXPECTED_TABLE == [
|
||||
results = [
|
||||
either.right
|
||||
for either in self.hive.yield_table(("sample_table", "Regular"))
|
||||
]
|
||||
assert EXPECTED_TABLE == results
|
||||
|
||||
def test_col_data_type(self):
|
||||
"""
|
||||
@ -324,3 +426,416 @@ class HiveUnitTest(TestCase):
|
||||
|
||||
String.__eq__ = custom_eq
|
||||
self.assertEqual(expected, original)
|
||||
|
||||
def test_ssl_connection_configuration(self):
|
||||
"""
|
||||
Test SSL configuration in Hive connection
|
||||
"""
|
||||
# Test SSL configuration with certificates
|
||||
ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
sslConfig=SslConfig(
|
||||
root=ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("test_cert.pem"),
|
||||
sslKey=CustomSecretStr("test_key.pem"),
|
||||
caCertificate=CustomSecretStr("test_ca.pem"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertTrue(ssl_connection.useSSL)
|
||||
self.assertIsNotNone(ssl_connection.sslConfig)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.sslCertificate.get_secret_value(),
|
||||
"test_cert.pem",
|
||||
)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.sslKey.get_secret_value(), "test_key.pem"
|
||||
)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.caCertificate.get_secret_value(),
|
||||
"test_ca.pem",
|
||||
)
|
||||
|
||||
def test_https_scheme_configuration(self):
|
||||
"""
|
||||
Test HTTPS scheme configuration in Hive connection
|
||||
"""
|
||||
https_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
password=CustomSecretStr("password"),
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.BASIC,
|
||||
)
|
||||
|
||||
self.assertEqual(https_connection.scheme, HiveScheme.hive_https)
|
||||
self.assertEqual(https_connection.auth, Auth.BASIC)
|
||||
self.assertEqual(https_connection.username, "username")
|
||||
self.assertEqual(https_connection.password.get_secret_value(), "password")
|
||||
|
||||
@patch("metadata.ingestion.source.database.hive.connection.check_ssl_and_init")
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.hive.connection.create_generic_db_connection"
|
||||
)
|
||||
def test_get_connection_with_ssl(self, mock_create_connection, mock_ssl_manager):
|
||||
"""
|
||||
Test get_connection function with SSL configuration
|
||||
"""
|
||||
# Mock SSL manager
|
||||
mock_ssl_manager_instance = Mock()
|
||||
mock_ssl_manager_instance.setup_ssl.return_value = mock_hive_connection_ssl
|
||||
mock_ssl_manager.return_value = mock_ssl_manager_instance
|
||||
|
||||
# Mock create_generic_db_connection
|
||||
mock_engine = Mock()
|
||||
mock_create_connection.return_value = mock_engine
|
||||
|
||||
# Test SSL connection
|
||||
result = get_connection(mock_hive_connection_ssl)
|
||||
|
||||
# Verify SSL manager was called
|
||||
mock_ssl_manager.assert_called_once()
|
||||
mock_ssl_manager_instance.setup_ssl.assert_called_once_with(
|
||||
mock_hive_connection_ssl
|
||||
)
|
||||
|
||||
# Verify connection was created
|
||||
mock_create_connection.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, mock_engine)
|
||||
|
||||
@patch("metadata.ingestion.source.database.hive.connection.check_ssl_and_init")
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.hive.connection.create_generic_db_connection"
|
||||
)
|
||||
def test_get_connection_without_ssl(self, mock_create_connection, mock_ssl_manager):
|
||||
"""
|
||||
Test get_connection function without SSL configuration
|
||||
"""
|
||||
# Mock SSL manager returns None (no SSL)
|
||||
mock_ssl_manager.return_value = None
|
||||
|
||||
# Mock create_generic_db_connection
|
||||
mock_engine = Mock()
|
||||
mock_create_connection.return_value = mock_engine
|
||||
|
||||
# Test non-SSL connection
|
||||
non_ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=False,
|
||||
)
|
||||
|
||||
result = get_connection(non_ssl_connection)
|
||||
|
||||
# Verify SSL manager was called but returned None
|
||||
mock_ssl_manager.assert_called_once()
|
||||
|
||||
# Verify connection was created
|
||||
mock_create_connection.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, mock_engine)
|
||||
|
||||
def test_connection_url_with_ssl(self):
|
||||
"""
|
||||
Test connection URL generation with SSL configuration
|
||||
"""
|
||||
# Test basic SSL connection
|
||||
ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
)
|
||||
|
||||
url = get_connection_url(ssl_connection)
|
||||
self.assertEqual(url, "hive://username@localhost:1466")
|
||||
|
||||
# Test HTTPS scheme connection
|
||||
https_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
password=CustomSecretStr("password"),
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.BASIC,
|
||||
)
|
||||
|
||||
url = get_connection_url(https_connection)
|
||||
self.assertEqual(url, "hive+https://username:password@localhost:1000")
|
||||
|
||||
def test_custom_hive_connection_ssl_initialization(self):
|
||||
"""
|
||||
Test CustomHiveConnection SSL initialization
|
||||
"""
|
||||
# Test SSL connection with certificates
|
||||
ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
sslConfig=SslConfig(
|
||||
root=ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("test_cert.pem"),
|
||||
sslKey=CustomSecretStr("test_key.pem"),
|
||||
caCertificate=CustomSecretStr("test_ca.pem"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Test the configuration parsing
|
||||
self.assertTrue(ssl_connection.useSSL)
|
||||
self.assertIsNotNone(ssl_connection.sslConfig)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.sslCertificate.get_secret_value(),
|
||||
"test_cert.pem",
|
||||
)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.sslKey.get_secret_value(), "test_key.pem"
|
||||
)
|
||||
self.assertEqual(
|
||||
ssl_connection.sslConfig.root.caCertificate.get_secret_value(),
|
||||
"test_ca.pem",
|
||||
)
|
||||
|
||||
def test_ssl_config_validation(self):
|
||||
"""
|
||||
Test SSL configuration validation
|
||||
"""
|
||||
# Test valid SSL config
|
||||
valid_ssl_config = ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("valid_cert.pem"),
|
||||
sslKey=CustomSecretStr("valid_key.pem"),
|
||||
caCertificate=CustomSecretStr("valid_ca.pem"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
valid_ssl_config.sslCertificate.get_secret_value(), "valid_cert.pem"
|
||||
)
|
||||
self.assertEqual(valid_ssl_config.sslKey.get_secret_value(), "valid_key.pem")
|
||||
self.assertEqual(
|
||||
valid_ssl_config.caCertificate.get_secret_value(), "valid_ca.pem"
|
||||
)
|
||||
|
||||
# Test SSL config with only some certificates
|
||||
partial_ssl_config = ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("cert_only.pem")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
partial_ssl_config.sslCertificate.get_secret_value(), "cert_only.pem"
|
||||
)
|
||||
self.assertIsNone(partial_ssl_config.sslKey)
|
||||
self.assertIsNone(partial_ssl_config.caCertificate)
|
||||
|
||||
def test_hive_scheme_enum_values(self):
|
||||
"""
|
||||
Test HiveScheme enum values for SSL support
|
||||
"""
|
||||
self.assertEqual(HiveScheme.hive.value, "hive")
|
||||
self.assertEqual(HiveScheme.hive_http.value, "hive+http")
|
||||
self.assertEqual(HiveScheme.hive_https.value, "hive+https")
|
||||
|
||||
# Verify all schemes are available
|
||||
schemes = [scheme.value for scheme in HiveScheme]
|
||||
self.assertIn("hive", schemes)
|
||||
self.assertIn("hive+http", schemes)
|
||||
self.assertIn("hive+https", schemes)
|
||||
|
||||
def test_auth_enum_values(self):
|
||||
"""
|
||||
Test Auth enum values for SSL authentication
|
||||
"""
|
||||
self.assertEqual(Auth.NONE.value, "NONE")
|
||||
self.assertEqual(Auth.LDAP.value, "LDAP")
|
||||
self.assertEqual(Auth.KERBEROS.value, "KERBEROS")
|
||||
self.assertEqual(Auth.CUSTOM.value, "CUSTOM")
|
||||
self.assertEqual(Auth.NOSASL.value, "NOSASL")
|
||||
self.assertEqual(Auth.BASIC.value, "BASIC")
|
||||
self.assertEqual(Auth.GSSAPI.value, "GSSAPI")
|
||||
self.assertEqual(Auth.JWT.value, "JWT")
|
||||
self.assertEqual(Auth.PLAIN.value, "PLAIN")
|
||||
|
||||
@patch("metadata.ingestion.source.database.hive.connection.check_ssl_and_init")
|
||||
def test_ssl_manager_integration(self, mock_ssl_manager):
|
||||
"""
|
||||
Test SSL manager integration with Hive connection
|
||||
"""
|
||||
# Mock SSL manager
|
||||
mock_ssl_manager_instance = Mock()
|
||||
mock_ssl_manager_instance.setup_ssl.return_value = mock_hive_connection_ssl
|
||||
mock_ssl_manager.return_value = mock_ssl_manager_instance
|
||||
|
||||
# Test that SSL manager is called when SSL is enabled
|
||||
ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
)
|
||||
|
||||
# Test the configuration
|
||||
self.assertTrue(ssl_connection.useSSL)
|
||||
|
||||
# Note: The SSL manager would be called when get_connection is actually invoked
|
||||
# This test just verifies the SSL configuration is properly set
|
||||
|
||||
def test_custom_hive_connection_ssl_parameters(self):
|
||||
"""
|
||||
Test CustomHiveConnection SSL parameter handling
|
||||
"""
|
||||
# Test SSL parameters that would be passed to CustomHiveConnection
|
||||
ssl_params = {
|
||||
"use_ssl": True,
|
||||
"ssl_certfile": "test_cert.pem",
|
||||
"ssl_keyfile": "test_key.pem",
|
||||
"ssl_ca_certs": "test_ca.pem",
|
||||
"ssl_cert_reqs": 0, # ssl.CERT_NONE
|
||||
}
|
||||
|
||||
# Verify SSL parameters are properly structured
|
||||
self.assertTrue(ssl_params["use_ssl"])
|
||||
self.assertEqual(ssl_params["ssl_certfile"], "test_cert.pem")
|
||||
self.assertEqual(ssl_params["ssl_keyfile"], "test_key.pem")
|
||||
self.assertEqual(ssl_params["ssl_ca_certs"], "test_ca.pem")
|
||||
self.assertEqual(ssl_params["ssl_cert_reqs"], 0)
|
||||
|
||||
def test_https_scheme_authentication_modes(self):
|
||||
"""
|
||||
Test HTTPS scheme with different authentication modes
|
||||
"""
|
||||
# Test BASIC authentication
|
||||
basic_auth_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
password=CustomSecretStr("password"),
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.BASIC,
|
||||
)
|
||||
|
||||
self.assertEqual(basic_auth_connection.auth, Auth.BASIC)
|
||||
self.assertEqual(basic_auth_connection.scheme, HiveScheme.hive_https)
|
||||
|
||||
# Test NOSASL authentication
|
||||
nosasl_auth_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.NOSASL,
|
||||
)
|
||||
|
||||
self.assertEqual(nosasl_auth_connection.auth, Auth.NOSASL)
|
||||
|
||||
# Test NONE authentication
|
||||
none_auth_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive_https,
|
||||
username="username",
|
||||
hostPort="localhost:1000",
|
||||
auth=Auth.NONE,
|
||||
)
|
||||
|
||||
self.assertEqual(none_auth_connection.auth, Auth.NONE)
|
||||
|
||||
def test_ssl_certificate_parameter_mapping(self):
|
||||
"""
|
||||
Test SSL certificate parameter mapping for HTTPS scheme
|
||||
"""
|
||||
# Test SSL certificate parameter mapping as used in CustomHiveConnection
|
||||
ssl_cert_parameter_map = {
|
||||
"none": 0, # CERT_NONE
|
||||
"optional": 1, # CERT_OPTIONAL
|
||||
"required": 2, # CERT_REQUIRED
|
||||
}
|
||||
|
||||
self.assertEqual(ssl_cert_parameter_map["none"], 0)
|
||||
self.assertEqual(ssl_cert_parameter_map["optional"], 1)
|
||||
self.assertEqual(ssl_cert_parameter_map["required"], 2)
|
||||
|
||||
# Test default value handling
|
||||
default_ssl_cert = "none"
|
||||
self.assertEqual(ssl_cert_parameter_map.get(default_ssl_cert, 0), 0)
|
||||
|
||||
def test_connection_arguments_ssl_setup(self):
|
||||
"""
|
||||
Test SSL setup in connection arguments
|
||||
"""
|
||||
# Test that SSL configuration is properly added to connection arguments
|
||||
ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
useSSL=True,
|
||||
sslConfig=SslConfig(
|
||||
root=ValidateSslClientConfig(
|
||||
sslCertificate=CustomSecretStr("test_cert.pem"),
|
||||
sslKey=CustomSecretStr("test_key.pem"),
|
||||
caCertificate=CustomSecretStr("test_ca.pem"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Verify SSL configuration is present
|
||||
self.assertTrue(ssl_connection.useSSL)
|
||||
self.assertIsNotNone(ssl_connection.sslConfig)
|
||||
|
||||
# Test that SSL config values are accessible
|
||||
ssl_config = ssl_connection.sslConfig.root
|
||||
self.assertEqual(ssl_config.sslCertificate.get_secret_value(), "test_cert.pem")
|
||||
self.assertEqual(ssl_config.sslKey.get_secret_value(), "test_key.pem")
|
||||
self.assertEqual(ssl_config.caCertificate.get_secret_value(), "test_ca.pem")
|
||||
|
||||
def test_kerberos_ssl_integration(self):
|
||||
"""
|
||||
Test Kerberos authentication with SSL
|
||||
"""
|
||||
# Test Kerberos connection with SSL
|
||||
kerberos_ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
hostPort="localhost:1466",
|
||||
auth=Auth.KERBEROS,
|
||||
kerberosServiceName="hive",
|
||||
useSSL=True,
|
||||
)
|
||||
|
||||
self.assertEqual(kerberos_ssl_connection.auth, Auth.KERBEROS)
|
||||
self.assertEqual(kerberos_ssl_connection.kerberosServiceName, "hive")
|
||||
self.assertTrue(kerberos_ssl_connection.useSSL)
|
||||
|
||||
def test_ldap_ssl_integration(self):
|
||||
"""
|
||||
Test LDAP authentication with SSL
|
||||
"""
|
||||
# Test LDAP connection with SSL
|
||||
ldap_ssl_connection = HiveConnection(
|
||||
type="Hive",
|
||||
scheme=HiveScheme.hive,
|
||||
username="username",
|
||||
password=CustomSecretStr("password"),
|
||||
hostPort="localhost:1466",
|
||||
auth=Auth.LDAP,
|
||||
useSSL=True,
|
||||
)
|
||||
|
||||
self.assertEqual(ldap_ssl_connection.auth, Auth.LDAP)
|
||||
self.assertEqual(ldap_ssl_connection.username, "username")
|
||||
self.assertEqual(ldap_ssl_connection.password.get_secret_value(), "password")
|
||||
self.assertTrue(ldap_ssl_connection.useSSL)
|
||||
|
@ -75,6 +75,17 @@
|
||||
"description": "Authentication options to pass to Hive connector. These options are based on SQLAlchemy.",
|
||||
"type": "string"
|
||||
},
|
||||
"useSSL": {
|
||||
"title": "Use SSL",
|
||||
"description": "Enable SSL connection to Hive server. When enabled, SSL transport will be used for secure communication.",
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"sslConfig": {
|
||||
"title": "SSL Configuration",
|
||||
"description": "SSL Configuration details.",
|
||||
"$ref": "../../../../security/ssl/verifySSLConfig.json#/definitions/sslConfig"
|
||||
},
|
||||
"metastoreConnection":{
|
||||
"title": "Hive Metastore Connection Details",
|
||||
"description": "Hive Metastore Connection Details",
|
||||
|
@ -63,7 +63,11 @@ export interface HiveConnection {
|
||||
/**
|
||||
* SQLAlchemy driver scheme options.
|
||||
*/
|
||||
scheme?: HiveScheme;
|
||||
scheme?: HiveScheme;
|
||||
/**
|
||||
* SSL Configuration details.
|
||||
*/
|
||||
sslConfig?: Config;
|
||||
supportsDBTExtraction?: boolean;
|
||||
supportsMetadataExtraction?: boolean;
|
||||
supportsProfiler?: boolean;
|
||||
@ -81,6 +85,11 @@ export interface HiveConnection {
|
||||
* Hive.
|
||||
*/
|
||||
username?: string;
|
||||
/**
|
||||
* Enable SSL connection to Hive server. When enabled, SSL transport will be used for secure
|
||||
* communication.
|
||||
*/
|
||||
useSSL?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
x
Reference in New Issue
Block a user