fix(token-service): extend validation for actor (#13947)

This commit is contained in:
david-leifker 2025-07-02 16:27:42 -05:00 committed by GitHub
parent 5292a268c9
commit e7463ac1f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 253 additions and 23 deletions

View File

@ -46,7 +46,7 @@ public class StatefulTokenService extends StatelessTokenService {
@Nullable final String iss,
@Nonnull final EntityService<?> entityService,
@Nonnull final String salt) {
super(signingKey, signingAlgorithm, iss);
super(systemOperationContext, signingKey, signingAlgorithm, iss);
this.systemOperationContext = systemOperationContext;
this._entityService = entityService;
this._revokedTokenCache =

View File

@ -2,6 +2,10 @@ package com.datahub.authentication.token;
import com.datahub.authentication.Actor;
import com.datahub.authentication.ActorType;
import com.datahub.authentication.Authentication;
import com.linkedin.metadata.aspect.AspectRetriever;
import io.datahubproject.metadata.context.ActorContext;
import io.datahubproject.metadata.context.OperationContext;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtBuilder;
@ -37,19 +41,24 @@ public class StatelessTokenService {
private final String signingKey;
private final SignatureAlgorithm signingAlgorithm;
private final String iss;
private final OperationContext systemOperationContext;
public StatelessTokenService(
@Nonnull final String signingKey, @Nonnull final String signingAlgorithm) {
this(signingKey, signingAlgorithm, null);
@Nonnull OperationContext systemOperationContext,
@Nonnull final String signingKey,
@Nonnull final String signingAlgorithm) {
this(systemOperationContext, signingKey, signingAlgorithm, null);
}
public StatelessTokenService(
@Nonnull OperationContext systemOperationContext,
@Nonnull final String signingKey,
@Nonnull final String signingAlgorithm,
@Nullable final String iss) {
this.signingKey = Objects.requireNonNull(signingKey);
this.signingAlgorithm = validateAlgorithm(Objects.requireNonNull(signingAlgorithm));
this.iss = iss;
this.systemOperationContext = systemOperationContext;
}
/**
@ -131,6 +140,9 @@ public class StatelessTokenService {
final String actorId = claims.get(TokenClaims.ACTOR_ID_CLAIM_NAME, String.class);
final String actorType = claims.get(TokenClaims.ACTOR_TYPE_CLAIM_NAME, String.class);
if (tokenType != null && actorId != null && actorType != null) {
// Validate the actor is active before returning claims
validateActor(actorId);
return new TokenClaims(
TokenVersion.fromNumericStringValue(tokenVersion),
TokenType.valueOf(tokenType),
@ -140,13 +152,38 @@ public class StatelessTokenService {
}
} catch (io.jsonwebtoken.ExpiredJwtException e) {
throw new TokenExpiredException("Failed to validate DataHub token. Token has expired.", e);
} catch (TokenException e) {
throw e;
} catch (Exception e) {
throw new TokenException("Failed to validate DataHub token", e);
}
throw new TokenException(
"Failed to validate DataHub token: Found malformed or missing 'actor' claim.");
}
/** Validates that the actor is active using the OperationContext's built-in validation */
private void validateActor(@Nonnull final String actorId) throws TokenException {
try {
AspectRetriever aspectRetriever = systemOperationContext.getAspectRetriever();
ActorContext actorContext =
ActorContext.builder()
.authentication(new Authentication(new Actor(ActorType.USER, actorId), ""))
.enforceExistenceEnabled(true)
.build();
// Use the existing isActive check from ActorContext
if (!actorContext.isActive(aspectRetriever)) {
throw new TokenException("Actor is not active");
}
} catch (Exception e) {
if (e instanceof TokenException) {
throw (TokenException) e;
}
throw new TokenException("Failed to validate actor status", e);
}
}
private void validateTokenAlgorithm(final String algorithm) throws TokenException {
try {
validateAlgorithm(algorithm);

View File

@ -6,38 +6,52 @@ import static org.testng.Assert.*;
import com.datahub.authentication.Actor;
import com.datahub.authentication.ActorType;
import com.datahub.authentication.authenticator.DataHubTokenAuthenticator;
import com.linkedin.common.Status;
import com.linkedin.common.urn.Urn;
import com.linkedin.common.urn.UrnUtils;
import com.linkedin.entity.Aspect;
import com.linkedin.identity.CorpUserStatus;
import com.linkedin.metadata.aspect.AspectRetriever;
import com.linkedin.metadata.key.CorpUserKey;
import io.datahubproject.metadata.context.OperationContext;
import io.datahubproject.test.metadata.context.TestOperationContexts;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import javax.crypto.spec.SecretKeySpec;
import org.mockito.Mockito;
import org.testng.annotations.Test;
public class StatelessTokenServiceTest {
private static final String TEST_SIGNING_KEY = "WnEdIeTG/VVCLQqGwC/BAkqyY0k+H8NEAtWGejrBI94=";
private OperationContext opContext = TestOperationContexts.systemContextNoValidate();
@Test
public void testConstructor() {
final DataHubTokenAuthenticator authenticator = new DataHubTokenAuthenticator();
assertThrows(() -> new StatelessTokenService(null, null, null));
assertThrows(() -> new StatelessTokenService(TEST_SIGNING_KEY, null, null));
assertThrows(() -> new StatelessTokenService(TEST_SIGNING_KEY, "UNSUPPORTED_ALG", null));
assertThrows(() -> new StatelessTokenService(opContext, null, null, null));
assertThrows(() -> new StatelessTokenService(opContext, TEST_SIGNING_KEY, null, null));
assertThrows(
() -> new StatelessTokenService(opContext, TEST_SIGNING_KEY, "UNSUPPORTED_ALG", null));
// Succeeds:
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(TEST_SIGNING_KEY, "HS256", null);
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256", null);
}
@Test
public void testGenerateAccessTokenPersonalToken() throws Exception {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "datahub"));
@ -62,7 +76,7 @@ public class StatelessTokenServiceTest {
@Test
public void testGenerateAccessTokenPersonalTokenEternal() throws Exception {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "datahub"), null);
@ -87,7 +101,7 @@ public class StatelessTokenServiceTest {
@Test
public void testGenerateAccessTokenSessionToken() throws Exception {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
String token =
statelessTokenService.generateAccessToken(
TokenType.SESSION, new Actor(ActorType.USER, "datahub"));
@ -112,7 +126,7 @@ public class StatelessTokenServiceTest {
@Test
public void testValidateAccessTokenFailsDueToExpiration() {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
// Generate token that expires immediately.
String token =
statelessTokenService.generateAccessToken(
@ -127,7 +141,7 @@ public class StatelessTokenServiceTest {
@Test
public void testValidateAccessTokenFailsDueToManipulation() {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "datahub"));
@ -149,7 +163,7 @@ public class StatelessTokenServiceTest {
+ "CJ0eXBlIjoiU0VTU0lPTiIsInZlcnNpb24iOiIxIiwianRpIjoiN2VmOTkzYjQtMjBiOC00Y2Y5LTljNm"
+ "YtMTE2NjNjZWVmOTQzIiwic3ViIjoiZGF0YWh1YiIsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9.";
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
// Validation should fail.
assertThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(badToken));
}
@ -157,7 +171,7 @@ public class StatelessTokenServiceTest {
@Test
public void testValidateAccessTokenFailsDueToUnsupportedSigningAlgorithm() throws Exception {
StatelessTokenService statelessTokenService =
new StatelessTokenService(TEST_SIGNING_KEY, "HS256");
new StatelessTokenService(opContext, TEST_SIGNING_KEY, "HS256");
Map<String, Object> claims = new HashMap<>();
claims.put(
@ -184,4 +198,186 @@ public class StatelessTokenServiceTest {
// Validation should fail.
assertThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(badToken));
}
@Test
public void testValidateAccessTokenSystemActorAlwaysActive() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a token for system actor
String token =
statelessTokenService.generateAccessToken(
TokenType.SESSION, new Actor(ActorType.USER, "__datahub_system"));
assertNotNull(token);
// System actor should always be active, regardless of aspect retriever response
// No need to mock anything - system actor bypasses all checks
TokenClaims claims = statelessTokenService.validateAccessToken(token);
assertEquals(claims.getActorId(), "__datahub_system");
}
@Test
public void testValidateAccessTokenFailsDueToHardDeletedUser() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a valid token
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "deleteduser"));
assertNotNull(token);
// Mock to return empty aspect map - user has no CorpUserKey aspect (hard deleted)
Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any()))
.thenReturn(Map.of(UrnUtils.getUrn("urn:li:corpuser:deleteduser"), Collections.emptyMap()));
// Validation should fail due to missing CorpUserKey aspect
TokenException exception =
expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token));
assertEquals(exception.getMessage(), "Actor is not active");
}
@Test
public void testValidateAccessTokenFailsDueToRemovedStatus() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a valid token
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "removeduser"));
assertNotNull(token);
// Create a removed status
Status removedStatus = new Status().setRemoved(true);
CorpUserKey corpUserKey = new CorpUserKey().setUsername("removeduser");
// Mock to return removed status
Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:removeduser");
Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any()))
.thenReturn(
Map.of(
userUrn,
Map.of(
"status", new Aspect(removedStatus.data()),
"corpUserKey", new Aspect(corpUserKey.data()))));
// Validation should fail due to removed status
TokenException exception =
expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token));
assertEquals(exception.getMessage(), "Actor is not active");
}
@Test
public void testValidateAccessTokenFailsDueToSuspendedStatus() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a valid token
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "suspendeduser"));
assertNotNull(token);
// Create a suspended corp user status
Status activeStatus = new Status().setRemoved(false);
CorpUserStatus suspendedStatus = new CorpUserStatus().setStatus("SUSPENDED");
CorpUserKey corpUserKey = new CorpUserKey().setUsername("suspendeduser");
// Mock to return suspended status
Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:suspendeduser");
Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any()))
.thenReturn(
Map.of(
userUrn,
Map.of(
"status", new Aspect(activeStatus.data()),
"corpUserStatus", new Aspect(suspendedStatus.data()),
"corpUserKey", new Aspect(corpUserKey.data()))));
// Validation should fail due to suspended status
TokenException exception =
expectThrows(TokenException.class, () -> statelessTokenService.validateAccessToken(token));
assertEquals(exception.getMessage(), "Actor is not active");
}
@Test
public void testValidateAccessTokenSucceedsForActiveUser() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a valid token
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "activeuser"));
assertNotNull(token);
// Create active user aspects
Status activeStatus = new Status().setRemoved(false);
CorpUserStatus activeUserStatus = new CorpUserStatus().setStatus("ACTIVE");
CorpUserKey corpUserKey = new CorpUserKey().setUsername("activeuser");
// Mock to return active user
Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:activeuser");
Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any()))
.thenReturn(
Map.of(
userUrn,
Map.of(
"status", new Aspect(activeStatus.data()),
"corpUserStatus", new Aspect(activeUserStatus.data()),
"corpUserKey", new Aspect(corpUserKey.data()))));
// Validation should succeed
TokenClaims claims = statelessTokenService.validateAccessToken(token);
assertEquals(claims.getActorId(), "activeuser");
}
@Test
public void testValidateAccessTokenSucceedsWithMissingOptionalAspects() throws Exception {
AspectRetriever mockAspectRetriever = Mockito.mock(AspectRetriever.class);
OperationContext mockContext =
TestOperationContexts.systemContextNoSearchAuthorization(mockAspectRetriever);
StatelessTokenService statelessTokenService =
new StatelessTokenService(mockContext, TEST_SIGNING_KEY, "HS256");
// Generate a valid token
String token =
statelessTokenService.generateAccessToken(
TokenType.PERSONAL, new Actor(ActorType.USER, "minimaluser"));
assertNotNull(token);
// Only provide the required corpUserKey aspect - status and corpUserStatus will use defaults
CorpUserKey corpUserKey = new CorpUserKey().setUsername("minimaluser");
// Mock to return only corpUserKey
Urn userUrn = UrnUtils.getUrn("urn:li:corpuser:minimaluser");
Mockito.when(mockAspectRetriever.getLatestAspectObjects(Mockito.any(), Mockito.any()))
.thenReturn(Map.of(userUrn, Map.of("corpUserKey", new Aspect(corpUserKey.data()))));
// Validation should succeed - missing aspects use defaults (not removed, not suspended)
TokenClaims claims = statelessTokenService.validateAccessToken(token);
assertEquals(claims.getActorId(), "minimaluser");
}
}

