feat(recomendations): Make top platforms account only for searchable entities (#9240)

This commit is contained in:
Pedro Silva 2023-11-15 19:53:50 +00:00 committed by GitHub
parent 4201e541ca
commit 6655918923
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 31 deletions

View File

@ -142,11 +142,11 @@ public class ElasticSearchService implements EntitySearchService, ElasticSearchI
@Nonnull
@Override
public Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field,
public Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
@Nullable Filter requestParams, int limit) {
log.debug("Aggregating by value: {}, field: {}, requestParams: {}, limit: {}", entityName, field, requestParams,
limit);
return esSearchDAO.aggregateByValue(entityName, field, requestParams, limit);
log.debug("Aggregating by value: {}, field: {}, requestParams: {}, limit: {}", entityNames.toString(), field,
requestParams, limit);
return esSearchDAO.aggregateByValue(entityNames, field, requestParams, limit);
}
@Nonnull

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
@ -263,17 +264,16 @@ public class ESSearchDAO {
* @return
*/
@Nonnull
public Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field,
public Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
@Nullable Filter requestParams, int limit) {
final SearchRequest searchRequest = SearchRequestHandler.getAggregationRequest(field, transformFilterForEntities(requestParams, indexConvention), limit);
String indexName;
if (entityName == null) {
indexName = indexConvention.getAllEntityIndicesPattern();
if (entityNames == null) {
String indexName = indexConvention.getAllEntityIndicesPattern();
searchRequest.indices(indexName);
} else {
EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName);
indexName = indexConvention.getIndexName(entitySpec);
Stream<String> stream = entityNames.stream().map(entityRegistry::getEntitySpec).map(indexConvention::getIndexName);
searchRequest.indices(stream.toArray(String[]::new));
}
searchRequest.indices(indexName);
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "aggregateByValue_search").time()) {
final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

View File

@ -3,6 +3,7 @@ package com.linkedin.metadata.search;
import com.datahub.test.Snapshot;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.collect.ImmutableList;
import com.linkedin.common.urn.TestEntityUrn;
import com.linkedin.common.urn.Urn;
import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
@ -99,7 +100,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
BrowseResult browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0);
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
Urn urn = new TestEntityUrn("test", "urn1", "VALUE_1");
ObjectNode document = JsonNodeFactory.instance.objectNode();
@ -124,7 +125,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1);
assertEquals(browseResult.getGroups().get(0).getName(), "b");
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10),
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L));
Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2");
@ -147,7 +148,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1);
assertEquals(browseResult.getGroups().get(0).getName(), "b");
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10),
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L));
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
@ -158,7 +159,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0);
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
}
@Test
@ -181,7 +182,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getEntities().get(0).getEntity(), urn);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10),
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L));
Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2");
@ -198,7 +199,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getEntities().get(0).getEntity(), urn2);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10),
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L));
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
@ -208,6 +209,6 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getNumEntities().intValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0);
assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
}
}

View File

@ -82,7 +82,7 @@ public abstract class EntitySearchAggregationSource implements RecommendationSou
public List<RecommendationContent> getRecommendations(@Nonnull Urn userUrn,
@Nullable RecommendationRequestContext requestContext) {
Map<String, Long> aggregationResult =
_entitySearchService.aggregateByValue(null, getSearchFieldName(), null, getMaxContent());
_entitySearchService.aggregateByValue(getEntityNames(), getSearchFieldName(), null, getMaxContent());
if (aggregationResult.isEmpty()) {
return Collections.emptyList();
@ -116,6 +116,11 @@ public abstract class EntitySearchAggregationSource implements RecommendationSou
.collect(Collectors.toList());
}
protected List<String> getEntityNames() {
// By default, no list is applied which means searching across entities.
return null;
}
// Get top K entries with the most count
private <T> List<Map.Entry<T, Long>> getTopKValues(Map<T, Long> countMap) {
final PriorityQueue<Map.Entry<T, Long>> queue =

View File

@ -1,15 +1,16 @@
package com.linkedin.metadata.recommendation.candidatesource;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableList;
import com.linkedin.common.urn.Urn;
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.dataplatform.DataPlatformInfo;
import com.linkedin.metadata.Constants;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.RecommendationRenderType;
import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.ScenarioType;
import com.linkedin.metadata.search.EntitySearchService;
import java.util.Set;
import java.util.List;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
@ -18,12 +19,24 @@ import lombok.extern.slf4j.Slf4j;
public class TopPlatformsSource extends EntitySearchAggregationSource {
/**
* TODO: Remove this once we permit specifying set of entities in aggregation API (filter out assertions)
* Set of entities that we want to consider for defining the top platform sources.
* This must match SearchUtils.SEARCHABLE_ENTITY_TYPES
*/
private static final Set<String> FILTERED_DATA_PLATFORM_URNS = ImmutableSet.of(
"urn:li:dataPlatform:great-expectations"
private static final List<String> SEARCHABLE_ENTITY_TYPES = ImmutableList.of(
Constants.DATASET_ENTITY_NAME,
Constants.DASHBOARD_ENTITY_NAME,
Constants.CHART_ENTITY_NAME,
Constants.ML_MODEL_ENTITY_NAME,
Constants.ML_MODEL_GROUP_ENTITY_NAME,
Constants.ML_FEATURE_TABLE_ENTITY_NAME,
Constants.ML_FEATURE_ENTITY_NAME,
Constants.ML_PRIMARY_KEY_ENTITY_NAME,
Constants.DATA_FLOW_ENTITY_NAME,
Constants.DATA_JOB_ENTITY_NAME,
Constants.TAG_ENTITY_NAME,
Constants.CONTAINER_ENTITY_NAME,
Constants.NOTEBOOK_ENTITY_NAME
);
private final EntityService _entityService;
private static final String PLATFORM = "platform";
@ -52,6 +65,10 @@ public class TopPlatformsSource extends EntitySearchAggregationSource {
return requestContext.getScenario() == ScenarioType.HOME;
}
protected List<String> getEntityNames() {
return SEARCHABLE_ENTITY_TYPES;
}
@Override
protected String getSearchFieldName() {
return PLATFORM;
@ -69,9 +86,6 @@ public class TopPlatformsSource extends EntitySearchAggregationSource {
@Override
protected boolean isValidCandidateUrn(Urn urn) {
if (FILTERED_DATA_PLATFORM_URNS.contains(urn.toString())) {
return false;
}
RecordTemplate dataPlatformInfo = _entityService.getLatestAspect(urn, "dataPlatformInfo");
if (dataPlatformInfo == null) {
return false;

View File

@ -131,15 +131,15 @@ public interface EntitySearchService {
/**
* Returns number of documents per field value given the field and filters
*
* @param entityName name of the entity, if empty aggregate over all entities
* @param entityNames list of name of entities to aggregate across, if empty aggregate over all entities
* @param field the field name for aggregate
* @param requestParams filters to apply before aggregating
* @param limit the number of aggregations to return
* @return
*/
@Nonnull
Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field, @Nullable Filter requestParams,
int limit);
Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
@Nullable Filter requestParams, int limit);
/**
* Gets a list of groups/entities that match given browse request.