fix(oidc settings): effective JWS algorithm setting (#9712)

This commit is contained in:
Davi Arnaut 2024-01-24 17:36:30 -08:00 committed by GitHub
parent 9d8e2b9067
commit 23277f8dc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 179 additions and 6 deletions

View File

@ -76,6 +76,9 @@ public class AuthUtils {
public static final String USE_NONCE = "useNonce";
public static final String READ_TIMEOUT = "readTimeout";
public static final String EXTRACT_JWT_ACCESS_TOKEN_CLAIMS = "extractJwtAccessTokenClaims";
// Retained for backwards compatibility
public static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
public static final String PREFERRED_JWS_ALGORITHM_2 = "preferredJwsAlgorithm2";
/**
* Determines whether the inbound request should be forward to downstream Metadata Service. Today,

View File

@ -226,8 +226,8 @@ public class OidcConfigs extends SsoConfigs {
extractJwtAccessTokenClaims =
Optional.of(jsonNode.get(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS).asBoolean());
}
if (jsonNode.has(OIDC_PREFERRED_JWS_ALGORITHM)) {
preferredJwsAlgorithm = Optional.of(jsonNode.get(OIDC_PREFERRED_JWS_ALGORITHM).asText());
if (jsonNode.has(PREFERRED_JWS_ALGORITHM_2)) {
preferredJwsAlgorithm = Optional.of(jsonNode.get(PREFERRED_JWS_ALGORITHM_2).asText());
} else {
preferredJwsAlgorithm =
Optional.ofNullable(getOptional(configs, OIDC_PREFERRED_JWS_ALGORITHM, null));

View File

@ -101,6 +101,9 @@ play {
test {
useJUnitPlatform()
testLogging.showStandardStreams = true
testLogging.exceptionFormat = 'full'
def playJava17CompatibleJvmArgs = [
"--add-opens=java.base/java.lang=ALL-UNNAMED",
//"--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",

View File

@ -1,5 +1,6 @@
package security;
import static auth.AuthUtils.*;
import static auth.sso.oidc.OidcConfigs.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -24,6 +25,7 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.pac4j.oidc.client.OidcClient;
import org.json.JSONObject;
public class OidcConfigurationTest {
@ -317,4 +319,26 @@ public class OidcConfigurationTest {
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals(10000, ((OidcClient) oidcProvider.client()).getConfiguration().getReadTimeout());
}
@Test
public void readPreferredJwsAlgorithmPropagationFromConfig() {
final String SSO_SETTINGS_JSON_STR = new JSONObject().put(PREFERRED_JWS_ALGORITHM, "HS256").toString();
CONFIG.withValue(OIDC_PREFERRED_JWS_ALGORITHM, ConfigValueFactory.fromAnyRef("RS256"));
OidcConfigs.Builder oidcConfigsBuilder = new OidcConfigs.Builder();
oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR);
OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder);
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals("RS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString());
}
@Test
public void readPreferredJwsAlgorithmPropagationFromJSON() {
final String SSO_SETTINGS_JSON_STR = new JSONObject().put(PREFERRED_JWS_ALGORITHM, "Unused").put(PREFERRED_JWS_ALGORITHM_2, "HS256").toString();
CONFIG.withValue(OIDC_PREFERRED_JWS_ALGORITHM, ConfigValueFactory.fromAnyRef("RS256"));
OidcConfigs.Builder oidcConfigsBuilder = new OidcConfigs.Builder();
oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR);
OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder);
OidcProvider oidcProvider = new OidcProvider(oidcConfigs);
assertEquals("HS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString());
}
}

View File

@ -90,7 +90,12 @@ record OidcSettings {
extractJwtAccessTokenClaims: optional boolean
/**
* ADVANCED. Which jws algorithm to use.
* ADVANCED. Which jws algorithm to use. Unused.
*/
preferredJwsAlgorithm: optional string
}
/**
* ADVANCED. Which jws algorithm to use.
*/
preferredJwsAlgorithm2: optional string
}

View File

@ -18,4 +18,12 @@ dependencies {
compileOnly externalDependency.lombok
annotationProcessor externalDependency.lombok
testImplementation externalDependency.testng
testImplementation externalDependency.springBootTest
}
test {
testLogging.showStandardStreams = true
testLogging.exceptionFormat = 'full'
}

View File

