diff --git a/conf/openmetadata.yaml b/conf/openmetadata.yaml index 527840ac282..bb4a99248e9 100644 --- a/conf/openmetadata.yaml +++ b/conf/openmetadata.yaml @@ -192,6 +192,7 @@ authenticationConfiguration: clientAuthenticationMethod: ${OIDC_CLIENT_AUTH_METHOD:-"client_secret_post"} tenant: ${OIDC_TENANT:-""} maxClockSkew: ${OIDC_MAX_CLOCK_SKEW:-""} + tokenValidity: ${OIDC_OM_REFRESH_TOKEN_VALIDITY:-"3600"} # in seconds customParams: ${OIDC_CUSTOM_PARAMS:-} samlConfiguration: debugMode: ${SAML_DEBUG_MODE:-false} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java index ed02a510efc..c89e2d0078c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java @@ -13,8 +13,6 @@ package org.openmetadata.service; -import static org.openmetadata.service.security.SecurityUtil.tryCreateOidcClient; - import io.dropwizard.Application; import io.dropwizard.configuration.EnvironmentVariableSubstitutor; import io.dropwizard.configuration.SubstitutingSourceProvider; @@ -105,6 +103,7 @@ import org.openmetadata.service.security.AuthCallbackServlet; import org.openmetadata.service.security.AuthLoginServlet; import org.openmetadata.service.security.AuthLogoutServlet; import org.openmetadata.service.security.AuthRefreshServlet; +import org.openmetadata.service.security.AuthenticationCodeFlowHandler; import org.openmetadata.service.security.Authorizer; import org.openmetadata.service.security.NoopAuthorizer; import org.openmetadata.service.security.NoopFilter; @@ -127,7 +126,6 @@ import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverity import org.openmetadata.service.util.jdbi.DatabaseAuthenticationProviderFactory; import org.openmetadata.service.util.jdbi.OMSqlLogger; import org.pac4j.core.util.CommonHelper; -import org.pac4j.oidc.client.OidcClient; import org.quartz.SchedulerException; /** Main catalog application */ @@ -273,55 +271,32 @@ public class OpenMetadataApplication extends Application claimsOrder; - private final Map claimsMapping; - private final String serverUrl; - private final String principalDomain; + private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler; - public AuthCallbackServlet( - OidcClient oidcClient, - AuthenticationConfiguration authenticationConfiguration, - AuthorizerConfiguration authorizerConfiguration) { - CommonHelper.assertNotBlank( - "ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl()); - this.client = oidcClient; - this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); - this.claimsMapping = - listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() - .map(s -> s.split(":")) - .collect(Collectors.toMap(s -> s[0], s -> s[1])); - validatePrincipalClaimsMapping(claimsMapping); - this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); - this.clientAuthentication = getClientAuthentication(client.getConfiguration()); - this.principalDomain = authorizerConfiguration.getPrincipalDomain(); + public AuthCallbackServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) { + this.authenticationCodeFlowHandler = authenticationCodeFlowHandler; } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) { - try { - LOG.debug("Performing Auth Callback For User Session: {} ", req.getSession().getId()); - String computedCallbackUrl = client.getCallbackUrl(); - Map> parameters = retrieveParameters(req); - AuthenticationResponse response = - AuthenticationResponseParser.parse(new URI(computedCallbackUrl), parameters); - - if (response instanceof AuthenticationErrorResponse authenticationErrorResponse) { - LOG.error( - "Bad authentication response, error={}", authenticationErrorResponse.getErrorObject()); - throw new TechnicalException("Bad authentication response"); - } - - LOG.debug("Authentication response successful"); - AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) response; - - OIDCProviderMetadata metadata = client.getConfiguration().getProviderMetadata(); - if (metadata.supportsAuthorizationResponseIssuerParam() - && !metadata.getIssuer().equals(successResponse.getIssuer())) { - throw new TechnicalException("Issuer mismatch, possible mix-up attack."); - } - - // Optional state validation - validateStateIfRequired(req, resp, successResponse); - - // Build Credentials - OidcCredentials credentials = buildCredentials(successResponse); - - // Validations - validateAndSendTokenRequest(req, credentials, computedCallbackUrl); - - // Log Error if the Refresh Token is null - if (credentials.getRefreshToken() == null) { - LOG.error("Refresh token is null for user session: {}", req.getSession().getId()); - } - - validateNonceIfRequired(req, credentials.getIdToken().getJWTClaimsSet()); - - // Put Credentials in Session - req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); - - // Redirect - sendRedirectWithToken( - resp, credentials, serverUrl, claimsMapping, claimsOrder, principalDomain); - } catch (Exception e) { - getErrorMessage(resp, e); - } - } - - private OidcCredentials buildCredentials(AuthenticationSuccessResponse successResponse) { - OidcCredentials credentials = new OidcCredentials(); - // get authorization code - AuthorizationCode code = successResponse.getAuthorizationCode(); - if (code != null) { - credentials.setCode(code); - } - // get ID token - JWT idToken = successResponse.getIDToken(); - if (idToken != null) { - credentials.setIdToken(idToken); - } - // get access token - AccessToken accessToken = successResponse.getAccessToken(); - if (accessToken != null) { - credentials.setAccessToken(accessToken); - } - - return credentials; - } - - private void validateNonceIfRequired(HttpServletRequest req, JWTClaimsSet claimsSet) - throws BadJOSEException { - if (client.getConfiguration().isUseNonce()) { - String expectedNonce = - (String) req.getSession().getAttribute(client.getNonceSessionAttributeName()); - if (CommonHelper.isNotBlank(expectedNonce)) { - String tokenNonce; - try { - tokenNonce = claimsSet.getStringClaim("nonce"); - } catch (java.text.ParseException var10) { - throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + var10.getMessage()); - } - - if (tokenNonce == null) { - throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION; - } - - if (!expectedNonce.equals(tokenNonce)) { - throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + tokenNonce); - } - } else { - throw new TechnicalException("Missing nonce parameter from Session."); - } - } - } - - private void validateStateIfRequired( - HttpServletRequest req, - HttpServletResponse resp, - AuthenticationSuccessResponse successResponse) { - if (client.getConfiguration().isWithState()) { - // Validate state for CSRF mitigation - State requestState = - (State) req.getSession().getAttribute(client.getStateSessionAttributeName()); - if (requestState == null || CommonHelper.isBlank(requestState.getValue())) { - getErrorMessage(resp, new TechnicalException("Missing state parameter")); - return; - } - - State responseState = successResponse.getState(); - if (responseState == null) { - throw new TechnicalException("Missing state parameter"); - } - - LOG.debug("Request state: {}/response state: {}", requestState, responseState); - if (!requestState.equals(responseState)) { - throw new TechnicalException( - "State parameter is different from the one sent in authentication request."); - } - } - } - - private void validateAndSendTokenRequest( - HttpServletRequest req, OidcCredentials oidcCredentials, String computedCallbackUrl) - throws IOException, ParseException, URISyntaxException { - if (oidcCredentials.getCode() != null) { - LOG.debug("Initiating Token Request for User Session: {} ", req.getSession().getId()); - CodeVerifier verifier = - (CodeVerifier) - req.getSession().getAttribute(client.getCodeVerifierSessionAttributeName()); - // Token request - TokenRequest request = - createTokenRequest( - new AuthorizationCodeGrant( - oidcCredentials.getCode(), new URI(computedCallbackUrl), verifier)); - executeTokenRequest(request, oidcCredentials); - } - } - - protected Map> retrieveParameters(HttpServletRequest request) { - Map requestParameters = request.getParameterMap(); - Map> map = new HashMap<>(); - for (var entry : requestParameters.entrySet()) { - map.put(entry.getKey(), Arrays.asList(entry.getValue())); - } - return map; - } - - protected TokenRequest createTokenRequest(final AuthorizationGrant grant) { - if (client.getConfiguration().getClientAuthenticationMethod() != null) { - return new TokenRequest( - client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), - this.clientAuthentication, - grant); - } else { - return new TokenRequest( - client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), - new ClientID(client.getConfiguration().getClientId()), - grant); - } - } - - private void executeTokenRequest(TokenRequest request, OidcCredentials credentials) - throws IOException, ParseException { - HTTPRequest tokenHttpRequest = request.toHTTPRequest(); - client.getConfiguration().configureHttpRequest(tokenHttpRequest); - - HTTPResponse httpResponse = tokenHttpRequest.send(); - LOG.debug( - "Token response: status={}, content={}", - httpResponse.getStatusCode(), - httpResponse.getContent()); - - TokenResponse response = OIDCTokenResponseParser.parse(httpResponse); - if (response instanceof TokenErrorResponse tokenErrorResponse) { - ErrorObject errorObject = tokenErrorResponse.getErrorObject(); - throw new TechnicalException( - "Bad token response, error=" - + errorObject.getCode() - + "," - + " description=" - + errorObject.getDescription()); - } - LOG.debug("Token response successful"); - OIDCTokenResponse tokenSuccessResponse = (OIDCTokenResponse) response; - - OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens(); - credentials.setAccessToken(oidcTokens.getAccessToken()); - credentials.setRefreshToken(oidcTokens.getRefreshToken()); - if (oidcTokens.getIDToken() != null) { - credentials.setIdToken(oidcTokens.getIDToken()); - } + authenticationCodeFlowHandler.handleCallback(req, resp); } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java index b3e4ff54c13..16e2f8bad7c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java @@ -1,166 +1,22 @@ package org.openmetadata.service.security; -import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; -import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; -import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession; -import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken; -import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping; - -import com.nimbusds.oauth2.sdk.id.State; -import com.nimbusds.oauth2.sdk.pkce.CodeChallenge; -import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; -import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; -import com.nimbusds.openid.connect.sdk.AuthenticationRequest; -import com.nimbusds.openid.connect.sdk.Nonce; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; -import org.openmetadata.schema.api.security.AuthenticationConfiguration; -import org.openmetadata.schema.api.security.AuthorizerConfiguration; -import org.pac4j.core.exception.TechnicalException; -import org.pac4j.core.util.CommonHelper; -import org.pac4j.oidc.client.GoogleOidcClient; -import org.pac4j.oidc.client.OidcClient; -import org.pac4j.oidc.config.OidcConfiguration; -import org.pac4j.oidc.credentials.OidcCredentials; @WebServlet("/api/v1/auth/login") @Slf4j public class AuthLoginServlet extends HttpServlet { - public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile"; - private final OidcClient client; - private final List claimsOrder; - private final Map claimsMapping; - private final String serverUrl; - private final String principalDomain; + private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler; - public AuthLoginServlet( - OidcClient oidcClient, - AuthenticationConfiguration authenticationConfiguration, - AuthorizerConfiguration authorizerConfiguration) { - this.client = oidcClient; - this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); - this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); - this.claimsMapping = - listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() - .map(s -> s.split(":")) - .collect(Collectors.toMap(s -> s[0], s -> s[1])); - validatePrincipalClaimsMapping(claimsMapping); - this.principalDomain = authorizerConfiguration.getPrincipalDomain(); + public AuthLoginServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) { + this.authenticationCodeFlowHandler = authenticationCodeFlowHandler; } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) { - try { - LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId()); - Optional credentials = getUserCredentialsFromSession(req, client); - if (credentials.isPresent()) { - LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId()); - sendRedirectWithToken( - resp, credentials.get(), serverUrl, claimsMapping, claimsOrder, principalDomain); - } else { - LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId()); - Map params = buildParams(); - - params.put(OidcConfiguration.REDIRECT_URI, client.getCallbackUrl()); - - addStateAndNonceParameters(req, params); - - // This is always used to prompt the user to login - if (client instanceof GoogleOidcClient) { - params.put(OidcConfiguration.PROMPT, "consent"); - } else { - params.put(OidcConfiguration.PROMPT, "login"); - } - params.put(OidcConfiguration.MAX_AGE, "0"); - - String location = buildAuthenticationRequestUrl(params); - LOG.debug("Authentication request url: {}", location); - - resp.sendRedirect(location); - } - } catch (Exception e) { - getErrorMessage(resp, new TechnicalException(e)); - } - } - - protected Map buildParams() { - Map authParams = new HashMap<>(); - authParams.put(OidcConfiguration.SCOPE, client.getConfiguration().getScope()); - authParams.put(OidcConfiguration.RESPONSE_TYPE, client.getConfiguration().getResponseType()); - authParams.put(OidcConfiguration.RESPONSE_MODE, "query"); - authParams.putAll(client.getConfiguration().getCustomParams()); - authParams.put(OidcConfiguration.CLIENT_ID, client.getConfiguration().getClientId()); - - return new HashMap<>(authParams); - } - - protected void addStateAndNonceParameters( - final HttpServletRequest request, final Map params) { - // Init state for CSRF mitigation - if (client.getConfiguration().isWithState()) { - State state = new State(CommonHelper.randomString(10)); - params.put(OidcConfiguration.STATE, state.getValue()); - request.getSession().setAttribute(client.getStateSessionAttributeName(), state); - } - - // Init nonce for replay attack mitigation - if (client.getConfiguration().isUseNonce()) { - Nonce nonce = new Nonce(); - params.put(OidcConfiguration.NONCE, nonce.getValue()); - request.getSession().setAttribute(client.getNonceSessionAttributeName(), nonce.getValue()); - } - - CodeChallengeMethod pkceMethod = client.getConfiguration().findPkceMethod(); - - // Use Default PKCE method if not disabled - if (pkceMethod == null && !client.getConfiguration().isDisablePkce()) { - pkceMethod = CodeChallengeMethod.S256; - } - if (pkceMethod != null) { - CodeVerifier verfifier = new CodeVerifier(CommonHelper.randomString(43)); - request.getSession().setAttribute(client.getCodeVerifierSessionAttributeName(), verfifier); - params.put( - OidcConfiguration.CODE_CHALLENGE, - CodeChallenge.compute(pkceMethod, verfifier).getValue()); - params.put(OidcConfiguration.CODE_CHALLENGE_METHOD, pkceMethod.getValue()); - } - } - - protected String buildAuthenticationRequestUrl(final Map params) { - // Build authentication request query string - String queryString; - try { - queryString = - AuthenticationRequest.parse( - params.entrySet().stream() - .collect( - Collectors.toMap( - Map.Entry::getKey, e -> Collections.singletonList(e.getValue())))) - .toQueryString(); - } catch (Exception e) { - throw new TechnicalException(e); - } - return client.getConfiguration().getProviderMetadata().getAuthorizationEndpointURI().toString() - + '?' - + queryString; - } - - public static void writeJsonResponse(HttpServletResponse response, String message) - throws IOException { - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - response.getOutputStream().print(message); - response.getOutputStream().flush(); - response.setStatus(HttpServletResponse.SC_OK); + authenticationCodeFlowHandler.handleLogin(req, resp); } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLogoutServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLogoutServlet.java index 4703a750033..a0f79c16c1f 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLogoutServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLogoutServlet.java @@ -4,33 +4,20 @@ import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; import lombok.extern.slf4j.Slf4j; @WebServlet("/api/v1/auth/logout") @Slf4j public class AuthLogoutServlet extends HttpServlet { - private final String url; + private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler; - public AuthLogoutServlet(String url) { - this.url = url; + public AuthLogoutServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) { + this.authenticationCodeFlowHandler = authenticationCodeFlowHandler; } @Override protected void doGet( final HttpServletRequest httpServletRequest, final HttpServletResponse httpServletResponse) { - try { - LOG.debug("Performing application logout"); - HttpSession session = httpServletRequest.getSession(false); - if (session != null) { - LOG.debug("Invalidating the session for logout"); - session.invalidate(); - httpServletResponse.sendRedirect(url); - } else { - LOG.error("No session store available for this web context"); - } - } catch (Exception ex) { - LOG.error("[Auth Logout] Error while performing logout", ex); - } + authenticationCodeFlowHandler.handleLogout(httpServletRequest, httpServletResponse); } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthRefreshServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthRefreshServlet.java index a40a7614b7e..e1a669bc895 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthRefreshServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthRefreshServlet.java @@ -1,58 +1,22 @@ package org.openmetadata.service.security; -import static org.openmetadata.service.security.AuthLoginServlet.writeJsonResponse; -import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; -import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession; - -import java.util.Optional; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; -import org.openmetadata.service.auth.JwtResponse; -import org.openmetadata.service.util.JsonUtils; -import org.pac4j.core.exception.TechnicalException; -import org.pac4j.oidc.client.OidcClient; -import org.pac4j.oidc.credentials.OidcCredentials; @WebServlet("/api/v1/auth/refresh") @Slf4j public class AuthRefreshServlet extends HttpServlet { - private final OidcClient client; - private final String baseUrl; + private final AuthenticationCodeFlowHandler authenticationCodeFlowHandler; - public AuthRefreshServlet(OidcClient oidcClient, String url) { - this.client = oidcClient; - this.baseUrl = url; + public AuthRefreshServlet(AuthenticationCodeFlowHandler authenticationCodeFlowHandler) { + this.authenticationCodeFlowHandler = authenticationCodeFlowHandler; } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) { - try { - LOG.debug("Performing Auth Refresh For User Session: {} ", req.getSession().getId()); - Optional credentials = getUserCredentialsFromSession(req, client); - if (credentials.isPresent()) { - LOG.debug("Credentials Found For User Session: {} ", req.getSession().getId()); - JwtResponse jwtResponse = new JwtResponse(); - jwtResponse.setAccessToken(credentials.get().getIdToken().getParsedString()); - jwtResponse.setExpiryDuration( - credentials - .get() - .getIdToken() - .getJWTClaimsSet() - .getExpirationTime() - .toInstant() - .getEpochSecond()); - writeJsonResponse(resp, JsonUtils.pojoToJson(jwtResponse)); - } else { - LOG.debug( - "Credentials Not Found For User Session: {}, Redirect to Logout ", - req.getSession().getId()); - resp.sendRedirect(String.format("%s/logout", baseUrl)); - } - } catch (Exception e) { - getErrorMessage(resp, new TechnicalException(e)); - } + authenticationCodeFlowHandler.handleRefresh(req, resp); } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java new file mode 100644 index 00000000000..1f7511e074e --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java @@ -0,0 +1,912 @@ +package org.openmetadata.service.security; + +import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; +import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; +import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY; +import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY; +import static org.openmetadata.service.security.SecurityUtil.findEmailFromClaims; +import static org.openmetadata.service.security.SecurityUtil.getClaimOrObject; +import static org.openmetadata.service.security.SecurityUtil.getFirstMatchJwtClaim; +import static org.openmetadata.service.util.UserUtil.getRoleListFromUser; +import static org.pac4j.core.util.CommonHelper.assertNotNull; +import static org.pac4j.core.util.CommonHelper.isNotEmpty; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jwt.JWT; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.proc.BadJWTException; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.ErrorObject; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.ClientSecretPost; +import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.http.HTTPRequest; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.pkce.CodeChallenge; +import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; +import com.nimbusds.oauth2.sdk.util.JSONObjectUtils; +import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse; +import com.nimbusds.openid.connect.sdk.AuthenticationRequest; +import com.nimbusds.openid.connect.sdk.AuthenticationResponse; +import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser; +import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse; +import com.nimbusds.openid.connect.sdk.Nonce; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import com.nimbusds.openid.connect.sdk.token.OIDCTokens; +import com.nimbusds.openid.connect.sdk.validators.BadJWTExceptions; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.text.ParseException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; +import java.util.stream.Collectors; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.ws.rs.BadRequestException; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import net.minidev.json.JSONObject; +import org.openmetadata.schema.api.security.AuthenticationConfiguration; +import org.openmetadata.schema.api.security.AuthorizerConfiguration; +import org.openmetadata.schema.auth.JWTAuthMechanism; +import org.openmetadata.schema.auth.ServiceTokenType; +import org.openmetadata.schema.entity.teams.User; +import org.openmetadata.schema.security.client.OidcClientConfig; +import org.openmetadata.schema.type.Include; +import org.openmetadata.service.Entity; +import org.openmetadata.service.auth.JwtResponse; +import org.openmetadata.service.security.jwt.JWTTokenGenerator; +import org.openmetadata.service.util.JsonUtils; +import org.pac4j.core.context.HttpConstants; +import org.pac4j.core.exception.TechnicalException; +import org.pac4j.core.util.CommonHelper; +import org.pac4j.core.util.HttpUtils; +import org.pac4j.oidc.client.AzureAd2Client; +import org.pac4j.oidc.client.GoogleOidcClient; +import org.pac4j.oidc.client.OidcClient; +import org.pac4j.oidc.config.AzureAd2OidcConfiguration; +import org.pac4j.oidc.config.OidcConfiguration; +import org.pac4j.oidc.config.PrivateKeyJWTClientAuthnMethodConfig; +import org.pac4j.oidc.credentials.OidcCredentials; + +@Slf4j +public class AuthenticationCodeFlowHandler { + private static final Collection SUPPORTED_METHODS = + Arrays.asList( + ClientAuthenticationMethod.CLIENT_SECRET_POST, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC, + ClientAuthenticationMethod.PRIVATE_KEY_JWT, + ClientAuthenticationMethod.NONE); + + public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org"; + public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile"; + private final OidcClient client; + private final List claimsOrder; + private final Map claimsMapping; + private final String serverUrl; + private final ClientAuthentication clientAuthentication; + private final String principalDomain; + private final int tokenValidity; + + public AuthenticationCodeFlowHandler( + AuthenticationConfiguration authenticationConfiguration, + AuthorizerConfiguration authorizerConfiguration) { + // Assert oidcConfig and Callback Url + CommonHelper.assertNotNull( + "OidcConfiguration", authenticationConfiguration.getOidcConfiguration()); + CommonHelper.assertNotBlank( + "CallbackUrl", authenticationConfiguration.getOidcConfiguration().getCallbackUrl()); + CommonHelper.assertNotBlank( + "ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl()); + + // Build Required Params + this.client = buildOidcClient(authenticationConfiguration.getOidcConfiguration()); + client.setCallbackUrl(authenticationConfiguration.getOidcConfiguration().getCallbackUrl()); + this.clientAuthentication = getClientAuthentication(client.getConfiguration()); + this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); + this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); + this.claimsMapping = + listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() + .map(s -> s.split(":")) + .collect(Collectors.toMap(s -> s[0], s -> s[1])); + validatePrincipalClaimsMapping(claimsMapping); + this.principalDomain = authorizerConfiguration.getPrincipalDomain(); + this.tokenValidity = authenticationConfiguration.getOidcConfiguration().getTokenValidity(); + } + + private OidcClient buildOidcClient(OidcClientConfig clientConfig) { + String id = clientConfig.getId(); + String secret = clientConfig.getSecret(); + if (CommonHelper.isNotBlank(id) && CommonHelper.isNotBlank(secret)) { + OidcConfiguration configuration = new OidcConfiguration(); + configuration.setClientId(id); + + configuration.setResponseMode("query"); + + // Add Secret + if (CommonHelper.isNotBlank(secret)) { + configuration.setSecret(secret); + } + + // Response Type + String responseType = clientConfig.getResponseType(); + if (CommonHelper.isNotBlank(responseType)) { + configuration.setResponseType(responseType); + } + + String scope = clientConfig.getScope(); + if (CommonHelper.isNotBlank(scope)) { + configuration.setScope(scope); + } + + String discoveryUri = clientConfig.getDiscoveryUri(); + if (CommonHelper.isNotBlank(discoveryUri)) { + configuration.setDiscoveryURI(discoveryUri); + } + + String useNonce = clientConfig.getUseNonce(); + if (CommonHelper.isNotBlank(useNonce)) { + configuration.setUseNonce(Boolean.parseBoolean(useNonce)); + } + + String jwsAlgo = clientConfig.getPreferredJwsAlgorithm(); + if (CommonHelper.isNotBlank(jwsAlgo)) { + configuration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(jwsAlgo)); + } + + String maxClockSkew = clientConfig.getMaxClockSkew(); + if (CommonHelper.isNotBlank(maxClockSkew)) { + configuration.setMaxClockSkew(Integer.parseInt(maxClockSkew)); + } + + String clientAuthenticationMethod = clientConfig.getClientAuthenticationMethod().value(); + if (CommonHelper.isNotBlank(clientAuthenticationMethod)) { + configuration.setClientAuthenticationMethod( + ClientAuthenticationMethod.parse(clientAuthenticationMethod)); + } + + // Disable PKCE + configuration.setDisablePkce(clientConfig.getDisablePkce()); + + // Add Custom Params + if (clientConfig.getCustomParams() != null) { + for (int j = 1; j <= 5; ++j) { + if (clientConfig.getCustomParams().containsKey(String.format("customParamKey%d", j))) { + configuration.addCustomParam( + clientConfig.getCustomParams().get(String.format("customParamKey%d", j)), + clientConfig.getCustomParams().get(String.format("customParamValue%d", j))); + } + } + } + + String type = clientConfig.getType(); + OidcClient oidcClient; + if ("azure".equalsIgnoreCase(type)) { + AzureAd2OidcConfiguration azureAdConfiguration = + new AzureAd2OidcConfiguration(configuration); + String tenant = clientConfig.getTenant(); + if (CommonHelper.isNotBlank(tenant)) { + azureAdConfiguration.setTenant(tenant); + } + + oidcClient = new AzureAd2Client(azureAdConfiguration); + } else if ("google".equalsIgnoreCase(type)) { + oidcClient = new GoogleOidcClient(configuration); + // Google needs it as param + oidcClient.getConfiguration().getCustomParams().put("access_type", "offline"); + } else { + oidcClient = new OidcClient(configuration); + } + + oidcClient.setName(String.format("OMOidcClient%s", oidcClient.getName())); + return oidcClient; + } + throw new IllegalArgumentException( + "Client ID and Client Secret is required to create OidcClient"); + } + + // Login + public void handleLogin(HttpServletRequest req, HttpServletResponse resp) { + try { + LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId()); + Optional credentials = getUserCredentialsFromSession(req); + if (credentials.isPresent()) { + LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId()); + sendRedirectWithToken(resp, credentials.get()); + } else { + LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId()); + Map params = buildLoginParams(); + + params.put(OidcConfiguration.REDIRECT_URI, client.getCallbackUrl()); + + addStateAndNonceParameters(client, req, params); + + // This is always used to prompt the user to login + if (client instanceof GoogleOidcClient) { + params.put(OidcConfiguration.PROMPT, "consent"); + } else { + params.put(OidcConfiguration.PROMPT, "login"); + } + params.put(OidcConfiguration.MAX_AGE, "0"); + + String location = buildLoginAuthenticationRequestUrl(params); + LOG.debug("Authentication request url: {}", location); + + resp.sendRedirect(location); + } + } catch (Exception e) { + getErrorMessage(resp, new TechnicalException(e)); + } + } + + // Callback + public void handleCallback(HttpServletRequest req, HttpServletResponse resp) { + try { + LOG.debug("Performing Auth Callback For User Session: {} ", req.getSession().getId()); + String computedCallbackUrl = client.getCallbackUrl(); + Map> parameters = retrieveCallbackParameters(req); + AuthenticationResponse response = + AuthenticationResponseParser.parse(new URI(computedCallbackUrl), parameters); + + if (response instanceof AuthenticationErrorResponse authenticationErrorResponse) { + LOG.error( + "Bad authentication response, error={}", authenticationErrorResponse.getErrorObject()); + throw new TechnicalException("Bad authentication response"); + } + + LOG.debug("Authentication response successful"); + AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) response; + + OIDCProviderMetadata metadata = client.getConfiguration().getProviderMetadata(); + if (metadata.supportsAuthorizationResponseIssuerParam() + && !metadata.getIssuer().equals(successResponse.getIssuer())) { + throw new TechnicalException("Issuer mismatch, possible mix-up attack."); + } + + // Optional state validation + validateStateIfRequired(req, resp, successResponse); + + // Build Credentials + OidcCredentials credentials = buildCredentials(successResponse); + + // Validations + validateAndSendTokenRequest(req, credentials, computedCallbackUrl); + + // Log Error if the Refresh Token is null + if (credentials.getRefreshToken() == null) { + LOG.error("Refresh token is null for user session: {}", req.getSession().getId()); + } + + validateNonceIfRequired(req, credentials.getIdToken().getJWTClaimsSet()); + + // Put Credentials in Session + req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); + + // Redirect + sendRedirectWithToken(resp, credentials); + } catch (Exception e) { + getErrorMessage(resp, e); + } + } + + // Logout + public void handleLogout( + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + try { + LOG.debug("Performing application logout"); + HttpSession session = httpServletRequest.getSession(false); + if (session != null) { + LOG.debug("Invalidating the session for logout"); + session.invalidate(); + httpServletResponse.sendRedirect(serverUrl); + } else { + LOG.error("No session store available for this web context"); + } + } catch (Exception ex) { + LOG.error("[Auth Logout] Error while performing logout", ex); + } + } + + // Refresh + public void handleRefresh( + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + try { + LOG.debug( + "Performing Auth Refresh For User Session: {} ", httpServletRequest.getSession().getId()); + Optional credentials = getUserCredentialsFromSession(httpServletRequest); + if (credentials.isPresent()) { + LOG.debug( + "Credentials Found For User Session: {} ", httpServletRequest.getSession().getId()); + JwtResponse jwtResponse = new JwtResponse(); + jwtResponse.setAccessToken(credentials.get().getIdToken().getParsedString()); + jwtResponse.setExpiryDuration( + credentials + .get() + .getIdToken() + .getJWTClaimsSet() + .getExpirationTime() + .toInstant() + .getEpochSecond()); + writeJsonResponse(httpServletResponse, JsonUtils.pojoToJson(jwtResponse)); + } else { + LOG.debug( + "Credentials Not Found For User Session: {}, Redirect to Logout ", + httpServletRequest.getSession().getId()); + httpServletResponse.sendRedirect(String.format("%s/logout", serverUrl)); + } + } catch (Exception e) { + getErrorMessage(httpServletResponse, new TechnicalException(e)); + } + } + + private String buildLoginAuthenticationRequestUrl(final Map params) { + // Build authentication request query string + String queryString; + try { + queryString = + AuthenticationRequest.parse( + params.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, e -> Collections.singletonList(e.getValue())))) + .toQueryString(); + } catch (Exception e) { + throw new TechnicalException(e); + } + return client.getConfiguration().getProviderMetadata().getAuthorizationEndpointURI().toString() + + '?' + + queryString; + } + + private Map buildLoginParams() { + Map authParams = new HashMap<>(); + authParams.put(OidcConfiguration.SCOPE, client.getConfiguration().getScope()); + authParams.put(OidcConfiguration.RESPONSE_TYPE, client.getConfiguration().getResponseType()); + authParams.put(OidcConfiguration.RESPONSE_MODE, "query"); + authParams.putAll(client.getConfiguration().getCustomParams()); + authParams.put(OidcConfiguration.CLIENT_ID, client.getConfiguration().getClientId()); + + return new HashMap<>(authParams); + } + + private Optional getUserCredentialsFromSession(HttpServletRequest request) + throws URISyntaxException { + OidcCredentials credentials = + (OidcCredentials) request.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE); + + if (credentials != null && credentials.getRefreshToken() != null) { + LOG.trace("Credentials found in session: {}", credentials); + renewOidcCredentials(request, credentials); + return Optional.of(credentials); + } else { + if (credentials == null) { + LOG.error("No credentials found against session. ID: {}", request.getSession().getId()); + } else { + LOG.error("No refresh token found against session. ID: {}", request.getSession().getId()); + } + } + return Optional.empty(); + } + + private void validateAndSendTokenRequest( + HttpServletRequest req, OidcCredentials oidcCredentials, String computedCallbackUrl) + throws IOException, com.nimbusds.oauth2.sdk.ParseException, URISyntaxException { + if (oidcCredentials.getCode() != null) { + LOG.debug("Initiating Token Request for User Session: {} ", req.getSession().getId()); + CodeVerifier verifier = + (CodeVerifier) + req.getSession().getAttribute(client.getCodeVerifierSessionAttributeName()); + // Token request + TokenRequest request = + createTokenRequest( + new AuthorizationCodeGrant( + oidcCredentials.getCode(), new URI(computedCallbackUrl), verifier)); + executeAuthorizationCodeTokenRequest(request, oidcCredentials); + } + } + + private void validateStateIfRequired( + HttpServletRequest req, + HttpServletResponse resp, + AuthenticationSuccessResponse successResponse) { + if (client.getConfiguration().isWithState()) { + // Validate state for CSRF mitigation + State requestState = + (State) req.getSession().getAttribute(client.getStateSessionAttributeName()); + if (requestState == null || CommonHelper.isBlank(requestState.getValue())) { + getErrorMessage(resp, new TechnicalException("Missing state parameter")); + return; + } + + State responseState = successResponse.getState(); + if (responseState == null) { + throw new TechnicalException("Missing state parameter"); + } + + LOG.debug("Request state: {}/response state: {}", requestState, responseState); + if (!requestState.equals(responseState)) { + throw new TechnicalException( + "State parameter is different from the one sent in authentication request."); + } + } + } + + private OidcCredentials buildCredentials(AuthenticationSuccessResponse successResponse) { + OidcCredentials credentials = new OidcCredentials(); + // get authorization code + AuthorizationCode code = successResponse.getAuthorizationCode(); + if (code != null) { + credentials.setCode(code); + } + // get ID token + JWT idToken = successResponse.getIDToken(); + if (idToken != null) { + credentials.setIdToken(idToken); + } + // get access token + AccessToken accessToken = successResponse.getAccessToken(); + if (accessToken != null) { + credentials.setAccessToken(accessToken); + } + + return credentials; + } + + private void validateNonceIfRequired(HttpServletRequest req, JWTClaimsSet claimsSet) + throws BadJOSEException { + if (client.getConfiguration().isUseNonce()) { + String expectedNonce = + (String) req.getSession().getAttribute(client.getNonceSessionAttributeName()); + if (CommonHelper.isNotBlank(expectedNonce)) { + String tokenNonce; + try { + tokenNonce = claimsSet.getStringClaim("nonce"); + } catch (java.text.ParseException var10) { + throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + var10.getMessage()); + } + + if (tokenNonce == null) { + throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION; + } + + if (!expectedNonce.equals(tokenNonce)) { + throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + tokenNonce); + } + } else { + throw new TechnicalException("Missing nonce parameter from Session."); + } + } + } + + protected Map> retrieveCallbackParameters(HttpServletRequest request) { + Map requestParameters = request.getParameterMap(); + Map> map = new HashMap<>(); + for (var entry : requestParameters.entrySet()) { + map.put(entry.getKey(), Arrays.asList(entry.getValue())); + } + return map; + } + + private void writeJsonResponse(HttpServletResponse response, String message) throws IOException { + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.getOutputStream().print(message); + response.getOutputStream().flush(); + response.setStatus(HttpServletResponse.SC_OK); + } + + private ClientAuthentication getClientAuthentication(OidcConfiguration configuration) { + ClientID clientID = new ClientID(configuration.getClientId()); + ClientAuthentication clientAuthenticationMechanism = null; + if (configuration.getSecret() != null) { + // check authentication methods + List metadataMethods = + configuration.findProviderMetadata().getTokenEndpointAuthMethods(); + + ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration); + + final ClientAuthenticationMethod chosenMethod; + if (isNotEmpty(metadataMethods)) { + if (preferredMethod != null) { + if (metadataMethods.contains(preferredMethod)) { + chosenMethod = preferredMethod; + } else { + throw new TechnicalException( + "Preferred authentication method (" + + preferredMethod + + ") not supported " + + "by provider according to provider metadata (" + + metadataMethods + + ")."); + } + } else { + chosenMethod = firstSupportedMethod(metadataMethods); + } + } else { + chosenMethod = + preferredMethod != null ? preferredMethod : ClientAuthenticationMethod.getDefault(); + LOG.info( + "Provider metadata does not provide Token endpoint authentication methods. Using: {}", + chosenMethod); + } + + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(chosenMethod)) { + Secret clientSecret = new Secret(configuration.getSecret()); + clientAuthenticationMechanism = new ClientSecretPost(clientID, clientSecret); + } else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(chosenMethod)) { + Secret clientSecret = new Secret(configuration.getSecret()); + clientAuthenticationMechanism = new ClientSecretBasic(clientID, clientSecret); + } else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(chosenMethod)) { + PrivateKeyJWTClientAuthnMethodConfig privateKetJwtConfig = + configuration.getPrivateKeyJWTClientAuthnMethodConfig(); + assertNotNull("privateKetJwtConfig", privateKetJwtConfig); + JWSAlgorithm jwsAlgo = privateKetJwtConfig.getJwsAlgorithm(); + assertNotNull("privateKetJwtConfig.getJwsAlgorithm()", jwsAlgo); + PrivateKey privateKey = privateKetJwtConfig.getPrivateKey(); + assertNotNull("privateKetJwtConfig.getPrivateKey()", privateKey); + String keyID = privateKetJwtConfig.getKeyID(); + try { + clientAuthenticationMechanism = + new PrivateKeyJWT( + clientID, + configuration.findProviderMetadata().getTokenEndpointURI(), + jwsAlgo, + privateKey, + keyID, + null); + } catch (final JOSEException e) { + throw new TechnicalException( + "Cannot instantiate private key JWT client authentication method", e); + } + } + } + + return clientAuthenticationMechanism; + } + + private static ClientAuthenticationMethod getPreferredAuthenticationMethod( + OidcConfiguration config) { + ClientAuthenticationMethod configurationMethod = config.getClientAuthenticationMethod(); + if (configurationMethod == null) { + return null; + } + + if (!SUPPORTED_METHODS.contains(configurationMethod)) { + throw new TechnicalException( + "Configured authentication method (" + configurationMethod + ") is not supported."); + } + + return configurationMethod; + } + + private ClientAuthenticationMethod firstSupportedMethod( + final List metadataMethods) { + Optional firstSupported = + metadataMethods.stream().filter(SUPPORTED_METHODS::contains).findFirst(); + if (firstSupported.isPresent()) { + return firstSupported.get(); + } else { + throw new TechnicalException( + "None of the Token endpoint provider metadata authentication methods are supported: " + + metadataMethods); + } + } + + @SneakyThrows + public static void getErrorMessage(HttpServletResponse resp, Exception e) { + resp.setContentType("text/html; charset=UTF-8"); + LOG.error("[Auth Callback Servlet] Failed in Auth Login : {}", e.getMessage()); + resp.getOutputStream() + .println( + String.format( + "

