fix(frontend): Add JSON list oidc group extraction logic (#9495)

Co-authored-by: Ethan Cartwright <ethan.cartwright@acryl.io>
This commit is contained in:
ethan-cartwright 2023-12-26 09:04:05 -05:00 committed by GitHub
parent d399a53057
commit 1e64a75339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 17 deletions

View File

@ -10,6 +10,8 @@ import auth.CookieConfigs;
import auth.sso.SsoManager;
import client.AuthServiceClient;
import com.datahub.authentication.Authentication;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linkedin.common.AuditStamp;
import com.linkedin.common.CorpGroupUrnArray;
import com.linkedin.common.CorpuserUrnArray;
@ -300,6 +302,29 @@ public class OidcCallbackLogic extends DefaultCallbackLogic<Result, PlayWebConte
return corpUserSnapshot;
}
public static Collection<String> getGroupNames(CommonProfile profile, Object groupAttribute, String groupsClaimName) {
Collection<String> groupNames = Collections.emptyList();
try {
if (groupAttribute instanceof Collection) {
// List of group names
groupNames = (Collection<String>) profile.getAttribute(groupsClaimName, Collection.class);
} else if (groupAttribute instanceof String) {
String groupString = (String) groupAttribute;
ObjectMapper objectMapper = new ObjectMapper();
try {
// Json list of group names
groupNames = objectMapper.readValue(groupString, new TypeReference<List<String>>(){});
} catch (Exception e) {
groupNames = Arrays.asList(groupString.split(","));
}
}
} catch (Exception e) {
log.error(String.format(
"Failed to parse group names: Expected to find a list of strings for attribute with name %s, found %s",
groupsClaimName, profile.getAttribute(groupsClaimName).getClass()));
}
return groupNames;
}
private List<CorpGroupSnapshot> extractGroups(CommonProfile profile) {
log.debug(
@ -320,23 +345,7 @@ public class OidcCallbackLogic extends DefaultCallbackLogic<Result, PlayWebConte
if (profile.containsAttribute(groupsClaimName)) {
try {
final List<CorpGroupSnapshot> groupSnapshots = new ArrayList<>();
final Collection<String> groupNames;
final Object groupAttribute = profile.getAttribute(groupsClaimName);
if (groupAttribute instanceof Collection) {
// List of group names
groupNames =
(Collection<String>) profile.getAttribute(groupsClaimName, Collection.class);
} else if (groupAttribute instanceof String) {
// Single group name
groupNames = Collections.singleton(profile.getAttribute(groupsClaimName, String.class));
} else {
log.error(
String.format(
"Fail to parse OIDC group claim with name %s. Unknown type %s provided.",
groupsClaimName, groupAttribute.getClass()));
// Skip over group attribute. Do not throw.
groupNames = Collections.emptyList();
}
Collection<String> groupNames = getGroupNames(profile, profile.getAttribute(groupsClaimName), groupsClaimName);
for (String groupName : groupNames) {
// Create a basic CorpGroupSnapshot from the information.

View File

@ -0,0 +1,64 @@
package oidc;
import auth.sso.oidc.OidcConfigs;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import static auth.sso.oidc.OidcCallbackLogic.getGroupNames;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.pac4j.core.profile.CommonProfile;
public class OidcCallbackLogicTest {
@Test
public void testGetGroupsClaimNamesJsonArray() {
CommonProfile profile = createMockProfileWithAttribute("[\"group1\", \"group2\"]", "groupsClaimName");
Collection<String> result = getGroupNames(profile, "[\"group1\", \"group2\"]", "groupsClaimName");
assertEquals(Arrays.asList("group1", "group2"), result);
}
@Test
public void testGetGroupNamesWithSingleGroup() {
CommonProfile profile = createMockProfileWithAttribute("group1", "groupsClaimName");
Collection<String> result = getGroupNames(profile, "group1", "groupsClaimName");
assertEquals(Arrays.asList("group1"), result);
}
@Test
public void testGetGroupNamesWithCommaSeparated() {
CommonProfile profile = createMockProfileWithAttribute("group1,group2", "groupsClaimName");
Collection<String> result = getGroupNames(profile, "group1,group2", "groupsClaimName");
assertEquals(Arrays.asList("group1", "group2"), result);
}
@Test
public void testGetGroupNamesWithCollection() {
CommonProfile profile = createMockProfileWithAttribute(Arrays.asList("group1", "group2"), "groupsClaimName");
Collection<String> result = getGroupNames(profile, Arrays.asList("group1", "group2"), "groupsClaimName");
assertEquals(Arrays.asList("group1", "group2"), result);
}
// Helper method to create a mock CommonProfile with given attribute
private CommonProfile createMockProfileWithAttribute(Object attribute, String attributeName) {
CommonProfile profile = mock(CommonProfile.class);
// Mock for getAttribute(String)
when(profile.getAttribute(attributeName)).thenReturn(attribute);
// Mock for getAttribute(String, Class<T>)
if (attribute instanceof Collection) {
when(profile.getAttribute(attributeName, Collection.class)).thenReturn((Collection) attribute);
} else if (attribute instanceof String) {
when(profile.getAttribute(attributeName, String.class)).thenReturn((String) attribute);
}
// Add more conditions here if needed for other types
return profile;
}
}