feat(recomendations): Make top platforms account only for searchable entities (#9240)

This commit is contained in:
Pedro Silva 2023-11-15 19:53:50 +00:00 committed by GitHub
parent 4201e541ca
commit 6655918923
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 31 deletions

View File

@ -142,11 +142,11 @@ public class ElasticSearchService implements EntitySearchService, ElasticSearchI
@Nonnull @Nonnull
@Override @Override
public Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field, public Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
@Nullable Filter requestParams, int limit) { @Nullable Filter requestParams, int limit) {
log.debug("Aggregating by value: {}, field: {}, requestParams: {}, limit: {}", entityName, field, requestParams, log.debug("Aggregating by value: {}, field: {}, requestParams: {}, limit: {}", entityNames.toString(), field,
limit); requestParams, limit);
return esSearchDAO.aggregateByValue(entityName, field, requestParams, limit); return esSearchDAO.aggregateByValue(entityNames, field, requestParams, limit);
} }
@Nonnull @Nonnull

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@ -263,17 +264,16 @@ public class ESSearchDAO {
* @return * @return
*/ */
@Nonnull @Nonnull
public Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field, public Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
@Nullable Filter requestParams, int limit) { @Nullable Filter requestParams, int limit) {
final SearchRequest searchRequest = SearchRequestHandler.getAggregationRequest(field, transformFilterForEntities(requestParams, indexConvention), limit); final SearchRequest searchRequest = SearchRequestHandler.getAggregationRequest(field, transformFilterForEntities(requestParams, indexConvention), limit);
String indexName; if (entityNames == null) {
if (entityName == null) { String indexName = indexConvention.getAllEntityIndicesPattern();
indexName = indexConvention.getAllEntityIndicesPattern(); searchRequest.indices(indexName);
} else { } else {
EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName); Stream<String> stream = entityNames.stream().map(entityRegistry::getEntitySpec).map(indexConvention::getIndexName);
indexName = indexConvention.getIndexName(entitySpec); searchRequest.indices(stream.toArray(String[]::new));
} }
searchRequest.indices(indexName);
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "aggregateByValue_search").time()) { try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "aggregateByValue_search").time()) {
final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

View File

@ -3,6 +3,7 @@ package com.linkedin.metadata.search;
import com.datahub.test.Snapshot; import com.datahub.test.Snapshot;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.collect.ImmutableList;
import com.linkedin.common.urn.TestEntityUrn; import com.linkedin.common.urn.TestEntityUrn;
import com.linkedin.common.urn.Urn; import com.linkedin.common.urn.Urn;
import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor; import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
@ -99,7 +100,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
BrowseResult browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10); BrowseResult browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0); assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0); assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
Urn urn = new TestEntityUrn("test", "urn1", "VALUE_1"); Urn urn = new TestEntityUrn("test", "urn1", "VALUE_1");
ObjectNode document = JsonNodeFactory.instance.objectNode(); ObjectNode document = JsonNodeFactory.instance.objectNode();
@ -124,7 +125,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1); assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1);
assertEquals(browseResult.getGroups().get(0).getName(), "b"); assertEquals(browseResult.getGroups().get(0).getName(), "b");
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10), assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L)); ImmutableMap.of("textFieldOverride", 1L));
Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2"); Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2");
@ -147,7 +148,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1); assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 1);
assertEquals(browseResult.getGroups().get(0).getName(), "b"); assertEquals(browseResult.getGroups().get(0).getName(), "b");
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10), assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L)); ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L));
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString()); _elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
@ -158,7 +159,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10); browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0); assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0); assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
} }
@Test @Test
@ -181,7 +182,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getEntities().get(0).getEntity(), urn); assertEquals(searchResult.getEntities().get(0).getEntity(), urn);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 1);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10), assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L)); ImmutableMap.of("textFieldOverride", 1L));
Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2"); Urn urn2 = new TestEntityUrn("test2", "urn2", "VALUE_2");
@ -198,7 +199,7 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getEntities().get(0).getEntity(), urn2); assertEquals(searchResult.getEntities().get(0).getEntity(), urn2);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 2);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textFieldOverride", null, 10), assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textFieldOverride", null, 10),
ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L)); ImmutableMap.of("textFieldOverride", 1L, "textFieldOverride2", 1L));
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString()); _elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
@ -208,6 +209,6 @@ abstract public class TestEntityTestBase extends AbstractTestNGSpringContextTest
assertEquals(searchResult.getNumEntities().intValue(), 0); assertEquals(searchResult.getNumEntities().intValue(), 0);
assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0); assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
assertEquals(_elasticSearchService.aggregateByValue(ENTITY_NAME, "textField", null, 10).size(), 0); assertEquals(_elasticSearchService.aggregateByValue(ImmutableList.of(ENTITY_NAME), "textField", null, 10).size(), 0);
} }
} }

View File

