test(graphql): fix searchFlags in searchAcrossLineage (#11448)

This commit is contained in:
david-leifker 2024-09-20 15:26:25 -05:00 committed by GitHub
parent c5112af573
commit 3995140f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 159 additions and 8 deletions

View File

@ -14,6 +14,7 @@ import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageInput;
import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageResults;
import com.linkedin.datahub.graphql.resolvers.ResolverUtils;
import com.linkedin.datahub.graphql.types.common.mappers.LineageFlagsInputMapper;
import com.linkedin.datahub.graphql.types.common.mappers.SearchFlagsInputMapper;
import com.linkedin.datahub.graphql.types.entitytype.EntityTypeMapper;
import com.linkedin.datahub.graphql.types.mappers.UrnScrollAcrossLineageResultsMapper;
import com.linkedin.entity.client.EntityClient;
@ -89,7 +90,6 @@ public class ScrollAcrossLineageResolver
if (lineageFlags.getEndTimeMillis() == null && endTimeMillis != null) {
lineageFlags.setEndTimeMillis(endTimeMillis);
}
;
com.linkedin.metadata.graph.LineageDirection resolvedDirection =
com.linkedin.metadata.graph.LineageDirection.valueOf(lineageDirection.toString());
@ -107,17 +107,13 @@ public class ScrollAcrossLineageResolver
count);
final SearchFlags searchFlags;
final com.linkedin.datahub.graphql.generated.SearchFlags inputFlags =
input.getSearchFlags();
com.linkedin.datahub.graphql.generated.SearchFlags inputFlags = input.getSearchFlags();
if (inputFlags != null) {
searchFlags =
new SearchFlags()
.setSkipCache(inputFlags.getSkipCache())
.setFulltext(inputFlags.getFulltext())
.setMaxAggValues(inputFlags.getMaxAggValues());
searchFlags = SearchFlagsInputMapper.INSTANCE.apply(context, inputFlags);
} else {
searchFlags = null;
}
return UrnScrollAcrossLineageResultsMapper.map(
context,
_entityClient.scrollAcrossLineage(

View File

@ -0,0 +1,155 @@
package com.linkedin.datahub.graphql.resolvers.search;
import static com.linkedin.datahub.graphql.TestUtils.getMockAllowContext;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import com.datahub.authentication.Authentication;
import com.linkedin.common.UrnArrayArray;
import com.linkedin.common.urn.UrnUtils;
import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
import com.linkedin.datahub.graphql.QueryContext;
import com.linkedin.datahub.graphql.generated.EntityType;
import com.linkedin.datahub.graphql.generated.LineageDirection;
import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageInput;
import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageResults;
import com.linkedin.datahub.graphql.generated.SearchAcrossLineageResult;
import com.linkedin.datahub.graphql.generated.SearchFlags;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.metadata.models.registry.ConfigEntityRegistry;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.search.AggregationMetadataArray;
import com.linkedin.metadata.search.LineageScrollResult;
import com.linkedin.metadata.search.LineageSearchEntity;
import com.linkedin.metadata.search.LineageSearchEntityArray;
import com.linkedin.metadata.search.MatchedFieldArray;
import com.linkedin.metadata.search.SearchResultMetadata;
import graphql.schema.DataFetchingEnvironment;
import io.datahubproject.metadata.context.OperationContext;
import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import org.mockito.ArgumentCaptor;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
public class ScrollAcrossLineageResolverTest {
private static final String SOURCE_URN_STRING =
"urn:li:dataset:(urn:li:dataPlatform:foo,bar,PROD)";
private static final String TARGET_URN_STRING =
"urn:li:dataset:(urn:li:dataPlatform:foo,baz,PROD)";
private static final String QUERY = "";
private static final int START = 0;
private static final int COUNT = 10;
private static final Long START_TIMESTAMP_MILLIS = 0L;
private static final Long END_TIMESTAMP_MILLIS = 1000L;
private EntityClient _entityClient;
private DataFetchingEnvironment _dataFetchingEnvironment;
private Authentication _authentication;
private ScrollAcrossLineageResolver _resolver;
@BeforeTest
public void disableAssert() {
PathSpecBasedSchemaAnnotationVisitor.class
.getClassLoader()
.setClassAssertionStatus(PathSpecBasedSchemaAnnotationVisitor.class.getName(), false);
}
@BeforeMethod
public void setupTest() {
_entityClient = mock(EntityClient.class);
_dataFetchingEnvironment = mock(DataFetchingEnvironment.class);
_authentication = mock(Authentication.class);
_resolver = new ScrollAcrossLineageResolver(_entityClient);
}
@Test
public void testAllEntitiesInitialization() {
InputStream inputStream = ClassLoader.getSystemResourceAsStream("entity-registry.yml");
EntityRegistry entityRegistry = new ConfigEntityRegistry(inputStream);
SearchAcrossLineageResolver resolver =
new SearchAcrossLineageResolver(_entityClient, entityRegistry);
assertTrue(resolver._allEntities.contains("dataset"));
assertTrue(resolver._allEntities.contains("dataFlow"));
// Test for case sensitivity
assertFalse(resolver._allEntities.contains("dataflow"));
}
@Test
public void testSearchAcrossLineage() throws Exception {
final QueryContext mockContext = getMockAllowContext();
when(mockContext.getAuthentication()).thenReturn(_authentication);
when(_dataFetchingEnvironment.getContext()).thenReturn(mockContext);
final SearchFlags searchFlags = new SearchFlags();
searchFlags.setFulltext(true);
final ScrollAcrossLineageInput input = new ScrollAcrossLineageInput();
input.setCount(COUNT);
input.setDirection(LineageDirection.DOWNSTREAM);
input.setOrFilters(Collections.emptyList());
input.setQuery(QUERY);
input.setTypes(Collections.emptyList());
input.setStartTimeMillis(START_TIMESTAMP_MILLIS);
input.setEndTimeMillis(END_TIMESTAMP_MILLIS);
input.setUrn(SOURCE_URN_STRING);
input.setSearchFlags(searchFlags);
when(_dataFetchingEnvironment.getArgument(eq("input"))).thenReturn(input);
final LineageScrollResult lineageSearchResult = new LineageScrollResult();
lineageSearchResult.setNumEntities(1);
lineageSearchResult.setPageSize(10);
final SearchResultMetadata searchResultMetadata = new SearchResultMetadata();
searchResultMetadata.setAggregations(new AggregationMetadataArray());
lineageSearchResult.setMetadata(searchResultMetadata);
final LineageSearchEntity lineageSearchEntity = new LineageSearchEntity();
lineageSearchEntity.setEntity(UrnUtils.getUrn(TARGET_URN_STRING));
lineageSearchEntity.setScore(15.0);
lineageSearchEntity.setDegree(1);
lineageSearchEntity.setMatchedFields(new MatchedFieldArray());
lineageSearchEntity.setPaths(new UrnArrayArray());
lineageSearchResult.setEntities(new LineageSearchEntityArray(lineageSearchEntity));
ArgumentCaptor<OperationContext> opContext = ArgumentCaptor.forClass(OperationContext.class);
when(_entityClient.scrollAcrossLineage(
opContext.capture(),
eq(UrnUtils.getUrn(SOURCE_URN_STRING)),
eq(com.linkedin.metadata.graph.LineageDirection.DOWNSTREAM),
anyList(),
eq(QUERY),
eq(null),
any(),
eq(null),
nullable(String.class),
nullable(String.class),
eq(COUNT)))
.thenReturn(lineageSearchResult);
final ScrollAcrossLineageResults results = _resolver.get(_dataFetchingEnvironment).join();
assertEquals(results.getCount(), 10);
assertEquals(results.getTotal(), 1);
assertEquals(
opContext.getValue().getSearchContext().getLineageFlags().getStartTimeMillis(),
START_TIMESTAMP_MILLIS);
assertEquals(
opContext.getValue().getSearchContext().getLineageFlags().getEndTimeMillis(),
END_TIMESTAMP_MILLIS);
final List<SearchAcrossLineageResult> entities = results.getSearchResults();
assertEquals(entities.size(), 1);
final SearchAcrossLineageResult entity = entities.get(0);
assertEquals(entity.getEntity().getUrn(), TARGET_URN_STRING);
assertEquals(entity.getEntity().getType(), EntityType.DATASET);
}
}