feat(reasoner): support value type inference in udf (#132)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-03-05 11:59:50 +08:00 committed by GitHub
parent eb2590aada
commit 258b0e7dfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 376 additions and 182 deletions

View File

@ -13,6 +13,8 @@
package com.antgroup.openspg.reasoner.common.types
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
trait KgType {
def isNullable: Boolean = false
}
@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType
* @param elementType
*/
final case class KTMultiVersion(elementType: KgType) extends KgType
object KgType {
def getNumberSeq(kgType: KgType): Int = {
kgType match {
case KTInteger => 1
case KTLong => 2
case KTDouble => 3
case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")
}
}
}

View File

@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
IRProperty(s.alias, propertyName) ->
ProjectRule(
IRProperty(s.alias, propertyName),
propertyType,
Ref(ddlBlockWithNodes._3.target.alias)))))
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
case AddPredicate(predicate) =>
@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
ProjectBlock(
List.apply(opBlock),
ProjectFields(Map.apply(lValueName ->
ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain))))
ProjectRule(lValueName, opChain))))
}
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
ProjectBlock(
@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
lValueName ->
ProjectRule(
lValueName,
exprParser.parseRetType(opChain.curExpr),
opChain.curExpr))))
case _ => null
}
@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
List.empty)
case _ =>
rule match {
case ProjectRule(_, lvalueType, _) =>
val projectRule = ProjectRule(lvalueFiled, lvalueType, expr)
case ProjectRule(_, _) =>
val projectRule = ProjectRule(lvalueFiled, expr)
ProjectBlock(
List.apply(preBlock),
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
(ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true)
(ProjectRule(IRVariable(defaultName), expr), columnName, true)
}
def parseGraphStructure(
@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
(
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}
@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
(
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}

View File

@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
}
}
def parseRetType(expr: Expr): KgType = {
KTObject
}
def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
ctx.getChild(0) match {
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
if (ctx.property_name() != null) {
ProjectRule(
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
parseRetType(expr),
expr)
} else {
ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr)
ProjectRule(IRVariable(ctx.identifier().getText), expr)
}
}

View File

@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph._
@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}
it("addproperies with constraint") {
@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}
it("addproperies2") {
@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addproperies") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addNode") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
.asInstanceOf[AddPredicate]
.predicate
.fields
.keySet should contain ("same_domain_num")
.keySet should contain("same_domain_num")
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
blocks(1)
@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o)))))
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))

View File

@ -33,6 +33,10 @@
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-common</artifactId>
</dependency>
<dependency>
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-udf</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>

View File

