mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-06-27 03:20:10 +00:00
feat(reasoner): support value type inference in udf (#132)
Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
parent
eb2590aada
commit
258b0e7dfb
@ -13,6 +13,8 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.common.types
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
|
||||
trait KgType {
|
||||
def isNullable: Boolean = false
|
||||
}
|
||||
@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType
|
||||
* @param elementType
|
||||
*/
|
||||
final case class KTMultiVersion(elementType: KgType) extends KgType
|
||||
|
||||
object KgType {
|
||||
|
||||
def getNumberSeq(kgType: KgType): Int = {
|
||||
kgType match {
|
||||
case KTInteger => 1
|
||||
case KTLong => 2
|
||||
case KTDouble => 3
|
||||
case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
IRProperty(s.alias, propertyName) ->
|
||||
ProjectRule(
|
||||
IRProperty(s.alias, propertyName),
|
||||
propertyType,
|
||||
Ref(ddlBlockWithNodes._3.target.alias)))))
|
||||
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
|
||||
case AddPredicate(predicate) =>
|
||||
@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
ProjectBlock(
|
||||
List.apply(opBlock),
|
||||
ProjectFields(Map.apply(lValueName ->
|
||||
ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain))))
|
||||
ProjectRule(lValueName, opChain))))
|
||||
}
|
||||
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
|
||||
ProjectBlock(
|
||||
@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
lValueName ->
|
||||
ProjectRule(
|
||||
lValueName,
|
||||
exprParser.parseRetType(opChain.curExpr),
|
||||
opChain.curExpr))))
|
||||
case _ => null
|
||||
}
|
||||
@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
List.empty)
|
||||
case _ =>
|
||||
rule match {
|
||||
case ProjectRule(_, lvalueType, _) =>
|
||||
val projectRule = ProjectRule(lvalueFiled, lvalueType, expr)
|
||||
case ProjectRule(_, _) =>
|
||||
val projectRule = ProjectRule(lvalueFiled, expr)
|
||||
ProjectBlock(
|
||||
List.apply(preBlock),
|
||||
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
|
||||
@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
|
||||
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
|
||||
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
|
||||
(ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true)
|
||||
(ProjectRule(IRVariable(defaultName), expr), columnName, true)
|
||||
}
|
||||
|
||||
def parseGraphStructure(
|
||||
@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
val defaultColumnName = parseExpr2ElementStr(expr)
|
||||
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
|
||||
(
|
||||
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
|
||||
ProjectRule(IRVariable(defaultColumnName), expr),
|
||||
columnName,
|
||||
false)
|
||||
}
|
||||
@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
|
||||
val defaultColumnName = parseExpr2ElementStr(expr)
|
||||
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
|
||||
(
|
||||
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
|
||||
ProjectRule(IRVariable(defaultColumnName), expr),
|
||||
columnName,
|
||||
false)
|
||||
}
|
||||
|
@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
def parseRetType(expr: Expr): KgType = {
|
||||
KTObject
|
||||
}
|
||||
|
||||
def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
|
||||
ctx.getChild(0) match {
|
||||
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
|
||||
@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
|
||||
if (ctx.property_name() != null) {
|
||||
ProjectRule(
|
||||
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
|
||||
parseRetType(expr),
|
||||
expr)
|
||||
} else {
|
||||
ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr)
|
||||
ProjectRule(IRVariable(ctx.identifier().getText), expr)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants
|
||||
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
|
||||
import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
|
||||
import com.antgroup.openspg.reasoner.lube.block._
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr._
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||
@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
print(block.pretty)
|
||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||
proj.projects.items.head._2 should equal(
|
||||
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
|
||||
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
|
||||
}
|
||||
|
||||
it("addproperies with constraint") {
|
||||
@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
print(block.pretty)
|
||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||
proj.projects.items.head._2 should equal(
|
||||
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
|
||||
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
|
||||
}
|
||||
|
||||
it("addproperies2") {
|
||||
@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||
proj.projects.items.head._2 should equal(
|
||||
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
|
||||
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
|
||||
}
|
||||
it("addproperies") {
|
||||
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
||||
@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||
proj.projects.items.head._2 should equal(
|
||||
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
|
||||
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
|
||||
}
|
||||
it("addNode") {
|
||||
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
||||
@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
.asInstanceOf[AddPredicate]
|
||||
.predicate
|
||||
.fields
|
||||
.keySet should contain ("same_domain_num")
|
||||
.keySet should contain("same_domain_num")
|
||||
|
||||
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
|
||||
blocks(1)
|
||||
@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
val block = parser.parse(dsl)
|
||||
print(block.pretty)
|
||||
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
|
||||
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o)))))
|
||||
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
|
||||
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
|
||||
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
|
||||
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
|
||||
@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
||||
val block = parser.parse(dsl)
|
||||
print(block.pretty)
|
||||
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
|
||||
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
|
||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
|
||||
|
@ -33,6 +33,10 @@
|
||||
<groupId>com.antgroup.openspg.reasoner</groupId>
|
||||
<artifactId>reasoner-common</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.antgroup.openspg.reasoner</groupId>
|
||||
<artifactId>reasoner-udf</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
|
@ -13,10 +13,12 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.lube.catalog
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
|
||||
import scala.collection.mutable
|
||||
import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}
|
||||
|
||||
|
||||
/**
|
||||
@ -27,6 +29,7 @@ import scala.collection.mutable
|
||||
*/
|
||||
abstract class Catalog() extends Serializable {
|
||||
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
|
||||
@transient private val udfRepo = UdfMngFactory.getUdfMng
|
||||
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
|
||||
|
||||
/**
|
||||
@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable {
|
||||
graphRepository.get(graphName).orNull
|
||||
}
|
||||
|
||||
def getUdfRepo: UdfMng = udfRepo
|
||||
|
||||
/**
|
||||
* Get schema from knowledge graph
|
||||
*/
|
||||
|
@ -13,7 +13,6 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.lube.common.rule
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRField
|
||||
|
||||
@ -39,13 +38,6 @@ trait Rule extends Cloneable{
|
||||
*/
|
||||
def getExpr: Expr
|
||||
|
||||
|
||||
/**
|
||||
* get lvalue type
|
||||
* @return
|
||||
*/
|
||||
def getLvalueType: KgType
|
||||
|
||||
/**
|
||||
* get dependencies
|
||||
* @return
|
||||
|
@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
|
||||
*/
|
||||
override def getExpr: Expr = expr
|
||||
|
||||
/**
|
||||
* get lvalue type
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
override def getLvalueType: KgType = KTBoolean
|
||||
|
||||
override def andRule(rule: Rule): Rule = {
|
||||
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
|
||||
|
||||
@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
|
||||
* @param lvalueType
|
||||
* @param expr
|
||||
*/
|
||||
final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
|
||||
final case class ProjectRule(output: IRField, expr: Expr)
|
||||
extends DependencyRule {
|
||||
|
||||
/**
|
||||
@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
|
||||
*/
|
||||
override def getExpr: Expr = expr
|
||||
|
||||
/**
|
||||
* get lvalue type
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
override def getLvalueType: KgType = lvalueType
|
||||
|
||||
override def andRule(rule: Rule): Rule = {
|
||||
throw UnsupportedOperationException("ProjectRule cannot support andRule")
|
||||
|
@ -138,7 +138,7 @@ object RuleUtils {
|
||||
case logicRule: LogicRule =>
|
||||
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
|
||||
case _ =>
|
||||
ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr)
|
||||
ProjectRule(IRVariable(ruleNameStr), expr)
|
||||
}
|
||||
val oldDependencies = rule.getDependencies
|
||||
if (oldDependencies != null) {
|
||||
@ -162,7 +162,7 @@ object RuleUtils {
|
||||
case logicRule: LogicRule =>
|
||||
LogicRule(rule.getName, logicRule.ruleExplain, expr)
|
||||
case _ =>
|
||||
ProjectRule(rule.getOutput, rule.getLvalueType, expr)
|
||||
ProjectRule(rule.getOutput, expr)
|
||||
}
|
||||
val oldDependencies = rule.getDependencies
|
||||
if (oldDependencies != null) {
|
||||
|
@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
|
||||
false)))),
|
||||
ProjectFields(
|
||||
Map.apply(IRVariable("total_domain_num") ->
|
||||
ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o")))))
|
||||
ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
|
||||
val p = BlockUtils.transBlock2Graph(block)
|
||||
p.size should equal(1)
|
||||
p.head.graphPattern.nodes.size should equal(2)
|
||||
@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
|
||||
it("rename_rule") {
|
||||
val rule = ProjectRule(
|
||||
IRVariable("a"),
|
||||
KTObject,
|
||||
BinaryOpExpr(
|
||||
BEqual,
|
||||
UnaryOpExpr(GetField("birthDate"), Ref("e")),
|
||||
@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
|
||||
it("variable_rule") {
|
||||
val rule = ProjectRule(
|
||||
IRVariable("a"),
|
||||
KTObject,
|
||||
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
|
||||
|
||||
val rule2 = ProjectRule(
|
||||
IRVariable("b"),
|
||||
KTObject,
|
||||
BinaryOpExpr(
|
||||
BEqual,
|
||||
UnaryOpExpr(GetField("attr1"), Ref("e")),
|
||||
@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
|
||||
def getDependenceRule(): Rule = {
|
||||
val r0 = ProjectRule(
|
||||
IRVariable("r0"),
|
||||
KTLong,
|
||||
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
|
||||
)
|
||||
val r1 = LogicRule(
|
||||
@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
|
||||
val r0 = LogicRule("tmp", "",
|
||||
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
|
||||
val r = ProjectRule(IRVariable("g"),
|
||||
KTLong,
|
||||
OpChainExpr(
|
||||
GraphAggregatorExpr(
|
||||
"unresolved_default_path",
|
||||
|
@ -13,9 +13,16 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.lube.logical
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
import com.antgroup.openspg.reasoner.common.trees.BottomUp
|
||||
import com.antgroup.openspg.reasoner.common.types._
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr._
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
|
||||
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMng
|
||||
|
||||
object ExprUtil {
|
||||
|
||||
@ -46,15 +53,13 @@ object ExprUtil {
|
||||
|
||||
}
|
||||
|
||||
|
||||
def needResolved(rule: Expr): Boolean = {
|
||||
!getReferProperties(rule).filter(_._1 == null).isEmpty
|
||||
}
|
||||
|
||||
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
|
||||
|
||||
def rewriter: PartialFunction[Expr, Expr] = {
|
||||
case Ref(refName) =>
|
||||
def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
|
||||
if (replaceVar.contains(refName)) {
|
||||
val propertyVar = replaceVar(refName)
|
||||
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
|
||||
@ -76,4 +81,104 @@ object ExprUtil {
|
||||
newRule
|
||||
}
|
||||
|
||||
def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
|
||||
expr match {
|
||||
case Ref(name) =>
|
||||
if (referVars.contains(IRVariable(name))) {
|
||||
referVars(IRVariable(name))
|
||||
} else {
|
||||
KTObject
|
||||
}
|
||||
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
|
||||
case BinaryOpExpr(name, l, r) =>
|
||||
name match {
|
||||
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |
|
||||
BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign =>
|
||||
KTBoolean
|
||||
case BAdd | BSub | BMul | BDiv | BMod =>
|
||||
val left = getTargetType(l, referVars, udfRepo)
|
||||
val right = getTargetType(r, referVars, udfRepo)
|
||||
getUpperType(left, right)
|
||||
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||
}
|
||||
case UnaryOpExpr(name, arg) =>
|
||||
name match {
|
||||
case Not | Exists => KTBoolean
|
||||
case Abs | Neg => getTargetType(arg, referVars, udfRepo)
|
||||
case Floor | Ceil => KTDouble
|
||||
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||
}
|
||||
case FunctionExpr(name, funcArgs) =>
|
||||
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
|
||||
name match {
|
||||
case "rule_value" => types(1)
|
||||
case "cast_type" | "Cast" =>
|
||||
funcArgs(1).asInstanceOf[VString].value match {
|
||||
case "int" | "bigint" | "long" => KTLong
|
||||
case "float" | "double" => KTDouble
|
||||
case "varchar" | "string" => KTString
|
||||
case _ =>
|
||||
throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}")
|
||||
}
|
||||
case _ =>
|
||||
val udf = udfRepo.getUdfMeta(name, types.asJava)
|
||||
if (udf != null) {
|
||||
udf.getResultType
|
||||
} else {
|
||||
throw UnsupportedOperationException(s"cannot find UDF: ${name}")
|
||||
}
|
||||
}
|
||||
|
||||
case AggOpExpr(name, args) =>
|
||||
name match {
|
||||
case Min | Max | Sum | Avg | First | Accumulate(_) =>
|
||||
getTargetType(args, referVars, udfRepo)
|
||||
case StrJoin(_) => KTString
|
||||
case Count => KTLong
|
||||
case AggUdf(name, _) =>
|
||||
val types = getTargetType(args.head, referVars, udfRepo)
|
||||
val udf = udfRepo.getUdafMeta(name, types)
|
||||
if (udf != null) {
|
||||
udf.getResultType
|
||||
} else {
|
||||
throw UnsupportedOperationException(s"cannot find UDAF ${name}")
|
||||
}
|
||||
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||
}
|
||||
case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo)
|
||||
case ListOpExpr(name, _) =>
|
||||
name match {
|
||||
case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo)
|
||||
case Constraint(_, _, _) => KTBoolean
|
||||
case Get(_) | Slice(_, _) => KTObject
|
||||
}
|
||||
case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo)
|
||||
case VNull | VString(_) => KTString
|
||||
case VLong(_) => KTLong
|
||||
case VDouble(_) => KTDouble
|
||||
case VBoolean(_) => KTBoolean
|
||||
case VList(_, listType) => KTList(listType)
|
||||
case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}")
|
||||
}
|
||||
}
|
||||
|
||||
def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
|
||||
val newReferVars = new mutable.HashMap[IRField, KgType]
|
||||
newReferVars.++=(referVars)
|
||||
for (r <- rule.getDependencies) {
|
||||
newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo))
|
||||
}
|
||||
getTargetType(rule.getExpr, newReferVars.toMap, udfRepo)
|
||||
}
|
||||
|
||||
private def getUpperType(left: KgType, right: KgType): KgType = {
|
||||
val l = KgType.getNumberSeq(left)
|
||||
val r = KgType.getNumberSeq(right)
|
||||
if (l >= r) {
|
||||
left
|
||||
} else {
|
||||
right
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
import com.antgroup.openspg.reasoner.common.types.KTObject
|
||||
import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject}
|
||||
import com.antgroup.openspg.reasoner.lube.block.Aggregations
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
|
||||
@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL
|
||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
||||
class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
||||
class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit
|
||||
context: LogicalPlannerContext) {
|
||||
|
||||
def plan(dependency: LogicalOperator): LogicalOperator = {
|
||||
val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
|
||||
@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
||||
v
|
||||
}
|
||||
})
|
||||
|
||||
val referTypes = new mutable.HashMap[IRField, KgType]()
|
||||
for (v <- ruleFields) {
|
||||
resolved.getVar(v.name) match {
|
||||
case p: PropertyVar => referTypes.put(v, p.field.kgType)
|
||||
case node: NodeVar =>
|
||||
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||
case edge: EdgeVar =>
|
||||
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||
case _ => throw UnsupportedOperationException(s"cannot support $v")
|
||||
}
|
||||
}
|
||||
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
|
||||
val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo)
|
||||
|
||||
val renameVar = ruleFields
|
||||
.filter(_.isInstanceOf[IRVariable])
|
||||
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
|
||||
@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
||||
val field = getAggregateTarget(referFields, resolved, dependency)
|
||||
field match {
|
||||
case IRNode(alias, _) =>
|
||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
|
||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
|
||||
aggMap.put(propertyVar, newAggExpr)
|
||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||
case IREdge(alias, _) =>
|
||||
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
|
||||
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
|
||||
} else {
|
||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
|
||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
|
||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||
aggMap.put(propertyVar, newAggExpr)
|
||||
}
|
||||
case IRVariable(alias) =>
|
||||
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
|
||||
val propertyVar =
|
||||
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true))
|
||||
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true))
|
||||
aggMap.put(propertyVar, newAggExpr)
|
||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||
case _ =>
|
||||
|
@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||
import com.antgroup.openspg.reasoner.common.types.KTObject
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType
|
||||
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||
import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var}
|
||||
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
||||
import com.antgroup.openspg.reasoner.lube.logical._
|
||||
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
|
||||
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
|
||||
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
||||
class ProjectPlanner(projects: ProjectFields) {
|
||||
class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
|
||||
|
||||
def plan(dependency: LogicalOperator): LogicalOperator = {
|
||||
val projectMap = new mutable.HashMap[Var, Expr]()
|
||||
@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) {
|
||||
v
|
||||
}
|
||||
})
|
||||
val propertyVar = getTarget(rule._1, referVars, resolved, dependency)
|
||||
val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency)
|
||||
val transformer = new Rule2ExprTransformer()
|
||||
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
|
||||
val replaceVar = reference
|
||||
@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) {
|
||||
private def getTarget(
|
||||
left: IRField,
|
||||
referVars: List[IRField],
|
||||
rule: Rule,
|
||||
resolved: SolvedModel,
|
||||
dependency: LogicalOperator): PropertyVar = {
|
||||
val referTypes = new mutable.HashMap[IRField, KgType]()
|
||||
for (v <- referVars) {
|
||||
resolved.getVar(v.name) match {
|
||||
case p: PropertyVar => referTypes.put(v, p.field.kgType)
|
||||
case node: NodeVar =>
|
||||
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||
case edge: EdgeVar =>
|
||||
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
|
||||
val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo)
|
||||
|
||||
left match {
|
||||
case IRVariable(name) =>
|
||||
if (referVars.size == 1) {
|
||||
PropertyVar(referVars.head.name, new Field(name, KTObject, true))
|
||||
PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
|
||||
} else {
|
||||
val aliasSet = new mutable.HashSet[String]()
|
||||
for (rVar <- referVars) {
|
||||
@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) {
|
||||
}
|
||||
}
|
||||
val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
|
||||
PropertyVar(targetAlias, new Field(left.name, KTObject, true))
|
||||
PropertyVar(targetAlias, new Field(left.name, ruleRetType, true))
|
||||
}
|
||||
case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true))
|
||||
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
|
||||
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
|
||||
case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true))
|
||||
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
|
||||
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
|
||||
case _ => throw UnsupportedOperationException(s"cannot support $left")
|
||||
}
|
||||
|
||||
|
@ -13,11 +13,13 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.lube.logical
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.types.KTInteger
|
||||
import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
|
||||
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
|
||||
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
|
||||
import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMngFactory
|
||||
import org.scalatest.funspec.AnyFunSpec
|
||||
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
|
||||
|
||||
@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec {
|
||||
val replaceMap = Map.apply(
|
||||
"a" -> PropertyVar("b", new Field("id", KTInteger, true)),
|
||||
"c" -> PropertyVar("d", new Field("id", KTInteger, true)))
|
||||
val r1 = ProjectRule(IRVariable("test"),
|
||||
KTInteger, Ref("a"))
|
||||
val r2 = ProjectRule(IRVariable("test"),
|
||||
KTInteger, Ref("c"))
|
||||
val r1 = ProjectRule(IRVariable("test"), Ref("a"))
|
||||
val r2 = ProjectRule(IRVariable("test"), Ref("c"))
|
||||
r1.addDependency(r2)
|
||||
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
|
||||
print(newRule.getExpr.pretty)
|
||||
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
|
||||
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
|
||||
newRule.getExpr.asInstanceOf[UnaryOpExpr]
|
||||
.arg.asInstanceOf[Ref].refName should equal("b")
|
||||
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b")
|
||||
}
|
||||
|
||||
it("test expr output type") {
|
||||
val parser = new RuleExprParser()
|
||||
val udfRepo = UdfMngFactory.getUdfMng
|
||||
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
|
||||
|
||||
var expr: Expr = parser.parse("A.age")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
|
||||
|
||||
expr = parser.parse("A.age + 1")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong)
|
||||
|
||||
expr = parser.parse("A.age > 10")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean)
|
||||
|
||||
expr = parser.parse("floor(A.age)")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble)
|
||||
|
||||
expr = parser.parse("abs(A.age)")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
|
||||
|
||||
expr = parser.parse("concat(A.age, \",\")")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
|
||||
|
||||
expr = parser.parse("cast_type(A.age, 'string')")
|
||||
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
|
||||
}
|
||||
|
||||
it("test rule output type") {
|
||||
val parser = new RuleExprParser()
|
||||
val udfRepo = UdfMngFactory.getUdfMng
|
||||
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
|
||||
|
||||
val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10"))
|
||||
val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age"))
|
||||
rule.addDependency(r1)
|
||||
|
||||
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
|
||||
}
|
||||
}
|
||||
|
@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec {
|
||||
println(optimizedLogicalPlan.pretty)
|
||||
}
|
||||
|
||||
it("online") {
|
||||
val dsl = """Define (s:User)-[p:redPacket]->(o:Int) {
|
||||
| GraphStructure {
|
||||
| (s)
|
||||
| }
|
||||
| Rule {
|
||||
|LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m
|
||||
|Latest30DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|Latest7DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|LatestTTT = Latest7DayPayCount.accumulate(+)
|
||||
|LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0
|
||||
|Latest7DayPayCountSum=Latest7DayPayCount
|
||||
|Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0
|
||||
|HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount
|
||||
|HighLost("高频降频100%"):HighReduceValue == 1
|
||||
|HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1
|
||||
|HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8
|
||||
|HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5
|
||||
|HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3
|
||||
|Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|Latest30DayPayDayCount=size(Latest30DayPayCount)
|
||||
|Latest3060DayPayDayCount=size(Latest3060DayPayCount)
|
||||
|High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13
|
||||
|High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13
|
||||
|Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
||||
|Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
||||
|Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
||||
|Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|
||||
|Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|
||||
|Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|Latest6090DayPayDayCount=size(Latest6090DayPayCount)
|
||||
|Latest60DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|Latest60DayPayDayCount=size(Latest60DayPayCount)
|
||||
|Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0
|
||||
|Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0
|
||||
|HistoricallyPay=s.ngfe_tag__pay_cnt_total
|
||||
|HistoricallyPayCount=size(HistoricallyPay)
|
||||
|New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0
|
||||
|Latest90DayPayCount=s.ngfe_tag__pay_cnt_d
|
||||
|Latest90DayPayDayCount=size(Latest90DayPayCount)
|
||||
|Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0
|
||||
|o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost"))
|
||||
| }
|
||||
|}""".stripMargin
|
||||
val parser = new OpenSPGDslParser()
|
||||
val block = parser.parse(dsl)
|
||||
println(block.pretty)
|
||||
val schema: Map[String, Set[String]] = Map.apply(
|
||||
"User" -> Set
|
||||
.apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d"))
|
||||
val catalog = new PropertyGraphCatalog(schema)
|
||||
catalog.init()
|
||||
implicit val context: LogicalPlannerContext =
|
||||
LogicalPlannerContext(catalog, parser, Map.empty)
|
||||
val logicalPlan = LogicalPlanner.plan(block)
|
||||
println(logicalPlan.head.pretty)
|
||||
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head)
|
||||
println(optimizedLogicalPlan.pretty)
|
||||
}
|
||||
|
||||
it("test start flag") {
|
||||
val dsl =
|
||||
"""
|
||||
@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec {
|
||||
| (s)-[p2:followPM]->(o)
|
||||
|}
|
||||
|Rule {
|
||||
| c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0)
|
||||
| c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0)
|
||||
|
|
||||
|}
|
||||
|Action {
|
||||
|
@ -41,7 +41,7 @@
|
||||
<properties>
|
||||
<antlr4.version>4.8</antlr4.version>
|
||||
<fastjson.version>1.2.71_noneautotype</fastjson.version>
|
||||
<geotools.version>27.0</geotools.version>
|
||||
<geotools.version>28.0</geotools.version>
|
||||
<google.s2.version>2.0.0</google.s2.version>
|
||||
<hadoop.version>2.7.2</hadoop.version>
|
||||
<hive.version>3.1.0</hive.version>
|
||||
|
@ -137,7 +137,7 @@ public class KgReasonerABMLocalTest {
|
||||
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
||||
+ " }\n"
|
||||
+ " Rule {\n"
|
||||
+ " \tv = t.stock/t.total\n"
|
||||
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
|
||||
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
||||
+ " o = v\n"
|
||||
+ " }\n"
|
||||
@ -197,7 +197,7 @@ public class KgReasonerABMLocalTest {
|
||||
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
||||
+ " }\n"
|
||||
+ " Rule {\n"
|
||||
+ " \tv = t.stock/t.total\n"
|
||||
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
|
||||
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
||||
+ " o = v\n"
|
||||
+ " }\n"
|
||||
|
@ -50,9 +50,12 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
+ "R1: A.id == $idSet1\n"
|
||||
+ "R2: B.id in $idSet2\n"
|
||||
+ "R3: C.id in $idSet2\n"
|
||||
+ "totalTrans1 = group(A,B,C).sum(p1.amount)\n"
|
||||
+ "totalTrans2 = group(A,B,C).sum(p2.amount)\n"
|
||||
+ "totalTrans3 = group(A,B,C).sum(p3.amount)\n"
|
||||
+ "p1_amt = cast_type(p1.amount,'long')\n"
|
||||
+ "p2_amt = cast_type(p2.amount,'long')\n"
|
||||
+ "p3_amt = cast_type(p3.amount,'long')\n"
|
||||
+ "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
|
||||
+ "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
|
||||
+ "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
|
||||
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
|
||||
+ "R2('取top2'): top(totalTrans, 2)"
|
||||
+ "}\n"
|
||||
@ -90,7 +93,7 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
+ "R1: A.id in $idSet1\n"
|
||||
+ "R2: B.id in $idSet2\n"
|
||||
+ "R3: C.id in $idSet2\n"
|
||||
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
|
||||
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
|
||||
+ "R2('取top2'): top(totalTrans, 3)"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
@ -127,7 +130,7 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
+ "R1: A.id == $idSet1\n"
|
||||
+ "R2: B.id == $idSet2\n"
|
||||
+ "R3: C.id == $idSet3\n"
|
||||
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
|
||||
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
|
||||
+ "R2('取top2'): top(totalTrans, 3)"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
@ -169,9 +172,12 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
+ "R1: A.id in $idSet1\n"
|
||||
+ "R2: B.id in $idSet2\n"
|
||||
+ "R3: C.id in $idSet2\n"
|
||||
+ "t1 = group(A,B,C).sum(p1.amount)\n"
|
||||
+ "t2 = group(A,B,C).sum(p2.amount)\n"
|
||||
+ "t3 = group(A,B,C).sum(p3.amount)\n"
|
||||
+ "p1_amt = cast_type(p1.amount,'long')\n"
|
||||
+ "p2_amt = cast_type(p2.amount,'long')\n"
|
||||
+ "p3_amt = cast_type(p3.amount,'long')\n"
|
||||
+ "t1 = group(A,B,C).sum(p1_amt)\n"
|
||||
+ "t2 = group(A,B,C).sum(p2_amt)\n"
|
||||
+ "t3 = group(A,B,C).sum(p3_amt)\n"
|
||||
+ "totalSum = t1 + t2 + t3"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
|
@ -403,7 +403,7 @@ public class KgReasonerTopKFilmTest {
|
||||
+ " o->star [starOfFilm] as sf2\n"
|
||||
+ "}\n"
|
||||
+ "Rule {\n"
|
||||
+ "total = sf.joinTs + sf2.joinTs\n"
|
||||
+ "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n"
|
||||
+ "R2: top(total, 1)\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
|
@ -25,13 +25,7 @@ import com.antgroup.openspg.reasoner.kggraph.AggregationSchemaInfo;
|
||||
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
|
||||
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
|
||||
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.*;
|
||||
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
|
||||
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
|
||||
import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
|
||||
@ -41,6 +35,7 @@ import com.antgroup.openspg.reasoner.lube.logical.Var;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
|
||||
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
|
||||
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
|
||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
||||
@ -353,26 +348,14 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
if (null != udafInitParams) {
|
||||
udaf.initialize(udafInitParams);
|
||||
}
|
||||
|
||||
String sourceAlias = null;
|
||||
String sourcePropertyName = null;
|
||||
ParsedAggEle parsedAggEle;
|
||||
Set<String> aliasList = aggInfo.getExprUseAliasSet();
|
||||
if (aliasList.size() <= 1) {
|
||||
Expr sourceExpr = aggInfo.getAggEle();
|
||||
// aggregate by vertex subgraph
|
||||
if (sourceExpr instanceof Ref) {
|
||||
Ref sourceRef = (Ref) sourceExpr;
|
||||
sourceAlias = sourceRef.refName();
|
||||
} else if (sourceExpr instanceof UnaryOpExpr) {
|
||||
UnaryOpExpr expr = (UnaryOpExpr) sourceExpr;
|
||||
GetField getField = (GetField) expr.name();
|
||||
sourceAlias = ((Ref) expr.arg()).refName();
|
||||
sourcePropertyName = getField.fieldName();
|
||||
}
|
||||
parsedAggEle = aggInfo.getParsedAggEle();
|
||||
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
|
||||
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
||||
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) {
|
||||
sb.append(valueFiltered2).append("## ");
|
||||
}
|
||||
@ -381,19 +364,41 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
}
|
||||
}
|
||||
}
|
||||
String finalSourcePropertyName = sourcePropertyName;
|
||||
String finalSourcePropertyName = parsedAggEle.getSourcePropertyName();
|
||||
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
||||
if (valueFiltered.getVertexAlias().contains(sourceAlias)) {
|
||||
List<IVertex<IVertexId, IProperty>> vertexList = valueFiltered.getVertex(sourceAlias);
|
||||
if (sourcePropertyName == null) {
|
||||
if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) {
|
||||
List<IVertex<IVertexId, IProperty>> vertexList =
|
||||
valueFiltered.getVertex(parsedAggEle.getSourceAlias());
|
||||
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
|
||||
vertexList.forEach(
|
||||
vertex -> {
|
||||
Map<String, Object> context =
|
||||
RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias());
|
||||
Object value =
|
||||
RuleRunner.getInstance()
|
||||
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
|
||||
udaf.update(value);
|
||||
});
|
||||
} else if (finalSourcePropertyName == null) {
|
||||
vertexList.forEach(udaf::update);
|
||||
} else {
|
||||
vertexList.forEach(
|
||||
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
|
||||
}
|
||||
} else {
|
||||
List<IEdge<IVertexId, IProperty>> edgeList = valueFiltered.getEdge(sourceAlias);
|
||||
if (sourcePropertyName == null) {
|
||||
List<IEdge<IVertexId, IProperty>> edgeList =
|
||||
valueFiltered.getEdge(parsedAggEle.getSourceAlias());
|
||||
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
|
||||
edgeList.forEach(
|
||||
edge -> {
|
||||
Map<String, Object> context =
|
||||
RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias());
|
||||
Object value =
|
||||
RuleRunner.getInstance()
|
||||
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
|
||||
udaf.update(value);
|
||||
});
|
||||
} else if (finalSourcePropertyName == null) {
|
||||
edgeList.forEach(udaf::update);
|
||||
} else {
|
||||
edgeList.forEach(
|
||||
|
@ -23,6 +23,7 @@ import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
||||
private final ParsedAggEle parsedAggEle;
|
||||
|
||||
/**
|
||||
* constructor
|
||||
@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
|
||||
*/
|
||||
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
||||
super(taskId, var, aggregator);
|
||||
parsedAggEle = parsedAggEle();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -58,4 +60,9 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
|
||||
public Expr getAggEle() {
|
||||
return getAggIfOpExpr().aggOpExpr().aggEleExpr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParsedAggEle getParsedAggEle() {
|
||||
return parsedAggEle;
|
||||
}
|
||||
}
|
||||
|
@ -22,9 +22,11 @@ import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
||||
private final ParsedAggEle parsedAggEle;
|
||||
|
||||
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
||||
super(taskId, var, aggregator);
|
||||
parsedAggEle = parsedAggEle();
|
||||
}
|
||||
|
||||
public AggOpExpr getAggOpExpr() {
|
||||
@ -45,4 +47,9 @@ public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Se
|
||||
public Expr getAggEle() {
|
||||
return getAggOpExpr().aggEleExpr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParsedAggEle getParsedAggEle() {
|
||||
return this.parsedAggEle;
|
||||
}
|
||||
}
|
||||
|
@ -19,6 +19,9 @@ import com.antgroup.openspg.reasoner.lube.common.expr.AggUdf;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
|
||||
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
|
||||
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
|
||||
import com.antgroup.openspg.reasoner.lube.logical.Var;
|
||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
|
||||
@ -139,6 +142,9 @@ public abstract class BaseGroupProcess implements Serializable {
|
||||
*/
|
||||
public abstract Expr getAggEle();
|
||||
|
||||
/** get parsed agg ele */
|
||||
public abstract ParsedAggEle getParsedAggEle();
|
||||
|
||||
public Set<String> parseExprUseAliasSet() {
|
||||
scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
|
||||
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
|
||||
@ -148,6 +154,26 @@ public abstract class BaseGroupProcess implements Serializable {
|
||||
return WareHouseUtils.getRuleList(getAggEle());
|
||||
}
|
||||
|
||||
protected ParsedAggEle parsedAggEle() {
|
||||
String sourceAlias = null;
|
||||
String sourcePropertyName = null;
|
||||
List<String> exprStrList = null;
|
||||
Expr aggEle = getAggEle();
|
||||
if (aggEle instanceof Ref) {
|
||||
Ref sourceRef = (Ref) aggEle;
|
||||
sourceAlias = sourceRef.refName();
|
||||
} else if (aggEle instanceof UnaryOpExpr) {
|
||||
UnaryOpExpr expr = (UnaryOpExpr) aggEle;
|
||||
GetField getField = (GetField) expr.name();
|
||||
sourceAlias = ((Ref) expr.arg()).refName();
|
||||
sourcePropertyName = getField.fieldName();
|
||||
} else if (1 == this.exprUseAliasSet.size()) {
|
||||
sourceAlias = this.exprUseAliasSet.iterator().next();
|
||||
exprStrList = this.exprRuleString;
|
||||
}
|
||||
return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList);
|
||||
}
|
||||
|
||||
/**
|
||||
* getter
|
||||
*
|
||||
|
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.rdg.common.groupProcess;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class ParsedAggEle {
|
||||
private final String sourceAlias;
|
||||
private final String sourcePropertyName;
|
||||
private final List<String> exprStrList;
|
||||
|
||||
public ParsedAggEle(String sourceAlias, String sourcePropertyName, List<String> exprStrList) {
|
||||
this.sourceAlias = sourceAlias;
|
||||
this.sourcePropertyName = sourcePropertyName;
|
||||
this.exprStrList = exprStrList;
|
||||
}
|
||||
|
||||
public String getSourceAlias() {
|
||||
return sourceAlias;
|
||||
}
|
||||
|
||||
public String getSourcePropertyName() {
|
||||
return sourcePropertyName;
|
||||
}
|
||||
|
||||
public List<String> getExprStrList() {
|
||||
return exprStrList;
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user