From e7463ac1f4449afb4fe6f382e32f48ce9b9a7483 Mon Sep 17 00:00:00 2001 From: david-leifker <114954101+david-leifker@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:27:42 -0500 Subject: [PATCH] fix(token-service): extend validation for actor (#13947) --- .../token/StatefulTokenService.java | 2 +- .../token/StatelessTokenService.java | 41 +++- .../token/StatelessTokenServiceTest.java | 218 +++++++++++++++++- .../tests/tokens/session_access_token_test.py | 15 +- 4 files changed, 253 insertions(+), 23 deletions(-) diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatefulTokenService.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatefulTokenService.java index 01a35a72e3..892633e0e5 100644 --- a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatefulTokenService.java +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatefulTokenService.java @@ -46,7 +46,7 @@ public class StatefulTokenService extends StatelessTokenService { @Nullable final String iss, @Nonnull final EntityService entityService, @Nonnull final String salt) { - super(signingKey, signingAlgorithm, iss); + super(systemOperationContext, signingKey, signingAlgorithm, iss); this.systemOperationContext = systemOperationContext; this._entityService = entityService; this._revokedTokenCache = diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatelessTokenService.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatelessTokenService.java index 71f12477a3..04a694793c 100644 --- a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatelessTokenService.java +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/StatelessTokenService.java @@ -2,6 +2,10 @@ package com.datahub.authentication.token; import com.datahub.authentication.Actor; import com.datahub.authentication.ActorType; +import com.datahub.authentication.Authentication; +import com.linkedin.metadata.aspect.AspectRetriever; +import io.datahubproject.metadata.context.ActorContext; +import io.datahubproject.metadata.context.OperationContext; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jws; import io.jsonwebtoken.JwtBuilder; @@ -37,19 +41,24 @@ public class StatelessTokenService { private final String signingKey; private final SignatureAlgorithm signingAlgorithm; private final String iss; + private final OperationContext systemOperationContext; public StatelessTokenService( - @Nonnull final String signingKey, @Nonnull final String signingAlgorithm) { - this(signingKey, signingAlgorithm, null); + @Nonnull OperationContext systemOperationContext, + @Nonnull final String signingKey, + @Nonnull final String signingAlgorithm) { + this(systemOperationContext, signingKey, signingAlgorithm, null); } public StatelessTokenService( + @Nonnull OperationContext systemOperationContext, @Nonnull final String signingKey, @Nonnull final String signingAlgorithm, @Nullable final String iss) { this.signingKey = Objects.requireNonNull(signingKey); this.signingAlgorithm = validateAlgorithm(Objects.requireNonNull(signingAlgorithm)); this.iss = iss; + this.systemOperationContext = systemOperationContext; } /** @@ -131,6 +140,9 @@ public class StatelessTokenService { final String actorId = claims.get(TokenClaims.ACTOR_ID_CLAIM_NAME, String.class); final String actorType = claims.get(TokenClaims.ACTOR_TYPE_CLAIM_NAME, String.class); if (tokenType != null && actorId != null && actorType != null) { + // Validate the actor is active before returning claims + validateActor(actorId); + return new TokenClaims( TokenVersion.fromNumericStringValue(tokenVersion), TokenType.valueOf(tokenType), @@ -140,13 +152,38 @@ public class StatelessTokenService { } } catch (io.jsonwebtoken.ExpiredJwtException e) { throw new TokenExpiredException("Failed to validate DataHub token. Token has expired.", e); + } catch (TokenException e) { + throw e; } catch (Exception e) { throw new TokenException("Failed to validate DataHub token", e); } + throw new TokenException( "Failed to validate DataHub token: Found malformed or missing 'actor' claim."); } + /** Validates that the actor is active using the OperationContext's built-in validation */ + private void validateActor(@Nonnull final String actorId) throws TokenException { + try { + AspectRetriever aspectRetriever = systemOperationContext.getAspectRetriever(); + ActorContext actorContext = + ActorContext.builder() + .authentication(new Authentication(new Actor(ActorType.USER, actorId), "")) + .enforceExistenceEnabled(true) + .build(); + + // Use the existing isActive check from ActorContext + if (!actorContext.isActive(aspectRetriever)) { + throw new TokenException("Actor is not active"); + } + } catch (Exception e) { + if (e instanceof TokenException) { + throw (TokenException) e; + } + throw new TokenException("Failed to validate actor status", e); + } + } + private void validateTokenAlgorithm(final String algorithm) throws TokenException { try { validateAlgorithm(algorithm); diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/StatelessTokenServiceTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/StatelessTokenServiceTest.java index 8413084415..e4c5a5fb6c 100644 --- a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/StatelessTokenServiceTest.java +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/StatelessTokenServiceTest.java @@ -6,38 +6,52 @@ import static org.testng.Assert.*; import com.datahub.authentication.Actor; import com.datahub.authentication.ActorType; import com.datahub.authentication.authenticator.DataHubTokenAuthenticator; +import com.linkedin.common.Status; +import com.linkedin.common.urn.Urn; +import com.linkedin.common.urn.UrnUtils; +import com.linkedin.entity.Aspect; +import com.linkedin.identity.CorpUserStatus; +import com.linkedin.metadata.aspect.AspectRetriever; +import com.linkedin.metadata.key.CorpUserKey; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.test.metadata.context.TestOperationContexts; import io.jsonwebtoken.JwtBuilder; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; import java.nio.charset.StandardCharsets; import java.security.Key; +import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.UUID; import javax.crypto.spec.SecretKeySpec; +import org.mockito.Mockito; import org.testng.annotations.Test; public class StatelessTokenServiceTest { private static final String TEST_SIGNING_KEY = "WnEdIeTG/VVCLQqGwC/BAkqyY0k+H8NEAtWGejrBI94="; + private OperationContext opContext = TestOperationContexts.systemContextNoValidate(); @Test public void testConstructor() { final DataHubTokenAuthenticator authenticator = new DataHubTokenAuthenticator(); assertThrows(() -> new StatelessTokenService(null, null, null)); - assertThrows(() -> new StatelessTokenService(TEST_SIGNING_KEY, null, null)); - assertThrows(() -> new StatelessTokenService(TEST_SIGNING_KEY, "UNSUPPORTED_ALG", null)); + assertThrows(() -> new StatelessTokenService(opContext, null, null, null)); + assertThrows(() -> new StatelessTokenService(opContext, TEST_SIGNING_KEY, null, null)); + assertThrows( + () -> new StatelessTokenService(opContext, TEST_SIGNING_KEY, "UNSUPPORTED_ALG", null)); // Succeeds: - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); - new StatelessTokenService(TEST_SIGNING_KEY, "HS256", null); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256", null); } @Test public void testGenerateAccessTokenPersonalToken() throws Exception { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); String token = statelessTokenService.generateAccessToken( TokenType.PERSONAL, new Actor(ActorType.USER, "datahub")); @@ -62,7 +76,7 @@ public class StatelessTokenServiceTest { @Test public void testGenerateAccessTokenPersonalTokenEternal() throws Exception { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); String token = statelessTokenService.generateAccessToken( TokenType.PERSONAL, new Actor(ActorType.USER, "datahub"), null); @@ -87,7 +101,7 @@ public class StatelessTokenServiceTest { @Test public void testGenerateAccessTokenSessionToken() throws Exception { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); String token = statelessTokenService.generateAccessToken( TokenType.SESSION, new Actor(ActorType.USER, "datahub")); @@ -112,7 +126,7 @@ public class StatelessTokenServiceTest { @Test public void testValidateAccessTokenFailsDueToExpiration() { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); // Generate token that expires immediately. String token = statelessTokenService.generateAccessToken( @@ -127,7 +141,7 @@ public class StatelessTokenServiceTest { @Test public void testValidateAccessTokenFailsDueToManipulation() { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); String token = statelessTokenService.generateAccessToken( TokenType.PERSONAL, new Actor(ActorType.USER, "datahub")); @@ -149,7 +163,7 @@ public class StatelessTokenServiceTest { + "CJ0eXBlIjoiU0VTU0lPTiIsInZlcnNpb24iOiIxIiwianRpIjoiN2VmOTkzYjQtMjBiOC00Y2Y5LTljNm" + "YtMTE2NjNjZWVmOTQzIiwic3ViIjoiZGF0YWh1YiIsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9."; StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); // Validation should fail. assertThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(badToken)); } @@ -157,7 +171,7 @@ public class StatelessTokenServiceTest { @Test public void testValidateAccessTokenFailsDueToUnsupportedSigningAlgorithm() throws Exception { StatelessTokenService statelessTokenService = - new StatelessTokenService(TEST_SIGNING_KEY, "HS256"); + new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256"); Map claims = new HashMap<>(); claims.put( @@ -184,4 +198,186 @@ public class StatelessTokenServiceTest { // Validation should fail. assertThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(badToken)); } + + @Test + public void testValidateAccessTokenSystemActorAlwaysActive() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a token for system actor + String token = + statelessTokenService.generateAccessToken( + TokenType.SESSION, new Actor(ActorType.USER, "__datahub_system")); + assertNotNull(token); + + // System actor should always be active, regardless of aspect retriever response + // No need to mock anything - system actor bypasses all checks + TokenClaims claims = statelessTokenService.validateAccessToken(token); + assertEquals(claims.getActorId(), "__datahub_system"); + } + + @Test + public void testValidateAccessTokenFailsDueToHardDeletedUser() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a valid token + String token = + statelessTokenService.generateAccessToken( + TokenType.PERSONAL, new Actor(ActorType.USER, "deleteduser")); + assertNotNull(token); + + // Mock to return empty aspect map - user has no CorpUserKey aspect (hard deleted) + Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any())) + .thenReturn(Map.of(UrnUtils.getUrn("urn:li:corpuser:deleteduser"), Collections.emptyMap())); + + // Validation should fail due to missing CorpUserKey aspect + TokenException exception = + expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token)); + assertEquals(exception.getMessage(), "Actor is not active"); + } + + @Test + public void testValidateAccessTokenFailsDueToRemovedStatus() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a valid token + String token = + statelessTokenService.generateAccessToken( + TokenType.PERSONAL, new Actor(ActorType.USER, "removeduser")); + assertNotNull(token); + + // Create a removed status + Status removedStatus = new Status().setRemoved(true); + CorpUserKey corpUserKey = new CorpUserKey().setUsername("removeduser"); + + // Mock to return removed status + Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:removeduser"); + Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any())) + .thenReturn( + Map.of( + userUrn, + Map.of( + "status", new Aspect(removedStatus.data()), + "corpUserKey", new Aspect(corpUserKey.data())))); + + // Validation should fail due to removed status + TokenException exception = + expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token)); + assertEquals(exception.getMessage(), "Actor is not active"); + } + + @Test + public void testValidateAccessTokenFailsDueToSuspendedStatus() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a valid token + String token = + statelessTokenService.generateAccessToken( + TokenType.PERSONAL, new Actor(ActorType.USER, "suspendeduser")); + assertNotNull(token); + + // Create a suspended corp user status + Status activeStatus = new Status().setRemoved(false); + CorpUserStatus suspendedStatus = new CorpUserStatus().setStatus("SUSPENDED"); + CorpUserKey corpUserKey = new CorpUserKey().setUsername("suspendeduser"); + + // Mock to return suspended status + Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:suspendeduser"); + Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any())) + .thenReturn( + Map.of( + userUrn, + Map.of( + "status", new Aspect(activeStatus.data()), + "corpUserStatus", new Aspect(suspendedStatus.data()), + "corpUserKey", new Aspect(corpUserKey.data())))); + + // Validation should fail due to suspended status + TokenException exception = + expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token)); + assertEquals(exception.getMessage(), "Actor is not active"); + } + + @Test + public void testValidateAccessTokenSucceedsForActiveUser() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a valid token + String token = + statelessTokenService.generateAccessToken( + TokenType.PERSONAL, new Actor(ActorType.USER, "activeuser")); + assertNotNull(token); + + // Create active user aspects + Status activeStatus = new Status().setRemoved(false); + CorpUserStatus activeUserStatus = new CorpUserStatus().setStatus("ACTIVE"); + CorpUserKey corpUserKey = new CorpUserKey().setUsername("activeuser"); + + // Mock to return active user + Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:activeuser"); + Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any())) + .thenReturn( + Map.of( + userUrn, + Map.of( + "status", new Aspect(activeStatus.data()), + "corpUserStatus", new Aspect(activeUserStatus.data()), + "corpUserKey", new Aspect(corpUserKey.data())))); + + // Validation should succeed + TokenClaims claims = statelessTokenService.validateAccessToken(token); + assertEquals(claims.getActorId(), "activeuser"); + } + + @Test + public void testValidateAccessTokenSucceedsWithMissingOptionalAspects() throws Exception { + AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class); + OperationContext mockContext = + TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever); + + StatelessTokenService statelessTokenService = + new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256"); + + // Generate a valid token + String token = + statelessTokenService.generateAccessToken( + TokenType.PERSONAL, new Actor(ActorType.USER, "minimaluser")); + assertNotNull(token); + + // Only provide the required corpUserKey aspect - status and corpUserStatus will use defaults + CorpUserKey corpUserKey = new CorpUserKey().setUsername("minimaluser"); + + // Mock to return only corpUserKey + Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:minimaluser"); + Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any())) + .thenReturn(Map.of(userUrn, Map.of("corpUserKey", new Aspect(corpUserKey.data())))); + + // Validation should succeed - missing aspects use defaults (not removed, not suspended) + TokenClaims claims = statelessTokenService.validateAccessToken(token); + assertEquals(claims.getActorId(), "minimaluser"); + } } diff --git a/smoke-test/tests/tokens/session_access_token_test.py b/smoke-test/tests/tokens/session_access_token_test.py index 5328fd4237..a39c1fbc9b 100644 --- a/smoke-test/tests/tokens/session_access_token_test.py +++ b/smoke-test/tests/tokens/session_access_token_test.py @@ -100,8 +100,7 @@ def custom_user_session(): assert {"username": "sessionUser"} not in res_data["data"]["listUsers"]["users"] -@pytest.mark.dependency() -def test_soft_delete(graph_client, custom_user_session): +def test_01_soft_delete(graph_client, custom_user_session): # assert initial access assert getUserId(custom_user_session) == {"urn": user_urn} @@ -110,15 +109,14 @@ def test_soft_delete(graph_client, custom_user_session): with pytest.raises(HTTPError) as req_info: getUserId(custom_user_session) - assert "403 Client Error: Forbidden" in str(req_info.value) + assert "401 Client Error: Unauthorized" in str(req_info.value) # undo soft delete graph_client.set_soft_delete_status(urn=user_urn, delete=False) wait_for_writes_to_sync() -@pytest.mark.dependency(depends=["test_soft_delete"]) -def test_suspend(graph_client, custom_user_session): +def test_02_suspend(graph_client, custom_user_session): # assert initial access assert getUserId(custom_user_session) == {"urn": user_urn} @@ -140,7 +138,7 @@ def test_suspend(graph_client, custom_user_session): with pytest.raises(HTTPError) as req_info: getUserId(custom_user_session) - assert "403 Client Error: Forbidden" in str(req_info.value) + assert "401 Client Error: Unauthorized" in str(req_info.value) # undo suspend graph_client.emit( @@ -160,8 +158,7 @@ def test_suspend(graph_client, custom_user_session): wait_for_writes_to_sync() -@pytest.mark.dependency(depends=["test_suspend"]) -def test_hard_delete(graph_client, custom_user_session): +def test_03_hard_delete(graph_client, custom_user_session): # assert initial access assert getUserId(custom_user_session) == {"urn": user_urn} @@ -170,4 +167,4 @@ def test_hard_delete(graph_client, custom_user_session): with pytest.raises(HTTPError) as req_info: getUserId(custom_user_session) - assert "403 Client Error: Forbidden" in str(req_info.value) + assert "401 Client Error: Unauthorized" in str(req_info.value)