feat(reasoner): support id equal push down. (#201)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-04-16 14:29:19 +08:00 committed by GitHub
parent e177b7c6c9
commit 28eca52d73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 398 additions and 19 deletions

View File

@ -15,8 +15,12 @@ package com.antgroup.openspg.reasoner.lube.utils
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.graph.edge.SPO
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.common.pattern.GraphPath
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Block2GraphPathTransformer
@ -53,4 +57,83 @@ object BlockUtils {
}
}
def getStarts(block: Block): Set[String] = {
val start = block.transform[Set[String]] {
case (AggregationBlock(_, _, group), groupList) =>
val groupAlias = group.map(_.name).toSet
if (groupList.head.isEmpty) {
groupAlias
} else {
val commonGroups = groupList.head.intersect(groupAlias)
if (commonGroups.isEmpty) {
throw UnsupportedOperationException(
s"cannot support groups ${groupAlias}, ${groupList.head}")
} else {
commonGroups
}
}
case (DDLBlock(ddlOp, _), list) =>
val starts = new mutable.HashSet[String]()
for (ddl <- ddlOp) {
ddl match {
case AddProperty(s, _, _) => starts.add(s.alias)
case AddPredicate(p) =>
starts.add(p.source.alias)
starts.add(p.target.alias)
case _ =>
}
}
if (list.head.isEmpty) {
starts.toSet
} else if (starts.isEmpty) {
list.head
} else {
val commonStart = list.head.intersect(starts)
if (commonStart.isEmpty) {
throw UnsupportedOperationException(
s"cannot support non-common starts ${list.head}, ${starts}")
} else {
commonStart
}
}
case (SourceBlock(_), _) => Set.empty
case (_, groupList) => groupList.head
}
if (start.isEmpty) {
getFilterStarts(block)
} else {
start
}
}
private def getFilterStarts(block: Block): Set[String] = {
block.transform[Set[String]] {
case (FilterBlock(_, rule), list) =>
rule.getExpr match {
case BinaryOpExpr(BEqual, _, _) =>
val irFields = ExprUtils.getAllInputFieldInRule(rule.getExpr, null, null)
if (irFields.size != 1 || !irFields.head.isInstanceOf[IRNode] || !irFields.head
.asInstanceOf[IRNode]
.fields
.equals(Set.apply(Constants.NODE_ID_KEY))) {
list.head
} else {
if (list.head.isEmpty) {
Set.apply(irFields.head.name)
} else {
val commonStart = list.head.intersect(Set.apply(irFields.head.name))
if (commonStart.isEmpty) {
list.head
} else {
commonStart
}
}
}
case _ => list.head
}
case (SourceBlock(_), _) => Set.empty
case (_, groupList) => groupList.head
}
}
}

View File

