feat(reasoner): add create node in define block (#191)

Co-authored-by: wenchengyao <wenchengyao.wcy@antgroup.com>
Co-authored-by: FishJoy <chengqiang.cq@antgroup.com>
This commit is contained in:
赵培龙 2024-04-15 20:25:34 +08:00 committed by GitHub
parent 9c84f14708
commit e177b7c6c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 194 additions and 30 deletions

View File

@ -139,9 +139,6 @@ class OpenSPGDslParser extends ParserInterface {
parseBaseRuleDefine(ctx.base_rule_define(), ddlBlockWithNodes._2, ddlBlockWithNodes._3)
val ddlBlockOp = ddlBlockWithNodes._1.ddlOp.head
val ruleBlock = ddlInfo._1
if (ddlInfo._2.nonEmpty) {
return DDLBlock(ddlInfo._2, List.apply(ruleBlock))
}
ddlBlockOp match {
case AddProperty(s, propertyName, _) =>
val isLastAssignTargetAlis = ruleBlock match {
@ -176,7 +173,7 @@ class OpenSPGDslParser extends ParserInterface {
ProjectRule(
IRProperty(s.alias, propertyName),
Ref(ddlBlockWithNodes._3.target.alias)))))
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(prjBlk))
case AddPredicate(predicate) =>
val attrFields = new mutable.HashMap[String, Expr]()
addPropertiesMap.foreach(x =>
@ -209,9 +206,9 @@ class OpenSPGDslParser extends ParserInterface {
predicate.source,
predicate.target,
attrFields.toMap,
predicate.direction))),
predicate.direction))) ++ ddlInfo._2,
List.apply(depBlk))
case _ => DDLBlock(Set.apply(ddlBlockOp), List.apply(ruleBlock))
case _ => DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(ruleBlock))
}
}
@ -298,7 +295,7 @@ class OpenSPGDslParser extends ParserInterface {
predicate: PredicateElement): (Block, Set[DDLOp]) = {
val matchBlock = parseGraphStructure(ctx.the_graph_structure(), head, predicate)
val ruleBlock = parseRule(ctx.the_rule(), matchBlock)
val ddlOp = parseCreateAction(ctx.create_action())
val ddlOp = parseCreateAction(ctx.create_action(), matchBlock)
(ruleBlock, ddlOp)
}
@ -600,11 +597,31 @@ class OpenSPGDslParser extends ParserInterface {
curBlock
}
def parseCreateAction(ctx: Create_actionContext): Set[DDLOp] = {
def parseCreateAction(ctx: Create_actionContext, matchBlock: MatchBlock): Set[DDLOp] = {
if (ctx == null) {
Set.empty
} else {
ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet
val ddlBlockSet = ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet
val matchEleInfo = matchBlock.patterns.map(x => x._2.graphPattern.nodes).flatten
val allEleInfo = ddlBlockSet.map {
case AddVertex(s, _) => s.alias -> s
case _ => null
}.filter(_ != null).toMap ++ matchEleInfo
ddlBlockSet.map {
case c: AddVertex => c
case c: AddProperty => c
case c: AddPredicate => AddPredicate(
PredicateElement(
c.predicate.label,
c.predicate.alias,
allEleInfo(c.predicate.source.alias),
allEleInfo(c.predicate.target.alias),
c.predicate.fields,
c.predicate.direction
)
)
}.toSet
}
}

View File

@ -387,8 +387,8 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
block.asInstanceOf[DDLBlock].ddlOp.size should equal(2)
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true)
block.asInstanceOf[DDLBlock].ddlOp.size should equal(3)
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
}
it("addNodeException") {
@ -504,8 +504,8 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
block.dependencies.head.isInstanceOf[MatchBlock] should equal(true)
block.asInstanceOf[DDLBlock].ddlOp.size should equal(1)
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true)
block.asInstanceOf[DDLBlock].ddlOp.size should equal(2)
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddPredicate] should equal(true)
}

View File

@ -41,15 +41,16 @@ object BlockUtils {
predicate.target.typeNames.head).toString)
case AddProperty(s, propertyName, _) =>
defines.add(s.typeNames.head + "." + propertyName)
case AddVertex(s, _) =>
// defines.add(s.typeNames.head)
return Set.apply("result")
case _ =>
}
})
case _ => defines.add("result")
}
defines.toSet
if (defines.isEmpty) {
Set.apply("result")
} else {
defines.toSet
}
}
}

