Centralize OIDC Flow to handler, and refresh logic (#17082)

* Centralize OIDC Flow to handler, and refresh logic

* Remove forced condition

* Return on success
This commit is contained in:
Mohit Yadav 2024-07-18 23:53:50 +05:30 committed by GitHub
parent c340fe94f3
commit ebdd7f7fd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 945 additions and 845 deletions

View File

@ -192,6 +192,7 @@ authenticationConfiguration:
clientAuthenticationMethod: ${OIDC_CLIENT_AUTH_METHOD:-"client_secret_post"}
tenant: ${OIDC_TENANT:-""}
maxClockSkew: ${OIDC_MAX_CLOCK_SKEW:-""}
tokenValidity: ${OIDC_OM_REFRESH_TOKEN_VALIDITY:-"3600"} # in seconds
customParams: ${OIDC_CUSTOM_PARAMS:-}
samlConfiguration:
debugMode: ${SAML_DEBUG_MODE:-false}

View File

@ -13,8 +13,6 @@
package org.openmetadata.service;
import static org.openmetadata.service.security.SecurityUtil.tryCreateOidcClient;
import io.dropwizard.Application;
import io.dropwizard.configuration.EnvironmentVariableSubstitutor;
import io.dropwizard.configuration.SubstitutingSourceProvider;
@ -105,6 +103,7 @@ import org.openmetadata.service.security.AuthCallbackServlet;
import org.openmetadata.service.security.AuthLoginServlet;
import org.openmetadata.service.security.AuthLogoutServlet;
import org.openmetadata.service.security.AuthRefreshServlet;
import org.openmetadata.service.security.AuthenticationCodeFlowHandler;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.NoopAuthorizer;
import org.openmetadata.service.security.NoopFilter;
@ -127,7 +126,6 @@ import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverity
import org.openmetadata.service.util.jdbi.DatabaseAuthenticationProviderFactory;
import org.openmetadata.service.util.jdbi.OMSqlLogger;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.oidc.client.OidcClient;
import org.quartz.SchedulerException;
/** Main catalog application */
@ -273,55 +271,32 @@ public class OpenMetadataApplication extends Application<OpenMetadataApplication
contextHandler.setSessionHandler(new SessionHandler());
}
AuthenticationCodeFlowHandler authenticationCodeFlowHandler =
new AuthenticationCodeFlowHandler(
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration());
// Register Servlets
OidcClient oidcClient =
tryCreateOidcClient(config.getAuthenticationConfiguration().getOidcConfiguration());
oidcClient.setCallbackUrl(
config.getAuthenticationConfiguration().getOidcConfiguration().getCallbackUrl());
ServletRegistration.Dynamic authLogin =
environment
.servlets()
.addServlet(
"oauth_login",
new AuthLoginServlet(
oidcClient,
config.getAuthenticationConfiguration(),
config.getAuthorizerConfiguration()));
.addServlet("oauth_login", new AuthLoginServlet(authenticationCodeFlowHandler));
authLogin.addMapping("/api/v1/auth/login");
ServletRegistration.Dynamic authCallback =
environment
.servlets()
.addServlet(
"auth_callback",
new AuthCallbackServlet(
oidcClient,
config.getAuthenticationConfiguration(),
config.getAuthorizerConfiguration()));
.addServlet("auth_callback", new AuthCallbackServlet(authenticationCodeFlowHandler));
authCallback.addMapping("/callback");
ServletRegistration.Dynamic authLogout =
environment
.servlets()
.addServlet(
"auth_logout",
new AuthLogoutServlet(
config
.getAuthenticationConfiguration()
.getOidcConfiguration()
.getServerUrl()));
.addServlet("auth_logout", new AuthLogoutServlet(authenticationCodeFlowHandler));
authLogout.addMapping("/api/v1/auth/logout");
ServletRegistration.Dynamic refreshServlet =
environment
.servlets()
.addServlet(
"auth_refresh",
new AuthRefreshServlet(
oidcClient,
config
.getAuthenticationConfiguration()
.getOidcConfiguration()
.getServerUrl()));
.addServlet("auth_refresh", new AuthRefreshServlet(authenticationCodeFlowHandler));
refreshServlet.addMapping("/api/v1/auth/refresh");
}
}

View File