[Auth Callback Servlet] Failed in Auth Login : %s

", e.getMessage())); + } + + private void sendRedirectWithToken(HttpServletResponse response, OidcCredentials credentials) + throws ParseException, IOException { + JWT jwt = credentials.getIdToken(); + Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + claims.putAll(jwt.getJWTClaimsSet().getClaims()); + + String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims); + String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, principalDomain); + + String url = + String.format( + "%s/auth/callback?id_token=%s&email=%s&name=%s", + serverUrl, credentials.getIdToken().getParsedString(), email, userName); + response.sendRedirect(url); + } + + private void renewOidcCredentials(HttpServletRequest request, OidcCredentials credentials) { + LOG.debug("Renewing Credentials for User Session {}", request.getSession().getId()); + if (client.getConfiguration() instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) { + refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials); + } else { + refreshTokenRequest(request, credentials); + } + request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); + } + + public void refreshTokenRequest( + final HttpServletRequest httpServletRequest, final OidcCredentials credentials) { + final var refreshToken = credentials.getRefreshToken(); + if (refreshToken != null) { + try { + final var request = createTokenRequest(new RefreshTokenGrant(refreshToken)); + HTTPResponse httpResponse = executeTokenHttpRequest(request); + if (httpResponse.getStatusCode() == 200) { + JSONObject jsonObjectResponse = httpResponse.getContentAsJSONObject(); + String idTokenKey = "id_token"; + if (jsonObjectResponse.containsKey(idTokenKey)) { + Object value = jsonObjectResponse.get(idTokenKey); + if (value == null) { + throw new com.nimbusds.oauth2.sdk.ParseException( + "JSON object member with key " + idTokenKey + " has null value"); + } else { + LOG.info("Found a JWT token in the response, trying to parse it"); + OIDCTokenResponse tokenSuccessResponse = + parseTokenResponseFromHttpResponse(httpResponse); + // Populate credentials + populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); + } + } else { + // Note: since the id_token is not present, we must receive accessToken + // We can do better and get userInfo from + // client.getConfiguration().findProviderMetadata().getUserInfoEndpointURI() + // but currently we are just return the OM created token in the response + String accessToken = JSONObjectUtils.getString(jsonObjectResponse, "access_token"); + LOG.info( + "Found an access token in the response, trying to parse it, Value : {}", + accessToken); + OIDCTokenResponse tokenSuccessResponse = + parseTokenResponseFromHttpResponse(httpResponse); + // Populate credentials + populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); + + OidcCredentials storedCredentials = + (OidcCredentials) + httpServletRequest.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE); + + // Get the claims from the stored credentials + Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + claims.putAll(storedCredentials.getIdToken().getJWTClaimsSet().getClaims()); + + String username = + SecurityUtil.findUserNameFromClaims(claimsMapping, claimsOrder, claims); + User user = Entity.getEntityByName(Entity.USER, username, "id", Include.NON_DELETED); + + // Create a JWT here + JWTAuthMechanism jwtAuthMechanism = + JWTTokenGenerator.getInstance() + .generateJWTToken( + username, + getRoleListFromUser(user), + !nullOrEmpty(user.getIsAdmin()) && user.getIsAdmin(), + user.getEmail(), + tokenValidity, + false, + ServiceTokenType.OM_USER); + // Set the access token to the new JWT token + credentials.setIdToken(SignedJWT.parse(jwtAuthMechanism.getJWTToken())); + } + return; + } else { + throw new TechnicalException( + String.format( + "Failed to refresh id_token, response code:%s , Error : %s", + httpResponse.getStatusCode(), httpResponse.getContent())); + } + + } catch (final IOException | com.nimbusds.oauth2.sdk.ParseException e) { + throw new TechnicalException(e); + } catch (ParseException e) { + throw new RuntimeException(e); + } + } + throw new BadRequestException("No refresh token available"); + } + + public static boolean isJWT(String token) { + return token.split("\\.").length == 3; + } + + private void refreshAccessTokenAzureAd2Token( + AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) { + HttpURLConnection connection = null; + try { + Map headers = new HashMap<>(); + headers.put( + HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE); + headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON); + // get the token endpoint from discovery URI + URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL(); + connection = HttpUtils.openPostConnection(tokenEndpointURL, headers); + + BufferedWriter out = + new BufferedWriter( + new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8)); + out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue())); + out.close(); + + int responseCode = connection.getResponseCode(); + if (responseCode != 200) { + throw new TechnicalException( + "request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection)); + } + var body = HttpUtils.readBody(connection); + Map res = JsonUtils.readValue(body, new TypeReference<>() {}); + azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token"))); + } catch (final IOException e) { + throw new TechnicalException(e); + } finally { + HttpUtils.closeConnection(connection); + } + } + + public static String findUserNameFromClaims( + Map jwtPrincipalClaimsMapping, + List jwtPrincipalClaimsOrder, + Map claims) { + if (!nullOrEmpty(jwtPrincipalClaimsMapping)) { + // We have a mapping available so we will use that + String usernameClaim = jwtPrincipalClaimsMapping.get(USERNAME_CLAIM_KEY); + String userNameClaimValue = getClaimOrObject(claims.get(usernameClaim)); + if (!nullOrEmpty(userNameClaimValue)) { + return userNameClaimValue; + } else { + throw new AuthenticationException("Invalid JWT token, 'username' claim is not present"); + } + } else { + String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims); + String userName; + if (jwtClaim.contains("@")) { + userName = jwtClaim.split("@")[0]; + } else { + userName = jwtClaim; + } + return userName; + } + } + + public static void validatePrincipalClaimsMapping(Map mapping) { + if (!nullOrEmpty(mapping)) { + String username = mapping.get(USERNAME_CLAIM_KEY); + String email = mapping.get(EMAIL_CLAIM_KEY); + if (nullOrEmpty(username) || nullOrEmpty(email)) { + throw new IllegalArgumentException( + "Invalid JWT Principal Claims Mapping. Both username and email should be present"); + } + } + // If emtpy, jwtPrincipalClaims will be used so no need to validate + } + + private HTTPResponse executeTokenHttpRequest(TokenRequest request) throws IOException { + HTTPRequest tokenHttpRequest = request.toHTTPRequest(); + client.getConfiguration().configureHttpRequest(tokenHttpRequest); + + HTTPResponse httpResponse = tokenHttpRequest.send(); + LOG.debug( + "Token response: status={}, content={}", + httpResponse.getStatusCode(), + httpResponse.getContent()); + + return httpResponse; + } + + private TokenRequest createTokenRequest(final AuthorizationGrant grant) { + if (clientAuthentication != null) { + return new TokenRequest( + client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), + this.clientAuthentication, + grant); + } else { + return new TokenRequest( + client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), + new ClientID(client.getConfiguration().getClientId()), + grant); + } + } + + private void addStateAndNonceParameters( + final OidcClient client, final HttpServletRequest request, final Map params) { + // Init state for CSRF mitigation + if (client.getConfiguration().isWithState()) { + State state = new State(CommonHelper.randomString(10)); + params.put(OidcConfiguration.STATE, state.getValue()); + request.getSession().setAttribute(client.getStateSessionAttributeName(), state); + } + + // Init nonce for replay attack mitigation + if (client.getConfiguration().isUseNonce()) { + Nonce nonce = new Nonce(); + params.put(OidcConfiguration.NONCE, nonce.getValue()); + request.getSession().setAttribute(client.getNonceSessionAttributeName(), nonce.getValue()); + } + + CodeChallengeMethod pkceMethod = client.getConfiguration().findPkceMethod(); + + // Use Default PKCE method if not disabled + if (pkceMethod == null && !client.getConfiguration().isDisablePkce()) { + pkceMethod = CodeChallengeMethod.S256; + } + if (pkceMethod != null) { + CodeVerifier verfifier = new CodeVerifier(CommonHelper.randomString(43)); + request.getSession().setAttribute(client.getCodeVerifierSessionAttributeName(), verfifier); + params.put( + OidcConfiguration.CODE_CHALLENGE, + CodeChallenge.compute(pkceMethod, verfifier).getValue()); + params.put(OidcConfiguration.CODE_CHALLENGE_METHOD, pkceMethod.getValue()); + } + } + + private void executeAuthorizationCodeTokenRequest( + TokenRequest request, OidcCredentials credentials) + throws IOException, com.nimbusds.oauth2.sdk.ParseException { + HTTPResponse httpResponse = executeTokenHttpRequest(request); + OIDCTokenResponse tokenSuccessResponse = parseTokenResponseFromHttpResponse(httpResponse); + + // Populate credentials + populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); + } + + private void populateCredentialsFromTokenResponse( + OIDCTokenResponse tokenSuccessResponse, OidcCredentials credentials) { + OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens(); + credentials.setAccessToken(oidcTokens.getAccessToken()); + credentials.setRefreshToken(oidcTokens.getRefreshToken()); + if (oidcTokens.getIDToken() != null) { + credentials.setIdToken(oidcTokens.getIDToken()); + } + } + + private OIDCTokenResponse parseTokenResponseFromHttpResponse(HTTPResponse httpResponse) + throws com.nimbusds.oauth2.sdk.ParseException { + TokenResponse response = OIDCTokenResponseParser.parse(httpResponse); + if (response instanceof TokenErrorResponse tokenErrorResponse) { + ErrorObject errorObject = tokenErrorResponse.getErrorObject(); + throw new TechnicalException( + "Bad token response, error=" + + errorObject.getCode() + + "," + + " description=" + + errorObject.getDescription()); + } + LOG.debug("Token response successful"); + return (OIDCTokenResponse) response; + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java index f8055248484..c5ad60f725c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java @@ -14,82 +14,28 @@ package org.openmetadata.service.security; import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; -import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE; import static org.openmetadata.service.security.JwtFilter.BOT_CLAIM; import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY; import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY; -import static org.pac4j.core.util.CommonHelper.assertNotNull; -import static org.pac4j.core.util.CommonHelper.isNotEmpty; import com.auth0.jwt.interfaces.Claim; -import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; -import com.nimbusds.jose.JOSEException; -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jwt.JWT; -import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; -import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod; -import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; -import com.nimbusds.oauth2.sdk.auth.ClientSecretPost; -import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT; -import com.nimbusds.oauth2.sdk.auth.Secret; -import com.nimbusds.oauth2.sdk.id.ClientID; -import com.nimbusds.oauth2.sdk.token.BearerAccessToken; -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.net.HttpURLConnection; -import java.net.URL; -import java.nio.charset.StandardCharsets; import java.security.Principal; -import java.security.PrivateKey; -import java.text.ParseException; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collection; -import java.util.Date; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.TreeMap; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import javax.ws.rs.client.Invocation; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.SecurityContext; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang.StringUtils; import org.openmetadata.common.utils.CommonUtil; -import org.openmetadata.schema.security.client.OidcClientConfig; import org.openmetadata.service.OpenMetadataApplicationConfig; -import org.openmetadata.service.util.JsonUtils; -import org.pac4j.core.context.HttpConstants; -import org.pac4j.core.exception.TechnicalException; -import org.pac4j.core.util.CommonHelper; -import org.pac4j.core.util.HttpUtils; -import org.pac4j.oidc.client.AzureAd2Client; -import org.pac4j.oidc.client.GoogleOidcClient; -import org.pac4j.oidc.client.OidcClient; -import org.pac4j.oidc.config.AzureAd2OidcConfiguration; -import org.pac4j.oidc.config.OidcConfiguration; -import org.pac4j.oidc.config.PrivateKeyJWTClientAuthnMethodConfig; -import org.pac4j.oidc.credentials.OidcCredentials; -import org.pac4j.oidc.credentials.authenticator.OidcAuthenticator; @Slf4j public final class SecurityUtil { public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org"; - private static final Collection SUPPORTED_METHODS = - Arrays.asList( - ClientAuthenticationMethod.CLIENT_SECRET_POST, - ClientAuthenticationMethod.CLIENT_SECRET_BASIC, - ClientAuthenticationMethod.PRIVATE_KEY_JWT, - ClientAuthenticationMethod.NONE); - private SecurityUtil() {} public static String getUserName(SecurityContext securityContext) { @@ -131,293 +77,6 @@ public final class SecurityUtil { return target.request(); } - public static OidcClient tryCreateOidcClient(OidcClientConfig clientConfig) { - String id = clientConfig.getId(); - String secret = clientConfig.getSecret(); - if (CommonHelper.isNotBlank(id) && CommonHelper.isNotBlank(secret)) { - OidcConfiguration configuration = new OidcConfiguration(); - configuration.setClientId(id); - - configuration.setResponseMode("query"); - - // Add Secret - if (CommonHelper.isNotBlank(secret)) { - configuration.setSecret(secret); - } - - // Response Type - String responseType = clientConfig.getResponseType(); - if (CommonHelper.isNotBlank(responseType)) { - configuration.setResponseType(responseType); - } - - String scope = clientConfig.getScope(); - if (CommonHelper.isNotBlank(scope)) { - configuration.setScope(scope); - } - - String discoveryUri = clientConfig.getDiscoveryUri(); - if (CommonHelper.isNotBlank(discoveryUri)) { - configuration.setDiscoveryURI(discoveryUri); - } - - String useNonce = clientConfig.getUseNonce(); - if (CommonHelper.isNotBlank(useNonce)) { - configuration.setUseNonce(Boolean.parseBoolean(useNonce)); - } - - String jwsAlgo = clientConfig.getPreferredJwsAlgorithm(); - if (CommonHelper.isNotBlank(jwsAlgo)) { - configuration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(jwsAlgo)); - } - - String maxClockSkew = clientConfig.getMaxClockSkew(); - if (CommonHelper.isNotBlank(maxClockSkew)) { - configuration.setMaxClockSkew(Integer.parseInt(maxClockSkew)); - } - - String clientAuthenticationMethod = clientConfig.getClientAuthenticationMethod().value(); - if (CommonHelper.isNotBlank(clientAuthenticationMethod)) { - configuration.setClientAuthenticationMethod( - ClientAuthenticationMethod.parse(clientAuthenticationMethod)); - } - - // Disable PKCE - configuration.setDisablePkce(clientConfig.getDisablePkce()); - - // Add Custom Params - if (clientConfig.getCustomParams() != null) { - for (int j = 1; j <= 5; ++j) { - if (clientConfig.getCustomParams().containsKey(String.format("customParamKey%d", j))) { - configuration.addCustomParam( - clientConfig.getCustomParams().get(String.format("customParamKey%d", j)), - clientConfig.getCustomParams().get(String.format("customParamValue%d", j))); - } - } - } - - String type = clientConfig.getType(); - OidcClient oidcClient; - if ("azure".equalsIgnoreCase(type)) { - AzureAd2OidcConfiguration azureAdConfiguration = - new AzureAd2OidcConfiguration(configuration); - String tenant = clientConfig.getTenant(); - if (CommonHelper.isNotBlank(tenant)) { - azureAdConfiguration.setTenant(tenant); - } - - oidcClient = new AzureAd2Client(azureAdConfiguration); - } else if ("google".equalsIgnoreCase(type)) { - oidcClient = new GoogleOidcClient(configuration); - // Google needs it as param - oidcClient.getConfiguration().getCustomParams().put("access_type", "offline"); - } else { - oidcClient = new OidcClient(configuration); - } - - oidcClient.setName(String.format("OMOidcClient%s", oidcClient.getName())); - return oidcClient; - } - throw new IllegalArgumentException( - "Client ID and Client Secret is required to create OidcClient"); - } - - public static ClientAuthentication getClientAuthentication(OidcConfiguration configuration) { - ClientID clientID = new ClientID(configuration.getClientId()); - ClientAuthentication clientAuthenticationMechanism = null; - if (configuration.getSecret() != null) { - // check authentication methods - List metadataMethods = - configuration.findProviderMetadata().getTokenEndpointAuthMethods(); - - ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration); - - final ClientAuthenticationMethod chosenMethod; - if (isNotEmpty(metadataMethods)) { - if (preferredMethod != null) { - if (metadataMethods.contains(preferredMethod)) { - chosenMethod = preferredMethod; - } else { - throw new TechnicalException( - "Preferred authentication method (" - + preferredMethod - + ") not supported " - + "by provider according to provider metadata (" - + metadataMethods - + ")."); - } - } else { - chosenMethod = firstSupportedMethod(metadataMethods); - } - } else { - chosenMethod = - preferredMethod != null ? preferredMethod : ClientAuthenticationMethod.getDefault(); - LOG.info( - "Provider metadata does not provide Token endpoint authentication methods. Using: {}", - chosenMethod); - } - - if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(chosenMethod)) { - Secret clientSecret = new Secret(configuration.getSecret()); - clientAuthenticationMechanism = new ClientSecretPost(clientID, clientSecret); - } else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(chosenMethod)) { - Secret clientSecret = new Secret(configuration.getSecret()); - clientAuthenticationMechanism = new ClientSecretBasic(clientID, clientSecret); - } else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(chosenMethod)) { - PrivateKeyJWTClientAuthnMethodConfig privateKetJwtConfig = - configuration.getPrivateKeyJWTClientAuthnMethodConfig(); - assertNotNull("privateKetJwtConfig", privateKetJwtConfig); - JWSAlgorithm jwsAlgo = privateKetJwtConfig.getJwsAlgorithm(); - assertNotNull("privateKetJwtConfig.getJwsAlgorithm()", jwsAlgo); - PrivateKey privateKey = privateKetJwtConfig.getPrivateKey(); - assertNotNull("privateKetJwtConfig.getPrivateKey()", privateKey); - String keyID = privateKetJwtConfig.getKeyID(); - try { - clientAuthenticationMechanism = - new PrivateKeyJWT( - clientID, - configuration.findProviderMetadata().getTokenEndpointURI(), - jwsAlgo, - privateKey, - keyID, - null); - } catch (final JOSEException e) { - throw new TechnicalException( - "Cannot instantiate private key JWT client authentication method", e); - } - } - } - - return clientAuthenticationMechanism; - } - - private static ClientAuthenticationMethod getPreferredAuthenticationMethod( - OidcConfiguration config) { - ClientAuthenticationMethod configurationMethod = config.getClientAuthenticationMethod(); - if (configurationMethod == null) { - return null; - } - - if (!SUPPORTED_METHODS.contains(configurationMethod)) { - throw new TechnicalException( - "Configured authentication method (" + configurationMethod + ") is not supported."); - } - - return configurationMethod; - } - - private static ClientAuthenticationMethod firstSupportedMethod( - final List metadataMethods) { - Optional firstSupported = - metadataMethods.stream().filter(SUPPORTED_METHODS::contains).findFirst(); - if (firstSupported.isPresent()) { - return firstSupported.get(); - } else { - throw new TechnicalException( - "None of the Token endpoint provider metadata authentication methods are supported: " - + metadataMethods); - } - } - - @SneakyThrows - public static void getErrorMessage(HttpServletResponse resp, Exception e) { - resp.setContentType("text/html; charset=UTF-8"); - LOG.error("[Auth Callback Servlet] Failed in Auth Login : {}", e.getMessage()); - resp.getOutputStream() - .println( - String.format( - "

[Auth Callback Servlet] Failed in Auth Login : %s

", e.getMessage())); - } - - public static void sendRedirectWithToken( - HttpServletResponse response, - OidcCredentials credentials, - String serverUrl, - Map claimsMapping, - List claimsOrder, - String defaultDomain) - throws ParseException, IOException { - JWT jwt = credentials.getIdToken(); - Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); - claims.putAll(jwt.getJWTClaimsSet().getClaims()); - - String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims); - String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, defaultDomain); - - String url = - String.format( - "%s/auth/callback?id_token=%s&email=%s&name=%s", - serverUrl, credentials.getIdToken().getParsedString(), email, userName); - response.sendRedirect(url); - } - - public static boolean isCredentialsExpired(OidcCredentials credentials) throws ParseException { - Date expiration = credentials.getIdToken().getJWTClaimsSet().getExpirationTime(); - return expiration != null && expiration.toInstant().isBefore(Instant.now().plusSeconds(30)); - } - - public static Optional getUserCredentialsFromSession( - HttpServletRequest request, OidcClient client) throws ParseException { - OidcCredentials credentials = - (OidcCredentials) request.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE); - if (credentials != null && credentials.getRefreshToken() != null) { - removeOrRenewOidcCredentials(request, client, credentials); - return Optional.of(credentials); - } else { - if (credentials == null) { - LOG.error("No credentials found against session. ID: {}", request.getSession().getId()); - } else { - LOG.error("No refresh token found against session. ID: {}", request.getSession().getId()); - } - } - return Optional.empty(); - } - - private static void removeOrRenewOidcCredentials( - HttpServletRequest request, OidcClient client, OidcCredentials credentials) { - LOG.debug("Expired credentials found, trying to renew."); - if (client.getConfiguration() instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) { - refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials); - } else { - OidcAuthenticator authenticator = new OidcAuthenticator(client.getConfiguration(), client); - authenticator.refresh(credentials); - } - request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); - } - - private static void refreshAccessTokenAzureAd2Token( - AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) { - HttpURLConnection connection = null; - try { - Map headers = new HashMap<>(); - headers.put( - HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE); - headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON); - // get the token endpoint from discovery URI - URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL(); - connection = HttpUtils.openPostConnection(tokenEndpointURL, headers); - - BufferedWriter out = - new BufferedWriter( - new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8)); - out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue())); - out.close(); - - int responseCode = connection.getResponseCode(); - if (responseCode != 200) { - throw new TechnicalException( - "request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection)); - } - var body = HttpUtils.readBody(connection); - Map res = JsonUtils.readValue(body, new TypeReference<>() {}); - azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token"))); - } catch (final IOException e) { - throw new TechnicalException(e); - } finally { - HttpUtils.closeConnection(connection); - } - } - public static String findUserNameFromClaims( Map jwtPrincipalClaimsMapping, List jwtPrincipalClaimsOrder, @@ -470,7 +129,7 @@ public final class SecurityUtil { } } - private static String getClaimOrObject(Object obj) { + public static String getClaimOrObject(Object obj) { if (obj == null) { return ""; } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/saml/SamlAssertionConsumerServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/saml/SamlAssertionConsumerServlet.java index a17842e129f..0a0b7c6ce3c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/saml/SamlAssertionConsumerServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/saml/SamlAssertionConsumerServlet.java @@ -92,7 +92,7 @@ public class SamlAssertionConsumerServlet extends HttpServlet { username, getRoleListFromUser(user), !nullOrEmpty(user.getIsAdmin()) && user.getIsAdmin(), - email, + user.getEmail(), SamlSettingsHolder.getInstance().getTokenValidity(), false, ServiceTokenType.OM_USER); diff --git a/openmetadata-spec/src/main/resources/json/schema/security/client/oidcClientConfig.json b/openmetadata-spec/src/main/resources/json/schema/security/client/oidcClientConfig.json index 2e45e5ea3a0..528da87ac36 100644 --- a/openmetadata-spec/src/main/resources/json/schema/security/client/oidcClientConfig.json +++ b/openmetadata-spec/src/main/resources/json/schema/security/client/oidcClientConfig.json @@ -56,6 +56,11 @@ "type": "string", "enum": ["client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt"] }, + "tokenValidity": { + "description": "Validity for the JWT Token created from SAML Response", + "type": "integer", + "default": "3600" + }, "customParams": { "description": "Custom Params.", "existingJavaType" : "java.util.Map",