@ -14,6 +14,7 @@
package com.antgroup.openspg.reasoner.lube.logical.operators
import com.antgroup.openspg.reasoner.lube.catalog.SemanticPropertyGraph
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.logical.{SolvedModel, Var}
abstract class Source extends LogicalLeafOperator {
@ -31,6 +32,18 @@ final case class Start(
override def fields: List[Var] = solved.fields.values.toList
}
final case class StartFromVertex(
graph: SemanticPropertyGraph,
id: Expr,
types: Set[String],
alias: String,
solved: SolvedModel)
extends Source {
override def refFields: List[Var] = fields
override def fields: List[Var] = solved.fields.values.toList
}
final case class Driving(graph: SemanticPropertyGraph, alias: String, solved: SolvedModel)
extends Source {
override def refFields: List[Var] = fields

View File

@ -26,6 +26,7 @@ object LogicalOptimizer {
var LOGICAL_OPT_RULES: Seq[Rule] =
Seq(
PatternJoinPure,
IdEqualPushDown,
GroupNode,
DistinctGet,
NodeIdToEdgeProperty,

View File

@ -0,0 +1,74 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.lube.logical.optimizer.rules
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.trees.BottomUp
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr, Expr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.logical.operators._
import com.antgroup.openspg.reasoner.lube.logical.optimizer.{Direction, Rule, Up}
import com.antgroup.openspg.reasoner.lube.logical.planning.LogicalPlannerContext
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
case object IdEqualPushDown extends Rule {
override def ruleWithContext(implicit context: LogicalPlannerContext): PartialFunction[
(LogicalOperator, Map[String, Object]),
(LogicalOperator, Map[String, Object])] = {
case (filter: Filter, map) =>
val start = map.get(Constants.START_ALIAS)
val idExpr = getIdExpr(filter, start)
if (idExpr == null) {
filter -> map
} else {
def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = { case start: Start =>
StartFromVertex(start.graph, idExpr, start.types, start.alias, start.solved)
}
val newFilter = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
newFilter.in -> map
}
case (start: Start, _) =>
start -> Map.apply(Constants.START_ALIAS -> start.alias)
}
private def getIdExpr(filter: Filter, start: Option[Object]): Expr = {
if (start.isEmpty) {
return null
}
filter.rule.getExpr match {
case BinaryOpExpr(BEqual, left, right) =>
val irFields = ExprUtils.getAllInputFieldInRule(
filter.rule.getExpr,
filter.solved.getNodeAliasSet,
filter.solved.getEdgeAliasSet)
if (irFields.size != 1 || !irFields.head.isInstanceOf[IRNode] || !irFields.head
.asInstanceOf[IRNode]
.name
.equals(start.get) || !irFields.head
.asInstanceOf[IRNode]
.fields
.equals(Set.apply(Constants.NODE_ID_KEY))) {
null
} else {
right
}
case _ => null
}
}
override def direction: Direction = Up
override def maxIterations: Int = 1
}

View File

@ -16,11 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.{
NotImplementedException,
SchemaException,
UnsupportedOperationException
}
import com.antgroup.openspg.reasoner.common.exception.{NotImplementedException, SchemaException, UnsupportedOperationException}
import com.antgroup.openspg.reasoner.common.graph.edge.SPO
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.catalog.{Catalog, SemanticPropertyGraph}
@ -31,7 +27,7 @@ import com.antgroup.openspg.reasoner.lube.common.rule.Rule
import com.antgroup.openspg.reasoner.lube.logical._
import com.antgroup.openspg.reasoner.lube.logical.operators._
import com.antgroup.openspg.reasoner.lube.logical.validate.Dag
import com.antgroup.openspg.reasoner.lube.utils.{ExprUtils, RuleUtils}
import com.antgroup.openspg.reasoner.lube.utils.{BlockUtils, ExprUtils, RuleUtils}
/**
* Logical planner for KGReasoner, generate an optimal logical plan for KGDSL or GQL.
@ -59,7 +55,7 @@ object LogicalPlanner {
*/
def plan(input: Block)(implicit context: LogicalPlannerContext): List[LogicalOperator] = {
val source = resolve(input)
val groups = getStarts(input)
val groups = BlockUtils.getStarts(input)
val planWithoutResult = if (groups.isEmpty) {
planBlock(input.dependencies.head, input, None, source)
} else {

View File

@ -0,0 +1,58 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.lube.logical
import com.antgroup.openspg.reasoner.lube.utils.BlockUtils
import com.antgroup.openspg.reasoner.parser.OpenSPGDslParser
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
class BlockUtilTests extends AnyFunSpec{
it("group start test") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| amt = group(s).sum(p.amt)
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
BlockUtils.getStarts(block) should equal (Set.apply("s"))
}
it("group filter with id test") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: o.id == '1111111'
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
BlockUtils.getStarts(block) should equal (Set.apply("o"))
}
}

View File

@ -249,4 +249,42 @@ class OptimizerTests extends AnyFunSpec {
group should equal(List.apply(NodeVar("A", Set.empty)))
}
}
it("test id equal push down") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: o.id == '1111111'
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
val schema: Map[String, Set[String]] =
Map.apply(
"test" -> Set.apply("id"),
"test_abc_test" -> Set.empty)
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(
catalog,
parser,
Map
.apply((Constants.SPG_REASONER_MULTI_VERSION_ENABLE, true))
.asInstanceOf[Map[String, Object]])
val dag = Validator.validate(List.apply(block))
val logicalPlan = LogicalPlanner.plan(dag).popRoot()
val rule = Seq(IdEqualPushDown)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan, rule)
optimizedLogicalPlan.findExactlyOne { case start: StartFromVertex =>
start.id should equal(VString("1111111"))
start.alias should equal("o")
}
}
}

