diff --git a/build.gradle b/build.gradle index 60c04c60d3..c54714802f 100644 --- a/build.gradle +++ b/build.gradle @@ -41,6 +41,7 @@ project.ext.externalDependency = [ 'commonsLang': 'commons-lang:commons-lang:2.6', 'commonsCollections': 'commons-collections:commons-collections:3.2.2', 'data' : 'com.linkedin.pegasus:data:' + pegasusVersion, + 'dgraph4j' : 'io.dgraph:dgraph4j:21.03.1', 'dropwizardMetricsCore': 'io.dropwizard.metrics:metrics-core:4.2.3', 'dropwizardMetricsJmx': 'io.dropwizard.metrics:metrics-jmx:4.2.3', 'ebean': 'io.ebean:ebean:11.33.3', diff --git a/metadata-io/build.gradle b/metadata-io/build.gradle index 342d94610e..898bfbf673 100644 --- a/metadata-io/build.gradle +++ b/metadata-io/build.gradle @@ -15,6 +15,7 @@ dependencies { compile spec.product.pegasus.data compile spec.product.pegasus.generator + compile externalDependency.dgraph4j exclude group: 'com.google.guava', module: 'guava' compile externalDependency.lombok compile externalDependency.elasticSearchRest compile externalDependency.elasticSearchTransport @@ -25,6 +26,7 @@ dependencies { compile externalDependency.ebean enhance externalDependency.ebeanAgent compile externalDependency.opentelemetryAnnotations + compile externalDependency.resilience4j compile externalDependency.springContext annotationProcessor externalDependency.lombok @@ -39,6 +41,7 @@ dependencies { testCompile externalDependency.testContainers testCompile externalDependency.testContainersJunit testCompile externalDependency.testContainersElasticsearch + testCompile externalDependency.lombok testCompile project(':test-models') testAnnotationProcessor externalDependency.lombok diff --git a/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphExecutor.java b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphExecutor.java new file mode 100644 index 0000000000..dc267d16e4 --- /dev/null +++ b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphExecutor.java @@ -0,0 +1,98 @@ +package com.linkedin.metadata.graph; + +import io.dgraph.DgraphClient; +import io.dgraph.TxnConflictException; +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; +import io.grpc.StatusRuntimeException; +import lombok.extern.slf4j.Slf4j; + +import java.time.Duration; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.function.Function; + +@Slf4j +public class DgraphExecutor { + + // requests are retried with an exponential randomized backoff + // wait 0.01s, 0.02s, 0.04s, 0.08s, ..., 10s, all ±50% + private static final Duration INITIAL_DURATION = Duration.ofMillis(10); + private static final Duration MAX_DURATION = Duration.ofSeconds(10); + private static final double BACKOFF_MULTIPLIER = 2.0; + private static final double RANDOMIZATION_FACTOR = 0.5; + + private final DgraphClient _client; + private final Retry _retry; + + public DgraphExecutor(DgraphClient client, int maxAttempts) { + this._client = client; + + RetryConfig config = RetryConfig.custom() + .intervalFunction(IntervalFunction.ofExponentialRandomBackoff(INITIAL_DURATION, BACKOFF_MULTIPLIER, RANDOMIZATION_FACTOR, MAX_DURATION)) + .retryOnException(DgraphExecutor::isRetryableException) + .failAfterMaxAttempts(true) + .maxAttempts(maxAttempts) + .build(); + this._retry = Retry.of("DgraphExecutor", config); + } + + /** + * Executes the given DgraphClient call and retries retry-able exceptions. + * Subsequent executions will experience an exponential randomized backoff. + * + * @param func call on the provided DgraphClient + * @param return type of the function + * @return return value of the function + * @throws io.github.resilience4j.retry.MaxRetriesExceeded if max attempts exceeded + */ + public T executeFunction(Function func) { + return Retry.decorateFunction(this._retry, func).apply(_client); + } + + /** + * Executes the given DgraphClient call and retries retry-able exceptions. + * Subsequent executions will experience an exponential randomized backoff. + * + * @param func call on the provided DgraphClient + * @throws io.github.resilience4j.retry.MaxRetriesExceeded if max attempts exceeded + */ + public void executeConsumer(Consumer func) { + this._retry.executeSupplier(() -> { + func.accept(_client); + return null; + }); + } + + /** + * Defines which DgraphClient exceptions are being retried. + * + * @param t exception from DgraphClient + * @return true if this exception can be retried + */ + private static boolean isRetryableException(Throwable t) { + // unwrap RuntimeException and ExecutionException + while (true) { + if ((t instanceof RuntimeException || t instanceof ExecutionException) && t.getCause() != null) { + t = t.getCause(); + continue; + } + break; + } + + // retry-able exceptions + if (t instanceof TxnConflictException + || t instanceof StatusRuntimeException && ( + t.getMessage().contains("operation opIndexing is already running") + || t.getMessage().contains("Please retry") + || t.getMessage().contains("DEADLINE_EXCEEDED:") + || t.getMessage().contains("context deadline exceeded") + || t.getMessage().contains("Only leader can decide to commit or abort") + )) { + log.debug("retrying request due to {}", t.getMessage()); + return true; + } + return false; + } +} diff --git a/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphGraphService.java b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphGraphService.java new file mode 100644 index 0000000000..6767229515 --- /dev/null +++ b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphGraphService.java @@ -0,0 +1,667 @@ +package com.linkedin.metadata.graph; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.ByteString; +import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.query.filter.Criterion; +import com.linkedin.metadata.query.filter.CriterionArray; +import com.linkedin.metadata.query.filter.Filter; +import com.linkedin.metadata.query.filter.RelationshipDirection; +import com.linkedin.metadata.query.filter.RelationshipFilter; +import io.dgraph.DgraphClient; +import io.dgraph.DgraphProto.Mutation; +import io.dgraph.DgraphProto.NQuad; +import io.dgraph.DgraphProto.Operation; +import io.dgraph.DgraphProto.Request; +import io.dgraph.DgraphProto.Response; +import io.dgraph.DgraphProto.Value; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.tuple.Pair; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +@Slf4j +public class DgraphGraphService implements GraphService { + + // calls to Dgraph cluster will be retried if they throw retry-able exceptions + // with a max number of attempts of 160 a call will finally fail after around 15 minutes + private static final int MAX_ATTEMPTS = 160; + + private final @Nonnull DgraphExecutor _dgraph; + + private static final String URN_RELATIONSHIP_TYPE = "urn"; + private static final String TYPE_RELATIONSHIP_TYPE = "type"; + private static final String KEY_RELATIONSHIP_TYPE = "key"; + + + @Getter(lazy = true) + // we want to defer initialization of schema (accessing Dgraph server) to the first time accessing _schema + private final DgraphSchema _schema = getSchema(); + + public DgraphGraphService(@Nonnull DgraphClient client) { + this._dgraph = new DgraphExecutor(client, MAX_ATTEMPTS); + } + + protected @Nonnull DgraphSchema getSchema() { + Response response = _dgraph.executeFunction(dgraphClient -> + dgraphClient.newReadOnlyTransaction().doRequest( + Request.newBuilder().setQuery("schema { predicate }").build() + ) + ); + DgraphSchema schema = getSchema(response.getJson().toStringUtf8()).withDgraph(_dgraph); + + if (schema.isEmpty()) { + Operation setSchema = Operation.newBuilder() + .setSchema("" + + ": string @index(hash) @upsert .\n" + + ": string @index(hash) .\n" + + ": string @index(hash) .\n" + ) + .build(); + _dgraph.executeConsumer(dgraphClient -> dgraphClient.alter(setSchema)); + } + + return schema; + } + + protected static @Nonnull DgraphSchema getSchema(@Nonnull String json) { + Map data = getDataFromResponseJson(json); + + Object schemaObj = data.get("schema"); + if (!(schemaObj instanceof List)) { + log.info("The result from Dgraph did not contain a 'schema' field, or that field is not a List"); + return DgraphSchema.empty(); + } + + List schemaList = (List) schemaObj; + Set fieldNames = schemaList.stream().flatMap(fieldObj -> { + if (!(fieldObj instanceof Map)) { + return Stream.empty(); + } + + Map fieldMap = (Map) fieldObj; + if (!(fieldMap.containsKey("predicate") && fieldMap.get("predicate") instanceof String)) { + return Stream.empty(); + } + + String fieldName = (String) fieldMap.get("predicate"); + return Stream.of(fieldName); + }).filter(f -> !f.startsWith("dgraph.")).collect(Collectors.toSet()); + + Object typesObj = data.get("types"); + if (!(typesObj instanceof List)) { + log.info("The result from Dgraph did not contain a 'types' field, or that field is not a List"); + return DgraphSchema.empty(); + } + + List types = (List) typesObj; + Map> typeFields = types.stream().flatMap(typeObj -> { + if (!(typeObj instanceof Map)) { + return Stream.empty(); + } + + Map typeMap = (Map) typeObj; + if (!(typeMap.containsKey("fields") + && typeMap.containsKey("name") + && typeMap.get("fields") instanceof List + && typeMap.get("name") instanceof String)) { + return Stream.empty(); + } + + String typeName = (String) typeMap.get("name"); + List fieldsList = (List) typeMap.get("fields"); + + Set fields = fieldsList.stream().flatMap(fieldObj -> { + if (!(fieldObj instanceof Map)) { + return Stream.empty(); + } + + Map fieldMap = (Map) fieldObj; + if (!(fieldMap.containsKey("name") && fieldMap.get("name") instanceof String)) { + return Stream.empty(); + } + + String fieldName = (String) fieldMap.get("name"); + return Stream.of(fieldName); + }).filter(f -> !f.startsWith("dgraph.")).collect(Collectors.toSet()); + return Stream.of(Pair.of(typeName, fields)); + }).filter(t -> !t.getKey().startsWith("dgraph.")).collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + + return new DgraphSchema(fieldNames, typeFields); + } + + @Override + public void addEdge(Edge edge) { + log.debug(String.format("Adding Edge source: %s, destination: %s, type: %s", + edge.getSource(), + edge.getDestination(), + edge.getRelationshipType())); + + // add the relationship type to the schema + // TODO: translate edge name to allowed dgraph uris + String sourceEntityType = getDgraphType(edge.getSource()); + String relationshipType = edge.getRelationshipType(); + get_schema().ensureField(sourceEntityType, relationshipType, URN_RELATIONSHIP_TYPE, TYPE_RELATIONSHIP_TYPE, KEY_RELATIONSHIP_TYPE); + + // lookup the source and destination nodes + // TODO: add escape for string values + String query = String.format("query {\n" + + " src as var(func: eq(urn, \"%s\"))\n" + + " dst as var(func: eq(urn, \"%s\"))\n" + + "}", edge.getSource(), edge.getDestination()); + String srcVar = "uid(src)"; + String dstVar = "uid(dst)"; + + // edge case: source and destination are same node + if (edge.getSource().equals(edge.getDestination())) { + query = String.format("query {\n" + + " node as var(func: eq(urn, \"%s\"))\n" + + "}", edge.getSource()); + srcVar = "uid(node)"; + dstVar = "uid(node)"; + } + + // create source and destination nodes if they do not exist + // and create the new edge between them + // TODO: add escape for string values + // TODO: translate edge name to allowed dgraph uris + StringJoiner mutations = new StringJoiner("\n"); + mutations.add(String.format("%s \"%s\" .", srcVar, getDgraphType(edge.getSource()))); + mutations.add(String.format("%s \"%s\" .", srcVar, edge.getSource())); + mutations.add(String.format("%s \"%s\" .", srcVar, edge.getSource().getEntityType())); + mutations.add(String.format("%s \"%s\" .", srcVar, edge.getSource().getEntityKey())); + if (!edge.getSource().equals(edge.getDestination())) { + mutations.add(String.format("%s \"%s\" .", dstVar, getDgraphType(edge.getDestination()))); + mutations.add(String.format("%s \"%s\" .", dstVar, edge.getDestination())); + mutations.add(String.format("%s \"%s\" .", dstVar, edge.getDestination().getEntityType())); + mutations.add(String.format("%s \"%s\" .", dstVar, edge.getDestination().getEntityKey())); + } + mutations.add(String.format("%s <%s> %s .", srcVar, edge.getRelationshipType(), dstVar)); + + log.debug("Query: " + query); + log.debug("Mutations: " + mutations); + + // construct the upsert + Mutation mutation = Mutation.newBuilder() + .setSetNquads(ByteString.copyFromUtf8(mutations.toString())) + .build(); + Request request = Request.newBuilder() + .setQuery(query) + .addMutations(mutation) + .setCommitNow(true) + .build(); + + // run the request + _dgraph.executeFunction(client -> client.newTransaction().doRequest(request)); + } + + private static @Nonnull String getDgraphType(@Nonnull Urn urn) { + return urn.getNamespace() + ":" + urn.getEntityType(); + } + + // Returns reversed and directed relationship types: + // returns <~rel> on outgoing and on incoming and both on undirected + private static List getDirectedRelationshipTypes(List relationships, + RelationshipDirection direction) { + + if (direction == RelationshipDirection.OUTGOING || direction == RelationshipDirection.UNDIRECTED) { + List outgoingRelationships = relationships.stream() + .map(type -> "~" + type).collect(Collectors.toList()); + + if (direction == RelationshipDirection.OUTGOING) { + return outgoingRelationships; + } else { + relationships = new ArrayList<>(relationships); + relationships.addAll(outgoingRelationships); + } + } + + return relationships; + } + + protected static String getQueryForRelatedEntities(@Nullable String sourceType, + @Nonnull Filter sourceEntityFilter, + @Nullable String destinationType, + @Nonnull Filter destinationEntityFilter, + @Nonnull List relationshipTypes, + @Nonnull RelationshipFilter relationshipFilter, + int offset, + int count) { + if (relationshipTypes.isEmpty()) { + // we would have to construct a query that never returns any results + // just do not call this method in the first place + throw new IllegalArgumentException("The relationship types must not be empty"); + } + + + if (sourceEntityFilter.hasCriteria() || destinationEntityFilter.hasCriteria()) { + throw new IllegalArgumentException("The DgraphGraphService does not support criteria in source or destination entity filter"); + } + + //noinspection ConstantConditions + if (sourceEntityFilter.hasOr() && sourceEntityFilter.getOr().size() > 1 + || destinationEntityFilter.hasOr() && destinationEntityFilter.getOr().size() > 1) { + throw new IllegalArgumentException("The DgraphGraphService does not support multiple OR criteria in source or destination entity filter"); + } + + //noinspection ConstantConditions + if (relationshipFilter.hasCriteria() || relationshipFilter.hasOr() && relationshipFilter.getOr().size() > 0) { + throw new IllegalArgumentException("The DgraphGraphService does not support any criteria for the relationship filter"); + } + + // We are not querying for and return + // but we reverse the relationship and query for <~relationship> + // this guarantees there are no duplicates among the returned s + final List directedRelationshipTypes = getDirectedRelationshipTypes( + relationshipTypes, relationshipFilter.getDirection() + ); + + List filters = new ArrayList<>(); + + Set destinationNodeFilterNames = new HashSet<>(); + String sourceTypeFilterName = null; + String destinationTypeFilterName = null; + List sourceFilterNames = new ArrayList<>(); + List destinationFilterNames = new ArrayList<>(); + List relationshipTypeFilterNames = new ArrayList<>(); + + if (sourceType != null) { + sourceTypeFilterName = "sourceType"; + // TODO: escape string value + filters.add(String.format("%s as var(func: eq(, \"%s\"))", sourceTypeFilterName, sourceType)); + } + + if (destinationType != null) { + destinationTypeFilterName = "destinationType"; + // TODO: escape string value + filters.add(String.format("%s as var(func: eq(, \"%s\"))", destinationTypeFilterName, destinationType)); + } + + //noinspection ConstantConditions + if (sourceEntityFilter.hasOr() && sourceEntityFilter.getOr().size() == 1) { + CriterionArray sourceCriteria = sourceEntityFilter.getOr().get(0).getAnd(); + IntStream.range(0, sourceCriteria.size()) + .forEach(idx -> { + String sourceFilterName = "sourceFilter" + (idx + 1); + sourceFilterNames.add(sourceFilterName); + Criterion criterion = sourceCriteria.get(idx); + // TODO: escape field name and string value + filters.add(String.format("%s as var(func: eq(<%s>, \"%s\"))", sourceFilterName, criterion.getField(), criterion.getValue())); + }); + } + + //noinspection ConstantConditions + if (destinationEntityFilter.hasOr() && destinationEntityFilter.getOr().size() == 1) { + CriterionArray destinationCriteria = destinationEntityFilter.getOr().get(0).getAnd(); + IntStream.range(0, destinationCriteria.size()) + .forEach(idx -> { + String sourceFilterName = "destinationFilter" + (idx + 1); + destinationFilterNames.add(sourceFilterName); + Criterion criterion = destinationCriteria.get(idx); + // TODO: escape field name and string value + filters.add(String.format("%s as var(func: eq(<%s>, \"%s\"))", sourceFilterName, criterion.getField(), criterion.getValue())); + }); + } + + IntStream.range(0, directedRelationshipTypes.size()) + .forEach(idx -> { + String relationshipTypeFilterName = "relationshipType" + (idx + 1); + relationshipTypeFilterNames.add(relationshipTypeFilterName); + // TODO: escape string value + filters.add(String.format("%s as var(func: has(<%s>))", relationshipTypeFilterName, directedRelationshipTypes.get(idx))); + }); + + // the destination node filter is the first filter that is being applied on the destination node + // we can add multiple filters, they will combine as OR + if (destinationTypeFilterName != null) { + destinationNodeFilterNames.add(destinationTypeFilterName); + } + destinationNodeFilterNames.addAll(destinationFilterNames); + destinationNodeFilterNames.addAll(relationshipTypeFilterNames); + + StringJoiner destinationNodeFilterJoiner = new StringJoiner(", "); + destinationNodeFilterNames.stream().sorted().forEach(destinationNodeFilterJoiner::add); + String destinationNodeFilter = destinationNodeFilterJoiner.toString(); + + String filterConditions = getFilterConditions( + sourceTypeFilterName, destinationTypeFilterName, + sourceFilterNames, destinationFilterNames, + relationshipTypeFilterNames, directedRelationshipTypes + ); + + StringJoiner relationshipsJoiner = new StringJoiner("\n "); + getRelationships(sourceTypeFilterName, sourceFilterNames, directedRelationshipTypes) + .forEach(relationshipsJoiner::add); + String relationships = relationshipsJoiner.toString(); + + StringJoiner filterJoiner = new StringJoiner("\n "); + filters.forEach(filterJoiner::add); + String filterExpressions = filterJoiner.toString(); + + return String.format("query {\n" + + " %s\n" + + "\n" + + " result (func: uid(%s), first: %d, offset: %d) %s {\n" + + " \n" + + " %s\n" + + " }\n" + + "}", + filterExpressions, + destinationNodeFilter, + count, offset, + filterConditions, + relationships); + } + + @Nonnull + @Override + public RelatedEntitiesResult findRelatedEntities(@Nullable String sourceType, + @Nonnull Filter sourceEntityFilter, + @Nullable String destinationType, + @Nonnull Filter destinationEntityFilter, + @Nonnull List relationshipTypes, + @Nonnull RelationshipFilter relationshipFilter, + int offset, + int count) { + if (relationshipTypes.isEmpty() || relationshipTypes.stream().noneMatch(relationship -> get_schema().hasField(relationship))) { + return new RelatedEntitiesResult(offset, 0, 0, Collections.emptyList()); + } + + String query = getQueryForRelatedEntities( + sourceType, sourceEntityFilter, + destinationType, destinationEntityFilter, + relationshipTypes.stream().filter(get_schema()::hasField).collect(Collectors.toList()), + relationshipFilter, + offset, count + ); + + Request request = Request.newBuilder() + .setQuery(query) + .build(); + + log.debug("Query: " + query); + Response response = _dgraph.executeFunction(client -> client.newReadOnlyTransaction().doRequest(request)); + String json = response.getJson().toStringUtf8(); + Map data = getDataFromResponseJson(json); + + List entities = getRelatedEntitiesFromResponseData(data); + int total = offset + entities.size(); + if (entities.size() == count) { + // indicate that there might be more results + total++; + } + return new RelatedEntitiesResult(offset, entities.size(), total, entities); + } + + // Creates filter conditions from destination to source nodes + protected static @Nonnull String getFilterConditions(@Nullable String sourceTypeFilterName, + @Nullable String destinationTypeFilterName, + @Nonnull List sourceFilterNames, + @Nonnull List destinationFilterNames, + @Nonnull List relationshipTypeFilterNames, + @Nonnull List relationshipTypes) { + if (relationshipTypes.size() != relationshipTypeFilterNames.size()) { + throw new IllegalArgumentException("relationshipTypeFilterNames and relationshipTypes " + + "must have same size: " + relationshipTypeFilterNames + " vs. " + relationshipTypes); + } + + if (sourceTypeFilterName == null + && destinationTypeFilterName == null + && sourceFilterNames.isEmpty() + && destinationFilterNames.isEmpty() + && relationshipTypeFilterNames.isEmpty()) { + return ""; + } + + StringJoiner andJoiner = new StringJoiner(" AND\n "); + if (destinationTypeFilterName != null) { + andJoiner.add(String.format("uid(%s)", destinationTypeFilterName)); + } + + destinationFilterNames.forEach(filter -> andJoiner.add(String.format("uid(%s)", filter))); + + if (!relationshipTypes.isEmpty()) { + StringJoiner orJoiner = new StringJoiner(" OR\n "); + IntStream.range(0, relationshipTypes.size()).forEach(idx -> orJoiner.add(getRelationshipCondition( + relationshipTypes.get(idx), relationshipTypeFilterNames.get(idx), + sourceTypeFilterName, sourceFilterNames + ))); + String relationshipCondition = orJoiner.toString(); + andJoiner.add(String.format("(\n %s\n )", relationshipCondition)); + } + + String conditions = andJoiner.toString(); + return String.format("@filter(\n %s\n )", conditions); + } + + protected static String getRelationshipCondition(@Nonnull String relationshipType, + @Nonnull String relationshipTypeFilterName, + @Nullable String objectFilterName, + @Nonnull List destinationFilterNames) { + StringJoiner andJoiner = new StringJoiner(" AND "); + andJoiner.add(String.format("uid(%s)", relationshipTypeFilterName)); + if (objectFilterName != null) { + andJoiner.add(String.format("uid_in(<%s>, uid(%s))", relationshipType, objectFilterName)); + } + destinationFilterNames.forEach(filter -> andJoiner.add(String.format("uid_in(<%s>, uid(%s))", relationshipType, filter))); + return andJoiner.toString(); + } + + + // Creates filter conditions from destination to source nodes + protected static @Nonnull List getRelationships(@Nullable String sourceTypeFilterName, + @Nonnull List sourceFilterNames, + @Nonnull List relationshipTypes) { + return relationshipTypes.stream().map(relationshipType -> { + StringJoiner andJoiner = new StringJoiner(" AND "); + if (sourceTypeFilterName != null) { + andJoiner.add(String.format("uid(%s)", sourceTypeFilterName)); + } + sourceFilterNames.forEach(filterName -> andJoiner.add(String.format("uid(%s)", filterName))); + + if (andJoiner.length() > 0) { + return String.format("<%s> @filter( %s ) { }", relationshipType, andJoiner); + } else { + return String.format("<%s> { }", relationshipType); + } + }).collect(Collectors.toList()); + } + + protected static Map getDataFromResponseJson(String json) { + ObjectMapper mapper = new ObjectMapper(); + TypeReference> typeRef = new TypeReference>() { }; + try { + return mapper.readValue(json, typeRef); + } catch (IOException e) { + throw new RuntimeException("Failed to parse response json: " + json.substring(0, 1000), e); + } + } + + protected static List getRelatedEntitiesFromResponseData(Map data) { + Object obj = data.get("result"); + if (!(obj instanceof List)) { + throw new IllegalArgumentException( + "The result from Dgraph did not contain a 'result' field, or that field is not a List" + ); + } + + List results = (List) obj; + return results.stream().flatMap(destinationObj -> { + if (!(destinationObj instanceof Map)) { + return Stream.empty(); + } + + Map destination = (Map) destinationObj; + if (destination.containsKey("urn") && destination.get("urn") instanceof String) { + String urn = (String) destination.get("urn"); + + return destination.entrySet().stream() + .filter(entry -> !entry.getKey().equals("urn")) + .flatMap(entry -> { + Object relationshipObj = entry.getKey(); + Object sourcesObj = entry.getValue(); + if (!(relationshipObj instanceof String && sourcesObj instanceof List)) { + return Stream.empty(); + } + + String relationship = (String) relationshipObj; + List sources = (List) sourcesObj; + + if (sources.size() == 0) { + return Stream.empty(); + } + + if (relationship.startsWith("~")) { + relationship = relationship.substring(1); + } + + return Stream.of(relationship); + }) + // for undirected we get duplicate relationships + .distinct() + .map(relationship -> new RelatedEntity(relationship, urn)); + } + + return Stream.empty(); + }).collect(Collectors.toList()); + } + + @Override + public void removeNode(@Nonnull Urn urn) { + String query = String.format("query {\n" + + " node as var(func: eq(urn, \"%s\"))\n" + + "}", urn); + String deletion = "uid(node) * * ."; + + log.debug("Query: " + query); + log.debug("Delete: " + deletion); + + Mutation mutation = Mutation.newBuilder() + .setDelNquads(ByteString.copyFromUtf8(deletion)) + .build(); + Request request = Request.newBuilder() + .setQuery(query) + .addMutations(mutation) + .setCommitNow(true) + .build(); + + _dgraph.executeConsumer(client -> client.newTransaction().doRequest(request)); + } + + @Override + public void removeEdgesFromNode(@Nonnull Urn urn, + @Nonnull List relationshipTypes, + @Nonnull RelationshipFilter relationshipFilter) { + if (relationshipTypes.isEmpty()) { + return; + } + + RelationshipDirection direction = relationshipFilter.getDirection(); + + if (direction == RelationshipDirection.OUTGOING || direction == RelationshipDirection.UNDIRECTED) { + removeOutgoingEdgesFromNode(urn, relationshipTypes); + } + + if (direction == RelationshipDirection.INCOMING || direction == RelationshipDirection.UNDIRECTED) { + removeIncomingEdgesFromNode(urn, relationshipTypes); + } + } + + private void removeOutgoingEdgesFromNode(@Nonnull Urn urn, + @Nonnull List relationshipTypes) { + // TODO: add escape for string values + String query = String.format("query {\n" + + " node as var(func: eq(, \"%s\"))\n" + + "}", urn); + + Value star = Value.newBuilder().setDefaultVal("_STAR_ALL").build(); + List deletions = relationshipTypes.stream().map(relationshipType -> + NQuad.newBuilder() + .setSubject("uid(node)") + .setPredicate(relationshipType) + .setObjectValue(star) + .build() + ).collect(Collectors.toList()); + + log.debug("Query: " + query); + log.debug("Deletions: " + deletions); + + Mutation mutation = Mutation.newBuilder() + .addAllDel(deletions) + .build(); + Request request = Request.newBuilder() + .setQuery(query) + .addMutations(mutation) + .setCommitNow(true) + .build(); + + _dgraph.executeConsumer(client -> client.newTransaction().doRequest(request)); + } + + private void removeIncomingEdgesFromNode(@Nonnull Urn urn, + @Nonnull List relationshipTypes) { + // TODO: add escape for string values + StringJoiner reverseEdges = new StringJoiner("\n "); + IntStream.range(0, relationshipTypes.size()).forEach(idx -> + reverseEdges.add("<~" + relationshipTypes.get(idx) + "> { uids" + (idx + 1) + " as uid }") + ); + String query = String.format("query {\n" + + " node as var(func: eq(, \"%s\"))\n" + + "\n" + + " var(func: uid(node)) @normalize {\n" + + " %s\n" + + " }\n" + + "}", urn, reverseEdges); + + StringJoiner deletions = new StringJoiner("\n"); + IntStream.range(0, relationshipTypes.size()).forEach(idx -> + deletions.add("uid(uids" + (idx + 1) + ") <" + relationshipTypes.get(idx) + "> uid(node) .") + ); + + log.debug("Query: " + query); + log.debug("Deletions: " + deletions); + + Mutation mutation = Mutation.newBuilder() + .setDelNquads(ByteString.copyFromUtf8(deletions.toString())) + .build(); + Request request = Request.newBuilder() + .setQuery(query) + .addMutations(mutation) + .setCommitNow(true) + .build(); + + _dgraph.executeConsumer(client -> client.newTransaction().doRequest(request)); + } + + @Override + public void configure() { } + + @Override + public void clear() { + log.debug("dropping Dgraph data"); + + Operation dropAll = Operation.newBuilder().setDropOp(Operation.DropOp.ALL).build(); + _dgraph.executeConsumer(client -> client.alter(dropAll)); + + // drop schema cache + get_schema().clear(); + + // setup urn, type and key relationships + getSchema(); + } +} diff --git a/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphSchema.java b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphSchema.java new file mode 100644 index 0000000000..1dfc811365 --- /dev/null +++ b/metadata-io/src/main/java/com/linkedin/metadata/graph/DgraphSchema.java @@ -0,0 +1,128 @@ +package com.linkedin.metadata.graph; + +import io.dgraph.DgraphProto; +import lombok.extern.slf4j.Slf4j; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; +import java.util.stream.Collectors; + +/** + * Provides a thread-safe Dgraph schema. Returned data structures are immutable. + */ +@Slf4j +public class DgraphSchema { + private final @Nonnull Set fields; + private final @Nonnull Map> types; + private final DgraphExecutor dgraph; + + public static DgraphSchema empty() { + return new DgraphSchema(Collections.emptySet(), Collections.emptyMap(), null); + } + + public DgraphSchema(@Nonnull Set fields, @Nonnull Map> types) { + this(fields, types, null); + } + + public DgraphSchema(@Nonnull Set fields, @Nonnull Map> types, DgraphExecutor dgraph) { + this.fields = fields; + this.types = types; + this.dgraph = dgraph; + } + + /** + * Adds the given DgraphExecutor to this schema returning a new instance. + * Be aware this and the new instance share the underlying fields and types datastructures. + * + * @param dgraph dgraph executor to add + * @return new instance + */ + public DgraphSchema withDgraph(DgraphExecutor dgraph) { + return new DgraphSchema(this.fields, this.types, dgraph); + } + + synchronized public boolean isEmpty() { + return fields.isEmpty(); + } + + synchronized public Set getFields() { + // Provide an unmodifiable copy + return Collections.unmodifiableSet(new HashSet<>(fields)); + } + + synchronized public Set getFields(String typeName) { + // Provide an unmodifiable copy + return Collections.unmodifiableSet(new HashSet<>(types.getOrDefault(typeName, Collections.emptySet()))); + } + + synchronized public Map> getTypes() { + // Provide an unmodifiable copy of the map and contained sets + return Collections.unmodifiableMap( + new HashSet<>(types.entrySet()).stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> Collections.unmodifiableSet(new HashSet<>(e.getValue())) + )) + ); + } + + synchronized public boolean hasType(String typeName) { + return types.containsKey(typeName); + } + + synchronized public boolean hasField(String fieldName) { + return fields.contains(fieldName); + } + + synchronized public boolean hasField(String typeName, String fieldName) { + return types.getOrDefault(typeName, Collections.emptySet()).contains(fieldName); + } + + synchronized public void ensureField(String typeName, String fieldName, String... existingFieldNames) { + // quickly check if the field is known for this type + if (hasField(typeName, fieldName)) { + return; + } + + // add type and field to schema + StringJoiner schema = new StringJoiner("\n"); + + if (!fields.contains(fieldName)) { + schema.add(String.format("<%s>: [uid] @reverse .", fieldName)); + } + + // update the schema on the Dgraph cluster + Set allTypesFields = new HashSet<>(Arrays.asList(existingFieldNames)); + allTypesFields.addAll(types.getOrDefault(typeName, Collections.emptySet())); + allTypesFields.add(fieldName); + + if (dgraph != null) { + log.info("Adding predicate {} for type {} to schema", fieldName, typeName); + + StringJoiner type = new StringJoiner("\n "); + allTypesFields.stream().map(t -> "<" + t + ">").forEach(type::add); + schema.add(String.format("type <%s> {\n %s\n}", typeName, type)); + log.debug("Adding to schema: " + schema); + DgraphProto.Operation setSchema = DgraphProto.Operation.newBuilder().setSchema(schema.toString()).setRunInBackground(true).build(); + dgraph.executeConsumer(dgraphClient -> dgraphClient.alter(setSchema)); + } + + // now that the schema has been updated on dgraph we can cache this new type / field + // ensure type and fields of type exist + if (!types.containsKey(typeName)) { + types.put(typeName, new HashSet<>()); + } + types.get(typeName).add(fieldName); + fields.add(fieldName); + } + + synchronized public void clear() { + types.clear(); + fields.clear(); + } +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphContainer.java b/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphContainer.java new file mode 100644 index 0000000000..6847b9bb93 --- /dev/null +++ b/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphContainer.java @@ -0,0 +1,238 @@ +package com.linkedin.metadata.graph; + +import com.github.dockerjava.api.command.InspectContainerResponse; +import lombok.NonNull; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.HttpWaitStrategy; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.containers.wait.strategy.WaitAllStrategy; +import org.testcontainers.containers.wait.strategy.WaitStrategy; +import org.testcontainers.utility.DockerImageName; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; +import java.util.stream.Stream; + +import static java.net.HttpURLConnection.HTTP_OK; +import static java.util.stream.Collectors.toSet; + +public class DgraphContainer extends GenericContainer { + + /** + * The image defaults to the official Dgraph image: Dgraph. + */ + public static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("dgraph/dgraph"); + + private static final int HTTP_PORT = 8080; + + private static final int GRPC_PORT = 9080; + + private boolean started = false; + + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + super.containerIsStarted(containerInfo); + started = true; + } + + @Override + protected void containerIsStopped(InspectContainerResponse containerInfo) { + super.containerIsStopped(containerInfo); + started = false; + } + + private final Map zeroArguments = new HashMap<>(); + + private final Map alphaArguments = new HashMap<>(); + + /** + * Creates a DgraphContainer using a specific docker image. Connect the container + * to another DgraphContainer to form a cluster via `peerAlias`. + * + * @param dockerImageName The docker image to use. + */ + public DgraphContainer(@NonNull final DockerImageName dockerImageName) { + super(dockerImageName); + + dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME); + + WaitStrategy waitForLeader = new LogMessageWaitStrategy() + .withRegEx(".* Got Zero leader: .*\n"); + WaitStrategy waitForCluster = new LogMessageWaitStrategy() + .withRegEx(".* Server is ready\n"); + WaitStrategy waitForHttp = new HttpWaitStrategy() + .forPort(HTTP_PORT) + .forStatusCodeMatching(response -> response == HTTP_OK); + + this.waitStrategy = new WaitAllStrategy() + .withStrategy(waitForLeader) + .withStrategy(waitForCluster) + .withStrategy(waitForHttp) + .withStartupTimeout(Duration.ofMinutes(1)); + + if (dockerImageName.getVersionPart().compareTo("v21.03.0") < 0) { + withAlphaArgument("whitelist", "0.0.0.0/0"); + } else { + withAlphaArgumentValues("security", "whitelist=0.0.0.0/0"); + } + + addExposedPorts(HTTP_PORT, GRPC_PORT); + } + + /** + * Adds an argument to the zero command. + * + * @param argument name of the argument + * @param value value, null if argument is a flag + * @return this + */ + public DgraphContainer withZeroArgument(@NonNull String argument, String value) { + addArgument(zeroArguments, argument, value); + return this; + } + + /** + * Adds a value to an argument list to the zero command. + * + * Some arguments of the zero command form a list of values, e.g. `audit` or `raft`. + * These values are separated by a ";". Setting multiple values for those arguments should + * be done via this method. + * + * @param argument name of the argument + * @param values values to add to the argument + * @return this + */ + public DgraphContainer withZeroArgumentValues(@NonNull String argument, @NonNull String... values) { + addArgumentValues(zeroArguments, argument, values); + return this; + } + + /** + * Adds an argument to the alpha command. + * + * @param argument name of the argument + * @param value value, null if argument is a flag + * @return this + */ + public DgraphContainer withAlphaArgument(@NonNull String argument, String value) { + addArgument(alphaArguments, argument, value); + return this; + } + + /** + * Adds a value to an argument list to the alpha command. + * + * Some arguments of the alpha command form a list of values, e.g. `audit` or `raft`. + * These values are separated by a ";". Setting multiple values for those arguments should + * be done via this method. + * + * @param argument name of the argument + * @param values values to add to the argument + * @return this + */ + public DgraphContainer withAlphaArgumentValues(@NonNull String argument, @NonNull String... values) { + addArgumentValues(alphaArguments, argument, values); + return this; + } + + private void addArgument(Map arguments, @NonNull String argument, String value) { + if (started) { + throw new IllegalStateException("The container started already, cannot amend command arguments"); + } + + arguments.put(argument, value); + } + + private void addArgumentValues(Map arguments, @NonNull String argument, @NonNull String... values) { + if (started) { + throw new IllegalStateException("The container started already, cannot amend command arguments"); + } + + StringJoiner joiner = new StringJoiner("; "); + Arrays.stream(values).forEach(joiner::add); + String value = joiner.toString(); + + if (arguments.containsKey(argument)) { + arguments.put(argument, arguments.get(argument) + "; " + value); + } else { + arguments.put(argument, value); + } + } + + /** + * Provides the command used to start the zero process. Command line arguments can be added + * by calling `withZeroArgument` and `withZeroArgumentValues` before calling this method. + * @return command string + */ + public @NonNull String getZeroCommand() { + return getCommand("dgraph zero", zeroArguments); + } + + /** + * Provides the command used to start the alpha process. Command line arguments can be added + * by calling `withAlphaArgument` and `withAlphaArgumentValues` before calling this method. + * @return command string + */ + public @NonNull String getAlphaCommand() { + return getCommand("dgraph alpha", alphaArguments); + } + + private @NonNull String getCommand(@NonNull String command, @NonNull Map arguments) { + StringJoiner joiner = new StringJoiner(" --"); + + arguments.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(argument -> { + if (argument.getValue() == null) { + return argument.getKey(); + } else { + return argument.getKey() + " \"" + argument.getValue() + "\""; + } + }).forEach(joiner::add); + + if (joiner.length() == 0) { + return command; + } else { + return command + " --" + joiner; + } + } + + @Override + public void start() { + String zeroCommand = this.getZeroCommand(); + String alhpaCommand = this.getAlphaCommand(); + this.setCommand("/bin/bash", "-c", zeroCommand + " & " + alhpaCommand); + super.start(); + } + + @Override + public Set getLivenessCheckPortNumbers() { + return Stream.of(getHttpPort(), getGrpcPort()) + .map(this::getMappedPort) + .collect(toSet()); + } + + @Override + protected void configure() { } + + public int getHttpPort() { + return getMappedPort(HTTP_PORT); + } + + public int getGrpcPort() { + return getMappedPort(GRPC_PORT); + } + + public String getHttpUrl() { + return String.format("http://%s:%d", getHost(), getHttpPort()); + } + + public String getGrpcUrl() { + return String.format("%s:%d", getHost(), getGrpcPort()); + } + +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphGraphServiceTest.java b/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphGraphServiceTest.java new file mode 100644 index 0000000000..894cb00a10 --- /dev/null +++ b/metadata-io/src/test/java/com/linkedin/metadata/graph/DgraphGraphServiceTest.java @@ -0,0 +1,779 @@ +package com.linkedin.metadata.graph; + +import com.linkedin.metadata.query.filter.RelationshipDirection; +import io.dgraph.DgraphClient; +import io.dgraph.DgraphGrpc; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; +import lombok.extern.slf4j.Slf4j; +import org.testcontainers.containers.output.Slf4jLogConsumer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import javax.annotation.Nonnull; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static com.linkedin.metadata.search.utils.QueryUtils.EMPTY_FILTER; +import static com.linkedin.metadata.search.utils.QueryUtils.newFilter; +import static com.linkedin.metadata.search.utils.QueryUtils.newRelationshipFilter; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +@SuppressWarnings("ArraysAsListWithZeroOrOneArgument") +@Slf4j +public class DgraphGraphServiceTest extends GraphServiceTestBase { + + private ManagedChannel _channel; + private DgraphGraphService _service; + private DgraphContainer _container; + + @Override + protected Duration getTestConcurrentOpTimeout() { + return Duration.ofMinutes(5); + } + + @BeforeTest + public void setup() { + _container = new DgraphContainer(DgraphContainer.DEFAULT_IMAGE_NAME.withTag("v21.03.0")) + .withTmpFs(Collections.singletonMap("/dgraph", "rw,noexec,nosuid,size=1g")) + .withStartupTimeout(Duration.ofMinutes(1)) + .withStartupAttempts(3); + _container.start(); + + Slf4jLogConsumer logConsumer = new Slf4jLogConsumer(log); + _container.followOutput(logConsumer); + } + + @BeforeMethod + public void connect() { + _channel = ManagedChannelBuilder + .forAddress(_container.getHost(), _container.getGrpcPort()) + .usePlaintext() + .build(); + + // https://discuss.dgraph.io/t/dgraph-java-client-setting-deadlines-per-call/3056 + ClientInterceptor timeoutInterceptor = new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions.withDeadlineAfter(30, TimeUnit.SECONDS)); + } + }; + + DgraphGrpc.DgraphStub stub = DgraphGrpc.newStub(_channel).withInterceptors(timeoutInterceptor); + _service = new DgraphGraphService(new DgraphClient(stub)); + } + + @AfterMethod + public void disconnect() throws InterruptedException { + try { + _channel.shutdownNow(); + _channel.awaitTermination(10, TimeUnit.SECONDS); + } finally { + _channel = null; + _service = null; + } + } + + @AfterTest + public void tearDown() { + _container.stop(); + } + + @Nonnull + @Override + protected GraphService getGraphService() { + _service.clear(); + return _service; + } + + @Override + protected void syncAfterWrite() { } + + @Test + public void testGetSchema() { + DgraphSchema schema = DgraphGraphService.getSchema("{\n" + + " \"schema\": [\n" + + " {\n" + + " \"predicate\": \"PredOne\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"PredTwo\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"dgraph.type\"\n" + + " }\n" + + " ],\n" + + " \"types\": [\n" + + " {\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"dgraph.type\"\n" + + " }\n" + + " ],\n" + + " \"name\": \"dgraph.meta\"\n" + + " },\n" + + " {\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"PredOne\"\n" + + " },\n" + + " {\n" + + " \"name\": \"PredTwo\"\n" + + " }\n" + + " ],\n" + + " \"name\": \"ns:typeOne\"\n" + + " },\n" + + " {\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"PredTwo\"\n" + + " }\n" + + " ],\n" + + " \"name\": \"ns:typeTwo\"\n" + + " }\n" + + " ]\n" + + " }"); + assertEquals(schema.getFields(), new HashSet<>(Arrays.asList("PredOne", "PredTwo"))); + + assertEquals(schema.getTypes(), new HashMap>() {{ + put("ns:typeOne", new HashSet<>(Arrays.asList("PredOne", "PredTwo"))); + put("ns:typeTwo", new HashSet<>(Arrays.asList("PredTwo"))); + }}); + + assertEquals(schema.getFields("ns:typeOne"), new HashSet<>(Arrays.asList("PredOne", "PredTwo"))); + assertEquals(schema.getFields("ns:typeTwo"), new HashSet<>(Arrays.asList("PredTwo"))); + assertEquals(schema.getFields("ns:unknown"), Collections.emptySet()); + + schema.ensureField("newType", "newField"); + assertEquals(schema.getFields(), new HashSet<>(Arrays.asList("PredOne", "PredTwo", "newField"))); + assertEquals(schema.getTypes(), new HashMap>() {{ + put("ns:typeOne", new HashSet<>(Arrays.asList("PredOne", "PredTwo"))); + put("ns:typeTwo", new HashSet<>(Arrays.asList("PredTwo"))); + put("newType", new HashSet<>(Arrays.asList("newField"))); + }}); + + schema.ensureField("ns:typeOne", "otherField"); + assertEquals(schema.getFields(), new HashSet<>(Arrays.asList("PredOne", "PredTwo", "newField", "otherField"))); + assertEquals(schema.getTypes(), new HashMap>() {{ + put("ns:typeOne", new HashSet<>(Arrays.asList("PredOne", "PredTwo", "otherField"))); + put("ns:typeTwo", new HashSet<>(Arrays.asList("PredTwo"))); + put("newType", new HashSet<>(Arrays.asList("newField"))); + }}); + + schema.ensureField("ns:typeTwo", "PredTwo"); + assertEquals(schema.getFields(), new HashSet<>(Arrays.asList("PredOne", "PredTwo", "newField", "otherField"))); + assertEquals(schema.getTypes(), new HashMap>() {{ + put("ns:typeOne", new HashSet<>(Arrays.asList("PredOne", "PredTwo", "otherField"))); + put("ns:typeTwo", new HashSet<>(Arrays.asList("PredTwo"))); + put("newType", new HashSet<>(Arrays.asList("newField"))); + }}); + } + + @Test + public void testGetSchemaIncomplete() { + DgraphSchema schemaWithNonListTypes = DgraphGraphService.getSchema("{\n" + + " \"schema\": [\n" + + " {\n" + + " \"predicate\": \"PredOne\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"PredTwo\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"dgraph.type\"\n" + + " }\n" + + " ],\n" + + " \"types\": \"not a list\"\n" + + " }"); + assertTrue(schemaWithNonListTypes.isEmpty(), "Should be empty if type field is not a list"); + + DgraphSchema schemaWithoutTypes = DgraphGraphService.getSchema("{\n" + + " \"schema\": [\n" + + " {\n" + + " \"predicate\": \"PredOne\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"PredTwo\"\n" + + " },\n" + + " {\n" + + " \"predicate\": \"dgraph.type\"\n" + + " }\n" + + " ]" + + " }"); + assertTrue(schemaWithoutTypes.isEmpty(), "Should be empty if no type field exists"); + + DgraphSchema schemaWithNonListSchema = DgraphGraphService.getSchema("{\n" + + " \"schema\": \"not a list\"" + + " }"); + assertTrue(schemaWithNonListSchema.isEmpty(), "Should be empty if schema field is not a list"); + + DgraphSchema schemaWithoutSchema = DgraphGraphService.getSchema("{ }"); + assertTrue(schemaWithoutSchema.isEmpty(), "Should be empty if no schema field exists"); + } + + @Test + public void testGetSchemaDgraph() { + // TODO: test that dgraph schema gets altered + } + + @Test + public void testGetFilterConditions() { + // no filters + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()), + "" + ); + + // source type not supported without restricting relationship types + // there must be as many relation type filter names as there are relationships + assertEquals( + DgraphGraphService.getFilterConditions( + "sourceTypeFilter", + null, + Collections.emptyList(), + Collections.emptyList(), + Arrays.asList("RelationshipTypeFilter"), + Arrays.asList("relationship")), + "@filter(\n" + + " (\n" + + " uid(RelationshipTypeFilter) AND uid_in(, uid(sourceTypeFilter))\n" + + " )\n" + + " )" + ); + + // destination type + assertEquals( + DgraphGraphService.getFilterConditions( + null, + "destinationTypeFilter", + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()), + "@filter(\n" + + " uid(destinationTypeFilter)\n" + + " )" + ); + + // source filter not supported without restricting relationship types + // there must be as many relation type filter names as there are relationships + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Arrays.asList("sourceFilter"), + Collections.emptyList(), + Arrays.asList("RelationshipTypeFilter"), + Arrays.asList("relationship")), + "@filter(\n" + + " (\n" + + " uid(RelationshipTypeFilter) AND uid_in(, uid(sourceFilter))\n" + + " )\n" + + " )" + ); + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Arrays.asList("sourceFilter1", "sourceFilter2"), + Collections.emptyList(), + Arrays.asList("RelationshipTypeFilter"), + Arrays.asList("relationship")), + "@filter(\n" + + " (\n" + + " uid(RelationshipTypeFilter) AND uid_in(, uid(sourceFilter1)) AND " + + "uid_in(, uid(sourceFilter2))\n" + + " )\n" + + " )" + ); + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Arrays.asList("sourceFilter1", "sourceFilter2"), + Collections.emptyList(), + Arrays.asList("RelationshipTypeFilter1", "RelationshipTypeFilter2"), + Arrays.asList("relationship1", "relationship2")), + "@filter(\n" + + " (\n" + + " uid(RelationshipTypeFilter1) AND uid_in(, uid(sourceFilter1)) AND " + + "uid_in(, uid(sourceFilter2)) OR\n" + + " uid(RelationshipTypeFilter2) AND uid_in(, uid(sourceFilter1)) AND " + + "uid_in(, uid(sourceFilter2))\n" + + " )\n" + + " )" + ); + + // destination filters + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Collections.emptyList(), + Arrays.asList("destinationFilter"), + Collections.emptyList(), + Collections.emptyList()), + "@filter(\n" + + " uid(destinationFilter)\n" + + " )" + ); + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Collections.emptyList(), + Arrays.asList("destinationFilter1", "destinationFilter2"), + Collections.emptyList(), + Collections.emptyList()), + "@filter(\n" + + " uid(destinationFilter1) AND\n" + + " uid(destinationFilter2)\n" + + " )" + ); + + // relationship type filters require relationship types + assertEquals( + DgraphGraphService.getFilterConditions( + null, + null, + Collections.emptyList(), + Collections.emptyList(), + Arrays.asList("relationshipTypeFilter1", "relationshipTypeFilter2"), + Arrays.asList("relationship1", "relationship2")), + "@filter(\n" + + " (\n" + + " uid(relationshipTypeFilter1) OR\n" + + " uid(relationshipTypeFilter2)\n" + + " )\n" + + " )" + ); + + // all filters at once + assertEquals( + DgraphGraphService.getFilterConditions( + "sourceTypeFilter", + "destinationTypeFilter", + Arrays.asList("sourceFilter1", "sourceFilter2"), + Arrays.asList("destinationFilter1", "destinationFilter2"), + Arrays.asList("relationshipTypeFilter1", "relationshipTypeFilter2"), + Arrays.asList("relationship1", "relationship2")), + "@filter(\n" + + " uid(destinationTypeFilter) AND\n" + + " uid(destinationFilter1) AND\n" + + " uid(destinationFilter2) AND\n" + + " (\n" + + " uid(relationshipTypeFilter1) AND uid_in(, uid(sourceTypeFilter)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2)) OR\n" + + " uid(relationshipTypeFilter2) AND uid_in(, uid(sourceTypeFilter)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2))\n" + + " )\n" + + " )" + ); + + // TODO: check getFilterConditions throws an exception when relationshipTypes and + // relationshipTypeFilterNames do not have the same size + } + + @Test + public void testGetRelationships() { + // no relationships + assertEquals( + DgraphGraphService.getRelationships( + null, + Collections.emptyList(), + Collections.emptyList()), + Collections.emptyList() + ); + + // one relationship but no filters + assertEquals( + DgraphGraphService.getRelationships( + null, + Collections.emptyList(), + Arrays.asList("relationship") + ), + Arrays.asList(" { }") + ); + + // more relationship and source type filter + assertEquals( + DgraphGraphService.getRelationships( + "sourceTypeFilter", + Collections.emptyList(), + Arrays.asList("relationship1", "~relationship2") + ), + Arrays.asList( + " @filter( uid(sourceTypeFilter) ) { }", + "<~relationship2> @filter( uid(sourceTypeFilter) ) { }" + ) + ); + + // more relationship, source type and source filters + assertEquals( + DgraphGraphService.getRelationships( + "sourceTypeFilter", + Arrays.asList("sourceFilter1", "sourceFilter2"), + Arrays.asList("relationship1", "~relationship2") + ), + Arrays.asList( + " @filter( uid(sourceTypeFilter) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }", + "<~relationship2> @filter( uid(sourceTypeFilter) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }" + ) + ); + + // more relationship and only source filters + assertEquals( + DgraphGraphService.getRelationships( + null, + Arrays.asList("sourceFilter1", "sourceFilter2"), + Arrays.asList("relationship1", "~relationship2", "relationship3") + ), + Arrays.asList( + " @filter( uid(sourceFilter1) AND uid(sourceFilter2) ) { }", + "<~relationship2> @filter( uid(sourceFilter1) AND uid(sourceFilter2) ) { }", + " @filter( uid(sourceFilter1) AND uid(sourceFilter2) ) { }" + ) + ); + + // two relationship and only one source filter + assertEquals( + DgraphGraphService.getRelationships( + null, + Arrays.asList("sourceFilter"), + Arrays.asList("~relationship1", "~relationship2") + ), + Arrays.asList( + "<~relationship1> @filter( uid(sourceFilter) ) { }", + "<~relationship2> @filter( uid(sourceFilter) ) { }" + ) + ); + } + + @Test + public void testGetRelationshipCondition() { + assertEquals( + DgraphGraphService.getRelationshipCondition( + "relationship", + "relationshipFilter", + null, + Collections.emptyList()), + "uid(relationshipFilter)" + ); + + assertEquals( + DgraphGraphService.getRelationshipCondition( + "relationship", + "relationshipFilter", + "destinationTypeFilter", + Collections.emptyList()), + "uid(relationshipFilter) AND uid_in(, uid(destinationTypeFilter))" + ); + + assertEquals( + DgraphGraphService.getRelationshipCondition( + "relationship", + "relationshipFilter", + "destinationTypeFilter", + Arrays.asList("destinationFilter")), + "uid(relationshipFilter) AND uid_in(, uid(destinationTypeFilter)) AND " + + "uid_in(, uid(destinationFilter))" + ); + + assertEquals( + DgraphGraphService.getRelationshipCondition( + "relationship", + "relationshipFilter", + "destinationTypeFilter", + Arrays.asList("destinationFilter1", "destinationFilter2")), + "uid(relationshipFilter) AND uid_in(, uid(destinationTypeFilter)) AND " + + "uid_in(, uid(destinationFilter1)) AND uid_in(, uid(destinationFilter2))" + ); + + assertEquals( + DgraphGraphService.getRelationshipCondition( + "relationship", + "relationshipFilter", + null, + Arrays.asList("destinationFilter1", "destinationFilter2")), + "uid(relationshipFilter) AND uid_in(, uid(destinationFilter1)) AND " + + "uid_in(, uid(destinationFilter2))" + ); + } + + @Test + public void testGetQueryForRelatedEntitiesOutgoing() { + doTestGetQueryForRelatedEntitiesDirection(RelationshipDirection.OUTGOING, + "query {\n" + + " sourceType as var(func: eq(, \"sourceType\"))\n" + + " destinationType as var(func: eq(, \"destinationType\"))\n" + + " sourceFilter1 as var(func: eq(, \"urn:ns:type:source-key\"))\n" + + " sourceFilter2 as var(func: eq(, \"source-key\"))\n" + + " destinationFilter1 as var(func: eq(, \"urn:ns:type:dest-key\"))\n" + + " destinationFilter2 as var(func: eq(, \"dest-key\"))\n" + + " relationshipType1 as var(func: has(<~relationship1>))\n" + + " relationshipType2 as var(func: has(<~relationship2>))\n" + + "\n" + + " result (func: uid(destinationFilter1, destinationFilter2, destinationType, relationshipType1, relationshipType2), " + + "first: 100, offset: 0) @filter(\n" + + " uid(destinationType) AND\n" + + " uid(destinationFilter1) AND\n" + + " uid(destinationFilter2) AND\n" + + " (\n" + + " uid(relationshipType1) AND uid_in(<~relationship1>, uid(sourceType)) AND " + + "uid_in(<~relationship1>, uid(sourceFilter1)) AND uid_in(<~relationship1>, uid(sourceFilter2)) OR\n" + + " uid(relationshipType2) AND uid_in(<~relationship2>, uid(sourceType)) AND " + + "uid_in(<~relationship2>, uid(sourceFilter1)) AND uid_in(<~relationship2>, uid(sourceFilter2))\n" + + " )\n" + + " ) {\n" + + " \n" + + " <~relationship1> @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " <~relationship2> @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " }\n" + + "}" + ); + } + + @Test + public void testGetQueryForRelatedEntitiesIncoming() { + doTestGetQueryForRelatedEntitiesDirection(RelationshipDirection.INCOMING, + "query {\n" + + " sourceType as var(func: eq(, \"sourceType\"))\n" + + " destinationType as var(func: eq(, \"destinationType\"))\n" + + " sourceFilter1 as var(func: eq(, \"urn:ns:type:source-key\"))\n" + + " sourceFilter2 as var(func: eq(, \"source-key\"))\n" + + " destinationFilter1 as var(func: eq(, \"urn:ns:type:dest-key\"))\n" + + " destinationFilter2 as var(func: eq(, \"dest-key\"))\n" + + " relationshipType1 as var(func: has())\n" + + " relationshipType2 as var(func: has())\n" + + "\n" + + " result (func: uid(destinationFilter1, destinationFilter2, destinationType, relationshipType1, relationshipType2), " + + "first: 100, offset: 0) @filter(\n" + + " uid(destinationType) AND\n" + + " uid(destinationFilter1) AND\n" + + " uid(destinationFilter2) AND\n" + + " (\n" + + " uid(relationshipType1) AND uid_in(, uid(sourceType)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2)) OR\n" + + " uid(relationshipType2) AND uid_in(, uid(sourceType)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2))\n" + + " )\n" + + " ) {\n" + + " \n" + + " @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " }\n" + + "}" + ); + } + + @Test + public void testGetQueryForRelatedEntitiesUndirected() { + doTestGetQueryForRelatedEntitiesDirection(RelationshipDirection.UNDIRECTED, + "query {\n" + + " sourceType as var(func: eq(, \"sourceType\"))\n" + + " destinationType as var(func: eq(, \"destinationType\"))\n" + + " sourceFilter1 as var(func: eq(, \"urn:ns:type:source-key\"))\n" + + " sourceFilter2 as var(func: eq(, \"source-key\"))\n" + + " destinationFilter1 as var(func: eq(, \"urn:ns:type:dest-key\"))\n" + + " destinationFilter2 as var(func: eq(, \"dest-key\"))\n" + + " relationshipType1 as var(func: has())\n" + + " relationshipType2 as var(func: has())\n" + + " relationshipType3 as var(func: has(<~relationship1>))\n" + + " relationshipType4 as var(func: has(<~relationship2>))\n" + + "\n" + + " result (func: uid(destinationFilter1, destinationFilter2, destinationType, " + + "relationshipType1, relationshipType2, relationshipType3, relationshipType4), first: 100, offset: 0) @filter(\n" + + " uid(destinationType) AND\n" + + " uid(destinationFilter1) AND\n" + + " uid(destinationFilter2) AND\n" + + " (\n" + + " uid(relationshipType1) AND uid_in(, uid(sourceType)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2)) OR\n" + + " uid(relationshipType2) AND uid_in(, uid(sourceType)) AND " + + "uid_in(, uid(sourceFilter1)) AND uid_in(, uid(sourceFilter2)) OR\n" + + " uid(relationshipType3) AND uid_in(<~relationship1>, uid(sourceType)) AND " + + "uid_in(<~relationship1>, uid(sourceFilter1)) AND uid_in(<~relationship1>, uid(sourceFilter2)) OR\n" + + " uid(relationshipType4) AND uid_in(<~relationship2>, uid(sourceType)) AND " + + "uid_in(<~relationship2>, uid(sourceFilter1)) AND uid_in(<~relationship2>, uid(sourceFilter2))\n" + + " )\n" + + " ) {\n" + + " \n" + + " @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " <~relationship1> @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " <~relationship2> @filter( uid(sourceType) AND uid(sourceFilter1) AND uid(sourceFilter2) ) { }\n" + + " }\n" + + "}" + ); + } + + private void doTestGetQueryForRelatedEntitiesDirection(@Nonnull RelationshipDirection direction, @Nonnull String expectedQuery) { + assertEquals( + DgraphGraphService.getQueryForRelatedEntities( + "sourceType", + newFilter(new HashMap() {{ + put("urn", "urn:ns:type:source-key"); + put("key", "source-key"); + }}), + "destinationType", + newFilter(new HashMap() {{ + put("urn", "urn:ns:type:dest-key"); + put("key", "dest-key"); + }}), + Arrays.asList("relationship1", "relationship2"), + newRelationshipFilter(EMPTY_FILTER, direction), + 0, 100 + ), + expectedQuery + ); + } + + @Test + public void testGetDestinationUrnsFromResponseData() { + // no results + assertEquals( + DgraphGraphService.getRelatedEntitiesFromResponseData( + new HashMap() {{ + put("result", Collections.emptyList()); + }} + ), + Collections.emptyList() + ); + + // one result and one relationship with two sources + assertEquals( + DgraphGraphService.getRelatedEntitiesFromResponseData( + new HashMap() {{ + put("result", Arrays.asList( + new HashMap() {{ + put("urn", "urn:ns:type:dest-key"); + put("~pred", Arrays.asList( + new HashMap() {{ + put("uid", "0x1"); + }}, + new HashMap() {{ + put("uid", "0x2"); + }} + )); + }} + )); + }} + ), + Arrays.asList(new RelatedEntity("pred", "urn:ns:type:dest-key")) + ); + + // multiple results and one relationship + assertEquals( + DgraphGraphService.getRelatedEntitiesFromResponseData( + new HashMap() {{ + put("result", Arrays.asList( + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-1"); + put("~pred", Arrays.asList( + new HashMap() {{ + put("uid", "0x1"); + }}, + new HashMap() {{ + put("uid", "0x2"); + }} + )); + }}, + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-2"); + put("~pred", Arrays.asList( + new HashMap() {{ + put("uid", "0x2"); + }} + )); + }} + )); + }} + ), + Arrays.asList( + new RelatedEntity("pred", "urn:ns:type:dest-key-1"), + new RelatedEntity("pred", "urn:ns:type:dest-key-2") + ) + ); + + // multiple results and relationships + assertEqualsAnyOrder( + DgraphGraphService.getRelatedEntitiesFromResponseData( + new HashMap() {{ + put("result", Arrays.asList( + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-1"); + put("~pred1", Arrays.asList( + new HashMap() {{ + put("uid", "0x1"); + }}, + new HashMap() {{ + put("uid", "0x2"); + }} + )); + }}, + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-2"); + put("~pred1", Arrays.asList( + new HashMap() {{ + put("uid", "0x2"); + }} + )); + }}, + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-3"); + put("pred1", Arrays.asList( + new HashMap() {{ + put("uid", "0x3"); + }} + )); + put("~pred1", Arrays.asList( + new HashMap() {{ + put("uid", "0x1"); + }}, + new HashMap() {{ + put("uid", "0x4"); + }} + )); + }}, + new HashMap() {{ + put("urn", "urn:ns:type:dest-key-4"); + put("pred2", Arrays.asList( + new HashMap() {{ + put("uid", "0x5"); + }} + )); + }} + )); + }} + ), + Arrays.asList( + new RelatedEntity("pred1", "urn:ns:type:dest-key-1"), + new RelatedEntity("pred1", "urn:ns:type:dest-key-2"), + new RelatedEntity("pred1", "urn:ns:type:dest-key-3"), + new RelatedEntity("pred2", "urn:ns:type:dest-key-4") + ), + RELATED_ENTITY_COMPARATOR + ); + } +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/graph/GraphServiceTestBase.java b/metadata-io/src/test/java/com/linkedin/metadata/graph/GraphServiceTestBase.java index 172d4df4d6..f1ccc4d847 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/graph/GraphServiceTestBase.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/graph/GraphServiceTestBase.java @@ -11,6 +11,7 @@ import org.testng.annotations.Test; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.net.URISyntaxException; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.Collections; @@ -127,6 +128,13 @@ abstract public class GraphServiceTestBase { */ protected static @Nullable String anyType = null; + /** + * Timeout used to test concurrent ops in doTestConcurrentOp. + */ + protected Duration getTestConcurrentOpTimeout() { + return Duration.ofMinutes(1); + } + @Test public void testStaticUrns() { assertNotNull(datasetOneUrn); @@ -1293,7 +1301,7 @@ abstract public class GraphServiceTestBase { // too many edges may cause too many threads throwing // java.util.concurrent.RejectedExecutionException: Thread limit exceeded replacing blocked worker int nodes = 5; - int relationshipTypes = 5; + int relationshipTypes = 3; List allRelationships = IntStream.range(1, relationshipTypes + 1).mapToObj(id -> "relationship" + id).collect(Collectors.toList()); List edges = getFullyConnectedGraph(nodes, allRelationships); @@ -1324,8 +1332,8 @@ abstract public class GraphServiceTestBase { public void testConcurrentRemoveEdgesFromNode() throws Exception { final GraphService service = getGraphService(); - int nodes = 10; - int relationshipTypes = 5; + int nodes = 5; + int relationshipTypes = 3; List allRelationships = IntStream.range(1, relationshipTypes + 1).mapToObj(id -> "relationship" + id).collect(Collectors.toList()); List edges = getFullyConnectedGraph(nodes, allRelationships); @@ -1368,8 +1376,8 @@ abstract public class GraphServiceTestBase { // too many edges may cause too many threads throwing // java.util.concurrent.RejectedExecutionException: Thread limit exceeded replacing blocked worker - int nodes = 10; - int relationshipTypes = 5; + int nodes = 5; + int relationshipTypes = 3; List allRelationships = IntStream.range(1, relationshipTypes + 1).mapToObj(id -> "relationship" + id).collect(Collectors.toList()); List edges = getFullyConnectedGraph(nodes, allRelationships); @@ -1426,15 +1434,16 @@ abstract public class GraphServiceTestBase { } operation.run(); - finished.countDown(); } catch (Throwable t) { t.printStackTrace(); throwables.add(t); } + finished.countDown(); } }).start()); - assertTrue(finished.await(10, TimeUnit.SECONDS)); + assertTrue(finished.await(getTestConcurrentOpTimeout().toMillis(), TimeUnit.MILLISECONDS)); + throwables.forEach(throwable -> System.err.printf(System.currentTimeMillis() + ": exception occurred: %s%n", throwable)); assertEquals(throwables.size(), 0); }