fix(graphql): BatchGetEntitiesResolver respects order (#9557)

This commit is contained in:
Shirshanka Das 2024-01-03 16:51:35 -08:00 committed by GitHub
parent 4240578627
commit a8faa172c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 163 additions and 11 deletions

View File

@ -3,16 +3,20 @@ package com.linkedin.datahub.graphql.resolvers.load;
import com.linkedin.datahub.graphql.generated.Entity;
import com.linkedin.datahub.graphql.generated.EntityType;
import com.linkedin.datahub.graphql.resolvers.BatchLoadUtils;
import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class BatchGetEntitiesResolver implements DataFetcher<CompletableFuture<List<Entity>>> {
private final List<com.linkedin.datahub.graphql.types.EntityType<?, ?>> _entityTypes;
@ -30,13 +34,21 @@ public class BatchGetEntitiesResolver implements DataFetcher<CompletableFuture<L
final List<Entity> entities = _entitiesProvider.apply(environment);
Map<EntityType, List<Entity>> entityTypeToEntities = new HashMap<>();
entities.forEach(
(entity) -> {
EntityType type = entity.getType();
List<Entity> entitiesList = entityTypeToEntities.getOrDefault(type, new ArrayList<>());
entitiesList.add(entity);
entityTypeToEntities.put(type, entitiesList);
});
Map<String, List<Integer>> entityIndexMap = new HashMap<>();
int index = 0;
for (Entity entity : entities) {
List<Integer> indexList = new ArrayList<>();
if (entityIndexMap.containsKey(entity.getUrn())) {
indexList = entityIndexMap.get(entity.getUrn());
}
indexList.add(index);
entityIndexMap.put(entity.getUrn(), indexList);
index++;
EntityType type = entity.getType();
List<Entity> entitiesList = entityTypeToEntities.getOrDefault(type, new ArrayList<>());
entitiesList.add(entity);
entityTypeToEntities.put(type, entitiesList);
}
List<CompletableFuture<List<Entity>>> entitiesFutures = new ArrayList<>();
@ -49,9 +61,32 @@ public class BatchGetEntitiesResolver implements DataFetcher<CompletableFuture<L
return CompletableFuture.allOf(entitiesFutures.toArray(new CompletableFuture[0]))
.thenApply(
v ->
entitiesFutures.stream()
.flatMap(future -> future.join().stream())
.collect(Collectors.toList()));
v -> {
Entity[] finalEntityList = new Entity[entities.size()];
// Returned objects can be either of type Entity or wrapped as
// DataFetcherResult<Entity>
// Therefore we need to be working with raw Objects in this area of the code
List<Object> returnedList =
entitiesFutures.stream()
.flatMap(future -> future.join().stream())
.collect(Collectors.toList());
for (Object element : returnedList) {
Entity entity = null;
if (element instanceof DataFetcherResult) {
entity = ((DataFetcherResult<Entity>) element).getData();
} else if (element instanceof Entity) {
entity = (Entity) element;
} else {
throw new RuntimeException(
String.format(
"Cannot process entity because it is neither an Entity not a DataFetcherResult. %s",
element));
}
for (int idx : entityIndexMap.get(entity.getUrn())) {
finalEntityList[idx] = entity;
}
}
return Arrays.asList(finalEntityList);
});
}
}

View File