View File

@ -100,8 +100,7 @@ def custom_user_session():
assert {"username": "sessionUser"} not in res_data["data"]["listUsers"]["users"]
@pytest.mark.dependency()
def test_soft_delete(graph_client, custom_user_session):
def test_01_soft_delete(graph_client, custom_user_session):
# assert initial access
assert getUserId(custom_user_session) == {"urn": user_urn}
@ -110,15 +109,14 @@ def test_soft_delete(graph_client, custom_user_session):
with pytest.raises(HTTPError) as req_info:
getUserId(custom_user_session)
assert "403 Client Error: Forbidden" in str(req_info.value)
assert "401 Client Error: Unauthorized" in str(req_info.value)
# undo soft delete
graph_client.set_soft_delete_status(urn=user_urn, delete=False)
wait_for_writes_to_sync()
@pytest.mark.dependency(depends=["test_soft_delete"])
def test_suspend(graph_client, custom_user_session):
def test_02_suspend(graph_client, custom_user_session):
# assert initial access
assert getUserId(custom_user_session) == {"urn": user_urn}
@ -140,7 +138,7 @@ def test_suspend(graph_client, custom_user_session):
with pytest.raises(HTTPError) as req_info:
getUserId(custom_user_session)
assert "403 Client Error: Forbidden" in str(req_info.value)
assert "401 Client Error: Unauthorized" in str(req_info.value)
# undo suspend
graph_client.emit(
@ -160,8 +158,7 @@ def test_suspend(graph_client, custom_user_session):
wait_for_writes_to_sync()
@pytest.mark.dependency(depends=["test_suspend"])
def test_hard_delete(graph_client, custom_user_session):
def test_03_hard_delete(graph_client, custom_user_session):
# assert initial access
assert getUserId(custom_user_session) == {"urn": user_urn}
@ -170,4 +167,4 @@ def test_hard_delete(graph_client, custom_user_session):
with pytest.raises(HTTPError) as req_info:
getUserId(custom_user_session)
assert "403 Client Error: Forbidden" in str(req_info.value)
assert "401 Client Error: Unauthorized" in str(req_info.value)