[Issue-16642] Add Claim Mapping to uniquely identifty username and email from claims (#16643)

* - Add Claim Mapping to uniquely identift username and email from claims

* - Null Check

* - Add field to yaml

* - Fix issue with token being null

* - Auth Code Flow Fix

* support jwtPrincipleClaimMapping from UI

---------

Co-authored-by: Chira Madlani <chirag@getcollate.io>
This commit is contained in:
Mohit Yadav 2024-06-19 13:13:09 +05:30 committed by GitHub
parent a5295396bd
commit 53407fb681
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 324 additions and 133 deletions

View File

@ -35,6 +35,7 @@ import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.TimeZone; import java.util.TimeZone;
@ -183,6 +184,10 @@ public final class CommonUtil {
return list == null || list.isEmpty(); return list == null || list.isEmpty();
} }
public static boolean nullOrEmpty(Map<?, ?> m) {
return m == null || m.isEmpty();
}
public static boolean nullOrEmpty(Object object) { public static boolean nullOrEmpty(Object object) {
return object == null || nullOrEmpty(object.toString()); return object == null || nullOrEmpty(object.toString());
} }

View File

@ -175,6 +175,7 @@ authenticationConfiguration:
clientId: ${AUTHENTICATION_CLIENT_ID:-""} clientId: ${AUTHENTICATION_CLIENT_ID:-""}
callbackUrl: ${AUTHENTICATION_CALLBACK_URL:-""} callbackUrl: ${AUTHENTICATION_CALLBACK_URL:-""}
jwtPrincipalClaims: ${AUTHENTICATION_JWT_PRINCIPAL_CLAIMS:-[email,preferred_username,sub]} jwtPrincipalClaims: ${AUTHENTICATION_JWT_PRINCIPAL_CLAIMS:-[email,preferred_username,sub]}
jwtPrincipalClaimsMapping: ${AUTHENTICATION_JWT_PRINCIPAL_CLAIMS_MAPPING:-[]}
enableSelfSignup : ${AUTHENTICATION_ENABLE_SELF_SIGNUP:-true} enableSelfSignup : ${AUTHENTICATION_ENABLE_SELF_SIGNUP:-true}
oidcConfiguration: oidcConfiguration:
id: ${OIDC_CLIENT_ID:-""} id: ${OIDC_CLIENT_ID:-""}

View File

