feat(reasoner): support value type inference in udf (#132)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-03-05 11:59:50 +08:00 committed by GitHub
parent eb2590aada
commit 258b0e7dfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 376 additions and 182 deletions

View File

@ -13,6 +13,8 @@
package com.antgroup.openspg.reasoner.common.types package com.antgroup.openspg.reasoner.common.types
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
trait KgType { trait KgType {
def isNullable: Boolean = false def isNullable: Boolean = false
} }
@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType
* @param elementType * @param elementType
*/ */
final case class KTMultiVersion(elementType: KgType) extends KgType final case class KTMultiVersion(elementType: KgType) extends KgType
object KgType {
def getNumberSeq(kgType: KgType): Int = {
kgType match {
case KTInteger => 1
case KTLong => 2
case KTDouble => 3
case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")
}
}
}

View File

@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
IRProperty(s.alias, propertyName) -> IRProperty(s.alias, propertyName) ->
ProjectRule( ProjectRule(
IRProperty(s.alias, propertyName), IRProperty(s.alias, propertyName),
propertyType,
Ref(ddlBlockWithNodes._3.target.alias))))) Ref(ddlBlockWithNodes._3.target.alias)))))
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk)) DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
case AddPredicate(predicate) => case AddPredicate(predicate) =>
@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
ProjectBlock( ProjectBlock(
List.apply(opBlock), List.apply(opBlock),
ProjectFields(Map.apply(lValueName -> ProjectFields(Map.apply(lValueName ->
ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain)))) ProjectRule(lValueName, opChain))))
} }
case AggIfOpExpr(_, _) | AggOpExpr(_, _) => case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
ProjectBlock( ProjectBlock(
@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
lValueName -> lValueName ->
ProjectRule( ProjectRule(
lValueName, lValueName,
exprParser.parseRetType(opChain.curExpr),
opChain.curExpr)))) opChain.curExpr))))
case _ => null case _ => null
} }
@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
List.empty) List.empty)
case _ => case _ =>
rule match { rule match {
case ProjectRule(_, lvalueType, _) => case ProjectRule(_, _) =>
val projectRule = ProjectRule(lvalueFiled, lvalueType, expr) val projectRule = ProjectRule(lvalueFiled, expr)
ProjectBlock( ProjectBlock(
List.apply(preBlock), List.apply(preBlock),
ProjectFields(Map.apply(lvalueFiled -> projectRule))) ProjectFields(Map.apply(lvalueFiled -> projectRule)))
@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal())) exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
val defaultName = "const_output_" + patternParser.getDefaultAliasNum val defaultName = "const_output_" + patternParser.getDefaultAliasNum
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName) val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
(ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true) (ProjectRule(IRVariable(defaultName), expr), columnName, true)
} }
def parseGraphStructure( def parseGraphStructure(
@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr) val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName) val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
( (
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr), ProjectRule(IRVariable(defaultColumnName), expr),
columnName, columnName,
false) false)
} }
@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr) val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName) val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
( (
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr), ProjectRule(IRVariable(defaultColumnName), expr),
columnName, columnName,
false) false)
} }

View File

@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
} }
} }
def parseRetType(expr: Expr): KgType = {
KTObject
}
def parseRuleExpression(ctx: Rule_expressionContext): Rule = { def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
ctx.getChild(0) match { ctx.getChild(0) match {
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c) case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
if (ctx.property_name() != null) { if (ctx.property_name() != null) {
ProjectRule( ProjectRule(
IRProperty(ctx.identifier().getText, ctx.property_name().getText), IRProperty(ctx.identifier().getText, ctx.property_name().getText),
parseRetType(expr),
expr) expr)
} else { } else {
ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr) ProjectRule(IRVariable(ctx.identifier().getText), expr)
} }
} }

View File

