mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-26 06:36:59 +00:00
fix(reasoner): fix bug in FilterPushDown (#113)
Co-authored-by: peilong <peilong.zpl@antgroup.com> Co-authored-by: youdonghai <donghai.ydh@antgroup.com>
This commit is contained in:
parent
1c85965fa9
commit
1b3b78b02e
8
pom.xml
8
pom.xml
@ -135,6 +135,12 @@
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>30.1-jre</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
@ -317,7 +323,7 @@
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.18.12</version>
|
||||
<version>1.18.22</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
|
||||
@ -55,6 +55,7 @@ object FilterPushDown extends SimpleRule {
|
||||
}
|
||||
|
||||
private def pushDown(filter: Filter, aliasMap: Map[String, Set[String]]): LogicalOperator = {
|
||||
var keepOriginal = false
|
||||
var hasPushDown = false
|
||||
def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = {
|
||||
case logicalOp: StackingLogicalOperator =>
|
||||
@ -65,19 +66,28 @@ object FilterPushDown extends SimpleRule {
|
||||
} else {
|
||||
logicalOp
|
||||
}
|
||||
case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) =>
|
||||
if (hasPushDown) {
|
||||
val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias)
|
||||
if (!aliasMap.keySet.intersect(alias).isEmpty) {
|
||||
keepOriginal = true
|
||||
}
|
||||
}
|
||||
boundedVarLenExpand
|
||||
}
|
||||
|
||||
val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
|
||||
if (hasPushDown) {
|
||||
newRoot.in
|
||||
if (keepOriginal || !hasPushDown) {
|
||||
filter
|
||||
} else {
|
||||
newRoot
|
||||
newRoot.in
|
||||
}
|
||||
}
|
||||
|
||||
private def pushDown2Pattern(
|
||||
filter: Filter,
|
||||
aliasMap: Map[String, Set[String]]): (Boolean, LogicalOperator) = {
|
||||
var keepOriginal = false
|
||||
var hasPushDown: Boolean = false
|
||||
|
||||
def rewriter: PartialFunction[LogicalOperator, LogicalOperator] = {
|
||||
@ -105,13 +115,21 @@ object FilterPushDown extends SimpleRule {
|
||||
patternScan
|
||||
}
|
||||
}
|
||||
case boundedVarLenExpand @ BoundedVarLenExpand(_, expandInto: ExpandInto, edgePattern, _) =>
|
||||
if (hasPushDown) {
|
||||
val alias = Set.apply(edgePattern.edge.alias, edgePattern.dst.alias)
|
||||
if (!aliasMap.keySet.intersect(alias).isEmpty) {
|
||||
keepOriginal = true
|
||||
}
|
||||
}
|
||||
boundedVarLenExpand
|
||||
}
|
||||
|
||||
val newRoot = BottomUp[LogicalOperator](rewriter).transform(filter).asInstanceOf[Filter]
|
||||
if (hasPushDown) {
|
||||
(hasPushDown, newRoot.in)
|
||||
if (keepOriginal || !hasPushDown) {
|
||||
(hasPushDown, filter)
|
||||
} else {
|
||||
(hasPushDown, newRoot)
|
||||
(hasPushDown, newRoot.in)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -124,6 +124,52 @@ public class KgReasonerTransitiveTest {
|
||||
Assert.assertEquals(result.getRows().get(0)[1], "C5");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransitiveWithPathLongestWithTailRule() {
|
||||
String dsl =
|
||||
"GraphStructure {\n"
|
||||
+ " A [RelatedParty, __start__='true']\n"
|
||||
+ " B [RelatedParty]\n"
|
||||
+ " A->B [holdShare] repeat(1,3) as e\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ " R0: B.name == 'C4'\n"
|
||||
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id,B.id) \n"
|
||||
+ "}";
|
||||
LocalReasonerResult result = doProcess(dsl);
|
||||
// check result
|
||||
Assert.assertEquals(1, result.getRows().size());
|
||||
Assert.assertEquals(2, result.getRows().get(0).length);
|
||||
Assert.assertEquals(result.getRows().get(0)[0], "P1");
|
||||
Assert.assertEquals(result.getRows().get(0)[1], "C4");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransitiveWithPathLongestWithHeadRule() {
|
||||
String dsl =
|
||||
"GraphStructure {\n"
|
||||
+ " A [RelatedParty, __start__='true']\n"
|
||||
+ " B [RelatedParty]\n"
|
||||
+ " A->B [holdShare] repeat(1,10) as e\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ " R0: A.name == 'P1'\n"
|
||||
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id,B.id) \n"
|
||||
+ "}";
|
||||
LocalReasonerResult result = doProcess(dsl);
|
||||
// check result
|
||||
Assert.assertEquals(1, result.getRows().size());
|
||||
Assert.assertEquals(2, result.getRows().get(0).length);
|
||||
Assert.assertEquals(result.getRows().get(0)[0], "P1");
|
||||
Assert.assertEquals(result.getRows().get(0)[1], "C5");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransitiveWithPathShortest() {
|
||||
String dsl =
|
||||
@ -146,6 +192,28 @@ public class KgReasonerTransitiveTest {
|
||||
Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransitiveWithPathShortestWithPath() {
|
||||
String dsl =
|
||||
"GraphStructure {\n"
|
||||
+ " A [RelatedParty, __start__='true']\n"
|
||||
+ " B [RelatedParty]\n"
|
||||
+ " A->B [holdShare] repeat(1,10) as e\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ " R1: group(A).keep_shortest_path(e)\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id,B.id, __path__) \n"
|
||||
+ "}";
|
||||
LocalReasonerResult result = doProcess(dsl);
|
||||
// check result
|
||||
Assert.assertEquals(1, result.getRows().size());
|
||||
Assert.assertEquals(3, result.getRows().get(0).length);
|
||||
Assert.assertEquals(result.getRows().get(0)[0], "P1");
|
||||
Assert.assertTrue("C2,C1".contains(result.getRows().get(0)[1].toString()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransitiveWithPathAll() {
|
||||
String dsl =
|
||||
|
||||
@ -715,9 +715,9 @@ public class TransitiveOptionalTest {
|
||||
+ " R4: e2.relatedReason == 'CONTROL'\n"
|
||||
+ " R5: e3.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n"
|
||||
+ " R6: C.entityType == 'CORPORATION'\n"
|
||||
+ " R7: D.entityType == 'CORPORATION'\n"
|
||||
+ " R7: (not exists(D)) or (exists(D) and D.entityType == 'CORPORATION')\n"
|
||||
+ " R8: C.belongCategory != 'MY_GROUP' && C.belongCategory != 'MY_BB'\n"
|
||||
+ " R9: D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB'\n"
|
||||
+ " R9: (not exists(D)) or (exists(D) && D.belongCategory != 'MY_GROUP' && D.belongCategory != 'MY_BB')\n"
|
||||
+ " R10: A.id == 'A_89'\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
@ -944,10 +944,10 @@ public class TransitiveOptionalTest {
|
||||
+ " R41: X6.entityType == 'CORPORATION'\n"
|
||||
+ " R42: X7.entityType == 'CORPORATION'\n"
|
||||
+ " R43: X8.entityType == 'CORPORATION'\n"
|
||||
+ " R48: Y5.entityType == 'CORPORATION'\n"
|
||||
+ " R49: Y6.entityType == 'CORPORATION'\n"
|
||||
+ " R50: Y7.entityType == 'CORPORATION'\n"
|
||||
+ " R51: Y8.entityType == 'CORPORATION'\n"
|
||||
+ " R48: (exists(Y5) and Y5.entityType == 'CORPORATION') or (not exists(Y5))\n"
|
||||
+ " R49: (exists(Y6) and Y6.entityType == 'CORPORATION') or (not exists(Y6))\n"
|
||||
+ " R50: (exists(Y7) and Y7.entityType == 'CORPORATION') or (not exists(Y7))\n"
|
||||
+ " R51: (exists(Y8) and Y8.entityType == 'CORPORATION') or (not exists(Y8))\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id, D.id, B3.id, C3.id, B4.id, C4.id, X5.id, X6.id, X7.id, X8.id, Y5.id, Y6.id, Y7.id, Y8.id)\n"
|
||||
@ -1153,7 +1153,7 @@ public class TransitiveOptionalTest {
|
||||
+ " totalRate = e1.edges().reduce((x,y) => y.shareholdingRatio/100.0 * x, 1)\n"
|
||||
+ " R1: totalRate > 0.05\n"
|
||||
+ " R0(\"只保留最长的路径\"): group(A).keep_longest_path(e1)\n"
|
||||
+ "R2: B.entityType == 'CORPORATION'\n"
|
||||
+ "R2: (exists(B) and B.entityType == 'CORPORATION') or (not exist(B))\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id, B.id)\n"
|
||||
@ -1260,7 +1260,7 @@ public class TransitiveOptionalTest {
|
||||
+ " R6: F4.votingRatio >= 30 && F6.votingRatio >= 10\n"
|
||||
+ " R17: F8.edges().constraint((pre,cur) => cur.relatedReason == 'CONTROL')\n"
|
||||
+ " R8: X.entityType == 'CORPORATION'\n"
|
||||
+ " R9: Y.entityType == 'CORPORATION'\n"
|
||||
+ " R9: (exists(Y) and Y.entityType == 'CORPORATION') or (not exists(Y))\n"
|
||||
+ " R10: X.belongCategory != 'MY_GROUP' && X.belongCategory != 'MY_BB'\n"
|
||||
+ " R11: A.id == 'A_810'\n"
|
||||
+ "}\n"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user