diff --git a/datahub-dao/src/main/java/com/linkedin/datahub/dao/DaoFactory.java b/datahub-dao/src/main/java/com/linkedin/datahub/dao/DaoFactory.java index 7f711c7006..50346c2a1d 100644 --- a/datahub-dao/src/main/java/com/linkedin/datahub/dao/DaoFactory.java +++ b/datahub-dao/src/main/java/com/linkedin/datahub/dao/DaoFactory.java @@ -16,6 +16,8 @@ public class DaoFactory { private static final String GMS_HOST_ENV_VAR = "DATAHUB_GMS_HOST"; private static final String GMS_PORT_ENV_VAR = "DATAHUB_GMS_PORT"; + private static final String GMS_USE_SSL_ENV_VAR = "DATAHUB_GMS_USE_SSL"; + private static final String GMS_SSL_PROTOCOL_VAR = "DATAHUB_GMS_SSL_PROTOCOL"; private static GmsDao _gmsDao; private static DocumentSearchDao datasetDocumentSearchDao; @@ -35,7 +37,9 @@ public class DaoFactory { private static GmsDao getGmsDao() { if (_gmsDao == null) { _gmsDao = new GmsDao(Configuration.getEnvironmentVariable(GMS_HOST_ENV_VAR), - Integer.valueOf(Configuration.getEnvironmentVariable(GMS_PORT_ENV_VAR))); + Integer.parseInt(Configuration.getEnvironmentVariable(GMS_PORT_ENV_VAR)), + Boolean.parseBoolean(Configuration.getEnvironmentVariable(GMS_USE_SSL_ENV_VAR, "False")), + Configuration.getEnvironmentVariable(GMS_SSL_PROTOCOL_VAR)); } return _gmsDao; } diff --git a/datahub-dao/src/main/java/com/linkedin/datahub/dao/table/GmsDao.java b/datahub-dao/src/main/java/com/linkedin/datahub/dao/table/GmsDao.java index fc3f00db65..5fa4493fb2 100644 --- a/datahub-dao/src/main/java/com/linkedin/datahub/dao/table/GmsDao.java +++ b/datahub-dao/src/main/java/com/linkedin/datahub/dao/table/GmsDao.java @@ -40,4 +40,8 @@ public class GmsDao { public GmsDao(@Nonnull String restliHostName, @Nonnull int restliHostPort) { this(DefaultRestliClientFactory.getRestLiClient(restliHostName, restliHostPort)); } + + public GmsDao(@Nonnull String restliHostName, @Nonnull int restliHostPort, boolean useSSL, String sslProtocol) { + this(DefaultRestliClientFactory.getRestLiClient(restliHostName, restliHostPort, useSSL, sslProtocol)); + } } diff --git a/docker/datahub-frontend/env/docker.env b/docker/datahub-frontend/env/docker.env index c132d9a0ac..c83c19e9d5 100644 --- a/docker/datahub-frontend/env/docker.env +++ b/docker/datahub-frontend/env/docker.env @@ -3,3 +3,9 @@ DATAHUB_GMS_PORT=8080 DATAHUB_SECRET=YouKnowNothing DATAHUB_APP_VERSION=1.0 DATAHUB_PLAY_MEM_BUFFER_SIZE=10MB + +# Uncomment and set these to support SSL connection to GMS +# NOTE: Currently GMS itself does not offer SSL support, these settings are intended for when there is a proxy in front +# of GMS that handles SSL, such as an EC2 Load Balancer. +#DATAHUB_GMS_USE_SSL=true +#DATAHUB_GMS_SSL_PROTOCOL= \ No newline at end of file diff --git a/docker/datahub-mce-consumer/env/docker.env b/docker/datahub-mce-consumer/env/docker.env index 59907e8278..fc707bc5c6 100644 --- a/docker/datahub-mce-consumer/env/docker.env +++ b/docker/datahub-mce-consumer/env/docker.env @@ -2,3 +2,9 @@ KAFKA_BOOTSTRAP_SERVER=broker:29092 KAFKA_SCHEMAREGISTRY_URL=http://schema-registry:8081 GMS_HOST=datahub-gms GMS_PORT=8080 + +# Uncomment and set these to support SSL connection to GMS +# NOTE: Currently GMS itself does not offer SSL support, these settings are intended for when there is a proxy in front +# of GMS that handles SSL, such as an EC2 Load Balancer. +#GMS_USE_SSL=true +#GMS_SSL_PROTOCOL= \ No newline at end of file diff --git a/metadata-jobs/mce-consumer-job/src/main/java/com/linkedin/metadata/kafka/config/RemoteWriterConfig.java b/metadata-jobs/mce-consumer-job/src/main/java/com/linkedin/metadata/kafka/config/RemoteWriterConfig.java index 5a53cab2cc..7125eb1fd2 100644 --- a/metadata-jobs/mce-consumer-job/src/main/java/com/linkedin/metadata/kafka/config/RemoteWriterConfig.java +++ b/metadata-jobs/mce-consumer-job/src/main/java/com/linkedin/metadata/kafka/config/RemoteWriterConfig.java @@ -16,10 +16,14 @@ public class RemoteWriterConfig { private String gmsHost; @Value("${GMS_PORT:8080}") private int gmsPort; + @Value("${GMS_USE_SSL:false}") + private boolean gmsUseSSL; + @Value("${GMS_SSL_PROTOCOL:#{null}}") + private String gmsSslProtocol; @Bean public BaseRemoteWriterDAO remoteWriterDAO() { - Client restClient = DefaultRestliClientFactory.getRestLiClient(gmsHost, gmsPort); + Client restClient = DefaultRestliClientFactory.getRestLiClient(gmsHost, gmsPort, gmsUseSSL, gmsSslProtocol); return new RestliRemoteWriterDAO(restClient); } } diff --git a/metadata-utils/src/main/java/com/linkedin/metadata/restli/DefaultRestliClientFactory.java b/metadata-utils/src/main/java/com/linkedin/metadata/restli/DefaultRestliClientFactory.java index 5d67adec00..b2e023527a 100644 --- a/metadata-utils/src/main/java/com/linkedin/metadata/restli/DefaultRestliClientFactory.java +++ b/metadata-utils/src/main/java/com/linkedin/metadata/restli/DefaultRestliClientFactory.java @@ -12,8 +12,14 @@ import com.linkedin.restli.client.RestClient; import org.apache.commons.lang.StringUtils; import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import java.security.InvalidParameterException; +import java.security.NoSuchAlgorithmException; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; public class DefaultRestliClientFactory { @@ -36,14 +42,55 @@ public class DefaultRestliClientFactory { @Nonnull public static RestClient getRestLiClient(@Nonnull String restLiServerHost, int restLiServerPort) { + return getRestLiClient(restLiServerHost, restLiServerPort, false, null); + } + + @Nonnull + public static RestClient getRestLiClient(@Nonnull String restLiServerHost, int restLiServerPort, boolean useSSL, + @Nullable String sslProtocol) { if (StringUtils.isBlank(restLiServerHost) || restLiServerPort <= 0) { throw new InvalidParameterException("Invalid restli server host name or port!"); } + if (useSSL) { + return getHttpsRestClient(restLiServerHost, restLiServerPort, sslProtocol); + } else { + return getHttpRestClient(restLiServerHost, restLiServerPort); + } + } + + private static RestClient getHttpsRestClient(@Nonnull String restLiServerHost, int restLiServerPort, + @Nullable String sslProtocol) { + Map params = new HashMap<>(); + + try { + params.put(HttpClientFactory.HTTP_SSL_CONTEXT, SSLContext.getDefault()); + } catch (NoSuchAlgorithmException ex) { + throw new RuntimeException(ex); + } + + SSLParameters sslParameters = new SSLParameters(); + if (sslProtocol != null) { + sslParameters.setProtocols(new String[]{sslProtocol}); + } + params.put(HttpClientFactory.HTTP_SSL_PARAMS, sslParameters); + + return getHttpRestClient("https", restLiServerHost, restLiServerPort, params); + } + + private static RestClient getHttpRestClient(@Nonnull String restLiServerHost, int restLiServerPort) { + return getHttpRestClient("http", restLiServerHost, restLiServerPort, new HashMap<>()); + } + + private static RestClient getHttpRestClient(@Nonnull String scheme, @Nonnull String restLiServerHost, + int restLiServerPort, @Nonnull Map params) { + Map finalParams = new HashMap<>(); + finalParams.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT_IN_MS); + finalParams.putAll(params); + HttpClientFactory http = new HttpClientFactory.Builder().build(); - TransportClient transportClient = http - .getClient(Collections.singletonMap(HttpClientFactory.HTTP_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT_IN_MS)); + TransportClient transportClient = http.getClient(Collections.unmodifiableMap(finalParams)); Client r2Client = new TransportClientAdapter(transportClient); - return new RestClient(r2Client, "http://" + restLiServerHost + ":" + restLiServerPort + "/"); + return new RestClient(r2Client, scheme + "://" + restLiServerHost + ":" + restLiServerPort + "/"); } }