From 1e64a75339cd1e4e99ef0ab4b926057a2cceb511 Mon Sep 17 00:00:00 2001 From: ethan-cartwright Date: Tue, 26 Dec 2023 09:04:05 -0500 Subject: [PATCH] fix(frontend): Add JSON list oidc group extraction logic (#9495) Co-authored-by: Ethan Cartwright --- .../app/auth/sso/oidc/OidcCallbackLogic.java | 43 ++++++++----- .../test/oidc/OidcCallbackLogicTest.java | 64 +++++++++++++++++++ 2 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 datahub-frontend/test/oidc/OidcCallbackLogicTest.java diff --git a/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java b/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java index fa562f5431..c72c353708 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java @@ -10,6 +10,8 @@ import auth.CookieConfigs; import auth.sso.SsoManager; import client.AuthServiceClient; import com.datahub.authentication.Authentication; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.linkedin.common.AuditStamp; import com.linkedin.common.CorpGroupUrnArray; import com.linkedin.common.CorpuserUrnArray; @@ -300,6 +302,29 @@ public class OidcCallbackLogic extends DefaultCallbackLogic getGroupNames(CommonProfile profile, Object groupAttribute, String groupsClaimName) { + Collection groupNames = Collections.emptyList(); + try { + if (groupAttribute instanceof Collection) { + // List of group names + groupNames = (Collection) profile.getAttribute(groupsClaimName, Collection.class); + } else if (groupAttribute instanceof String) { + String groupString = (String) groupAttribute; + ObjectMapper objectMapper = new ObjectMapper(); + try { + // Json list of group names + groupNames = objectMapper.readValue(groupString, new TypeReference>(){}); + } catch (Exception e) { + groupNames = Arrays.asList(groupString.split(",")); + } + } + } catch (Exception e) { + log.error(String.format( + "Failed to parse group names: Expected to find a list of strings for attribute with name %s, found %s", + groupsClaimName, profile.getAttribute(groupsClaimName).getClass())); + } + return groupNames; + } private List extractGroups(CommonProfile profile) { log.debug( @@ -320,23 +345,7 @@ public class OidcCallbackLogic extends DefaultCallbackLogic groupSnapshots = new ArrayList<>(); - final Collection groupNames; - final Object groupAttribute = profile.getAttribute(groupsClaimName); - if (groupAttribute instanceof Collection) { - // List of group names - groupNames = - (Collection) profile.getAttribute(groupsClaimName, Collection.class); - } else if (groupAttribute instanceof String) { - // Single group name - groupNames = Collections.singleton(profile.getAttribute(groupsClaimName, String.class)); - } else { - log.error( - String.format( - "Fail to parse OIDC group claim with name %s. Unknown type %s provided.", - groupsClaimName, groupAttribute.getClass())); - // Skip over group attribute. Do not throw. - groupNames = Collections.emptyList(); - } + Collection groupNames = getGroupNames(profile, profile.getAttribute(groupsClaimName), groupsClaimName); for (String groupName : groupNames) { // Create a basic CorpGroupSnapshot from the information. diff --git a/datahub-frontend/test/oidc/OidcCallbackLogicTest.java b/datahub-frontend/test/oidc/OidcCallbackLogicTest.java new file mode 100644 index 0000000000..f4784c29e9 --- /dev/null +++ b/datahub-frontend/test/oidc/OidcCallbackLogicTest.java @@ -0,0 +1,64 @@ +package oidc; + +import auth.sso.oidc.OidcConfigs; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static auth.sso.oidc.OidcCallbackLogic.getGroupNames; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.pac4j.core.profile.CommonProfile; + +public class OidcCallbackLogicTest { + + @Test + public void testGetGroupsClaimNamesJsonArray() { + CommonProfile profile = createMockProfileWithAttribute("[\"group1\", \"group2\"]", "groupsClaimName"); + Collection result = getGroupNames(profile, "[\"group1\", \"group2\"]", "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } + @Test + public void testGetGroupNamesWithSingleGroup() { + CommonProfile profile = createMockProfileWithAttribute("group1", "groupsClaimName"); + Collection result = getGroupNames(profile, "group1", "groupsClaimName"); + assertEquals(Arrays.asList("group1"), result); + } + + @Test + public void testGetGroupNamesWithCommaSeparated() { + CommonProfile profile = createMockProfileWithAttribute("group1,group2", "groupsClaimName"); + Collection result = getGroupNames(profile, "group1,group2", "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } + + @Test + public void testGetGroupNamesWithCollection() { + CommonProfile profile = createMockProfileWithAttribute(Arrays.asList("group1", "group2"), "groupsClaimName"); + Collection result = getGroupNames(profile, Arrays.asList("group1", "group2"), "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } + // Helper method to create a mock CommonProfile with given attribute + private CommonProfile createMockProfileWithAttribute(Object attribute, String attributeName) { + CommonProfile profile = mock(CommonProfile.class); + + // Mock for getAttribute(String) + when(profile.getAttribute(attributeName)).thenReturn(attribute); + + // Mock for getAttribute(String, Class) + if (attribute instanceof Collection) { + when(profile.getAttribute(attributeName, Collection.class)).thenReturn((Collection) attribute); + } else if (attribute instanceof String) { + when(profile.getAttribute(attributeName, String.class)).thenReturn((String) attribute); + } + // Add more conditions here if needed for other types + + return profile; + } +}