From 1b3b78b02e25371568b237cce6f5a0f024425fb6 Mon Sep 17 00:00:00 2001 From: FishJoy Date: Fri, 2 Feb 2024 11:31:19 +0800 Subject: [PATCH] fix(reasoner): fix bug in FilterPushDown (#113) Co-authored-by: peilong Co-authored-by: youdonghai --- pom.xml | 8 ++- .../optimizer/rules/FilterPushDown.scala | 30 ++++++-- .../transitive/KgReasonerTransitiveTest.java | 68 +++++++++++++++++++ .../transitive/TransitiveOptionalTest.java | 16 ++--- 4 files changed, 107 insertions(+), 15 deletions(-) diff --git a/pom.xml b/pom.xml index df1d9fa2..eda21e0f 100644 --- a/pom.xml +++ b/pom.xml @@ -135,6 +135,12 @@ com.google.guava guava 30.1-jre + + + org.projectlombok + lombok + + org.apache.commons @@ -317,7 +323,7 @@ org.projectlombok lombok - 1.18.12 + 1.18.22 ch.qos.logback diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/optimizer/rules/FilterPushDown.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/optimizer/rules/FilterPushDown.scala index 8752e5c6..fd3ab05d 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/optimizer/rules/FilterPushDown.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/optimizer/rules/FilterPushDown.scala @@ -55,6 +55,7 @@ object FilterPushDown extends SimpleRule { } private def pushDown(filter: Filter, aliasMap: Map[String, Set[String]]): LogicalOperator = { + var keepOriginal = false var hasPushDown = false def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = { case logicalOp: StackingLogicalOperator => @@ -65,19 +66,28 @@ object FilterPushDown extends SimpleRule { } else { logicalOp } + case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) => + if (hasPushDown) { + val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias) + if (!aliasMap.keySet.intersect(alias).isEmpty) { + keepOriginal = true + } + } + boundedVarLenExpand } val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter] - if (hasPushDown) { - newRoot.in + if (keepOriginal || !hasPushDown) { + filter } else { - newRoot + newRoot.in } } private def pushDown2Pattern( filter: Filter, aliasMap: Map[String, Set[String]]): (Boolean, LogicalOperator) = { + var keepOriginal = false var hasPushDown: Boolean = false def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = { @@ -105,13 +115,21 @@ object FilterPushDown extends SimpleRule { patternScan } } + case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) => + if (hasPushDown) { + val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias) + if (!aliasMap.keySet.intersect(alias).isEmpty) { + keepOriginal = true + } + } + boundedVarLenExpand } val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter] - if (hasPushDown) { - (hasPushDown, newRoot.in) + if (keepOriginal || !hasPushDown) { + (hasPushDown, filter) } else { - (hasPushDown, newRoot) + (hasPushDown, newRoot.in) } } diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/KgReasonerTransitiveTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/KgReasonerTransitiveTest.java index 10a1a0c4..b0f1f1cb 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/KgReasonerTransitiveTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/KgReasonerTransitiveTest.java @@ -124,6 +124,52 @@ public class KgReasonerTransitiveTest { Assert.assertEquals(result.getRows().get(0)[1], "C5"); } + @Test + public void testTransitiveWithPathLongestWithTailRule() { + String dsl = + "GraphStructure {\n" + + " A [RelatedParty, __start__='true']\n" + + " B [RelatedParty]\n" + + " A->B [holdShare] repeat(1,3) as e\n" + + "}\n" + + "Rule {\n" + + " R0: B.name == 'C4'\n" + + " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n" + + "}\n" + + "Action {\n" + + " get(A.id,B.id) \n" + + "}"; + LocalReasonerResult result = doProcess(dsl); + // check result + Assert.assertEquals(1, result.getRows().size()); + Assert.assertEquals(2, result.getRows().get(0).length); + Assert.assertEquals(result.getRows().get(0)[0], "P1"); + Assert.assertEquals(result.getRows().get(0)[1], "C4"); + } + + @Test + public void testTransitiveWithPathLongestWithHeadRule() { + String dsl = + "GraphStructure {\n" + + " A [RelatedParty, __start__='true']\n" + + " B [RelatedParty]\n" + + " A->B [holdShare] repeat(1,10) as e\n" + + "}\n" + + "Rule {\n" + + " R0: A.name == 'P1'\n" + + " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n" + + "}\n" + + "Action {\n" + + " get(A.id,B.id) \n" + + "}"; + LocalReasonerResult result = doProcess(dsl); + // check result + Assert.assertEquals(1, result.getRows().size()); + Assert.assertEquals(2, result.getRows().get(0).length); + Assert.assertEquals(result.getRows().get(0)[0], "P1"); + Assert.assertEquals(result.getRows().get(0)[1], "C5"); + } + @Test public void testTransitiveWithPathShortest() { String dsl = @@ -146,6 +192,28 @@ public class KgReasonerTransitiveTest { Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString())); } + @Test + public void testTransitiveWithPathShortestWithPath() { + String dsl = + "GraphStructure {\n" + + " A [RelatedParty, __start__='true']\n" + + " B [RelatedParty]\n" + + " A->B [holdShare] repeat(1,10) as e\n" + + "}\n" + + "Rule {\n" + + " R1: group(A).keep_shortest_path(e)\n" + + "}\n" + + "Action {\n" + + " get(A.id,B.id, __path__) \n" + + "}"; + LocalReasonerResult result = doProcess(dsl); + // check result + Assert.assertEquals(1, result.getRows().size()); + Assert.assertEquals(3, result.getRows().get(0).length); + Assert.assertEquals(result.getRows().get(0)[0], "P1"); + Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString())); + } + @Test public void testTransitiveWithPathAll() { String dsl = diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java index 0a451877..7b67b691 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java @@ -715,9 +715,9 @@ public class TransitiveOptionalTest { + " R4: e2.relatedReason == 'CONTROL'\n" + " R5: e3.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n" + " R6: C.entityType == 'CORPORATION'\n" - + " R7: D.entityType == 'CORPORATION'\n" + + " R7: (not exists(D)) or (exists(D) and D.entityType == 'CORPORATION')\n" + " R8: C.belongCategory != 'MY_GROUP' && C.belongCategory != 'MY_BB'\n" - + " R9: D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB'\n" + + " R9: (not exists(D)) or (exists(D) && D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB')\n" + " R10: A.id == 'A_89'\n" + "}\n" + "Action {\n" @@ -944,10 +944,10 @@ public class TransitiveOptionalTest { + " R41: X6.entityType == 'CORPORATION'\n" + " R42: X7.entityType == 'CORPORATION'\n" + " R43: X8.entityType == 'CORPORATION'\n" - + " R48: Y5.entityType == 'CORPORATION'\n" - + " R49: Y6.entityType == 'CORPORATION'\n" - + " R50: Y7.entityType == 'CORPORATION'\n" - + " R51: Y8.entityType == 'CORPORATION'\n" + + " R48: (exists(Y5) and Y5.entityType == 'CORPORATION') or (not exists(Y5))\n" + + " R49: (exists(Y6) and Y6.entityType == 'CORPORATION') or (not exists(Y6))\n" + + " R50: (exists(Y7) and Y7.entityType == 'CORPORATION') or (not exists(Y7))\n" + + " R51: (exists(Y8) and Y8.entityType == 'CORPORATION') or (not exists(Y8))\n" + "}\n" + "Action {\n" + " get(A.id, D.id, B3.id, C3.id, B4.id, C4.id, X5.id, X6.id, X7.id, X8.id, Y5.id, Y6.id, Y7.id, Y8.id)\n" @@ -1153,7 +1153,7 @@ public class TransitiveOptionalTest { + " totalRate = e1.edges().reduce((x,y) => y.shareholdingRatio/100.0 * x, 1)\n" + " R1: totalRate > 0.05\n" + " R0(\"只保留最长的路径\"): group(A).keep_longest_path(e1)\n" - + "R2: B.entityType == 'CORPORATION'\n" + + "R2: (exists(B) and B.entityType == 'CORPORATION') or (not exist(B))\n" + "}\n" + "Action {\n" + " get(A.id, B.id)\n" @@ -1260,7 +1260,7 @@ public class TransitiveOptionalTest { + " R6: F4.votingRatio >= 30 && F6.votingRatio >= 10\n" + " R17: F8.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n" + " R8: X.entityType == 'CORPORATION'\n" - + " R9: Y.entityType == 'CORPORATION'\n" + + " R9: (exists(Y) and Y.entityType == 'CORPORATION') or (not exists(Y))\n" + " R10: X.belongCategory != 'MY_GROUP' && X.belongCategory != 'MY_BB'\n" + " R11: A.id == 'A_810'\n" + "}\n"