@ -1,281 +1,22 @@
package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE;
import static org.openmetadata.service.security.SecurityUtil.getClientAuthentication;
import static org.openmetadata.service.security.SecurityUtil.getErrorMessage;
import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken;
import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.ErrorObject;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.BadJWTExceptions;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.api.security.AuthorizerConfiguration;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.credentials.OidcCredentials;
@WebServlet("/callback")
@Slf4j
public class AuthCallbackServlet extends HttpServlet {
private final OidcClient client;
private final ClientAuthentication clientAuthentication;
private final List<String> claimsOrder;
private final Map<String, String> claimsMapping;
private final String serverUrl;
private final String principalDomain;
private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler;
public AuthCallbackServlet(
OidcClient oidcClient,
AuthenticationConfiguration authenticationConfiguration,
AuthorizerConfiguration authorizerConfiguration) {
CommonHelper.assertNotBlank(
"ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl());
this.client = oidcClient;
this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
this.claimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(claimsMapping);
this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
this.clientAuthentication = getClientAuthentication(client.getConfiguration());
this.principalDomain = authorizerConfiguration.getPrincipalDomain();
public AuthCallbackServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) {
this.authenticationCodeFlowHandler = authenticationCodeFlowHandler;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
try {
LOG.debug("Performing Auth Callback For User Session: {} ", req.getSession().getId());
String computedCallbackUrl = client.getCallbackUrl();
Map<String, List<String>> parameters = retrieveParameters(req);
AuthenticationResponse response =
AuthenticationResponseParser.parse(new URI(computedCallbackUrl), parameters);
if (response instanceof AuthenticationErrorResponse authenticationErrorResponse) {
LOG.error(
"Bad authentication response, error={}", authenticationErrorResponse.getErrorObject());
throw new TechnicalException("Bad authentication response");
}
LOG.debug("Authentication response successful");
AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) response;
OIDCProviderMetadata metadata = client.getConfiguration().getProviderMetadata();
if (metadata.supportsAuthorizationResponseIssuerParam()
&& !metadata.getIssuer().equals(successResponse.getIssuer())) {
throw new TechnicalException("Issuer mismatch, possible mix-up attack.");
}
// Optional state validation
validateStateIfRequired(req, resp, successResponse);
// Build Credentials
OidcCredentials credentials = buildCredentials(successResponse);
// Validations
validateAndSendTokenRequest(req, credentials, computedCallbackUrl);
// Log Error if the Refresh Token is null
if (credentials.getRefreshToken() == null) {
LOG.error("Refresh token is null for user session: {}", req.getSession().getId());
}
validateNonceIfRequired(req, credentials.getIdToken().getJWTClaimsSet());
// Put Credentials in Session
req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
// Redirect
sendRedirectWithToken(
resp, credentials, serverUrl, claimsMapping, claimsOrder, principalDomain);
} catch (Exception e) {
getErrorMessage(resp, e);
}
}
private OidcCredentials buildCredentials(AuthenticationSuccessResponse successResponse) {
OidcCredentials credentials = new OidcCredentials();
// get authorization code
AuthorizationCode code = successResponse.getAuthorizationCode();
if (code != null) {
credentials.setCode(code);
}
// get ID token
JWT idToken = successResponse.getIDToken();
if (idToken != null) {
credentials.setIdToken(idToken);
}
// get access token
AccessToken accessToken = successResponse.getAccessToken();
if (accessToken != null) {
credentials.setAccessToken(accessToken);
}
return credentials;
}
private void validateNonceIfRequired(HttpServletRequest req, JWTClaimsSet claimsSet)
throws BadJOSEException {
if (client.getConfiguration().isUseNonce()) {
String expectedNonce =
(String) req.getSession().getAttribute(client.getNonceSessionAttributeName());
if (CommonHelper.isNotBlank(expectedNonce)) {
String tokenNonce;
try {
tokenNonce = claimsSet.getStringClaim("nonce");
} catch (java.text.ParseException var10) {
throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + var10.getMessage());
}
if (tokenNonce == null) {
throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION;
}
if (!expectedNonce.equals(tokenNonce)) {
throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + tokenNonce);
}
} else {
throw new TechnicalException("Missing nonce parameter from Session.");
}
}
}
private void validateStateIfRequired(
HttpServletRequest req,
HttpServletResponse resp,
AuthenticationSuccessResponse successResponse) {
if (client.getConfiguration().isWithState()) {
// Validate state for CSRF mitigation
State requestState =
(State) req.getSession().getAttribute(client.getStateSessionAttributeName());
if (requestState == null || CommonHelper.isBlank(requestState.getValue())) {
getErrorMessage(resp, new TechnicalException("Missing state parameter"));
return;
}
State responseState = successResponse.getState();
if (responseState == null) {
throw new TechnicalException("Missing state parameter");
}
LOG.debug("Request state: {}/response state: {}", requestState, responseState);
if (!requestState.equals(responseState)) {
throw new TechnicalException(
"State parameter is different from the one sent in authentication request.");
}
}
}
private void validateAndSendTokenRequest(
HttpServletRequest req, OidcCredentials oidcCredentials, String computedCallbackUrl)
throws IOException, ParseException, URISyntaxException {
if (oidcCredentials.getCode() != null) {
LOG.debug("Initiating Token Request for User Session: {} ", req.getSession().getId());
CodeVerifier verifier =
(CodeVerifier)
req.getSession().getAttribute(client.getCodeVerifierSessionAttributeName());
// Token request
TokenRequest request =
createTokenRequest(
new AuthorizationCodeGrant(
oidcCredentials.getCode(), new URI(computedCallbackUrl), verifier));
executeTokenRequest(request, oidcCredentials);
}
}
protected Map<String, List<String>> retrieveParameters(HttpServletRequest request) {
Map<String, String[]> requestParameters = request.getParameterMap();
Map<String, List<String>> map = new HashMap<>();
for (var entry : requestParameters.entrySet()) {
map.put(entry.getKey(), Arrays.asList(entry.getValue()));
}
return map;
}
protected TokenRequest createTokenRequest(final AuthorizationGrant grant) {
if (client.getConfiguration().getClientAuthenticationMethod() != null) {
return new TokenRequest(
client.getConfiguration().findProviderMetadata().getTokenEndpointURI(),
this.clientAuthentication,
grant);
} else {
return new TokenRequest(
client.getConfiguration().findProviderMetadata().getTokenEndpointURI(),
new ClientID(client.getConfiguration().getClientId()),
grant);
}
}
private void executeTokenRequest(TokenRequest request, OidcCredentials credentials)
throws IOException, ParseException {
HTTPRequest tokenHttpRequest = request.toHTTPRequest();
client.getConfiguration().configureHttpRequest(tokenHttpRequest);
HTTPResponse httpResponse = tokenHttpRequest.send();
LOG.debug(
"Token response: status={}, content={}",
httpResponse.getStatusCode(),
httpResponse.getContent());
TokenResponse response = OIDCTokenResponseParser.parse(httpResponse);
if (response instanceof TokenErrorResponse tokenErrorResponse) {
ErrorObject errorObject = tokenErrorResponse.getErrorObject();
throw new TechnicalException(
"Bad token response, error="
+ errorObject.getCode()
+ ","
+ " description="
+ errorObject.getDescription());
}
LOG.debug("Token response successful");
OIDCTokenResponse tokenSuccessResponse = (OIDCTokenResponse) response;
OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens();
credentials.setAccessToken(oidcTokens.getAccessToken());
credentials.setRefreshToken(oidcTokens.getRefreshToken());
if (oidcTokens.getIDToken() != null) {
credentials.setIdToken(oidcTokens.getIDToken());
}
authenticationCodeFlowHandler.handleCallback(req, resp);
}
}

View File

@ -1,166 +1,22 @@
package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.service.security.SecurityUtil.getErrorMessage;
import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession;
import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken;
import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallenge;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.Nonce;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.api.security.AuthorizerConfiguration;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.oidc.client.GoogleOidcClient;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.config.OidcConfiguration;
import org.pac4j.oidc.credentials.OidcCredentials;
@WebServlet("/api/v1/auth/login")
@Slf4j
public class AuthLoginServlet extends HttpServlet {
public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile";
private final OidcClient client;
private final List<String> claimsOrder;
private final Map<String, String> claimsMapping;
private final String serverUrl;
private final String principalDomain;
private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler;
public AuthLoginServlet(
OidcClient oidcClient,
AuthenticationConfiguration authenticationConfiguration,
AuthorizerConfiguration authorizerConfiguration) {
this.client = oidcClient;
this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
this.claimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(claimsMapping);
this.principalDomain = authorizerConfiguration.getPrincipalDomain();
public AuthLoginServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) {
this.authenticationCodeFlowHandler = authenticationCodeFlowHandler;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
try {
LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId());
Optional<OidcCredentials> credentials = getUserCredentialsFromSession(req, client);
if (credentials.isPresent()) {
LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId());
sendRedirectWithToken(
resp, credentials.get(), serverUrl, claimsMapping, claimsOrder, principalDomain);
} else {
LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId());
Map<String, String> params = buildParams();
params.put(OidcConfiguration.REDIRECT_URI, client.getCallbackUrl());
addStateAndNonceParameters(req, params);
// This is always used to prompt the user to login
if (client instanceof GoogleOidcClient) {
params.put(OidcConfiguration.PROMPT, "consent");
} else {
params.put(OidcConfiguration.PROMPT, "login");
}
params.put(OidcConfiguration.MAX_AGE, "0");
String location = buildAuthenticationRequestUrl(params);
LOG.debug("Authentication request url: {}", location);
resp.sendRedirect(location);
}
} catch (Exception e) {
getErrorMessage(resp, new TechnicalException(e));
}
}
protected Map<String, String> buildParams() {
Map<String, String> authParams = new HashMap<>();
authParams.put(OidcConfiguration.SCOPE, client.getConfiguration().getScope());
authParams.put(OidcConfiguration.RESPONSE_TYPE, client.getConfiguration().getResponseType());
authParams.put(OidcConfiguration.RESPONSE_MODE, "query");
authParams.putAll(client.getConfiguration().getCustomParams());
authParams.put(OidcConfiguration.CLIENT_ID, client.getConfiguration().getClientId());
return new HashMap<>(authParams);
}
protected void addStateAndNonceParameters(
final HttpServletRequest request, final Map<String, String> params) {
// Init state for CSRF mitigation
if (client.getConfiguration().isWithState()) {
State state = new State(CommonHelper.randomString(10));
params.put(OidcConfiguration.STATE, state.getValue());
request.getSession().setAttribute(client.getStateSessionAttributeName(), state);
}
// Init nonce for replay attack mitigation
if (client.getConfiguration().isUseNonce()) {
Nonce nonce = new Nonce();
params.put(OidcConfiguration.NONCE, nonce.getValue());
request.getSession().setAttribute(client.getNonceSessionAttributeName(), nonce.getValue());
}
CodeChallengeMethod pkceMethod = client.getConfiguration().findPkceMethod();
// Use Default PKCE method if not disabled
if (pkceMethod == null && !client.getConfiguration().isDisablePkce()) {
pkceMethod = CodeChallengeMethod.S256;
}
if (pkceMethod != null) {
CodeVerifier verfifier = new CodeVerifier(CommonHelper.randomString(43));
request.getSession().setAttribute(client.getCodeVerifierSessionAttributeName(), verfifier);
params.put(
OidcConfiguration.CODE_CHALLENGE,
CodeChallenge.compute(pkceMethod, verfifier).getValue());
params.put(OidcConfiguration.CODE_CHALLENGE_METHOD, pkceMethod.getValue());
}
}
protected String buildAuthenticationRequestUrl(final Map<String, String> params) {
// Build authentication request query string
String queryString;
try {
queryString =
AuthenticationRequest.parse(
params.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey, e -> Collections.singletonList(e.getValue()))))
.toQueryString();
} catch (Exception e) {
throw new TechnicalException(e);
}
return client.getConfiguration().getProviderMetadata().getAuthorizationEndpointURI().toString()
+ '?'
+ queryString;
}
public static void writeJsonResponse(HttpServletResponse response, String message)
throws IOException {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.getOutputStream().print(message);
response.getOutputStream().flush();
response.setStatus(HttpServletResponse.SC_OK);
authenticationCodeFlowHandler.handleLogin(req, resp);
}
}

