From e177b7c6c9b8b9c92d70bc8648c2695ed85bccdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E5=9F=B9=E9=BE=99?= Date: Mon, 15 Apr 2024 20:25:34 +0800 Subject: [PATCH] feat(reasoner): add create node in define block (#191) Co-authored-by: wenchengyao Co-authored-by: FishJoy --- .../reasoner/parser/OpenSPGDslParser.scala | 35 +++-- .../parser/OpenSPGDslParserTest.scala | 8 +- .../reasoner/lube/utils/BlockUtils.scala | 9 +- .../reasoner/lube/logical/SolvedModel.scala | 8 +- .../logical/planning/LogicalPlanner.scala | 18 ++- .../logical/planning/SubQueryPlanner.scala | 4 +- .../local/main/KgReasonerLeadToTest.java | 2 +- .../transitive/TransitiveOptionalTest.java | 131 ++++++++++++++++++ .../rdg/common/ExtractRelationImpl.java | 9 +- 9 files changed, 194 insertions(+), 30 deletions(-) diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala index 7c1c1d32..febe5e4e 100644 --- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala +++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala @@ -139,9 +139,6 @@ class OpenSPGDslParser extends ParserInterface { parseBaseRuleDefine(ctx.base_rule_define(), ddlBlockWithNodes._2, ddlBlockWithNodes._3) val ddlBlockOp = ddlBlockWithNodes._1.ddlOp.head val ruleBlock = ddlInfo._1 - if (ddlInfo._2.nonEmpty) { - return DDLBlock(ddlInfo._2, List.apply(ruleBlock)) - } ddlBlockOp match { case AddProperty(s, propertyName, _) => val isLastAssignTargetAlis = ruleBlock match { @@ -176,7 +173,7 @@ class OpenSPGDslParser extends ParserInterface { ProjectRule( IRProperty(s.alias, propertyName), Ref(ddlBlockWithNodes._3.target.alias))))) - DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk)) + DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(prjBlk)) case AddPredicate(predicate) => val attrFields = new mutable.HashMap[String, Expr]() addPropertiesMap.foreach(x => @@ -209,9 +206,9 @@ class OpenSPGDslParser extends ParserInterface { predicate.source, predicate.target, attrFields.toMap, - predicate.direction))), + predicate.direction))) ++ ddlInfo._2, List.apply(depBlk)) - case _ => DDLBlock(Set.apply(ddlBlockOp), List.apply(ruleBlock)) + case _ => DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(ruleBlock)) } } @@ -298,7 +295,7 @@ class OpenSPGDslParser extends ParserInterface { predicate: PredicateElement): (Block, Set[DDLOp]) = { val matchBlock = parseGraphStructure(ctx.the_graph_structure(), head, predicate) val ruleBlock = parseRule(ctx.the_rule(), matchBlock) - val ddlOp = parseCreateAction(ctx.create_action()) + val ddlOp = parseCreateAction(ctx.create_action(), matchBlock) (ruleBlock, ddlOp) } @@ -600,11 +597,31 @@ class OpenSPGDslParser extends ParserInterface { curBlock } - def parseCreateAction(ctx: Create_actionContext): Set[DDLOp] = { + def parseCreateAction(ctx: Create_actionContext, matchBlock: MatchBlock): Set[DDLOp] = { if (ctx == null) { Set.empty } else { - ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet + val ddlBlockSet = ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet + val matchEleInfo = matchBlock.patterns.map(x => x._2.graphPattern.nodes).flatten + val allEleInfo = ddlBlockSet.map { + case AddVertex(s, _) => s.alias -> s + case _ => null + }.filter(_ != null).toMap ++ matchEleInfo + ddlBlockSet.map { + case c: AddVertex => c + case c: AddProperty => c + case c: AddPredicate => AddPredicate( + PredicateElement( + c.predicate.label, + c.predicate.alias, + allEleInfo(c.predicate.source.alias), + allEleInfo(c.predicate.target.alias), + c.predicate.fields, + c.predicate.direction + ) + ) + }.toSet + } } diff --git a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala index 789e899b..a02ff209 100644 --- a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala +++ b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala @@ -387,8 +387,8 @@ class OpenSPGDslParserTest extends AnyFunSpec { val block = parser.parse(dsl) print(block.pretty) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) - block.asInstanceOf[DDLBlock].ddlOp.size should equal(2) - block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true) + block.asInstanceOf[DDLBlock].ddlOp.size should equal(3) + block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true) } it("addNodeException") { @@ -504,8 +504,8 @@ class OpenSPGDslParserTest extends AnyFunSpec { val block = parser.parse(dsl) print(block.pretty) block.dependencies.head.isInstanceOf[MatchBlock] should equal(true) - block.asInstanceOf[DDLBlock].ddlOp.size should equal(1) - block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true) + block.asInstanceOf[DDLBlock].ddlOp.size should equal(2) + block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddPredicate] should equal(true) } diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/BlockUtils.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/BlockUtils.scala index 6e421109..4e167912 100644 --- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/BlockUtils.scala +++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/BlockUtils.scala @@ -41,15 +41,16 @@ object BlockUtils { predicate.target.typeNames.head).toString) case AddProperty(s, propertyName, _) => defines.add(s.typeNames.head + "." + propertyName) - case AddVertex(s, _) => - // defines.add(s.typeNames.head) - return Set.apply("result") case _ => } }) case _ => defines.add("result") } - defines.toSet + if (defines.isEmpty) { + Set.apply("result") + } else { + defines.toSet + } } } diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/SolvedModel.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/SolvedModel.scala index 97e7ad69..70f19503 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/SolvedModel.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/SolvedModel.scala @@ -65,7 +65,13 @@ case class SolvedModel( } def solve: SolvedModel = { - val tmp = tmpFields.values.map(p => fields(p.name).merge(Option.apply(p))).toList + val tmp = tmpFields.values.map(p => { + if (!fields.contains(p.name)) { + throw InvalidRefVariable(s"not found fields name : ${p} in solved fields") + } + fields(p.name).merge(Option.apply(p)) + } + ).toList var newFields = fields for (t <- tmp) { newFields = newFields.updated(t.name, newFields(t.name).merge(Option.apply(t))) diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/LogicalPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/LogicalPlanner.scala index 8e0765e3..1603c130 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/LogicalPlanner.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/LogicalPlanner.scala @@ -402,10 +402,22 @@ object LogicalPlanner { val starts = new mutable.HashSet[String]() for (ddl <- ddlOp) { ddl match { - case AddProperty(s, _, _) => starts.add(s.alias) + case AddProperty(s, _, _) => + if (starts.isEmpty) { + starts.add(s.alias) + } else { + val common = starts.intersect(Set.apply(s.alias)) + starts.clear() + starts.++=(common) + } case AddPredicate(p) => - starts.add(p.source.alias) - starts.add(p.target.alias) + if (starts.isEmpty) { + starts.++=(Set.apply(p.source.alias, p.target.alias)) + } else { + val common = starts.intersect(Set.apply(p.source.alias, p.target.alias)) + starts.clear() + starts.++=(common) + } case _ => } } diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/SubQueryPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/SubQueryPlanner.scala index 73b0648b..d9ac4a73 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/SubQueryPlanner.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/SubQueryPlanner.scala @@ -254,13 +254,11 @@ class SubQueryPlanner(val dag: Dag[Block])(implicit context: LogicalPlannerConte rootAlias = predicate.target.alias } else if (direction == Direction.OUT && predicate.direction == Direction.IN) { rootAlias == predicate.target.alias - } else { + } else if (direction != null) { rootAlias == predicate.source.alias } case AddProperty(s, _, _) => rootAlias = s.alias - case AddVertex(s, _) => - rootAlias = s.alias case _ => } }) diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerLeadToTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerLeadToTest.java index a8b8ef8b..97ddac4f 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerLeadToTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerLeadToTest.java @@ -84,7 +84,7 @@ public class KgReasonerLeadToTest { LocalReasonerRunner runner = new LocalReasonerRunner(); LocalReasonerResult result = runner.run(task); System.out.println(result); - Assert.assertEquals(1, result.getVertexList().size()); + Assert.assertEquals(2, result.getVertexList().size()); } public static class GraphLoaderForAddVertex extends AbstractLocalGraphLoader { diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java index 7b67b691..a3491436 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/transitive/TransitiveOptionalTest.java @@ -14,6 +14,9 @@ package com.antgroup.openspg.reasoner.runner.local.main.transitive; import com.antgroup.openspg.reasoner.common.constants.Constants; +import com.antgroup.openspg.reasoner.common.graph.edge.IEdge; +import com.antgroup.openspg.reasoner.common.graph.property.IProperty; +import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex; import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState; import com.antgroup.openspg.reasoner.lube.catalog.Catalog; import com.antgroup.openspg.reasoner.lube.catalog.impl.PropertyGraphCatalog; @@ -25,9 +28,11 @@ import com.antgroup.openspg.reasoner.runner.local.loader.MockLocalGraphLoader; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask; import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import org.junit.Assert; import org.junit.Test; @@ -1286,4 +1291,130 @@ public class TransitiveOptionalTest { LocalReasonerResult rst = runTest(schema, dsl, dataGraphStr2); Assert.assertEquals(1, rst.getRows().size()); } + + @Test + public void testCreateInstance() { + String dsl = + "Define (s:Custid)-[p:strNum]->(o:Int) {\n" + + " GraphStructure {\n" + + " (s)<-[pp:hasCust]-(str:STR)\n" + + " }\n" + + " Rule {\n" + + " o = group(s).countIf(str.status == 'CLOSE', str)\n" + + " }\n" + + "}\n" + + "\n" + + "Define (s:Custid)-[p:isInWhiteBlack]->(o:Boolean) {\n" + + " GraphStructure {\n" + + " (s)<-[:hasCust]-(str:STR)\n" + + " }\n" + + " Rule {\n" + + " R1: str.matchrule == '0202'\n" + + " R2: str.isreport == '1' \n" + + "\n" + + " o = (R1 and R2)\n" + + " }\n" + + "}\n" + + "\n" + + "Define (s:Custid)-[p:isAggregator]->(o:Boolean) {\n" + + " GraphStructure {\n" + + " (s)<-[e:complained]-(u1:Custid)\n" + + " }\n" + + " Rule {\n" + + " R0: s.isInWhiteBlack == null or s.isInWhiteBlack == false\n" + + " complainNum = group(s).count(e)\n" + + " R5(\"被投诉大于20条\"): complainNum >=20\n" + + "\n" + + " o = true\n" + + " }\n" + + " Action {\n" + + " gang = createNodeInstance(\n" + + " type=Gang,\n" + + " value={\n" + + " id=concat(s.id, \"_gang\")\n" + + " }\n" + + " )\n" + + " createEdgeInstance(\n" + + " src=gang,\n" + + " dst=s,\n" + + " type=has,\n" + + " value={\n" + + " }\n" + + " )\n" + + " }\n" + + "}\n" + + "GraphStructure {" + + " A [Custid, __start__='true']\n" + + " B [Gang]\n" + + " B->A [has]\n" + + "}\n" + + "Rule {\n" + + "}\n" + + "Action {\n" + + " get(A.id, A.isAggregator, B.id) \n" + + "}"; + + System.out.println(dsl); + LocalReasonerTask task = new LocalReasonerTask(); + task.setDsl(dsl); + + // add mock catalog + Map> schema = new HashMap<>(); + schema.put( + "Custid", + Convert2ScalaUtil.toScalaImmutableSet( + Sets.newHashSet( + "trdAmtIn90d", + "trdAmt90d", + "trdCntCustIn90d", + "custcntpty90CustNum90dInGenderFemale", + "custcntpty90CustNum90dInGenderMale", + "name"))); + schema.put( + "STR", + Convert2ScalaUtil.toScalaImmutableSet( + Sets.newHashSet("conclusion", "name", "status", "matchrule", "isreport"))); + + schema.put("Gang", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("cid", "name"))); + schema.put("Gang_has_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("info"))); + schema.put( + "Custid_complained_Custid", + Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo"))); + schema.put( + "STR_hasCust_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo"))); + + Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema)); + catalog.init(); + task.setCatalog(catalog); + task.setGraphLoadClass( + "com.antgroup.openspg.reasoner.runner.local.main.transitive.TransitiveOptionalTest$GangGraphLoader"); + + // enable subquery + Map params = new HashMap<>(); + params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true); + params.put(Constants.SPG_REASONER_MULTI_VERSION_ENABLE, "true"); + task.setParams(params); + + LocalReasonerRunner runner = new LocalReasonerRunner(); + LocalReasonerResult result = runner.run(task); + Assert.assertEquals(2, result.getRows().size()); + } + + public static class GangGraphLoader extends AbstractLocalGraphLoader { + @Override + public List> genVertexList() { + return Lists.newArrayList( + constructionVertex("A1", "Custid", "name", "A1", "cid", "a1"), + constructionVertex("A2", "Custid", "name", "A2", "cid", "a2"), + constructionVertex("B1", "Gang", "name", "B2", "cid", "b1")); + } + + @Override + public List> genEdgeList() { + return Lists.newArrayList( + constructionEdge("B1", "has", "A1", "info", "b1_a1"), + constructionEdge("B1", "has", "A2", "info", "b1_a2"), + constructionEdge("A1", "complained", "A2", "info", "a1ca2")); + } + } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/ExtractRelationImpl.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/ExtractRelationImpl.java index 5bbbee14..1c5b2b65 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/ExtractRelationImpl.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/ExtractRelationImpl.java @@ -60,7 +60,7 @@ public class ExtractRelationImpl implements Serializable { private final String taskId; - private final PatternElement sourceElement; + private final Element sourceElement; private final EntityElement targetEntityElement; private final PatternElement targetPatternElement; @@ -84,7 +84,7 @@ public class ExtractRelationImpl implements Serializable { this.propertyRuleMap.put(propertyName, rule); } - PatternElement sourceElement = (PatternElement) addPredicate.predicate().source(); + Element se = addPredicate.predicate().source(); Element te = addPredicate.predicate().target(); EntityElement targetEntityElement = null; PatternElement targetPatternElement = null; @@ -114,11 +114,10 @@ public class ExtractRelationImpl implements Serializable { } } if (!this.propertyRuleMap.containsKey(Constants.EDGE_FROM_ID_KEY)) { - this.propertyRuleMap.put( - Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(sourceElement.alias() + ".id")); + this.propertyRuleMap.put(Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(se.alias() + ".id")); } - this.sourceElement = sourceElement; + this.sourceElement = se; this.targetEntityElement = targetEntityElement; this.targetPatternElement = targetPatternElement; }