mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-06-27 03:20:10 +00:00
feat(reasoner): support value type inference in udf (#132)
Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
parent
eb2590aada
commit
258b0e7dfb
@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
package com.antgroup.openspg.reasoner.common.types
|
package com.antgroup.openspg.reasoner.common.types
|
||||||
|
|
||||||
|
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||||
|
|
||||||
trait KgType {
|
trait KgType {
|
||||||
def isNullable: Boolean = false
|
def isNullable: Boolean = false
|
||||||
}
|
}
|
||||||
@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType
|
|||||||
* @param elementType
|
* @param elementType
|
||||||
*/
|
*/
|
||||||
final case class KTMultiVersion(elementType: KgType) extends KgType
|
final case class KTMultiVersion(elementType: KgType) extends KgType
|
||||||
|
|
||||||
|
object KgType {
|
||||||
|
|
||||||
|
def getNumberSeq(kgType: KgType): Int = {
|
||||||
|
kgType match {
|
||||||
|
case KTInteger => 1
|
||||||
|
case KTLong => 2
|
||||||
|
case KTDouble => 3
|
||||||
|
case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
IRProperty(s.alias, propertyName) ->
|
IRProperty(s.alias, propertyName) ->
|
||||||
ProjectRule(
|
ProjectRule(
|
||||||
IRProperty(s.alias, propertyName),
|
IRProperty(s.alias, propertyName),
|
||||||
propertyType,
|
|
||||||
Ref(ddlBlockWithNodes._3.target.alias)))))
|
Ref(ddlBlockWithNodes._3.target.alias)))))
|
||||||
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
|
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
|
||||||
case AddPredicate(predicate) =>
|
case AddPredicate(predicate) =>
|
||||||
@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
ProjectBlock(
|
ProjectBlock(
|
||||||
List.apply(opBlock),
|
List.apply(opBlock),
|
||||||
ProjectFields(Map.apply(lValueName ->
|
ProjectFields(Map.apply(lValueName ->
|
||||||
ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain))))
|
ProjectRule(lValueName, opChain))))
|
||||||
}
|
}
|
||||||
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
|
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
|
||||||
ProjectBlock(
|
ProjectBlock(
|
||||||
@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
lValueName ->
|
lValueName ->
|
||||||
ProjectRule(
|
ProjectRule(
|
||||||
lValueName,
|
lValueName,
|
||||||
exprParser.parseRetType(opChain.curExpr),
|
|
||||||
opChain.curExpr))))
|
opChain.curExpr))))
|
||||||
case _ => null
|
case _ => null
|
||||||
}
|
}
|
||||||
@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
List.empty)
|
List.empty)
|
||||||
case _ =>
|
case _ =>
|
||||||
rule match {
|
rule match {
|
||||||
case ProjectRule(_, lvalueType, _) =>
|
case ProjectRule(_, _) =>
|
||||||
val projectRule = ProjectRule(lvalueFiled, lvalueType, expr)
|
val projectRule = ProjectRule(lvalueFiled, expr)
|
||||||
ProjectBlock(
|
ProjectBlock(
|
||||||
List.apply(preBlock),
|
List.apply(preBlock),
|
||||||
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
|
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
|
||||||
@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
|
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
|
||||||
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
|
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
|
||||||
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
|
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
|
||||||
(ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true)
|
(ProjectRule(IRVariable(defaultName), expr), columnName, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
def parseGraphStructure(
|
def parseGraphStructure(
|
||||||
@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
val defaultColumnName = parseExpr2ElementStr(expr)
|
val defaultColumnName = parseExpr2ElementStr(expr)
|
||||||
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
|
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
|
||||||
(
|
(
|
||||||
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
|
ProjectRule(IRVariable(defaultColumnName), expr),
|
||||||
columnName,
|
columnName,
|
||||||
false)
|
false)
|
||||||
}
|
}
|
||||||
@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
|
|||||||
val defaultColumnName = parseExpr2ElementStr(expr)
|
val defaultColumnName = parseExpr2ElementStr(expr)
|
||||||
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
|
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
|
||||||
(
|
(
|
||||||
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
|
ProjectRule(IRVariable(defaultColumnName), expr),
|
||||||
columnName,
|
columnName,
|
||||||
false)
|
false)
|
||||||
}
|
}
|
||||||
|
@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def parseRetType(expr: Expr): KgType = {
|
|
||||||
KTObject
|
|
||||||
}
|
|
||||||
|
|
||||||
def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
|
def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
|
||||||
ctx.getChild(0) match {
|
ctx.getChild(0) match {
|
||||||
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
|
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
|
||||||
@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
|
|||||||
if (ctx.property_name() != null) {
|
if (ctx.property_name() != null) {
|
||||||
ProjectRule(
|
ProjectRule(
|
||||||
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
|
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
|
||||||
parseRetType(expr),
|
|
||||||
expr)
|
expr)
|
||||||
} else {
|
} else {
|
||||||
ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr)
|
ProjectRule(IRVariable(ctx.identifier().getText), expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser
|
|||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.constants.Constants
|
import com.antgroup.openspg.reasoner.common.constants.Constants
|
||||||
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
|
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
|
||||||
import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
|
|
||||||
import com.antgroup.openspg.reasoner.lube.block._
|
import com.antgroup.openspg.reasoner.lube.block._
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr._
|
import com.antgroup.openspg.reasoner.lube.common.expr._
|
||||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||||
@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
print(block.pretty)
|
print(block.pretty)
|
||||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||||
proj.projects.items.head._2 should equal(
|
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
|
||||||
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
it("addproperies with constraint") {
|
it("addproperies with constraint") {
|
||||||
@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
print(block.pretty)
|
print(block.pretty)
|
||||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||||
proj.projects.items.head._2 should equal(
|
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
|
||||||
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
it("addproperies2") {
|
it("addproperies2") {
|
||||||
@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||||
proj.projects.items.head._2 should equal(
|
proj.projects.items.head._2 should equal(
|
||||||
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
|
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
|
||||||
}
|
}
|
||||||
it("addproperies") {
|
it("addproperies") {
|
||||||
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
||||||
@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
|
||||||
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
|
||||||
proj.projects.items.head._2 should equal(
|
proj.projects.items.head._2 should equal(
|
||||||
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
|
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
|
||||||
}
|
}
|
||||||
it("addNode") {
|
it("addNode") {
|
||||||
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
|
||||||
@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
.asInstanceOf[AddPredicate]
|
.asInstanceOf[AddPredicate]
|
||||||
.predicate
|
.predicate
|
||||||
.fields
|
.fields
|
||||||
.keySet should contain ("same_domain_num")
|
.keySet should contain("same_domain_num")
|
||||||
|
|
||||||
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
|
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
|
||||||
blocks(1)
|
blocks(1)
|
||||||
@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
val block = parser.parse(dsl)
|
val block = parser.parse(dsl)
|
||||||
print(block.pretty)
|
print(block.pretty)
|
||||||
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
|
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
|
||||||
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o)))))
|
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
|
||||||
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
|
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
|
||||||
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
|
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
|
||||||
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
|
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
|
||||||
@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
|
|||||||
val block = parser.parse(dsl)
|
val block = parser.parse(dsl)
|
||||||
print(block.pretty)
|
print(block.pretty)
|
||||||
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
|
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
|
||||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value)))))
|
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
|
||||||
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
|
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
|
||||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
|
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
|
||||||
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
|
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
|
||||||
|
@ -33,6 +33,10 @@
|
|||||||
<groupId>com.antgroup.openspg.reasoner</groupId>
|
<groupId>com.antgroup.openspg.reasoner</groupId>
|
||||||
<artifactId>reasoner-common</artifactId>
|
<artifactId>reasoner-common</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.antgroup.openspg.reasoner</groupId>
|
||||||
|
<artifactId>reasoner-udf</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scala-lang</groupId>
|
<groupId>org.scala-lang</groupId>
|
||||||
<artifactId>scala-library</artifactId>
|
<artifactId>scala-library</artifactId>
|
||||||
|
@ -13,10 +13,12 @@
|
|||||||
|
|
||||||
package com.antgroup.openspg.reasoner.lube.catalog
|
package com.antgroup.openspg.reasoner.lube.catalog
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
|
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
|
||||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
|
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
|
||||||
import scala.collection.mutable
|
import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -27,6 +29,7 @@ import scala.collection.mutable
|
|||||||
*/
|
*/
|
||||||
abstract class Catalog() extends Serializable {
|
abstract class Catalog() extends Serializable {
|
||||||
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
|
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
|
||||||
|
@transient private val udfRepo = UdfMngFactory.getUdfMng
|
||||||
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
|
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable {
|
|||||||
graphRepository.get(graphName).orNull
|
graphRepository.get(graphName).orNull
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def getUdfRepo: UdfMng = udfRepo
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get schema from knowledge graph
|
* Get schema from knowledge graph
|
||||||
*/
|
*/
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
package com.antgroup.openspg.reasoner.lube.common.rule
|
package com.antgroup.openspg.reasoner.lube.common.rule
|
||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.types.KgType
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
|
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
|
||||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRField
|
import com.antgroup.openspg.reasoner.lube.common.graph.IRField
|
||||||
|
|
||||||
@ -39,13 +38,6 @@ trait Rule extends Cloneable{
|
|||||||
*/
|
*/
|
||||||
def getExpr: Expr
|
def getExpr: Expr
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get lvalue type
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
def getLvalueType: KgType
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get dependencies
|
* get dependencies
|
||||||
* @return
|
* @return
|
||||||
|
@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
|
|||||||
*/
|
*/
|
||||||
override def getExpr: Expr = expr
|
override def getExpr: Expr = expr
|
||||||
|
|
||||||
/**
|
|
||||||
* get lvalue type
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
override def getLvalueType: KgType = KTBoolean
|
|
||||||
|
|
||||||
override def andRule(rule: Rule): Rule = {
|
override def andRule(rule: Rule): Rule = {
|
||||||
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
|
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
|
||||||
|
|
||||||
@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
|
|||||||
* @param lvalueType
|
* @param lvalueType
|
||||||
* @param expr
|
* @param expr
|
||||||
*/
|
*/
|
||||||
final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
|
final case class ProjectRule(output: IRField, expr: Expr)
|
||||||
extends DependencyRule {
|
extends DependencyRule {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
|
|||||||
*/
|
*/
|
||||||
override def getExpr: Expr = expr
|
override def getExpr: Expr = expr
|
||||||
|
|
||||||
/**
|
|
||||||
* get lvalue type
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
override def getLvalueType: KgType = lvalueType
|
|
||||||
|
|
||||||
override def andRule(rule: Rule): Rule = {
|
override def andRule(rule: Rule): Rule = {
|
||||||
throw UnsupportedOperationException("ProjectRule cannot support andRule")
|
throw UnsupportedOperationException("ProjectRule cannot support andRule")
|
||||||
|
@ -138,7 +138,7 @@ object RuleUtils {
|
|||||||
case logicRule: LogicRule =>
|
case logicRule: LogicRule =>
|
||||||
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
|
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
|
||||||
case _ =>
|
case _ =>
|
||||||
ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr)
|
ProjectRule(IRVariable(ruleNameStr), expr)
|
||||||
}
|
}
|
||||||
val oldDependencies = rule.getDependencies
|
val oldDependencies = rule.getDependencies
|
||||||
if (oldDependencies != null) {
|
if (oldDependencies != null) {
|
||||||
@ -162,7 +162,7 @@ object RuleUtils {
|
|||||||
case logicRule: LogicRule =>
|
case logicRule: LogicRule =>
|
||||||
LogicRule(rule.getName, logicRule.ruleExplain, expr)
|
LogicRule(rule.getName, logicRule.ruleExplain, expr)
|
||||||
case _ =>
|
case _ =>
|
||||||
ProjectRule(rule.getOutput, rule.getLvalueType, expr)
|
ProjectRule(rule.getOutput, expr)
|
||||||
}
|
}
|
||||||
val oldDependencies = rule.getDependencies
|
val oldDependencies = rule.getDependencies
|
||||||
if (oldDependencies != null) {
|
if (oldDependencies != null) {
|
||||||
|
@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
|
|||||||
false)))),
|
false)))),
|
||||||
ProjectFields(
|
ProjectFields(
|
||||||
Map.apply(IRVariable("total_domain_num") ->
|
Map.apply(IRVariable("total_domain_num") ->
|
||||||
ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o")))))
|
ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
|
||||||
val p = BlockUtils.transBlock2Graph(block)
|
val p = BlockUtils.transBlock2Graph(block)
|
||||||
p.size should equal(1)
|
p.size should equal(1)
|
||||||
p.head.graphPattern.nodes.size should equal(2)
|
p.head.graphPattern.nodes.size should equal(2)
|
||||||
@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
|
|||||||
it("rename_rule") {
|
it("rename_rule") {
|
||||||
val rule = ProjectRule(
|
val rule = ProjectRule(
|
||||||
IRVariable("a"),
|
IRVariable("a"),
|
||||||
KTObject,
|
|
||||||
BinaryOpExpr(
|
BinaryOpExpr(
|
||||||
BEqual,
|
BEqual,
|
||||||
UnaryOpExpr(GetField("birthDate"), Ref("e")),
|
UnaryOpExpr(GetField("birthDate"), Ref("e")),
|
||||||
@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
|
|||||||
it("variable_rule") {
|
it("variable_rule") {
|
||||||
val rule = ProjectRule(
|
val rule = ProjectRule(
|
||||||
IRVariable("a"),
|
IRVariable("a"),
|
||||||
KTObject,
|
|
||||||
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
|
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
|
||||||
|
|
||||||
val rule2 = ProjectRule(
|
val rule2 = ProjectRule(
|
||||||
IRVariable("b"),
|
IRVariable("b"),
|
||||||
KTObject,
|
|
||||||
BinaryOpExpr(
|
BinaryOpExpr(
|
||||||
BEqual,
|
BEqual,
|
||||||
UnaryOpExpr(GetField("attr1"), Ref("e")),
|
UnaryOpExpr(GetField("attr1"), Ref("e")),
|
||||||
@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
|
|||||||
def getDependenceRule(): Rule = {
|
def getDependenceRule(): Rule = {
|
||||||
val r0 = ProjectRule(
|
val r0 = ProjectRule(
|
||||||
IRVariable("r0"),
|
IRVariable("r0"),
|
||||||
KTLong,
|
|
||||||
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
|
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
|
||||||
)
|
)
|
||||||
val r1 = LogicRule(
|
val r1 = LogicRule(
|
||||||
@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
|
|||||||
val r0 = LogicRule("tmp", "",
|
val r0 = LogicRule("tmp", "",
|
||||||
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
|
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
|
||||||
val r = ProjectRule(IRVariable("g"),
|
val r = ProjectRule(IRVariable("g"),
|
||||||
KTLong,
|
|
||||||
OpChainExpr(
|
OpChainExpr(
|
||||||
GraphAggregatorExpr(
|
GraphAggregatorExpr(
|
||||||
"unresolved_default_path",
|
"unresolved_default_path",
|
||||||
|
@ -13,9 +13,16 @@
|
|||||||
|
|
||||||
package com.antgroup.openspg.reasoner.lube.logical
|
package com.antgroup.openspg.reasoner.lube.logical
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||||
import com.antgroup.openspg.reasoner.common.trees.BottomUp
|
import com.antgroup.openspg.reasoner.common.trees.BottomUp
|
||||||
|
import com.antgroup.openspg.reasoner.common.types._
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr._
|
import com.antgroup.openspg.reasoner.lube.common.expr._
|
||||||
|
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
|
||||||
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
||||||
|
import com.antgroup.openspg.reasoner.udf.UdfMng
|
||||||
|
|
||||||
object ExprUtil {
|
object ExprUtil {
|
||||||
|
|
||||||
@ -46,15 +53,13 @@ object ExprUtil {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def needResolved(rule: Expr): Boolean = {
|
def needResolved(rule: Expr): Boolean = {
|
||||||
!getReferProperties(rule).filter(_._1 == null).isEmpty
|
!getReferProperties(rule).filter(_._1 == null).isEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
|
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
|
||||||
|
|
||||||
def rewriter: PartialFunction[Expr, Expr] = {
|
def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
|
||||||
case Ref(refName) =>
|
|
||||||
if (replaceVar.contains(refName)) {
|
if (replaceVar.contains(refName)) {
|
||||||
val propertyVar = replaceVar(refName)
|
val propertyVar = replaceVar(refName)
|
||||||
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
|
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
|
||||||
@ -76,4 +81,104 @@ object ExprUtil {
|
|||||||
newRule
|
newRule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
|
||||||
|
expr match {
|
||||||
|
case Ref(name) =>
|
||||||
|
if (referVars.contains(IRVariable(name))) {
|
||||||
|
referVars(IRVariable(name))
|
||||||
|
} else {
|
||||||
|
KTObject
|
||||||
|
}
|
||||||
|
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
|
||||||
|
case BinaryOpExpr(name, l, r) =>
|
||||||
|
name match {
|
||||||
|
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |
|
||||||
|
BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign =>
|
||||||
|
KTBoolean
|
||||||
|
case BAdd | BSub | BMul | BDiv | BMod =>
|
||||||
|
val left = getTargetType(l, referVars, udfRepo)
|
||||||
|
val right = getTargetType(r, referVars, udfRepo)
|
||||||
|
getUpperType(left, right)
|
||||||
|
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||||
|
}
|
||||||
|
case UnaryOpExpr(name, arg) =>
|
||||||
|
name match {
|
||||||
|
case Not | Exists => KTBoolean
|
||||||
|
case Abs | Neg => getTargetType(arg, referVars, udfRepo)
|
||||||
|
case Floor | Ceil => KTDouble
|
||||||
|
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||||
|
}
|
||||||
|
case FunctionExpr(name, funcArgs) =>
|
||||||
|
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
|
||||||
|
name match {
|
||||||
|
case "rule_value" => types(1)
|
||||||
|
case "cast_type" | "Cast" =>
|
||||||
|
funcArgs(1).asInstanceOf[VString].value match {
|
||||||
|
case "int" | "bigint" | "long" => KTLong
|
||||||
|
case "float" | "double" => KTDouble
|
||||||
|
case "varchar" | "string" => KTString
|
||||||
|
case _ =>
|
||||||
|
throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}")
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
val udf = udfRepo.getUdfMeta(name, types.asJava)
|
||||||
|
if (udf != null) {
|
||||||
|
udf.getResultType
|
||||||
|
} else {
|
||||||
|
throw UnsupportedOperationException(s"cannot find UDF: ${name}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case AggOpExpr(name, args) =>
|
||||||
|
name match {
|
||||||
|
case Min | Max | Sum | Avg | First | Accumulate(_) =>
|
||||||
|
getTargetType(args, referVars, udfRepo)
|
||||||
|
case StrJoin(_) => KTString
|
||||||
|
case Count => KTLong
|
||||||
|
case AggUdf(name, _) =>
|
||||||
|
val types = getTargetType(args.head, referVars, udfRepo)
|
||||||
|
val udf = udfRepo.getUdafMeta(name, types)
|
||||||
|
if (udf != null) {
|
||||||
|
udf.getResultType
|
||||||
|
} else {
|
||||||
|
throw UnsupportedOperationException(s"cannot find UDAF ${name}")
|
||||||
|
}
|
||||||
|
case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
|
||||||
|
}
|
||||||
|
case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo)
|
||||||
|
case ListOpExpr(name, _) =>
|
||||||
|
name match {
|
||||||
|
case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo)
|
||||||
|
case Constraint(_, _, _) => KTBoolean
|
||||||
|
case Get(_) | Slice(_, _) => KTObject
|
||||||
|
}
|
||||||
|
case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo)
|
||||||
|
case VNull | VString(_) => KTString
|
||||||
|
case VLong(_) => KTLong
|
||||||
|
case VDouble(_) => KTDouble
|
||||||
|
case VBoolean(_) => KTBoolean
|
||||||
|
case VList(_, listType) => KTList(listType)
|
||||||
|
case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
|
||||||
|
val newReferVars = new mutable.HashMap[IRField, KgType]
|
||||||
|
newReferVars.++=(referVars)
|
||||||
|
for (r <- rule.getDependencies) {
|
||||||
|
newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo))
|
||||||
|
}
|
||||||
|
getTargetType(rule.getExpr, newReferVars.toMap, udfRepo)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def getUpperType(left: KgType, right: KgType): KgType = {
|
||||||
|
val l = KgType.getNumberSeq(left)
|
||||||
|
val r = KgType.getNumberSeq(right)
|
||||||
|
if (l >= r) {
|
||||||
|
left
|
||||||
|
} else {
|
||||||
|
right
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
|
|||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||||
import com.antgroup.openspg.reasoner.common.types.KTObject
|
import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject}
|
||||||
import com.antgroup.openspg.reasoner.lube.block.Aggregations
|
import com.antgroup.openspg.reasoner.lube.block.Aggregations
|
||||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
|
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
|
||||||
@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL
|
|||||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
|
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
|
||||||
import org.apache.commons.lang3.StringUtils
|
import org.apache.commons.lang3.StringUtils
|
||||||
|
|
||||||
class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit
|
||||||
|
context: LogicalPlannerContext) {
|
||||||
|
|
||||||
def plan(dependency: LogicalOperator): LogicalOperator = {
|
def plan(dependency: LogicalOperator): LogicalOperator = {
|
||||||
val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
|
val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
|
||||||
@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
|||||||
v
|
v
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
val referTypes = new mutable.HashMap[IRField, KgType]()
|
||||||
|
for (v <- ruleFields) {
|
||||||
|
resolved.getVar(v.name) match {
|
||||||
|
case p: PropertyVar => referTypes.put(v, p.field.kgType)
|
||||||
|
case node: NodeVar =>
|
||||||
|
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||||
|
case edge: EdgeVar =>
|
||||||
|
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||||
|
case _ => throw UnsupportedOperationException(s"cannot support $v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
|
||||||
|
val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo)
|
||||||
|
|
||||||
val renameVar = ruleFields
|
val renameVar = ruleFields
|
||||||
.filter(_.isInstanceOf[IRVariable])
|
.filter(_.isInstanceOf[IRVariable])
|
||||||
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
|
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
|
||||||
@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
|
|||||||
val field = getAggregateTarget(referFields, resolved, dependency)
|
val field = getAggregateTarget(referFields, resolved, dependency)
|
||||||
field match {
|
field match {
|
||||||
case IRNode(alias, _) =>
|
case IRNode(alias, _) =>
|
||||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
|
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
|
||||||
aggMap.put(propertyVar, newAggExpr)
|
aggMap.put(propertyVar, newAggExpr)
|
||||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||||
case IREdge(alias, _) =>
|
case IREdge(alias, _) =>
|
||||||
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
|
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
|
||||||
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
|
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
|
||||||
} else {
|
} else {
|
||||||
val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
|
val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
|
||||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||||
aggMap.put(propertyVar, newAggExpr)
|
aggMap.put(propertyVar, newAggExpr)
|
||||||
}
|
}
|
||||||
case IRVariable(alias) =>
|
case IRVariable(alias) =>
|
||||||
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
|
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
|
||||||
val propertyVar =
|
val propertyVar =
|
||||||
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true))
|
PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true))
|
||||||
aggMap.put(propertyVar, newAggExpr)
|
aggMap.put(propertyVar, newAggExpr)
|
||||||
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
|
||||||
case _ =>
|
case _ =>
|
||||||
|
@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
|
|||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
|
||||||
import com.antgroup.openspg.reasoner.common.types.KTObject
|
import com.antgroup.openspg.reasoner.common.types.KgType
|
||||||
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
|
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
|
||||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
|
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
|
||||||
import com.antgroup.openspg.reasoner.lube.common.graph._
|
import com.antgroup.openspg.reasoner.lube.common.graph._
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var}
|
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
|
||||||
|
import com.antgroup.openspg.reasoner.lube.logical._
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
|
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
|
||||||
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
|
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
|
||||||
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
|
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
|
||||||
import org.apache.commons.lang3.StringUtils
|
import org.apache.commons.lang3.StringUtils
|
||||||
|
|
||||||
class ProjectPlanner(projects: ProjectFields) {
|
class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
|
||||||
|
|
||||||
def plan(dependency: LogicalOperator): LogicalOperator = {
|
def plan(dependency: LogicalOperator): LogicalOperator = {
|
||||||
val projectMap = new mutable.HashMap[Var, Expr]()
|
val projectMap = new mutable.HashMap[Var, Expr]()
|
||||||
@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) {
|
|||||||
v
|
v
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
val propertyVar = getTarget(rule._1, referVars, resolved, dependency)
|
val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency)
|
||||||
val transformer = new Rule2ExprTransformer()
|
val transformer = new Rule2ExprTransformer()
|
||||||
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
|
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
|
||||||
val replaceVar = reference
|
val replaceVar = reference
|
||||||
@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) {
|
|||||||
private def getTarget(
|
private def getTarget(
|
||||||
left: IRField,
|
left: IRField,
|
||||||
referVars: List[IRField],
|
referVars: List[IRField],
|
||||||
|
rule: Rule,
|
||||||
resolved: SolvedModel,
|
resolved: SolvedModel,
|
||||||
dependency: LogicalOperator): PropertyVar = {
|
dependency: LogicalOperator): PropertyVar = {
|
||||||
|
val referTypes = new mutable.HashMap[IRField, KgType]()
|
||||||
|
for (v <- referVars) {
|
||||||
|
resolved.getVar(v.name) match {
|
||||||
|
case p: PropertyVar => referTypes.put(v, p.field.kgType)
|
||||||
|
case node: NodeVar =>
|
||||||
|
node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||||
|
case edge: EdgeVar =>
|
||||||
|
edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
|
||||||
|
val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo)
|
||||||
|
|
||||||
left match {
|
left match {
|
||||||
case IRVariable(name) =>
|
case IRVariable(name) =>
|
||||||
if (referVars.size == 1) {
|
if (referVars.size == 1) {
|
||||||
PropertyVar(referVars.head.name, new Field(name, KTObject, true))
|
PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
|
||||||
} else {
|
} else {
|
||||||
val aliasSet = new mutable.HashSet[String]()
|
val aliasSet = new mutable.HashSet[String]()
|
||||||
for (rVar <- referVars) {
|
for (rVar <- referVars) {
|
||||||
@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
|
val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
|
||||||
PropertyVar(targetAlias, new Field(left.name, KTObject, true))
|
PropertyVar(targetAlias, new Field(left.name, ruleRetType, true))
|
||||||
}
|
}
|
||||||
case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true))
|
case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true))
|
||||||
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
|
case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
|
||||||
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
|
case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
|
||||||
case _ => throw UnsupportedOperationException(s"cannot support $left")
|
case _ => throw UnsupportedOperationException(s"cannot support $left")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@
|
|||||||
|
|
||||||
package com.antgroup.openspg.reasoner.lube.logical
|
package com.antgroup.openspg.reasoner.lube.logical
|
||||||
|
|
||||||
import com.antgroup.openspg.reasoner.common.types.KTInteger
|
import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
|
||||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr}
|
import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
|
||||||
import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable
|
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
|
||||||
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
|
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
|
||||||
|
import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser
|
||||||
|
import com.antgroup.openspg.reasoner.udf.UdfMngFactory
|
||||||
import org.scalatest.funspec.AnyFunSpec
|
import org.scalatest.funspec.AnyFunSpec
|
||||||
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
|
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
|
||||||
|
|
||||||
@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec {
|
|||||||
val replaceMap = Map.apply(
|
val replaceMap = Map.apply(
|
||||||
"a" -> PropertyVar("b", new Field("id", KTInteger, true)),
|
"a" -> PropertyVar("b", new Field("id", KTInteger, true)),
|
||||||
"c" -> PropertyVar("d", new Field("id", KTInteger, true)))
|
"c" -> PropertyVar("d", new Field("id", KTInteger, true)))
|
||||||
val r1 = ProjectRule(IRVariable("test"),
|
val r1 = ProjectRule(IRVariable("test"), Ref("a"))
|
||||||
KTInteger, Ref("a"))
|
val r2 = ProjectRule(IRVariable("test"), Ref("c"))
|
||||||
val r2 = ProjectRule(IRVariable("test"),
|
|
||||||
KTInteger, Ref("c"))
|
|
||||||
r1.addDependency(r2)
|
r1.addDependency(r2)
|
||||||
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
|
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
|
||||||
print(newRule.getExpr.pretty)
|
print(newRule.getExpr.pretty)
|
||||||
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
|
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
|
||||||
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
|
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
|
||||||
newRule.getExpr.asInstanceOf[UnaryOpExpr]
|
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b")
|
||||||
.arg.asInstanceOf[Ref].refName should equal("b")
|
}
|
||||||
|
|
||||||
|
it("test expr output type") {
|
||||||
|
val parser = new RuleExprParser()
|
||||||
|
val udfRepo = UdfMngFactory.getUdfMng
|
||||||
|
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
|
||||||
|
|
||||||
|
var expr: Expr = parser.parse("A.age")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
|
||||||
|
|
||||||
|
expr = parser.parse("A.age + 1")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong)
|
||||||
|
|
||||||
|
expr = parser.parse("A.age > 10")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean)
|
||||||
|
|
||||||
|
expr = parser.parse("floor(A.age)")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble)
|
||||||
|
|
||||||
|
expr = parser.parse("abs(A.age)")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
|
||||||
|
|
||||||
|
expr = parser.parse("concat(A.age, \",\")")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
|
||||||
|
|
||||||
|
expr = parser.parse("cast_type(A.age, 'string')")
|
||||||
|
ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
|
||||||
|
}
|
||||||
|
|
||||||
|
it("test rule output type") {
|
||||||
|
val parser = new RuleExprParser()
|
||||||
|
val udfRepo = UdfMngFactory.getUdfMng
|
||||||
|
val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
|
||||||
|
|
||||||
|
val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10"))
|
||||||
|
val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age"))
|
||||||
|
rule.addDependency(r1)
|
||||||
|
|
||||||
|
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec {
|
|||||||
println(optimizedLogicalPlan.pretty)
|
println(optimizedLogicalPlan.pretty)
|
||||||
}
|
}
|
||||||
|
|
||||||
it("online") {
|
|
||||||
val dsl = """Define (s:User)-[p:redPacket]->(o:Int) {
|
|
||||||
| GraphStructure {
|
|
||||||
| (s)
|
|
||||||
| }
|
|
||||||
| Rule {
|
|
||||||
|LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m
|
|
||||||
|Latest30DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|Latest7DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|LatestTTT = Latest7DayPayCount.accumulate(+)
|
|
||||||
|LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0
|
|
||||||
|Latest7DayPayCountSum=Latest7DayPayCount
|
|
||||||
|Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0
|
|
||||||
|HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount
|
|
||||||
|HighLost("高频降频100%"):HighReduceValue == 1
|
|
||||||
|HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1
|
|
||||||
|HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8
|
|
||||||
|HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5
|
|
||||||
|HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3
|
|
||||||
|Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|Latest30DayPayDayCount=size(Latest30DayPayCount)
|
|
||||||
|Latest3060DayPayDayCount=size(Latest3060DayPayCount)
|
|
||||||
|High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13
|
|
||||||
|High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13
|
|
||||||
|Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
|
||||||
|Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
|
||||||
|Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
|
|
||||||
|Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|
|
||||||
|Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
|
|
||||||
|Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|Latest6090DayPayDayCount=size(Latest6090DayPayCount)
|
|
||||||
|Latest60DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|Latest60DayPayDayCount=size(Latest60DayPayCount)
|
|
||||||
|Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0
|
|
||||||
|Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0
|
|
||||||
|HistoricallyPay=s.ngfe_tag__pay_cnt_total
|
|
||||||
|HistoricallyPayCount=size(HistoricallyPay)
|
|
||||||
|New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0
|
|
||||||
|Latest90DayPayCount=s.ngfe_tag__pay_cnt_d
|
|
||||||
|Latest90DayPayDayCount=size(Latest90DayPayCount)
|
|
||||||
|Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0
|
|
||||||
|o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost"))
|
|
||||||
| }
|
|
||||||
|}""".stripMargin
|
|
||||||
val parser = new OpenSPGDslParser()
|
|
||||||
val block = parser.parse(dsl)
|
|
||||||
println(block.pretty)
|
|
||||||
val schema: Map[String, Set[String]] = Map.apply(
|
|
||||||
"User" -> Set
|
|
||||||
.apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d"))
|
|
||||||
val catalog = new PropertyGraphCatalog(schema)
|
|
||||||
catalog.init()
|
|
||||||
implicit val context: LogicalPlannerContext =
|
|
||||||
LogicalPlannerContext(catalog, parser, Map.empty)
|
|
||||||
val logicalPlan = LogicalPlanner.plan(block)
|
|
||||||
println(logicalPlan.head.pretty)
|
|
||||||
val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head)
|
|
||||||
println(optimizedLogicalPlan.pretty)
|
|
||||||
}
|
|
||||||
|
|
||||||
it("test start flag") {
|
it("test start flag") {
|
||||||
val dsl =
|
val dsl =
|
||||||
"""
|
"""
|
||||||
@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec {
|
|||||||
| (s)-[p2:followPM]->(o)
|
| (s)-[p2:followPM]->(o)
|
||||||
|}
|
|}
|
||||||
|Rule {
|
|Rule {
|
||||||
| c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0)
|
| c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0)
|
||||||
|
|
|
|
||||||
|}
|
|}
|
||||||
|Action {
|
|Action {
|
||||||
|
@ -41,7 +41,7 @@
|
|||||||
<properties>
|
<properties>
|
||||||
<antlr4.version>4.8</antlr4.version>
|
<antlr4.version>4.8</antlr4.version>
|
||||||
<fastjson.version>1.2.71_noneautotype</fastjson.version>
|
<fastjson.version>1.2.71_noneautotype</fastjson.version>
|
||||||
<geotools.version>27.0</geotools.version>
|
<geotools.version>28.0</geotools.version>
|
||||||
<google.s2.version>2.0.0</google.s2.version>
|
<google.s2.version>2.0.0</google.s2.version>
|
||||||
<hadoop.version>2.7.2</hadoop.version>
|
<hadoop.version>2.7.2</hadoop.version>
|
||||||
<hive.version>3.1.0</hive.version>
|
<hive.version>3.1.0</hive.version>
|
||||||
|
@ -137,7 +137,7 @@ public class KgReasonerABMLocalTest {
|
|||||||
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
||||||
+ " }\n"
|
+ " }\n"
|
||||||
+ " Rule {\n"
|
+ " Rule {\n"
|
||||||
+ " \tv = t.stock/t.total\n"
|
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
|
||||||
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
||||||
+ " o = v\n"
|
+ " o = v\n"
|
||||||
+ " }\n"
|
+ " }\n"
|
||||||
@ -197,7 +197,7 @@ public class KgReasonerABMLocalTest {
|
|||||||
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
|
||||||
+ " }\n"
|
+ " }\n"
|
||||||
+ " Rule {\n"
|
+ " Rule {\n"
|
||||||
+ " \tv = t.stock/t.total\n"
|
+ " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
|
||||||
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
+ " R1(\"必须大于20%\"): v > 0.2\n"
|
||||||
+ " o = v\n"
|
+ " o = v\n"
|
||||||
+ " }\n"
|
+ " }\n"
|
||||||
|
@ -50,9 +50,12 @@ public class KgReasonerAliasSetKFilmTest {
|
|||||||
+ "R1: A.id == $idSet1\n"
|
+ "R1: A.id == $idSet1\n"
|
||||||
+ "R2: B.id in $idSet2\n"
|
+ "R2: B.id in $idSet2\n"
|
||||||
+ "R3: C.id in $idSet2\n"
|
+ "R3: C.id in $idSet2\n"
|
||||||
+ "totalTrans1 = group(A,B,C).sum(p1.amount)\n"
|
+ "p1_amt = cast_type(p1.amount,'long')\n"
|
||||||
+ "totalTrans2 = group(A,B,C).sum(p2.amount)\n"
|
+ "p2_amt = cast_type(p2.amount,'long')\n"
|
||||||
+ "totalTrans3 = group(A,B,C).sum(p3.amount)\n"
|
+ "p3_amt = cast_type(p3.amount,'long')\n"
|
||||||
|
+ "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
|
||||||
|
+ "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
|
||||||
|
+ "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
|
||||||
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
|
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
|
||||||
+ "R2('取top2'): top(totalTrans, 2)"
|
+ "R2('取top2'): top(totalTrans, 2)"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
@ -90,7 +93,7 @@ public class KgReasonerAliasSetKFilmTest {
|
|||||||
+ "R1: A.id in $idSet1\n"
|
+ "R1: A.id in $idSet1\n"
|
||||||
+ "R2: B.id in $idSet2\n"
|
+ "R2: B.id in $idSet2\n"
|
||||||
+ "R3: C.id in $idSet2\n"
|
+ "R3: C.id in $idSet2\n"
|
||||||
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
|
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
|
||||||
+ "R2('取top2'): top(totalTrans, 3)"
|
+ "R2('取top2'): top(totalTrans, 3)"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
+ "Action {\n"
|
+ "Action {\n"
|
||||||
@ -127,7 +130,7 @@ public class KgReasonerAliasSetKFilmTest {
|
|||||||
+ "R1: A.id == $idSet1\n"
|
+ "R1: A.id == $idSet1\n"
|
||||||
+ "R2: B.id == $idSet2\n"
|
+ "R2: B.id == $idSet2\n"
|
||||||
+ "R3: C.id == $idSet3\n"
|
+ "R3: C.id == $idSet3\n"
|
||||||
+ "totalTrans = p1.amount + p2.amount + p3.amount\n"
|
+ "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
|
||||||
+ "R2('取top2'): top(totalTrans, 3)"
|
+ "R2('取top2'): top(totalTrans, 3)"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
+ "Action {\n"
|
+ "Action {\n"
|
||||||
@ -169,9 +172,12 @@ public class KgReasonerAliasSetKFilmTest {
|
|||||||
+ "R1: A.id in $idSet1\n"
|
+ "R1: A.id in $idSet1\n"
|
||||||
+ "R2: B.id in $idSet2\n"
|
+ "R2: B.id in $idSet2\n"
|
||||||
+ "R3: C.id in $idSet2\n"
|
+ "R3: C.id in $idSet2\n"
|
||||||
+ "t1 = group(A,B,C).sum(p1.amount)\n"
|
+ "p1_amt = cast_type(p1.amount,'long')\n"
|
||||||
+ "t2 = group(A,B,C).sum(p2.amount)\n"
|
+ "p2_amt = cast_type(p2.amount,'long')\n"
|
||||||
+ "t3 = group(A,B,C).sum(p3.amount)\n"
|
+ "p3_amt = cast_type(p3.amount,'long')\n"
|
||||||
|
+ "t1 = group(A,B,C).sum(p1_amt)\n"
|
||||||
|
+ "t2 = group(A,B,C).sum(p2_amt)\n"
|
||||||
|
+ "t3 = group(A,B,C).sum(p3_amt)\n"
|
||||||
+ "totalSum = t1 + t2 + t3"
|
+ "totalSum = t1 + t2 + t3"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
+ "Action {\n"
|
+ "Action {\n"
|
||||||
|
@ -403,7 +403,7 @@ public class KgReasonerTopKFilmTest {
|
|||||||
+ " o->star [starOfFilm] as sf2\n"
|
+ " o->star [starOfFilm] as sf2\n"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
+ "Rule {\n"
|
+ "Rule {\n"
|
||||||
+ "total = sf.joinTs + sf2.joinTs\n"
|
+ "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n"
|
||||||
+ "R2: top(total, 1)\n"
|
+ "R2: top(total, 1)\n"
|
||||||
+ "}\n"
|
+ "}\n"
|
||||||
+ "Action {\n"
|
+ "Action {\n"
|
||||||
|
@ -25,13 +25,7 @@ import com.antgroup.openspg.reasoner.kggraph.AggregationSchemaInfo;
|
|||||||
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
|
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
|
||||||
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
|
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
|
||||||
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
|
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr;
|
import com.antgroup.openspg.reasoner.lube.common.expr.*;
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
|
|
||||||
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
|
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
|
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
|
import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
|
||||||
@ -41,6 +35,7 @@ import com.antgroup.openspg.reasoner.lube.logical.Var;
|
|||||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
|
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
|
||||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
|
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
|
||||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
|
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
|
||||||
|
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
|
||||||
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
|
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
|
||||||
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
|
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
|
||||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
||||||
@ -353,26 +348,14 @@ public class KgGraphAggregateImpl implements Serializable {
|
|||||||
if (null != udafInitParams) {
|
if (null != udafInitParams) {
|
||||||
udaf.initialize(udafInitParams);
|
udaf.initialize(udafInitParams);
|
||||||
}
|
}
|
||||||
|
ParsedAggEle parsedAggEle;
|
||||||
String sourceAlias = null;
|
|
||||||
String sourcePropertyName = null;
|
|
||||||
Set<String> aliasList = aggInfo.getExprUseAliasSet();
|
Set<String> aliasList = aggInfo.getExprUseAliasSet();
|
||||||
if (aliasList.size() <= 1) {
|
if (aliasList.size() <= 1) {
|
||||||
Expr sourceExpr = aggInfo.getAggEle();
|
parsedAggEle = aggInfo.getParsedAggEle();
|
||||||
// aggregate by vertex subgraph
|
|
||||||
if (sourceExpr instanceof Ref) {
|
|
||||||
Ref sourceRef = (Ref) sourceExpr;
|
|
||||||
sourceAlias = sourceRef.refName();
|
|
||||||
} else if (sourceExpr instanceof UnaryOpExpr) {
|
|
||||||
UnaryOpExpr expr = (UnaryOpExpr) sourceExpr;
|
|
||||||
GetField getField = (GetField) expr.name();
|
|
||||||
sourceAlias = ((Ref) expr.arg()).refName();
|
|
||||||
sourcePropertyName = getField.fieldName();
|
|
||||||
}
|
|
||||||
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
|
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
|
||||||
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
||||||
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
|
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
|
||||||
StringBuffer sb = new StringBuffer();
|
StringBuilder sb = new StringBuilder();
|
||||||
for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) {
|
for (KgGraph<IVertexId> valueFiltered2 : valueFilteredList) {
|
||||||
sb.append(valueFiltered2).append("## ");
|
sb.append(valueFiltered2).append("## ");
|
||||||
}
|
}
|
||||||
@ -381,19 +364,41 @@ public class KgGraphAggregateImpl implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
String finalSourcePropertyName = sourcePropertyName;
|
String finalSourcePropertyName = parsedAggEle.getSourcePropertyName();
|
||||||
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
for (KgGraph<IVertexId> valueFiltered : valueFilteredList) {
|
||||||
if (valueFiltered.getVertexAlias().contains(sourceAlias)) {
|
if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) {
|
||||||
List<IVertex<IVertexId, IProperty>> vertexList = valueFiltered.getVertex(sourceAlias);
|
List<IVertex<IVertexId, IProperty>> vertexList =
|
||||||
if (sourcePropertyName == null) {
|
valueFiltered.getVertex(parsedAggEle.getSourceAlias());
|
||||||
|
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
|
||||||
|
vertexList.forEach(
|
||||||
|
vertex -> {
|
||||||
|
Map<String, Object> context =
|
||||||
|
RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias());
|
||||||
|
Object value =
|
||||||
|
RuleRunner.getInstance()
|
||||||
|
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
|
||||||
|
udaf.update(value);
|
||||||
|
});
|
||||||
|
} else if (finalSourcePropertyName == null) {
|
||||||
vertexList.forEach(udaf::update);
|
vertexList.forEach(udaf::update);
|
||||||
} else {
|
} else {
|
||||||
vertexList.forEach(
|
vertexList.forEach(
|
||||||
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
|
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
List<IEdge<IVertexId, IProperty>> edgeList = valueFiltered.getEdge(sourceAlias);
|
List<IEdge<IVertexId, IProperty>> edgeList =
|
||||||
if (sourcePropertyName == null) {
|
valueFiltered.getEdge(parsedAggEle.getSourceAlias());
|
||||||
|
if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
|
||||||
|
edgeList.forEach(
|
||||||
|
edge -> {
|
||||||
|
Map<String, Object> context =
|
||||||
|
RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias());
|
||||||
|
Object value =
|
||||||
|
RuleRunner.getInstance()
|
||||||
|
.executeExpression(context, parsedAggEle.getExprStrList(), taskId);
|
||||||
|
udaf.update(value);
|
||||||
|
});
|
||||||
|
} else if (finalSourcePropertyName == null) {
|
||||||
edgeList.forEach(udaf::update);
|
edgeList.forEach(udaf::update);
|
||||||
} else {
|
} else {
|
||||||
edgeList.forEach(
|
edgeList.forEach(
|
||||||
|
@ -23,6 +23,7 @@ import java.io.Serializable;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
||||||
|
private final ParsedAggEle parsedAggEle;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* constructor
|
* constructor
|
||||||
@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
|
|||||||
*/
|
*/
|
||||||
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
||||||
super(taskId, var, aggregator);
|
super(taskId, var, aggregator);
|
||||||
|
parsedAggEle = parsedAggEle();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -58,4 +60,9 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
|
|||||||
public Expr getAggEle() {
|
public Expr getAggEle() {
|
||||||
return getAggIfOpExpr().aggOpExpr().aggEleExpr();
|
return getAggIfOpExpr().aggOpExpr().aggEleExpr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParsedAggEle getParsedAggEle() {
|
||||||
|
return parsedAggEle;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,9 +22,11 @@ import java.io.Serializable;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
|
||||||
|
private final ParsedAggEle parsedAggEle;
|
||||||
|
|
||||||
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
|
||||||
super(taskId, var, aggregator);
|
super(taskId, var, aggregator);
|
||||||
|
parsedAggEle = parsedAggEle();
|
||||||
}
|
}
|
||||||
|
|
||||||
public AggOpExpr getAggOpExpr() {
|
public AggOpExpr getAggOpExpr() {
|
||||||
@ -45,4 +47,9 @@ public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Se
|
|||||||
public Expr getAggEle() {
|
public Expr getAggEle() {
|
||||||
return getAggOpExpr().aggEleExpr();
|
return getAggOpExpr().aggEleExpr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParsedAggEle getParsedAggEle() {
|
||||||
|
return this.parsedAggEle;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,9 @@ import com.antgroup.openspg.reasoner.lube.common.expr.AggUdf;
|
|||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
|
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
|
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
|
||||||
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
|
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
|
||||||
|
import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
|
||||||
|
import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
|
||||||
|
import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
|
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
|
||||||
import com.antgroup.openspg.reasoner.lube.logical.Var;
|
import com.antgroup.openspg.reasoner.lube.logical.Var;
|
||||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
|
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
|
||||||
@ -139,6 +142,9 @@ public abstract class BaseGroupProcess implements Serializable {
|
|||||||
*/
|
*/
|
||||||
public abstract Expr getAggEle();
|
public abstract Expr getAggEle();
|
||||||
|
|
||||||
|
/** get parsed agg ele */
|
||||||
|
public abstract ParsedAggEle getParsedAggEle();
|
||||||
|
|
||||||
public Set<String> parseExprUseAliasSet() {
|
public Set<String> parseExprUseAliasSet() {
|
||||||
scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
|
scala.collection.immutable.List<String> aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
|
||||||
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
|
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
|
||||||
@ -148,6 +154,26 @@ public abstract class BaseGroupProcess implements Serializable {
|
|||||||
return WareHouseUtils.getRuleList(getAggEle());
|
return WareHouseUtils.getRuleList(getAggEle());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected ParsedAggEle parsedAggEle() {
|
||||||
|
String sourceAlias = null;
|
||||||
|
String sourcePropertyName = null;
|
||||||
|
List<String> exprStrList = null;
|
||||||
|
Expr aggEle = getAggEle();
|
||||||
|
if (aggEle instanceof Ref) {
|
||||||
|
Ref sourceRef = (Ref) aggEle;
|
||||||
|
sourceAlias = sourceRef.refName();
|
||||||
|
} else if (aggEle instanceof UnaryOpExpr) {
|
||||||
|
UnaryOpExpr expr = (UnaryOpExpr) aggEle;
|
||||||
|
GetField getField = (GetField) expr.name();
|
||||||
|
sourceAlias = ((Ref) expr.arg()).refName();
|
||||||
|
sourcePropertyName = getField.fieldName();
|
||||||
|
} else if (1 == this.exprUseAliasSet.size()) {
|
||||||
|
sourceAlias = this.exprUseAliasSet.iterator().next();
|
||||||
|
exprStrList = this.exprRuleString;
|
||||||
|
}
|
||||||
|
return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* getter
|
* getter
|
||||||
*
|
*
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 OpenSPG Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
* in compliance with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
* or implied.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.antgroup.openspg.reasoner.rdg.common.groupProcess;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class ParsedAggEle {
|
||||||
|
private final String sourceAlias;
|
||||||
|
private final String sourcePropertyName;
|
||||||
|
private final List<String> exprStrList;
|
||||||
|
|
||||||
|
public ParsedAggEle(String sourceAlias, String sourcePropertyName, List<String> exprStrList) {
|
||||||
|
this.sourceAlias = sourceAlias;
|
||||||
|
this.sourcePropertyName = sourcePropertyName;
|
||||||
|
this.exprStrList = exprStrList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getSourceAlias() {
|
||||||
|
return sourceAlias;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getSourcePropertyName() {
|
||||||
|
return sourcePropertyName;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getExprStrList() {
|
||||||
|
return exprStrList;
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user