feat(reasoner): support id equal push down. (#202)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-04-16 17:51:16 +08:00 committed by GitHub
parent 28eca52d73
commit 2ac8fe2261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 114 additions and 11 deletions

View File

@ -19,7 +19,7 @@ import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.graph.edge.SPO import com.antgroup.openspg.reasoner.common.graph.edge.SPO
import com.antgroup.openspg.reasoner.lube.block._ import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr} import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BIn, BinaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.common.pattern.GraphPath import com.antgroup.openspg.reasoner.lube.common.pattern.GraphPath
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Block2GraphPathTransformer import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Block2GraphPathTransformer
@ -110,7 +110,7 @@ object BlockUtils {
block.transform[Set[String]] { block.transform[Set[String]] {
case (FilterBlock(_, rule), list) => case (FilterBlock(_, rule), list) =>
rule.getExpr match { rule.getExpr match {
case BinaryOpExpr(BEqual, _, _) => case BinaryOpExpr(BEqual | BIn, _, _) =>
val irFields = ExprUtils.getAllInputFieldInRule(rule.getExpr, null, null) val irFields = ExprUtils.getAllInputFieldInRule(rule.getExpr, null, null)
if (irFields.size != 1 || !irFields.head.isInstanceOf[IRNode] || !irFields.head if (irFields.size != 1 || !irFields.head.isInstanceOf[IRNode] || !irFields.head
.asInstanceOf[IRNode] .asInstanceOf[IRNode]

View File

@ -15,7 +15,7 @@ package com.antgroup.openspg.reasoner.lube.logical.optimizer.rules
import com.antgroup.openspg.reasoner.common.constants.Constants import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.trees.BottomUp import com.antgroup.openspg.reasoner.common.trees.BottomUp
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr, Expr} import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BIn, BinaryOpExpr, Directly, Expr, GetField, OpChainExpr, TypeValidatedExpr, UnaryOpExpr, VConstant, VList}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.logical.operators._ import com.antgroup.openspg.reasoner.lube.logical.operators._
import com.antgroup.openspg.reasoner.lube.logical.optimizer.{Direction, Rule, Up} import com.antgroup.openspg.reasoner.lube.logical.optimizer.{Direction, Rule, Up}
@ -37,7 +37,7 @@ case object IdEqualPushDown extends Rule {
StartFromVertex(start.graph, idExpr, start.types, start.alias, start.solved) StartFromVertex(start.graph, idExpr, start.types, start.alias, start.solved)
} }
val newFilter = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter] val newFilter = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
newFilter.in -> map newFilter -> map
} }
case (start: Start, _) => case (start: Start, _) =>
start -> Map.apply(Constants.START_ALIAS -> start.alias) start -> Map.apply(Constants.START_ALIAS -> start.alias)
@ -48,7 +48,7 @@ case object IdEqualPushDown extends Rule {
return null return null
} }
filter.rule.getExpr match { filter.rule.getExpr match {
case BinaryOpExpr(BEqual, left, right) => case BinaryOpExpr(BEqual | BIn, left, right) =>
val irFields = ExprUtils.getAllInputFieldInRule( val irFields = ExprUtils.getAllInputFieldInRule(
filter.rule.getExpr, filter.rule.getExpr,
filter.solved.getNodeAliasSet, filter.solved.getNodeAliasSet,
@ -62,7 +62,10 @@ case object IdEqualPushDown extends Rule {
.equals(Set.apply(Constants.NODE_ID_KEY))) { .equals(Set.apply(Constants.NODE_ID_KEY))) {
null null
} else { } else {
right left match {
case UnaryOpExpr(GetField(_), _) => right
case _ => left
}
} }
case _ => null case _ => null
} }

View File

@ -287,4 +287,80 @@ class OptimizerTests extends AnyFunSpec {
start.alias should equal("o") start.alias should equal("o")
} }
} }
it("test id equal push down left") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: '1111111' == o.id
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
val schema: Map[String, Set[String]] =
Map.apply(
"test" -> Set.apply("id"),
"test_abc_test" -> Set.empty)
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(
catalog,
parser,
Map
.apply((Constants.SPG_REASONER_MULTI_VERSION_ENABLE, true))
.asInstanceOf[Map[String, Object]])
val dag = Validator.validate(List.apply(block))
val logicalPlan = LogicalPlanner.plan(dag).popRoot()
val rule = Seq(IdEqualPushDown)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan, rule)
optimizedLogicalPlan.findExactlyOne { case start: StartFromVertex =>
start.id should equal(VString("1111111"))
start.alias should equal("o")
}
}
it("test id equal push down with in") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: o.id in ['1111111', '2222222']
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
val schema: Map[String, Set[String]] =
Map.apply(
"test" -> Set.apply("id"),
"test_abc_test" -> Set.empty)
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(
catalog,
parser,
Map
.apply((Constants.SPG_REASONER_MULTI_VERSION_ENABLE, true))
.asInstanceOf[Map[String, Object]])
val dag = Validator.validate(List.apply(block))
val logicalPlan = LogicalPlanner.plan(dag).popRoot()
val rule = Seq(IdEqualPushDown)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan, rule)
optimizedLogicalPlan.findExactlyOne { case start: StartFromVertex =>
start.id.asInstanceOf[VList].list should equal(List.apply("1111111", "2222222"))
start.alias should equal("o")
}
}
} }

