mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-28 07:33:59 +00:00
fix(reasoner): bugfix for mock data to run (#162)
Co-authored-by: Donghai <donghai.ydh@antgroup.com> Co-authored-by: FishJoy <chengqiang.cq@antgroup.com> Co-authored-by: wangshaofei <wangshaofei.wsf@antgroup.com>
This commit is contained in:
parent
57f74c5447
commit
1982db26ad
@ -33,7 +33,8 @@ import com.antgroup.openspg.reasoner.lube.common.pattern.{
|
||||
EntityElement,
|
||||
GraphPath,
|
||||
PatternElement,
|
||||
PredicateElement
|
||||
PredicateElement,
|
||||
VariablePatternConnection
|
||||
}
|
||||
import com.antgroup.openspg.reasoner.lube.common.rule.{LogicRule, ProjectRule, Rule}
|
||||
import com.antgroup.openspg.reasoner.lube.parser.ParserInterface
|
||||
@ -142,7 +143,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
return DDLBlock(ddlInfo._2, List.apply(ruleBlock))
|
||||
}
|
||||
ddlBlockOp match {
|
||||
case AddProperty(s, propertyName, propertyType) =>
|
||||
case AddProperty(s, propertyName, _) =>
|
||||
val isLastAssignTargetAlis = ruleBlock match {
|
||||
case ProjectBlock(_, projects) =>
|
||||
var tmpIsAssign = false
|
||||
@ -242,7 +243,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
ctx.full_edge_pointing_right().element_pattern_declaration_and_filler(),
|
||||
ctx.full_edge_pointing_right().edge_pattern_pernodelimit_clause(),
|
||||
Direction.OUT,
|
||||
false)
|
||||
isOptional = false)
|
||||
|
||||
val predicateElement =
|
||||
PredicateElement(p.relTypes.head, p.alias, s, o, Map.empty, Direction.OUT)
|
||||
@ -278,23 +279,16 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
}
|
||||
}
|
||||
|
||||
def isNeedDependenceExpr(rule: Rule, ruleRefRelate: Map[Rule, Set[Rule]]): Boolean = {
|
||||
if (!ruleRefRelate.contains(rule) || ruleRefRelate(rule).size != 1) {
|
||||
return false
|
||||
}
|
||||
val refRule = ruleRefRelate(rule).head
|
||||
refRule.getExpr match {
|
||||
case _: OrderAndLimit => false
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
|
||||
def isGenerateOneStepBlockExpr(rule: Rule): Boolean = {
|
||||
def isFilter2ProjectBlock(rule: Rule, ruleRefRelate: Map[Rule, Set[Rule]]): Boolean = {
|
||||
rule.getExpr match {
|
||||
case _: OrderAndLimit => true
|
||||
case _: GraphAggregatorExpr => true
|
||||
case _: OpChainExpr => true
|
||||
case _: OrderAndLimit => true
|
||||
case _ => false
|
||||
case _ => if (!ruleRefRelate.contains(rule)) {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -497,18 +491,12 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
rule: Rule,
|
||||
preBlock: Block,
|
||||
kg: IRGraph): Block = {
|
||||
val isGenerateStep = isGenerateOneStepBlockExpr(rule)
|
||||
if (isNeedDependenceExpr(rule, ruleRefRelate) && !isGenerateStep) {
|
||||
// if ref equal 1, and without graph group we add to dependencies
|
||||
ruleRefRelate(rule).head.addDependency(rule)
|
||||
null
|
||||
} else {
|
||||
genBlockOp(
|
||||
rule,
|
||||
preBlock,
|
||||
(ruleRefRelate.contains(rule) && ruleRefRelate(rule).size > 1) || isGenerateStep,
|
||||
kg)
|
||||
}
|
||||
val isFilter2ProjectStep = isFilter2ProjectBlock(rule, ruleRefRelate)
|
||||
genBlockOp(
|
||||
rule,
|
||||
preBlock,
|
||||
isFilter2ProjectStep,
|
||||
kg)
|
||||
}
|
||||
|
||||
def parseRule(ctx: The_ruleContext, matchBlock: MatchBlock): Block = {
|
||||
@ -527,18 +515,43 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
parseRuleBlock(rules, matchBlock.patterns)
|
||||
}
|
||||
|
||||
def findAllPathAlias(patterns: Map[String, GraphPath]): Map[String, IRPath] = {
|
||||
patterns.values
|
||||
.flatMap(path =>
|
||||
path.graphPattern.edges.flatMap(pair => pair._2.map {
|
||||
case connection: VariablePatternConnection =>
|
||||
val start = IRNode(connection.source, Set.empty)
|
||||
val end = IRNode(connection.target, Set.empty)
|
||||
val irEdge = IREdge(connection.alias, Set.empty)
|
||||
connection.alias -> IRPath(connection.alias, List.apply(start, irEdge, end))
|
||||
case _ => null
|
||||
}.filter(_ != null).toMap))
|
||||
.toMap
|
||||
}
|
||||
|
||||
def addIntoRefFieldsMap(fieldName: String,
|
||||
fields: Set[String],
|
||||
refFieldsMap: Map[String, Set[String]]): Map[String, Set[String]] = {
|
||||
var attrs = fields
|
||||
if (refFieldsMap.contains(fieldName)) {
|
||||
attrs = attrs ++ refFieldsMap(fieldName)
|
||||
}
|
||||
refFieldsMap + (fieldName -> attrs)
|
||||
}
|
||||
|
||||
def parseRuleBlock(rules: List[Rule], patterns: Map[String, GraphPath]): Block = {
|
||||
var refFieldsMap: Map[String, Set[String]] = Map.empty
|
||||
val allRepeatPath = findAllPathAlias(patterns)
|
||||
if (rules.nonEmpty) {
|
||||
rules.foreach(rule => {
|
||||
val irFields = RuleUtils.getAllInputFieldInRule(rule, Set.empty, Set.empty)
|
||||
irFields.foreach {
|
||||
val repeatIrFields = ExprUtils.getRepeatPathInputFieldInRule(rule.getExpr, allRepeatPath)
|
||||
val totalIrFields = irFields ++ repeatIrFields
|
||||
totalIrFields.foreach {
|
||||
case c: IRNode =>
|
||||
var attrs = c.fields
|
||||
if (refFieldsMap.contains(c.name)) {
|
||||
attrs = attrs ++ refFieldsMap(c.name)
|
||||
}
|
||||
refFieldsMap += (c.name -> attrs)
|
||||
refFieldsMap = addIntoRefFieldsMap(c.name, c.fields, refFieldsMap)
|
||||
case c: IREdge =>
|
||||
refFieldsMap = addIntoRefFieldsMap(c.name, c.fields, refFieldsMap)
|
||||
case c => c
|
||||
}
|
||||
})
|
||||
@ -549,7 +562,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
var ruleInstructs = Map[Rule, Set[Rule]]()
|
||||
|
||||
rules.foreach(rule => {
|
||||
val refRules = getRefRules(rule, rules.toList)
|
||||
val refRules = getRefRules(rule, rules)
|
||||
for (ref <- refRules) {
|
||||
if (ruleInstructs.contains(ref)) {
|
||||
ruleInstructs += ref -> ruleInstructs(ref).union(Set.apply(rule))
|
||||
@ -809,7 +822,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
def parseBaseJob(ctx: Base_jobContext, param: Map[String, Object]): Block = {
|
||||
var head: PatternElement = null
|
||||
if (param.contains(Constants.START_LABEL)) {
|
||||
head = PatternElement(null, Set.apply(param(Constants.START_LABEL).toString), null);
|
||||
head = PatternElement(null, Set.apply(param(Constants.START_LABEL).toString), null)
|
||||
}
|
||||
|
||||
if (param.contains(Constants.START_ALIAS)) {
|
||||
|
||||
@ -1083,14 +1083,16 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
print(block.pretty)
|
||||
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b), distinct=false)
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
|
||||
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R3) -> LogicRule(R3,男性,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R2) -> LogicRule(R2,有车,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R1) -> LogicRule(R1,有房,BinaryOpExpr(name=BEqual)))))
|
||||
* └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(null,Map(s -> (s:User)),Map(),Map(s -> Set(beautiful, haveCar, haveHouse, height, id, gender))),false)))
|
||||
* └─SourceBlock(graph=KG(Map(s -> IRNode(s,Set(beautiful, haveCar, haveHouse, height, id, gender))),Map()))"""
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R8) -> LogicRule(R8,白富美,BinaryOpExpr(name=BAnd)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R7) -> LogicRule(R7,高富帅,BinaryOpExpr(name=BAnd)))))
|
||||
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R3) -> LogicRule(R3,男性,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R2) -> LogicRule(R2,有车,BinaryOpExpr(name=BEqual)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R1) -> LogicRule(R1,有房,BinaryOpExpr(name=BEqual)))))
|
||||
* └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(null,Map(s -> (s:User)),Map(),Map(s -> Set(beautiful, haveCar, haveHouse, height, id, gender))),false)))
|
||||
* └─SourceBlock(graph=KG(Map(s -> IRNode(s,Set(beautiful, haveCar, haveHouse, height, id, gender))),Map()))"""
|
||||
.stripMargin('*')
|
||||
block.pretty should equal(text)
|
||||
}
|
||||
|
||||
@ -16,6 +16,8 @@ package com.antgroup.openspg.reasoner.parser.expr
|
||||
import com.antgroup.openspg.reasoner.common.exception.KGDSLGrammarException
|
||||
import com.antgroup.openspg.reasoner.common.types.{KTObject, KTString}
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{BinaryOpExpr, _}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.{IREdge, IRNode, IRPath}
|
||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
|
||||
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Expr2QlexpressTransformer
|
||||
import org.scalatest.funspec.AnyFunSpec
|
||||
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
|
||||
@ -553,6 +555,75 @@ class ExprTest extends AnyFunSpec {
|
||||
Ref("e")
|
||||
), null).pretty
|
||||
expr.pretty should equal(expectResult)
|
||||
val repeatPath = Map.apply("e" -> IRPath("e", List.apply(
|
||||
IRNode("A", Set.empty),
|
||||
IREdge("e", Set.empty),
|
||||
IRNode("B", Set.empty)
|
||||
)))
|
||||
val irFields = ExprUtils.getRepeatPathInputFieldInRule(expr, repeatPath)
|
||||
irFields.size should equal(0)
|
||||
}
|
||||
|
||||
it ("e.edges().reduce((x, y) => x + y.times, 0)") {
|
||||
val exprParser = new RuleExprParser()
|
||||
val expr = exprParser.parse("e.edges().reduce((x, y) => x + y.times, 0)")
|
||||
print(expr.pretty)
|
||||
val expectResult = OpChainExpr(
|
||||
ListOpExpr(
|
||||
Reduce(
|
||||
"y",
|
||||
"x",
|
||||
BinaryOpExpr(BAdd, Ref("x"), UnaryOpExpr(GetField("times"), Ref("y"))),
|
||||
VLong("0")
|
||||
),
|
||||
Ref("e")
|
||||
),
|
||||
OpChainExpr(
|
||||
PathOpExpr(
|
||||
GetEdgesExpr,
|
||||
Ref("e")
|
||||
),
|
||||
null
|
||||
)
|
||||
).pretty
|
||||
expr.pretty should equal(expectResult)
|
||||
val repeatPath = Map.apply("e" -> IRPath("e", List.apply(
|
||||
IRNode("A", Set.empty),
|
||||
IREdge("e", Set.empty),
|
||||
IRNode("B", Set.empty)
|
||||
)))
|
||||
val irFields = ExprUtils.getRepeatPathInputFieldInRule(expr, repeatPath)
|
||||
irFields.size should equal(1)
|
||||
}
|
||||
|
||||
it ("e.edges().constraint((cur, pre) => cur.logId == pre.logId)") {
|
||||
val exprParser = new RuleExprParser()
|
||||
val expr = exprParser.parse("e.edges().constraint((cur, pre) => cur.logId == pre.logId)")
|
||||
print(expr.pretty)
|
||||
val expectResult = OpChainExpr(
|
||||
ListOpExpr(
|
||||
Constraint(
|
||||
"cur",
|
||||
"pre",
|
||||
BinaryOpExpr(BEqual,
|
||||
UnaryOpExpr(GetField("logId"), Ref("cur")), UnaryOpExpr(GetField("logId"), Ref("pre")))
|
||||
),
|
||||
Ref("e")
|
||||
),
|
||||
OpChainExpr(
|
||||
PathOpExpr(
|
||||
GetEdgesExpr,
|
||||
Ref("e")
|
||||
), null)
|
||||
).pretty
|
||||
expr.pretty should equal(expectResult)
|
||||
val repeatPath = Map.apply("e" -> IRPath("e", List.apply(
|
||||
IRNode("A", Set.empty),
|
||||
IREdge("e", Set.empty),
|
||||
IRNode("B", Set.empty)
|
||||
)))
|
||||
val irFields = ExprUtils.getRepeatPathInputFieldInRule(expr, repeatPath)
|
||||
irFields.size should equal(1)
|
||||
}
|
||||
|
||||
it ("e.nodes().reduce((x, y) => x + y.times, 0)") {
|
||||
@ -578,6 +649,13 @@ class ExprTest extends AnyFunSpec {
|
||||
)
|
||||
).pretty
|
||||
expr.pretty should equal(expectResult)
|
||||
val repeatPath = Map.apply("e" -> IRPath("e", List.apply(
|
||||
IRNode("A", Set.empty),
|
||||
IREdge("e", Set.empty),
|
||||
IRNode("B", Set.empty)
|
||||
)))
|
||||
val irFields = ExprUtils.getRepeatPathInputFieldInRule(expr, repeatPath)
|
||||
irFields.size should equal(2)
|
||||
}
|
||||
|
||||
it ("e.nodes().constraint((cur, pre) => cur.logId == pre.logId)") {
|
||||
@ -601,6 +679,13 @@ class ExprTest extends AnyFunSpec {
|
||||
), null)
|
||||
).pretty
|
||||
expr.pretty should equal(expectResult)
|
||||
val repeatPath = Map.apply("e" -> IRPath("e", List.apply(
|
||||
IRNode("A", Set.empty),
|
||||
IREdge("e", Set.empty),
|
||||
IRNode("B", Set.empty)
|
||||
)))
|
||||
val irFields = ExprUtils.getRepeatPathInputFieldInRule(expr, repeatPath)
|
||||
irFields.size should equal(2)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
import com.antgroup.openspg.reasoner.common.trees.{BottomUp, TopDown, Transform}
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{Constraint, Expr, GetField, ListOpExpr, Ref, UnaryOpExpr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr._
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||
|
||||
/**
|
||||
@ -43,6 +43,48 @@ object ExprUtils {
|
||||
}).transform(expr)
|
||||
}
|
||||
|
||||
def getRepeatPathInputFieldInRule(expr: Expr,
|
||||
repeatPathMap: Map[String, IRPath]): List[IRField] = {
|
||||
expr match {
|
||||
case OpChainExpr(ListOpExpr(listOp, _), OpChainExpr(PathOpExpr(name, ref), _)) =>
|
||||
val irPath = repeatPathMap(ref.refName)
|
||||
val props = listOp match {
|
||||
case constraint: Constraint =>
|
||||
getAllInputFieldInRule(constraint.reduceFunc, Set.empty, Set.empty).filter(
|
||||
ir => ir.name.equals(constraint.cur) && ir.name.equals(constraint.pre))
|
||||
.flatMap(t => t match {
|
||||
case IRNode(_, fields) => fields.toList
|
||||
case IREdge(_, fields) => fields.toList
|
||||
case _ => List.empty
|
||||
})
|
||||
|
||||
case compute: Reduce =>
|
||||
getAllInputFieldInRule(compute.reduceFunc, Set.empty, Set.empty).filter(
|
||||
ir => ir.name.equals(compute.ele)
|
||||
).flatMap(t => t match {
|
||||
case IRNode(_, fields) => fields.toList
|
||||
case IREdge(_, fields) => fields.toList
|
||||
case _ => List.empty
|
||||
})
|
||||
|
||||
case _ => List.empty
|
||||
}
|
||||
name match {
|
||||
case GetNodesExpr =>
|
||||
irPath.elements.filter(ele => ele.isInstanceOf[IRNode]).map {
|
||||
case IRNode(irName, fields) => IRNode(irName, fields ++ props)
|
||||
case _ => null
|
||||
}.filter(_ != null)
|
||||
case GetEdgesExpr =>
|
||||
irPath.elements.filter(ele => ele.isInstanceOf[IREdge]).map {
|
||||
case IREdge(irName, fields) => IREdge(irName, fields ++ props)
|
||||
case _ => null
|
||||
}.filter(_ != null)
|
||||
}
|
||||
case _ => List.empty
|
||||
}
|
||||
}
|
||||
|
||||
def getAllInputFieldInRule(
|
||||
expr: Expr,
|
||||
nodesAlias: Set[String],
|
||||
@ -69,6 +111,12 @@ object ExprUtils {
|
||||
getAllInputFieldInRule(constraint.reduceFunc, nodesAlias, edgeAlias).filter(
|
||||
ir => !ir.name.equals(constraint.cur) && !ir.name.equals(constraint.pre))
|
||||
mergeListIRField(c.flatten ++ irList)
|
||||
case compute: Reduce =>
|
||||
val irList =
|
||||
getAllInputFieldInRule(compute.reduceFunc, nodesAlias, edgeAlias).filter(
|
||||
ir => !ir.name.equals(compute.ele) && !ir.name.equals(compute.res)
|
||||
)
|
||||
mergeListIRField(c.flatten ++ irList)
|
||||
case _ =>
|
||||
mergeListIRField(c.flatten)
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ package com.antgroup.openspg.reasoner.lube.logical.optimizer.rules
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.{Catalog, SemanticPropertyGraph}
|
||||
import com.antgroup.openspg.reasoner.lube.common.pattern.NodePattern
|
||||
import com.antgroup.openspg.reasoner.lube.logical.{EdgeVar, NodeVar, PropertyVar, Var}
|
||||
@ -42,14 +43,15 @@ object ExpandIntoPure extends Rule {
|
||||
case (order: OrderAndLimit, map) =>
|
||||
order -> merge(map, order.refFields, order.solved.getNodeAliasSet)
|
||||
case (expandInto @ ExpandInto(in, _, _), map) =>
|
||||
val newMap = merge(map, expandInto.refFields, expandInto.solved.getNodeAliasSet)
|
||||
val needPure = canPure(
|
||||
expandInto,
|
||||
map.asInstanceOf[Map[String, Var]],
|
||||
newMap.asInstanceOf[Map[String, Var]],
|
||||
context.catalog.getGraph(Catalog.defaultGraphName))
|
||||
if (needPure) {
|
||||
in -> map
|
||||
in -> newMap
|
||||
} else {
|
||||
expandInto -> map
|
||||
expandInto -> newMap
|
||||
}
|
||||
|
||||
}
|
||||
@ -92,8 +94,9 @@ object ExpandIntoPure extends Rule {
|
||||
if (!map.contains(alias) || map(alias).isEmpty) {
|
||||
true
|
||||
} else {
|
||||
val usedPros = map(alias).asInstanceOf[NodeVar].fields
|
||||
val originalProps = types.map(graph.getNode(_).properties).flatten
|
||||
if (map(alias).asInstanceOf[NodeVar].fields.intersect(originalProps).isEmpty) {
|
||||
if (usedPros.intersect(originalProps).isEmpty) {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
|
||||
@ -24,12 +24,7 @@ import com.antgroup.openspg.reasoner.lube.block.{AddPredicate, AddProperty, AddV
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRNode, IRProperty, IRVariable}
|
||||
import com.antgroup.openspg.reasoner.lube.common.pattern.{
|
||||
Connection,
|
||||
NodePattern,
|
||||
Pattern,
|
||||
VariablePatternConnection
|
||||
}
|
||||
import com.antgroup.openspg.reasoner.lube.common.pattern.{Connection, NodePattern, Pattern, VariablePatternConnection}
|
||||
import com.antgroup.openspg.reasoner.lube.logical.{NodeVar, PropertyVar, Var}
|
||||
import com.antgroup.openspg.reasoner.lube.logical.operators._
|
||||
import com.antgroup.openspg.reasoner.lube.logical.optimizer.{Direction, Rule, Up}
|
||||
@ -48,7 +43,7 @@ object NodeIdToEdgeProperty extends Rule {
|
||||
} else {
|
||||
val toEdge = targetConnection(expandInto)
|
||||
if (toEdge != null) {
|
||||
expandInto -> (map + (expandInto.pattern.root.alias -> toEdge))
|
||||
expandInto.in -> (map + (expandInto.pattern.root.alias -> toEdge))
|
||||
} else {
|
||||
expandInto -> map
|
||||
}
|
||||
@ -251,7 +246,12 @@ object NodeIdToEdgeProperty extends Rule {
|
||||
if (!expandInto.pattern.isInstanceOf[NodePattern]) {
|
||||
false
|
||||
} else {
|
||||
val fieldNames = expandInto.refFields.head.asInstanceOf[NodeVar].fields.map(_.name)
|
||||
val fieldNames = expandInto.fields
|
||||
.filter(_.name.equals(expandInto.pattern.root.alias))
|
||||
.head
|
||||
.asInstanceOf[NodeVar]
|
||||
.fields
|
||||
.map(_.name)
|
||||
val normalNames = fieldNames.filter(!NODE_DEFAULT_PROPS.contains(_))
|
||||
if (normalNames.isEmpty) {
|
||||
true
|
||||
|
||||
@ -61,9 +61,9 @@ object LogicalPlanner {
|
||||
val source = resolve(input)
|
||||
val groups = getStarts(input)
|
||||
val planWithoutResult = if (groups.isEmpty) {
|
||||
planBlock(input.dependencies.head, None, source)
|
||||
planBlock(input.dependencies.head, input, None, source)
|
||||
} else {
|
||||
planBlock(input.dependencies.head, None, source)(
|
||||
planBlock(input.dependencies.head, input, None, source)(
|
||||
context.addParam(Constants.START_ALIAS, groups.head))
|
||||
}
|
||||
val plan = input match {
|
||||
@ -200,14 +200,17 @@ object LogicalPlanner {
|
||||
}
|
||||
}
|
||||
|
||||
private def planBlock(input: Block, plan: Option[LogicalOperator], solvedModel: SolvedModel)(
|
||||
implicit context: LogicalPlannerContext): LogicalOperator = {
|
||||
private def planBlock(
|
||||
input: Block,
|
||||
root: Block,
|
||||
plan: Option[LogicalOperator],
|
||||
solvedModel: SolvedModel)(implicit context: LogicalPlannerContext): LogicalOperator = {
|
||||
if (input.dependencies.isEmpty) {
|
||||
planLeaf(input, solvedModel)
|
||||
} else {
|
||||
// plan one of the block dependencies
|
||||
val dependency = planBlock(input.dependencies.head, plan, solvedModel)
|
||||
planNonLeaf(input, solvedModel, dependency)
|
||||
val dependency = planBlock(input.dependencies.head, root, plan, solvedModel)
|
||||
planNonLeaf(input, root, solvedModel, dependency)
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,8 +225,11 @@ object LogicalPlanner {
|
||||
}
|
||||
}
|
||||
|
||||
private def planNonLeaf(block: Block, solvedModel: SolvedModel, plan: LogicalOperator)(implicit
|
||||
context: LogicalPlannerContext): LogicalOperator = {
|
||||
private def planNonLeaf(
|
||||
block: Block,
|
||||
root: Block,
|
||||
solvedModel: SolvedModel,
|
||||
plan: LogicalOperator)(implicit context: LogicalPlannerContext): LogicalOperator = {
|
||||
block match {
|
||||
case MatchBlock(_, matches) =>
|
||||
// TODO: plan the first one in current
|
||||
@ -231,7 +237,7 @@ object LogicalPlanner {
|
||||
case FilterBlock(_, rule) =>
|
||||
planFilter(rule, plan)
|
||||
case ProjectBlock(_, projects) =>
|
||||
planProject(projects, plan)
|
||||
planProject(projects, root, plan)
|
||||
case AggregationBlock(_, aggregations, group) =>
|
||||
planAggregate(aggregations, group, plan)
|
||||
case OrderAndSliceBlock(_, orderBy, limit, group) =>
|
||||
@ -371,9 +377,9 @@ object LogicalPlanner {
|
||||
* @param context
|
||||
* @return
|
||||
*/
|
||||
private def planProject(projects: ProjectFields, dependency: LogicalOperator)(implicit
|
||||
context: LogicalPlannerContext): LogicalOperator = {
|
||||
val projectPlanner = new ProjectPlanner(projects)
|
||||
private def planProject(projects: ProjectFields, root: Block, dependency: LogicalOperator)(
|
||||
implicit context: LogicalPlannerContext): LogicalOperator = {
|
||||
val projectPlanner = new ProjectPlanner(projects, root)
|
||||
projectPlanner.plan(dependency)
|
||||
}
|
||||
|
||||
|
||||
@ -17,18 +17,19 @@ import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType
|
||||
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
|
||||
import com.antgroup.openspg.reasoner.lube.block.{AggregationBlock, Block, ProjectFields}
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
||||
import com.antgroup.openspg.reasoner.lube.logical._
|
||||
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
|
||||
import com.antgroup.openspg.reasoner.lube.logical.operators.{Driving, LogicalOperator, Project, Start}
|
||||
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
|
||||
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
||||
class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
|
||||
class ProjectPlanner(projects: ProjectFields, root: Block)(implicit
|
||||
context: LogicalPlannerContext) {
|
||||
|
||||
def plan(dependency: LogicalOperator): LogicalOperator = {
|
||||
val projectMap = new mutable.HashMap[Var, Expr]()
|
||||
@ -88,7 +89,7 @@ class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerCo
|
||||
|
||||
left match {
|
||||
case IRVariable(name) =>
|
||||
if (referVars.size == 1) {
|
||||
if (referVars.map(_.name).toSet.size == 1) {
|
||||
PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
|
||||
} else {
|
||||
val aliasSet = new mutable.HashSet[String]()
|
||||
@ -110,11 +111,38 @@ class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerCo
|
||||
|
||||
}
|
||||
|
||||
private def getGroups: Set[String] = {
|
||||
root.transform[Set[String]] {
|
||||
case (AggregationBlock(_, _, group), groupList) =>
|
||||
val groupAlias = group.map(_.name).toSet
|
||||
if (groupList.head.isEmpty) {
|
||||
groupAlias
|
||||
} else {
|
||||
val commonGroups = groupList.head.intersect(groupAlias)
|
||||
if (commonGroups.isEmpty) {
|
||||
throw UnsupportedOperationException(
|
||||
s"cannot support groups ${groupAlias}, ${groupList.head}")
|
||||
} else {
|
||||
commonGroups
|
||||
}
|
||||
}
|
||||
case (_, groupList) =>
|
||||
if (groupList.isEmpty) {
|
||||
Set.empty
|
||||
} else {
|
||||
groupList.head
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getTargetAlias(aliasSet: Set[String], dependency: LogicalOperator): String = {
|
||||
val groups = getGroups
|
||||
val aliasOrder = dependency.transform[String] {
|
||||
case (stackOp: StackingLogicalOperator, list) =>
|
||||
case (start: Start, _) => ""
|
||||
case (driving: Driving, _) => ""
|
||||
case (stackOp: LogicalOperator, list) =>
|
||||
if (StringUtils.isEmpty(list.head)) {
|
||||
val curAliasSet = stackOp.fields.map(_.name).toSet
|
||||
val curAliasSet = stackOp.fields.map(_.name).toSet.diff(groups)
|
||||
if (aliasSet.diff(curAliasSet).isEmpty) {
|
||||
if (aliasSet.isEmpty) {
|
||||
curAliasSet.head
|
||||
@ -127,12 +155,6 @@ class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerCo
|
||||
} else {
|
||||
list.head
|
||||
}
|
||||
case (_, list) =>
|
||||
if (!list.isEmpty && StringUtils.isNotEmpty(list.head)) {
|
||||
list.head
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
aliasOrder
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@ import com.antgroup.openspg.reasoner.runner.local.impl.LocalPropertyGraph;
|
||||
import com.antgroup.openspg.reasoner.runner.local.impl.LocalReasonerSession;
|
||||
import com.antgroup.openspg.reasoner.runner.local.impl.LocalRunnerThreadPool;
|
||||
import com.antgroup.openspg.reasoner.runner.local.load.graph.AbstractLocalGraphLoader;
|
||||
import com.antgroup.openspg.reasoner.runner.local.loader.MockLocalGraphLoader;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
|
||||
import com.antgroup.openspg.reasoner.runner.local.rdg.LocalRDG;
|
||||
@ -157,17 +158,24 @@ public class LocalReasonerRunner {
|
||||
localPropertyGraph.setStartIdTuple2List(null);
|
||||
}
|
||||
|
||||
String isGraphOutput =
|
||||
String.valueOf(
|
||||
task.getParams().computeIfAbsent(ConfigKey.KG_REASONER_OUTPUT_GRAPH, k -> "false"));
|
||||
if ("true".equals(isGraphOutput)) {
|
||||
localPropertyGraph.setCarryTraversalGraph(true);
|
||||
}
|
||||
|
||||
// judge is need add same mock graph
|
||||
if (task.getParams().containsKey(ConfigKey.KG_REASONER_MOCK_GRAPH_DATA)) {
|
||||
String demoGraph = task.getParams().get(ConfigKey.KG_REASONER_MOCK_GRAPH_DATA).toString();
|
||||
MockLocalGraphLoader mockLocalGraphLoader = new MockLocalGraphLoader(demoGraph);
|
||||
mockLocalGraphLoader.setGraphState(localPropertyGraph.getGraphState());
|
||||
mockLocalGraphLoader.load();
|
||||
}
|
||||
|
||||
if (physicalOpRoot instanceof Select) {
|
||||
String isGraphOutput =
|
||||
String.valueOf(
|
||||
task.getParams().computeIfAbsent(ConfigKey.KG_REASONER_OUTPUT_GRAPH, k -> "false"));
|
||||
if ("true".equals(isGraphOutput)) {
|
||||
LocalRDG rdg = ((Select<LocalRDG>) physicalOpRoot).in().rdg();
|
||||
result = rdg.getRDGGraph();
|
||||
} else {
|
||||
LocalRow row = (LocalRow) ((Select<LocalRDG>) physicalOpRoot).row();
|
||||
result = row.getResult();
|
||||
}
|
||||
LocalRow row = (LocalRow) ((Select<LocalRDG>) physicalOpRoot).row();
|
||||
result = row.getResult();
|
||||
} else {
|
||||
LocalRDG rdg = physicalOpRoot.rdg();
|
||||
result = rdg.getResult();
|
||||
|
||||
@ -57,6 +57,9 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
|
||||
/** default path limit */
|
||||
private long defaultPathLimit = 3000;
|
||||
|
||||
/** carry traversal graph data */
|
||||
private boolean isCarryTraversalGraph = false;
|
||||
|
||||
/** local property graph */
|
||||
public LocalPropertyGraph(GraphState<IVertexId> graphState) {
|
||||
this.graphState = graphState;
|
||||
@ -72,7 +75,8 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
|
||||
executorTimeoutMs,
|
||||
alias,
|
||||
getTaskId(),
|
||||
getExecutionRecorder());
|
||||
getExecutionRecorder(),
|
||||
isCarryTraversalGraph);
|
||||
result.setMaxPathLimit(getMaxPathLimit());
|
||||
result.setStrictMaxPathLimit(getStrictMaxPathLimit());
|
||||
return result;
|
||||
@ -101,7 +105,9 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
|
||||
executorTimeoutMs,
|
||||
alias,
|
||||
getTaskId(),
|
||||
getExecutionRecorder());
|
||||
// subquery can not carry all graph
|
||||
getExecutionRecorder(),
|
||||
false);
|
||||
result.setMaxPathLimit(getMaxPathLimit());
|
||||
result.setStrictMaxPathLimit(getStrictMaxPathLimit());
|
||||
return result;
|
||||
@ -168,6 +174,15 @@ public class LocalPropertyGraph implements PropertyGraph<LocalRDG> {
|
||||
this.executorTimeoutMs = executorTimeoutMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>isCarryTraversalGraph</tt>.
|
||||
*
|
||||
* @param carryTraversalGraph value to be assigned to property isCarryTraversalGraph
|
||||
*/
|
||||
public void setCarryTraversalGraph(boolean carryTraversalGraph) {
|
||||
isCarryTraversalGraph = carryTraversalGraph;
|
||||
}
|
||||
|
||||
/** max path limit */
|
||||
private Long getMaxPathLimit() {
|
||||
Object maxPathLimitObj = null;
|
||||
|
||||
@ -72,6 +72,21 @@ public class LocalReasonerResult {
|
||||
this.errMsg = "";
|
||||
}
|
||||
|
||||
/** output graph and row */
|
||||
public LocalReasonerResult(
|
||||
List<String> columns,
|
||||
List<Object[]> rows,
|
||||
List<IVertex<IVertexId, IProperty>> vertexList,
|
||||
List<IEdge<IVertexId, IProperty>> edgeList,
|
||||
boolean graphResult) {
|
||||
this.columns = columns;
|
||||
this.rows = rows;
|
||||
this.graphResult = graphResult;
|
||||
this.vertexList = vertexList;
|
||||
this.edgeList = edgeList;
|
||||
this.errMsg = "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>ddlResult</tt>.
|
||||
*
|
||||
|
||||
@ -18,6 +18,7 @@ import com.antgroup.openspg.reasoner.common.Utils;
|
||||
import com.antgroup.openspg.reasoner.common.exception.NotImplementedException;
|
||||
import com.antgroup.openspg.reasoner.common.graph.edge.Direction;
|
||||
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
|
||||
import com.antgroup.openspg.reasoner.common.graph.edge.impl.PathEdge;
|
||||
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
|
||||
import com.antgroup.openspg.reasoner.common.graph.property.impl.VertexProperty;
|
||||
import com.antgroup.openspg.reasoner.common.graph.type.GraphItemType;
|
||||
@ -146,6 +147,9 @@ public class LocalRDG extends RDG<LocalRDG> {
|
||||
|
||||
private final IExecutionRecorder executionRecorder;
|
||||
|
||||
/** carry all tranversal graph data */
|
||||
protected boolean isCarryTraversalGraph = false;
|
||||
|
||||
/** local rdg with graph state */
|
||||
public LocalRDG(
|
||||
GraphState<IVertexId> graphState,
|
||||
@ -154,7 +158,8 @@ public class LocalRDG extends RDG<LocalRDG> {
|
||||
long executorTimeoutMs,
|
||||
String startVertexAlias,
|
||||
String taskId,
|
||||
IExecutionRecorder executionRecorder) {
|
||||
IExecutionRecorder executionRecorder,
|
||||
boolean carryTraversalGraph) {
|
||||
this.graphState = graphState;
|
||||
Pattern startIdPattern = new NodePattern(new PatternElement(startVertexAlias, null, null));
|
||||
for (IVertexId vertexId : startIdList) {
|
||||
@ -167,6 +172,7 @@ public class LocalRDG extends RDG<LocalRDG> {
|
||||
this.startVertexAlias = startVertexAlias;
|
||||
this.taskId = taskId;
|
||||
this.patternMatcher = new PatternMatcher(this.taskId, graphState);
|
||||
this.isCarryTraversalGraph = carryTraversalGraph;
|
||||
|
||||
if (null == executionRecorder) {
|
||||
this.executionRecorder = new EmptyRecorder();
|
||||
@ -534,7 +540,7 @@ public class LocalRDG extends RDG<LocalRDG> {
|
||||
log.info("LocalRDG select,,matchCount=" + rows.size());
|
||||
this.executionRecorder.stageResultWithDesc(
|
||||
"select(" + RunnerUtil.getReadableAsList(as) + ")", this.kgGraphList.size(), "select");
|
||||
return new LocalRow(cols, this, as, rows);
|
||||
return new LocalRow(cols, this, as, rows, getRDGGraph());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -1216,19 +1222,34 @@ public class LocalRDG extends RDG<LocalRDG> {
|
||||
Lists.newArrayList(resultVertexSet), Lists.newArrayList(resultEdgeSet), true);
|
||||
}
|
||||
|
||||
/**
|
||||
* get all RDG Edges and Nodes
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public LocalReasonerResult getRDGGraph() {
|
||||
LocalReasonerResult localReasonerResult = this.getResult();
|
||||
/** add graph to traversal graph */
|
||||
private LocalReasonerResult getRDGGraph() {
|
||||
if (!isCarryTraversalGraph) {
|
||||
return new LocalReasonerResult(Lists.newArrayList(), Lists.newArrayList(), false);
|
||||
}
|
||||
LocalReasonerResult localReasonerResult = getResult();
|
||||
this.kgGraphList.forEach(
|
||||
graph -> {
|
||||
for (String alias : graph.getVertexAlias()) {
|
||||
localReasonerResult.getVertexList().addAll(graph.getVertex(alias));
|
||||
}
|
||||
for (String alias : graph.getEdgeAlias()) {
|
||||
java.util.List<IEdge<IVertexId, IProperty>> edges = graph.getEdge(alias);
|
||||
for (IEdge<IVertexId, IProperty> edge : edges) {
|
||||
if (edge instanceof PathEdge) {
|
||||
if (((PathEdge<?, ?, ?>) edge).getVertexList() != null) {
|
||||
localReasonerResult
|
||||
.getVertexList()
|
||||
.addAll(((PathEdge<IVertexId, IProperty, IProperty>) edge).getVertexList());
|
||||
}
|
||||
if (((PathEdge<?, ?, ?>) edge).getEdgeList() != null) {
|
||||
localReasonerResult
|
||||
.getEdgeList()
|
||||
.addAll(((PathEdge<IVertexId, IProperty, IProperty>) edge).getEdgeList());
|
||||
}
|
||||
}
|
||||
localReasonerResult.getEdgeList().add(edge);
|
||||
}
|
||||
localReasonerResult.getEdgeList().addAll(graph.getEdge(alias));
|
||||
}
|
||||
});
|
||||
|
||||
@ -30,6 +30,7 @@ import scala.collection.JavaConversions;
|
||||
@Slf4j
|
||||
public class LocalRow extends Row<LocalRDG> {
|
||||
private final List<String> columns;
|
||||
private final LocalReasonerResult graphRst;
|
||||
private List<Object[]> rowList;
|
||||
|
||||
/** row implement */
|
||||
@ -37,11 +38,13 @@ public class LocalRow extends Row<LocalRDG> {
|
||||
scala.collection.immutable.List<Var> orderedFields,
|
||||
LocalRDG rdg,
|
||||
scala.collection.immutable.List<String> as,
|
||||
List<Object[]> rows) {
|
||||
List<Object[]> rows,
|
||||
LocalReasonerResult graphResult) {
|
||||
super(orderedFields, rdg);
|
||||
this.columns = new ArrayList<>();
|
||||
this.columns.addAll(Lists.newArrayList(JavaConversions.asJavaCollection(as)));
|
||||
this.rowList = rows;
|
||||
this.graphRst = graphResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -64,6 +67,11 @@ public class LocalRow extends Row<LocalRDG> {
|
||||
|
||||
/** get select result */
|
||||
public LocalReasonerResult getResult() {
|
||||
return new LocalReasonerResult(columns, rowList);
|
||||
return new LocalReasonerResult(
|
||||
columns,
|
||||
rowList,
|
||||
graphRst.getVertexList(),
|
||||
graphRst.getEdgeList(),
|
||||
graphRst.isGraphResult());
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,7 +51,8 @@ public class LocalRDGTest {
|
||||
1000,
|
||||
"",
|
||||
"",
|
||||
null);
|
||||
null,
|
||||
false);
|
||||
localRDG.limit(1);
|
||||
localRDG.show(10);
|
||||
}
|
||||
@ -66,7 +67,8 @@ public class LocalRDGTest {
|
||||
1000,
|
||||
"",
|
||||
"",
|
||||
null);
|
||||
null,
|
||||
false);
|
||||
|
||||
List<KgGraph<IVertexId>> kgGraphList = new ArrayList<>();
|
||||
Map<String, Set<IVertex<IVertexId, IProperty>>> alias2VertexMap = new HashMap<>();
|
||||
|
||||
@ -37,6 +37,61 @@ import scala.Tuple2;
|
||||
|
||||
public class LocalRunnerTest {
|
||||
|
||||
@Test
|
||||
public void doTestFilter() {
|
||||
String dsl =
|
||||
"GraphStructure {\n"
|
||||
+ " (s1:Road.Event)-[p1:subject]-(o1:Road.Researcher)\n"
|
||||
+ " (s1:Road.Event)-[p3:object]-(o3:Road.Area)\n"
|
||||
+ " (s1:Road.Event)-[p2:province]-(o2:Road.AdministrativeRegion)\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ " R0: o1.id == \"张三\"\n"
|
||||
+ " R1: o2.name rlike \"江西省\"\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(o3.name)\n"
|
||||
+ "}";
|
||||
|
||||
LocalReasonerTask task = new LocalReasonerTask();
|
||||
task.setDsl(dsl);
|
||||
task.setGraphLoadClass("com.antgroup.openspg.reasoner.runner.local.loader.TestRoadGraphLoader");
|
||||
task.getParams().put(Constants.SPG_REASONER_PLAN_PRETTY_PRINT_LOGGER_ENABLE, true);
|
||||
task.getParams().put(Constants.SPG_REASONER_LUBE_SUBQUERY_ENABLE, true);
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("张三", "Road.Researcher")));
|
||||
|
||||
// add mock catalog
|
||||
Map<String, scala.collection.immutable.Set<String>> schema = new HashMap<>();
|
||||
schema.put("Road.Researcher", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id")));
|
||||
schema.put(
|
||||
"Road.Event",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id", "kgstartDateRaw")));
|
||||
schema.put("Road.Area", Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id", "name")));
|
||||
schema.put(
|
||||
"Road.AdministrativeRegion",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("id", "name")));
|
||||
|
||||
schema.put(
|
||||
"Road.Event_subject_Road.Researcher",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("holdRet")));
|
||||
schema.put(
|
||||
"Road.Event_object_Road.Area",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("holdRet")));
|
||||
schema.put(
|
||||
"Road.Event_province_Road.AdministrativeRegion",
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet("holdRet")));
|
||||
Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema));
|
||||
catalog.init();
|
||||
task.setCatalog(catalog);
|
||||
|
||||
LocalReasonerRunner runner = new LocalReasonerRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println(result);
|
||||
Assert.assertEquals(result.getRows().size(), 1);
|
||||
Assert.assertEquals(result.getRows().get(0)[0], "江西yy校");
|
||||
clear();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doTestAvg2() {
|
||||
String dsl =
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.runner.local.loader;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
|
||||
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
|
||||
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex;
|
||||
import com.antgroup.openspg.reasoner.runner.local.load.graph.AbstractLocalGraphLoader;
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
|
||||
public class TestRoadGraphLoader extends AbstractLocalGraphLoader {
|
||||
@Override
|
||||
public List<IVertex<String, IProperty>> genVertexList() {
|
||||
return Lists.newArrayList(
|
||||
constructionVertex("张三", "Road.Researcher"),
|
||||
constructionVertex("江西yy校", "Road.Area", "name", "江西yy校"),
|
||||
constructionVertex("湖北xx地", "Road.Area", "name", "湖北xx地"),
|
||||
constructionVertex("江西省", "Road.AdministrativeRegion", "name", "江西省"),
|
||||
constructionVertex("湖北省", "Road.AdministrativeRegion", "name", "湖北省"),
|
||||
constructionVertex("E1", "Road.Event", "kgstartDateRaw", "20230901"),
|
||||
constructionVertex("E2", "Road.Event", "kgstartDateRaw", "20230901"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<IEdge<String, IProperty>> genEdgeList() {
|
||||
return Lists.newArrayList(
|
||||
constructionEdge("E1", "subject", "张三"),
|
||||
constructionEdge("E1", "object", "江西yy校"),
|
||||
constructionEdge("E1", "province", "江西省"),
|
||||
constructionEdge("E2", "subject", "张三"),
|
||||
constructionEdge("E2", "object", "湖北xx地"),
|
||||
constructionEdge("E2", "province", "湖北省"));
|
||||
}
|
||||
}
|
||||
@ -13,6 +13,8 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.runner.local.main;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants;
|
||||
import com.antgroup.openspg.reasoner.runner.ConfigKey;
|
||||
import com.antgroup.openspg.reasoner.runner.local.main.basetest.TransBaseTestData;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@ -35,6 +37,42 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
@Before
|
||||
public void init() {}
|
||||
|
||||
@Test
|
||||
public void test0() {
|
||||
FileMutex.runTestWithMutex(this::doTest0);
|
||||
}
|
||||
|
||||
private void doTest0() {
|
||||
String dsl =
|
||||
"\n"
|
||||
+ "GraphStructure {\n"
|
||||
+ " (A:User)-[p1:trans]->(B:User)\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ " R1: A.id == $id"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id, B.id)\n"
|
||||
+ "}";
|
||||
List<String[]> result =
|
||||
TransBaseTestData.runTestResult(
|
||||
dsl,
|
||||
new HashMap<String, Object>() {
|
||||
{
|
||||
put("id", "'A'");
|
||||
put(Constants.START_ALIAS, "A");
|
||||
put(
|
||||
ConfigKey.KG_REASONER_MOCK_GRAPH_DATA,
|
||||
"Graph {\n" + " A [User]\n" + " B [User]\n" + " A->B [trans]\n" + "}");
|
||||
put(ConfigKey.KG_REASONER_OUTPUT_GRAPH, "true");
|
||||
}
|
||||
});
|
||||
Assert.assertEquals(1, result.size());
|
||||
Assert.assertEquals(2, result.get(0).length);
|
||||
Assert.assertEquals("A", result.get(0)[0]);
|
||||
Assert.assertEquals("B", result.get(0)[1]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1() {
|
||||
FileMutex.runTestWithMutex(this::doTest1);
|
||||
@ -199,4 +237,47 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
Assert.assertEquals("3", result.get(0)[2]);
|
||||
Assert.assertEquals("700", result.get(0)[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test5() {
|
||||
FileMutex.runTestWithMutex(this::doTest5);
|
||||
}
|
||||
|
||||
private void doTest5() {
|
||||
String dsl =
|
||||
"\n"
|
||||
+ "GraphStructure {\n"
|
||||
+ " (A:User)-[p1:trans]->(B:User)-[p2:trans]->(C:User)-[p3:trans]->(A)\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ "p1_amt = cast_type(p1.amount,'long')\n"
|
||||
+ "p2_amt = cast_type(p2.amount,'long')\n"
|
||||
+ "p3_amt = cast_type(p3.amount,'long')\n"
|
||||
+ "R1: A.id == $idSet1\n"
|
||||
+ "R2: B.id in $idSet2\n"
|
||||
+ "R3: C.id in $idSet2\n"
|
||||
+ "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
|
||||
+ "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
|
||||
+ "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
|
||||
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
|
||||
+ "R2('取top2'): top(totalTrans, 2)"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ " get(A.id, B.id, C.id, totalTrans, p1_amt)\n"
|
||||
+ "}";
|
||||
List<String[]> result =
|
||||
TransBaseTestData.runTestResult(
|
||||
dsl,
|
||||
new HashMap<String, Object>() {
|
||||
{
|
||||
put("idSet1", "'1'");
|
||||
put("idSet2", "['2', '3']");
|
||||
}
|
||||
});
|
||||
Assert.assertEquals(1, result.size());
|
||||
Assert.assertEquals(5, result.get(0).length);
|
||||
Assert.assertEquals("1", result.get(0)[0]);
|
||||
Assert.assertEquals("2", result.get(0)[1]);
|
||||
Assert.assertEquals("3", result.get(0)[2]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,6 +93,7 @@ import java.util.function.Predicate;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.fs.shell.Count;
|
||||
import scala.Tuple2;
|
||||
import scala.collection.JavaConversions;
|
||||
|
||||
@ -1363,8 +1364,10 @@ public class RunnerUtil {
|
||||
|
||||
public static void updateUdafDataFromProperty(
|
||||
LazyUdaf udaf, IProperty property, String propertyName) {
|
||||
if (property.isKeyExist(propertyName)) {
|
||||
if (property != null && property.isKeyExist(propertyName)) {
|
||||
udaf.update(property.get(propertyName));
|
||||
} else if (udaf.getName().equalsIgnoreCase(Count.NAME)) {
|
||||
udaf.update(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,18 +14,16 @@
|
||||
package com.antgroup.openspg.reasoner.session
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants
|
||||
import com.antgroup.openspg.reasoner.common.exception.SchemaException
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.impl.PropertyGraphCatalog
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{AggIfOpExpr, BAnd, BinaryOpExpr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{AggIfOpExpr, BAnd, BinaryOpExpr, GetField, Ref, UnaryOpExpr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
|
||||
import com.antgroup.openspg.reasoner.lube.logical.optimizer.LogicalOptimizer
|
||||
import com.antgroup.openspg.reasoner.lube.logical.planning.{LogicalPlanner, LogicalPlannerContext}
|
||||
import com.antgroup.openspg.reasoner.lube.physical.operators._
|
||||
import com.antgroup.openspg.reasoner.lube.physical.rdg.RDG
|
||||
import com.antgroup.openspg.reasoner.parser.OpenSPGDslParser
|
||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner
|
||||
import com.antgroup.openspg.reasoner.util.LoaderUtil
|
||||
import org.scalatest.funspec.AnyFunSpec
|
||||
import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, equal}
|
||||
@ -616,7 +614,9 @@ class ReasonerSessionTests extends AnyFunSpec {
|
||||
a._2 match {
|
||||
case AggIfOpExpr(
|
||||
_,
|
||||
BinaryOpExpr(BAnd, BinaryOpExpr(_, _, _), BinaryOpExpr(_, _, _))) =>
|
||||
BinaryOpExpr(BAnd,
|
||||
UnaryOpExpr(GetField("R1"), Ref("t")),
|
||||
UnaryOpExpr(GetField("R2"), Ref("t")))) =>
|
||||
1
|
||||
case _ => 0
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user