View File

@ -65,7 +65,13 @@ case class SolvedModel(
}
def solve: SolvedModel = {
val tmp = tmpFields.values.map(p => fields(p.name).merge(Option.apply(p))).toList
val tmp = tmpFields.values.map(p => {
if (!fields.contains(p.name)) {
throw InvalidRefVariable(s"not found fields name : ${p} in solved fields")
}
fields(p.name).merge(Option.apply(p))
}
).toList
var newFields = fields
for (t <- tmp) {
newFields = newFields.updated(t.name, newFields(t.name).merge(Option.apply(t)))

View File

@ -402,10 +402,22 @@ object LogicalPlanner {
val starts = new mutable.HashSet[String]()
for (ddl <- ddlOp) {
ddl match {
case AddProperty(s, _, _) => starts.add(s.alias)
case AddProperty(s, _, _) =>
if (starts.isEmpty) {
starts.add(s.alias)
} else {
val common = starts.intersect(Set.apply(s.alias))
starts.clear()
starts.++=(common)
}
case AddPredicate(p) =>
starts.add(p.source.alias)
starts.add(p.target.alias)
if (starts.isEmpty) {
starts.++=(Set.apply(p.source.alias, p.target.alias))
} else {
val common = starts.intersect(Set.apply(p.source.alias, p.target.alias))
starts.clear()
starts.++=(common)
}
case _ =>
}
}

View File

@ -254,13 +254,11 @@ class SubQueryPlanner(val dag: Dag[Block])(implicit context: LogicalPlannerConte
rootAlias = predicate.target.alias
} else if (direction == Direction.OUT && predicate.direction == Direction.IN) {
rootAlias == predicate.target.alias
} else {
} else if (direction != null) {
rootAlias == predicate.source.alias
}
case AddProperty(s, _, _) =>
rootAlias = s.alias
case AddVertex(s, _) =>
rootAlias = s.alias
case _ =>
}
})

View File