@ -322,7 +322,7 @@ public class SystemRepository {
OpenMetadataConnection openMetadataServerConnection = OpenMetadataConnection openMetadataServerConnection =
new OpenMetadataConnectionBuilder(applicationConfig).build(); new OpenMetadataConnectionBuilder(applicationConfig).build();
try { try {
jwtFilter.validateAndReturnDecodedJwtToken( jwtFilter.validateJwtAndGetClaims(
openMetadataServerConnection.getSecurityConfig().getJwtToken()); openMetadataServerConnection.getSecurityConfig().getJwtToken());
return new StepValidation() return new StepValidation()
.withDescription(ValidationStepDescription.JWT_TOKEN.key) .withDescription(ValidationStepDescription.JWT_TOKEN.key)

View File

@ -1,9 +1,11 @@
package org.openmetadata.service.security; package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE; import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE;
import static org.openmetadata.service.security.SecurityUtil.getClientAuthentication; import static org.openmetadata.service.security.SecurityUtil.getClientAuthentication;
import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; import static org.openmetadata.service.security.SecurityUtil.getErrorMessage;
import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken; import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken;
import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping;
import com.nimbusds.jose.proc.BadJOSEException; import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
@ -40,6 +42,7 @@ import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
import javax.servlet.annotation.WebServlet; import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -58,6 +61,7 @@ public class AuthCallbackServlet extends HttpServlet {
private final OidcClient client; private final OidcClient client;
private final ClientAuthentication clientAuthentication; private final ClientAuthentication clientAuthentication;
private final List<String> claimsOrder; private final List<String> claimsOrder;
private final Map<String, String> claimsMapping;
private final String serverUrl; private final String serverUrl;
private final String principalDomain; private final String principalDomain;
@ -69,6 +73,11 @@ public class AuthCallbackServlet extends HttpServlet {
"ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl()); "ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl());
this.client = oidcClient; this.client = oidcClient;
this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
this.claimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(claimsMapping);
this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
this.clientAuthentication = getClientAuthentication(client.getConfiguration()); this.clientAuthentication = getClientAuthentication(client.getConfiguration());
this.principalDomain = authorizerConfiguration.getPrincipalDomain(); this.principalDomain = authorizerConfiguration.getPrincipalDomain();
@ -118,7 +127,8 @@ public class AuthCallbackServlet extends HttpServlet {
req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
// Redirect // Redirect
sendRedirectWithToken(resp, credentials, serverUrl, claimsOrder, principalDomain); sendRedirectWithToken(
resp, credentials, serverUrl, claimsMapping, claimsOrder, principalDomain);
} catch (Exception e) { } catch (Exception e) {
getErrorMessage(resp, e); getErrorMessage(resp, e);
} }

View File

@ -1,8 +1,10 @@
package org.openmetadata.service.security; package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.service.security.SecurityUtil.getErrorMessage; import static org.openmetadata.service.security.SecurityUtil.getErrorMessage;
import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession; import static org.openmetadata.service.security.SecurityUtil.getUserCredentialsFromSession;
import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken; import static org.openmetadata.service.security.SecurityUtil.sendRedirectWithToken;
import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping;
import com.nimbusds.oauth2.sdk.id.State; import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallenge; import com.nimbusds.oauth2.sdk.pkce.CodeChallenge;
@ -37,6 +39,7 @@ public class AuthLoginServlet extends HttpServlet {
public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile"; public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile";
private final OidcClient client; private final OidcClient client;
private final List<String> claimsOrder; private final List<String> claimsOrder;
private final Map<String, String> claimsMapping;
private final String serverUrl; private final String serverUrl;
private final String principalDomain; private final String principalDomain;
@ -47,6 +50,11 @@ public class AuthLoginServlet extends HttpServlet {
this.client = oidcClient; this.client = oidcClient;
this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl(); this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims(); this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
this.claimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(claimsMapping);
this.principalDomain = authorizerConfiguration.getPrincipalDomain(); this.principalDomain = authorizerConfiguration.getPrincipalDomain();
} }
@ -57,7 +65,8 @@ public class AuthLoginServlet extends HttpServlet {
Optional<OidcCredentials> credentials = getUserCredentialsFromSession(req, client); Optional<OidcCredentials> credentials = getUserCredentialsFromSession(req, client);
if (credentials.isPresent()) { if (credentials.isPresent()) {
LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId()); LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId());
sendRedirectWithToken(resp, credentials.get(), serverUrl, claimsOrder, principalDomain); sendRedirectWithToken(
resp, credentials.get(), serverUrl, claimsMapping, claimsOrder, principalDomain);
} else { } else {
LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId()); LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId());
Map<String, String> params = buildParams(); Map<String, String> params = buildParams();

View File

@ -13,7 +13,12 @@
package org.openmetadata.service.security; package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import static org.openmetadata.service.security.SecurityUtil.findUserNameFromClaims;
import static org.openmetadata.service.security.SecurityUtil.isBot;
import static org.openmetadata.service.security.SecurityUtil.validateDomainEnforcement;
import static org.openmetadata.service.security.SecurityUtil.validatePrincipalClaimsMapping;
import static org.openmetadata.service.security.jwt.JWTTokenGenerator.ROLES_CLAIM; import static org.openmetadata.service.security.jwt.JWTTokenGenerator.ROLES_CLAIM;
import static org.openmetadata.service.security.jwt.JWTTokenGenerator.TOKEN_TYPE; import static org.openmetadata.service.security.jwt.JWTTokenGenerator.TOKEN_TYPE;
@ -24,18 +29,19 @@ import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTDecodeException; import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.DecodedJWT;
import com.fasterxml.jackson.databind.node.TextNode;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import java.net.URL; import java.net.URL;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter; import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.SecurityContext; import javax.ws.rs.core.SecurityContext;
import javax.ws.rs.core.UriInfo; import javax.ws.rs.core.UriInfo;
import javax.ws.rs.ext.Provider; import javax.ws.rs.ext.Provider;
import lombok.Getter;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
@ -52,10 +58,13 @@ import org.openmetadata.service.security.saml.JwtTokenCacheManager;
@Slf4j @Slf4j
@Provider @Provider
public class JwtFilter implements ContainerRequestFilter { public class JwtFilter implements ContainerRequestFilter {
public static final String EMAIL_CLAIM_KEY = "email";
public static final String USERNAME_CLAIM_KEY = "username";
public static final String AUTHORIZATION_HEADER = "Authorization"; public static final String AUTHORIZATION_HEADER = "Authorization";
public static final String TOKEN_PREFIX = "Bearer"; public static final String TOKEN_PREFIX = "Bearer";
public static final String BOT_CLAIM = "isBot"; public static final String BOT_CLAIM = "isBot";
private List<String> jwtPrincipalClaims; @Getter private List<String> jwtPrincipalClaims;
@Getter private Map<String, String> jwtPrincipalClaimsMapping;
private JwkProvider jwkProvider; private JwkProvider jwkProvider;
private String principalDomain; private String principalDomain;
private boolean enforcePrincipalDomain; private boolean enforcePrincipalDomain;
@ -90,7 +99,13 @@ public class JwtFilter implements ContainerRequestFilter {
AuthenticationConfiguration authenticationConfiguration, AuthenticationConfiguration authenticationConfiguration,
AuthorizerConfiguration authorizerConfiguration) { AuthorizerConfiguration authorizerConfiguration) {
this.providerType = authenticationConfiguration.getProvider(); this.providerType = authenticationConfiguration.getProvider();
// Cannot remove Principal Claims listing since that is , breaking change for existing users
this.jwtPrincipalClaims = authenticationConfiguration.getJwtPrincipalClaims(); this.jwtPrincipalClaims = authenticationConfiguration.getJwtPrincipalClaims();
this.jwtPrincipalClaimsMapping =
listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
.map(s -> s.split(":"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
validatePrincipalClaimsMapping(jwtPrincipalClaimsMapping);
ImmutableList.Builder<URL> publicKeyUrlsBuilder = ImmutableList.builder(); ImmutableList.Builder<URL> publicKeyUrlsBuilder = ImmutableList.builder();
for (String publicKeyUrlStr : authenticationConfiguration.getPublicKeyUrls()) { for (String publicKeyUrlStr : authenticationConfiguration.getPublicKeyUrls()) {
@ -131,25 +146,52 @@ public class JwtFilter implements ContainerRequestFilter {
} }
// Extract token from the header // Extract token from the header
MultivaluedMap<String, String> headers = requestContext.getHeaders(); String tokenFromHeader = extractToken(requestContext.getHeaders());
String tokenFromHeader = extractToken(headers);
LOG.debug("Token from header:{}", tokenFromHeader); LOG.debug("Token from header:{}", tokenFromHeader);
// the case where OMD generated the Token for the Client Map<String, Claim> claims = validateJwtAndGetClaims(tokenFromHeader);
if (AuthProvider.BASIC.equals(providerType) || AuthProvider.SAML.equals(providerType)) { String userName = findUserNameFromClaims(jwtPrincipalClaimsMapping, jwtPrincipalClaims, claims);
validateTokenIsNotUsedAfterLogout(tokenFromHeader);
// Check Validations
checkValidationsForToken(claims, tokenFromHeader, userName);
// Setting Security Context
CatalogPrincipal catalogPrincipal = new CatalogPrincipal(userName);
String scheme = requestContext.getUriInfo().getRequestUri().getScheme();
CatalogSecurityContext catalogSecurityContext =
new CatalogSecurityContext(
catalogPrincipal,
scheme,
SecurityContext.DIGEST_AUTH,
getUserRolesFromClaims(claims, isBot(claims)));
LOG.debug("SecurityContext {}", catalogSecurityContext);
requestContext.setSecurityContext(catalogSecurityContext);
} }
DecodedJWT jwt = validateAndReturnDecodedJwtToken(tokenFromHeader); public void checkValidationsForToken(
Map<String, Claim> claims, String tokenFromHeader, String userName) {
// the case where OMD generated the Token for the Client in case OM generated Token
validateTokenIsNotUsedAfterLogout(tokenFromHeader);
Map<String, Claim> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); // Validate Domain
claims.putAll(jwt.getClaims()); validateDomainEnforcement(
jwtPrincipalClaimsMapping,
jwtPrincipalClaims,
claims,
principalDomain,
enforcePrincipalDomain);
String userName = validateAndReturnUsername(claims); // Validate Bot token matches what was created in OM
if (isBot(claims)) {
validateBotToken(tokenFromHeader, userName);
}
// validate personal access token
validatePersonalAccessToken(claims, tokenFromHeader, userName);
}
private Set<String> getUserRolesFromClaims(Map<String, Claim> claims, boolean isBot) {
Set<String> userRoles = new HashSet<>(); Set<String> userRoles = new HashSet<>();
boolean isBot =
claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean());
// Re-sync user roles from token // Re-sync user roles from token
if (useRolesFromProvider && !isBot && claims.containsKey(ROLES_CLAIM)) { if (useRolesFromProvider && !isBot && claims.containsKey(ROLES_CLAIM)) {
List<String> roles = claims.get(ROLES_CLAIM).asList(String.class); List<String> roles = claims.get(ROLES_CLAIM).asList(String.class);
@ -157,30 +199,11 @@ public class JwtFilter implements ContainerRequestFilter {
userRoles = new HashSet<>(claims.get(ROLES_CLAIM).asList(String.class)); userRoles = new HashSet<>(claims.get(ROLES_CLAIM).asList(String.class));
} }
} }
return userRoles;
// validate bot token
if (isBot) {
validateBotToken(tokenFromHeader, userName);
}
// validate access token
if (claims.containsKey(TOKEN_TYPE)
&& ServiceTokenType.PERSONAL_ACCESS.value().equals(claims.get(TOKEN_TYPE).asString())) {
validatePersonalAccessToken(tokenFromHeader, userName);
}
// Setting Security Context
CatalogPrincipal catalogPrincipal = new CatalogPrincipal(userName);
String scheme = requestContext.getUriInfo().getRequestUri().getScheme();
CatalogSecurityContext catalogSecurityContext =
new CatalogSecurityContext(
catalogPrincipal, scheme, SecurityContext.DIGEST_AUTH, userRoles);
LOG.debug("SecurityContext {}", catalogSecurityContext);
requestContext.setSecurityContext(catalogSecurityContext);
} }
@SneakyThrows @SneakyThrows
public DecodedJWT validateAndReturnDecodedJwtToken(String token) { public Map<String, Claim> validateJwtAndGetClaims(String token) {
// Decode JWT Token // Decode JWT Token
DecodedJWT jwt; DecodedJWT jwt;
try { try {
@ -204,43 +227,11 @@ public class JwtFilter implements ContainerRequestFilter {
} catch (RuntimeException runtimeException) { } catch (RuntimeException runtimeException) {
throw new AuthenticationException("Invalid token", runtimeException); throw new AuthenticationException("Invalid token", runtimeException);
} }
return jwt;
}
@SneakyThrows Map<String, Claim> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
public String validateAndReturnUsername(Map<String, Claim> claims) { claims.putAll(jwt.getClaims());
// Get email from JWT token
String jwtClaim =
jwtPrincipalClaims.stream()
.filter(claims::containsKey)
.findFirst()
.map(claims::get)
.map(claim -> claim.as(TextNode.class).asText())
.orElseThrow(
() ->
new AuthenticationException(
"Invalid JWT token, none of the following claims are present "
+ jwtPrincipalClaims));
String userName; return claims;
String domain;
if (jwtClaim.contains("@")) {
userName = jwtClaim.split("@")[0];
domain = jwtClaim.split("@")[1];
} else {
userName = jwtClaim;
domain = StringUtils.EMPTY;
}
// validate principal domain, for users
boolean isBot =
claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean());
if (!isBot && (enforcePrincipalDomain && !domain.equals(principalDomain))) {
throw new AuthenticationException(
String.format(
"Not Authorized! Email does not match the principal domain %s", principalDomain));
}
return userName;
} }
protected static String extractToken(MultivaluedMap<String, String> headers) { protected static String extractToken(MultivaluedMap<String, String> headers) {
@ -275,14 +266,26 @@ public class JwtFilter implements ContainerRequestFilter {
throw AuthenticationException.getInvalidTokenException(); throw AuthenticationException.getInvalidTokenException();
} }
private void validatePersonalAccessToken(String tokenFromHeader, String userName) { private void validatePersonalAccessToken(
if (UserTokenCache.getToken(userName).contains(tokenFromHeader)) { Map<String, Claim> claims, String tokenFromHeader, String userName) {
if (claims.containsKey(TOKEN_TYPE)
&& ServiceTokenType.PERSONAL_ACCESS
.value()
.equals(
claims.get(TOKEN_TYPE) != null
? StringUtils.EMPTY
: claims.get(TOKEN_TYPE).asString())) {
Set<String> userTokens = UserTokenCache.getToken(userName);
if (userTokens != null && userTokens.contains(tokenFromHeader)) {
return; return;
} }
throw AuthenticationException.getInvalidTokenException(); throw AuthenticationException.getInvalidTokenException();
} }
}
private void validateTokenIsNotUsedAfterLogout(String authToken) { private void validateTokenIsNotUsedAfterLogout(String authToken) {
// Only OMD generated Tokens
if (AuthProvider.BASIC.equals(providerType) || AuthProvider.SAML.equals(providerType)) {
LogoutRequest previouslyLoggedOutEvent = LogoutRequest previouslyLoggedOutEvent =
JwtTokenCacheManager.getInstance().getLogoutEventForToken(authToken); JwtTokenCacheManager.getInstance().getLogoutEventForToken(authToken);
if (previouslyLoggedOutEvent != null) { if (previouslyLoggedOutEvent != null) {
@ -290,3 +293,4 @@ public class JwtFilter implements ContainerRequestFilter {
} }
} }
} }
}

View File

@ -13,10 +13,15 @@
package org.openmetadata.service.security; package org.openmetadata.service.security;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE; import static org.openmetadata.service.security.AuthLoginServlet.OIDC_CREDENTIAL_PROFILE;
import static org.openmetadata.service.security.JwtFilter.BOT_CLAIM;
import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY;
import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY;
import static org.pac4j.core.util.CommonHelper.assertNotNull; import static org.pac4j.core.util.CommonHelper.assertNotNull;
import static org.pac4j.core.util.CommonHelper.isNotEmpty; import static org.pac4j.core.util.CommonHelper.isNotEmpty;
import com.auth0.jwt.interfaces.Claim;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder; import com.google.common.collect.ImmutableMap.Builder;
@ -56,6 +61,7 @@ import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.SecurityContext; import javax.ws.rs.core.SecurityContext;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.openmetadata.common.utils.CommonUtil; import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.schema.security.client.OidcClientConfig; import org.openmetadata.schema.security.client.OidcClientConfig;
import org.openmetadata.service.OpenMetadataApplicationConfig; import org.openmetadata.service.OpenMetadataApplicationConfig;
@ -327,31 +333,16 @@ public final class SecurityUtil {
HttpServletResponse response, HttpServletResponse response,
OidcCredentials credentials, OidcCredentials credentials,
String serverUrl, String serverUrl,
Map<String, String> claimsMapping,
List<String> claimsOrder, List<String> claimsOrder,
String defaultDomain) String defaultDomain)
throws ParseException, IOException { throws ParseException, IOException {
JWT jwt = credentials.getIdToken(); JWT jwt = credentials.getIdToken();
Map<String, Object> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); Map<String, Object> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
claims.putAll(jwt.getJWTClaimsSet().getClaims()); claims.putAll(jwt.getJWTClaimsSet().getClaims());
String preferredJwtClaim =
claimsOrder.stream()
.filter(claims::containsKey)
.findFirst()
.map(claims::get)
.map(String.class::cast)
.orElseThrow(
() ->
new AuthenticationException(
"Invalid JWT token, none of the following claims are present "
+ claimsOrder));
String userName; String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims);
if (preferredJwtClaim.contains("@")) { String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, defaultDomain);
userName = preferredJwtClaim.split("@")[0];
} else {
userName = preferredJwtClaim;
}
String email = String.format("%s@%s", userName, defaultDomain);
String url = String url =
String.format( String.format(
@ -426,4 +417,133 @@ public final class SecurityUtil {
HttpUtils.closeConnection(connection); HttpUtils.closeConnection(connection);
} }
} }
public static String findUserNameFromClaims(
Map<String, String> jwtPrincipalClaimsMapping,
List<String> jwtPrincipalClaimsOrder,
Map<String, ?> claims) {
if (!nullOrEmpty(jwtPrincipalClaimsMapping)) {
// We have a mapping available so we will use that
String usernameClaim = jwtPrincipalClaimsMapping.get(USERNAME_CLAIM_KEY);
String userNameClaimValue = getClaimOrObject(claims.get(usernameClaim));
if (!nullOrEmpty(userNameClaimValue)) {
return userNameClaimValue;
} else {
throw new AuthenticationException("Invalid JWT token, 'username' claim is not present");
}
} else {
String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims);
String userName;
if (jwtClaim.contains("@")) {
userName = jwtClaim.split("@")[0];
} else {
userName = jwtClaim;
}
return userName;
}
}
public static String findEmailFromClaims(
Map<String, String> jwtPrincipalClaimsMapping,
List<String> jwtPrincipalClaimsOrder,
Map<String, ?> claims,
String defaulPrincipalClaim) {
if (!nullOrEmpty(jwtPrincipalClaimsMapping)) {
// We have a mapping available so we will use that
String emailClaim = jwtPrincipalClaimsMapping.get(EMAIL_CLAIM_KEY);
String emailClaimValue = getClaimOrObject(claims.get(emailClaim));
if (!nullOrEmpty(emailClaimValue) && emailClaimValue.contains("@")) {
return emailClaimValue;
} else {
throw new AuthenticationException(
String.format(
"Invalid JWT token, 'email' claim is not present or invalid : %s",
emailClaimValue));
}
} else {
String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims);
if (jwtClaim.contains("@")) {
return jwtClaim;
} else {
return String.format("%s@%s", jwtClaim, defaulPrincipalClaim);
}
}
}
private static String getClaimOrObject(Object obj) {
if (obj == null) {
return "";
}
if (obj instanceof Claim c) {
return c.asString();
} else if (obj instanceof String s) {
return s;
}
return StringUtils.EMPTY;
}
public static String getFirstMatchJwtClaim(
List<String> jwtPrincipalClaimsOrder, Map<String, ?> claims) {
return jwtPrincipalClaimsOrder.stream()
.filter(claims::containsKey)
.findFirst()
.map(claims::get)
.map(SecurityUtil::getClaimOrObject)
.orElseThrow(
() ->
new AuthenticationException(
"Invalid JWT token, none of the following claims are present "
+ jwtPrincipalClaimsOrder));
}
public static void validatePrincipalClaimsMapping(Map<String, String> mapping) {
if (!nullOrEmpty(mapping)) {
String username = mapping.get(USERNAME_CLAIM_KEY);
String email = mapping.get(EMAIL_CLAIM_KEY);
if (nullOrEmpty(username) || nullOrEmpty(email)) {
throw new IllegalArgumentException(
"Invalid JWT Principal Claims Mapping. Both username and email should be present");
}
}
// If emtpy, jwtPrincipalClaims will be used so no need to validate
}
public static void validateDomainEnforcement(
Map<String, String> jwtPrincipalClaimsMapping,
List<String> jwtPrincipalClaimsOrder,
Map<String, Claim> claims,
String principalDomain,
boolean enforcePrincipalDomain) {
String domain = StringUtils.EMPTY;
if (!nullOrEmpty(jwtPrincipalClaimsMapping)) {
// We have a mapping available so we will use that
String emailClaim = jwtPrincipalClaimsMapping.get(EMAIL_CLAIM_KEY);
String emailClaimValue = getClaimOrObject(claims.get(emailClaim));
if (!nullOrEmpty(emailClaimValue)) {
if (emailClaimValue.contains("@")) {
domain = emailClaimValue.split("@")[1];
}
} else {
throw new AuthenticationException("Invalid JWT token, 'email' claim is not present");
}
} else {
String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims);
if (jwtClaim.contains("@")) {
domain = jwtClaim.split("@")[1];
}
}
// Validate
if (!isBot(claims) && (enforcePrincipalDomain && !domain.equals(principalDomain))) {
throw new AuthenticationException(
String.format(
"Not Authorized! Email does not match the principal domain %s", principalDomain));
}
}
public static boolean isBot(Map<String, Claim> claims) {
return claims.containsKey(BOT_CLAIM) && Boolean.TRUE.equals(claims.get(BOT_CLAIM).asBoolean());
}
} }

