feat(reasoner): support id equal push down. (#202)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-04-16 17:51:16 +08:00 committed by GitHub
parent 28eca52d73
commit 2ac8fe2261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 114 additions and 11 deletions

View File

@ -19,7 +19,7 @@ import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.graph.edge.SPO
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BIn, BinaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.common.pattern.GraphPath
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Block2GraphPathTransformer
@ -110,7 +110,7 @@ object BlockUtils {
block.transform[Set[String]] {
case (FilterBlock(_, rule), list) =>
rule.getExpr match {
case BinaryOpExpr(BEqual, _, _) =>
case BinaryOpExpr(BEqual | BIn, _, _) =>
val irFields = ExprUtils.getAllInputFieldInRule(rule.getExpr, null, null)
if (irFields.size != 1 || !irFields.head.isInstanceOf[IRNode] || !irFields.head
.asInstanceOf[IRNode]

View File

@ -15,7 +15,7 @@ package com.antgroup.openspg.reasoner.lube.logical.optimizer.rules
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.trees.BottomUp
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BinaryOpExpr, Expr}
import com.antgroup.openspg.reasoner.lube.common.expr.{BEqual, BIn, BinaryOpExpr, Directly, Expr, GetField, OpChainExpr, TypeValidatedExpr, UnaryOpExpr, VConstant, VList}
import com.antgroup.openspg.reasoner.lube.common.graph.IRNode
import com.antgroup.openspg.reasoner.lube.logical.operators._
import com.antgroup.openspg.reasoner.lube.logical.optimizer.{Direction, Rule, Up}
@ -37,7 +37,7 @@ case object IdEqualPushDown extends Rule {
StartFromVertex(start.graph, idExpr, start.types, start.alias, start.solved)
}
val newFilter = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
newFilter.in -> map
newFilter -> map
}
case (start: Start, _) =>
start -> Map.apply(Constants.START_ALIAS -> start.alias)
@ -48,7 +48,7 @@ case object IdEqualPushDown extends Rule {
return null
}
filter.rule.getExpr match {
case BinaryOpExpr(BEqual, left, right) =>
case BinaryOpExpr(BEqual | BIn, left, right) =>
val irFields = ExprUtils.getAllInputFieldInRule(
filter.rule.getExpr,
filter.solved.getNodeAliasSet,
@ -62,7 +62,10 @@ case object IdEqualPushDown extends Rule {
.equals(Set.apply(Constants.NODE_ID_KEY))) {
null
} else {
right
left match {
case UnaryOpExpr(GetField(_), _) => right
case _ => left
}
}
case _ => null
}

View File

@ -287,4 +287,80 @@ class OptimizerTests extends AnyFunSpec {
start.alias should equal("o")
}
}
it("test id equal push down left") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: '1111111' == o.id
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
val schema: Map[String, Set[String]] =
Map.apply(
"test" -> Set.apply("id"),
"test_abc_test" -> Set.empty)
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(
catalog,
parser,
Map
.apply((Constants.SPG_REASONER_MULTI_VERSION_ENABLE, true))
.asInstanceOf[Map[String, Object]])
val dag = Validator.validate(List.apply(block))
val logicalPlan = LogicalPlanner.plan(dag).popRoot()
val rule = Seq(IdEqualPushDown)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan, rule)
optimizedLogicalPlan.findExactlyOne { case start: StartFromVertex =>
start.id should equal(VString("1111111"))
start.alias should equal("o")
}
}
it("test id equal push down with in") {
val dsl =
"""
|GraphStructure {
| (s: test)-[p: abc]->(o: test)
|}
|Rule {
| R1: o.id in ['1111111', '2222222']
|}
|Action {
| get(s.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
val schema: Map[String, Set[String]] =
Map.apply(
"test" -> Set.apply("id"),
"test_abc_test" -> Set.empty)
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(
catalog,
parser,
Map
.apply((Constants.SPG_REASONER_MULTI_VERSION_ENABLE, true))
.asInstanceOf[Map[String, Object]])
val dag = Validator.validate(List.apply(block))
val logicalPlan = LogicalPlanner.plan(dag).popRoot()
val rule = Seq(IdEqualPushDown)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan, rule)
optimizedLogicalPlan.findExactlyOne { case start: StartFromVertex =>
start.id.asInstanceOf[VList].list should equal(List.apply("1111111", "2222222"))
start.alias should equal("o")
}
}
}

View File

@ -133,7 +133,10 @@ class NodeIdToEdgePropertyTests extends AnyFunSpec {
val dsl =
"""
|GraphStructure {
| (A:User)-[e1:lk]->(B:User)-[e2:lk]->(C:User)
| B [User, __start__ = 'true']
| A, C [User]
| A -> B [lk] as e1
| B -> C [lk] as e2
|}
|Rule {
| R1(""): e1.weight < e2.weight

View File

@ -124,10 +124,31 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
new Expr2QlexpressTransformer(RuleRunner::convertPropertyName);
List<String> exprQlList =
Lists.newArrayList(JavaConversions.seqAsJavaList(transformer.transform(id)));
String idStr =
String.valueOf(RuleRunner.getInstance().executeExpression(new HashMap<>(), exprQlList, ""));
List<String> idStrList = new ArrayList<>();
Object idObj = RuleRunner.getInstance().executeExpression(new HashMap<>(), exprQlList, "");
if (idObj instanceof String) {
idStrList.add(String.valueOf(idObj));
} else if (idObj instanceof List) {
List idOList = (List) idObj;
for (Object ido : idOList) {
idStrList.add(String.valueOf(ido));
}
} else if (idObj instanceof String[]) {
String[] idArray = (String[]) idObj;
idStrList.addAll(Lists.newArrayList(idArray));
} else if (idObj instanceof Object[]) {
Object[] idArray = (Object[]) idObj;
for (Object idO : idArray) {
idStrList.add(String.valueOf(idO));
}
}
for (String type : JavaConversions.asJavaCollection(types)) {
startIdSet.add(IVertexId.from(idStr, type));
for (String idStr : idStrList) {
startIdSet.add(IVertexId.from(idStr, type));
}
}
if (startIdSet.isEmpty()) {
throw new RuntimeException("can not extract start id list");
}
LocalRDG result =
new LocalRDG(

View File

@ -85,7 +85,7 @@ public class KgReasonerAliasSetKFilmTest {
+ " (A:User)-[p1:trans]->(B:User)\n"
+ "}\n"
+ "Rule {\n"
+ " R1: A.id == 'A'"
+ " R1: A.id in ['A']"
+ "}\n"
+ "Action {\n"
+ " get(A.id, B.id)\n"