@ -13,10 +13,12 @@
package com.antgroup.openspg.reasoner.lube.catalog
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
import scala.collection.mutable
import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}
/**
@ -27,6 +29,7 @@ import scala.collection.mutable
*/
abstract class Catalog() extends Serializable {
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
@transient private val udfRepo = UdfMngFactory.getUdfMng
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
/**
@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable {
graphRepository.get(graphName).orNull
}
def getUdfRepo: UdfMng = udfRepo
/**
* Get schema from knowledge graph
*/

View File

@ -13,7 +13,6 @@
package com.antgroup.openspg.reasoner.lube.common.rule
import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.IRField
@ -39,13 +38,6 @@ trait Rule extends Cloneable{
*/
def getExpr: Expr
/**
* get lvalue type
* @return
*/
def getLvalueType: KgType
/**
* get dependencies
* @return

View File

@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
*/
override def getExpr: Expr = expr
/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = KTBoolean
override def andRule(rule: Rule): Rule = {
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
* @param lvalueType
* @param expr
*/
final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
final case class ProjectRule(output: IRField, expr: Expr)
extends DependencyRule {
/**
@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
*/
override def getExpr: Expr = expr
/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = lvalueType
override def andRule(rule: Rule): Rule = {
throw UnsupportedOperationException("ProjectRule cannot support andRule")

View File

@ -138,7 +138,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
case _ =>
ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr)
ProjectRule(IRVariable(ruleNameStr), expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {
@ -162,7 +162,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(rule.getName, logicRule.ruleExplain, expr)
case _ =>
ProjectRule(rule.getOutput, rule.getLvalueType, expr)
ProjectRule(rule.getOutput, expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {

View File

@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
false)))),
ProjectFields(
Map.apply(IRVariable("total_domain_num") ->
ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o")))))
ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
val p = BlockUtils.transBlock2Graph(block)
p.size should equal(1)
p.head.graphPattern.nodes.size should equal(2)
@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
it("rename_rule") {
val rule = ProjectRule(
IRVariable("a"),
KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("birthDate"), Ref("e")),
@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
it("variable_rule") {
val rule = ProjectRule(
IRVariable("a"),
KTObject,
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
val rule2 = ProjectRule(
IRVariable("b"),
KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("attr1"), Ref("e")),
@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
def getDependenceRule(): Rule = {
val r0 = ProjectRule(
IRVariable("r0"),
KTLong,
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
)
val r1 = LogicRule(
@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
val r0 = LogicRule("tmp", "",
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
val r = ProjectRule(IRVariable("g"),
KTLong,
OpChainExpr(
GraphAggregatorExpr(
"unresolved_default_path",

View File

@ -13,9 +13,16 @@
package com.antgroup.openspg.reasoner.lube.logical
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.trees.BottomUp
import com.antgroup.openspg.reasoner.common.types._
import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
import com.antgroup.openspg.reasoner.udf.UdfMng
object ExprUtil {
@ -46,15 +53,13 @@ object ExprUtil {
}
def needResolved(rule: Expr): Boolean = {
!getReferProperties(rule).filter(_._1 == null).isEmpty
}
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
def rewriter: PartialFunction[Expr, Expr] = {
case Ref(refName) =>
def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
if (replaceVar.contains(refName)) {
val propertyVar = replaceVar(refName)
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
@ -76,4 +81,104 @@ object ExprUtil {
newRule
}
def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
expr match {
case Ref(name) =>
if (referVars.contains(IRVariable(name))) {
referVars(IRVariable(name))
} else {
KTObject
}
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
case BinaryOpExpr(name, l, r) =>
name match {
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |
BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign =>
KTBoolean
case BAdd | BSub | BMul | BDiv | BMod =>
val left = getTargetType(l, referVars, udfRepo)
val right = getTargetType(r, referVars, udfRepo)
getUpperType(left, right)
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case UnaryOpExpr(name, arg) =>
name match {
case Not | Exists => KTBoolean
case Abs | Neg => getTargetType(arg, referVars, udfRepo)
case Floor | Ceil => KTDouble
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case FunctionExpr(name, funcArgs) =>
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
name match {
case "rule_value" => types(1)
case "cast_type" | "Cast" =>
funcArgs(1).asInstanceOf[VString].value match {
case "int" | "bigint" | "long" => KTLong
case "float" | "double" => KTDouble
case "varchar" | "string" => KTString
case _ =>
throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}")
}
case _ =>
val udf = udfRepo.getUdfMeta(name, types.asJava)
if (udf != null) {
udf.getResultType
} else {
throw UnsupportedOperationException(s"cannot find UDF: ${name}")
}
}
case AggOpExpr(name, args) =>
name match {
case Min | Max | Sum | Avg | First | Accumulate(_) =>
getTargetType(args, referVars, udfRepo)
case StrJoin(_) => KTString
case Count => KTLong
case AggUdf(name, _) =>
val types = getTargetType(args.head, referVars, udfRepo)
val udf = udfRepo.getUdafMeta(name, types)
if (udf != null) {
udf.getResultType
} else {
throw UnsupportedOperationException(s"cannot find UDAF ${name}")
}
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
}
case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo)
case ListOpExpr(name, _) =>
name match {
case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo)
case Constraint(_, _, _) => KTBoolean
case Get(_) | Slice(_, _) => KTObject
}
case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo)
case VNull | VString(_) => KTString
case VLong(_) => KTLong
case VDouble(_) => KTDouble
case VBoolean(_) => KTBoolean
case VList(_, listType) => KTList(listType)
case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}")
}
}
def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
val newReferVars = new mutable.HashMap[IRField, KgType]
newReferVars.++=(referVars)
for (r <- rule.getDependencies) {
newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo))
}
getTargetType(rule.getExpr, newReferVars.toMap, udfRepo)
}
private def getUpperType(left: KgType, right: KgType): KgType = {
val l = KgType.getNumberSeq(left)
val r = KgType.getNumberSeq(right)
if (l >= r) {
left
} else {
right
}
}
}

View File

@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.types.KTObject
import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject}
import com.antgroup.openspg.reasoner.lube.block.Aggregations
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
import org.apache.commons.lang3.StringUtils
class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit
context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = {
val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
v
}
})
val referTypes = new mutable.HashMap[IRField, KgType]()
for (v <- ruleFields) {
resolved.getVar(v.name) match {
case p: PropertyVar => referTypes.put(v, p.field.kgType)
case node: NodeVar =>
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case edge: EdgeVar =>
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case _ => throw UnsupportedOperationException(s"cannot support $v")
}
}
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo)
val renameVar = ruleFields
.filter(_.isInstanceOf[IRVariable])
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
val field = getAggregateTarget(referFields, resolved, dependency)
field match {
case IRNode(alias, _) =>
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case IREdge(alias, _) =>
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
} else {
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
aggMap.put(propertyVar, newAggExpr)
}
case IRVariable(alias) =>
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
val propertyVar =
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true))
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case _ =>

View File

@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.types.KTObject
import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
import com.antgroup.openspg.reasoner.lube.common.graph._
import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var}
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
import com.antgroup.openspg.reasoner.lube.logical._
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
import org.apache.commons.lang3.StringUtils
class ProjectPlanner(projects: ProjectFields) {
class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = {
val projectMap = new mutable.HashMap[Var, Expr]()
@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) {
v
}
})
val propertyVar = getTarget(rule._1, referVars, resolved, dependency)
val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency)
val transformer = new Rule2ExprTransformer()
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
val replaceVar = reference
@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) {
private def getTarget(
left: IRField,
referVars: List[IRField],
rule: Rule,
resolved: SolvedModel,
dependency: LogicalOperator): PropertyVar = {
val referTypes = new mutable.HashMap[IRField, KgType]()
for (v <- referVars) {
resolved.getVar(v.name) match {
case p: PropertyVar => referTypes.put(v, p.field.kgType)
case node: NodeVar =>
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case edge: EdgeVar =>
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
case _ =>
}
}
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo)
left match {
case IRVariable(name) =>
if (referVars.size == 1) {
PropertyVar(referVars.head.name, new Field(name, KTObject, true))
PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
} else {
val aliasSet = new mutable.HashSet[String]()
for (rVar <- referVars) {
@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) {
}
}
val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
PropertyVar(targetAlias, new Field(left.name, KTObject, true))
PropertyVar(targetAlias, new Field(left.name, ruleRetType, true))
}
case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true))
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true))
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
case _ => throw UnsupportedOperationException(s"cannot support $left")
}