View File

@ -14,11 +14,9 @@
package org.openmetadata.service.socket; package org.openmetadata.service.socket;
import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import io.socket.engineio.server.utils.ParseQS; import io.socket.engineio.server.utils.ParseQS;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.TreeMap;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.FilterConfig; import javax.servlet.FilterConfig;
@ -29,6 +27,7 @@ import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.api.security.AuthenticationConfiguration; import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.api.security.AuthorizerConfiguration; import org.openmetadata.schema.api.security.AuthorizerConfiguration;
import org.openmetadata.service.security.JwtFilter; import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.security.SecurityUtil;
@Slf4j @Slf4j
public class SocketAddressFilter implements Filter { public class SocketAddressFilter implements Filter {
@ -82,11 +81,10 @@ public class SocketAddressFilter implements Filter {
public static void validatePrefixedTokenRequest(JwtFilter jwtFilter, String prefixedToken) { public static void validatePrefixedTokenRequest(JwtFilter jwtFilter, String prefixedToken) {
String token = JwtFilter.extractToken(prefixedToken); String token = JwtFilter.extractToken(prefixedToken);
// validate token Map<String, Claim> claims = jwtFilter.validateJwtAndGetClaims(token);
DecodedJWT jwt = jwtFilter.validateAndReturnDecodedJwtToken(token); String userName =
// validate Domain and Username SecurityUtil.findUserNameFromClaims(
Map<String, Claim> claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); jwtFilter.getJwtPrincipalClaimsMapping(), jwtFilter.getJwtPrincipalClaims(), claims);
claims.putAll(jwt.getClaims()); jwtFilter.checkValidationsForToken(claims, token, userName);
jwtFilter.validateAndReturnUsername(claims);
} }
} }