@ -0,0 +1,117 @@
package com.linkedin.datahub.graphql.resolvers.load;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.testng.Assert.*;
import com.google.common.collect.ImmutableList;
import com.linkedin.datahub.graphql.generated.Dashboard;
import com.linkedin.datahub.graphql.generated.Dataset;
import com.linkedin.datahub.graphql.generated.Entity;
import com.linkedin.datahub.graphql.types.dataset.DatasetType;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.metadata.entity.EntityService;
import graphql.schema.DataFetchingEnvironment;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
public class BatchGetEntitiesResolverTest {
private EntityClient _entityClient;
private EntityService _entityService;
private DataFetchingEnvironment _dataFetchingEnvironment;
@BeforeMethod
public void setupTest() {
_entityService = mock(EntityService.class);
_dataFetchingEnvironment = mock(DataFetchingEnvironment.class);
_entityClient = mock(EntityClient.class);
}
List<Entity> getRequestEntities(List<String> urnList) {
return urnList.stream()
.map(
urn -> {
if (urn.startsWith("urn:li:dataset")) {
Dataset entity = new Dataset();
entity.setUrn(urn);
return entity;
} else if (urn.startsWith("urn:li:dashboard")) {
Dashboard entity = new Dashboard();
entity.setUrn(urn);
return entity;
} else {
throw new RuntimeException("Can't handle urn " + urn);
}
})
.collect(Collectors.toList());
}
@Test
/** Tests that if responses come back out of order, we stitch them back correctly */
public void testReordering() throws Exception {
Function entityProvider = mock(Function.class);
List<Entity> inputEntities =
getRequestEntities(ImmutableList.of("urn:li:dataset:1", "urn:li:dataset:2"));
when(entityProvider.apply(any())).thenReturn(inputEntities);
BatchGetEntitiesResolver resolver =
new BatchGetEntitiesResolver(
ImmutableList.of(new DatasetType(_entityClient)), entityProvider);
DataLoaderRegistry mockDataLoaderRegistry = mock(DataLoaderRegistry.class);
when(_dataFetchingEnvironment.getDataLoaderRegistry()).thenReturn(mockDataLoaderRegistry);
DataLoader mockDataLoader = mock(DataLoader.class);
when(mockDataLoaderRegistry.getDataLoader(any())).thenReturn(mockDataLoader);
Dataset mockResponseEntity1 = new Dataset();
mockResponseEntity1.setUrn("urn:li:dataset:1");
Dataset mockResponseEntity2 = new Dataset();
mockResponseEntity2.setUrn("urn:li:dataset:2");
CompletableFuture mockFuture =
CompletableFuture.completedFuture(
ImmutableList.of(mockResponseEntity2, mockResponseEntity1));
when(mockDataLoader.loadMany(any())).thenReturn(mockFuture);
when(_entityService.exists(any())).thenReturn(true);
List<Entity> batchGetResponse = resolver.get(_dataFetchingEnvironment).join();
assertEquals(batchGetResponse.size(), 2);
assertEquals(batchGetResponse.get(0), mockResponseEntity1);
assertEquals(batchGetResponse.get(1), mockResponseEntity2);
}
@Test
/** Tests that if input list contains duplicates, we stitch them back correctly */
public void testDuplicateUrns() throws Exception {
Function entityProvider = mock(Function.class);
List<Entity> inputEntities =
getRequestEntities(ImmutableList.of("urn:li:dataset:foo", "urn:li:dataset:foo"));
when(entityProvider.apply(any())).thenReturn(inputEntities);
BatchGetEntitiesResolver resolver =
new BatchGetEntitiesResolver(
ImmutableList.of(new DatasetType(_entityClient)), entityProvider);
DataLoaderRegistry mockDataLoaderRegistry = mock(DataLoaderRegistry.class);
when(_dataFetchingEnvironment.getDataLoaderRegistry()).thenReturn(mockDataLoaderRegistry);
DataLoader mockDataLoader = mock(DataLoader.class);
when(mockDataLoaderRegistry.getDataLoader(any())).thenReturn(mockDataLoader);
Dataset mockResponseEntity = new Dataset();
mockResponseEntity.setUrn("urn:li:dataset:foo");
CompletableFuture mockFuture =
CompletableFuture.completedFuture(ImmutableList.of(mockResponseEntity));
when(mockDataLoader.loadMany(any())).thenReturn(mockFuture);
when(_entityService.exists(any())).thenReturn(true);
List<Entity> batchGetResponse = resolver.get(_dataFetchingEnvironment).join();
assertEquals(batchGetResponse.size(), 2);
assertEquals(batchGetResponse.get(0), mockResponseEntity);
assertEquals(batchGetResponse.get(1), mockResponseEntity);
}
}