fix(reasoner): support std prop during type inference && fix some bugs in NodeIdToEdgeProperty (#159)

This commit is contained in:
FishJoy 2024-03-15 15:28:20 +08:00 committed by GitHub
parent 3b214401ec
commit 9eac15c4de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 134 additions and 1866 deletions

View File

@ -38,7 +38,7 @@
<properties>
<antlr4.version>4.8</antlr4.version>
<dropwizard.metrics.version>4.2.9</dropwizard.metrics.version>
<geotools.version>27.0</geotools.version>
<geotools.version>28.0</geotools.version>
<google.s2.version>2.0.0</google.s2.version>
<groovy.version>3.0.9</groovy.version>
<gson.version>2.10.1</gson.version>

View File

@ -15,7 +15,7 @@ package com.antgroup.openspg.reasoner.catalog.impl
import scala.language.implicitConversions
import com.antgroup.openspg.core.schema.model.`type`.{BaseSPGType, BasicType, ConceptType, EntityType, EventType, SPGTypeEnum, StandardType}
import com.antgroup.openspg.core.schema.model.`type`._
import com.antgroup.openspg.reasoner.catalog.impl.struct.PropertyMeta
import com.antgroup.openspg.reasoner.common.exception.KGValueException
import com.antgroup.openspg.reasoner.common.types._
@ -30,9 +30,15 @@ object PropertySchemaOps {
case "CONCEPT" =>
KTConcept(propertySchema.getPropRange.getRangeEntityName)
case "STANDARD" =>
KTStd(propertySchema.getPropRange.getRangeEntityName, propertySchema.isSpreadable)
KTStd(
propertySchema.getPropRange.getRangeEntityName,
toKgType(propertySchema.getPropRange.getAttrRangeTypeEnum),
propertySchema.isSpreadable)
case "PROPERTY" =>
KTStd(propertySchema.getPropRange.getRangeEntityName, propertySchema.isSpreadable)
KTStd(
propertySchema.getPropRange.getRangeEntityName,
toKgType(propertySchema.getPropRange.getAttrRangeTypeEnum),
propertySchema.isSpreadable)
case "ENTITY" =>
KTAdvanced(propertySchema.getPropRange.getRangeEntityName)
case _ => throw KGValueException(s"unsupported type: ${propertySchema.getCategory}")
@ -49,7 +55,8 @@ object PropertySchemaOps {
// todo
KTAdvanced(eventType.getName)
case standardType: StandardType =>
KTStd(spgType.getName, standardType.getSpreadable)
// todo basicType support
KTStd(spgType.getName, null, standardType.getSpreadable)
case basicType: BasicType =>
toKgType(basicType.getBasicType.name())
case _ =>
@ -57,7 +64,7 @@ object PropertySchemaOps {
}
}
private def toKgType(basicType: String): KgType = {
private def toKgType(basicType: String): BasicKgType = {
basicType.toUpperCase() match {
case "INTEGER" => KTLong
case "LONG" => KTLong

View File

@ -19,53 +19,57 @@ trait KgType {
def isNullable: Boolean = false
}
case object KTString extends KgType
case object KTCharacter extends KgType
case object KTInteger extends KgType
case object KTLong extends KgType
case object KTDouble extends KgType
trait BasicKgType extends KgType
trait AdvancedKgType extends KgType
case object KTString extends BasicKgType
case object KTCharacter extends BasicKgType
case object KTInteger extends BasicKgType
case object KTLong extends BasicKgType
case object KTDouble extends BasicKgType
// corresponding to java object
case object KTObject extends KgType
case object KTDate extends KgType
case object KTBoolean extends KgType
case object KTParameter extends KgType
case object KTObject extends BasicKgType
case object KTDate extends BasicKgType
case object KTBoolean extends BasicKgType
/**
* list type
* @param elementType element type
*/
final case class KTList(elementType: KgType) extends KgType
final case class KTList(elementType: KgType) extends AdvancedKgType
/**
* array type
* @param elementType element type
*/
final case class KTArray(elementType: KgType) extends KgType
final case class KTArray(elementType: KgType) extends AdvancedKgType
/**
* Standard entity in Knowledge Graph.
* @param label entity type name
* @param spreadable is spreadable.
*/
final case class KTStd(label: String, spreadable: Boolean) extends KgType
final case class KTStd(label: String, basicType: BasicKgType, spreadable: Boolean)
extends AdvancedKgType
/**
* Meta concept in Knowledge Graph.
* @param label meta concept name
*/
final case class KTConcept(label: String) extends KgType
final case class KTConcept(label: String) extends AdvancedKgType
/**
* Custom semantic type, which linked to entity in Knowledge Graph.
* @param label entity type TODO add link function
*/
final case class KTAdvanced(label: String) extends KgType
final case class KTAdvanced(label: String) extends AdvancedKgType
/**
* multi version property, default version number unit is ms
* @param elementType
*/
final case class KTMultiVersion(elementType: KgType) extends KgType
final case class KTMultiVersion(elementType: KgType) extends AdvancedKgType
object KgType {

View File

@ -91,7 +91,16 @@ object ExprUtil {
} else {
KTObject
}
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
case UnaryOpExpr(GetField(name), Ref(alis)) =>
val kgType = referVars(IRProperty(alis, name))
if (kgType.isInstanceOf[BasicKgType]) {
kgType
} else {
kgType match {
case KTStd(_, basicType, _) => basicType
case _ => KTObject
}
}
case BinaryOpExpr(name, l, r) =>
name match {
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |

View File

@ -20,7 +20,7 @@ import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.graph.edge
import com.antgroup.openspg.reasoner.common.types.KTString
import com.antgroup.openspg.reasoner.lube.block.{AddPredicate, DDLOp}
import com.antgroup.openspg.reasoner.lube.block.{AddPredicate, AddProperty, AddVertex, DDLOp}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRNode, IRProperty, IRVariable}
@ -85,8 +85,10 @@ object NodeIdToEdgeProperty extends Rule {
private def genField(direction: edge.Direction, fieldName: String): String = {
(direction, fieldName) match {
case (edge.Direction.OUT, Constants.NODE_ID_KEY) => Constants.EDGE_TO_ID_KEY
case (edge.Direction.OUT, Constants.CONTEXT_LABEL) => Constants.EDGE_TO_ID_TYPE_KEY
case (edge.Direction.OUT | edge.Direction.BOTH, Constants.NODE_ID_KEY) =>
Constants.EDGE_TO_ID_KEY
case (edge.Direction.OUT | edge.Direction.BOTH, Constants.CONTEXT_LABEL) =>
Constants.EDGE_TO_ID_TYPE_KEY
case (edge.Direction.IN, Constants.NODE_ID_KEY) => Constants.EDGE_FROM_ID_KEY
case (edge.Direction.IN, Constants.CONTEXT_LABEL) => Constants.EDGE_FROM_ID_TYPE_KEY
case (_, _) =>
@ -124,31 +126,56 @@ object NodeIdToEdgeProperty extends Rule {
private def projectUpdate(project: Project, map: Map[String, Object]): Project = {
val exprMap = new mutable.HashMap[Var, Expr]()
for (expr <- project.expr) {
val input = ExprUtils.getAllInputFieldInRule(
expr._2,
project.solved.getNodeAliasSet,
project.solved.getEdgeAliasSet)
val replaceVar = new mutable.HashMap[IRField, IRField]
for (irField <- input) {
if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) {
for (propName <- irField.asInstanceOf[IRNode].fields) {
if (NODE_DEFAULT_PROPS.contains(propName)) {
val edgeInfo = map(irField.name).asInstanceOf[Connection]
replaceVar.put(
IRProperty(irField.name, propName),
IRProperty(edgeInfo.alias, genField(edgeInfo.direction, propName)))
replaceVar.put(IRVariable(irField.name), IRVariable(edgeInfo.alias))
}
exprMap.put(
expr._1,
exprRewrite(expr._2, project.solved.getNodeAliasSet, project.solved.getEdgeAliasSet, map))
}
project.copy(expr = exprMap.toMap)
}
private def exprRewrite(
expr: Expr,
nodes: Set[String],
edges: Set[String],
map: Map[String, Object]): Expr = {
val input = ExprUtils.getAllInputFieldInRule(expr, nodes, edges)
val replaceVar = new mutable.HashMap[IRField, IRField]
for (irField <- input) {
if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) {
for (propName <- irField.asInstanceOf[IRNode].fields) {
if (NODE_DEFAULT_PROPS.contains(propName)) {
val edgeInfo = map(irField.name).asInstanceOf[Connection]
replaceVar.put(
IRProperty(irField.name, propName),
IRProperty(edgeInfo.alias, genField(edgeInfo.direction, propName)))
replaceVar.put(IRVariable(irField.name), IRVariable(edgeInfo.alias))
}
}
}
if (replaceVar.isEmpty) {
exprMap.+=(expr)
} else {
exprMap.put(expr._1, ExprUtils.renameVariableInExpr(expr._2, replaceVar.toMap))
}
if (replaceVar.isEmpty) {
expr
} else {
ExprUtils.renameVariableInExpr(expr, replaceVar.toMap)
}
}
private def ddlUpdate(ddl: DDL, map: Map[String, Object]): DDL = {
val nodes = ddl.solved.getNodeAliasSet
val edges = ddl.solved.getEdgeAliasSet
val newOps = new mutable.HashSet[DDLOp]()
for (ddlOp <- ddl.ddlOp) {
ddlOp match {
case ddlOp: AddProperty => newOps.add(ddlOp)
case AddVertex(s, props) =>
val newProps = props.map(p => (p._1, exprRewrite(p._2, nodes, edges, map)))
newOps.add(AddVertex(s, newProps))
case AddPredicate(predicate) =>
val newProps = predicate.fields.map(p => (p._1, exprRewrite(p._2, nodes, edges, map)))
newOps.add(AddPredicate(predicate.copy(fields = newProps)))
}
}
project.copy(expr = exprMap.toMap)
ddl.copy(ddlOp = newOps.toSet)
}
private def selectUpdate(select: Select, map: Map[String, Object]): Select = {
@ -172,40 +199,6 @@ object NodeIdToEdgeProperty extends Rule {
select.copy(fields = newFields.toList)
}
private def ddlUpdate(ddl: DDL, map: Map[String, Object]): DDL = {
val ddlOps = new mutable.HashSet[DDLOp]()
for (ddlOp <- ddl.ddlOp) {
ddlOp match {
case AddPredicate(predicate) =>
val newFields = new mutable.HashMap[String, Expr]()
for (field <- predicate.fields) {
val input = ExprUtils.getAllInputFieldInRule(field._2, null, null)
val replaceVar = new mutable.HashMap[IRField, IRProperty]
for (irField <- input) {
if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) {
for (propName <- irField.asInstanceOf[IRNode].fields) {
if (NODE_DEFAULT_PROPS.contains(propName)) {
val edgeInfo = map(irField.name).asInstanceOf[Connection]
replaceVar.put(
IRProperty(irField.name, propName),
IRProperty(edgeInfo.alias, genField(edgeInfo.direction, propName)))
}
}
}
}
if (replaceVar.isEmpty) {
newFields.put(field._1, field._2)
} else {
newFields.put(field._1, ExprUtils.renameVariableInExpr(field._2, replaceVar.toMap))
}
}
ddlOps.add(AddPredicate(predicate.copy(fields = newFields.toMap)))
case _ => ddlOps.add(ddlOp)
}
}
ddl.copy(ddlOp = ddlOps.toSet)
}
private def targetConnection(expandInto: ExpandInto): Connection = {
val alias = expandInto.pattern.root.alias
val edgeAlias = expandInto.transform[Connection] {
@ -217,6 +210,12 @@ object NodeIdToEdgeProperty extends Rule {
} else {
targetConnection(alias, expandInto.pattern)
}
case (linkedExpand: LinkedExpand, list) =>
if (!list.isEmpty && list.head != null) {
list.head
} else {
targetConnection(alias, linkedExpand.edgePattern)
}
case (_, list) =>
if (list.isEmpty) {
null

View File

@ -39,8 +39,8 @@ object Pure extends SimpleRule {
val projects = select.refFields.map((_, Directly)).toMap
select.withNewChildren(Array.apply(Project(in, projects, in.solved)))
case project @ Project(in, _, _) =>
if (in.isInstanceOf[Project] || in.isInstanceOf[ExpandInto] || in
.isInstanceOf[PatternScan] || in.isInstanceOf[BinaryLogicalOperator]) {
if (in.isInstanceOf[Project] || in.isInstanceOf[ExpandInto] || in.isInstanceOf[LinkedExpand]
|| in.isInstanceOf[PatternScan] || in.isInstanceOf[BinaryLogicalOperator]) {
project
} else {
val projectOutput: List[Var] = project.fields
@ -71,17 +71,20 @@ object Pure extends SimpleRule {
}
}
varMap.values.map(f => {
if (f.isInstanceOf[PathVar]) {
f
} else if (!solved.fields.contains(f.name)) {
throw InvalidRefVariable(s"can not find $f")
} else if (solved.fields.get(f.name).get.isInstanceOf[RepeatPathVar]) {
f.intersect(solved.fields.get(f.name).get.asInstanceOf[RepeatPathVar].pathVar.elements(1))
} else {
f.intersect(solved.fields.get(f.name).get)
}
}).toList
varMap.values
.map(f => {
if (f.isInstanceOf[PathVar]) {
f
} else if (!solved.fields.contains(f.name)) {
throw InvalidRefVariable(s"can not find $f")
} else if (solved.fields.get(f.name).get.isInstanceOf[RepeatPathVar]) {
f.intersect(
solved.fields.get(f.name).get.asInstanceOf[RepeatPathVar].pathVar.elements(1))
} else {
f.intersect(solved.fields.get(f.name).get)
}
})
.toList
}
override def direction: Direction = Down

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,7 @@
package com.antgroup.openspg.reasoner.lube.logical
import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
import com.antgroup.openspg.reasoner.common.types._
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
@ -76,4 +76,23 @@ class ExprUtilTests extends AnyFunSpec {
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
}
it("test advanced type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map
.apply(IRProperty("e", "eventTime") -> KTStd("STD.Timestamp", KTLong, false))
.asInstanceOf[Map[IRField, KgType]]
val r1 = ProjectRule(
IRVariable("eventDay"),
parser.parse("from_unix_time_ms(e.eventTime, 'yyyyMMdd')"))
val r2 = ProjectRule(
IRVariable("timeInDay"),
parser.parse("from_unix_time_ms(e.eventTime, 'yyyyMMdd') in ['20240311', '20240312']"))
ExprUtil.getTargetType(r1, map, udfRepo) should equal(KTString)
ExprUtil.getTargetType(r2, map, udfRepo) should equal(KTBoolean)
}
}

View File

@ -63,101 +63,6 @@ class OptimizerTests extends AnyFunSpec {
}
}
it("testEdgeToProperty") {
val schema = ResourceLoader.loadResourceFile("TuringSchema.json")
val catalog = new JSONGraphCatalog(schema)
val dsl =
"""
|Define (user:TuringCore.AlipayUser)-[bt:belongTo]->(tc:`TuringCrowd`/`通勤用户`) {
| GraphStructure {
| (user) -[pwl:workLoc]-> (aa1:CKG.AdministrativeArea)
| (te:TuringCore.TravelEvent) -[ptler:traveler]-> (user)
| (te) -[ptm:travelMode]-> (tm:TuringCore.TravelMode)
| (te) -[pte:travelEndpoint]-> (aa1:CKG.AdministrativeArea)
| }
| Rule {
| R1('常驻地在杭州'): aa1.id == '中国-浙江省-杭州市'
| R2('工作日上班时间通勤用户'): dayOfWeek(te.eventTime) in [1, 2, 3, 4, 5]
| and hourOfDay(te.eventTime) in [6, 7, 8, 9, 10, 17, 18, 19, 20, 21]
| R3('公交地铁'): tm.id in ['bus', 'subway']
| tmCount('出行次数') = group(user).count(te.id)
| R4('出行次数大于3次'): tmCount >= 3
| }
|}
|""".stripMargin
catalog.init()
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(catalog, parser, Map.empty)
val logicalPlan = LogicalPlanner.plan(block).head
val finalOp = LogicalOptimizer.optimize(
logicalPlan,
Seq.apply(FilterPushDown, EdgeToProperty, SolvedModelPure))
val qlTransformer = new Expr2QlexpressTransformer()
finalOp.findExactlyOne { case ExpandInto(_, _, pattern) =>
qlTransformer.transform(pattern.getNode("te").rule).head should equal(
"((te.travelMode in [\"bus\",\"subway\"]) && ((dayOfWeek(te.eventTime) in [1,2,3,4,5]) && (hourOfDay(te.eventTime) in [6,7,8,9,10,17,18,19,20,21]))) && (te.travelEndpoint == \"中国-浙江省-杭州市\")")
}
finalOp.findExactlyOne { case PatternScan(_, pattern) =>
qlTransformer.transform(pattern.getNode("user").rule).head should equal(
"user.workLoc == \"中国-浙江省-杭州市\"")
}
finalOp.findExactlyOne { case Start(_, _, _, solved) =>
solved.alias2Types.keys.toSet should equal(Set.apply("user", "te", "ptler"))
solved.fields("user").asInstanceOf[NodeVar].fields.map(_.name) should equal(
Set.apply("workLoc"))
solved.fields("te").asInstanceOf[NodeVar].fields.map(_.name) should equal(
Set.apply("eventTime", "id", "travelMode", "travelEndpoint"))
}
}
it("concept to property") {
val schema = ResourceLoader.loadResourceFile("TuringSchema.json")
val catalog = new JSONGraphCatalog(schema)
catalog.init()
val start = Start(
catalog.getGraph(Catalog.defaultGraphName),
null,
Set.empty,
SolvedModel(
Map.empty,
Map.apply(
("te", NodeVar("te", Set.apply(new Field("eventTime", KTString, true)))),
("user", NodeVar("user", Set.apply())),
("tm", NodeVar("tm", Set.apply(new Field("id", KTString, true))))),
Map.empty))
val r1 = LogicRule(
"R1",
"xx",
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("eventTime"), Ref("te")), VString("1")))
val r2 = LogicRule(
"R2",
"xx",
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("id"), Ref("tm")), VString("bus")))
val patternElementMap = Map.apply(
("te", PatternElement("te", Set.apply("TuringCore.TravelEvent"), r1)),
("tm", PatternElement("tm", Set.apply("TuringCore.TravelMode"), r2)))
val edges: Map[String, Set[Connection]] = Map.apply((
"te",
Set.apply(
new PatternConnection("ptm", "te", Set.apply("travelMode"), "tm", Direction.OUT, null))))
val expand = ExpandInto(
start,
patternElementMap("te"),
PartialGraphPattern("te", patternElementMap, edges))
val logicalOp =
ExpandInto(expand, patternElementMap("tm"), NodePattern(patternElementMap("tm")))
val finalOp = BottomUp[LogicalOperator](EdgeToProperty.rule(null)).transform(logicalOp)
val qlTransformer = new Expr2QlexpressTransformer()
finalOp.findExactlyOne { case ExpandInto(_, _, pattern) =>
pattern.root.alias should equal("te")
qlTransformer.transform(pattern.getNode("te").rule).head should equal(
"(te.travelMode == \"bus\") && (te.eventTime == \"1\")")
}
}
it("expandInto pure") {
val dsl =
"""