From 53407fb681912734d10cfe8364d215055dede985 Mon Sep 17 00:00:00 2001 From: Mohit Yadav <105265192+mohityadav766@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:13:09 +0530 Subject: [PATCH] [Issue-16642] Add Claim Mapping to uniquely identifty username and email from claims (#16643) * - Add Claim Mapping to uniquely identift username and email from claims * - Null Check * - Add field to yaml * - Fix issue with token being null * - Auth Code Flow Fix * support jwtPrincipleClaimMapping from UI --------- Co-authored-by: Chira Madlani --- .../openmetadata/common/utils/CommonUtil.java | 5 + conf/openmetadata.yaml | 1 + .../service/jdbi3/SystemRepository.java | 2 +- .../service/security/AuthCallbackServlet.java | 12 +- .../service/security/AuthLoginServlet.java | 11 +- .../service/security/JwtFilter.java | 162 +++++++++--------- .../service/security/SecurityUtil.java | 156 +++++++++++++++-- .../service/socket/SocketAddressFilter.java | 14 +- .../authenticationConfiguration.json | 7 + .../AuthProviders/AuthProvider.interface.ts | 1 + .../Auth/AuthProviders/AuthProvider.tsx | 2 + .../ui/src/hooks/useApplicationStore.ts | 6 + .../ui/src/interface/store.interface.ts | 3 + .../ui/src/pages/SignUp/SignUpPage.tsx | 4 +- .../ui/src/utils/AuthProvider.util.ts | 71 +++++--- 15 files changed, 324 insertions(+), 133 deletions(-) diff --git a/common/src/main/java/org/openmetadata/common/utils/CommonUtil.java b/common/src/main/java/org/openmetadata/common/utils/CommonUtil.java index c4e5c691a9a..08e0b5cc37e 100644 --- a/common/src/main/java/org/openmetadata/common/utils/CommonUtil.java +++ b/common/src/main/java/org/openmetadata/common/utils/CommonUtil.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.Date; import java.util.Enumeration; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.TimeZone; @@ -183,6 +184,10 @@ public final class CommonUtil { return list == null || list.isEmpty(); } + public static boolean nullOrEmpty(Map m) { + return m == null || m.isEmpty(); + } + public static boolean nullOrEmpty(Object object) { return object == null || nullOrEmpty(object.toString()); } diff --git a/conf/openmetadata.yaml b/conf/openmetadata.yaml index 9776ffa62fb..b92a7dd7a40 100644 --- a/conf/openmetadata.yaml +++ b/conf/openmetadata.yaml @@ -175,6 +175,7 @@ authenticationConfiguration: clientId: ${AUTHENTICATION_CLIENT_ID:-""} callbackUrl: ${AUTHENTICATION_CALLBACK_URL:-""} jwtPrincipalClaims: ${AUTHENTICATION_JWT_PRINCIPAL_CLAIMS:-[email,preferred_username,sub]} + jwtPrincipalClaimsMapping: ${AUTHENTICATION_JWT_PRINCIPAL_CLAIMS_MAPPING:-[]} enableSelfSignup : ${AUTHENTICATION_ENABLE_SELF_SIGNUP:-true} oidcConfiguration: id: ${OIDC_CLIENT_ID:-""} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java index 89490f0ba55..16ce43d90d1 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java @@ -322,7 +322,7 @@ public class SystemRepository { OpenMetadataConnection openMetadataServerConnection = new OpenMetadataConnectionBuilder(applicationConfig).build(); try { - jwtFilter.validateAndReturnDecodedJwtToken( + jwtFilter.validateJwtAndGetClaims( openMetadataServerConnection.getSecurityConfig().getJwtToken()); return new StepValidation() .withDescription(ValidationStepDescription.JWT_TOKEN.key) diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthCallbackServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthCallbackServlet.java index e5d7507fc64..0419c35c7b4 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthCallbackServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthCallbackServlet.java @@ -1,9 +1,11 @@ package org.openmetadata.service.security; +import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE; import static org.openmetadata.service.security.SecurityUtil.getClientAuthentication; import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken; +import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping; import com.nimbusds.jose.proc.BadJOSEException; import com.nimbusds.jwt.JWT; @@ -40,6 +42,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; @@ -58,6 +61,7 @@ public class AuthCallbackServlet extends HttpServlet { private final OidcClient client; private final ClientAuthentication clientAuthentication; private final List claimsOrder; + private final Map claimsMapping; private final String serverUrl; private final String principalDomain; @@ -69,6 +73,11 @@ public class AuthCallbackServlet extends HttpServlet { "ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl()); this.client = oidcClient; this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); + this.claimsMapping = + listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() + .map(s -> s.split(":")) + .collect(Collectors.toMap(s -> s[0], s -> s[1])); + validatePrincipalClaimsMapping(claimsMapping); this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); this.clientAuthentication = getClientAuthentication(client.getConfiguration()); this.principalDomain = authorizerConfiguration.getPrincipalDomain(); @@ -118,7 +127,8 @@ public class AuthCallbackServlet extends HttpServlet { req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); // Redirect - sendRedirectWithToken(resp, credentials, serverUrl, claimsOrder, principalDomain); + sendRedirectWithToken( + resp, credentials, serverUrl, claimsMapping, claimsOrder, principalDomain); } catch (Exception e) { getErrorMessage(resp, e); } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java index f1b5ee3df17..b3e4ff54c13 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthLoginServlet.java @@ -1,8 +1,10 @@ package org.openmetadata.service.security; +import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession; import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken; +import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping; import com.nimbusds.oauth2.sdk.id.State; import com.nimbusds.oauth2.sdk.pkce.CodeChallenge; @@ -37,6 +39,7 @@ public class AuthLoginServlet extends HttpServlet { public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile"; private final OidcClient client; private final List claimsOrder; + private final Map claimsMapping; private final String serverUrl; private final String principalDomain; @@ -47,6 +50,11 @@ public class AuthLoginServlet extends HttpServlet { this.client = oidcClient; this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); + this.claimsMapping = + listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() + .map(s -> s.split(":")) + .collect(Collectors.toMap(s -> s[0], s -> s[1])); + validatePrincipalClaimsMapping(claimsMapping); this.principalDomain = authorizerConfiguration.getPrincipalDomain(); } @@ -57,7 +65,8 @@ public class AuthLoginServlet extends HttpServlet { Optional credentials = getUserCredentialsFromSession(req, client); if (credentials.isPresent()) { LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId()); - sendRedirectWithToken(resp, credentials.get(), serverUrl, claimsOrder, principalDomain); + sendRedirectWithToken( + resp, credentials.get(), serverUrl, claimsMapping, claimsOrder, principalDomain); } else { LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId()); Map params = buildParams(); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/JwtFilter.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/JwtFilter.java index 6685d39330c..7543c5ec8db 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/JwtFilter.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/JwtFilter.java @@ -13,7 +13,12 @@ package org.openmetadata.service.security; +import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; +import static org.openmetadata.service.security.SecurityUtil.findUserNameFromClaims; +import static org.openmetadata.service.security.SecurityUtil.isBot; +import static org.openmetadata.service.security.SecurityUtil.validateDomainEnforcement; +import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping; import static org.openmetadata.service.security.jwt.JWTTokenGenerator.ROLES_CLAIM; import static org.openmetadata.service.security.jwt.JWTTokenGenerator.TOKEN_TYPE; @@ -24,18 +29,19 @@ import com.auth0.jwt.algorithms.Algorithm; import com.auth0.jwt.exceptions.JWTDecodeException; import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.DecodedJWT; -import com.fasterxml.jackson.databind.node.TextNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import java.net.URL; import java.security.interfaces.RSAPublicKey; import java.util.*; +import java.util.stream.Collectors; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerRequestFilter; import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.SecurityContext; import javax.ws.rs.core.UriInfo; import javax.ws.rs.ext.Provider; +import lombok.Getter; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang.StringUtils; @@ -52,10 +58,13 @@ import org.openmetadata.service.security.saml.JwtTokenCacheManager; @Slf4j @Provider public class JwtFilter implements ContainerRequestFilter { + public static final String EMAIL_CLAIM_KEY = "email"; + public static final String USERNAME_CLAIM_KEY = "username"; public static final String AUTHORIZATION_HEADER = "Authorization"; public static final String TOKEN_PREFIX = "Bearer"; public static final String BOT_CLAIM = "isBot"; - private List jwtPrincipalClaims; + @Getter private List jwtPrincipalClaims; + @Getter private Map jwtPrincipalClaimsMapping; private JwkProvider jwkProvider; private String principalDomain; private boolean enforcePrincipalDomain; @@ -90,7 +99,13 @@ public class JwtFilter implements ContainerRequestFilter { AuthenticationConfiguration authenticationConfiguration, AuthorizerConfiguration authorizerConfiguration) { this.providerType = authenticationConfiguration.getProvider(); + // Cannot remove Principal Claims listing since that is , breaking change for existing users this.jwtPrincipalClaims = authenticationConfiguration.getJwtPrincipalClaims(); + this.jwtPrincipalClaimsMapping = + listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream() + .map(s -> s.split(":")) + .collect(Collectors.toMap(s -> s[0], s -> s[1])); + validatePrincipalClaimsMapping(jwtPrincipalClaimsMapping); ImmutableList.Builder publicKeyUrlsBuilder = ImmutableList.builder(); for (String publicKeyUrlStr : authenticationConfiguration.getPublicKeyUrls()) { @@ -131,25 +146,52 @@ public class JwtFilter implements ContainerRequestFilter { } // Extract token from the header - MultivaluedMap headers = requestContext.getHeaders(); - String tokenFromHeader = extractToken(headers); + String tokenFromHeader = extractToken(requestContext.getHeaders()); LOG.debug("Token from header:{}", tokenFromHeader); - // the case where OMD generated the Token for the Client - if (AuthProvider.BASIC.equals(providerType) || AuthProvider.SAML.equals(providerType)) { - validateTokenIsNotUsedAfterLogout(tokenFromHeader); + Map claims = validateJwtAndGetClaims(tokenFromHeader); + String userName = findUserNameFromClaims(jwtPrincipalClaimsMapping, jwtPrincipalClaims, claims); + + // Check Validations + checkValidationsForToken(claims, tokenFromHeader, userName); + + // Setting Security Context + CatalogPrincipal catalogPrincipal = new CatalogPrincipal(userName); + String scheme = requestContext.getUriInfo().getRequestUri().getScheme(); + CatalogSecurityContext catalogSecurityContext = + new CatalogSecurityContext( + catalogPrincipal, + scheme, + SecurityContext.DIGEST_AUTH, + getUserRolesFromClaims(claims, isBot(claims))); + LOG.debug("SecurityContext {}", catalogSecurityContext); + requestContext.setSecurityContext(catalogSecurityContext); + } + + public void checkValidationsForToken( + Map claims, String tokenFromHeader, String userName) { + // the case where OMD generated the Token for the Client in case OM generated Token + validateTokenIsNotUsedAfterLogout(tokenFromHeader); + + // Validate Domain + validateDomainEnforcement( + jwtPrincipalClaimsMapping, + jwtPrincipalClaims, + claims, + principalDomain, + enforcePrincipalDomain); + + // Validate Bot token matches what was created in OM + if (isBot(claims)) { + validateBotToken(tokenFromHeader, userName); } - DecodedJWT jwt = validateAndReturnDecodedJwtToken(tokenFromHeader); - - Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); - claims.putAll(jwt.getClaims()); - - String userName = validateAndReturnUsername(claims); + // validate personal access token + validatePersonalAccessToken(claims, tokenFromHeader, userName); + } + private Set getUserRolesFromClaims(Map claims, boolean isBot) { Set userRoles = new HashSet<>(); - boolean isBot = - claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean()); // Re-sync user roles from token if (useRolesFromProvider && !isBot && claims.containsKey(ROLES_CLAIM)) { List roles = claims.get(ROLES_CLAIM).asList(String.class); @@ -157,30 +199,11 @@ public class JwtFilter implements ContainerRequestFilter { userRoles = new HashSet<>(claims.get(ROLES_CLAIM).asList(String.class)); } } - - // validate bot token - if (isBot) { - validateBotToken(tokenFromHeader, userName); - } - - // validate access token - if (claims.containsKey(TOKEN_TYPE) - && ServiceTokenType.PERSONAL_ACCESS.value().equals(claims.get(TOKEN_TYPE).asString())) { - validatePersonalAccessToken(tokenFromHeader, userName); - } - - // Setting Security Context - CatalogPrincipal catalogPrincipal = new CatalogPrincipal(userName); - String scheme = requestContext.getUriInfo().getRequestUri().getScheme(); - CatalogSecurityContext catalogSecurityContext = - new CatalogSecurityContext( - catalogPrincipal, scheme, SecurityContext.DIGEST_AUTH, userRoles); - LOG.debug("SecurityContext {}", catalogSecurityContext); - requestContext.setSecurityContext(catalogSecurityContext); + return userRoles; } @SneakyThrows - public DecodedJWT validateAndReturnDecodedJwtToken(String token) { + public Map validateJwtAndGetClaims(String token) { // Decode JWT Token DecodedJWT jwt; try { @@ -204,43 +227,11 @@ public class JwtFilter implements ContainerRequestFilter { } catch (RuntimeException runtimeException) { throw new AuthenticationException("Invalid token", runtimeException); } - return jwt; - } - @SneakyThrows - public String validateAndReturnUsername(Map claims) { - // Get email from JWT token - String jwtClaim = - jwtPrincipalClaims.stream() - .filter(claims::containsKey) - .findFirst() - .map(claims::get) - .map(claim -> claim.as(TextNode.class).asText()) - .orElseThrow( - () -> - new AuthenticationException( - "Invalid JWT token, none of the following claims are present " - + jwtPrincipalClaims)); + Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + claims.putAll(jwt.getClaims()); - String userName; - String domain; - if (jwtClaim.contains("@")) { - userName = jwtClaim.split("@")[0]; - domain = jwtClaim.split("@")[1]; - } else { - userName = jwtClaim; - domain = StringUtils.EMPTY; - } - - // validate principal domain, for users - boolean isBot = - claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean()); - if (!isBot && (enforcePrincipalDomain && !domain.equals(principalDomain))) { - throw new AuthenticationException( - String.format( - "Not Authorized! Email does not match the principal domain %s", principalDomain)); - } - return userName; + return claims; } protected static String extractToken(MultivaluedMap headers) { @@ -275,18 +266,31 @@ public class JwtFilter implements ContainerRequestFilter { throw AuthenticationException.getInvalidTokenException(); } - private void validatePersonalAccessToken(String tokenFromHeader, String userName) { - if (UserTokenCache.getToken(userName).contains(tokenFromHeader)) { - return; + private void validatePersonalAccessToken( + Map claims, String tokenFromHeader, String userName) { + if (claims.containsKey(TOKEN_TYPE) + && ServiceTokenType.PERSONAL_ACCESS + .value() + .equals( + claims.get(TOKEN_TYPE) != null + ? StringUtils.EMPTY + : claims.get(TOKEN_TYPE).asString())) { + Set userTokens = UserTokenCache.getToken(userName); + if (userTokens != null && userTokens.contains(tokenFromHeader)) { + return; + } + throw AuthenticationException.getInvalidTokenException(); } - throw AuthenticationException.getInvalidTokenException(); } private void validateTokenIsNotUsedAfterLogout(String authToken) { - LogoutRequest previouslyLoggedOutEvent = - JwtTokenCacheManager.getInstance().getLogoutEventForToken(authToken); - if (previouslyLoggedOutEvent != null) { - throw new AuthenticationException("Expired token!"); + // Only OMD generated Tokens + if (AuthProvider.BASIC.equals(providerType) || AuthProvider.SAML.equals(providerType)) { + LogoutRequest previouslyLoggedOutEvent = + JwtTokenCacheManager.getInstance().getLogoutEventForToken(authToken); + if (previouslyLoggedOutEvent != null) { + throw new AuthenticationException("Expired token!"); + } } } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java index d8ee199c8b0..f8055248484 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/SecurityUtil.java @@ -13,10 +13,15 @@ package org.openmetadata.service.security; +import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE; +import static org.openmetadata.service.security.JwtFilter.BOT_CLAIM; +import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY; +import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY; import static org.pac4j.core.util.CommonHelper.assertNotNull; import static org.pac4j.core.util.CommonHelper.isNotEmpty; +import com.auth0.jwt.interfaces.Claim; import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; @@ -56,6 +61,7 @@ import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.SecurityContext; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang.StringUtils; import org.openmetadata.common.utils.CommonUtil; import org.openmetadata.schema.security.client.OidcClientConfig; import org.openmetadata.service.OpenMetadataApplicationConfig; @@ -327,31 +333,16 @@ public final class SecurityUtil { HttpServletResponse response, OidcCredentials credentials, String serverUrl, + Map claimsMapping, List claimsOrder, String defaultDomain) throws ParseException, IOException { JWT jwt = credentials.getIdToken(); Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); claims.putAll(jwt.getJWTClaimsSet().getClaims()); - String preferredJwtClaim = - claimsOrder.stream() - .filter(claims::containsKey) - .findFirst() - .map(claims::get) - .map(String.class::cast) - .orElseThrow( - () -> - new AuthenticationException( - "Invalid JWT token, none of the following claims are present " - + claimsOrder)); - String userName; - if (preferredJwtClaim.contains("@")) { - userName = preferredJwtClaim.split("@")[0]; - } else { - userName = preferredJwtClaim; - } - String email = String.format("%s@%s", userName, defaultDomain); + String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims); + String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, defaultDomain); String url = String.format( @@ -426,4 +417,133 @@ public final class SecurityUtil { HttpUtils.closeConnection(connection); } } + + public static String findUserNameFromClaims( + Map jwtPrincipalClaimsMapping, + List jwtPrincipalClaimsOrder, + Map claims) { + if (!nullOrEmpty(jwtPrincipalClaimsMapping)) { + // We have a mapping available so we will use that + String usernameClaim = jwtPrincipalClaimsMapping.get(USERNAME_CLAIM_KEY); + String userNameClaimValue = getClaimOrObject(claims.get(usernameClaim)); + if (!nullOrEmpty(userNameClaimValue)) { + return userNameClaimValue; + } else { + throw new AuthenticationException("Invalid JWT token, 'username' claim is not present"); + } + } else { + String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims); + String userName; + if (jwtClaim.contains("@")) { + userName = jwtClaim.split("@")[0]; + } else { + userName = jwtClaim; + } + return userName; + } + } + + public static String findEmailFromClaims( + Map jwtPrincipalClaimsMapping, + List jwtPrincipalClaimsOrder, + Map claims, + String defaulPrincipalClaim) { + if (!nullOrEmpty(jwtPrincipalClaimsMapping)) { + // We have a mapping available so we will use that + String emailClaim = jwtPrincipalClaimsMapping.get(EMAIL_CLAIM_KEY); + String emailClaimValue = getClaimOrObject(claims.get(emailClaim)); + if (!nullOrEmpty(emailClaimValue) && emailClaimValue.contains("@")) { + return emailClaimValue; + } else { + throw new AuthenticationException( + String.format( + "Invalid JWT token, 'email' claim is not present or invalid : %s", + emailClaimValue)); + } + } else { + String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims); + if (jwtClaim.contains("@")) { + return jwtClaim; + } else { + return String.format("%s@%s", jwtClaim, defaulPrincipalClaim); + } + } + } + + private static String getClaimOrObject(Object obj) { + if (obj == null) { + return ""; + } + + if (obj instanceof Claim c) { + return c.asString(); + } else if (obj instanceof String s) { + return s; + } + + return StringUtils.EMPTY; + } + + public static String getFirstMatchJwtClaim( + List jwtPrincipalClaimsOrder, Map claims) { + return jwtPrincipalClaimsOrder.stream() + .filter(claims::containsKey) + .findFirst() + .map(claims::get) + .map(SecurityUtil::getClaimOrObject) + .orElseThrow( + () -> + new AuthenticationException( + "Invalid JWT token, none of the following claims are present " + + jwtPrincipalClaimsOrder)); + } + + public static void validatePrincipalClaimsMapping(Map mapping) { + if (!nullOrEmpty(mapping)) { + String username = mapping.get(USERNAME_CLAIM_KEY); + String email = mapping.get(EMAIL_CLAIM_KEY); + if (nullOrEmpty(username) || nullOrEmpty(email)) { + throw new IllegalArgumentException( + "Invalid JWT Principal Claims Mapping. Both username and email should be present"); + } + } + // If emtpy, jwtPrincipalClaims will be used so no need to validate + } + + public static void validateDomainEnforcement( + Map jwtPrincipalClaimsMapping, + List jwtPrincipalClaimsOrder, + Map claims, + String principalDomain, + boolean enforcePrincipalDomain) { + String domain = StringUtils.EMPTY; + if (!nullOrEmpty(jwtPrincipalClaimsMapping)) { + // We have a mapping available so we will use that + String emailClaim = jwtPrincipalClaimsMapping.get(EMAIL_CLAIM_KEY); + String emailClaimValue = getClaimOrObject(claims.get(emailClaim)); + if (!nullOrEmpty(emailClaimValue)) { + if (emailClaimValue.contains("@")) { + domain = emailClaimValue.split("@")[1]; + } + } else { + throw new AuthenticationException("Invalid JWT token, 'email' claim is not present"); + } + } else { + String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims); + if (jwtClaim.contains("@")) { + domain = jwtClaim.split("@")[1]; + } + } + + // Validate + if (!isBot(claims) && (enforcePrincipalDomain && !domain.equals(principalDomain))) { + throw new AuthenticationException( + String.format( + "Not Authorized! Email does not match the principal domain %s", principalDomain)); + } + } + + public static boolean isBot(Map claims) { + return claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean()); + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/socket/SocketAddressFilter.java b/openmetadata-service/src/main/java/org/openmetadata/service/socket/SocketAddressFilter.java index 9c3f867cefa..bdb049e0ca2 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/socket/SocketAddressFilter.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/socket/SocketAddressFilter.java @@ -14,11 +14,9 @@ package org.openmetadata.service.socket; import com.auth0.jwt.interfaces.Claim; -import com.auth0.jwt.interfaces.DecodedJWT; import io.socket.engineio.server.utils.ParseQS; import java.io.IOException; import java.util.Map; -import java.util.TreeMap; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -29,6 +27,7 @@ import lombok.extern.slf4j.Slf4j; import org.openmetadata.schema.api.security.AuthenticationConfiguration; import org.openmetadata.schema.api.security.AuthorizerConfiguration; import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.security.SecurityUtil; @Slf4j public class SocketAddressFilter implements Filter { @@ -82,11 +81,10 @@ public class SocketAddressFilter implements Filter { public static void validatePrefixedTokenRequest(JwtFilter jwtFilter, String prefixedToken) { String token = JwtFilter.extractToken(prefixedToken); - // validate token - DecodedJWT jwt = jwtFilter.validateAndReturnDecodedJwtToken(token); - // validate Domain and Username - Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); - claims.putAll(jwt.getClaims()); - jwtFilter.validateAndReturnUsername(claims); + Map claims = jwtFilter.validateJwtAndGetClaims(token); + String userName = + SecurityUtil.findUserNameFromClaims( + jwtFilter.getJwtPrincipalClaimsMapping(), jwtFilter.getJwtPrincipalClaims(), claims); + jwtFilter.checkValidationsForToken(claims, token, userName); } } diff --git a/openmetadata-spec/src/main/resources/json/schema/configuration/authenticationConfiguration.json b/openmetadata-spec/src/main/resources/json/schema/configuration/authenticationConfiguration.json index 70412355bf4..670401107ca 100644 --- a/openmetadata-spec/src/main/resources/json/schema/configuration/authenticationConfiguration.json +++ b/openmetadata-spec/src/main/resources/json/schema/configuration/authenticationConfiguration.json @@ -65,6 +65,13 @@ "type": "string" } }, + "jwtPrincipalClaimsMapping": { + "description": "Jwt Principal Claim Mapping", + "type": "array", + "items": { + "type": "string" + } + }, "enableSelfSignup": { "description": "Enable Self Sign Up", "type": "boolean", diff --git a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.interface.ts b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.interface.ts index 5e4b2c0dbec..c1308e6dd74 100644 --- a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.interface.ts +++ b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.interface.ts @@ -64,6 +64,7 @@ export interface IAuthContext { updateAxiosInterceptors: () => void; updateCurrentUser: (user: User) => void; jwtPrincipalClaims: string[]; + jwtPrincipalClaimsMapping: string[]; } export type AuthenticationConfigurationWithScope = diff --git a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.tsx b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.tsx index be53b7d4f0e..3fd3661be54 100644 --- a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.tsx +++ b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AuthProviders/AuthProvider.tsx @@ -115,6 +115,7 @@ export const AuthProvider = ({ setAuthorizerConfig, setIsSigningUp, setJwtPrincipalClaims, + setJwtPrincipalClaimsMapping, removeRefreshToken, removeOidcToken, getOidcToken, @@ -538,6 +539,7 @@ export const AuthProvider = ({ if (provider && Object.values(AuthProviderEnum).includes(provider)) { const configJson = getAuthConfig(authConfig); setJwtPrincipalClaims(authConfig.jwtPrincipalClaims); + setJwtPrincipalClaimsMapping(authConfig.jwtPrincipalClaimsMapping); setAuthConfig(configJson); setAuthorizerConfig(authorizerConfig); updateAuthInstance(configJson); diff --git a/openmetadata-ui/src/main/resources/ui/src/hooks/useApplicationStore.ts b/openmetadata-ui/src/main/resources/ui/src/hooks/useApplicationStore.ts index a9756729fd3..63d2c853ea8 100644 --- a/openmetadata-ui/src/main/resources/ui/src/hooks/useApplicationStore.ts +++ b/openmetadata-ui/src/main/resources/ui/src/hooks/useApplicationStore.ts @@ -43,6 +43,7 @@ export const useApplicationStore = create()( authorizerConfig: undefined, isSigningUp: false, jwtPrincipalClaims: [], + jwtPrincipalClaimsMapping: [], userProfilePics: {}, cachedEntityData: {}, selectedPersona: {} as EntityReference, @@ -75,6 +76,11 @@ export const useApplicationStore = create()( ) => { set({ jwtPrincipalClaims: claims }); }, + setJwtPrincipalClaimsMapping: ( + claimMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping'] + ) => { + set({ jwtPrincipalClaimsMapping: claimMapping }); + }, setIsAuthenticated: (authenticated: boolean) => { set({ isAuthenticated: authenticated }); }, diff --git a/openmetadata-ui/src/main/resources/ui/src/interface/store.interface.ts b/openmetadata-ui/src/main/resources/ui/src/interface/store.interface.ts index da54d84dc16..e7b201a2c24 100644 --- a/openmetadata-ui/src/main/resources/ui/src/interface/store.interface.ts +++ b/openmetadata-ui/src/main/resources/ui/src/interface/store.interface.ts @@ -61,6 +61,9 @@ export interface ApplicationStore setJwtPrincipalClaims: ( claims: AuthenticationConfiguration['jwtPrincipalClaims'] ) => void; + setJwtPrincipalClaimsMapping: ( + claimsMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping'] + ) => void; setHelperFunctionsRef: (helperFunctions: HelperFunctions) => void; updateUserProfilePics: (data: { id: string; user: User }) => void; updateCachedEntityData: (data: { diff --git a/openmetadata-ui/src/main/resources/ui/src/pages/SignUp/SignUpPage.tsx b/openmetadata-ui/src/main/resources/ui/src/pages/SignUp/SignUpPage.tsx index db0743d0a34..047fb74a7f6 100644 --- a/openmetadata-ui/src/main/resources/ui/src/pages/SignUp/SignUpPage.tsx +++ b/openmetadata-ui/src/main/resources/ui/src/pages/SignUp/SignUpPage.tsx @@ -43,6 +43,7 @@ const SignUp = () => { const { setIsSigningUp, jwtPrincipalClaims = [], + jwtPrincipalClaimsMapping = [], authorizerConfig, updateCurrentUser, newUser, @@ -116,7 +117,8 @@ const SignUp = () => { ...getNameFromUserData( newUser as UserProfile, jwtPrincipalClaims, - authorizerConfig?.principalDomain + authorizerConfig?.principalDomain, + jwtPrincipalClaimsMapping ), }} layout="vertical" diff --git a/openmetadata-ui/src/main/resources/ui/src/utils/AuthProvider.util.ts b/openmetadata-ui/src/main/resources/ui/src/utils/AuthProvider.util.ts index 60ac5aa2647..c5cc1c3422e 100644 --- a/openmetadata-ui/src/main/resources/ui/src/utils/AuthProvider.util.ts +++ b/openmetadata-ui/src/main/resources/ui/src/utils/AuthProvider.util.ts @@ -18,7 +18,7 @@ import { } from '@azure/msal-browser'; import { CookieStorage } from 'cookie-storage'; import jwtDecode, { JwtPayload } from 'jwt-decode'; -import { first, isNil } from 'lodash'; +import { first, get, isEmpty, isNil } from 'lodash'; import { WebStorageStateStore } from 'oidc-client'; import { AuthenticationConfigurationWithScope, @@ -231,36 +231,59 @@ export const getNameFromEmail = (email: string) => { export const getNameFromUserData = ( user: UserProfile, jwtPrincipalClaims: AuthenticationConfiguration['jwtPrincipalClaims'] = [], - principleDomain = '' + principleDomain = '', + jwtPrincipalClaimsMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping'] = [] ) => { - // filter and extract the present claims in user profile - const jwtClaims = jwtPrincipalClaims.reduce( - (prev: string[], curr: string) => { - const currentClaim = user[curr as keyof UserProfile]; - if (currentClaim) { - return [...prev, currentClaim]; - } else { - return prev; - } - }, - [] - ); - - // get the first claim from claims list - const firstClaim = first(jwtClaims); - let userName = ''; let domain = principleDomain; + let email = ''; + if (isEmpty(jwtPrincipalClaimsMapping)) { + // filter and extract the present claims in user profile + const jwtClaims = jwtPrincipalClaims.reduce( + (prev: string[], curr: string) => { + const currentClaim = user[curr as keyof UserProfile]; + if (currentClaim) { + return [...prev, currentClaim]; + } else { + return prev; + } + }, + [] + ); - // if claims contains the "@" then split it out otherwise assign it to username as it is - if (firstClaim?.includes('@')) { - userName = firstClaim.split('@')[0]; - domain = firstClaim.split('@')[1]; + // get the first claim from claims list + const firstClaim = first(jwtClaims); + + // if claims contains the "@" then split it out otherwise assign it to username as it is + if (firstClaim?.includes('@')) { + userName = firstClaim.split('@')[0]; + domain = firstClaim.split('@')[1]; + } else { + userName = firstClaim ?? ''; + } + + email = userName + '@' + domain; } else { - userName = firstClaim ?? ''; + const mappingObj: Record = {}; + jwtPrincipalClaimsMapping.reduce((acc, value) => { + const [key, claim] = value.split(':'); + acc[key] = claim; + + return acc; + }, mappingObj); + + if (mappingObj['username'] && mappingObj['email']) { + userName = get(user, mappingObj['username'], ''); + email = get(user, mappingObj['email']); + } else { + // eslint-disable-next-line no-console + console.error( + 'username or email is not present in jwtPrincipalClaimsMapping' + ); + } } - return { name: userName, email: userName + '@' + domain }; + return { name: userName, email: email }; }; export const isProtectedRoute = (pathname: string) => {