@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser
import com.antgroup.openspg.reasoner.common.constants.Constants import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException} import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
import com.antgroup.openspg.reasoner.lube.block._ import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr._ import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph._ import com.antgroup.openspg.reasoner.lube.common.graph._
@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty) print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock] val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal( proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
} }
it("addproperies with constraint") { it("addproperies with constraint") {
@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty) print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock] val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal( proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
} }
it("addproperies2") { it("addproperies2") {
@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock] val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal( proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o"))) ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
} }
it("addproperies") { it("addproperies") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) { val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock] val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal( proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o"))) ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
} }
it("addNode") { it("addNode") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) { val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
.asInstanceOf[AddPredicate] .asInstanceOf[AddPredicate]
.predicate .predicate
.fields .fields
.keySet should contain ("same_domain_num") .keySet should contain("same_domain_num")
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true) blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
blocks(1) blocks(1)
@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl) val block = parser.parse(dsl)
print(block.pretty) print(block.pretty)
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean))) val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o))))) | └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set()))) | └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan))) | └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false))) | └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl) val block = parser.parse(dsl)
print(block.pretty) print(block.pretty)
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b)) val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value))))) * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr))) * └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan))))) * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual))))) * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))

View File

@ -33,6 +33,10 @@
<groupId>com.antgroup.openspg.reasoner</groupId> <groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-common</artifactId> <artifactId>reasoner-common</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-udf</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId> <artifactId>scala-library</artifactId>

View File

@ -13,10 +13,12 @@
package com.antgroup.openspg.reasoner.lube.catalog package com.antgroup.openspg.reasoner.lube.catalog
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException} import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
import scala.collection.mutable import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}
/** /**
@ -27,6 +29,7 @@ import scala.collection.mutable
*/ */
abstract class Catalog() extends Serializable { abstract class Catalog() extends Serializable {
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]() protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
@transient private val udfRepo = UdfMngFactory.getUdfMng
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]() private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
/** /**
@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable {
graphRepository.get(graphName).orNull graphRepository.get(graphName).orNull
} }
def getUdfRepo: UdfMng = udfRepo
/** /**
* Get schema from knowledge graph * Get schema from knowledge graph
*/ */

View File