View File

@ -133,7 +133,10 @@ class NodeIdToEdgePropertyTests extends AnyFunSpec {
val dsl = val dsl =
""" """
|GraphStructure { |GraphStructure {
| (A:User)-[e1:lk]->(B:User)-[e2:lk]->(C:User) | B [User, __start__ = 'true']
| A, C [User]
| A -> B [lk] as e1
| B -> C [lk] as e2
|} |}
|Rule { |Rule {
| R1(""): e1.weight < e2.weight | R1(""): e1.weight < e2.weight

View File

@ -124,11 +124,32 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
new Expr2QlexpressTransformer(RuleRunner::convertPropertyName); new Expr2QlexpressTransformer(RuleRunner::convertPropertyName);
List<String> exprQlList = List<String> exprQlList =
Lists.newArrayList(JavaConversions.seqAsJavaList(transformer.transform(id))); Lists.newArrayList(JavaConversions.seqAsJavaList(transformer.transform(id)));
String idStr = List<String> idStrList = new ArrayList<>();
String.valueOf(RuleRunner.getInstance().executeExpression(new HashMap<>(), exprQlList, "")); Object idObj = RuleRunner.getInstance().executeExpression(new HashMap<>(), exprQlList, "");
if (idObj instanceof String) {
idStrList.add(String.valueOf(idObj));
} else if (idObj instanceof List) {
List idOList = (List) idObj;
for (Object ido : idOList) {
idStrList.add(String.valueOf(ido));
}
} else if (idObj instanceof String[]) {
String[] idArray = (String[]) idObj;
idStrList.addAll(Lists.newArrayList(idArray));
} else if (idObj instanceof Object[]) {
Object[] idArray = (Object[]) idObj;
for (Object idO : idArray) {
idStrList.add(String.valueOf(idO));
}
}
for (String type : JavaConversions.asJavaCollection(types)) { for (String type : JavaConversions.asJavaCollection(types)) {
for (String idStr : idStrList) {
startIdSet.add(IVertexId.from(idStr, type)); startIdSet.add(IVertexId.from(idStr, type));
} }
}
if (startIdSet.isEmpty()) {
throw new RuntimeException("can not extract start id list");
}
LocalRDG result = LocalRDG result =
new LocalRDG( new LocalRDG(
graphState, graphState,

View File

@ -85,7 +85,7 @@ public class KgReasonerAliasSetKFilmTest {
+ " (A:User)-[p1:trans]->(B:User)\n" + " (A:User)-[p1:trans]->(B:User)\n"
+ "}\n" + "}\n"
+ "Rule {\n" + "Rule {\n"
+ " R1: A.id == 'A'" + " R1: A.id in ['A']"
+ "}\n" + "}\n"
+ "Action {\n" + "Action {\n"
+ " get(A.id, B.id)\n" + " get(A.id, B.id)\n"