@ -82,7 +82,7 @@ public abstract class EntitySearchAggregationSource implements RecommendationSou
public List<RecommendationContent> getRecommendations(@Nonnull Urn userUrn, public List<RecommendationContent> getRecommendations(@Nonnull Urn userUrn,
@Nullable RecommendationRequestContext requestContext) { @Nullable RecommendationRequestContext requestContext) {
Map<String, Long> aggregationResult = Map<String, Long> aggregationResult =
_entitySearchService.aggregateByValue(null, getSearchFieldName(), null, getMaxContent()); _entitySearchService.aggregateByValue(getEntityNames(), getSearchFieldName(), null, getMaxContent());
if (aggregationResult.isEmpty()) { if (aggregationResult.isEmpty()) {
return Collections.emptyList(); return Collections.emptyList();
@ -116,6 +116,11 @@ public abstract class EntitySearchAggregationSource implements RecommendationSou
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
protected List<String> getEntityNames() {
// By default, no list is applied which means searching across entities.
return null;
}
// Get top K entries with the most count // Get top K entries with the most count
private <T> List<Map.Entry<T, Long>> getTopKValues(Map<T, Long> countMap) { private <T> List<Map.Entry<T, Long>> getTopKValues(Map<T, Long> countMap) {
final PriorityQueue<Map.Entry<T, Long>> queue = final PriorityQueue<Map.Entry<T, Long>> queue =

View File

@ -1,15 +1,16 @@
package com.linkedin.metadata.recommendation.candidatesource; package com.linkedin.metadata.recommendation.candidatesource;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableList;
import com.linkedin.common.urn.Urn; import com.linkedin.common.urn.Urn;
import com.linkedin.data.template.RecordTemplate; import com.linkedin.data.template.RecordTemplate;
import com.linkedin.dataplatform.DataPlatformInfo; import com.linkedin.dataplatform.DataPlatformInfo;
import com.linkedin.metadata.Constants;
import com.linkedin.metadata.entity.EntityService; import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.RecommendationRenderType; import com.linkedin.metadata.recommendation.RecommendationRenderType;
import com.linkedin.metadata.recommendation.RecommendationRequestContext; import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.ScenarioType; import com.linkedin.metadata.recommendation.ScenarioType;
import com.linkedin.metadata.search.EntitySearchService; import com.linkedin.metadata.search.EntitySearchService;
import java.util.Set; import java.util.List;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -18,12 +19,24 @@ import lombok.extern.slf4j.Slf4j;
public class TopPlatformsSource extends EntitySearchAggregationSource { public class TopPlatformsSource extends EntitySearchAggregationSource {
/** /**
* TODO: Remove this once we permit specifying set of entities in aggregation API (filter out assertions) * Set of entities that we want to consider for defining the top platform sources.
* This must match SearchUtils.SEARCHABLE_ENTITY_TYPES
*/ */
private static final Set<String> FILTERED_DATA_PLATFORM_URNS = ImmutableSet.of( private static final List<String> SEARCHABLE_ENTITY_TYPES = ImmutableList.of(
"urn:li:dataPlatform:great-expectations" Constants.DATASET_ENTITY_NAME,
Constants.DASHBOARD_ENTITY_NAME,
Constants.CHART_ENTITY_NAME,
Constants.ML_MODEL_ENTITY_NAME,
Constants.ML_MODEL_GROUP_ENTITY_NAME,
Constants.ML_FEATURE_TABLE_ENTITY_NAME,
Constants.ML_FEATURE_ENTITY_NAME,
Constants.ML_PRIMARY_KEY_ENTITY_NAME,
Constants.DATA_FLOW_ENTITY_NAME,
Constants.DATA_JOB_ENTITY_NAME,
Constants.TAG_ENTITY_NAME,
Constants.CONTAINER_ENTITY_NAME,
Constants.NOTEBOOK_ENTITY_NAME
); );
private final EntityService _entityService; private final EntityService _entityService;
private static final String PLATFORM = "platform"; private static final String PLATFORM = "platform";
@ -52,6 +65,10 @@ public class TopPlatformsSource extends EntitySearchAggregationSource {
return requestContext.getScenario() == ScenarioType.HOME; return requestContext.getScenario() == ScenarioType.HOME;
} }
protected List<String> getEntityNames() {
return SEARCHABLE_ENTITY_TYPES;
}
@Override @Override
protected String getSearchFieldName() { protected String getSearchFieldName() {
return PLATFORM; return PLATFORM;
@ -69,9 +86,6 @@ public class TopPlatformsSource extends EntitySearchAggregationSource {
@Override @Override
protected boolean isValidCandidateUrn(Urn urn) { protected boolean isValidCandidateUrn(Urn urn) {
if (FILTERED_DATA_PLATFORM_URNS.contains(urn.toString())) {
return false;
}
RecordTemplate dataPlatformInfo = _entityService.getLatestAspect(urn, "dataPlatformInfo"); RecordTemplate dataPlatformInfo = _entityService.getLatestAspect(urn, "dataPlatformInfo");
if (dataPlatformInfo == null) { if (dataPlatformInfo == null) {
return false; return false;

View File

@ -131,15 +131,15 @@ public interface EntitySearchService {
/** /**
* Returns number of documents per field value given the field and filters * Returns number of documents per field value given the field and filters
* *
* @param entityName name of the entity, if empty aggregate over all entities * @param entityNames list of name of entities to aggregate across, if empty aggregate over all entities
* @param field the field name for aggregate * @param field the field name for aggregate
* @param requestParams filters to apply before aggregating * @param requestParams filters to apply before aggregating
* @param limit the number of aggregations to return * @param limit the number of aggregations to return
* @return * @return
*/ */
@Nonnull @Nonnull
Map<String, Long> aggregateByValue(@Nullable String entityName, @Nonnull String field, @Nullable Filter requestParams, Map<String, Long> aggregateByValue(@Nullable List<String> entityNames, @Nonnull String field,
int limit); @Nullable Filter requestParams, int limit);
/** /**
* Gets a list of groups/entities that match given browse request. * Gets a list of groups/entities that match given browse request.