@ -72,7 +72,9 @@ public class AuthServiceController {
private static final String USE_NONCE = "useNonce";
private static final String READ_TIMEOUT = "readTimeout";
private static final String EXTRACT_JWT_ACCESS_TOKEN_CLAIMS = "extractJwtAccessTokenClaims";
// Retained for backwards compatibility
private static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
private static final String PREFERRED_JWS_ALGORITHM_2 = "preferredJwsAlgorithm2";
@Inject StatelessTokenService _statelessTokenService;
@ -514,8 +516,8 @@ public class AuthServiceController {
if (oidcSettings.hasExtractJwtAccessTokenClaims()) {
json.put(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS, oidcSettings.isExtractJwtAccessTokenClaims());
}
if (oidcSettings.hasPreferredJwsAlgorithm()) {
json.put(PREFERRED_JWS_ALGORITHM, oidcSettings.getPreferredJwsAlgorithm());
if (oidcSettings.hasPreferredJwsAlgorithm2()) {
json.put(PREFERRED_JWS_ALGORITHM, oidcSettings.getPreferredJwsAlgorithm2());
}
}
}

View File

@ -0,0 +1,96 @@
package com.datahub.auth.authentication;
import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_INFO_ASPECT_NAME;
import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_URN;
import static org.mockito.Mockito.when;
import static org.testng.Assert.*;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.settings.global.GlobalSettingsInfo;
import com.linkedin.settings.global.OidcSettings;
import com.linkedin.settings.global.SsoSettings;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.testng.AbstractTestNGSpringContextTests;
import org.springframework.web.servlet.DispatcherServlet;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
@SpringBootTest(classes = {DispatcherServlet.class})
@ComponentScan(basePackages = {"com.datahub.auth.authentication"})
@Import({AuthServiceTestConfiguration.class})
public class AuthServiceControllerTest extends AbstractTestNGSpringContextTests {
@BeforeTest
public void disableAssert() {
PathSpecBasedSchemaAnnotationVisitor.class
.getClassLoader()
.setClassAssertionStatus(PathSpecBasedSchemaAnnotationVisitor.class.getName(), false);
}
@Autowired private AuthServiceController authServiceController;
@Autowired private EntityService mockEntityService;
private final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
@Test
public void initTest() {
assertNotNull(authServiceController);
assertNotNull(mockEntityService);
}
@Test
public void oldPreferredJwsAlgorithmIsNotReturned() throws IOException {
OidcSettings mockOidcSettings =
new OidcSettings()
.setEnabled(true)
.setClientId("1")
.setClientSecret("2")
.setDiscoveryUri("http://localhost")
.setPreferredJwsAlgorithm("test");
SsoSettings mockSsoSettings =
new SsoSettings().setBaseUrl("http://localhost").setOidcSettings(mockOidcSettings);
GlobalSettingsInfo mockGlobalSettingsInfo = new GlobalSettingsInfo().setSso(mockSsoSettings);
when(mockEntityService.getLatestAspect(GLOBAL_SETTINGS_URN, GLOBAL_SETTINGS_INFO_ASPECT_NAME))
.thenReturn(mockGlobalSettingsInfo);
ResponseEntity<String> httpResponse = authServiceController.getSsoSettings(null).join();
assertEquals(httpResponse.getStatusCode(), HttpStatus.OK);
JsonNode jsonNode = new ObjectMapper().readTree(httpResponse.getBody());
assertFalse(jsonNode.has(PREFERRED_JWS_ALGORITHM));
}
@Test
public void newPreferredJwsAlgorithmIsReturned() throws IOException {
OidcSettings mockOidcSettings =
new OidcSettings()
.setEnabled(true)
.setClientId("1")
.setClientSecret("2")
.setDiscoveryUri("http://localhost")
.setPreferredJwsAlgorithm("jws1")
.setPreferredJwsAlgorithm2("jws2");
SsoSettings mockSsoSettings =
new SsoSettings().setBaseUrl("http://localhost").setOidcSettings(mockOidcSettings);
GlobalSettingsInfo mockGlobalSettingsInfo = new GlobalSettingsInfo().setSso(mockSsoSettings);
when(mockEntityService.getLatestAspect(GLOBAL_SETTINGS_URN, GLOBAL_SETTINGS_INFO_ASPECT_NAME))
.thenReturn(mockGlobalSettingsInfo);
ResponseEntity<String> httpResponse = authServiceController.getSsoSettings(null).join();
assertEquals(httpResponse.getStatusCode(), HttpStatus.OK);
JsonNode jsonNode = new ObjectMapper().readTree(httpResponse.getBody());
assertTrue(jsonNode.has(PREFERRED_JWS_ALGORITHM));
assertEquals(jsonNode.get(PREFERRED_JWS_ALGORITHM).asText(), "jws2");
}
}

View File

@ -0,0 +1,32 @@
package com.datahub.auth.authentication;
import com.datahub.authentication.Authentication;
import com.datahub.authentication.invite.InviteTokenService;
import com.datahub.authentication.token.StatelessTokenService;
import com.datahub.authentication.user.NativeUserService;
import com.datahub.telemetry.TrackingService;
import com.linkedin.gms.factory.config.ConfigurationProvider;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.secret.SecretService;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.mock.mockito.MockBean;
@TestConfiguration
public class AuthServiceTestConfiguration {
@MockBean StatelessTokenService _statelessTokenService;
@MockBean Authentication _systemAuthentication;
@MockBean(name = "configurationProvider")
ConfigurationProvider _configProvider;
@MockBean NativeUserService _nativeUserService;
@MockBean EntityService _entityService;
@MockBean SecretService _secretService;
@MockBean InviteTokenService _inviteTokenService;
@MockBean TrackingService _trackingService;
}