mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-11-02 11:04:15 +00:00
feat(reasoner): add create node in define block (#191)
Co-authored-by: wenchengyao <wenchengyao.wcy@antgroup.com> Co-authored-by: FishJoy <chengqiang.cq@antgroup.com>
This commit is contained in:
parent
9c84f14708
commit
e177b7c6c9
@ -139,9 +139,6 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
parseBaseRuleDefine(ctx.base_rule_define(), ddlBlockWithNodes._2, ddlBlockWithNodes._3)
|
||||
val ddlBlockOp = ddlBlockWithNodes._1.ddlOp.head
|
||||
val ruleBlock = ddlInfo._1
|
||||
if (ddlInfo._2.nonEmpty) {
|
||||
return DDLBlock(ddlInfo._2, List.apply(ruleBlock))
|
||||
}
|
||||
ddlBlockOp match {
|
||||
case AddProperty(s, propertyName, _) =>
|
||||
val isLastAssignTargetAlis = ruleBlock match {
|
||||
@ -176,7 +173,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
ProjectRule(
|
||||
IRProperty(s.alias, propertyName),
|
||||
Ref(ddlBlockWithNodes._3.target.alias)))))
|
||||
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
|
||||
DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(prjBlk))
|
||||
case AddPredicate(predicate) =>
|
||||
val attrFields = new mutable.HashMap[String, Expr]()
|
||||
addPropertiesMap.foreach(x =>
|
||||
@ -209,9 +206,9 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
predicate.source,
|
||||
predicate.target,
|
||||
attrFields.toMap,
|
||||
predicate.direction))),
|
||||
predicate.direction))) ++ ddlInfo._2,
|
||||
List.apply(depBlk))
|
||||
case _ => DDLBlock(Set.apply(ddlBlockOp), List.apply(ruleBlock))
|
||||
case _ => DDLBlock(Set.apply(ddlBlockOp) ++ ddlInfo._2, List.apply(ruleBlock))
|
||||
}
|
||||
}
|
||||
|
||||
@ -298,7 +295,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
predicate: PredicateElement): (Block, Set[DDLOp]) = {
|
||||
val matchBlock = parseGraphStructure(ctx.the_graph_structure(), head, predicate)
|
||||
val ruleBlock = parseRule(ctx.the_rule(), matchBlock)
|
||||
val ddlOp = parseCreateAction(ctx.create_action())
|
||||
val ddlOp = parseCreateAction(ctx.create_action(), matchBlock)
|
||||
(ruleBlock, ddlOp)
|
||||
}
|
||||
|
||||
@ -600,11 +597,31 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
curBlock
|
||||
}
|
||||
|
||||
def parseCreateAction(ctx: Create_actionContext): Set[DDLOp] = {
|
||||
def parseCreateAction(ctx: Create_actionContext, matchBlock: MatchBlock): Set[DDLOp] = {
|
||||
if (ctx == null) {
|
||||
Set.empty
|
||||
} else {
|
||||
ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet
|
||||
val ddlBlockSet = ctx.create_action_body().asScala.map(x => parseCreateActionBody(x)).toSet
|
||||
val matchEleInfo = matchBlock.patterns.map(x => x._2.graphPattern.nodes).flatten
|
||||
val allEleInfo = ddlBlockSet.map {
|
||||
case AddVertex(s, _) => s.alias -> s
|
||||
case _ => null
|
||||
}.filter(_ != null).toMap ++ matchEleInfo
|
||||
ddlBlockSet.map {
|
||||
case c: AddVertex => c
|
||||
case c: AddProperty => c
|
||||
case c: AddPredicate => AddPredicate(
|
||||
PredicateElement(
|
||||
c.predicate.label,
|
||||
c.predicate.alias,
|
||||
allEleInfo(c.predicate.source.alias),
|
||||
allEleInfo(c.predicate.target.alias),
|
||||
c.predicate.fields,
|
||||
c.predicate.direction
|
||||
)
|
||||
)
|
||||
}.toSet
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -387,8 +387,8 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
val block = parser.parse(dsl)
|
||||
print(block.pretty)
|
||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.size should equal(2)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.size should equal(3)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
|
||||
}
|
||||
|
||||
it("addNodeException") {
|
||||
@ -504,8 +504,8 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
val block = parser.parse(dsl)
|
||||
print(block.pretty)
|
||||
block.dependencies.head.isInstanceOf[MatchBlock] should equal(true)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.size should equal(1)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddVertex] should equal(true)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.size should equal(2)
|
||||
block.asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddPredicate] should equal(true)
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -41,15 +41,16 @@ object BlockUtils {
|
||||
predicate.target.typeNames.head).toString)
|
||||
case AddProperty(s, propertyName, _) =>
|
||||
defines.add(s.typeNames.head + "." + propertyName)
|
||||
case AddVertex(s, _) =>
|
||||
// defines.add(s.typeNames.head)
|
||||
return Set.apply("result")
|
||||
case _ =>
|
||||
}
|
||||
})
|
||||
case _ => defines.add("result")
|
||||
}
|
||||
defines.toSet
|
||||
if (defines.isEmpty) {
|
||||
Set.apply("result")
|
||||
} else {
|
||||
defines.toSet
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -65,7 +65,13 @@ case class SolvedModel(
|
||||
}
|
||||
|
||||
def solve: SolvedModel = {
|
||||
val tmp = tmpFields.values.map(p => fields(p.name).merge(Option.apply(p))).toList
|
||||
val tmp = tmpFields.values.map(p => {
|
||||
if (!fields.contains(p.name)) {
|
||||
throw InvalidRefVariable(s"not found fields name : ${p} in solved fields")
|
||||
}
|
||||
fields(p.name).merge(Option.apply(p))
|
||||
}
|
||||
).toList
|
||||
var newFields = fields
|
||||
for (t <- tmp) {
|
||||
newFields = newFields.updated(t.name, newFields(t.name).merge(Option.apply(t)))
|
||||
|
||||
@ -402,10 +402,22 @@ object LogicalPlanner {
|
||||
val starts = new mutable.HashSet[String]()
|
||||
for (ddl <- ddlOp) {
|
||||
ddl match {
|
||||
case AddProperty(s, _, _) => starts.add(s.alias)
|
||||
case AddProperty(s, _, _) =>
|
||||
if (starts.isEmpty) {
|
||||
starts.add(s.alias)
|
||||
} else {
|
||||
val common = starts.intersect(Set.apply(s.alias))
|
||||
starts.clear()
|
||||
starts.++=(common)
|
||||
}
|
||||
case AddPredicate(p) =>
|
||||
starts.add(p.source.alias)
|
||||
starts.add(p.target.alias)
|
||||
if (starts.isEmpty) {
|
||||
starts.++=(Set.apply(p.source.alias, p.target.alias))
|
||||
} else {
|
||||
val common = starts.intersect(Set.apply(p.source.alias, p.target.alias))
|
||||
starts.clear()
|
||||
starts.++=(common)
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
@ -254,13 +254,11 @@ class SubQueryPlanner(val dag: Dag[Block])(implicit context: LogicalPlannerConte
|
||||
rootAlias = predicate.target.alias
|
||||
} else if (direction == Direction.OUT && predicate.direction == Direction.IN) {
|
||||
rootAlias == predicate.target.alias
|
||||
} else {
|
||||
} else if (direction != null) {
|
||||
rootAlias == predicate.source.alias
|
||||
}
|
||||
case AddProperty(s, _, _) =>
|
||||
rootAlias = s.alias
|
||||
case AddVertex(s, _) =>
|
||||
rootAlias = s.alias
|
||||
case _ =>
|
||||
}
|
||||
})
|
||||
|
||||
@ -84,7 +84,7 @@ public class KgReasonerLeadToTest {
|
||||
LocalReasonerRunner runner = new LocalReasonerRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println(result);
|
||||
Assert.assertEquals(1, result.getVertexList().size());
|
||||
Assert.assertEquals(2, result.getVertexList().size());
|
||||
}
|
||||
|
||||
public static class GraphLoaderForAddVertex extends AbstractLocalGraphLoader {
|
||||
|
||||
@ -14,6 +14,9 @@
|
||||
package com.antgroup.openspg.reasoner.runner.local.main.transitive;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants;
|
||||
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
|
||||
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
|
||||
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex;
|
||||
import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState;
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.Catalog;
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.impl.PropertyGraphCatalog;
|
||||
@ -25,9 +28,11 @@ import com.antgroup.openspg.reasoner.runner.local.loader.MockLocalGraphLoader;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
|
||||
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
@ -1286,4 +1291,130 @@ public class TransitiveOptionalTest {
|
||||
LocalReasonerResult rst = runTest(schema, dsl, dataGraphStr2);
|
||||
Assert.assertEquals(1, rst.getRows().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateInstance() {
|
||||
String dsl =
|
||||
"Define (s:Custid)-[p:strNum]->(o:Int) {\n"
|
||||
+ " GraphStructure {\n"
|
||||
+ " (s)<-[pp:hasCust]-(str:STR)\n"
|
||||
+ " }\n"
|
||||
+ " Rule {\n"
|
||||
+ " o = group(s).countIf(str.status == 'CLOSE', str)\n"
|
||||
+ " }\n"
|
||||
+ "}\n"
|
||||
+ "\n"
|
||||
+ "Define (s:Custid)-[p:isInWhiteBlack]->(o:Boolean) {\n"
|
||||
+ " GraphStructure {\n"
|
||||
+ " (s)<-[:hasCust]-(str:STR)\n"
|
||||
+ " }\n"
|
||||
+ " Rule {\n"
|
||||
+ " R1: str.matchrule == '0202'\n"
|
||||
+ " R2: str.isreport == '1' \n"
|
||||
+ "\n"
|
||||
+ " o = (R1 and R2)\n"
|
||||
+ " }\n"
|
||||
+ "}\n"
|
||||
+ "\n"
|
||||
+ "Define (s:Custid)-[p:isAggregator]->(o:Boolean) {\n"
|
||||
+ " GraphStructure {\n"
|
||||
+ " (s)<-[e:complained]-(u1:Custid)\n"
|
||||
+ " }\n"
|
||||
+ " Rule {\n"
|
||||
+ " R0: s.isInWhiteBlack == null or s.isInWhiteBlack == false\n"
|
||||
+ " complainNum = group(s).count(e)\n"
|
||||
+ " R5(\"被投诉大于20条\"): complainNum >=20\n"
|
||||
+ "\n"
|
||||
+ " o = true\n"
|
||||
+ " }\n"
|
||||
+ " Action {\n"
|
||||
+ " gang = createNodeInstance(\n"
|
||||
+ " type=Gang,\n"
|
||||
+ " value={\n"
|
||||
+ " id=concat(s.id, \"_gang\")\n"
|
||||
+ " }\n"
|
||||
+ " )\n"
|
||||
+ " createEdgeInstance(\n"
|
||||
+ " src=gang,\n"
|
||||
+ " dst=s,\n"
|
||||
+ " type=has,\n"
|
||||
+ " value={\n"
|
||||
+ " }\n"
|
||||
+ " )\n"
|
||||
+ " }\n"
|
||||
+ "}\n"
|
||||
+ "GraphStructure {"
|
||||
+ " A [Custid, __start__='true']\n"
|
||||
+ " B [Gang]\n"
|
||||
+ " B->A [has]\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id, A.isAggregator, B.id) \n"
|
||||
+ "}";
|
||||
|
||||
System.out.println(dsl);
|
||||
LocalReasonerTask task = new LocalReasonerTask();
|
||||
task.setDsl(dsl);
|
||||
|
||||
// add mock catalog
|
||||
Map<String, Set<String>> schema = new HashMap<>();
|
||||
schema.put(
|
||||
"Custid",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(
|
||||
Sets.newHashSet(
|
||||
"trdAmtIn90d",
|
||||
"trdAmt90d",
|
||||
"trdCntCustIn90d",
|
||||
"custcntpty90CustNum90dInGenderFemale",
|
||||
"custcntpty90CustNum90dInGenderMale",
|
||||
"name")));
|
||||
schema.put(
|
||||
"STR",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(
|
||||
Sets.newHashSet("conclusion", "name", "status", "matchrule", "isreport")));
|
||||
|
||||
schema.put("Gang", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("cid", "name")));
|
||||
schema.put("Gang_has_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("info")));
|
||||
schema.put(
|
||||
"Custid_complained_Custid",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo")));
|
||||
schema.put(
|
||||
"STR_hasCust_Custid", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("createMemo")));
|
||||
|
||||
Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema));
|
||||
catalog.init();
|
||||
task.setCatalog(catalog);
|
||||
task.setGraphLoadClass(
|
||||
"com.antgroup.openspg.reasoner.runner.local.main.transitive.TransitiveOptionalTest$GangGraphLoader");
|
||||
|
||||
// enable subquery
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true);
|
||||
params.put(Constants.SPG_REASONER_MULTI_VERSION_ENABLE, "true");
|
||||
task.setParams(params);
|
||||
|
||||
LocalReasonerRunner runner = new LocalReasonerRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
Assert.assertEquals(2, result.getRows().size());
|
||||
}
|
||||
|
||||
public static class GangGraphLoader extends AbstractLocalGraphLoader {
|
||||
@Override
|
||||
public List<IVertex<String, IProperty>> genVertexList() {
|
||||
return Lists.newArrayList(
|
||||
constructionVertex("A1", "Custid", "name", "A1", "cid", "a1"),
|
||||
constructionVertex("A2", "Custid", "name", "A2", "cid", "a2"),
|
||||
constructionVertex("B1", "Gang", "name", "B2", "cid", "b1"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<IEdge<String, IProperty>> genEdgeList() {
|
||||
return Lists.newArrayList(
|
||||
constructionEdge("B1", "has", "A1", "info", "b1_a1"),
|
||||
constructionEdge("B1", "has", "A2", "info", "b1_a2"),
|
||||
constructionEdge("A1", "complained", "A2", "info", "a1ca2"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ public class ExtractRelationImpl implements Serializable {
|
||||
|
||||
private final String taskId;
|
||||
|
||||
private final PatternElement sourceElement;
|
||||
private final Element sourceElement;
|
||||
private final EntityElement targetEntityElement;
|
||||
private final PatternElement targetPatternElement;
|
||||
|
||||
@ -84,7 +84,7 @@ public class ExtractRelationImpl implements Serializable {
|
||||
this.propertyRuleMap.put(propertyName, rule);
|
||||
}
|
||||
|
||||
PatternElement sourceElement = (PatternElement) addPredicate.predicate().source();
|
||||
Element se = addPredicate.predicate().source();
|
||||
Element te = addPredicate.predicate().target();
|
||||
EntityElement targetEntityElement = null;
|
||||
PatternElement targetPatternElement = null;
|
||||
@ -114,11 +114,10 @@ public class ExtractRelationImpl implements Serializable {
|
||||
}
|
||||
}
|
||||
if (!this.propertyRuleMap.containsKey(Constants.EDGE_FROM_ID_KEY)) {
|
||||
this.propertyRuleMap.put(
|
||||
Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(sourceElement.alias() + ".id"));
|
||||
this.propertyRuleMap.put(Constants.EDGE_FROM_ID_KEY, Lists.newArrayList(se.alias() + ".id"));
|
||||
}
|
||||
|
||||
this.sourceElement = sourceElement;
|
||||
this.sourceElement = se;
|
||||
this.targetEntityElement = targetEntityElement;
|
||||
this.targetPatternElement = targetPatternElement;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user