View File

@ -65,6 +65,13 @@
"type": "string" "type": "string"
} }
}, },
"jwtPrincipalClaimsMapping": {
"description": "Jwt Principal Claim Mapping",
"type": "array",
"items": {
"type": "string"
}
},
"enableSelfSignup": { "enableSelfSignup": {
"description": "Enable Self Sign Up", "description": "Enable Self Sign Up",
"type": "boolean", "type": "boolean",

View File

@ -64,6 +64,7 @@ export interface IAuthContext {
updateAxiosInterceptors: () => void; updateAxiosInterceptors: () => void;
updateCurrentUser: (user: User) => void; updateCurrentUser: (user: User) => void;
jwtPrincipalClaims: string[]; jwtPrincipalClaims: string[];
jwtPrincipalClaimsMapping: string[];
} }
export type AuthenticationConfigurationWithScope = export type AuthenticationConfigurationWithScope =

View File

@ -115,6 +115,7 @@ export const AuthProvider = ({
setAuthorizerConfig, setAuthorizerConfig,
setIsSigningUp, setIsSigningUp,
setJwtPrincipalClaims, setJwtPrincipalClaims,
setJwtPrincipalClaimsMapping,
removeRefreshToken, removeRefreshToken,
removeOidcToken, removeOidcToken,
getOidcToken, getOidcToken,
@ -538,6 +539,7 @@ export const AuthProvider = ({
if (provider && Object.values(AuthProviderEnum).includes(provider)) { if (provider && Object.values(AuthProviderEnum).includes(provider)) {
const configJson = getAuthConfig(authConfig); const configJson = getAuthConfig(authConfig);
setJwtPrincipalClaims(authConfig.jwtPrincipalClaims); setJwtPrincipalClaims(authConfig.jwtPrincipalClaims);
setJwtPrincipalClaimsMapping(authConfig.jwtPrincipalClaimsMapping);
setAuthConfig(configJson); setAuthConfig(configJson);
setAuthorizerConfig(authorizerConfig); setAuthorizerConfig(authorizerConfig);
updateAuthInstance(configJson); updateAuthInstance(configJson);

View File

@ -43,6 +43,7 @@ export const useApplicationStore = create<ApplicationStore>()(
authorizerConfig: undefined, authorizerConfig: undefined,
isSigningUp: false, isSigningUp: false,
jwtPrincipalClaims: [], jwtPrincipalClaims: [],
jwtPrincipalClaimsMapping: [],
userProfilePics: {}, userProfilePics: {},
cachedEntityData: {}, cachedEntityData: {},
selectedPersona: {} as EntityReference, selectedPersona: {} as EntityReference,
@ -75,6 +76,11 @@ export const useApplicationStore = create<ApplicationStore>()(
) => { ) => {
set({ jwtPrincipalClaims: claims }); set({ jwtPrincipalClaims: claims });
}, },
setJwtPrincipalClaimsMapping: (
claimMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping']
) => {
set({ jwtPrincipalClaimsMapping: claimMapping });
},
setIsAuthenticated: (authenticated: boolean) => { setIsAuthenticated: (authenticated: boolean) => {
set({ isAuthenticated: authenticated }); set({ isAuthenticated: authenticated });
}, },

View File

@ -61,6 +61,9 @@ export interface ApplicationStore
setJwtPrincipalClaims: ( setJwtPrincipalClaims: (
claims: AuthenticationConfiguration['jwtPrincipalClaims'] claims: AuthenticationConfiguration['jwtPrincipalClaims']
) => void; ) => void;
setJwtPrincipalClaimsMapping: (
claimsMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping']
) => void;
setHelperFunctionsRef: (helperFunctions: HelperFunctions) => void; setHelperFunctionsRef: (helperFunctions: HelperFunctions) => void;
updateUserProfilePics: (data: { id: string; user: User }) => void; updateUserProfilePics: (data: { id: string; user: User }) => void;
updateCachedEntityData: (data: { updateCachedEntityData: (data: {

View File

@ -43,6 +43,7 @@ const SignUp = () => {
const { const {
setIsSigningUp, setIsSigningUp,
jwtPrincipalClaims = [], jwtPrincipalClaims = [],
jwtPrincipalClaimsMapping = [],
authorizerConfig, authorizerConfig,
updateCurrentUser, updateCurrentUser,
newUser, newUser,
@ -116,7 +117,8 @@ const SignUp = () => {
...getNameFromUserData( ...getNameFromUserData(
newUser as UserProfile, newUser as UserProfile,
jwtPrincipalClaims, jwtPrincipalClaims,
authorizerConfig?.principalDomain authorizerConfig?.principalDomain,
jwtPrincipalClaimsMapping
), ),
}} }}
layout="vertical" layout="vertical"

View File

@ -18,7 +18,7 @@ import {
} from '@azure/msal-browser'; } from '@azure/msal-browser';
import { CookieStorage } from 'cookie-storage'; import { CookieStorage } from 'cookie-storage';
import jwtDecode, { JwtPayload } from 'jwt-decode'; import jwtDecode, { JwtPayload } from 'jwt-decode';
import { first, isNil } from 'lodash'; import { first, get, isEmpty, isNil } from 'lodash';
import { WebStorageStateStore } from 'oidc-client'; import { WebStorageStateStore } from 'oidc-client';
import { import {
AuthenticationConfigurationWithScope, AuthenticationConfigurationWithScope,
@ -231,8 +231,13 @@ export const getNameFromEmail = (email: string) => {
export const getNameFromUserData = ( export const getNameFromUserData = (
user: UserProfile, user: UserProfile,
jwtPrincipalClaims: AuthenticationConfiguration['jwtPrincipalClaims'] = [], jwtPrincipalClaims: AuthenticationConfiguration['jwtPrincipalClaims'] = [],
principleDomain = '' principleDomain = '',
jwtPrincipalClaimsMapping: AuthenticationConfiguration['jwtPrincipalClaimsMapping'] = []
) => { ) => {
let userName = '';
let domain = principleDomain;
let email = '';
if (isEmpty(jwtPrincipalClaimsMapping)) {
// filter and extract the present claims in user profile // filter and extract the present claims in user profile
const jwtClaims = jwtPrincipalClaims.reduce( const jwtClaims = jwtPrincipalClaims.reduce(
(prev: string[], curr: string) => { (prev: string[], curr: string) => {
@ -249,9 +254,6 @@ export const getNameFromUserData = (
// get the first claim from claims list // get the first claim from claims list
const firstClaim = first(jwtClaims); const firstClaim = first(jwtClaims);
let userName = '';
let domain = principleDomain;
// if claims contains the "@" then split it out otherwise assign it to username as it is // if claims contains the "@" then split it out otherwise assign it to username as it is
if (firstClaim?.includes('@')) { if (firstClaim?.includes('@')) {
userName = firstClaim.split('@')[0]; userName = firstClaim.split('@')[0];
@ -260,7 +262,28 @@ export const getNameFromUserData = (
userName = firstClaim ?? ''; userName = firstClaim ?? '';
} }
return { name: userName, email: userName + '@' + domain }; email = userName + '@' + domain;
} else {
const mappingObj: Record<string, string> = {};
jwtPrincipalClaimsMapping.reduce((acc, value) => {
const [key, claim] = value.split(':');
acc[key] = claim;
return acc;
}, mappingObj);
if (mappingObj['username'] && mappingObj['email']) {
userName = get(user, mappingObj['username'], '');
email = get(user, mappingObj['email']);
} else {
// eslint-disable-next-line no-console
console.error(
'username or email is not present in jwtPrincipalClaimsMapping'
);
}
}
return { name: userName, email: email };
}; };
export const isProtectedRoute = (pathname: string) => { export const isProtectedRoute = (pathname: string) => {