From 82402af74f39435664bd5cff0965654583beadc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E5=9F=B9=E9=BE=99?= Date: Wed, 27 Dec 2023 20:11:39 +0800 Subject: [PATCH] fix(reasoner): query language qgl concept parser carry empty properties lead to plan failed (#47) --- .../openspg/reasoner/parser/KgDslParser.scala | 72 ++------ .../parser/pattern/PatternParser.scala | 54 +++++- .../parser/pattern/PatternParserTest.scala | 19 +- .../semantic/rules/ConceptExplain.scala | 7 +- .../local/main/KgReasonerABMLocalTest.java | 153 ---------------- .../runner/local/main/KgReasonerAggTest.java | 4 +- .../local/main/KgReasonerZijinLocalTest.java | 166 ------------------ .../session/ReasonerSessionTests.scala | 79 +++++++++ 8 files changed, 164 insertions(+), 390 deletions(-) diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/KgDslParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/KgDslParser.scala index 6182ebd8..fd328a4e 100644 --- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/KgDslParser.scala +++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/KgDslParser.scala @@ -446,9 +446,6 @@ class KgDslParser extends ParserInterface { } def parseRule(ctx: The_ruleContext, matchBlock: MatchBlock): Block = { - val nodesProp = new mutable.HashMap[String, IRNode]() - val edgesProp = new mutable.HashMap[String, IREdge]() - var refFieldsMap: Map[String, Set[String]] = Map.empty var rules: List[Rule] = List.empty if (ctx.rule_expression_body().rule_expression().size() != 0) { rules = ctx @@ -456,6 +453,13 @@ class KgDslParser extends ParserInterface { .rule_expression() .asScala .map(x => exprParser.parseRuleExpression(x)).toList + } + parseRuleBlock(rules, matchBlock.patterns) + } + + def parseRuleBlock(rules: List[Rule], patterns: Map[String, GraphPath]): Block = { + var refFieldsMap: Map[String, Set[String]] = Map.empty + if (rules.nonEmpty) { rules.foreach(rule => { val irFields = RuleUtils.getAllInputFieldInRule(rule, Set.empty, Set.empty) irFields.foreach { @@ -470,51 +474,7 @@ class KgDslParser extends ParserInterface { }) } - val patternMaps = matchBlock.patterns.map(pattern => { - val nodeMaps = pattern._2.graphPattern.nodes.keySet - .map(nodeAlias => { - if (!nodesProp.contains(nodeAlias)) { - nodesProp += (nodeAlias -> IRNode(nodeAlias, Set.empty)) - } - if (refFieldsMap.contains(nodeAlias)) { - val irNode = nodesProp(nodeAlias).copy(fields = - nodesProp(nodeAlias).fields ++ refFieldsMap(nodeAlias)) - nodesProp.put(nodeAlias, irNode) - nodeAlias -> refFieldsMap(nodeAlias) - } else { - nodeAlias -> Set.empty[String] - } - }) - .filter(_ != null) - .toMap - val edgeMaps = pattern._2.graphPattern.edges - .map(edgeSet => { - edgeSet._2 - .map(edge => { - if (!edgesProp.contains(edge.alias)) { - edgesProp += - (edge.alias -> IREdge(edge.alias, Set.empty)) - } - if (refFieldsMap.contains(edge.alias)) { - val irEdge = edgesProp(edge.alias).copy(fields = - edgesProp(edge.alias).fields ++ refFieldsMap(edge.alias)) - edgesProp.put(edge.alias, irEdge) - edge.alias -> refFieldsMap(edge.alias) - } else { - edge.alias -> Set.empty[String] - } - }) - .filter(_ != null) - }) - .flatten - val updatedGraphPattern = - pattern._2.graphPattern.copy(properties = (nodeMaps ++ edgeMaps)) - pattern._1 -> pattern._2.copy(graphPattern = updatedGraphPattern) - }) - - val updatedMatch = matchBlock.copy( - dependencies = List.apply(SourceBlock(KG(nodesProp.toMap, edgesProp.toMap))), - patterns = patternMaps) + val updatedMatch = patternParser.parseSourceAndMatchBlock(refFieldsMap, patterns) var ruleInstructs = Map[Rule, Set[Rule]]() @@ -756,8 +716,6 @@ class KgDslParser extends ParserInterface { } }) - val matchBlock = MatchBlock(List.apply(patternParser.parseSourceBlock(pathMaps)), pathMaps) - if (ctx.element_pattern_where_clause() != null) { val trans: PartialFunction[Expr, Expr] = { case BinaryOpExpr(name, l, r) => @@ -769,14 +727,14 @@ class KgDslParser extends ParserInterface { } val expr = BottomUp(trans) .transform(patternParser.parseElePatternWhereClause(ctx.element_pattern_where_clause())) - FilterBlock( - List.apply(matchBlock), - LogicRule( - "anonymous_rule_" + patternParser.getDefaultAliasNum, - "anonymous_rule_" + patternParser.getDefaultAliasNum, - expr)) + + val rule = LogicRule( + "anonymous_rule_" + patternParser.getDefaultAliasNum, + "anonymous_rule_" + patternParser.getDefaultAliasNum, + expr) + parseRuleBlock(List.apply(rule), pathMaps) } else { - matchBlock + patternParser.parseSourceAndMatchBlock(Map.empty, pathMaps) } } diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParser.scala index bf520b0f..5e7bfa5a 100644 --- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParser.scala +++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParser.scala @@ -151,12 +151,13 @@ class PatternParser extends Serializable { Map.empty), false)) } - MatchBlock(List.apply(parseSourceBlock(pathMaps)), pathMaps) + parseSourceAndMatchBlock(Map.empty, pathMaps) } - def parseSourceBlock(patterns: Map[String, GraphPath]): SourceBlock = { - var nodesProp: Map[String, IRNode] = Map.empty - var edgesProp: Map[String, IREdge] = Map.empty + def parseSourceAndMatchBlock(refFieldsMap: Map[String, Set[String]], + patterns: Map[String, GraphPath]): MatchBlock = { + val nodesProp = new mutable.HashMap[String, IRNode]() + val edgesProp = new mutable.HashMap[String, IREdge]() patterns.foreach(path => { path._2.graphPattern.nodes.foreach(node => { nodesProp += (node._1 -> IRNode(node._1, Set.empty)) @@ -167,7 +168,50 @@ class PatternParser extends Serializable { }) }) }) - SourceBlock(KG(nodesProp, edgesProp)) + + val patternMaps = patterns.map(pattern => { + val nodeMaps = pattern._2.graphPattern.nodes.keySet + .map(nodeAlias => { + if (!nodesProp.contains(nodeAlias)) { + nodesProp += (nodeAlias -> IRNode(nodeAlias, Set.empty)) + } + if (refFieldsMap.contains(nodeAlias)) { + val irNode = nodesProp(nodeAlias).copy(fields = + nodesProp(nodeAlias).fields ++ refFieldsMap(nodeAlias)) + nodesProp.put(nodeAlias, irNode) + nodeAlias -> refFieldsMap(nodeAlias) + } else { + nodeAlias -> Set.empty[String] + } + }) + .filter(_ != null) + .toMap + val edgeMaps = pattern._2.graphPattern.edges + .map(edgeSet => { + edgeSet._2 + .map(edge => { + if (!edgesProp.contains(edge.alias)) { + edgesProp += + (edge.alias -> IREdge(edge.alias, Set.empty)) + } + if (refFieldsMap.contains(edge.alias)) { + val irEdge = edgesProp(edge.alias).copy(fields = + edgesProp(edge.alias).fields ++ refFieldsMap(edge.alias)) + edgesProp.put(edge.alias, irEdge) + edge.alias -> refFieldsMap(edge.alias) + } else { + edge.alias -> Set.empty[String] + } + }) + .filter(_ != null) + }) + .flatten + val updatedGraphPattern = + pattern._2.graphPattern.copy(properties = (nodeMaps ++ edgeMaps)) + pattern._1 -> pattern._2.copy(graphPattern = updatedGraphPattern) + }) + + MatchBlock(List.apply(SourceBlock(KG(nodesProp.toMap, edgesProp.toMap))), patternMaps) } def parseGraphStructureBody( ctx: Graph_structure_bodyContext, diff --git a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParserTest.scala b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParserTest.scala index 5785a6db..f40b98ad 100644 --- a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParserTest.scala +++ b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/pattern/PatternParserTest.scala @@ -13,6 +13,7 @@ package com.antgroup.openspg.reasoner.parser.pattern +import com.antgroup.openspg.reasoner.lube.block.{MatchBlock, SourceBlock} import com.antgroup.openspg.reasoner.parser.{DemoGraphParser, LexerInit} import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal} @@ -44,9 +45,12 @@ class PatternParserTest extends AnyFunSpec { .the_graph_structure() .graph_structure_define()) print(block.pretty) - val text = """└─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(null,Map(A -> (A:FilmPerson), C -> (C:FilmDirector), D -> (D:FilmDirector)),Map(A -> Set((A)->[e1:test]-(C))), C -> Set((C)->[e2:t1]-(D),[1,20]))),Map()),false))) - * └─SourceBlock(graph=KG(Map(A -> IRNode(A,Set()), C -> IRNode(C,Set()), D -> IRNode(D,Set())),Map(e1 -> IREdge(e1,Set()), e2 -> IREdge(e2,Set()))))""".stripMargin('*') - block.pretty should equal(text) + block.isInstanceOf[MatchBlock] should equal(true) + block.asInstanceOf[MatchBlock].dependencies.head.isInstanceOf[SourceBlock] should equal(true) + block.asInstanceOf[MatchBlock] + .dependencies.head.asInstanceOf[SourceBlock].graph.nodes.size should equal(3) + block.asInstanceOf[MatchBlock] + .dependencies.head.asInstanceOf[SourceBlock].graph.edges.size should equal(2) } it("gql") { val s = @@ -74,9 +78,12 @@ class PatternParserTest extends AnyFunSpec { .the_graph_structure() .graph_structure_define()) print(block.pretty) - val text = """└─MatchBlock(patterns=Map(path1 -> GraphPath(path1,GraphPattern(null,Map(A -> (A:Film), B -> (B:FilmStar,FilmDirector)),Map(A -> Set((A)->[p:starOfFilm,starOfDirector]-(B)))),Map()),false), path2 -> GraphPath(path2,GraphPattern(null,Map(B -> (B:FilmStar,FilmDirector), C -> (C:Film)),Map(B -> Set((B)<-[p2:starOfFilm]-(C),BinaryOpExpr(name=BGreaterThan),[1,4]))),Map()),false), path3 -> GraphPath(path3,GraphPattern(null,Map(B -> (B:FilmStar,FilmDirector), D -> (D:Robot.Film)),Map(B -> Set((B)<-[p3:starOfFilm]-(D),BinaryOpExpr(name=BGreaterThan),[1,4]))),Map()),false))) - * └─SourceBlock(graph=KG(Map(A -> IRNode(A,Set()), B -> IRNode(B,Set()), C -> IRNode(C,Set()), D -> IRNode(D,Set())),Map(p -> IREdge(p,Set()), p2 -> IREdge(p2,Set()), p3 -> IREdge(p3,Set()))))""".stripMargin('*') - block.pretty should equal(text) + block.isInstanceOf[MatchBlock] should equal(true) + block.asInstanceOf[MatchBlock].dependencies.head.isInstanceOf[SourceBlock] should equal(true) + block.asInstanceOf[MatchBlock] + .dependencies.head.asInstanceOf[SourceBlock].graph.nodes.size should equal(4) + block.asInstanceOf[MatchBlock] + .dependencies.head.asInstanceOf[SourceBlock].graph.edges.size should equal(3) } it("test demo graph 0") { diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/validate/semantic/rules/ConceptExplain.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/validate/semantic/rules/ConceptExplain.scala index e40ee8e8..1a373f24 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/validate/semantic/rules/ConceptExplain.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/validate/semantic/rules/ConceptExplain.scala @@ -70,7 +70,12 @@ object ConceptExplain extends Explain { graph.getTargetType(ele.label, "belongTo", Direction.IN), null))) newProps.+=((conceptAlias, Set.empty)) - newProps.+=((ele.alias, Set.apply(Constants.NODE_ID_KEY))) + + if (newProps.contains(ele.alias)) { + newProps.+=((ele.alias, Set.apply(Constants.NODE_ID_KEY) ++ newProps(ele.alias))) + } else { + newProps.+=((ele.alias, Set.apply(Constants.NODE_ID_KEY))) + } val connAlias = s"E_${conceptAlias}" val connection = new PatternConnection( connAlias, diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java index 1fd7922e..8736558a 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java @@ -19,7 +19,6 @@ import com.antgroup.openspg.reasoner.common.graph.property.IProperty; import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex; import com.antgroup.openspg.reasoner.lube.catalog.Catalog; import com.antgroup.openspg.reasoner.lube.catalog.impl.PropertyGraphCatalog; -import com.antgroup.openspg.reasoner.recorder.DefaultRecorder; import com.antgroup.openspg.reasoner.runner.ConfigKey; import com.antgroup.openspg.reasoner.runner.local.KGReasonerLocalRunner; import com.antgroup.openspg.reasoner.runner.local.load.graph.AbstractLocalGraphLoader; @@ -583,70 +582,6 @@ public class KgReasonerABMLocalTest { Assert.assertTrue(result.getErrMsg().contains("time unit need in s/ms/us, but this is xx")); } - @Test - public void test8() { - String dsl = - "Define (s:CustFundKG.Account)-[p:transInAmount]->(o:CustFundKG.Account) {\n" - + " GraphStructure {\n" - + " \t(o)<-[t:accountFundContact]-(s)\n" - + " }\n" - + " Rule {\n" - + " transAmount = group(s,o).sum(t.amount)\n" - + " \tp.transAmount = transAmount\n" - + " }\n" - + "}\n" - + "\n" - + "\n" - + "Define (s:CustFundKG.Account)-[p:centralizedTransfer]->(o:Boolean) {\n" - + " GraphStructure {\n" - + " \t(s)-[t:transInAmount]->(u:CustFundKG.Account)\n" - + " }\n" - + " Rule {\n" - + " \ttotalAmount = group(s).sum(t.transAmount)\n" - + " \ttop5Amount = group(s).order_edge_and_slice_sum(t.transAmount, \"desc\", 5)\n" - + " \tR2(\"top5流入资金占比\"): top5Amount*1.0/totalAmount > 0.5\n" - + " \to = rule_value(R2, true, false)\n" - + " }\n" - + "}\n" - + "GraphStructure {\n" - + "(A:CustFundKG.Account)\n" - + "}\n" - + "Rule {\n" - + "}\n" - + "Action {\n" - + " get(A.id, A.centralizedTransfer) \n" - + "}"; - System.out.println(dsl); - LocalReasonerTask task = new LocalReasonerTask(); - task.setDsl(dsl); - - // add mock catalog - Map> schema = new HashMap<>(); - schema.put("CustFundKG.Account", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id"))); - schema.put( - "CustFundKG.Account_accountFundContact_CustFundKG.Account", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("transDate", "amount"))); - Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema)); - catalog.init(); - task.setCatalog(catalog); - - task.setGraphLoadClass( - "com.antgroup.openspg.reasoner.runner.local.main.KgReasonerABMLocalTest$GraphLoader4"); - - // enable subquery - Map params = new HashMap<>(); - params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true); - params.put(ConfigKey.KG_REASONER_OUTPUT_GRAPH, true); - task.setParams(params); - task.setStartIdList(Lists.newArrayList(new Tuple2<>("A", "CustFundKG.Account"))); - - KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); - LocalReasonerResult result = runner.run(task); - - // only u1 - Assert.assertTrue(result != null); - } - public static class GraphLoader4 extends AbstractLocalGraphLoader { @Override @@ -866,94 +801,6 @@ public class KgReasonerABMLocalTest { return catalog; } - @Test - public void test2() { - String dsl = - "Define (s:ABM.Pkg)-[p:pkg2Family]->(o:ABM.BundleAppFamily) {\n" - + " GraphStructure {\n" - + " (s)-[:pkg2bundleApp]->(b:ABM.BundleApp)-[:belong]->(o)\n" - + " }\n" - + "\tRule{\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s:ABM.Pkg)-[p:blackSameFamilyPkg]->(o:ABM.Pkg) {\n" - + " GraphStructure {\n" - + " (s)-[:pkg2Family]->(b:ABM.BundleAppFamily)<-[:pkg2Family]-(o)\n" - + " }\n" - + "\tRule{\n" - + " R1(\"必须是黑包\"): s.algoMarkResult == 'UNSAFE'\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s:ABM.Pkg)-[p:blackRelateApdid]->(o:ABM.Apdid) {\n" - + " GraphStructure {\n" - + " (s)-[:blackSameFamilyPkg]->(p1:ABM.Pkg),\n" - + " (o)-[r:install]->(p1)\n" - + " }\n" - + "\tRule {\n" - + " R0(\"前几位安装\"): r.rn <= 20\n" - + " R3: o.insPkgToolsLabel != ''\n" - + " R4: o.insPkgToolsLabelCnt >= 2\n" - + " \n" - + " num = group(o).count(p1.id)\n" - + " R1(\"数目必须大于2\"): num >=2\n" - + " //排序暂未实现\n" - + " p.pkgNum = num\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s:ABM.Pkg)-[p:relateUser]->(o:ABM.User) {\n" - + " GraphStructure {\n" - + " (s)-[:blackRelateApdid]->(d:ABM.Apdid),\n" - + " (o)-[:acc2apdid]->(d)\n" - + " }\n" - + "\tRule {\n" - + " R1: o.hunterLabel != ''\n" - + " R2: o.hunterLabelCount >=2\n" - + " }\n" - + "}\n" - + "\n" - + "GraphStructure {\n" - + "\t(pkg:ABM.Pkg)-[p:blackRelateApdid]->(did:ABM.Apdid),\n" - + " (pkg)-[:relateUser]->(u:ABM.User)\n" - + "}\n" - + "Rule {\n" - + "}\n" - + "Action {\n" - + "\tget(pkg.id, p.pkgNum, did.id, u.id)\n" - + "}"; - - LocalReasonerTask task = new LocalReasonerTask(); - task.setDsl(dsl); - - Map dslParams = new HashMap<>(); - // use test catalog - Catalog catalog = initABMSchema(); - task.setCatalog(catalog); - - task.setGraphLoadClass( - "com.antgroup.openspg.reasoner.runner.local.main.KgReasonerABMLocalTest$GraphLoader"); - task.setExecutionRecorder(new DefaultRecorder()); - - // enable subquery - Map params = new HashMap<>(); - params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true); - task.setParams(params); - - KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); - LocalReasonerResult result = runner.run(task); - System.out.println(result); - Assert.assertEquals(1, result.getRows().size()); - Assert.assertEquals(4, result.getRows().get(0).length); - Assert.assertEquals("Pkg", result.getRows().get(0)[0]); - Assert.assertEquals("3", result.getRows().get(0)[1]); - Assert.assertEquals("Apdid", result.getRows().get(0)[2]); - Assert.assertEquals("user", result.getRows().get(0)[3]); - - System.out.println(task.getExecutionRecorder().toReadableString()); - } - public static class GraphLoader extends AbstractLocalGraphLoader { @Override diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAggTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAggTest.java index a01f7e50..e208dcf9 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAggTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAggTest.java @@ -53,7 +53,7 @@ public class KgReasonerAggTest { + "} \n" + "Rule {\n" + " totalAmt = group(s, s2).sum(e.payAmt90d)\n" - + " R1(\"必须大于1万\"): totalAmt > 1000\n" + + " R1: totalAmt > 1000\n" + " nums = group(s).count(s2)\n" + " result = rule_value(nums >= 1, true, false)\n" + "}\n" @@ -171,7 +171,7 @@ public class KgReasonerAggTest { + "} \n" + "Rule {\n" + " totalAmt = group(s, s2).sum(e.payAmt90d)\n" - + " R1(\"必须大于1万\"): totalAmt > 1000\n" + + " R1: totalAmt > 1000\n" + " nums = group(s).count(s2)\n" + " result = rule_value(nums >= 1, true, false)\n" + "}\n" diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerZijinLocalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerZijinLocalTest.java index 6575820a..6aacba78 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerZijinLocalTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerZijinLocalTest.java @@ -29,7 +29,6 @@ import com.google.common.collect.Sets; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.apache.commons.lang3.StringUtils; import org.junit.Assert; import org.junit.Test; import scala.Tuple2; @@ -185,171 +184,6 @@ public class KgReasonerZijinLocalTest { } } - @Test - public void test5() { - String dsl = - "Define (u1:AttributePOC.BrinsonAttribute)-[p:belongTo]->(o:`AttributePOC.TaxonomyOfBrinsonAttribute`/`市场贡献大`) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.BrinsonAttribute)-[p1:factorValue]->(u1)\n" - + " }\n" - + "\n" - + " Rule {\n" - + " R1: u1.factorType == \"market\"\n" - + " R4: s.factorType == \"total\"\n" - + " v = (u1.factorValue/ s.factorValue)\n" - + " R2(\"必须大于50%\"): v > 0.5\n" - + " }\n" - + "}\n" - + "\n" - + "Define (u2:AttributePOC.BrinsonAttribute)-[p:belongTo]->(o:`AttributePOC.TaxonomyOfBrinsonAttribute`/`选股贡献大`) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.BrinsonAttribute)-[p1:factorValue]->(u1:AttributePOC.BrinsonAttribute)\n" - + " (s)-[p2:factorValue]->(u2)\n" - + " (s)-[p3:factorValue]->(u3:AttributePOC.BrinsonAttribute)\n" - + " }\n" - + "\n" - + " Rule {\n" - + " R1: u1.factorType == \"cluster\"\n" - + " R2: u2.factorType == \"stock\"\n" - + " R3: u3.factorType == \"trade\"\n" - + " R4: s.factorType == \"total\"\n" - + " v = (u1.factorValue/ s.factorValue + u3.factorValue / s.factorValue)\n" - + " R6(\"必须大于50%\"): v < 0.5\n" - + " R5(\"交易收益大于选股\"): u2.factorValue > u3.factorValue\n" - + " }\n" - + "}\n" - + "\n" - + "Define (u2:AttributePOC.BrinsonAttribute)-[p:belongTo]->(o:`AttributePOC.TaxonomyOfBrinsonAttribute`/`交易贡献大`) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.BrinsonAttribute)-[p1:factorValue]->(u1:AttributePOC.BrinsonAttribute)\n" - + " (s)-[p2:factorValue]->(u2)\n" - + " (s)-[p3:factorValue]->(u3:AttributePOC.BrinsonAttribute)\n" - + " }\n" - + "\n" - + " Rule {\n" - + " R1: u1.factorType == \"cluster\"\n" - + " R2: u2.factorType == \"trade\"\n" - + " R3: u3.factorType == \"stock\"\n" - + " R4: s.factorType == \"total\"\n" - + " v = (u1.factorValue/ s.factorValue + u2.factorValue / s.factorValue)\n" - + " R5(\"必须大于50%\"): v > 0.5\n" - + " R6(\"交易收益大于选股\"): u2.factorValue > u3.factorValue\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s: AttributePOC.TracebackDay)-[p: market]->(o: Float) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.TracebackDay)-[:day]->(f: AttributePOC.BrinsonAttribute)-[:factorValue]->(u1:`AttributePOC.TaxonomyOfBrinsonAttribute`/`市场贡献大`)\n" - + "\t}\n" - + " Rule {\n" - + " o = u1.factorValue\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s: AttributePOC.TracebackDay)-[p: stock]->(o: Float) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.TracebackDay)-[:day]->(f: AttributePOC.BrinsonAttribute)-[:factorValue]->(u1:`AttributePOC.TaxonomyOfBrinsonAttribute`/`选股贡献大`)\n" - + "\t}\n" - + " Rule {\n" - + " o = u1.factorValue\n" - + " }\n" - + "}\n" - + "\n" - + "\n" - + "Define (s: AttributePOC.TracebackDay)-[p: trade]->(o: Float) {\n" - + " GraphStructure {\n" - + " (s:AttributePOC.TracebackDay)-[:day]->(f: AttributePOC.BrinsonAttribute)-[:factorValue]->(u1:`AttributePOC.TaxonomyOfBrinsonAttribute`/`交易贡献大`)\n" - + "\t}\n" - + " Rule {\n" - + " o = u1.factorValue\n" - + " }\n" - + "}\n" - + "\n" - + "Define (s: AttributePOC.TracebackDay)-[p: result]->(o: Text) {\n" - + " GraphStructure {\n" - + " (s)\n" - + "}\n" - + "Rule {\n" - + "// 按照选股、交易、市场的顺序输出\n" - + " str1 = rule_value(s.stock == null, \"\", concat(\"选股\", \": \", s.stock, ', '))\n" - + " str2 = concat(str1, rule_value(s.trade == null, \"\", concat(\"交易\", \": \", s.trade, ', ')))\n" - + " str3 = concat(str2, rule_value(s.market == null, \"\", concat(\"市场\", \": \", s.market)))\n" - + " o = str3\n" - + "}\n" - + "}\n" - + "\n" - + "Define (u1:AttributePOC.Scenario)-[p:belongTo]->(o:`AttributePOC.TaxonomyOfScenario`/`基金收益分析`) {\n" - + " GraphStructure {\n" - + " \t(u1)<-[p1:scConfig]-(s:AttributePOC.TracebackDay)\n" - + " }\n" - + " Rule {\n" - + " R1: s.result != null\n" - + " }\n" - + "}"; - dsl = - dsl - + "GraphStructure {\n" - + " (s: AttributePOC.TracebackDay)\n" - + "}\n" - + "Rule {\n" - + "// 按照选股、交易、市场的顺序输出\n" - + " str1 = rule_value(s.stock == null, \"\", concat(\"选股\", \": \", s.stock, ', '))\n" - + " str2 = rule_value(s.trade == null, \"\", concat(\"交易\", \": \", s.trade, ', '))\n" - + "}\n" - + "Action {\n" - + " get(s.id, str1, str2) \n" - + "}"; - - LocalReasonerTask task = new LocalReasonerTask(); - task.setExecutorTimeoutMs(60 * 1000 * 100); - task.setDsl(dsl); - - Map> schema = new HashMap<>(); - schema.put( - "AttributePOC.TracebackDay", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id"))); - schema.put( - "AttributePOC.Scenario", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id"))); - schema.put( - "AttributePOC.BrinsonAttribute", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id", "factorValue", "factorType"))); - schema.put( - "AttributePOC.TaxonomyOfBrinsonAttribute", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id"))); - schema.put( - "AttributePOC.TaxonomyOfScenario", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id"))); - schema.put( - "AttributePOC.TracebackDay_day_AttributePOC.BrinsonAttribute", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet())); - schema.put( - "AttributePOC.BrinsonAttribute_factorValue_AttributePOC.BrinsonAttribute", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet())); - schema.put( - "AttributePOC.Scenario_scConfig_AttributePOC.TracebackDay", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet())); - schema.put( - "AttributePOC.TracebackDay_scConfig_AttributePOC.Scenario", - Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet())); - - Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema)); - catalog.init(); - task.setCatalog(catalog); - - task.setGraphLoadClass( - "com.antgroup.openspg.reasoner.runner.local.main.KgReasonerZijinLocalTest$GraphLoader2"); - - // enable subquery - Map params = new HashMap<>(); - params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true); - task.setParams(params); - - KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); - LocalReasonerResult result = runner.run(task); - System.out.println(result); - Assert.assertEquals(1, result.getRows().size()); - Assert.assertTrue(StringUtils.isNotBlank(result.getRows().get(0)[1].toString())); - } - public static class GraphLoader2 extends AbstractLocalGraphLoader { @Override diff --git a/reasoner/runner/runner-common/src/test/scala/com/antgroup/reasoner/session/ReasonerSessionTests.scala b/reasoner/runner/runner-common/src/test/scala/com/antgroup/reasoner/session/ReasonerSessionTests.scala index 07dd4763..e039cc92 100644 --- a/reasoner/runner/runner-common/src/test/scala/com/antgroup/reasoner/session/ReasonerSessionTests.scala +++ b/reasoner/runner/runner-common/src/test/scala/com/antgroup/reasoner/session/ReasonerSessionTests.scala @@ -500,6 +500,85 @@ class ReasonerSessionTests extends AnyFunSpec { cnt should equal(5) } + it("test concept check") { + val dsl = + """ + |MATCH (s:`SupplyChain.Industry`/`商贸-资本品商贸`) + |RETURN s.id, s.name + |""".stripMargin + val schema: Map[String, Set[String]] = Map.apply( + "SupplyChain.Industry" -> Set.apply("id", "name"), + "SupplyChain.Product" -> Set.apply("id", "name"), + "SupplyChain.Product_belongTo_SupplyChain.Industry" -> Set.apply("payDate", "bizComment")) + val catalog = new PropertyGraphCatalog(schema) + catalog.init() + val session = new EmptySession(new KgDslParser(), catalog) + val rst = session.plan( + dsl, + Map + .apply( + (Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true), + (Constants.SPG_REASONER_PLAN_PRETTY_PRINT_LOGGER_ENABLE, true)) + .asInstanceOf[Map[String, Object]]) + val cnt = rst.head.transform[Int] { + case (expand: ExpandInto[EmptyRDG], cnt) => + cnt.sum + 1 + case (pattern: PatternScan[EmptyRDG], cnt) => + cnt.sum + 1 + case (_, cnt) => + if (cnt.isEmpty) { + 0 + } else { + cnt.sum + } + } + cnt should equal(2) + } + + + it("test concept check 2") { + val dsl = + """ + |MATCH + | (u:`RiskMining.TaxOfRiskUser`/`赌博App开发者`)-[:developed]->(app:`RiskMining.TaxOfRiskApp`/`赌博应用`), + | (b:`RiskMining.TaxOfRiskUser`/`赌博App老板`)-[:release]->(app) + |RETURN + | u.id, b.id ,app.id""".stripMargin + val schema: Map[String, Set[String]] = Map.apply( + "RiskMining.TaxOfRiskUser" -> Set.apply("id", "name"), + "RiskMining.TaxOfRiskApp" -> Set.apply("id", "name"), + "RiskMining.User" -> Set.apply("id", "name"), + "RiskMining.App" -> Set.apply("id", "name"), + "RiskMining.User_release_RiskMining.App" -> Set.apply("payDate", "bizComment"), + "RiskMining.User_developed_RiskMining.App" -> Set.apply("payDate", "bizComment"), + "RiskMining.User_belongTo_RiskMining.TaxOfRiskUser" -> Set.apply("payDate", "bizComment"), + "RiskMining.App_belongTo_RiskMining.TaxOfRiskApp" -> Set.apply("payDate", "bizComment")) + val catalog = new PropertyGraphCatalog(schema) + catalog.init() + val session = new EmptySession(new KgDslParser(), catalog) + val rst = session.plan( + dsl, + Map + .apply( + (Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true), + (Constants.SPG_REASONER_PLAN_PRETTY_PRINT_LOGGER_ENABLE, true)) + .asInstanceOf[Map[String, Object]]) + val cnt = rst.head.transform[Int] { + case (expand: ExpandInto[EmptyRDG], cnt) => + cnt.sum + 1 + case (pattern: PatternScan[EmptyRDG], cnt) => + cnt.sum + 1 + case (_, cnt) => + if (cnt.isEmpty) { + 0 + } else { + cnt.sum + } + } + cnt should equal(6) + } + + it("test agg count") { val dsl = """