@ -13,7 +13,6 @@
package com.antgroup.openspg.reasoner.lube.common.rule package com.antgroup.openspg.reasoner.lube.common.rule
import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.common.expr.Expr import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.IRField import com.antgroup.openspg.reasoner.lube.common.graph.IRField
@ -39,13 +38,6 @@ trait Rule extends Cloneable{
*/ */
def getExpr: Expr def getExpr: Expr
/**
* get lvalue type
* @return
*/
def getLvalueType: KgType
/** /**
* get dependencies * get dependencies
* @return * @return

View File

@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
*/ */
override def getExpr: Expr = expr override def getExpr: Expr = expr
/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = KTBoolean
override def andRule(rule: Rule): Rule = { override def andRule(rule: Rule): Rule = {
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr) val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
* @param lvalueType * @param lvalueType
* @param expr * @param expr
*/ */
final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr) final case class ProjectRule(output: IRField, expr: Expr)
extends DependencyRule { extends DependencyRule {
/** /**
@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
*/ */
override def getExpr: Expr = expr override def getExpr: Expr = expr
/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = lvalueType
override def andRule(rule: Rule): Rule = { override def andRule(rule: Rule): Rule = {
throw UnsupportedOperationException("ProjectRule cannot support andRule") throw UnsupportedOperationException("ProjectRule cannot support andRule")

View File

@ -138,7 +138,7 @@ object RuleUtils {
case logicRule: LogicRule => case logicRule: LogicRule =>
LogicRule(ruleNameStr, logicRule.ruleExplain, expr) LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
case _ => case _ =>
ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr) ProjectRule(IRVariable(ruleNameStr), expr)
} }
val oldDependencies = rule.getDependencies val oldDependencies = rule.getDependencies
if (oldDependencies != null) { if (oldDependencies != null) {
@ -162,7 +162,7 @@ object RuleUtils {
case logicRule: LogicRule => case logicRule: LogicRule =>
LogicRule(rule.getName, logicRule.ruleExplain, expr) LogicRule(rule.getName, logicRule.ruleExplain, expr)
case _ => case _ =>
ProjectRule(rule.getOutput, rule.getLvalueType, expr) ProjectRule(rule.getOutput, expr)
} }
val oldDependencies = rule.getDependencies val oldDependencies = rule.getDependencies
if (oldDependencies != null) { if (oldDependencies != null) {

View File

@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
false)))), false)))),
ProjectFields( ProjectFields(
Map.apply(IRVariable("total_domain_num") -> Map.apply(IRVariable("total_domain_num") ->
ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o"))))) ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
val p = BlockUtils.transBlock2Graph(block) val p = BlockUtils.transBlock2Graph(block)
p.size should equal(1) p.size should equal(1)
p.head.graphPattern.nodes.size should equal(2) p.head.graphPattern.nodes.size should equal(2)
@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
it("rename_rule") { it("rename_rule") {
val rule = ProjectRule( val rule = ProjectRule(
IRVariable("a"), IRVariable("a"),
KTObject,
BinaryOpExpr( BinaryOpExpr(
BEqual, BEqual,
UnaryOpExpr(GetField("birthDate"), Ref("e")), UnaryOpExpr(GetField("birthDate"), Ref("e")),
@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
it("variable_rule") { it("variable_rule") {
val rule = ProjectRule( val rule = ProjectRule(
IRVariable("a"), IRVariable("a"),
KTObject,
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b"))) BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
val rule2 = ProjectRule( val rule2 = ProjectRule(
IRVariable("b"), IRVariable("b"),
KTObject,
BinaryOpExpr( BinaryOpExpr(
BEqual, BEqual,
UnaryOpExpr(GetField("attr1"), Ref("e")), UnaryOpExpr(GetField("attr1"), Ref("e")),
@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
def getDependenceRule(): Rule = { def getDependenceRule(): Rule = {
val r0 = ProjectRule( val r0 = ProjectRule(
IRVariable("r0"), IRVariable("r0"),
KTLong,
BinaryOpExpr(BAssign, Ref("r0"), VLong("123")) BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
) )
val r1 = LogicRule( val r1 = LogicRule(
@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
val r0 = LogicRule("tmp", "", val r0 = LogicRule("tmp", "",
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10"))) BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
val r = ProjectRule(IRVariable("g"), val r = ProjectRule(IRVariable("g"),
KTLong,
OpChainExpr( OpChainExpr(
GraphAggregatorExpr( GraphAggregatorExpr(
"unresolved_default_path", "unresolved_default_path",

View File

@ -13,9 +13,16 @@
package com.antgroup.openspg.reasoner.lube.logical package com.antgroup.openspg.reasoner.lube.logical
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.trees.BottomUp import com.antgroup.openspg.reasoner.common.trees.BottomUp
import com.antgroup.openspg.reasoner.common.types._
import com.antgroup.openspg.reasoner.lube.common.expr._ import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.Rule import com.antgroup.openspg.reasoner.lube.common.rule.Rule
import com.antgroup.openspg.reasoner.udf.UdfMng
object ExprUtil { object ExprUtil {
@ -46,15 +53,13 @@ object ExprUtil {
} }
def needResolved(rule: Expr): Boolean = { def needResolved(rule: Expr): Boolean = {
!getReferProperties(rule).filter(_._1 == null).isEmpty !getReferProperties(rule).filter(_._1 == null).isEmpty
} }
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = { def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
def rewriter: PartialFunction[Expr, Expr] = { def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
case Ref(refName) =>
if (replaceVar.contains(refName)) { if (replaceVar.contains(refName)) {
val propertyVar = replaceVar(refName) val propertyVar = replaceVar(refName)
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name)) UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
@ -76,4 +81,104 @@ object ExprUtil {
newRule newRule
} }
def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
expr match {
case Ref(name) =>
if (referVars.contains(IRVariable(name))) {
referVars(IRVariable(name))
} else {
KTObject
}
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
case BinaryOpExpr(name, l, r) =>
name match {
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |
BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign =>
KTBoolean
case BAdd | BSub | BMul | BDiv | BMod =>
val left = getTargetType(l, referVars, udfRepo)
val right = getTargetType(r, referVars, udfRepo)
getUpperType(left, right)
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case UnaryOpExpr(name, arg) =>
name match {
case Not | Exists => KTBoolean
case Abs | Neg => getTargetType(arg, referVars, udfRepo)
case Floor | Ceil => KTDouble
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case FunctionExpr(name, funcArgs) =>
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
name match {
case "rule_value" => types(1)
case "cast_type" | "Cast" =>
funcArgs(1).asInstanceOf[VString].value match {
case "int" | "bigint" | "long" => KTLong
case "float" | "double" => KTDouble
case "varchar" | "string" => KTString
case _ =>
throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}")
}
case _ =>
val udf = udfRepo.getUdfMeta(name, types.asJava)
if (udf != null) {
udf.getResultType
} else {
throw UnsupportedOperationException(s"cannot find UDF: ${name}")
}
}
case AggOpExpr(name, args) =>
name match {
case Min | Max | Sum | Avg | First | Accumulate(_) =>
getTargetType(args, referVars, udfRepo)
case StrJoin(_) => KTString
case Count => KTLong
case AggUdf(name, _) =>
val types = getTargetType(args.head, referVars, udfRepo)
val udf = udfRepo.getUdafMeta(name, types)
if (udf != null) {
udf.getResultType
} else {
throw UnsupportedOperationException(s"cannot find UDAF ${name}")
}
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo)
case ListOpExpr(name, _) =>
name match {
case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo)
case Constraint(_, _, _) => KTBoolean
case Get(_) | Slice(_, _) => KTObject
}
case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo)
case VNull | VString(_) => KTString
case VLong(_) => KTLong
case VDouble(_) => KTDouble
case VBoolean(_) => KTBoolean
case VList(_, listType) => KTList(listType)
case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}")
}
}
def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
val newReferVars = new mutable.HashMap[IRField, KgType]
newReferVars.++=(referVars)
for (r <- rule.getDependencies) {
newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo))
}
getTargetType(rule.getExpr, newReferVars.toMap, udfRepo)
}
private def getUpperType(left: KgType, right: KgType): KgType = {
val l = KgType.getNumberSeq(left)
val r = KgType.getNumberSeq(right)
if (l >= r) {
left
} else {
right
}
}
} }

View File

@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.types.KTObject import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject}
import com.antgroup.openspg.reasoner.lube.block.Aggregations import com.antgroup.openspg.reasoner.lube.block.Aggregations
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
class AggregationPlanner(group: List[IRField], aggregations: Aggregations) { class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit
context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = { def plan(dependency: LogicalOperator): LogicalOperator = {
val groupVar: List[Var] = group.map(toVar(_, dependency.solved)) val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
v v
} }
}) })
val referTypes = new mutable.HashMap[IRField, KgType]()
for (v <- ruleFields) {
resolved.getVar(v.name) match {
case p: PropertyVar => referTypes.put(v, p.field.kgType)
case node: NodeVar =>
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case edge: EdgeVar =>
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case _ => throw UnsupportedOperationException(s"cannot support $v")
}
}
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo)
val renameVar = ruleFields val renameVar = ruleFields
.filter(_.isInstanceOf[IRVariable]) .filter(_.isInstanceOf[IRVariable])
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable])))) .map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
val field = getAggregateTarget(referFields, resolved, dependency) val field = getAggregateTarget(referFields, resolved, dependency)
field match { field match {
case IRNode(alias, _) => case IRNode(alias, _) =>
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true)) val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr) aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case IREdge(alias, _) => case IREdge(alias, _) =>
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) { if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr) aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
} else { } else {
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true)) val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
aggMap.put(propertyVar, newAggExpr) aggMap.put(propertyVar, newAggExpr)
} }
case IRVariable(alias) => case IRVariable(alias) =>
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias)) val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
val propertyVar = val propertyVar =
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true)) PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr) aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case _ => case _ =>

View File

@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.types.KTObject import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.block.ProjectFields import com.antgroup.openspg.reasoner.lube.block.ProjectFields
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr} import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
import com.antgroup.openspg.reasoner.lube.common.graph._ import com.antgroup.openspg.reasoner.lube.common.graph._
import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var} import com.antgroup.openspg.reasoner.lube.common.rule.Rule
import com.antgroup.openspg.reasoner.lube.logical._
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator} import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
class ProjectPlanner(projects: ProjectFields) { class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = { def plan(dependency: LogicalOperator): LogicalOperator = {
val projectMap = new mutable.HashMap[Var, Expr]() val projectMap = new mutable.HashMap[Var, Expr]()
@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) {
v v
} }
}) })
val propertyVar = getTarget(rule._1, referVars, resolved, dependency) val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency)
val transformer = new Rule2ExprTransformer() val transformer = new Rule2ExprTransformer()
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable]) val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
val replaceVar = reference val replaceVar = reference
@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) {
private def getTarget( private def getTarget(
left: IRField, left: IRField,
referVars: List[IRField], referVars: List[IRField],
rule: Rule,
resolved: SolvedModel, resolved: SolvedModel,
dependency: LogicalOperator): PropertyVar = { dependency: LogicalOperator): PropertyVar = {
val referTypes = new mutable.HashMap[IRField, KgType]()
for (v <- referVars) {
resolved.getVar(v.name) match {
case p: PropertyVar => referTypes.put(v, p.field.kgType)
case node: NodeVar =>
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case edge: EdgeVar =>
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case _ =>
}
}
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo)
left match { left match {
case IRVariable(name) => case IRVariable(name) =>
if (referVars.size == 1) { if (referVars.size == 1) {
PropertyVar(referVars.head.name, new Field(name, KTObject, true)) PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
} else { } else {
val aliasSet = new mutable.HashSet[String]() val aliasSet = new mutable.HashSet[String]()
for (rVar <- referVars) { for (rVar <- referVars) {
@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) {
} }
} }
val targetAlias = getTargetAlias(aliasSet.toSet, dependency) val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
PropertyVar(targetAlias, new Field(left.name, KTObject, true)) PropertyVar(targetAlias, new Field(left.name, ruleRetType, true))
} }
case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true)) case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true))
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true)) case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true)) case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
case _ => throw UnsupportedOperationException(s"cannot support $left") case _ => throw UnsupportedOperationException(s"cannot support $left")
} }