View File

@ -13,11 +13,13 @@
package com.antgroup.openspg.reasoner.lube.logical
import com.antgroup.openspg.reasoner.common.types.KTInteger
import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable
import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser
import com.antgroup.openspg.reasoner.udf.UdfMngFactory
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec {
val replaceMap = Map.apply(
"a" -> PropertyVar("b", new Field("id", KTInteger, true)),
"c" -> PropertyVar("d", new Field("id", KTInteger, true)))
val r1 = ProjectRule(IRVariable("test"),
KTInteger, Ref("a"))
val r2 = ProjectRule(IRVariable("test"),
KTInteger, Ref("c"))
val r1 = ProjectRule(IRVariable("test"), Ref("a"))
val r2 = ProjectRule(IRVariable("test"), Ref("c"))
r1.addDependency(r2)
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
print(newRule.getExpr.pretty)
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
newRule.getExpr.asInstanceOf[UnaryOpExpr]
.arg.asInstanceOf[Ref].refName should equal("b")
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b")
}
it("test expr output type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
var expr: Expr = parser.parse("A.age")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
expr = parser.parse("A.age + 1")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong)
expr = parser.parse("A.age > 10")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean)
expr = parser.parse("floor(A.age)")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble)
expr = parser.parse("abs(A.age)")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
expr = parser.parse("concat(A.age, \",\")")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
expr = parser.parse("cast_type(A.age, 'string')")
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
}
it("test rule output type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10"))
val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age"))
rule.addDependency(r1)
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
}
}

