diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolver.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolver.java index ecf36769df..2519d91aa3 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolver.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolver.java @@ -3,16 +3,20 @@ package com.linkedin.datahub.graphql.resolvers.load; import com.linkedin.datahub.graphql.generated.Entity; import com.linkedin.datahub.graphql.generated.EntityType; import com.linkedin.datahub.graphql.resolvers.BatchLoadUtils; +import graphql.execution.DataFetcherResult; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +@Slf4j public class BatchGetEntitiesResolver implements DataFetcher>> { private final List> _entityTypes; @@ -30,13 +34,21 @@ public class BatchGetEntitiesResolver implements DataFetcher entities = _entitiesProvider.apply(environment); Map> entityTypeToEntities = new HashMap<>(); - entities.forEach( - (entity) -> { - EntityType type = entity.getType(); - List entitiesList = entityTypeToEntities.getOrDefault(type, new ArrayList<>()); - entitiesList.add(entity); - entityTypeToEntities.put(type, entitiesList); - }); + Map> entityIndexMap = new HashMap<>(); + int index = 0; + for (Entity entity : entities) { + List indexList = new ArrayList<>(); + if (entityIndexMap.containsKey(entity.getUrn())) { + indexList = entityIndexMap.get(entity.getUrn()); + } + indexList.add(index); + entityIndexMap.put(entity.getUrn(), indexList); + index++; + EntityType type = entity.getType(); + List entitiesList = entityTypeToEntities.getOrDefault(type, new ArrayList<>()); + entitiesList.add(entity); + entityTypeToEntities.put(type, entitiesList); + } List>> entitiesFutures = new ArrayList<>(); @@ -49,9 +61,32 @@ public class BatchGetEntitiesResolver implements DataFetcher - entitiesFutures.stream() - .flatMap(future -> future.join().stream()) - .collect(Collectors.toList())); + v -> { + Entity[] finalEntityList = new Entity[entities.size()]; + // Returned objects can be either of type Entity or wrapped as + // DataFetcherResult + // Therefore we need to be working with raw Objects in this area of the code + List returnedList = + entitiesFutures.stream() + .flatMap(future -> future.join().stream()) + .collect(Collectors.toList()); + for (Object element : returnedList) { + Entity entity = null; + if (element instanceof DataFetcherResult) { + entity = ((DataFetcherResult) element).getData(); + } else if (element instanceof Entity) { + entity = (Entity) element; + } else { + throw new RuntimeException( + String.format( + "Cannot process entity because it is neither an Entity not a DataFetcherResult. %s", + element)); + } + for (int idx : entityIndexMap.get(entity.getUrn())) { + finalEntityList[idx] = entity; + } + } + return Arrays.asList(finalEntityList); + }); } } diff --git a/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolverTest.java b/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolverTest.java new file mode 100644 index 0000000000..6bd5b4f8c3 --- /dev/null +++ b/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/load/BatchGetEntitiesResolverTest.java @@ -0,0 +1,117 @@ +package com.linkedin.datahub.graphql.resolvers.load; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.testng.Assert.*; + +import com.google.common.collect.ImmutableList; +import com.linkedin.datahub.graphql.generated.Dashboard; +import com.linkedin.datahub.graphql.generated.Dataset; +import com.linkedin.datahub.graphql.generated.Entity; +import com.linkedin.datahub.graphql.types.dataset.DatasetType; +import com.linkedin.entity.client.EntityClient; +import com.linkedin.metadata.entity.EntityService; +import graphql.schema.DataFetchingEnvironment; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.dataloader.DataLoader; +import org.dataloader.DataLoaderRegistry; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class BatchGetEntitiesResolverTest { + private EntityClient _entityClient; + private EntityService _entityService; + private DataFetchingEnvironment _dataFetchingEnvironment; + + @BeforeMethod + public void setupTest() { + _entityService = mock(EntityService.class); + _dataFetchingEnvironment = mock(DataFetchingEnvironment.class); + _entityClient = mock(EntityClient.class); + } + + List getRequestEntities(List urnList) { + + return urnList.stream() + .map( + urn -> { + if (urn.startsWith("urn:li:dataset")) { + Dataset entity = new Dataset(); + entity.setUrn(urn); + return entity; + } else if (urn.startsWith("urn:li:dashboard")) { + Dashboard entity = new Dashboard(); + entity.setUrn(urn); + return entity; + } else { + throw new RuntimeException("Can't handle urn " + urn); + } + }) + .collect(Collectors.toList()); + } + + @Test + /** Tests that if responses come back out of order, we stitch them back correctly */ + public void testReordering() throws Exception { + Function entityProvider = mock(Function.class); + List inputEntities = + getRequestEntities(ImmutableList.of("urn:li:dataset:1", "urn:li:dataset:2")); + when(entityProvider.apply(any())).thenReturn(inputEntities); + BatchGetEntitiesResolver resolver = + new BatchGetEntitiesResolver( + ImmutableList.of(new DatasetType(_entityClient)), entityProvider); + + DataLoaderRegistry mockDataLoaderRegistry = mock(DataLoaderRegistry.class); + when(_dataFetchingEnvironment.getDataLoaderRegistry()).thenReturn(mockDataLoaderRegistry); + DataLoader mockDataLoader = mock(DataLoader.class); + when(mockDataLoaderRegistry.getDataLoader(any())).thenReturn(mockDataLoader); + + Dataset mockResponseEntity1 = new Dataset(); + mockResponseEntity1.setUrn("urn:li:dataset:1"); + + Dataset mockResponseEntity2 = new Dataset(); + mockResponseEntity2.setUrn("urn:li:dataset:2"); + + CompletableFuture mockFuture = + CompletableFuture.completedFuture( + ImmutableList.of(mockResponseEntity2, mockResponseEntity1)); + when(mockDataLoader.loadMany(any())).thenReturn(mockFuture); + when(_entityService.exists(any())).thenReturn(true); + List batchGetResponse = resolver.get(_dataFetchingEnvironment).join(); + assertEquals(batchGetResponse.size(), 2); + assertEquals(batchGetResponse.get(0), mockResponseEntity1); + assertEquals(batchGetResponse.get(1), mockResponseEntity2); + } + + @Test + /** Tests that if input list contains duplicates, we stitch them back correctly */ + public void testDuplicateUrns() throws Exception { + Function entityProvider = mock(Function.class); + List inputEntities = + getRequestEntities(ImmutableList.of("urn:li:dataset:foo", "urn:li:dataset:foo")); + when(entityProvider.apply(any())).thenReturn(inputEntities); + BatchGetEntitiesResolver resolver = + new BatchGetEntitiesResolver( + ImmutableList.of(new DatasetType(_entityClient)), entityProvider); + + DataLoaderRegistry mockDataLoaderRegistry = mock(DataLoaderRegistry.class); + when(_dataFetchingEnvironment.getDataLoaderRegistry()).thenReturn(mockDataLoaderRegistry); + DataLoader mockDataLoader = mock(DataLoader.class); + when(mockDataLoaderRegistry.getDataLoader(any())).thenReturn(mockDataLoader); + + Dataset mockResponseEntity = new Dataset(); + mockResponseEntity.setUrn("urn:li:dataset:foo"); + + CompletableFuture mockFuture = + CompletableFuture.completedFuture(ImmutableList.of(mockResponseEntity)); + when(mockDataLoader.loadMany(any())).thenReturn(mockFuture); + when(_entityService.exists(any())).thenReturn(true); + List batchGetResponse = resolver.get(_dataFetchingEnvironment).join(); + assertEquals(batchGetResponse.size(), 2); + assertEquals(batchGetResponse.get(0), mockResponseEntity); + assertEquals(batchGetResponse.get(1), mockResponseEntity); + } +}