View File

@ -13,6 +13,7 @@
package com.antgroup.openspg.reasoner.lube.physical
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.logical.RepeatPathVar
import com.antgroup.openspg.reasoner.lube.physical.rdg.RDG
@ -35,6 +36,16 @@ trait PropertyGraph[T <: RDG[T]] {
*/
def createRDG(alias: String, rdg: T): T
/**
* Start with specific vertex.
*
* @param alias
* @param id
* @param types
* @return
*/
def createRDG(alias: String, id: Expr, types: Set[String]): T
/**
* Start with specific rdg with specific alias which in [[RepeatPathVar]]
* @param repeatVar

View File

@ -50,6 +50,9 @@ abstract class PhysicalLeafOperator[T <: RDG[T]: TypeTag] extends PhysicalOperat
throw UnsupportedOperationException("LogicalLeafOperator cannot construct children")
}
def alias: String
def types: Set[String]
}
abstract class StackingPhysicalOperator[T <: RDG[T]: TypeTag] extends PhysicalOperator[T] {

View File

@ -15,6 +15,7 @@ package com.antgroup.openspg.reasoner.lube.physical.operators
import scala.reflect.runtime.universe.TypeTag
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.logical.Var
import com.antgroup.openspg.reasoner.lube.physical.planning.PhysicalPlannerContext
import com.antgroup.openspg.reasoner.lube.physical.rdg.RDG
@ -28,14 +29,27 @@ final case class Start[T <: RDG[T]: TypeTag](
override def rdg: T = context.graphSession.getGraph(graphName).createRDG(alias, types)
}
final case class StartFromVertex[T <: RDG[T]: TypeTag](
graphName: String,
alias: String,
meta: List[Var],
vId: Expr,
types: Set[String])(implicit override val context: PhysicalPlannerContext[T])
extends PhysicalLeafOperator[T] {
override def rdg: T = context.graphSession.getGraph(graphName).createRDG(alias, vId, types)
}
final case class DrivingRDG[T <: RDG[T]: TypeTag](
graphName: String,
meta: List[Var],
alias: String,
workingRdgName: String)(implicit override val context: PhysicalPlannerContext[T])
extends PhysicalLeafOperator[T] {
override def rdg: T = {
val workingRdg = context.graphSession.getWorkingRDG(workingRdgName)
context.graphSession.getGraph(graphName).createRDG(alias, workingRdg)
}
override def types: Set[String] = Set.empty
}

View File

@ -91,6 +91,10 @@ object PhysicalPlanner {
Start(start.graph.graphName, start.alias, start.fields, start.types)(
implicitly[TypeTag[T]],
context)
case start: LogicalOperators.StartFromVertex =>
StartFromVertex(start.graph.graphName, start.alias, start.fields, start.id, start.types)(
implicitly[TypeTag[T]],
context)
case driving: LogicalOperators.Driving =>
DrivingRDG(start.graph.graphName, start.fields, start.alias, workingRdgName)(
implicitly[TypeTag[T]],

View File

@ -15,19 +15,19 @@ package com.antgroup.openspg.reasoner.lube.physical.util
import scala.reflect.runtime.universe.TypeTag
import com.antgroup.openspg.reasoner.lube.physical.operators.{PhysicalOperator, Start}
import com.antgroup.openspg.reasoner.lube.physical.operators.{PhysicalLeafOperator, PhysicalOperator, Start, StartFromVertex}
import com.antgroup.openspg.reasoner.lube.physical.rdg.RDG
import com.antgroup.openspg.reasoner.lube.physical.util.PhysicalOperatorOps.RichPhysicalOperator
object PhysicalOperatorUtil {
def getStartTypes[T <: RDG[T]: TypeTag](physicalOp: PhysicalOperator[T]): Set[String] = {
getStartOp(physicalOp).types
}
def getStartOp[T <: RDG[T]: TypeTag](physicalOp: PhysicalOperator[T]): Start[T] = {
val op = physicalOp.findExactlyOne { case start: Start[T] => }
op.asInstanceOf[Start[T]]
def getStartOp[T <: RDG[T]: TypeTag](
physicalOp: PhysicalOperator[T]): PhysicalLeafOperator[T] = {
val op = physicalOp.findExactlyOne {
case start: Start[T] =>
case start: StartFromVertex[T] =>
}
op.asInstanceOf[PhysicalLeafOperator[T]]
}
}

View File

@ -20,9 +20,9 @@ import com.antgroup.openspg.reasoner.graphstate.GraphState;
import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState;
import com.antgroup.openspg.reasoner.lube.catalog.Catalog;
import com.antgroup.openspg.reasoner.lube.parser.ParserInterface;
import com.antgroup.openspg.reasoner.lube.physical.operators.PhysicalLeafOperator;
import com.antgroup.openspg.reasoner.lube.physical.operators.PhysicalOperator;
import com.antgroup.openspg.reasoner.lube.physical.operators.Select;
import com.antgroup.openspg.reasoner.lube.physical.operators.Start;
import com.antgroup.openspg.reasoner.lube.physical.util.PhysicalOperatorUtil;
import com.antgroup.openspg.reasoner.parser.OpenSPGDslParser;
import com.antgroup.openspg.reasoner.runner.ConfigKey;
@ -115,7 +115,7 @@ public class LocalReasonerRunner {
boolean isLastDsl = (i + 1 == dslDagList.size());
if (isLastDsl) {
Start<LocalRDG> start =
PhysicalLeafOperator<LocalRDG> start =
PhysicalOperatorUtil.getStartOp(
dslDagList.get(i),
com.antgroup.openspg.reasoner.runner.local.rdg.TypeTags.rdgTypeTag());

View File

@ -20,15 +20,19 @@ import com.antgroup.openspg.reasoner.common.graph.vertex.impl.MirrorVertex;
import com.antgroup.openspg.reasoner.common.graph.vertex.impl.NoneVertex;
import com.antgroup.openspg.reasoner.graphstate.GraphState;
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
import com.antgroup.openspg.reasoner.lube.logical.RepeatPathVar;
import com.antgroup.openspg.reasoner.lube.physical.PropertyGraph;
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Expr2QlexpressTransformer;
import com.antgroup.openspg.reasoner.recorder.EmptyRecorder;
import com.antgroup.openspg.reasoner.recorder.IExecutionRecorder;
import com.antgroup.openspg.reasoner.runner.ConfigKey;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
import com.antgroup.openspg.reasoner.runner.local.rdg.LocalRDG;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@ -113,6 +117,34 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
return result;
}
@Override
public LocalRDG createRDG(String alias, Expr id, Set<String> types) {
java.util.Set<IVertexId> startIdSet = new HashSet<>();
Expr2QlexpressTransformer transformer =
new Expr2QlexpressTransformer(RuleRunner::convertPropertyName);
List<String> exprQlList =
Lists.newArrayList(JavaConversions.seqAsJavaList(transformer.transform(id)));
String idStr =
String.valueOf(RuleRunner.getInstance().executeExpression(new HashMap<>(), exprQlList, ""));
for (String type : JavaConversions.asJavaCollection(types)) {
startIdSet.add(IVertexId.from(idStr, type));
}
LocalRDG result =
new LocalRDG(
graphState,
Lists.newArrayList(startIdSet),
threadPoolExecutor,
executorTimeoutMs,
alias,
getTaskId(),
// subquery can not carry all graph
getExecutionRecorder(),
false);
result.setMaxPathLimit(getMaxPathLimit());
result.setStrictMaxPathLimit(getStrictMaxPathLimit());
return result;
}
@Override
public LocalRDG createRDGFromPath(RepeatPathVar repeatVar, String alias, LocalRDG rdg) {
return null;

View File

@ -42,6 +42,11 @@ public class KgReasonerAliasSetKFilmTest {
FileMutex.runTestWithMutex(this::doTest0);
}
@Test
public void testUseRule0() {
FileMutex.runTestWithMutex(this::doTestUseRule0);
}
private void doTest0() {
String dsl =
"\n"
@ -73,6 +78,36 @@ public class KgReasonerAliasSetKFilmTest {
Assert.assertEquals("B", result.get(0)[1]);
}
private void doTestUseRule0() {
String dsl =
"\n"
+ "GraphStructure {\n"
+ " (A:User)-[p1:trans]->(B:User)\n"
+ "}\n"
+ "Rule {\n"
+ " R1: A.id == 'A'"
+ "}\n"
+ "Action {\n"
+ " get(A.id, B.id)\n"
+ "}";
List<String[]> result =
TransBaseTestData.runTestResult(
dsl,
new HashMap<String, Object>() {
{
put(Constants.START_ALIAS, "A");
put(
ConfigKey.KG_REASONER_MOCK_GRAPH_DATA,
"Graph {\n" + " A [User]\n" + " B [User]\n" + " A->B [trans]\n" + "}");
put(ConfigKey.KG_REASONER_OUTPUT_GRAPH, "true");
}
});
Assert.assertEquals(1, result.size());
Assert.assertEquals(2, result.get(0).length);
Assert.assertEquals("A", result.get(0)[0]);
Assert.assertEquals("B", result.get(0)[1]);
}
@Test
public void test1() {
FileMutex.runTestWithMutex(this::doTest1);

View File

@ -89,6 +89,7 @@ public class TransBaseTestData {
params.put(Constants.SPG_REASONER_PLAN_PRETTY_PRINT_LOGGER_ENABLE, true);
params.putAll(runParams);
task.setParams(params);
task.setExecutorTimeoutMs(5 * 60 * 1000);
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "User")));

View File

@ -864,7 +864,7 @@ public class TransitiveOptionalTest {
+ "B->A [relatedReason] as F1\n"
+ "\n"
+ "// 1.8的C\n"
+ "B->C [relatedReason] repeat(1,20) as F3\n"
+ "B->C [relatedReason] repeat(1,2) as F3\n"
+ "}\n"
+ "Rule {\n"
+ " R1: A.id == 'A_730'\n"

View File

@ -445,6 +445,12 @@ object LoaderUtil {
} else {
merge(solvedModel, list.head)
}
case (StartFromVertex(_, _, _, _, solvedModel), list) =>
if (list == null || list.isEmpty) {
solvedModel
} else {
merge(solvedModel, list.head)
}
case (LinkedExpand(_, edgePattern), list) =>
if (edgePattern.edge.funcName.equals(Constants.CONCEPT_EDGE_EXPAND_FUNC_NAME)) {
merge(getConceptEdgeExpandSolvedModel(logicalPlan.graph, edgePattern), list.head)

View File

@ -14,6 +14,7 @@
package com.antgroup.openspg.reasoner.session
import com.antgroup.openspg.reasoner.lube.catalog.Catalog
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.logical.RepeatPathVar
import com.antgroup.openspg.reasoner.lube.parser.ParserInterface
import com.antgroup.openspg.reasoner.lube.physical.PropertyGraph
@ -51,6 +52,15 @@ class EmptyPropertyGraph extends PropertyGraph[EmptyRDG] {
alias: String,
rdg: EmptyRDG): EmptyRDG = rdg
/**
* Start with specific vertex.
*
* @param alias
* @param id
* @param types
* @return
*/
override def createRDG(alias: String, id: Expr, types: Set[String]): EmptyRDG = new EmptyRDG()
}
class EmptySession(parser: ParserInterface, catalog: Catalog)

View File

@ -51,6 +51,6 @@ class PhysicalOpUtilTests extends AnyFunSpec {
catalog.init()
val session = new EmptySession(new OpenSPGDslParser(), catalog)
val rst = session.plan(dsl, Map.empty)
PhysicalOperatorUtil.getStartTypes(rst.head) should equal (Set.apply("User"))
PhysicalOperatorUtil.getStartOp(rst.head).alias should equal ("s")
}
}