View File

@ -13,11 +13,13 @@
package com.antgroup.openspg.reasoner.lube.logical package com.antgroup.openspg.reasoner.lube.logical
import com.antgroup.openspg.reasoner.common.types.KTInteger import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr} import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule} import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser
import com.antgroup.openspg.reasoner.udf.UdfMngFactory
import org.scalatest.funspec.AnyFunSpec import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal} import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec {
val replaceMap = Map.apply( val replaceMap = Map.apply(
"a" -> PropertyVar("b", new Field("id", KTInteger, true)), "a" -> PropertyVar("b", new Field("id", KTInteger, true)),
"c" -> PropertyVar("d", new Field("id", KTInteger, true))) "c" -> PropertyVar("d", new Field("id", KTInteger, true)))
val r1 = ProjectRule(IRVariable("test"), val r1 = ProjectRule(IRVariable("test"), Ref("a"))
KTInteger, Ref("a")) val r2 = ProjectRule(IRVariable("test"), Ref("c"))
val r2 = ProjectRule(IRVariable("test"),
KTInteger, Ref("c"))
r1.addDependency(r2) r1.addDependency(r2)
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap) val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
print(newRule.getExpr.pretty) print(newRule.getExpr.pretty)
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true) newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true) newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
newRule.getExpr.asInstanceOf[UnaryOpExpr] newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b")
.arg.asInstanceOf[Ref].refName should equal("b") }
it("test expr output type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
var expr: Expr = parser.parse("A.age")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
expr = parser.parse("A.age + 1")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong)
expr = parser.parse("A.age > 10")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean)
expr = parser.parse("floor(A.age)")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble)
expr = parser.parse("abs(A.age)")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
expr = parser.parse("concat(A.age, \",\")")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
expr = parser.parse("cast_type(A.age, 'string')")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
}
it("test rule output type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10"))
val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age"))
rule.addDependency(r1)
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
} }
} }

