fix(reasoner): fix bug in FilterPushDown (#113)

Co-authored-by: peilong <peilong.zpl@antgroup.com>
Co-authored-by: youdonghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-02-02 11:31:19 +08:00 committed by GitHub
parent 1c85965fa9
commit 1b3b78b02e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 15 deletions

View File

@ -135,6 +135,12 @@
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>30.1-jre</version>
<exclusions>
<exclusion>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
@ -317,7 +323,7 @@
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.12</version>
<version>1.18.22</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>

View File

@ -55,6 +55,7 @@ object FilterPushDown extends SimpleRule {
}
private def pushDown(filter: Filter, aliasMap: Map[String, Set[String]]): LogicalOperator = {
var keepOriginal = false
var hasPushDown = false
def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = {
case logicalOp: StackingLogicalOperator =>
@ -65,19 +66,28 @@ object FilterPushDown extends SimpleRule {
} else {
logicalOp
}
case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) =>
if (hasPushDown) {
val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias)
if (!aliasMap.keySet.intersect(alias).isEmpty) {
keepOriginal = true
}
}
boundedVarLenExpand
}
val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
if (hasPushDown) {
newRoot.in
if (keepOriginal || !hasPushDown) {
filter
} else {
newRoot
newRoot.in
}
}
private def pushDown2Pattern(
filter: Filter,
aliasMap: Map[String, Set[String]]): (Boolean, LogicalOperator) = {
var keepOriginal = false
var hasPushDown: Boolean = false
def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = {
@ -105,13 +115,21 @@ object FilterPushDown extends SimpleRule {
patternScan
}
}
case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) =>
if (hasPushDown) {
val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias)
if (!aliasMap.keySet.intersect(alias).isEmpty) {
keepOriginal = true
}
}
boundedVarLenExpand
}
val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
if (hasPushDown) {
(hasPushDown, newRoot.in)
if (keepOriginal || !hasPushDown) {
(hasPushDown, filter)
} else {
(hasPushDown, newRoot)
(hasPushDown, newRoot.in)
}
}

View File

@ -124,6 +124,52 @@ public class KgReasonerTransitiveTest {
Assert.assertEquals(result.getRows().get(0)[1], "C5");
}
@Test
public void testTransitiveWithPathLongestWithTailRule() {
String dsl =
"GraphStructure {\n"
+ " A [RelatedParty, __start__='true']\n"
+ " B [RelatedParty]\n"
+ " A->B [holdShare] repeat(1,3) as e\n"
+ "}\n"
+ "Rule {\n"
+ " R0: B.name == 'C4'\n"
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id,B.id) \n"
+ "}";
LocalReasonerResult result = doProcess(dsl);
// check result
Assert.assertEquals(1, result.getRows().size());
Assert.assertEquals(2, result.getRows().get(0).length);
Assert.assertEquals(result.getRows().get(0)[0], "P1");
Assert.assertEquals(result.getRows().get(0)[1], "C4");
}
@Test
public void testTransitiveWithPathLongestWithHeadRule() {
String dsl =
"GraphStructure {\n"
+ " A [RelatedParty, __start__='true']\n"
+ " B [RelatedParty]\n"
+ " A->B [holdShare] repeat(1,10) as e\n"
+ "}\n"
+ "Rule {\n"
+ " R0: A.name == 'P1'\n"
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id,B.id) \n"
+ "}";
LocalReasonerResult result = doProcess(dsl);
// check result
Assert.assertEquals(1, result.getRows().size());
Assert.assertEquals(2, result.getRows().get(0).length);
Assert.assertEquals(result.getRows().get(0)[0], "P1");
Assert.assertEquals(result.getRows().get(0)[1], "C5");
}
@Test
public void testTransitiveWithPathShortest() {
String dsl =
@ -146,6 +192,28 @@ public class KgReasonerTransitiveTest {
Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString()));
}
@Test
public void testTransitiveWithPathShortestWithPath() {
String dsl =
"GraphStructure {\n"
+ " A [RelatedParty, __start__='true']\n"
+ " B [RelatedParty]\n"
+ " A->B [holdShare] repeat(1,10) as e\n"
+ "}\n"
+ "Rule {\n"
+ " R1: group(A).keep_shortest_path(e)\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id,B.id, __path__) \n"
+ "}";
LocalReasonerResult result = doProcess(dsl);
// check result
Assert.assertEquals(1, result.getRows().size());
Assert.assertEquals(3, result.getRows().get(0).length);
Assert.assertEquals(result.getRows().get(0)[0], "P1");
Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString()));
}
@Test
public void testTransitiveWithPathAll() {
String dsl =

View File

@ -715,9 +715,9 @@ public class TransitiveOptionalTest {
+ " R4: e2.relatedReason == 'CONTROL'\n"
+ " R5: e3.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n"
+ " R6: C.entityType == 'CORPORATION'\n"
+ " R7: D.entityType == 'CORPORATION'\n"
+ " R7: (not exists(D)) or (exists(D) and D.entityType == 'CORPORATION')\n"
+ " R8: C.belongCategory != 'MY_GROUP' && C.belongCategory != 'MY_BB'\n"
+ " R9: D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB'\n"
+ " R9: (not exists(D)) or (exists(D) && D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB')\n"
+ " R10: A.id == 'A_89'\n"
+ "}\n"
+ "Action {\n"
@ -944,10 +944,10 @@ public class TransitiveOptionalTest {
+ " R41: X6.entityType == 'CORPORATION'\n"
+ " R42: X7.entityType == 'CORPORATION'\n"
+ " R43: X8.entityType == 'CORPORATION'\n"
+ " R48: Y5.entityType == 'CORPORATION'\n"
+ " R49: Y6.entityType == 'CORPORATION'\n"
+ " R50: Y7.entityType == 'CORPORATION'\n"
+ " R51: Y8.entityType == 'CORPORATION'\n"
+ " R48: (exists(Y5) and Y5.entityType == 'CORPORATION') or (not exists(Y5))\n"
+ " R49: (exists(Y6) and Y6.entityType == 'CORPORATION') or (not exists(Y6))\n"
+ " R50: (exists(Y7) and Y7.entityType == 'CORPORATION') or (not exists(Y7))\n"
+ " R51: (exists(Y8) and Y8.entityType == 'CORPORATION') or (not exists(Y8))\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id, D.id, B3.id, C3.id, B4.id, C4.id, X5.id, X6.id, X7.id, X8.id, Y5.id, Y6.id, Y7.id, Y8.id)\n"
@ -1153,7 +1153,7 @@ public class TransitiveOptionalTest {
+ " totalRate = e1.edges().reduce((x,y) => y.shareholdingRatio/100.0 * x, 1)\n"
+ " R1: totalRate > 0.05\n"
+ " R0(\"只保留最长的路径\"): group(A).keep_longest_path(e1)\n"
+ "R2: B.entityType == 'CORPORATION'\n"
+ "R2: (exists(B) and B.entityType == 'CORPORATION') or (not exist(B))\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id, B.id)\n"
@ -1260,7 +1260,7 @@ public class TransitiveOptionalTest {
+ " R6: F4.votingRatio >= 30 && F6.votingRatio >= 10\n"
+ " R17: F8.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n"
+ " R8: X.entityType == 'CORPORATION'\n"
+ " R9: Y.entityType == 'CORPORATION'\n"
+ " R9: (exists(Y) and Y.entityType == 'CORPORATION') or (not exists(Y))\n"
+ " R10: X.belongCategory != 'MY_GROUP' && X.belongCategory != 'MY_BB'\n"
+ " R11: A.id == 'A_810'\n"
+ "}\n"