View File

@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec {
println(optimizedLogicalPlan.pretty)
}
it("online") {
val dsl = """Define (s:User)-[p:redPacket]->(o:Int) {
| GraphStructure {
| (s)
| }
| Rule {
|LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m
|Latest30DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest7DayPayCount=s.ngfe_tag__pay_cnt_d
|LatestTTT = Latest7DayPayCount.accumulate(+)
|LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0
|Latest7DayPayCountSum=Latest7DayPayCount
|Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0
|HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount
|HighLost("高频降频100%"):HighReduceValue == 1
|HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1
|HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8
|HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5
|HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3
|Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest30DayPayDayCount=size(Latest30DayPayCount)
|Latest3060DayPayDayCount=size(Latest3060DayPayCount)
|High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13
|High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13
|Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest6090DayPayDayCount=size(Latest6090DayPayCount)
|Latest60DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest60DayPayDayCount=size(Latest60DayPayCount)
|Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0
|Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0
|HistoricallyPay=s.ngfe_tag__pay_cnt_total
|HistoricallyPayCount=size(HistoricallyPay)
|New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0
|Latest90DayPayCount=s.ngfe_tag__pay_cnt_d
|Latest90DayPayDayCount=size(Latest90DayPayCount)
|Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0
|o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost"))
| }
|}""".stripMargin
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
println(block.pretty)
val schema: Map[String, Set[String]] = Map.apply(
"User" -> Set
.apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d"))
val catalog = new PropertyGraphCatalog(schema)
catalog.init()
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(catalog, parser, Map.empty)
val logicalPlan = LogicalPlanner.plan(block)
println(logicalPlan.head.pretty)
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head)
println(optimizedLogicalPlan.pretty)
}
it("test start flag") {
val dsl =
"""
@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec {
| (s)-[p2:followPM]->(o)
|}
|Rule {
| c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0)
| c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0)
|
|}
|Action {

View File

@ -41,7 +41,7 @@
<properties>
<antlr4.version>4.8</antlr4.version>
<fastjson.version>1.2.71_noneautotype</fastjson.version>
<geotools.version>27.0</geotools.version>
<geotools.version>28.0</geotools.version>
<google.s2.version>2.0.0</google.s2.version>
<hadoop.version>2.7.2</hadoop.version>
<hive.version>3.1.0</hive.version>

View File

@ -137,7 +137,7 @@ public class KgReasonerABMLocalTest {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n"
+ " Rule {\n"
+ " \tv = t.stock/t.total\n"
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n"
+ " }\n"
@ -197,7 +197,7 @@ public class KgReasonerABMLocalTest {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n"
+ " Rule {\n"
+ " \tv = t.stock/t.total\n"
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n"
+ " }\n"

View File

@ -50,9 +50,12 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id == $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
+ "totalTrans1 = group(A,B,C).sum(p1.amount)\n"
+ "totalTrans2 = group(A,B,C).sum(p2.amount)\n"
+ "totalTrans3 = group(A,B,C).sum(p3.amount)\n"
+ "p1_amt = cast_type(p1.amount,'long')\n"
+ "p2_amt = cast_type(p2.amount,'long')\n"
+ "p3_amt = cast_type(p3.amount,'long')\n"
+ "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
+ "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
+ "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
+ "R2('取top2'): top(totalTrans, 2)"
+ "}\n"
@ -90,7 +93,7 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)"
+ "}\n"
+ "Action {\n"
@ -127,7 +130,7 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id == $idSet1\n"
+ "R2: B.id == $idSet2\n"
+ "R3: C.id == $idSet3\n"
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)"
+ "}\n"
+ "Action {\n"
@ -169,9 +172,12 @@ public class KgReasonerAliasSetKFilmTest {
+ "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
+ "t1 = group(A,B,C).sum(p1.amount)\n"
+ "t2 = group(A,B,C).sum(p2.amount)\n"
+ "t3 = group(A,B,C).sum(p3.amount)\n"
+ "p1_amt = cast_type(p1.amount,'long')\n"
+ "p2_amt = cast_type(p2.amount,'long')\n"
+ "p3_amt = cast_type(p3.amount,'long')\n"
+ "t1 = group(A,B,C).sum(p1_amt)\n"
+ "t2 = group(A,B,C).sum(p2_amt)\n"
+ "t3 = group(A,B,C).sum(p3_amt)\n"
+ "totalSum = t1 + t2 + t3"
+ "}\n"
+ "Action {\n"

View File

@ -403,7 +403,7 @@ public class KgReasonerTopKFilmTest {
+ " o->star [starOfFilm] as sf2\n"
+ "}\n"
+ "Rule {\n"
+ "total = sf.joinTs + sf2.joinTs\n"
+ "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n"
+ "R2: top(total, 1)\n"
+ "}\n"
+ "Action {\n"

View File

@ -25,13 +25,7 @@ import com.antgroup.openspg.reasoner.kggraph.AggregationSchemaInfo;
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr;
import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr;
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
import com.antgroup.openspg.reasoner.lube.common.expr.*;
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
@ -41,6 +35,7 @@ import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
@ -353,26 +348,14 @@ public class KgGraphAggregateImpl implements Serializable {
if (null != udafInitParams) {
udaf.initialize(udafInitParams);
}
String sourceAlias = null;
String sourcePropertyName = null;
ParsedAggEle parsedAggEle;
Set<String> aliasList = aggInfo.getExprUseAliasSet();
if (aliasList.size() <= 1) {
Expr sourceExpr = aggInfo.getAggEle();
// aggregate by vertex subgraph
if (sourceExpr instanceof Ref) {
Ref sourceRef = (Ref) sourceExpr;
sourceAlias = sourceRef.refName();
} else if (sourceExpr instanceof UnaryOpExpr) {
UnaryOpExpr expr = (UnaryOpExpr) sourceExpr;
GetField getField = (GetField) expr.name();
sourceAlias = ((Ref) expr.arg()).refName();
sourcePropertyName = getField.fieldName();
}
parsedAggEle = aggInfo.getParsedAggEle();
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
StringBuffer sb = new StringBuffer();
StringBuilder sb = new StringBuilder();
for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) {
sb.append(valueFiltered2).append("## ");
}
@ -381,19 +364,41 @@ public class KgGraphAggregateImpl implements Serializable {
}
}
}
String finalSourcePropertyName = sourcePropertyName;
String finalSourcePropertyName = parsedAggEle.getSourcePropertyName();
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
if (valueFiltered.getVertexAlias().contains(sourceAlias)) {
List<IVertex<IVertexId, IProperty>> vertexList = valueFiltered.getVertex(sourceAlias);
if (sourcePropertyName == null) {
if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) {
List<IVertex<IVertexId, IProperty>> vertexList =
valueFiltered.getVertex(parsedAggEle.getSourceAlias());
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
vertexList.forEach(
vertex -> {
Map<String, Object> context =
RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias());
Object value =
RuleRunner.getInstance()
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
udaf.update(value);
});
} else if (finalSourcePropertyName == null) {
vertexList.forEach(udaf::update);
} else {
vertexList.forEach(
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
}
} else {
List<IEdge<IVertexId, IProperty>> edgeList = valueFiltered.getEdge(sourceAlias);
if (sourcePropertyName == null) {
List<IEdge<IVertexId, IProperty>> edgeList =
valueFiltered.getEdge(parsedAggEle.getSourceAlias());
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
edgeList.forEach(
edge -> {
Map<String, Object> context =
RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias());
Object value =
RuleRunner.getInstance()
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
udaf.update(value);
});
} else if (finalSourcePropertyName == null) {
edgeList.forEach(udaf::update);
} else {
edgeList.forEach(

View File

@ -23,6 +23,7 @@ import java.io.Serializable;
import java.util.List;
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
private final ParsedAggEle parsedAggEle;
/**
* constructor
@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
*/
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator);
parsedAggEle = parsedAggEle();
}
/**
@ -58,4 +60,9 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
public Expr getAggEle() {
return getAggIfOpExpr().aggOpExpr().aggEleExpr();
}
@Override
public ParsedAggEle getParsedAggEle() {
return parsedAggEle;
}
}

View File

@ -22,9 +22,11 @@ import java.io.Serializable;
import java.util.List;
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
private final ParsedAggEle parsedAggEle;
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator);
parsedAggEle = parsedAggEle();
}
public AggOpExpr getAggOpExpr() {
@ -45,4 +47,9 @@ public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Se
public Expr getAggEle() {
return getAggOpExpr().aggEleExpr();
}
@Override
public ParsedAggEle getParsedAggEle() {
return this.parsedAggEle;
}
}

View File

@ -19,6 +19,9 @@ import com.antgroup.openspg.reasoner.lube.common.expr.AggUdf;
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
@ -139,6 +142,9 @@ public abstract class BaseGroupProcess implements Serializable {
*/
public abstract Expr getAggEle();
/** get parsed agg ele */
public abstract ParsedAggEle getParsedAggEle();
public Set<String> parseExprUseAliasSet() {
scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
@ -148,6 +154,26 @@ public abstract class BaseGroupProcess implements Serializable {
return WareHouseUtils.getRuleList(getAggEle());
}
protected ParsedAggEle parsedAggEle() {
String sourceAlias = null;
String sourcePropertyName = null;
List<String> exprStrList = null;
Expr aggEle = getAggEle();
if (aggEle instanceof Ref) {
Ref sourceRef = (Ref) aggEle;
sourceAlias = sourceRef.refName();
} else if (aggEle instanceof UnaryOpExpr) {
UnaryOpExpr expr = (UnaryOpExpr) aggEle;
GetField getField = (GetField) expr.name();
sourceAlias = ((Ref) expr.arg()).refName();
sourcePropertyName = getField.fieldName();
} else if (1 == this.exprUseAliasSet.size()) {
sourceAlias = this.exprUseAliasSet.iterator().next();
exprStrList = this.exprRuleString;
}
return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList);
}
/**
* getter
*

View File

@ -0,0 +1,40 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.rdg.common.groupProcess;
import java.util.List;
public class ParsedAggEle {
private final String sourceAlias;
private final String sourcePropertyName;
private final List<String> exprStrList;
public ParsedAggEle(String sourceAlias, String sourcePropertyName, List<String> exprStrList) {
this.sourceAlias = sourceAlias;
this.sourcePropertyName = sourcePropertyName;
this.exprStrList = exprStrList;
}
public String getSourceAlias() {
return sourceAlias;
}
public String getSourcePropertyName() {
return sourcePropertyName;
}
public List<String> getExprStrList() {
return exprStrList;
}
}