View File

@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec {
println(optimizedLogicalPlan.pretty) println(optimizedLogicalPlan.pretty)
} }
it("online") {
val dsl = """Define (s:User)-[p:redPacket]->(o:Int) {
| GraphStructure {
| (s)
| }
| Rule {
|LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m
|Latest30DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest7DayPayCount=s.ngfe_tag__pay_cnt_d
|LatestTTT = Latest7DayPayCount.accumulate(+)
|LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0
|Latest7DayPayCountSum=Latest7DayPayCount
|Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0
|HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount
|HighLost("高频降频100%"):HighReduceValue == 1
|HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1
|HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8
|HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5
|HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3
|Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest30DayPayDayCount=size(Latest30DayPayCount)
|Latest3060DayPayDayCount=size(Latest3060DayPayCount)
|High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13
|High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13
|Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest6090DayPayDayCount=size(Latest6090DayPayCount)
|Latest60DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest60DayPayDayCount=size(Latest60DayPayCount)
|Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0
|Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0
|HistoricallyPay=s.ngfe_tag__pay_cnt_total
|HistoricallyPayCount=size(HistoricallyPay)
|New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0
|Latest90DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest90DayPayDayCount=size(Latest90DayPayCount)
|Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0
|o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost"))
| }
|}""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
println(block.pretty)
val schema: Map[String, Set[String]] = Map.apply(
"User" -> Set
.apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d"))
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(catalog, parser, Map.empty)
val logicalPlan = LogicalPlanner.plan(block)
println(logicalPlan.head.pretty)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head)
println(optimizedLogicalPlan.pretty)
}
it("test start flag") { it("test start flag") {
val dsl = val dsl =
""" """
@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec {
| (s)-[p2:followPM]->(o) | (s)-[p2:followPM]->(o)
|} |}
|Rule { |Rule {
| c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0) | c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0)
| |
|} |}
|Action { |Action {

View File

@ -41,7 +41,7 @@
<properties> <properties>
<antlr4.version>4.8</antlr4.version> <antlr4.version>4.8</antlr4.version>
<fastjson.version>1.2.71_noneautotype</fastjson.version> <fastjson.version>1.2.71_noneautotype</fastjson.version>
<geotools.version>27.0</geotools.version> <geotools.version>28.0</geotools.version>
<google.s2.version>2.0.0</google.s2.version> <google.s2.version>2.0.0</google.s2.version>
<hadoop.version>2.7.2</hadoop.version> <hadoop.version>2.7.2</hadoop.version>
<hive.version>3.1.0</hive.version> <hive.version>3.1.0</hive.version>

View File

@ -137,7 +137,7 @@ public class KgReasonerABMLocalTest {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n" + " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n" + " }\n"
+ " Rule {\n" + " Rule {\n"
+ " \tv = t.stock/t.total\n" + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n" + " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n" + " o = v\n"
+ " }\n" + " }\n"
@ -197,7 +197,7 @@ public class KgReasonerABMLocalTest {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n" + " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n" + " }\n"
+ " Rule {\n" + " Rule {\n"
+ " \tv = t.stock/t.total\n" + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n" + " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n" + " o = v\n"
+ " }\n" + " }\n"

View File

@ -50,9 +50,12 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id == $idSet1\n" + "R1: A.id == $idSet1\n"
+ "R2: B.id in $idSet2\n" + "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n" + "R3: C.id in $idSet2\n"
+ "totalTrans1 = group(A,B,C).sum(p1.amount)\n" + "p1_amt = cast_type(p1.amount,'long')\n"
+ "totalTrans2 = group(A,B,C).sum(p2.amount)\n" + "p2_amt = cast_type(p2.amount,'long')\n"
+ "totalTrans3 = group(A,B,C).sum(p3.amount)\n" + "p3_amt = cast_type(p3.amount,'long')\n"
+ "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
+ "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
+ "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n" + "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
+ "R2('取top2'): top(totalTrans, 2)" + "R2('取top2'): top(totalTrans, 2)"
+ "}\n" + "}\n"
@ -90,7 +93,7 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id in $idSet1\n" + "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n" + "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n" + "R3: C.id in $idSet2\n"
+ "totalTrans = p1.amount + p2.amount + p3.amount\n" + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)" + "R2('取top2'): top(totalTrans, 3)"
+ "}\n" + "}\n"
+ "Action {\n" + "Action {\n"
@ -127,7 +130,7 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id == $idSet1\n" + "R1: A.id == $idSet1\n"
+ "R2: B.id == $idSet2\n" + "R2: B.id == $idSet2\n"
+ "R3: C.id == $idSet3\n" + "R3: C.id == $idSet3\n"
+ "totalTrans = p1.amount + p2.amount + p3.amount\n" + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)" + "R2('取top2'): top(totalTrans, 3)"
+ "}\n" + "}\n"
+ "Action {\n" + "Action {\n"
@ -169,9 +172,12 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id in $idSet1\n" + "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n" + "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n" + "R3: C.id in $idSet2\n"
+ "t1 = group(A,B,C).sum(p1.amount)\n" + "p1_amt = cast_type(p1.amount,'long')\n"
+ "t2 = group(A,B,C).sum(p2.amount)\n" + "p2_amt = cast_type(p2.amount,'long')\n"
+ "t3 = group(A,B,C).sum(p3.amount)\n" + "p3_amt = cast_type(p3.amount,'long')\n"
+ "t1 = group(A,B,C).sum(p1_amt)\n"
+ "t2 = group(A,B,C).sum(p2_amt)\n"
+ "t3 = group(A,B,C).sum(p3_amt)\n"
+ "totalSum = t1 + t2 + t3" + "totalSum = t1 + t2 + t3"
+ "}\n" + "}\n"
+ "Action {\n" + "Action {\n"

View File

@ -403,7 +403,7 @@ public class KgReasonerTopKFilmTest {
+ " o->star [starOfFilm] as sf2\n" + " o->star [starOfFilm] as sf2\n"
+ "}\n" + "}\n"
+ "Rule {\n" + "Rule {\n"
+ "total = sf.joinTs + sf2.joinTs\n" + "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n"
+ "R2: top(total, 1)\n" + "R2: top(total, 1)\n"
+ "}\n" + "}\n"
+ "Action {\n" + "Action {\n"

View File

@ -25,13 +25,7 @@ import com.antgroup.openspg.reasoner.kggraph.AggregationSchemaInfo;
import com.antgroup.openspg.reasoner.kggraph.KgGraph; import com.antgroup.openspg.reasoner.kggraph.KgGraph;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl; import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters; import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr; import com.antgroup.openspg.reasoner.lube.common.expr.*;
import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr;
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern; import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar; import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
import com.antgroup.openspg.reasoner.lube.logical.NodeVar; import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
@ -41,6 +35,7 @@ import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf; import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdafMeta; import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner; import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
@ -353,26 +348,14 @@ public class KgGraphAggregateImpl implements Serializable {
if (null != udafInitParams) { if (null != udafInitParams) {
udaf.initialize(udafInitParams); udaf.initialize(udafInitParams);
} }
ParsedAggEle parsedAggEle;
String sourceAlias = null;
String sourcePropertyName = null;
Set<String> aliasList = aggInfo.getExprUseAliasSet(); Set<String> aliasList = aggInfo.getExprUseAliasSet();
if (aliasList.size() <= 1) { if (aliasList.size() <= 1) {
Expr sourceExpr = aggInfo.getAggEle(); parsedAggEle = aggInfo.getParsedAggEle();
// aggregate by vertex subgraph
if (sourceExpr instanceof Ref) {
Ref sourceRef = (Ref) sourceExpr;
sourceAlias = sourceRef.refName();
} else if (sourceExpr instanceof UnaryOpExpr) {
UnaryOpExpr expr = (UnaryOpExpr) sourceExpr;
GetField getField = (GetField) expr.name();
sourceAlias = ((Ref) expr.arg()).refName();
sourcePropertyName = getField.fieldName();
}
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) { if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) { for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) { if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
StringBuffer sb = new StringBuffer(); StringBuilder sb = new StringBuilder();
for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) { for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) {
sb.append(valueFiltered2).append("## "); sb.append(valueFiltered2).append("## ");
} }
@ -381,19 +364,41 @@ public class KgGraphAggregateImpl implements Serializable {
} }
} }
} }
String finalSourcePropertyName = sourcePropertyName; String finalSourcePropertyName = parsedAggEle.getSourcePropertyName();
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) { for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
if (valueFiltered.getVertexAlias().contains(sourceAlias)) { if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) {
List<IVertex<IVertexId, IProperty>> vertexList = valueFiltered.getVertex(sourceAlias); List<IVertex<IVertexId, IProperty>> vertexList =
if (sourcePropertyName == null) { valueFiltered.getVertex(parsedAggEle.getSourceAlias());
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
vertexList.forEach(
vertex -> {
Map<String, Object> context =
RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias());
Object value =
RuleRunner.getInstance()
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
udaf.update(value);
});
} else if (finalSourcePropertyName == null) {
vertexList.forEach(udaf::update); vertexList.forEach(udaf::update);
} else { } else {
vertexList.forEach( vertexList.forEach(
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName)); v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
} }
} else { } else {
List<IEdge<IVertexId, IProperty>> edgeList = valueFiltered.getEdge(sourceAlias); List<IEdge<IVertexId, IProperty>> edgeList =
if (sourcePropertyName == null) { valueFiltered.getEdge(parsedAggEle.getSourceAlias());
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
edgeList.forEach(
edge -> {
Map<String, Object> context =
RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias());
Object value =
RuleRunner.getInstance()
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
udaf.update(value);
});
} else if (finalSourcePropertyName == null) {
edgeList.forEach(udaf::update); edgeList.forEach(udaf::update);
} else { } else {
edgeList.forEach( edgeList.forEach(

View File

@ -23,6 +23,7 @@ import java.io.Serializable;
import java.util.List; import java.util.List;
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable { public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
private final ParsedAggEle parsedAggEle;
/** /**
* constructor * constructor
@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
*/ */
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) { public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator); super(taskId, var, aggregator);
parsedAggEle = parsedAggEle();
} }
/** /**
@ -58,4 +60,9 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
public Expr getAggEle() { public Expr getAggEle() {
return getAggIfOpExpr().aggOpExpr().aggEleExpr(); return getAggIfOpExpr().aggOpExpr().aggEleExpr();
} }
@Override
public ParsedAggEle getParsedAggEle() {
return parsedAggEle;
}
} }

View File

@ -22,9 +22,11 @@ import java.io.Serializable;
import java.util.List; import java.util.List;
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable { public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
private final ParsedAggEle parsedAggEle;
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) { public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator); super(taskId, var, aggregator);
parsedAggEle = parsedAggEle();
} }
public AggOpExpr getAggOpExpr() { public AggOpExpr getAggOpExpr() {
@ -45,4 +47,9 @@ public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Se
public Expr getAggEle() { public Expr getAggEle() {
return getAggOpExpr().aggEleExpr(); return getAggOpExpr().aggEleExpr();
} }
@Override
public ParsedAggEle getParsedAggEle() {
return this.parsedAggEle;
}
} }

View File

@ -19,6 +19,9 @@ import com.antgroup.openspg.reasoner.lube.common.expr.AggUdf;
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator; import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet; import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr; import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar; import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
import com.antgroup.openspg.reasoner.lube.logical.Var; import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils; import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
@ -139,6 +142,9 @@ public abstract class BaseGroupProcess implements Serializable {
*/ */
public abstract Expr getAggEle(); public abstract Expr getAggEle();
/** get parsed agg ele */
public abstract ParsedAggEle getParsedAggEle();
public Set<String> parseExprUseAliasSet() { public Set<String> parseExprUseAliasSet() {
scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle()); scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList)); return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
@ -148,6 +154,26 @@ public abstract class BaseGroupProcess implements Serializable {
return WareHouseUtils.getRuleList(getAggEle()); return WareHouseUtils.getRuleList(getAggEle());
} }
protected ParsedAggEle parsedAggEle() {
String sourceAlias = null;
String sourcePropertyName = null;
List<String> exprStrList = null;
Expr aggEle = getAggEle();
if (aggEle instanceof Ref) {
Ref sourceRef = (Ref) aggEle;
sourceAlias = sourceRef.refName();
} else if (aggEle instanceof UnaryOpExpr) {
UnaryOpExpr expr = (UnaryOpExpr) aggEle;
GetField getField = (GetField) expr.name();
sourceAlias = ((Ref) expr.arg()).refName();
sourcePropertyName = getField.fieldName();
} else if (1 == this.exprUseAliasSet.size()) {
sourceAlias = this.exprUseAliasSet.iterator().next();
exprStrList = this.exprRuleString;
}
return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList);
}
/** /**
* getter * getter
* *

View File

@ -0,0 +1,40 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.rdg.common.groupProcess;
import java.util.List;
public class ParsedAggEle {
private final String sourceAlias;
private final String sourcePropertyName;
private final List<String> exprStrList;
public ParsedAggEle(String sourceAlias, String sourcePropertyName, List<String> exprStrList) {
this.sourceAlias = sourceAlias;
this.sourcePropertyName = sourcePropertyName;
this.exprStrList = exprStrList;
}
public String getSourceAlias() {
return sourceAlias;
}
public String getSourcePropertyName() {
return sourcePropertyName;
}
public List<String> getExprStrList() {
return exprStrList;
}
}