View File

@ -4,33 +4,20 @@ import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import lombok.extern.slf4j.Slf4j;
@WebServlet("/api/v1/auth/logout")
@Slf4j
public class AuthLogoutServlet extends HttpServlet {
private final String url;
private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler;
public AuthLogoutServlet(String url) {
this.url = url;
public AuthLogoutServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) {
this.authenticationCodeFlowHandler = authenticationCodeFlowHandler;
}
@Override
protected void doGet(
final HttpServletRequest httpServletRequest, final HttpServletResponse httpServletResponse) {
try {
LOG.debug("Performing application logout");
HttpSession session = httpServletRequest.getSession(false);
if (session != null) {
LOG.debug("Invalidating the session for logout");
session.invalidate();
httpServletResponse.sendRedirect(url);
} else {
LOG.error("No session store available for this web context");
}
} catch (Exception ex) {
LOG.error("[Auth Logout] Error while performing logout", ex);
}
authenticationCodeFlowHandler.handleLogout(httpServletRequest, httpServletResponse);
}
}

View File

@ -1,58 +1,22 @@
package org.openmetadata.service.security;
import static org.openmetadata.service.security.AuthLoginServlet.writeJsonResponse;
import static org.openmetadata.service.security.SecurityUtil.getErrorMessage;
import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession;
import java.util.Optional;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.auth.JwtResponse;
import org.openmetadata.service.util.JsonUtils;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.credentials.OidcCredentials;
@WebServlet("/api/v1/auth/refresh")
@Slf4j
public class AuthRefreshServlet extends HttpServlet {
private final OidcClient client;
private final String baseUrl;
private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler;
public AuthRefreshServlet(OidcClient oidcClient, String url) {
this.client = oidcClient;
this.baseUrl = url;
public AuthRefreshServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) {
this.authenticationCodeFlowHandler = authenticationCodeFlowHandler;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
try {
LOG.debug("Performing Auth Refresh For User Session: {} ", req.getSession().getId());
Optional<OidcCredentials> credentials = getUserCredentialsFromSession(req, client);
if (credentials.isPresent()) {
LOG.debug("Credentials Found For User Session: {} ", req.getSession().getId());
JwtResponse jwtResponse = new JwtResponse();
jwtResponse.setAccessToken(credentials.get().getIdToken().getParsedString());
jwtResponse.setExpiryDuration(
credentials
.get()
.getIdToken()
.getJWTClaimsSet()
.getExpirationTime()
.toInstant()
.getEpochSecond());
writeJsonResponse(resp, JsonUtils.pojoToJson(jwtResponse));
} else {
LOG.debug(
"Credentials Not Found For User Session: {}, Redirect to Logout ",
req.getSession().getId());
resp.sendRedirect(String.format("%s/logout", baseUrl));
}
} catch (Exception e) {
getErrorMessage(resp, new TechnicalException(e));
}
authenticationCodeFlowHandler.handleRefresh(req, resp);
}
}

View File