@ -84,7 +84,7 @@ public class KgReasonerLeadToTest {
LocalReasonerRunner runner = new LocalReasonerRunner();
LocalReasonerResult result = runner.run(task);
System.out.println(result);
Assert.assertEquals(1, result.getVertexList().size());
Assert.assertEquals(2, result.getVertexList().size());
}
public static class GraphLoaderForAddVertex extends AbstractLocalGraphLoader {

View File

@ -14,6 +14,9 @@
package com.antgroup.openspg.reasoner.runner.local.main.transitive;
import com.antgroup.openspg.reasoner.common.constants.Constants;
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex;
import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState;
import com.antgroup.openspg.reasoner.lube.catalog.Catalog;
import com.antgroup.openspg.reasoner.lube.catalog.impl.PropertyGraphCatalog;
@ -25,9 +28,11 @@ import com.antgroup.openspg.reasoner.runner.local.loader.MockLocalGraphLoader;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
@ -1286,4 +1291,130 @@ public class TransitiveOptionalTest {
LocalReasonerResult rst = runTest(schema, dsl, dataGraphStr2);
Assert.assertEquals(1, rst.getRows().size());
}
@Test
public void testCreateInstance() {
String dsl =
"Define (s:Custid)-[p:strNum]->(o:Int) {\n"
+ " GraphStructure {\n"
+ " (s)<-[pp:hasCust]-(str:STR)\n"
+ " }\n"
+ " Rule {\n"
+ " o = group(s).countIf(str.status == 'CLOSE', str)\n"
+ " }\n"
+ "}\n"
+ "\n"
+ "Define (s:Custid)-[p:isInWhiteBlack]->(o:Boolean) {\n"
+ " GraphStructure {\n"
+ " (s)<-[:hasCust]-(str:STR)\n"
+ " }\n"
+ " Rule {\n"
+ " R1: str.matchrule == '0202'\n"
+ " R2: str.isreport == '1' \n"
+ "\n"
+ " o = (R1 and R2)\n"
+ " }\n"
+ "}\n"
+ "\n"
+ "Define (s:Custid)-[p:isAggregator]->(o:Boolean) {\n"
+ " GraphStructure {\n"
+ " (s)<-[e:complained]-(u1:Custid)\n"
+ " }\n"
+ " Rule {\n"
+ " R0: s.isInWhiteBlack == null or s.isInWhiteBlack == false\n"
+ " complainNum = group(s).count(e)\n"
+ " R5(\"被投诉大于20条\"): complainNum >=20\n"
+ "\n"
+ " o = true\n"
+ " }\n"
+ " Action {\n"
+ " gang = createNodeInstance(\n"
+ " type=Gang,\n"
+ " value={\n"
+ " id=concat(s.id, \"_gang\")\n"
+ " }\n"
+ " )\n"
+ " createEdgeInstance(\n"
+ " src=gang,\n"
+ " dst=s,\n"
+ " type=has,\n"
+ " value={\n"
+ " }\n"
+ " )\n"
+ " }\n"
+ "}\n"
+ "GraphStructure {"
+ " A [Custid, __start__='true']\n"
+ " B [Gang]\n"
+ " B->A [has]\n"
+ "}\n"
+ "Rule {\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id, A.isAggregator, B.id) \n"
+ "}";
System.out.println(dsl);
LocalReasonerTask task = new LocalReasonerTask();
task.setDsl(dsl);
// add mock catalog
Map<String, Set<String>> schema = new HashMap<>();
schema.put(
"Custid",
Convert2ScalaUtil.toScalaImmutableSet(
Sets.newHashSet(
"trdAmtIn90d",
"trdAmt90d",
"trdCntCustIn90d",
"custcntpty90CustNum90dInGenderFemale",
"custcntpty90CustNum90dInGenderMale",
"name")));
schema.put(
"STR",
Convert2ScalaUtil.toScalaImmutableSet(
Sets.newHashSet("conclusion", "name", "status", "matchrule", "isreport")));
schema.put("Gang", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("cid", "name")));
schema.put("Gang_has_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("info")));
schema.put(
"Custid_complained_Custid",
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo")));
schema.put(
"STR_hasCust_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo")));
Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema));
catalog.init();
task.setCatalog(catalog);
task.setGraphLoadClass(
"com.antgroup.openspg.reasoner.runner.local.main.transitive.TransitiveOptionalTest$GangGraphLoader");
// enable subquery
Map<String, Object> params = new HashMap<>();
params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true);
params.put(Constants.SPG_REASONER_MULTI_VERSION_ENABLE, "true");
task.setParams(params);
LocalReasonerRunner runner = new LocalReasonerRunner();
LocalReasonerResult result = runner.run(task);
Assert.assertEquals(2, result.getRows().size());
}
public static class GangGraphLoader extends AbstractLocalGraphLoader {
@Override
public List<IVertex<String, IProperty>> genVertexList() {
return Lists.newArrayList(
constructionVertex("A1", "Custid", "name", "A1", "cid", "a1"),
constructionVertex("A2", "Custid", "name", "A2", "cid", "a2"),
constructionVertex("B1", "Gang", "name", "B2", "cid", "b1"));
}
@Override
public List<IEdge<String, IProperty>> genEdgeList() {
return Lists.newArrayList(
constructionEdge("B1", "has", "A1", "info", "b1_a1"),
constructionEdge("B1", "has", "A2", "info", "b1_a2"),
constructionEdge("A1", "complained", "A2", "info", "a1ca2"));
}
}
}

View File

@ -60,7 +60,7 @@ public class ExtractRelationImpl implements Serializable {
private final String taskId;
private final PatternElement sourceElement;
private final Element sourceElement;
private final EntityElement targetEntityElement;
private final PatternElement targetPatternElement;
@ -84,7 +84,7 @@ public class ExtractRelationImpl implements Serializable {
this.propertyRuleMap.put(propertyName, rule);
}
PatternElement sourceElement = (PatternElement) addPredicate.predicate().source();
Element se = addPredicate.predicate().source();
Element te = addPredicate.predicate().target();
EntityElement targetEntityElement = null;
PatternElement targetPatternElement = null;
@ -114,11 +114,10 @@ public class ExtractRelationImpl implements Serializable {
}
}
if (!this.propertyRuleMap.containsKey(Constants.EDGE_FROM_ID_KEY)) {
this.propertyRuleMap.put(
Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(sourceElement.alias() + ".id"));
this.propertyRuleMap.put(Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(se.alias() + ".id"));
}
this.sourceElement = sourceElement;
this.sourceElement = se;
this.targetEntityElement = targetEntityElement;
this.targetPatternElement = targetPatternElement;
}