@ -0,0 +1,912 @@
package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY;
import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY;
import static org.openmetadata.service.security.SecurityUtil.findEmailFromClaims;
import static org.openmetadata.service.security.SecurityUtil.getClaimOrObject;
import static org.openmetadata.service.security.SecurityUtil.getFirstMatchJwtClaim;
import static org.openmetadata.service.util.UserUtil.getRoleListFromUser;
import static org.pac4j.core.util.CommonHelper.assertNotNull;
import static org.pac4j.core.util.CommonHelper.isNotEmpty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.ErrorObject;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.ClientSecretPost;
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallenge;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.util.JSONObjectUtils;
import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.AuthenticationResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.BadJWTExceptions;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.ws.rs.BadRequestException;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.minidev.json.JSONObject;
import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.api.security.AuthorizerConfiguration;
import org.openmetadata.schema.auth.JWTAuthMechanism;
import org.openmetadata.schema.auth.ServiceTokenType;
import org.openmetadata.schema.entity.teams.User;
import org.openmetadata.schema.security.client.OidcClientConfig;
import org.openmetadata.schema.type.Include;
import org.openmetadata.service.Entity;
import org.openmetadata.service.auth.JwtResponse;
import org.openmetadata.service.security.jwt.JWTTokenGenerator;
import org.openmetadata.service.util.JsonUtils;
import org.pac4j.core.context.HttpConstants;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.core.util.HttpUtils;
import org.pac4j.oidc.client.AzureAd2Client;
import org.pac4j.oidc.client.GoogleOidcClient;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.config.AzureAd2OidcConfiguration;
import org.pac4j.oidc.config.OidcConfiguration;
import org.pac4j.oidc.config.PrivateKeyJWTClientAuthnMethodConfig;
import org.pac4j.oidc.credentials.OidcCredentials;
@Slf4j
public class AuthenticationCodeFlowHandler {
private static final Collection<ClientAuthenticationMethod> SUPPORTED_METHODS =
Arrays.asList(
ClientAuthenticationMethod.CLIENT_SECRET_POST,
ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
ClientAuthenticationMethod.PRIVATE_KEY_JWT,
ClientAuthenticationMethod.NONE);
public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org";
public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile";
private final OidcClient client;
private final List<String> claimsOrder;
private final Map<String, String> claimsMapping;
private final String serverUrl;
private final ClientAuthentication clientAuthentication;
private final String principalDomain;
private final int tokenValidity;
public AuthenticationCodeFlowHandler(
AuthenticationConfiguration authenticationConfiguration,
AuthorizerConfiguration authorizerConfiguration) {
// Assert oidcConfig and Callback Url
CommonHelper.assertNotNull(
"OidcConfiguration", authenticationConfiguration.getOidcConfiguration());
CommonHelper.assertNotBlank(
"CallbackUrl", authenticationConfiguration.getOidcConfiguration().getCallbackUrl());
CommonHelper.assertNotBlank(
"ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl());
// Build Required Params
this.client = buildOidcClient(authenticationConfiguration.getOidcConfiguration());
client.setCallbackUrl(authenticationConfiguration.getOidcConfiguration().getCallbackUrl());
this.clientAuthentication = getClientAuthentication(client.getConfiguration());
this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
this.claimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(claimsMapping);
this.principalDomain = authorizerConfiguration.getPrincipalDomain();
this.tokenValidity = authenticationConfiguration.getOidcConfiguration().getTokenValidity();
}
private OidcClient buildOidcClient(OidcClientConfig clientConfig) {
String id = clientConfig.getId();
String secret = clientConfig.getSecret();
if (CommonHelper.isNotBlank(id) && CommonHelper.isNotBlank(secret)) {
OidcConfiguration configuration = new OidcConfiguration();
configuration.setClientId(id);
configuration.setResponseMode("query");
// Add Secret
if (CommonHelper.isNotBlank(secret)) {
configuration.setSecret(secret);
}
// Response Type
String responseType = clientConfig.getResponseType();
if (CommonHelper.isNotBlank(responseType)) {
configuration.setResponseType(responseType);
}
String scope = clientConfig.getScope();
if (CommonHelper.isNotBlank(scope)) {
configuration.setScope(scope);
}
String discoveryUri = clientConfig.getDiscoveryUri();
if (CommonHelper.isNotBlank(discoveryUri)) {
configuration.setDiscoveryURI(discoveryUri);
}
String useNonce = clientConfig.getUseNonce();
if (CommonHelper.isNotBlank(useNonce)) {
configuration.setUseNonce(Boolean.parseBoolean(useNonce));
}
String jwsAlgo = clientConfig.getPreferredJwsAlgorithm();
if (CommonHelper.isNotBlank(jwsAlgo)) {
configuration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(jwsAlgo));
}
String maxClockSkew = clientConfig.getMaxClockSkew();
if (CommonHelper.isNotBlank(maxClockSkew)) {
configuration.setMaxClockSkew(Integer.parseInt(maxClockSkew));
}
String clientAuthenticationMethod = clientConfig.getClientAuthenticationMethod().value();
if (CommonHelper.isNotBlank(clientAuthenticationMethod)) {
configuration.setClientAuthenticationMethod(
ClientAuthenticationMethod.parse(clientAuthenticationMethod));
}
// Disable PKCE
configuration.setDisablePkce(clientConfig.getDisablePkce());
// Add Custom Params
if (clientConfig.getCustomParams() != null) {
for (int j = 1; j <= 5; ++j) {
if (clientConfig.getCustomParams().containsKey(String.format("customParamKey%d", j))) {
configuration.addCustomParam(
clientConfig.getCustomParams().get(String.format("customParamKey%d", j)),
clientConfig.getCustomParams().get(String.format("customParamValue%d", j)));
}
}
}
String type = clientConfig.getType();
OidcClient oidcClient;
if ("azure".equalsIgnoreCase(type)) {
AzureAd2OidcConfiguration azureAdConfiguration =
new AzureAd2OidcConfiguration(configuration);
String tenant = clientConfig.getTenant();
if (CommonHelper.isNotBlank(tenant)) {
azureAdConfiguration.setTenant(tenant);
}
oidcClient = new AzureAd2Client(azureAdConfiguration);
} else if ("google".equalsIgnoreCase(type)) {
oidcClient = new GoogleOidcClient(configuration);
// Google needs it as param
oidcClient.getConfiguration().getCustomParams().put("access_type", "offline");
} else {
oidcClient = new OidcClient(configuration);
}
oidcClient.setName(String.format("OMOidcClient%s", oidcClient.getName()));
return oidcClient;
}
throw new IllegalArgumentException(
"Client ID and Client Secret is required to create OidcClient");
}
// Login
public void handleLogin(HttpServletRequest req, HttpServletResponse resp) {
try {
LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId());
Optional<OidcCredentials> credentials = getUserCredentialsFromSession(req);
if (credentials.isPresent()) {
LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId());
sendRedirectWithToken(resp, credentials.get());
} else {
LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId());
Map<String, String> params = buildLoginParams();
params.put(OidcConfiguration.REDIRECT_URI, client.getCallbackUrl());
addStateAndNonceParameters(client, req, params);
// This is always used to prompt the user to login
if (client instanceof GoogleOidcClient) {
params.put(OidcConfiguration.PROMPT, "consent");
} else {
params.put(OidcConfiguration.PROMPT, "login");
}
params.put(OidcConfiguration.MAX_AGE, "0");
String location = buildLoginAuthenticationRequestUrl(params);
LOG.debug("Authentication request url: {}", location);
resp.sendRedirect(location);
}
} catch (Exception e) {
getErrorMessage(resp, new TechnicalException(e));
}
}
// Callback
public void handleCallback(HttpServletRequest req, HttpServletResponse resp) {
try {
LOG.debug("Performing Auth Callback For User Session: {} ", req.getSession().getId());
String computedCallbackUrl = client.getCallbackUrl();
Map<String, List<String>> parameters = retrieveCallbackParameters(req);
AuthenticationResponse response =
AuthenticationResponseParser.parse(new URI(computedCallbackUrl), parameters);
if (response instanceof AuthenticationErrorResponse authenticationErrorResponse) {
LOG.error(
"Bad authentication response, error={}", authenticationErrorResponse.getErrorObject());
throw new TechnicalException("Bad authentication response");
}
LOG.debug("Authentication response successful");
AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) response;
OIDCProviderMetadata metadata = client.getConfiguration().getProviderMetadata();
if (metadata.supportsAuthorizationResponseIssuerParam()
&& !metadata.getIssuer().equals(successResponse.getIssuer())) {
throw new TechnicalException("Issuer mismatch, possible mix-up attack.");
}
// Optional state validation
validateStateIfRequired(req, resp, successResponse);
// Build Credentials
OidcCredentials credentials = buildCredentials(successResponse);
// Validations
validateAndSendTokenRequest(req, credentials, computedCallbackUrl);
// Log Error if the Refresh Token is null
if (credentials.getRefreshToken() == null) {
LOG.error("Refresh token is null for user session: {}", req.getSession().getId());
}
validateNonceIfRequired(req, credentials.getIdToken().getJWTClaimsSet());
// Put Credentials in Session
req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
// Redirect
sendRedirectWithToken(resp, credentials);
} catch (Exception e) {
getErrorMessage(resp, e);
}
}
// Logout
public void handleLogout(
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
try {
LOG.debug("Performing application logout");
HttpSession session = httpServletRequest.getSession(false);
if (session != null) {
LOG.debug("Invalidating the session for logout");
session.invalidate();
httpServletResponse.sendRedirect(serverUrl);
} else {
LOG.error("No session store available for this web context");
}
} catch (Exception ex) {
LOG.error("[Auth Logout] Error while performing logout", ex);
}
}
// Refresh
public void handleRefresh(
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
try {
LOG.debug(
"Performing Auth Refresh For User Session: {} ", httpServletRequest.getSession().getId());
Optional<OidcCredentials> credentials = getUserCredentialsFromSession(httpServletRequest);
if (credentials.isPresent()) {
LOG.debug(
"Credentials Found For User Session: {} ", httpServletRequest.getSession().getId());
JwtResponse jwtResponse = new JwtResponse();
jwtResponse.setAccessToken(credentials.get().getIdToken().getParsedString());
jwtResponse.setExpiryDuration(
credentials
.get()
.getIdToken()
.getJWTClaimsSet()
.getExpirationTime()
.toInstant()
.getEpochSecond());
writeJsonResponse(httpServletResponse, JsonUtils.pojoToJson(jwtResponse));
} else {
LOG.debug(
"Credentials Not Found For User Session: {}, Redirect to Logout ",
httpServletRequest.getSession().getId());
httpServletResponse.sendRedirect(String.format("%s/logout", serverUrl));
}
} catch (Exception e) {
getErrorMessage(httpServletResponse, new TechnicalException(e));
}
}
private String buildLoginAuthenticationRequestUrl(final Map<String, String> params) {
// Build authentication request query string
String queryString;
try {
queryString =
AuthenticationRequest.parse(
params.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey, e -> Collections.singletonList(e.getValue()))))
.toQueryString();
} catch (Exception e) {
throw new TechnicalException(e);
}
return client.getConfiguration().getProviderMetadata().getAuthorizationEndpointURI().toString()
+ '?'
+ queryString;
}
private Map<String, String> buildLoginParams() {
Map<String, String> authParams = new HashMap<>();
authParams.put(OidcConfiguration.SCOPE, client.getConfiguration().getScope());
authParams.put(OidcConfiguration.RESPONSE_TYPE, client.getConfiguration().getResponseType());
authParams.put(OidcConfiguration.RESPONSE_MODE, "query");
authParams.putAll(client.getConfiguration().getCustomParams());
authParams.put(OidcConfiguration.CLIENT_ID, client.getConfiguration().getClientId());
return new HashMap<>(authParams);
}
private Optional<OidcCredentials> getUserCredentialsFromSession(HttpServletRequest request)
throws URISyntaxException {
OidcCredentials credentials =
(OidcCredentials) request.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE);
if (credentials != null && credentials.getRefreshToken() != null) {
LOG.trace("Credentials found in session: {}", credentials);
renewOidcCredentials(request, credentials);
return Optional.of(credentials);
} else {
if (credentials == null) {
LOG.error("No credentials found against session. ID: {}", request.getSession().getId());
} else {
LOG.error("No refresh token found against session. ID: {}", request.getSession().getId());
}
}
return Optional.empty();
}
private void validateAndSendTokenRequest(
HttpServletRequest req, OidcCredentials oidcCredentials, String computedCallbackUrl)
throws IOException, com.nimbusds.oauth2.sdk.ParseException, URISyntaxException {
if (oidcCredentials.getCode() != null) {
LOG.debug("Initiating Token Request for User Session: {} ", req.getSession().getId());
CodeVerifier verifier =
(CodeVerifier)
req.getSession().getAttribute(client.getCodeVerifierSessionAttributeName());
// Token request
TokenRequest request =
createTokenRequest(
new AuthorizationCodeGrant(
oidcCredentials.getCode(), new URI(computedCallbackUrl), verifier));
executeAuthorizationCodeTokenRequest(request, oidcCredentials);
}
}
private void validateStateIfRequired(
HttpServletRequest req,
HttpServletResponse resp,
AuthenticationSuccessResponse successResponse) {
if (client.getConfiguration().isWithState()) {
// Validate state for CSRF mitigation
State requestState =
(State) req.getSession().getAttribute(client.getStateSessionAttributeName());
if (requestState == null || CommonHelper.isBlank(requestState.getValue())) {
getErrorMessage(resp, new TechnicalException("Missing state parameter"));
return;
}
State responseState = successResponse.getState();
if (responseState == null) {
throw new TechnicalException("Missing state parameter");
}
LOG.debug("Request state: {}/response state: {}", requestState, responseState);
if (!requestState.equals(responseState)) {
throw new TechnicalException(
"State parameter is different from the one sent in authentication request.");
}
}
}
private OidcCredentials buildCredentials(AuthenticationSuccessResponse successResponse) {
OidcCredentials credentials = new OidcCredentials();
// get authorization code
AuthorizationCode code = successResponse.getAuthorizationCode();
if (code != null) {
credentials.setCode(code);
}
// get ID token
JWT idToken = successResponse.getIDToken();
if (idToken != null) {
credentials.setIdToken(idToken);
}
// get access token
AccessToken accessToken = successResponse.getAccessToken();
if (accessToken != null) {
credentials.setAccessToken(accessToken);
}
return credentials;
}
private void validateNonceIfRequired(HttpServletRequest req, JWTClaimsSet claimsSet)
throws BadJOSEException {
if (client.getConfiguration().isUseNonce()) {
String expectedNonce =
(String) req.getSession().getAttribute(client.getNonceSessionAttributeName());
if (CommonHelper.isNotBlank(expectedNonce)) {
String tokenNonce;
try {
tokenNonce = claimsSet.getStringClaim("nonce");
} catch (java.text.ParseException var10) {
throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + var10.getMessage());
}
if (tokenNonce == null) {
throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION;
}
if (!expectedNonce.equals(tokenNonce)) {
throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + tokenNonce);
}
} else {
throw new TechnicalException("Missing nonce parameter from Session.");
}
}
}
protected Map<String, List<String>> retrieveCallbackParameters(HttpServletRequest request) {
Map<String, String[]> requestParameters = request.getParameterMap();
Map<String, List<String>> map = new HashMap<>();
for (var entry : requestParameters.entrySet()) {
map.put(entry.getKey(), Arrays.asList(entry.getValue()));
}
return map;
}
private void writeJsonResponse(HttpServletResponse response, String message) throws IOException {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.getOutputStream().print(message);
response.getOutputStream().flush();
response.setStatus(HttpServletResponse.SC_OK);
}
private ClientAuthentication getClientAuthentication(OidcConfiguration configuration) {
ClientID clientID = new ClientID(configuration.getClientId());
ClientAuthentication clientAuthenticationMechanism = null;
if (configuration.getSecret() != null) {
// check authentication methods
List<ClientAuthenticationMethod> metadataMethods =
configuration.findProviderMetadata().getTokenEndpointAuthMethods();
ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration);
final ClientAuthenticationMethod chosenMethod;
if (isNotEmpty(metadataMethods)) {
if (preferredMethod != null) {
if (metadataMethods.contains(preferredMethod)) {
chosenMethod = preferredMethod;
} else {
throw new TechnicalException(
"Preferred authentication method ("
+ preferredMethod
+ ") not supported "
+ "by provider according to provider metadata ("
+ metadataMethods
+ ").");
}
} else {
chosenMethod = firstSupportedMethod(metadataMethods);
}
} else {
chosenMethod =
preferredMethod != null ? preferredMethod : ClientAuthenticationMethod.getDefault();
LOG.info(
"Provider metadata does not provide Token endpoint authentication methods. Using: {}",
chosenMethod);
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(chosenMethod)) {
Secret clientSecret = new Secret(configuration.getSecret());
clientAuthenticationMechanism = new ClientSecretPost(clientID, clientSecret);
} else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(chosenMethod)) {
Secret clientSecret = new Secret(configuration.getSecret());
clientAuthenticationMechanism = new ClientSecretBasic(clientID, clientSecret);
} else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(chosenMethod)) {
PrivateKeyJWTClientAuthnMethodConfig privateKetJwtConfig =
configuration.getPrivateKeyJWTClientAuthnMethodConfig();
assertNotNull("privateKetJwtConfig", privateKetJwtConfig);
JWSAlgorithm jwsAlgo = privateKetJwtConfig.getJwsAlgorithm();
assertNotNull("privateKetJwtConfig.getJwsAlgorithm()", jwsAlgo);
PrivateKey privateKey = privateKetJwtConfig.getPrivateKey();
assertNotNull("privateKetJwtConfig.getPrivateKey()", privateKey);
String keyID = privateKetJwtConfig.getKeyID();
try {
clientAuthenticationMechanism =
new PrivateKeyJWT(
clientID,
configuration.findProviderMetadata().getTokenEndpointURI(),
jwsAlgo,
privateKey,
keyID,
null);
} catch (final JOSEException e) {
throw new TechnicalException(
"Cannot instantiate private key JWT client authentication method", e);
}
}
}
return clientAuthenticationMechanism;
}
private static ClientAuthenticationMethod getPreferredAuthenticationMethod(
OidcConfiguration config) {
ClientAuthenticationMethod configurationMethod = config.getClientAuthenticationMethod();
if (configurationMethod == null) {
return null;
}
if (!SUPPORTED_METHODS.contains(configurationMethod)) {
throw new TechnicalException(
"Configured authentication method (" + configurationMethod + ") is not supported.");
}
return configurationMethod;
}
private ClientAuthenticationMethod firstSupportedMethod(
final List<ClientAuthenticationMethod> metadataMethods) {
Optional<ClientAuthenticationMethod> firstSupported =
metadataMethods.stream().filter(SUPPORTED_METHODS::contains).findFirst();
if (firstSupported.isPresent()) {
return firstSupported.get();
} else {
throw new TechnicalException(
"None of the Token endpoint provider metadata authentication methods are supported: "
+ metadataMethods);
}
}
@SneakyThrows
public static void getErrorMessage(HttpServletResponse resp, Exception e) {
resp.setContentType("text/html; charset=UTF-8");
LOG.error("[Auth Callback Servlet] Failed in Auth Login : {}", e.getMessage());
resp.getOutputStream()
.println(
String.format(
"<p> [Auth Callback Servlet] Failed in Auth Login : %s </p>", e.getMessage()));
}
private void sendRedirectWithToken(HttpServletResponse response, OidcCredentials credentials)
throws ParseException, IOException {
JWT jwt = credentials.getIdToken();
Map<String, Object> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
claims.putAll(jwt.getJWTClaimsSet().getClaims());
String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims);
String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, principalDomain);
String url =
String.format(
"%s/auth/callback?id_token=%s&email=%s&name=%s",
serverUrl, credentials.getIdToken().getParsedString(), email, userName);
response.sendRedirect(url);
}
private void renewOidcCredentials(HttpServletRequest request, OidcCredentials credentials) {
LOG.debug("Renewing Credentials for User Session {}", request.getSession().getId());
if (client.getConfiguration() instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) {
refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials);
} else {
refreshTokenRequest(request, credentials);
}
request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
}
public void refreshTokenRequest(
final HttpServletRequest httpServletRequest, final OidcCredentials credentials) {
final var refreshToken = credentials.getRefreshToken();
if (refreshToken != null) {
try {
final var request = createTokenRequest(new RefreshTokenGrant(refreshToken));
HTTPResponse httpResponse = executeTokenHttpRequest(request);
if (httpResponse.getStatusCode() == 200) {
JSONObject jsonObjectResponse = httpResponse.getContentAsJSONObject();
String idTokenKey = "id_token";
if (jsonObjectResponse.containsKey(idTokenKey)) {
Object value = jsonObjectResponse.get(idTokenKey);
if (value == null) {
throw new com.nimbusds.oauth2.sdk.ParseException(
"JSON object member with key " + idTokenKey + " has null value");
} else {
LOG.info("Found a JWT token in the response, trying to parse it");
OIDCTokenResponse tokenSuccessResponse =
parseTokenResponseFromHttpResponse(httpResponse);
// Populate credentials
populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials);
}
} else {
// Note: since the id_token is not present, we must receive accessToken
// We can do better and get userInfo from
// client.getConfiguration().findProviderMetadata().getUserInfoEndpointURI()
// but currently we are just return the OM created token in the response
String accessToken = JSONObjectUtils.getString(jsonObjectResponse, "access_token");
LOG.info(
"Found an access token in the response, trying to parse it, Value : {}",
accessToken);
OIDCTokenResponse tokenSuccessResponse =
parseTokenResponseFromHttpResponse(httpResponse);
// Populate credentials
populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials);
OidcCredentials storedCredentials =
(OidcCredentials)
httpServletRequest.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE);
// Get the claims from the stored credentials
Map<String, Object> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
claims.putAll(storedCredentials.getIdToken().getJWTClaimsSet().getClaims());
String username =
SecurityUtil.findUserNameFromClaims(claimsMapping, claimsOrder, claims);
User user = Entity.getEntityByName(Entity.USER, username, "id", Include.NON_DELETED);
// Create a JWT here
JWTAuthMechanism jwtAuthMechanism =
JWTTokenGenerator.getInstance()
.generateJWTToken(
username,
getRoleListFromUser(user),
!nullOrEmpty(user.getIsAdmin()) && user.getIsAdmin(),
user.getEmail(),
tokenValidity,
false,
ServiceTokenType.OM_USER);
// Set the access token to the new JWT token
credentials.setIdToken(SignedJWT.parse(jwtAuthMechanism.getJWTToken()));
}
return;
} else {
throw new TechnicalException(
String.format(
"Failed to refresh id_token, response code:%s , Error : %s",
httpResponse.getStatusCode(), httpResponse.getContent()));
}
} catch (final IOException | com.nimbusds.oauth2.sdk.ParseException e) {
throw new TechnicalException(e);
} catch (ParseException e) {
throw new RuntimeException(e);
}
}
throw new BadRequestException("No refresh token available");
}
public static boolean isJWT(String token) {
return token.split("\\.").length == 3;
}
private void refreshAccessTokenAzureAd2Token(
AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) {
HttpURLConnection connection = null;
try {
Map<String, String> headers = new HashMap<>();
headers.put(
HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE);
headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON);
// get the token endpoint from discovery URI
URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL();
connection = HttpUtils.openPostConnection(tokenEndpointURL, headers);
BufferedWriter out =
new BufferedWriter(
new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8));
out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue()));
out.close();
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
throw new TechnicalException(
"request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection));
}
var body = HttpUtils.readBody(connection);
Map<String, Object> res = JsonUtils.readValue(body, new TypeReference<>() {});
azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token")));
} catch (final IOException e) {
throw new TechnicalException(e);
} finally {
HttpUtils.closeConnection(connection);
}
}
public static String findUserNameFromClaims(
Map<String, String> jwtPrincipalClaimsMapping,
List<String> jwtPrincipalClaimsOrder,
Map<String, ?> claims) {
if (!nullOrEmpty(jwtPrincipalClaimsMapping)) {
// We have a mapping available so we will use that
String usernameClaim = jwtPrincipalClaimsMapping.get(USERNAME_CLAIM_KEY);
String userNameClaimValue = getClaimOrObject(claims.get(usernameClaim));
if (!nullOrEmpty(userNameClaimValue)) {
return userNameClaimValue;
} else {
throw new AuthenticationException("Invalid JWT token, 'username' claim is not present");
}
} else {
String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims);
String userName;
if (jwtClaim.contains("@")) {
userName = jwtClaim.split("@")[0];
} else {
userName = jwtClaim;
}
return userName;
}
}
public static void validatePrincipalClaimsMapping(Map<String, String> mapping) {
if (!nullOrEmpty(mapping)) {
String username = mapping.get(USERNAME_CLAIM_KEY);
String email = mapping.get(EMAIL_CLAIM_KEY);
if (nullOrEmpty(username) || nullOrEmpty(email)) {
throw new IllegalArgumentException(
"Invalid JWT Principal Claims Mapping. Both username and email should be present");
}
}
// If emtpy, jwtPrincipalClaims will be used so no need to validate
}
private HTTPResponse executeTokenHttpRequest(TokenRequest request) throws IOException {
HTTPRequest tokenHttpRequest = request.toHTTPRequest();
client.getConfiguration().configureHttpRequest(tokenHttpRequest);
HTTPResponse httpResponse = tokenHttpRequest.send();
LOG.debug(
"Token response: status={}, content={}",
httpResponse.getStatusCode(),
httpResponse.getContent());
return httpResponse;
}
private TokenRequest createTokenRequest(final AuthorizationGrant grant) {
if (clientAuthentication != null) {
return new TokenRequest(
client.getConfiguration().findProviderMetadata().getTokenEndpointURI(),
this.clientAuthentication,
grant);
} else {
return new TokenRequest(
client.getConfiguration().findProviderMetadata().getTokenEndpointURI(),
new ClientID(client.getConfiguration().getClientId()),
grant);
}
}
private void addStateAndNonceParameters(
final OidcClient client, final HttpServletRequest request, final Map<String, String> params) {
// Init state for CSRF mitigation
if (client.getConfiguration().isWithState()) {
State state = new State(CommonHelper.randomString(10));
params.put(OidcConfiguration.STATE, state.getValue());
request.getSession().setAttribute(client.getStateSessionAttributeName(), state);
}
// Init nonce for replay attack mitigation
if (client.getConfiguration().isUseNonce()) {
Nonce nonce = new Nonce();
params.put(OidcConfiguration.NONCE, nonce.getValue());
request.getSession().setAttribute(client.getNonceSessionAttributeName(), nonce.getValue());
}
CodeChallengeMethod pkceMethod = client.getConfiguration().findPkceMethod();
// Use Default PKCE method if not disabled
if (pkceMethod == null && !client.getConfiguration().isDisablePkce()) {
pkceMethod = CodeChallengeMethod.S256;
}
if (pkceMethod != null) {
CodeVerifier verfifier = new CodeVerifier(CommonHelper.randomString(43));
request.getSession().setAttribute(client.getCodeVerifierSessionAttributeName(), verfifier);
params.put(
OidcConfiguration.CODE_CHALLENGE,
CodeChallenge.compute(pkceMethod, verfifier).getValue());
params.put(OidcConfiguration.CODE_CHALLENGE_METHOD, pkceMethod.getValue());
}
}
private void executeAuthorizationCodeTokenRequest(
TokenRequest request, OidcCredentials credentials)
throws IOException, com.nimbusds.oauth2.sdk.ParseException {
HTTPResponse httpResponse = executeTokenHttpRequest(request);
OIDCTokenResponse tokenSuccessResponse = parseTokenResponseFromHttpResponse(httpResponse);
// Populate credentials
populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials);
}
private void populateCredentialsFromTokenResponse(
OIDCTokenResponse tokenSuccessResponse, OidcCredentials credentials) {
OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens();
credentials.setAccessToken(oidcTokens.getAccessToken());
credentials.setRefreshToken(oidcTokens.getRefreshToken());
if (oidcTokens.getIDToken() != null) {
credentials.setIdToken(oidcTokens.getIDToken());
}
}
private OIDCTokenResponse parseTokenResponseFromHttpResponse(HTTPResponse httpResponse)
throws com.nimbusds.oauth2.sdk.ParseException {
TokenResponse response = OIDCTokenResponseParser.parse(httpResponse);
if (response instanceof TokenErrorResponse tokenErrorResponse) {
ErrorObject errorObject = tokenErrorResponse.getErrorObject();
throw new TechnicalException(
"Bad token response, error="
+ errorObject.getCode()
+ ","
+ " description="
+ errorObject.getDescription());
}
LOG.debug("Token response successful");
return (OIDCTokenResponse) response;
}
}

View File

@ -14,82 +14,28 @@
package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE;
import static org.openmetadata.service.security.JwtFilter.BOT_CLAIM;
import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY;
import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY;
import static org.pac4j.core.util.CommonHelper.assertNotNull;
import static org.pac4j.core.util.CommonHelper.isNotEmpty;
import com.auth0.jwt.interfaces.Claim;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.JWT;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.ClientSecretPost;
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.security.PrivateKey;
import java.text.ParseException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.SecurityContext;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.schema.security.client.OidcClientConfig;
import org.openmetadata.service.OpenMetadataApplicationConfig;
import org.openmetadata.service.util.JsonUtils;
import org.pac4j.core.context.HttpConstants;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.core.util.HttpUtils;
import org.pac4j.oidc.client.AzureAd2Client;
import org.pac4j.oidc.client.GoogleOidcClient;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.config.AzureAd2OidcConfiguration;
import org.pac4j.oidc.config.OidcConfiguration;
import org.pac4j.oidc.config.PrivateKeyJWTClientAuthnMethodConfig;
import org.pac4j.oidc.credentials.OidcCredentials;
import org.pac4j.oidc.credentials.authenticator.OidcAuthenticator;
@Slf4j
public final class SecurityUtil {
public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org";
private static final Collection<ClientAuthenticationMethod> SUPPORTED_METHODS =
Arrays.asList(
ClientAuthenticationMethod.CLIENT_SECRET_POST,
ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
ClientAuthenticationMethod.PRIVATE_KEY_JWT,
ClientAuthenticationMethod.NONE);
private SecurityUtil() {}
public static String getUserName(SecurityContext securityContext) {
@ -131,293 +77,6 @@ public final class SecurityUtil {
return target.request();
}
public static OidcClient tryCreateOidcClient(OidcClientConfig clientConfig) {
String id = clientConfig.getId();
String secret = clientConfig.getSecret();
if (CommonHelper.isNotBlank(id) && CommonHelper.isNotBlank(secret)) {
OidcConfiguration configuration = new OidcConfiguration();
configuration.setClientId(id);
configuration.setResponseMode("query");
// Add Secret
if (CommonHelper.isNotBlank(secret)) {
configuration.setSecret(secret);
}
// Response Type
String responseType = clientConfig.getResponseType();
if (CommonHelper.isNotBlank(responseType)) {
configuration.setResponseType(responseType);
}
String scope = clientConfig.getScope();
if (CommonHelper.isNotBlank(scope)) {
configuration.setScope(scope);
}
String discoveryUri = clientConfig.getDiscoveryUri();
if (CommonHelper.isNotBlank(discoveryUri)) {
configuration.setDiscoveryURI(discoveryUri);
}
String useNonce = clientConfig.getUseNonce();
if (CommonHelper.isNotBlank(useNonce)) {
configuration.setUseNonce(Boolean.parseBoolean(useNonce));
}
String jwsAlgo = clientConfig.getPreferredJwsAlgorithm();
if (CommonHelper.isNotBlank(jwsAlgo)) {
configuration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(jwsAlgo));
}
String maxClockSkew = clientConfig.getMaxClockSkew();
if (CommonHelper.isNotBlank(maxClockSkew)) {
configuration.setMaxClockSkew(Integer.parseInt(maxClockSkew));
}
String clientAuthenticationMethod = clientConfig.getClientAuthenticationMethod().value();
if (CommonHelper.isNotBlank(clientAuthenticationMethod)) {
configuration.setClientAuthenticationMethod(
ClientAuthenticationMethod.parse(clientAuthenticationMethod));
}
// Disable PKCE
configuration.setDisablePkce(clientConfig.getDisablePkce());
// Add Custom Params
if (clientConfig.getCustomParams() != null) {
for (int j = 1; j <= 5; ++j) {
if (clientConfig.getCustomParams().containsKey(String.format("customParamKey%d", j))) {
configuration.addCustomParam(
clientConfig.getCustomParams().get(String.format("customParamKey%d", j)),
clientConfig.getCustomParams().get(String.format("customParamValue%d", j)));
}
}
}
String type = clientConfig.getType();
OidcClient oidcClient;
if ("azure".equalsIgnoreCase(type)) {
AzureAd2OidcConfiguration azureAdConfiguration =
new AzureAd2OidcConfiguration(configuration);
String tenant = clientConfig.getTenant();
if (CommonHelper.isNotBlank(tenant)) {
azureAdConfiguration.setTenant(tenant);
}
oidcClient = new AzureAd2Client(azureAdConfiguration);
} else if ("google".equalsIgnoreCase(type)) {
oidcClient = new GoogleOidcClient(configuration);
// Google needs it as param
oidcClient.getConfiguration().getCustomParams().put("access_type", "offline");
} else {
oidcClient = new OidcClient(configuration);
}
oidcClient.setName(String.format("OMOidcClient%s", oidcClient.getName()));
return oidcClient;
}
throw new IllegalArgumentException(
"Client ID and Client Secret is required to create OidcClient");
}
public static ClientAuthentication getClientAuthentication(OidcConfiguration configuration) {
ClientID clientID = new ClientID(configuration.getClientId());
ClientAuthentication clientAuthenticationMechanism = null;
if (configuration.getSecret() != null) {
// check authentication methods
List<ClientAuthenticationMethod> metadataMethods =
configuration.findProviderMetadata().getTokenEndpointAuthMethods();
ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration);
final ClientAuthenticationMethod chosenMethod;
if (isNotEmpty(metadataMethods)) {
if (preferredMethod != null) {
if (metadataMethods.contains(preferredMethod)) {
chosenMethod = preferredMethod;
} else {
throw new TechnicalException(
"Preferred authentication method ("
+ preferredMethod
+ ") not supported "
+ "by provider according to provider metadata ("
+ metadataMethods
+ ").");
}
} else {
chosenMethod = firstSupportedMethod(metadataMethods);
}
} else {
chosenMethod =
preferredMethod != null ? preferredMethod : ClientAuthenticationMethod.getDefault();
LOG.info(
"Provider metadata does not provide Token endpoint authentication methods. Using: {}",
chosenMethod);
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(chosenMethod)) {
Secret clientSecret = new Secret(configuration.getSecret());
clientAuthenticationMechanism = new ClientSecretPost(clientID, clientSecret);
} else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(chosenMethod)) {
Secret clientSecret = new Secret(configuration.getSecret());
clientAuthenticationMechanism = new ClientSecretBasic(clientID, clientSecret);
} else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(chosenMethod)) {
PrivateKeyJWTClientAuthnMethodConfig privateKetJwtConfig =
configuration.getPrivateKeyJWTClientAuthnMethodConfig();
assertNotNull("privateKetJwtConfig", privateKetJwtConfig);
JWSAlgorithm jwsAlgo = privateKetJwtConfig.getJwsAlgorithm();
assertNotNull("privateKetJwtConfig.getJwsAlgorithm()", jwsAlgo);
PrivateKey privateKey = privateKetJwtConfig.getPrivateKey();
assertNotNull("privateKetJwtConfig.getPrivateKey()", privateKey);
String keyID = privateKetJwtConfig.getKeyID();
try {
clientAuthenticationMechanism =
new PrivateKeyJWT(
clientID,
configuration.findProviderMetadata().getTokenEndpointURI(),
jwsAlgo,
privateKey,
keyID,
null);
} catch (final JOSEException e) {
throw new TechnicalException(
"Cannot instantiate private key JWT client authentication method", e);
}
}
}
return clientAuthenticationMechanism;
}
private static ClientAuthenticationMethod getPreferredAuthenticationMethod(
OidcConfiguration config) {
ClientAuthenticationMethod configurationMethod = config.getClientAuthenticationMethod();
if (configurationMethod == null) {
return null;
}
if (!SUPPORTED_METHODS.contains(configurationMethod)) {
throw new TechnicalException(
"Configured authentication method (" + configurationMethod + ") is not supported.");
}
return configurationMethod;
}
private static ClientAuthenticationMethod firstSupportedMethod(
final List<ClientAuthenticationMethod> metadataMethods) {
Optional<ClientAuthenticationMethod> firstSupported =
metadataMethods.stream().filter(SUPPORTED_METHODS::contains).findFirst();
if (firstSupported.isPresent()) {
return firstSupported.get();
} else {
throw new TechnicalException(
"None of the Token endpoint provider metadata authentication methods are supported: "
+ metadataMethods);
}
}
@SneakyThrows
public static void getErrorMessage(HttpServletResponse resp, Exception e) {
resp.setContentType("text/html; charset=UTF-8");
LOG.error("[Auth Callback Servlet] Failed in Auth Login : {}", e.getMessage());
resp.getOutputStream()
.println(
String.format(
"<p> [Auth Callback Servlet] Failed in Auth Login : %s </p>", e.getMessage()));
}
public static void sendRedirectWithToken(
HttpServletResponse response,
OidcCredentials credentials,
String serverUrl,
Map<String, String> claimsMapping,
List<String> claimsOrder,
String defaultDomain)
throws ParseException, IOException {
JWT jwt = credentials.getIdToken();
Map<String, Object> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
claims.putAll(jwt.getJWTClaimsSet().getClaims());
String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims);
String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, defaultDomain);
String url =
String.format(
"%s/auth/callback?id_token=%s&email=%s&name=%s",
serverUrl, credentials.getIdToken().getParsedString(), email, userName);
response.sendRedirect(url);
}
public static boolean isCredentialsExpired(OidcCredentials credentials) throws ParseException {
Date expiration = credentials.getIdToken().getJWTClaimsSet().getExpirationTime();
return expiration != null && expiration.toInstant().isBefore(Instant.now().plusSeconds(30));
}
public static Optional<OidcCredentials> getUserCredentialsFromSession(
HttpServletRequest request, OidcClient client) throws ParseException {
OidcCredentials credentials =
(OidcCredentials) request.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE);
if (credentials != null && credentials.getRefreshToken() != null) {
removeOrRenewOidcCredentials(request, client, credentials);
return Optional.of(credentials);
} else {
if (credentials == null) {
LOG.error("No credentials found against session. ID: {}", request.getSession().getId());
} else {
LOG.error("No refresh token found against session. ID: {}", request.getSession().getId());
}
}
return Optional.empty();
}
private static void removeOrRenewOidcCredentials(
HttpServletRequest request, OidcClient client, OidcCredentials credentials) {
LOG.debug("Expired credentials found, trying to renew.");
if (client.getConfiguration() instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) {
refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials);
} else {
OidcAuthenticator authenticator = new OidcAuthenticator(client.getConfiguration(), client);
authenticator.refresh(credentials);
}
request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
}
private static void refreshAccessTokenAzureAd2Token(
AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) {
HttpURLConnection connection = null;
try {
Map<String, String> headers = new HashMap<>();
headers.put(
HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE);
headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON);
// get the token endpoint from discovery URI
URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL();
connection = HttpUtils.openPostConnection(tokenEndpointURL, headers);
BufferedWriter out =
new BufferedWriter(
new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8));
out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue()));
out.close();
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
throw new TechnicalException(
"request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection));
}
var body = HttpUtils.readBody(connection);
Map<String, Object> res = JsonUtils.readValue(body, new TypeReference<>() {});
azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token")));
} catch (final IOException e) {
throw new TechnicalException(e);
} finally {
HttpUtils.closeConnection(connection);
}
}
public static String findUserNameFromClaims(
Map<String, String> jwtPrincipalClaimsMapping,
List<String> jwtPrincipalClaimsOrder,
@ -470,7 +129,7 @@ public final class SecurityUtil {
}
}
private static String getClaimOrObject(Object obj) {
public static String getClaimOrObject(Object obj) {
if (obj == null) {
return "";
}

View File

@ -92,7 +92,7 @@ public class SamlAssertionConsumerServlet extends HttpServlet {
username,
getRoleListFromUser(user),
!nullOrEmpty(user.getIsAdmin()) && user.getIsAdmin(),
email,
user.getEmail(),
SamlSettingsHolder.getInstance().getTokenValidity(),
false,
ServiceTokenType.OM_USER);

View File

@ -56,6 +56,11 @@
"type": "string",
"enum": ["client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt"]
},
"tokenValidity": {
"description": "Validity for the JWT Token created from SAML Response",
"type": "integer",
"default": "3600"
},
"customParams": {
"description": "Custom Params.",
"existingJavaType" : "java.util.Map<String,String>",